In [None]:
from disentangle.config_utils import get_configdir_from_saved_predictionfile
import ml_collections
import os
from disentangle.config_utils import load_config
from disentangle.core.data_type import DataType
from disentangle.scripts.evaluate import * 
from disentangle.core.data_split_type import DataSplitType
from disentangle.core.tiff_reader import load_tiff
denoised_fpath = '/group/jug/ashesh/data/paper_stats/All_P128_G64_M50_Sk44/pred_disentangle_2403_D16-M23-S0-L0_17.tif'
paper_figures_dir = '/group/jug/ashesh/data/paper_figures'

denoised_data = load_tiff(denoised_fpath)
denoiser_configdir = get_configdir_from_saved_predictionfile(os.path.basename(denoised_fpath))
denoiser_config = load_config(denoiser_configdir)
denoiser_config = ml_collections.ConfigDict(denoiser_config)
eval_datasplit_type = DataSplitType.Test
if denoiser_config.data.data_type == DataType.BioSR_MRC:
    denoiser_input_dir = '/group/jug/ashesh/data/BioSR/'
elif denoiser_config.data.data_type == DataType.OptiMEM100_014:
    denoiser_input_dir = '/group/jug/ashesh/data/microscopy/OptiMEM100x014.tif'
elif denoiser_config.data.data_type == DataType.SeparateTiffData:
    denoiser_input_dir = '/group/jug/ashesh/data/ventura_gigascience/'
    denoiser_config.data.ch1_fname = denoiser_config.data.ch1_fname.replace('lowsnr', 'highsnr')
    denoiser_config.data.ch2_fname = denoiser_config.data.ch2_fname.replace('lowsnr', 'highsnr')
with denoiser_config.unlocked():
    highres_data = get_data_without_synthetic_noise(denoiser_input_dir, denoiser_config, eval_datasplit_type)

if denoiser_config.model.denoise_channel == 'Ch1':
    highres_data = highres_data[...,0]
elif denoiser_config.model.denoise_channel == 'Ch2':
    highres_data = highres_data[...,1]
elif denoiser_config.model.denoise_channel == 'input':
    highres_data = np.mean(highres_data, axis=-1)
else:
    raise ValueError('Invalid denoise channel')


In [None]:
def get_noisy_data(highres_data):
    poisson_noise_factor = denoiser_config.data.poisson_noise_factor
    noisy_data = (np.random.poisson(highres_data / poisson_noise_factor) * poisson_noise_factor).astype(np.float32)

    if denoiser_config.data.get('enable_gaussian_noise', False):
        synthetic_scale = denoiser_config.data.get('synthetic_gaussian_scale', 0.1)
        shape = highres_data.shape
        noisy_data += np.random.normal(0, synthetic_scale, shape)
    return noisy_data


In [None]:
noisy_data = get_noisy_data(highres_data)

In [None]:
import matplotlib.pyplot as plt
from disentangle.analysis.plot_utils import clean_ax
nimgs = 3
imgsz = 2
factor = 1.2
_,ax = plt.subplots(figsize=(imgsz*3/factor,nimgs*imgsz),ncols=3,nrows=nimgs)
h = 256
w = int(256/factor)
for i in range(nimgs):
    hs = np.random.randint(0, highres_data.shape[1]-h)
    ws = np.random.randint(0, highres_data.shape[2]-w)
    print(h,w)
    ax[i,0].imshow(noisy_data[0,hs:hs+h,ws:ws+w],cmap='magma')
    ax[i,1].imshow(denoised_data[0,hs:hs+h,ws:ws+w,0],cmap='magma')
    ax[i,2].imshow(highres_data[0,hs:hs+h,ws:ws+w],cmap='magma')

ax[0,0].set_title('Noisy')
ax[0,1].set_title('Denoised')
ax[0,2].set_title('High SNR')
clean_ax(ax)
plt.subplots_adjust(wspace=0.02, hspace=0.02)
postfix = os.path.basename(denoised_fpath).replace('pred_disentangle_', '').replace('.tif', '')
fpath = os.path.join(paper_figures_dir, f'denoising_{postfix}.png')
plt.savefig(fpath, bbox_inches='tight', dpi=200)
print(fpath)


In [None]:
highres_data.shape

In [None]:
h,w