In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
DEBUG = False

In [None]:
%run ./nb_core/root_dirs.ipynb
setup_syspath_disentangle(DEBUG)
%run ./nb_core/disentangle_imports.ipynb

In [None]:
from disentangle.analysis.plot_utils import clean_ax
from disentangle.core.tiff_reader import load_tiff
from disentangle.config_utils import load_config, get_configdir_from_saved_predictionfile
from disentangle.core.data_split_type import DataSplitType
from disentangle.scripts.evaluate import get_highsnr_data
import ml_collections

In [None]:
noise_levels = [8900]
pred_dir = '/group/jug/ashesh/data/paper_stats/'

usplit_fname = {4450: 'Test_P64_G32_M5_Sk44/pred_disentangle_2402_D16-M3-S0-L0_42.tif',
                8900: 'Test_P64_G32_M5_Sk44/pred_disentangle_2402_D16-M3-S0-L0_43.tif'}

denoiSplit_fname = {4450: 'Test_P256_G64_M5_Sk44/pred_disentangle_2402_D16-M3-S0-L0_16.tif',
                    8900: 'Test_P256_G64_M5_Sk44/pred_disentangle_2402_D16-M3-S0-L0_15.tif'}

denoiSplitNM_fname = {4450: 'Test_P256_G64_M5_Sk44/pred_disentangle_2402_D16-M3-S0-L0_62.tif',
                      8900: 'Test_P256_G64_M5_Sk44/pred_disentangle_2402_D16-M3-S0-L0_61.tif'}
hdn_usplit = {} #{4450: 'pred_disentangle_2402_D23-M3-S0-L0_34.tif'}

In [None]:
def sanity_check_config():
    data_dicts = [usplit_fname, denoiSplit_fname, denoiSplitNM_fname]
    for ith_data, ddict in enumerate(data_dicts):
        for noise,fname in ddict.items():
            configdir = get_configdir_from_saved_predictionfile(fname)
            config = load_config(configdir)
            assert 'synthetic_gaussian_scale' in config.data
            assert config.data.synthetic_gaussian_scale == noise, f'{ith_data} {fname}: noise: {noise}, config: {config.data.synthetic_gaussian_scale}'
    


In [None]:
sanity_check_config()

### Loading target

In [None]:
configdir  = get_configdir_from_saved_predictionfile(denoiSplitNM_fname[noise_levels[0]])
config = ml_collections.ConfigDict(load_config(configdir))
highsnr_data = get_highsnr_data(config, config.datadir, DataSplitType.Test)

### Loading predictions

In [None]:
usplit_data = {k: load_tiff(os.path.join(pred_dir, v)) for k,v in usplit_fname.items()}
denoiSplit_data = {k: load_tiff(os.path.join(pred_dir, v)) for k,v in denoiSplit_fname.items()}
denoiSplitNM_data = {k: load_tiff(os.path.join(pred_dir, v)) for k,v in denoiSplitNM_fname.items()}
hdn_usplit_data = {k: load_tiff(os.path.join(pred_dir, v)) for k,v in hdn_usplit.items()}


### Cropping the target to get to the same shape as the predictions

In [None]:
shape = usplit_data[noise_levels[0]].shape
highsnr_data = highsnr_data[:, :shape[1], :shape[2]].copy()

In [None]:
def sanity_check_data():
    # all shapes should be same
    for noise_level in noise_levels:
        shape = usplit_data[noise_level].shape
        if noise_level in denoiSplit_data:
            assert shape == denoiSplit_data[noise_level].shape
        if noise_level in denoiSplitNM_data:
            assert shape == denoiSplitNM_data[noise_level].shape
        if noise_level in hdn_usplit_data:
            assert shape == hdn_usplit_data[noise_level].shape
        assert shape == highsnr_data.shape
            

In [None]:
sanity_check_data()

In [None]:
import numpy as np
def get_noisy_data(noise_level):
    return highsnr_data + np.random.normal(0, noise_level, highsnr_data.shape)

In [None]:
noise_level = 8900
noisy_data = get_noisy_data(noise_level)

In [None]:
data_idx = 0
ncols = 3 + 2
nrows = 2
img_sz = 4
_,ax = plt.subplots(figsize=(ncols*img_sz, nrows*img_sz), nrows=nrows, ncols=ncols)
ax[0,0].imshow(np.mean(noisy_data[data_idx], axis=-1), cmap='magma')
ax[0,0].set_title('Input')
ax[0,1].imshow(noisy_data[data_idx,:,:,0], cmap='magma')
ax[1,1].imshow(noisy_data[data_idx,:,:,1], cmap='magma')

ax[0,2].imshow(usplit_data[noise_level][data_idx,...,0], cmap='magma')
ax[1,2].imshow(usplit_data[noise_level][data_idx,...,1], cmap='magma')

ax[0,3].imshow(denoiSplit_data[noise_level][data_idx,...,0], cmap='magma')
ax[1,3].imshow(denoiSplit_data[noise_level][data_idx,...,1], cmap='magma')

ax[0,4].imshow(highsnr_data[data_idx,...,0], cmap='magma')
ax[1,4].imshow(highsnr_data[data_idx,...,1], cmap='magma')
clean_ax(ax)


In [None]:
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt

img_sz = 3
ncol_imgs = 5
nrow_imgs = 2
example_spacing = 1
grid_factor = 5
nimgs = 1
fig_w = ncol_imgs * img_sz
fig_h = img_sz * nrow_imgs + example_spacing * (nimgs - 1) / grid_factor
fig = plt.figure(figsize=(fig_w, fig_h))
gs = GridSpec(nrows=int(grid_factor * fig_h), ncols=int(grid_factor * fig_w), hspace=0.2, wspace=0.2)
grid_img_sz = img_sz * grid_factor

# ax_temp.imshow(highsnr_data[data_idx,...,0])
ax_temp = fig.add_subplot(gs[grid_img_sz//2:grid_img_sz + grid_img_sz//2,:grid_img_sz])
ax_temp.imshow(np.mean(noisy_data[data_idx], axis=-1), cmap='magma')
clean_ax(ax_temp)
# ax[0,0].set_title('Input')
ax_temp = fig.add_subplot(gs[:grid_img_sz, grid_img_sz:grid_img_sz * 2])
ax_temp.imshow(noisy_data[data_idx,:,:,0], cmap='magma')
clean_ax(ax_temp)

ax_temp = fig.add_subplot(gs[grid_img_sz:grid_img_sz * 2, grid_img_sz:grid_img_sz * 2])
ax_temp.imshow(noisy_data[data_idx,:,:,1], cmap='magma', vmin=highsnr_data[data_idx,...,1].min(), vmax=highsnr_data[data_idx,...,1].max())
clean_ax(ax_temp)

ax_temp = fig.add_subplot(gs[:grid_img_sz, grid_img_sz * 2:grid_img_sz * 3])
ax_temp.imshow(usplit_data[noise_level][data_idx,...,0], cmap='magma')
clean_ax(ax_temp)
ax_temp = fig.add_subplot(gs[grid_img_sz:grid_img_sz * 2, grid_img_sz * 2:grid_img_sz * 3])
ax_temp.imshow(usplit_data[noise_level][data_idx,...,1], cmap='magma')
clean_ax(ax_temp)

ax_temp = fig.add_subplot(gs[:grid_img_sz, grid_img_sz * 3:grid_img_sz * 4])
ax_temp.imshow(denoiSplitNM_data[noise_level][data_idx,...,0], cmap='magma')
clean_ax(ax_temp)
ax_temp = fig.add_subplot(gs[grid_img_sz:grid_img_sz * 2, grid_img_sz * 3:grid_img_sz * 4])
ax_temp.imshow(denoiSplitNM_data[noise_level][data_idx,...,1], cmap='magma')
clean_ax(ax_temp)

ax_temp = fig.add_subplot(gs[:grid_img_sz, grid_img_sz * 4:grid_img_sz * 5])
ax_temp.imshow(highsnr_data[data_idx,...,0], cmap='magma')
clean_ax(ax_temp)
ax_temp = fig.add_subplot(gs[grid_img_sz:grid_img_sz * 2, grid_img_sz * 4:grid_img_sz * 5])
ax_temp.imshow(highsnr_data[data_idx,...,1], cmap='magma')
clean_ax(ax_temp)


# ax_temp = fig.add_subplot(gs[row_s:row_s + grid_img_sz, grid_img_sz * i:grid_img_sz * i + grid_img_sz])
# ax_temp.imshow(ch0_pred, vmax=vmax0)

In [None]:
print(highsnr_data[...,0].min(), highsnr_data[...,0].max())
print(noisy_data[...,0].min(), noisy_data[...,0].max())

In [None]:
usplit_data[noise_level][...,0].min(), usplit_data[noise_level][...,0].max()

In [None]:
idx_list = [58,61,65, 30,31,33, 34,29,63, 32,35,36, 48,42,50,51,54,59, 55,43,49, 68,56,62,74,80,78, 75,79,82, 76,81,83, 77,85, 84]
for idx in idx_list:
    data  = load_tiff(f'/group/jug/ashesh/data/paper_stats/All_P128_G64_M50_Sk44/pred_disentangle_2402_D16-M23-S0-L0_{idx}.tif')
    r0 = np.logical_and(data > 64000, data < 64500)
    r1 = np.logical_and(data > 64500, data < 65000)
    r2 = data > 65000
    print(idx, data.min(), data.max(), r0.sum(), r1.sum(), r2.sum()) 

In [None]:
a = np.uint16([65535]) + 10
a