### Imports

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import random
import sys
import os 

import numpy as np
import torch
import pickle
import ml_collections
import glob
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib
from tqdm import tqdm
from copy import deepcopy

In [None]:
DATA_ROOT = '/group/jug/federico/careamics_training/data'
OUT_ROOT = '/group/jug/federico/careamics_training/training'
CODE_ROOT = '/home/federico.carrara/'
DEBUG = False

In [None]:
sys.path.append(os.path.join(CODE_ROOT, 'Documents/projects/careamics/src'))

from careamics.lvae_training.train_lvae import create_dataset
from careamics.models.lvae.utils import (
    ModelType, LossType
)
from careamics.models.lvae import get_config
from careamics.lvae_training.data_utils import DataType, DataSplitType, GridAlignement, load_tiff
from careamics.lvae_training.metrics import (
    PSNR, 
    RangeInvariantPsnr,
    avg_psnr,
    avg_range_inv_psnr,
    avg_ssim,
    compute_masked_psnr,
    compute_multiscale_ssim
)
from careamics.lvae_training.train_utils import get_mean_std_dict_for_model
from careamics.lvae_training.lightning_module import LadderVAELight
from careamics.lvae_training.eval_utils import (
    show_for_one, 
    get_plots_output_dir,
    get_dset_predictions,
    stitch_predictions,
    Calibration,
    get_calibrated_factor_for_stdev,
    plot_calibration,
    clean_ax,
    plot_error
)
# from disentangle.analysis.lvae_utils import get_img_from_forward_output
# from disentangle.analysis.plot_utils import get_k_largest_indices,plot_imgs_from_idx
# from disentangle.analysis.critic_notebook_utils import get_mmse_dict, get_label_separated_loss
# from disentangle.sampler.random_sampler import RandomSampler

torch.multiprocessing.set_sharing_strategy('file_system')

In [None]:
def fix_seeds():
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    np.random.seed(0)
    random.seed(0)
    torch.backends.cudnn.deterministic = True

In [None]:
ckpt_dir = os.path.join(OUT_ROOT, '2406/LVAE_denoiSplit/53')
assert os.path.exists(ckpt_dir)

In [None]:
# def get_dtype(ckpt_fpath):
#     if os.path.isdir(ckpt_fpath):
#         ckpt_fpath = ckpt_fpath[:-1] if ckpt_fpath[-1] == '/' else ckpt_fpath
#     elif os.path.isfile(ckpt_fpath):
#         ckpt_fpath = os.path.dirname(ckpt_fpath)
#     assert ckpt_fpath[-1] != '/'
#     return int(ckpt_fpath.split('/')[-2].split('-')[0][1:])

In [None]:
# dtype = get_dtype(ckpt_dir)

### Set Evaluation Parameters

In [None]:
# Set parameters
mmse_count = 10
image_size_for_grid_centers = 32 # what we retain from inner padding/tiling
eval_patch_size = None # actual patch size --> if not specified data.image_size
data_t_list = None # list of indexes of the data to be used
model_type = ModelType.LadderVae
eval_datasplit_type = DataSplitType.Val    
psnr_type = 'range_invariant' #'simple', 'range_invariant'
enable_calibration = True
which_ckpt = 'last' # 'best', 'last'

save_comparative_plots = False
batch_size = 32
num_workers = 4
COMPUTE_LOSS = False
use_deterministic_grid = None # for training -> get one 64x64 patch at random (not from the grid)

# threshold = None # 0.02
# compute_kl_loss = False
# evaluate_train = False # inspect training performance
# val_repeat_factor = None

### Load config 

In [None]:
def get_model_checkpoint(ckpt_dir, mode='best'):
    output = []
    if mode == 'best':
        for filename in glob.glob(ckpt_dir + "/*_best.ckpt"):
            output.append(filename)
    elif mode == 'last':
        for filename in glob.glob(ckpt_dir + "/*_last.ckpt"):
            output.append(filename)
    else:
        raise ValueError(f"Mode can be either 'best' or 'last', while you selected {mode}.")
    assert len(output) == 1, '\n'.join(output)
    return output[0]

In [None]:
def load_config(config_fpath):
    if os.path.isdir(config_fpath):
        config_fpath = os.path.join(config_fpath, 'config.pkl')
    else:
        assert config_fpath[-4:] == '.pkl', f'{config_fpath} is not a pickle file. Aborting'
    with open(config_fpath, 'rb') as f:
        config = pickle.load(f)
    return config

In [None]:
if os.path.isdir(ckpt_dir):
    config = load_config(ckpt_dir)
else:
    config = load_config(os.path.dirname(ckpt_dir))

config = ml_collections.ConfigDict(config)

In [None]:
print(config)

Changing config parameters should not be needed anymore, since only few parameters of the model are customizable now

In [None]:
training_image_size = None
training_grid_size = None
with config.unlocked():
#     if 'test_fraction' not in config.training:
#         config.training.test_fraction =0.0

    if 'datadir' not in config:
        config.datadir = ''

#     if 'encoder' not in config.model:
#         config.model.encoder = ml_collections.ConfigDict()
#         assert 'decoder' not in config.model
#         config.model.decoder = ml_collections.ConfigDict()
    
#         config.model.encoder.dropout = config.model.dropout
#         config.model.decoder.dropout = config.model.dropout
#         config.model.encoder.n_filters = config.model.n_filters
#         config.model.decoder.n_filters = config.model.n_filters
        
#     if 'multiscale_retain_spatial_dims' not in config.model.decoder:
#         config.model.decoder.multiscale_retain_spatial_dims = False
        
#     if 'res_block_kernel' not in config.model.encoder:
#         config.model.encoder.res_block_kernel = 3
#         assert 'res_block_kernel' not in config.model.decoder
#         config.model.decoder.res_block_kernel = 3
    
#     if 'res_block_skip_padding' not in config.model.encoder:
#         config.model.encoder.res_block_skip_padding = False
#         assert 'res_block_skip_padding' not in config.model.decoder
#         config.model.decoder.res_block_skip_padding = False
    
#     if 'skip_bottom_layers_count' in config.model:
#         config.model.skip_bottom_layers_count = 0
        
#     if 'logvar_lowerbound' not in config.model:
#         config.model.logvar_lowerbound = None
    
#     if 'train_aug_rotate' not in config.data:
#         config.data.train_aug_rotate = False
    
#     if 'multiscale_lowres_separate_branch' not in config.model:
#         config.model.multiscale_lowres_separate_branch = False
    
#     if 'multiscale_retain_spatial_dims' not in config.model:
#         config.model.multiscale_retain_spatial_dims = False
    
#     config.data.train_aug_rotate=False
    
#     if 'randomized_channels' not in config.data:
#         config.data.randomized_channels = False
        
    if 'predict_logvar' not in config.model:
        config.model.predict_logvar = None
    
    # if 'batchnorm' in config.model and 'batchnorm' not in config.model.encoder:
    #     assert 'batchnorm' not in config.model.decoder
    #     config.model.decoder.batchnorm = config.model.batchnorm
    #     config.model.encoder.batchnorm = config.model.batchnorm
    
#     if 'conv2d_bias' not in config.model.decoder:
#         config.model.decoder.conv2d_bias = True
        
    if eval_patch_size is not None:
        training_image_size = config.data.image_size
        config.data.image_size = eval_patch_size

    if image_size_for_grid_centers is not None:
        training_grid_size = config.data.get('grid_size', "grid_size not present")
        config.data.grid_size = image_size_for_grid_centers

#     if use_deterministic_grid is not None:
#         config.data.deterministic_grid = use_deterministic_grid

#     if threshold is not None:
#         config.data.threshold = threshold

#     if val_repeat_factor is not None:
#         config.training.val_repeat_factor = val_repeat_factor

#     config.model.mode_pred = not compute_kl_loss
    
#     if 'skip_receptive_field_loss_tokens' not in config.loss:
#         config.loss.skip_receptive_field_loss_tokens = []
    
#     if 'lowres_merge_type' not in config.model.encoder:
#         config.model.encoder.lowres_merge_type = 0
    
#     if 'validtarget_random_fraction' in config.data:
#         config.data.validtarget_random_fraction = None

#     if 'input_is_sum' not in config.data:
#         config.data.input_is_sum = False

# print(config)

In [None]:
dtype = config.data.data_type

if DEBUG:
    if dtype == DataType.CustomSinosoid:
        data_dir = f'{DATA_ROOT}/sinosoid/'
    elif dtype == DataType.OptiMEM100_014:
        data_dir = f'{DATA_ROOT}/microscopy/'
else:
    if dtype in [DataType.CustomSinosoid, DataType.CustomSinosoidThreeCurve]:
        data_dir = f'{DATA_ROOT}/sinosoid_without_test/sinosoid/'
    elif dtype == DataType.OptiMEM100_014:
        data_dir = f'{DATA_ROOT}/microscopy/'
    elif dtype == DataType.Prevedel_EMBL:
        data_dir = f'{DATA_ROOT}/Prevedel_EMBL/PKG_3P_dualcolor_stacks/NoAverage_NoRegistration/'
    elif dtype == DataType.AllenCellMito:
        data_dir = f'{DATA_ROOT}/allencell/2017_03_08_Struct_First_Pass_Seg/AICS-11/'
    elif dtype == DataType.SeparateTiffData:
        data_dir = f'{DATA_ROOT}/ventura_gigascience'
    elif dtype == DataType.SemiSupBloodVesselsEMBL:
        data_dir = f'{DATA_ROOT}/EMBL_halfsupervised/Demixing_3P'
    elif dtype == DataType.Pavia2VanillaSplitting:
        data_dir = f'{DATA_ROOT}/pavia2'
    elif dtype == DataType.ExpansionMicroscopyMitoTub:
        data_dir = f'{DATA_ROOT}/expansion_microscopy_Nick/'
    elif dtype == DataType.ShroffMitoEr:
        data_dir = f'{DATA_ROOT}/shrofflab/'
    elif dtype == DataType.HTIba1Ki67:
        data_dir = f'{DATA_ROOT}/Stefania/20230327_Ki67_and_Iba1_trainingdata/'
    elif dtype == DataType.BioSR_MRC:
        data_dir = f'{DATA_ROOT}/BioSR/'
    elif dtype == DataType.ExpMicroscopyV2:
        data_dir = f'{DATA_ROOT}/expansion_microscopy_v2/'
    elif dtype == DataType.TavernaSox2GolgiV2:
        data_dir = f'{DATA_ROOT}/TavernaSox2Golgi/acquisition2/'
    elif dtype == DataType.Pavia3SeqData:
        data_dir = f'{DATA_ROOT}/pavia3_sequential/'
    elif dtype == DataType.NicolaData:
        data_dir = f'{DATA_ROOT}/nikola_data/raw'
        
print(data_dir)

### Load data and model

In [None]:
padding_kwargs = {'mode': 'constant',}
padding_kwargs['constant_values'] = config.data.get('padding_value', 0)

dloader_kwargs = {
    'overlapping_padding_kwargs': padding_kwargs, 
    'grid_alignment': GridAlignement.Center
}

In [None]:
train_dset, val_dset = create_dataset(
    config, 
    data_dir, 
    eval_datasplit_type=eval_datasplit_type,
    kwargs_dict=dloader_kwargs
)
data_mean, data_std = train_dset.get_mean_std()

In [None]:
# create dataset without poisson noise as ground truth
new_config = deepcopy(ml_collections.ConfigDict(config))
if 'poisson_noise_factor' in new_config.data:
    new_config.data.poisson_noise_factor = -1
if 'enable_gaussian_noise' in new_config.data:
    new_config.data.enable_gaussian_noise = False  
    
_, highsnr_val_dset = create_dataset(
    new_config, 
    data_dir, 
    eval_datasplit_type=eval_datasplit_type,
    kwargs_dict=dloader_kwargs
)

In [None]:
with config.unlocked():
    if training_image_size is not None:
        config.data.image_size = training_image_size
        
mean_dict, std_dict = get_mean_std_dict_for_model(config, train_dset)
  
model = LadderVAELight(
    config, 
    mean_dict, 
    std_dict,
    target_ch=config.data.num_channels
)

In [None]:
if os.path.isdir(ckpt_dir):
    ckpt_fpath = get_model_checkpoint(ckpt_dir, mode=which_ckpt)
else:
    assert os.path.isfile(ckpt_dir)
    ckpt_fpath = ckpt_dir

print('Loading checkpoint from', ckpt_fpath)
checkpoint = torch.load(ckpt_fpath)

_ = model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
_= model.cuda()

model.set_params_to_same_device_as(torch.Tensor(1).cuda())

print('Loading from epoch', checkpoint['epoch'])

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Model has {count_parameters(model)/1000_000:.3f}M parameters')

In [None]:
if config.data.multiscale_lowres_count is not None and eval_patch_size is not None:
    model.reset_for_different_output_size(eval_patch_size)

### From here on we perform evaluation

Visualize Data: noisy & ground truth

In [None]:
# Print input (first row) and target (second row) of the val_dset
idx = np.random.randint(len(val_dset))
inp_tmp, tar_tmp, *_ = val_dset[idx]
gt_inp_tmp, gt_tar_tmp, *_ = highsnr_val_dset[idx]

# Noisy
ncols = len(tar_tmp)
nrows = 2
_, ax = plt.subplots(figsize=(4*ncols,4*nrows), ncols=ncols, nrows=nrows)
plt.suptitle("Noisy patches")
for i in range(min(ncols, len(inp_tmp))):
    ax[0,i].imshow(inp_tmp[i])

for channel_id in range(ncols):
    ax[1,channel_id].imshow(tar_tmp[channel_id])
    
# Ground truth
ncols = len(gt_tar_tmp)
_, ax = plt.subplots(figsize=(4*ncols,4*nrows), ncols=ncols, nrows=nrows)
plt.suptitle("Ground Truth patches")
for i in range(min(ncols, len(gt_inp_tmp))):
    ax[0,i].imshow(gt_inp_tmp[i])

for channel_id in range(ncols):
    ax[1,channel_id].imshow(gt_tar_tmp[channel_id])

In [None]:
if data_t_list is not None:
    val_dset.reduce_data(t_list=data_t_list)

In [None]:
def get_full_input_frame(idx, dset):
    img_tuples, noise_tuples = dset._load_img(idx)
    if len(noise_tuples) > 0:
        factor = np.sqrt(2) if dset._input_is_sum else 1.0
        img_tuples = [x + noise_tuples[0] * factor for x in img_tuples]

    inp = 0
    for nch in img_tuples:
        inp += nch/len(img_tuples)
    h_start, w_start = dset._get_deterministic_hw(idx)
    return inp, h_start, w_start

In [None]:
index = np.random.randint(len(val_dset))
inp, tar = val_dset[index]
frame, h_start, w_start = get_full_input_frame(index, val_dset)
print(h_start, w_start)

#### Plot predictions against a baseline for specific indexes

In [None]:
def get_hwt_start(idx):
    h,w,t = val_dset.idx_manager.hwt_from_idx(idx, grid_size=64)
    print(h,w,t)
    pad = val_dset.per_side_overlap_pixelcount()
    h =  h - pad
    w = w - pad
    return h,w,t

def get_crop_from_fulldset_prediction(full_dset_pred, idx, patch_size=256):
    h,w,t = get_hwt_start(idx)
    return np.swapaxes(full_dset_pred[t,h:h+patch_size,w:w+patch_size].astype(np.float32)[None], 0, 3)[...,0]

if save_comparative_plots: # this is false...
    assert eval_datasplit_type == DataSplitType.Test
    # CCP vs Microtubules: 925, 659, 502
    # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G16_M3_Sk0/pred_disentangle_2402_D23-M3-S0-L0_67.tif')
    hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G32_M5_Sk0/pred_disentangle_2403_D23-M3-S0-L0_29.tif')

    # ER vs Microtubule 853, 859, 332
    # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G16_M3_Sk0/pred_disentangle_2402_D23-M3-S0-L0_60.tif')

    #  ER vs CCP 327, 479, 637, 568
    # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G16_M3_Sk0/pred_disentangle_2402_D23-M3-S0-L0_59.tif')

    #  F-actin vs ER 797
    # hdn_usplitdata = load_tiff('/group/jug/ashesh/data/paper_stats/Test_PNone_G32_M10_Sk0/pred_disentangle_2403_D23-M3-S0-L0_15.tif')

    idx = 10 #np.random.randint(len(val_dset))
    patch_size = 500
    mmse_count = 50
    print(idx)
    show_for_one(
        idx, val_dset, 
        highsnr_val_dset, 
        model, 
        None, 
        mmse_count=mmse_count, 
        patch_size=patch_size, 
        baseline_preds=[
            get_crop_from_fulldset_prediction(hdn_usplitdata, idx).astype(np.float32),
        ],
        num_samples=0
    )

    plotsdir = get_plots_output_dir(
        ckpt_dir, 
        patch_size, 
        mmse_count=mmse_count
    )
    
    model_id = ckpt_dir.strip('/').split('/')[-1]
    fname = f'patch_comparison_{idx}_{model_id}.png'
    fpath = os.path.join(plotsdir, fname)
    plt.savefig(fpath, dpi=200, bbox_inches='tight')
    print(f'Saved to {fpath}')

#### Compute predictions and related metrics (PSNR) for the entire validation set

In [None]:
# patch-wise PSNR here

pred_tiled, rec_loss, logvar_tiled, patch_psnr_tuple, pred_std_tiled = get_dset_predictions(
  model, 
  val_dset,
  batch_size,
  num_workers=num_workers,
  mmse_count=mmse_count,
  model_type = model_type,
)
tmp = np.round([x.item() for x in patch_psnr_tuple],2)
print('Patch wise PSNR, as computed during training', tmp, np.mean(tmp))
print(f'Number of predicted tiles: {pred_tiled.shape[0]}, channels: {pred_tiled.shape[1]}, shape: {pred_tiled.shape[2:]}')
print(f'Reconstruction loss distrib: {np.quantile(rec_loss, [0,0.01,0.5, 0.9,0.99,0.999,1]).round(2)}')

In [None]:
# Print tiles in which the logvar is very low
idx_list = np.where(logvar_tiled.squeeze() < -6)[0]
if len(idx_list) > 0:
    plt.imshow(val_dset[idx_list[0]][1][1])

Get full image predictions by stitching the predicted tiles

In [None]:
if pred_tiled.shape[-1] != val_dset.get_img_sz():
    pad = (val_dset.get_img_sz() - pred_tiled.shape[-1] )//2
    pred_tiled = np.pad(pred_tiled, ((0,0),(0,0),(pad,pad),(pad,pad)))

# Stitch tiled predictions
pred = stitch_predictions(
    pred_tiled, 
    val_dset, 
    smoothening_pixelcount=0
)

# Stitch predicted tiled logvar
if len(np.unique(logvar_tiled)) == 1:
    logvar = None
else:
    logvar = stitch_predictions(logvar_tiled, val_dset, smoothening_pixelcount=0)

# Stitch the std of the predictions (i.e., std computed on the mmse_count predictions)
pred_std = stitch_predictions(pred_std_tiled, val_dset, smoothening_pixelcount=0)

In [None]:
if 'target_idx_list' in config.data and config.data.target_idx_list is not None:
    pred = pred[...,:len(config.data.target_idx_list)]
    pred_std = pred_std[...,:len(config.data.target_idx_list)]

Ignore (and remove) the pixels which are present in the last few rows and columns (since not multiples of patch_size)
1. They don't come in the batches. So, in prediction, they are simply zeros. So they are being are ignored right now. 
2. For the border pixels which are on the top and the left, overlapping yields worse performance. This is becuase, there is nothing to overlap on one side. So, they are essentially zero padded. This makes the performance worse. 

In [None]:
def get_ignored_pixels():
    ignored_pixels = 1
    while(pred[0, -ignored_pixels:, -ignored_pixels:,].std() == 0):
        ignored_pixels+=1
    ignored_pixels-=1
    print(f'In {pred.shape}, last {ignored_pixels} many rows and columns are all zero.')
    return ignored_pixels

actual_ignored_pixels = get_ignored_pixels()

In [None]:
if config.data.data_type in [
    DataType.OptiMEM100_014,
    DataType.SemiSupBloodVesselsEMBL, 
    DataType.Pavia2VanillaSplitting,
    DataType.ExpansionMicroscopyMitoTub,
    DataType.ShroffMitoEr,
    DataType.HTIba1Ki67
]:
    ignored_last_pixels = 32 
elif config.data.data_type == DataType.BioSR_MRC:
    ignored_last_pixels = 44
    if val_dset.get_img_sz() == 128:
        ignored_last_pixels = 108
elif config.data.data_type == DataType.NicolaData:
    ignored_last_pixels = 8
else:
    ignored_last_pixels = 0

ignore_first_pixels = 0
# ignored_last_pixels = 160
assert actual_ignored_pixels <= ignored_last_pixels, f'Set ignored_last_pixels={actual_ignored_pixels}'
print(ignored_last_pixels)

In [None]:
tar = val_dset._data
if 'target_idx_list' in config.data and config.data.target_idx_list is not None:
    tar = tar[...,config.data.target_idx_list]

def ignore_pixels(arr, patch_size):
    if arr.shape[2] % patch_size:
        if ignore_first_pixels:
            arr = arr[:,ignore_first_pixels:,ignore_first_pixels:]
        if ignored_last_pixels:
            arr = arr[:,:-ignored_last_pixels,:-ignored_last_pixels]

    return arr

pred = ignore_pixels(pred, val_dset.get_img_sz())
tar = ignore_pixels(tar, val_dset.get_img_sz())
if pred_std is not None:
    pred_std = ignore_pixels(pred_std, val_dset.get_img_sz())
    
print(pred.shape)

#### Perform Calibration

In [None]:
sep_mean, sep_std = model.data_mean, model.data_std
if isinstance(sep_mean, dict):
    sep_mean = sep_mean['target']
    sep_std = sep_std['target']

if isinstance(sep_mean, int):
    pass
else:
    sep_mean = sep_mean.squeeze()[None,None,None]
    sep_std = sep_std.squeeze()[None,None,None]
    sep_mean = sep_mean.cpu().numpy() 
    sep_std = sep_std.cpu().numpy()

tar_normalized = (tar - sep_mean)/ sep_std

# Check if normalization is correct (i.e., not already applied on tar)
print(f"Channelwise means: tar -> {tar.mean(axis=(0,1,2))}, normalized -> {tar_normalized.mean(axis=(0,1,2))}")

Plot RMV vs. RMSE without Calibration

In [None]:
# NOTE: Recall the `pred_std` here is the pixel-wise std of the mmse_count many predictions
if enable_calibration:
    calib = Calibration(
        num_bins=30, 
        mode='pixelwise'
    )
    native_stats = calib.compute_stats(
        pred=pred, 
        pred_logvar=pred_std, 
        target=tar_normalized
    )
    count = np.array(native_stats[0]['bin_count'])
    count = count / count.sum()
    # print(count.cumsum()[:-1])
    plt.plot(native_stats[0]['rmv'][1:-1], native_stats[0]['rmse'][1:-1], 'o')
    plt.title("RMV vs. RMSE plot - Not Calibrated")
    plt.xlabel('RMV'), plt.ylabel('RMSE')

Observe that the plot is far from resembling y = x!

In [None]:
def get_calibration_fnames(ckpt_dir):
    tokens = ckpt_dir.strip('/').split('/')
    modelid = int(tokens[-1])
    model_specs = tokens[-2].replace('-','')
    monthyear = tokens[-3]
    fname_factor = f'calibration_factor_{monthyear}_{model_specs}_{modelid}.npy'
    fname_stats = f'calibration_stats_{monthyear}_{model_specs}_{modelid}.pkl.npy'
    return {'stats': fname_stats, 'factor': fname_factor}

def get_calibration_factor_fname(ckpt_dir):
    return get_calibration_fnames(ckpt_dir)['factor']

def get_calibration_stats_fname(ckpt_dir):
    return get_calibration_fnames(ckpt_dir)['stats']

In [None]:
if enable_calibration:
    inp, _ = val_dset[0]
    plotsdir = get_plots_output_dir(OUT_ROOT, inp.shape[1], mmse_count=mmse_count)
    fname = get_calibration_factor_fname(ckpt_dir)
    factor_fpath = os.path.join(plotsdir, fname)
    
    # Compute calibration factors
    if eval_datasplit_type == DataSplitType.Val:
        # Compute calibration factors for the channels
        calib_factor0 = get_calibrated_factor_for_stdev(pred[...,0], np.log(pred_std[...,0]**2), tar_normalized[...,0], batch_size=8, lr=0.1)
        calib_factor1 = get_calibrated_factor_for_stdev(pred[...,1], np.log(pred_std[...,1]**2), tar_normalized[...,1], batch_size=8, lr=0.1)
        print(calib_factor0, calib_factor1)
        calib_factor = np.array([calib_factor0, calib_factor1]).reshape(1,1,1,2)
        np.save(factor_fpath, calib_factor)
        print(f'Saved calibration factor fitted on validation set to {factor_fpath}')

    # Use pre-computed calibration factor
    elif eval_datasplit_type == DataSplitType.Test:
        print('Loading the calibration factor from the file', factor_fpath)
        calib_factor = np.load(factor_fpath)

    # Given the calibration factor, plot RMV vs. RMSE
    calib = Calibration(num_bins=30, mode='pixelwise')
    pred_logvar = 2* np.log(pred_std * calib_factor)
    stats = calib.compute_stats(
        pred,
        pred_logvar, 
        tar_normalized
    )
    _,ax = plt.subplots(figsize=(5,5))
    plt.title("RMV vs. RMSE plot - Calibrated")
    plot_calibration(ax, stats)

if eval_datasplit_type == DataSplitType.Test:
    stats_fpath = os.path.join(plotsdir, get_calibration_stats_fname(ckpt_dir))
    np.save(stats_fpath, stats)
    print('Saved stats of Test set to ', stats_fpath)

A fancier Calibration Plot with multiple calibration factors:

In [None]:
def get_last_index(bin_count, quantile):
    cumsum = np.cumsum(bin_count)
    normalized_cumsum = cumsum / cumsum[-1]
    for i in range(1, len(normalized_cumsum)):
        if normalized_cumsum[-i] < quantile:
            return i - 1
    return None


def get_first_index(bin_count, quantile):
    cumsum = np.cumsum(bin_count)
    normalized_cumsum = cumsum / cumsum[-1]
    for i in range(len(normalized_cumsum)):
        if normalized_cumsum[i] > quantile:
            return i
    return None

In [None]:
try:
    calib_factors = [
        np.load(os.path.join('/path/to/calibration/factors/dir/', fpath), allow_pickle=True) 
        for fpath in [
            'calibration_stats_1.pkl.npy',
            'calibration_stats_2.pkl.npy',
            'calibration_stats_3.pkl.npy', 
        ]
    ]
    labels = ['w=0.5', 'w=0.9', 'w=1']
except FileNotFoundError:
    print('Calibration factors not found. Skipping the plot.')
    calib_factors = []

if len(calib_factors) > 0:
    _,ax = plt.subplots(figsize=(5,2.5))
    for i, calibration_stats in enumerate(calib_factors):
        first_idx = get_first_index(calibration_stats[()][0]['bin_count'], 0.0001)
        last_idx = get_last_index(calibration_stats[()][0]['bin_count'], 0.9999)
        ax.plot(
            calibration_stats[()][0]['rmv'][first_idx:-last_idx],
            calibration_stats[()][0]['rmse'][first_idx:-last_idx],
            '-+',
            label=labels[i]
        )

    ax.yaxis.grid(color='gray', linestyle='dashed')
    ax.xaxis.grid(color='gray', linestyle='dashed')
    ax.plot(np.arange(0,1.5, 0.01), np.arange(0,1.5, 0.01), 'k--')
    ax.set_facecolor('xkcd:light grey')
    plt.legend(loc='lower right')
    # plt.xlim(0,3)
    # plt.ylim(0,1.25)
    plt.xlabel('RMV')
    plt.ylabel('RMSE')
    ax.set_axisbelow(True)


    plotsdir = get_plots_output_dir(ckpt_dir, 0, mmse_count=0)
    model_id = ckpt_dir.strip('/').split('/')[-1]
    fname = f'calibration_plot_{model_id}.png'
    fpath = os.path.join(plotsdir, fname)
    # plt.savefig(fpath, dpi=200, bbox_inches='tight')
    print(f'Saved to {fpath}')


#### Visually compare Targets and Predictions

In [None]:
# One random target vs predicted image (patch of shape [sz x sz])
ncols = tar.shape[-1]
_,ax = plt.subplots(figsize=(ncols*5, 2*5), nrows=2, ncols=ncols)
img_idx = 0
sz = 800
hs = np.random.randint(tar.shape[1] - sz)
ws = np.random.randint(tar.shape[2] - sz)
for i in range(ncols):
    ax[i,0].set_title(f'Target Channel {i+1}')
    ax[i,0].imshow(tar[0, hs:hs+sz, ws:ws+sz, i])
    ax[i,1].set_title(f'Predicted Channel {i+1}')
    ax[i,1].imshow(pred[0, hs:hs+sz, ws:ws+sz, i])

# plt.subplots_adjust(wspace=0.1, hspace=0.1)
# clean_ax(ax)

In [None]:
nrows = pred.shape[-1]
img_sz = 3
_,ax = plt.subplots(figsize=(4*img_sz,nrows*img_sz), ncols=4, nrows=nrows)
idx = np.random.randint(len(pred))
print(idx)
for ch_id in range(nrows):
    ax[ch_id,0].set_title(f'Target Channel {ch_id+1}')
    ax[ch_id,0].imshow(tar_normalized[idx,..., ch_id], cmap='magma')
    ax[ch_id,1].set_title(f'Predicted Channel {ch_id+1}')
    ax[ch_id,1].imshow(pred[idx,:,:,ch_id], cmap='magma')
    plot_error(
        tar_normalized[idx,...,ch_id], 
        pred[idx,:,:,ch_id], 
        cmap = matplotlib.cm.coolwarm, 
        ax = ax[ch_id,2], 
        max_val = None
    )

    cropsz = 256
    h_s = np.random.randint(0, tar_normalized.shape[1] - cropsz)
    h_e = h_s + cropsz
    w_s = np.random.randint(0, tar_normalized.shape[2] - cropsz)
    w_e = w_s + cropsz

    plot_error(
        tar_normalized[idx,h_s:h_e,w_s:w_e, ch_id], 
        pred[idx,h_s:h_e,w_s:w_e,ch_id], 
        cmap = matplotlib.cm.coolwarm, 
        ax = ax[ch_id,3], 
        max_val = None
    )

    # Add rectangle to the region
    rect = patches.Rectangle((w_s, h_s), w_e-w_s, h_e-h_s, linewidth=1, edgecolor='r', facecolor='none')
    ax[ch_id,2].add_patch(rect)


#### Compute metrics between predicted data and high-SNR (ground truth) data

Prepare data:

In [None]:
# ch1_pred_unnorm = pred[...,0]*sep_std[...,0].cpu().numpy() + sep_mean[...,0].cpu().numpy()
# ch2_pred_unnorm = pred[...,1]*sep_std[...,1].cpu().numpy() + sep_mean[...,1].cpu().numpy()
pred_unnorm = []
for i in range(pred.shape[-1]):
    if sep_std.shape[-1]==1:
        temp_pred_unnorm = pred[...,i]*sep_std[...,0] + sep_mean[...,0]
    else:
        temp_pred_unnorm = pred[...,i]*sep_std[...,i] + sep_mean[...,i]
    pred_unnorm.append(temp_pred_unnorm)

In [None]:
# Get & process high-SNR data from previously loaded dataset
highres_data = highsnr_val_dset._data
if highres_data is not None:
    highres_data = ignore_pixels(highres_data, highsnr_val_dset.get_img_sz()).copy()
    if data_t_list is not None:
        highres_data = highres_data[data_t_list].copy()
    
    if 'target_idx_list' in config.data and config.data.target_idx_list is not None:
        highres_data = highres_data[...,config.data.target_idx_list]

Compute metrics:

In [None]:
if highres_data is not None:
    print(f'{DataSplitType.name(eval_datasplit_type)}_P{eval_patch_size}_G{image_size_for_grid_centers}_M{mmse_count}_Sk{ignored_last_pixels}')
    psnr_list = [avg_range_inv_psnr(highres_data[...,k], pred_unnorm[k]) for k in range(len(pred_unnorm))]
    tar_tmp = (highres_data - sep_mean) /sep_std
    # tar0_tmp = (highres_data[...,0] - sep_mean[...,0]) /sep_std[...,0]
    ssim_list = compute_multiscale_ssim(tar_tmp, pred)
    # ssim1_hres_mean, ssim1_hres_std = avg_ssim(highres_data[...,0], pred_unnorm[0])
    # ssim2_hres_mean, ssim2_hres_std = avg_ssim(highres_data[...,1], pred_unnorm[1])
    print('PSNR on Highres', ' '.join([str(x) for x in psnr_list]))
    print('SSIM on Highres', ' '.join([str(np.round(x,3)) for x in ssim_list]))

In [None]:
rmse_arr = []
psnr_arr = []
rinv_psnr_arr = []
ssim_arr = []
for ch_id in range(pred.shape[-1]):
    rmse =np.sqrt(((pred[...,ch_id] - tar_normalized[...,ch_id])**2).reshape(len(pred),-1).mean(axis=1))
    rmse_arr.append(rmse)
    psnr = avg_psnr(tar_normalized[...,ch_id].copy(), pred[...,ch_id].copy()) 
    rinv_psnr = avg_range_inv_psnr(tar_normalized[...,ch_id].copy(), pred[...,ch_id].copy())
    ssim_mean, ssim_std = avg_ssim(tar[...,ch_id], pred_unnorm[ch_id])
    psnr_arr.append(psnr)
    rinv_psnr_arr.append(rinv_psnr)
    ssim_arr.append((ssim_mean,ssim_std))

In [None]:
print(f'{DataSplitType.name(eval_datasplit_type)}_P{eval_patch_size}_G{image_size_for_grid_centers}_M{mmse_count}_Sk{ignored_last_pixels}')
print('Rec Loss: ', np.round(rec_loss.mean(),3) )
print('RMSE: ', ' <--> '.join([str(np.mean(x).round(3)) for x in rmse_arr]))
print('PSNR: ', ' <--> '.join([str(x) for x in psnr_arr]))
print('RangeInvPSNR: ',' <--> '.join([str(x) for x in rinv_psnr_arr]))
print('SSIM: ',' <--> '.join([f'{round(x,3)}±{round(y,4)}' for (x,y) in ssim_arr]))
print()