In [None]:
from disentangle.core.tiff_reader import load_tiff
from skimage.transform import resize

output_data_dir = '/group/jug/ashesh/naturemethods/hhmi'
k_idx = 3
# orig_res_fpath = '/group/jug/ashesh/data/paper_stats/Test_P128_G3-64-64_M50_Sk0/pred_training_disentangle_2506_D32-M3-S0-L8_6_1.tif'
orig_res_fpath = '/group/jug/ashesh/data/paper_stats/Test_P64_G32_M40_Sk0/pred_training_disentangle_2505_D32-M3-S0-L8_26_1.tif'
half_res_fpath = '/group/jug/ashesh/data/paper_stats/Test_P64_G32_M50_Sk0/pred_training_disentangle_2505_D32-M3-S0-L8_27_1.tif'
gt_fpath = f'/group/jug/ashesh/kth_data/D32/kth{k_idx}/gt_for_pred_training_disentangle_2505_D32-M3-S0-L8_27_1.tif'

gt = load_tiff(gt_fpath)
orig = load_tiff(orig_res_fpath).squeeze()
half = load_tiff(half_res_fpath).squeeze()
orig = orig[k_idx]
half = half[k_idx]

# gt_resized = resize(gt*1.0, (gt.shape[0]//2, gt.shape[1] // 2, gt.shape[2]), anti_aliasing=True) 
orig_resized = resize(orig*1.0, (orig.shape[0]//2, orig.shape[1] // 2, orig.shape[2]), anti_aliasing=True)
orig.shape, half.shape, gt.shape, orig_resized.shape

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from disentangle.analysis.plot_utils import clean_ax, add_text
savefig = True

def add_title(ax, title, offset=0, alpha=0.9):
    ax.text(65+offset, 85, title, bbox=dict(facecolor='white', alpha=alpha))

fontsize=13
q = 0.92
q_val0 = np.quantile(gt[...,0], q)
q_val1 = np.quantile(gt[...,1], q)
q_val2 = np.quantile(gt[...,2], q)
mask0 = gt[...,0] > q_val0
mask1 = gt[...,1] > q_val1
mask2 = gt[...,2] > q_val2
_,ax = plt.subplots(figsize=(9,9),ncols=3,nrows=3)
ax[0,0].imshow(gt[...,0], cmap='magma')
ax[0,1].imshow(gt[...,1], cmap='magma')
ax[0,2].imshow(gt[...,2], cmap='magma')
ax[1,0].imshow(mask0, cmap='gray')
ax[1,1].imshow(mask1, cmap='gray')
ax[1,2].imshow(mask2, cmap='gray')
ax[2,0].imshow(gt[...,0] * mask0, cmap='magma')
ax[2,1].imshow(gt[...,1] * mask1, cmap='magma')
ax[2,2].imshow(gt[...,2] * mask2, cmap='magma')

# avg intensities.
avg0 = np.sum(gt[...,0] * mask0)/np.sum(mask0)
add_title(ax[2,0], f' Avg. Non-zero Intensity: {avg0:.0f}')
avg1 = np.sum(gt[...,1] * mask1)/np.sum(mask1)
add_title(ax[2,1], f' Avg. Non-zero Intensity: {avg1:.0f}', offset=15)
avg2 = np.sum(gt[...,2] * mask2)/np.sum(mask2)
add_title(ax[2,2], f' Avg. Non-zero Intensity: {avg2:.0f}', offset=30)


ax[0,0].set_ylabel('Target', fontsize=fontsize)
ax[1,0].set_ylabel('Mask (92 percentile)', fontsize=fontsize)
ax[2,0].set_ylabel('Masked Target', fontsize=fontsize)
ax[0,0].set_title('Ch0', fontsize=fontsize)
ax[0,1].set_title('Ch1', fontsize=fontsize)
ax[0,2].set_title('Ch2', fontsize=fontsize)


clean_ax(ax)
# reduce the size between subplots to 0.02
plt.subplots_adjust(wspace=0.02, hspace=0.02)

print(np.mean(gt[mask0, 0]), np.mean(gt[mask1, 1]), np.mean(gt[mask2, 2]))
if savefig:
    fpath = f'{output_data_dir}/hhmi_kth{k_idx}_skew.png'
    print(f'Saving figure to {fpath}')
    plt.savefig(fpath, dpi=100, bbox_inches='tight')

In [None]:
import numpy as np  
print('0', np.quantile(gt[...,0], [0.92,0.99,0.995]).round())
print('1', np.quantile(gt[...,1], [0.92,0.99,0.995]).round())
print('2', np.quantile(gt[...,2], [0.92,0.99,0.995]).round())

In [None]:
from disentangle.core.psnr import RangeInvariantPsnr
ch_idx = 2
print(RangeInvariantPsnr(gt[None,...,ch_idx], half[None,...,ch_idx]).item(),RangeInvariantPsnr(gt[None,...,ch_idx], orig_resized[None,...,ch_idx]).item()) 

In [None]:
from disentangle.analysis.plot_utils import clean_ax, add_text
import matplotlib.pyplot as plt
from matplotlib_scalebar.scalebar import ScaleBar
from disentangle.core.psnr import RangeInvariantPsnr
import os

output_data_dir = '/group/jug/ashesh/naturemethods/hhmi'
savefig = True

_,ax = plt.subplots(figsize=(16,12),ncols=4,nrows=3)

scalebar = ScaleBar(45, 
                        "nm", 
                        # length_fraction=0.1, 
                        box_alpha=0.6, frameon=True, location='upper right', font_properties={'size':12})

ax[0,0].add_artist(scalebar)

for i in range(3):
    ax[0,i+1].imshow(gt[...,i], cmap='magma')
    add_text(ax[0, i+1], f'Target', gt.shape[-3:-1], place='TOP_RIGHT', alpha=0.7)

ax[0,0].imshow(gt.mean(axis=-1), cmap='magma')
for i in range(3):
    ax[1,i+1].imshow(orig[...,i], cmap='magma')
    add_text(ax[1, i+1], f' Pred ', orig.shape[-3:-1], place='TOP_RIGHT', alpha=0.7)
    psnr = RangeInvariantPsnr(gt[None,...,i], orig_resized[None,...,i]).item()
    psnr_str = f'PSNR: {psnr:.1f}'
    add_text(ax[1, i+1], psnr_str, orig.shape[-3:-1], place='TOP_LEFT', alpha=0.7)


for i in range(3):
    ax[2,i+1].imshow(half[...,i], cmap='magma')
    add_text(ax[2, i+1], f' Pred ', half.shape[-3:-1], place='TOP_RIGHT', alpha=0.7)
    psnr = RangeInvariantPsnr(gt[None,...,i], half[None,...,i]).item()
    psnr_str = f'PSNR: {psnr:.1f}'
    add_text(ax[2, i+1], psnr_str, half.shape[-3:-1], place='TOP_LEFT', alpha=0.7)

# disable the axis ax[1,0] and ax[2,0]
ax[1,0].axis('off')
ax[2,0].axis('off')

ax[0,0].set_title('Input', fontsize=fontsize)
ax[0,1].set_title('Ch1', fontsize=fontsize)
ax[0,2].set_title('Ch2', fontsize=fontsize)
ax[0,3].set_title('Ch3', fontsize=fontsize)

ax[1,1].set_ylabel('On Raw Inputs', fontsize=fontsize)
ax[2,1].set_ylabel('On Binned Inputs', fontsize=fontsize)

# reduce size between the subplots
plt.subplots_adjust(wspace=0.02, hspace=0.02)
clean_ax(ax)
if savefig:
    fpath = os.path.join(output_data_dir, f'hhmi_comparison_{k_idx}.png')
    print(fpath)
    plt.savefig(fpath, dpi=150, bbox_inches='tight')


In [None]:
from disentangle.data_loader.multitiffsamesized_raw_dloader import get_train_val_data as _loadmultitiff
from disentangle.config_utils import load_config
from disentangle.core.data_split_type import DataSplitType
from disentangle.core.tiff_reader import load_tiff

datadir = '/group/jug/ashesh/data/HHMI25/'
config = load_config('/group/jug/ashesh/training/disentangle/2505/D32-M3-S0-L8/26')
val_data = _loadmultitiff(datadir, config.data, datasplit_type=DataSplitType.Test)
tar = val_data[0][0]*1.0  # target data
# pred = load_tiff('/group/jug/ashesh/data/paper_stats/Test_P64_G32_M50_Sk0/pred_training_disentangle_2505_D32-M3-S0-L8_26_1.tif').squeeze()
pred = load_tiff(orig_res_fpath).squeeze()
print(tar.shape, pred.shape)

In [None]:
val_data[0][0].max(), pred.max()

In [None]:
from disentangle.scripts.evaluate import compute_high_snr_stats
compute_high_snr_stats(tar[:5], pred[:5])

In [None]:
from skimage.transform import resize

tar_resized = resize(tar, (tar.shape[0], tar.shape[1]//2, tar.shape[2] // 2, tar.shape[3]), anti_aliasing=True) 
pred_resized =resize(pred, (pred.shape[0], pred.shape[1]//2, pred.shape[2] // 2, pred.shape[3]), anti_aliasing=True) 
compute_high_snr_stats(tar_resized[:5], pred_resized[:5])


In [None]:
compute_high_snr_stats(tar_resized, pred_resized)


In [None]:
from disentangle.core.tiff_reader import load_tiff
tmp = load_tiff('/group/jug/ashesh/data/paper_stats/Test_P64_G3-32-32_M50_Sk0/kth_5/pred_training_disentangle_2507_D33-M3-S0-L8_19_1.tif')
tmp_gt = load_tiff('/group/jug/ashesh/kth_data/D33/kth5/gt_for_pred_training_disentangle_2507_D33-M3-S0-L8_19_1.tif')

In [None]:
tmp.shape

In [None]:
import matplotlib.pyplot as plt
# 6 columns, 2 rows
_, ax = plt.subplots(ncols=6, nrows=2, figsize=(12, 4))
for i in range(6):
    ax[0,i].imshow(tmp[:1000,:1000,i], cmap='magma')
    ax[1,i].imshow(tmp_gt[:1000,:1000,i], cmap='magma')
    ax[0,i].set_title(f'Pred Ch{i}')
    ax[1,i].set_title(f'GT Ch{i}')
    ax[0,i].axis('off')
    ax[1,i].axis('off')

In [None]:
tmp.shape