### Evaluation code for the dataset Bla bla

Short description of metrics and panels

#### General imports

In [6]:
import os
import pooch
import tifffile
import numpy as np
import matplotlib as mpl
import matplotlib.patches as patches
import matplotlib.pyplot as plt

from microsplit_reproducibility.configs.factory import (
    get_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from microsplit_reproducibility.utils.io import load_checkpoint
from microsplit_reproducibility.datasets import create_train_val_datasets

from careamics.lightning import VAEModule
from careamics.lvae_training.eval_utils import get_predictions, get_samples, plot_error
from careamics.utils.metrics import avg_range_invariant_psnr

from careamics.lvae_training.calibration import (
    Calibration,
    get_calibrated_factor_for_stdev,
    plot_calibration,
)

#### Experiments specific imports

In [None]:
from microsplit_reproducibility.configs.parameters.pavia_p24 import (
    get_denoisplit_parameters,
)
from microsplit_reproducibility.configs.data.pavia_p24 import get_data_configs
from microsplit_reproducibility.datasets.pavia_p24 import get_train_val_data

### Get configs

In [None]:
train_data_config, val_data_config, test_data_configs = get_data_configs()
experiment_params = get_denoisplit_parameters()

### Create dataset

In [None]:
tmp_local_path = "/localscratch/data/pavia3_sequential_cropped"

In [None]:
DATA = pooch.create(
    # path=pooch.os_cache("microsplit_reproducibility_pavia_p24"), # TODO should be downloaded and stored locally
    path=tmp_local_path,
    base_url="",
    registry={"": ""},
)

In [None]:
train_dset, val_dset, test_dset, data_stats = create_train_val_datasets(
    datapath=tmp_local_path,
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=get_train_val_data,
)

# TODO problem is, creating a dataloader requires a config, that's ugly af

### Get experiment configs

In [None]:
experiment_params["data_stats"] = data_stats  # TODO rethink

loss_config = get_loss_config(**experiment_params)
model_config = get_model_config(**experiment_params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **experiment_params
)
training_config = get_training_config(**experiment_params)
lr_scheduler_config = get_lr_scheduler_config(**experiment_params)
optimizer_config = get_optimizer_config(**experiment_params)

# TODO rename to create
experiment_config = get_algorithm_config(
    algorithm=experiment_params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

### Load predictions

In [None]:
# NOTE: Recall the `pred_std` here is the pixel-wise std of the mmse_count many predictions
calib = Calibration(
    num_bins=30,
    mode='pixelwise'
)
native_stats = calib.compute_stats(
    pred=predictions,
    pred_logvar=pred_std,
    target=target_normalized
)
count = np.array(native_stats[0]['bin_count'])
count = count / count.sum()
# print(count.cumsum()[:-1])
plt.plot(native_stats[0]['rmv'][1:-1], native_stats[0]['rmse'][1:-1], 'o')
plt.title("RMV vs. RMSE plot - Not Calibrated")
plt.xlabel('RMV'), plt.ylabel('RMSE')

In [None]:
inp, _ = val_dset[0]
out_dir = get_eval_output_dir(ckpt_dir, inp.shape[1], mmse_count=mmse_count)
fname = "calibration_factor.npy"
factor_fpath = os.path.join(out_dir, fname)

# Compute calibration factors
if eval_datasplit_type == DataSplitType.Val:
    # Compute calibration factors for the channels
    calib_factors = []
    for i in range(pred.shape[-1]):
        calib_factors.append(
            get_calibrated_factor_for_stdev(
                pred=pred[..., i],
                pred_logvar=np.log(pred_std[..., i] ** 2),
                target=tar_normalized[..., i],
                batch_size=8,
                lr=0.1
            )
        )
    print(f"Calibration factors: {[calib_factor for calib_factor in calib_factors]}")
    calib_factor = np.array(calib_factors).reshape(1, 1, 1, 2)
    np.save(factor_fpath, calib_factor)
    print(f'Saved calibration factor fitted on validation set to {factor_fpath}')

# Use pre-computed calibration factor
elif eval_datasplit_type == DataSplitType.Test:
    print('Loading the calibration factor from the file', factor_fpath)
    calib_factor = np.load(factor_fpath)

# Given the calibration factor, plot RMV vs. RMSE
calib = Calibration(num_bins=30, mode='pixelwise')
pred_logvar = 2* np.log(pred_std * calib_factor)
stats = calib.compute_stats(
    pred,
    pred_logvar,
    tar_normalized
)
_,ax = plt.subplots(figsize=(5,5))
plt.title("RMV vs. RMSE plot - Calibrated")
plot_calibration(ax, stats)

if eval_datasplit_type == DataSplitType.Test:
    stats_fpath = os.path.join(out_dir, "calibration_stats.pkl.npy")
    np.save(stats_fpath, stats)
    print('Saved stats of Test set to ', stats_fpath)

In [None]:
try:
        calib_factors = [
            np.load(os.path.join('/path/to/calibration/factors/dir/', fpath), allow_pickle=True)
            for fpath in [
                'calibration_stats_1.pkl.npy',
                'calibration_stats_2.pkl.npy',
                'calibration_stats_3.pkl.npy',
            ]
        ]
        labels = ['w=0.5', 'w=0.9', 'w=1']
    except FileNotFoundError:
        print('Calibration factors not found. Skipping the plot.')
        calib_factors = []

    if len(calib_factors) > 0:
        _,ax = plt.subplots(figsize=(5,2.5))
        for i, calibration_stats in enumerate(calib_factors):
            first_idx = get_first_index(calibration_stats[()][0]['bin_count'], 0.0001)
            last_idx = get_last_index(calibration_stats[()][0]['bin_count'], 0.9999)
            ax.plot(
                calibration_stats[()][0]['rmv'][first_idx:-last_idx],
                calibration_stats[()][0]['rmse'][first_idx:-last_idx],
                '-+',
                label=labels[i]
            )

        ax.yaxis.grid(color='gray', linestyle='dashed')
        ax.xaxis.grid(color='gray', linestyle='dashed')
        ax.plot(np.arange(0,1.5, 0.01), np.arange(0,1.5, 0.01), 'k--')
        ax.set_facecolor('xkcd:light grey')
        plt.legend(loc='lower right')
        # plt.xlim(0,3)
        # plt.ylim(0,1.25)
        plt.xlabel('RMV')
        plt.ylabel('RMSE')
        ax.set_axisbelow(True)


        plotsdir = get_plots_output_dir(ckpt_dir, 0, mmse_count=0)
        model_id = ckpt_dir.strip('/').split('/')[-1]
        fname = f'calibration_plot_{model_id}.png'
        fpath = os.path.join(plotsdir, fname)
        # plt.savefig(fpath, dpi=200, bbox_inches='tight')
        print(f'Saved to {fpath}')

### Visualize results


### Panel 1 ...