### Evaluation code for the dataset Bla bla

Short description of metrics and panels

#### General imports

In [None]:
from typing import Callable, Optional

import wandb
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from careamics.lightning import VAEModule

import configs

from configs.factory import (
    get_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from datasets import create_train_val_datasets
from utils.callbacks import get_callbacks
from utils.io import get_workdir, log_configs

#### Experiments specific imports

In [None]:
from configs.parameters import get_denoisplit_parameters
from configs.data import get_data_config

### Get configs

In [None]:
# TODO refactor, all functions should come from careamics
train_data_config, val_data_config, test_data_configs = get_data_configs()
params = get_denoisplit_parameters()
loss_config = get_loss_config(**params)
model_config = get_model_config(**params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **params
)
training_config = get_training_config(**params)
lr_scheduler_config = get_lr_scheduler_config(**params)
optimizer_config = get_optimizer_config(**params)

algo_config = get_algorithm_config(
    algorithm=params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

### Create dataset

In [None]:
# TODO add mode train/test to return different datasets
train_dset, val_dset, test_dset, data_stats = create_train_val_datasets(
    datapath=data_path,
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=load_data_fn,
)
train_dloader = DataLoader(
    train_dset,
    batch_size=params["batch_size"],
    num_workers=params["num_workers"],
    shuffle=True,
)
val_dloader = DataLoader(
    val_dset,
    batch_size=params["batch_size"],
    num_workers=params["num_workers"],
    shuffle=False,
)

### Create model and load checkpoint

In [None]:
lightning_model = create_split_lightning_model(
        algorithm="denoisplit",
        loss="denoisplit_musplit",
        model_parameters={"img_size": img_size,
        "multiscale_count": multiscale_count,
        "predict_logvar": predict_logvar,
        "target_ch": target_channels,
        "nm_paths": nm_paths},
        data_config={"data_stats": data_stats},
        training_config=training_config,
    )

In [None]:
if os.path.isdir(ckpt_dir):
    ckpt_fpath = get_model_checkpoint(ckpt_dir, mode=which_ckpt)
else:
    assert os.path.isfile(ckpt_dir)
    ckpt_fpath = ckpt_dir

print(f"Loading checkpoint from: '{ckpt_fpath}'")

In [None]:
checkpoint = torch.load(ckpt_fpath)

lightning_model.load_state_dict(checkpoint['state_dict'], strict=True)
lightning_model.eval()
lightning_model.cuda()

print('Loading weights from epoch', checkpoint['epoch'])

### Perform evaluation

In [None]:
# NOTE: here, patch-wise PSNR is used, hence results are not trustworthy
# TODO rename, put stitching inside
pred_tiled = get_dset_predictions(
  model=lightning_model,
  dset=test_dset,
  batch_size=batch_size,
  num_workers=num_workers,
  mmse_count=mmse_count,
  loss_type=algo_config["loss"],
)

In [None]:
# Stitch the std of the predictions (i.e., std computed on the mmse_count predictions)
if pred_tiled.shape[-1] != test_dset.get_img_sz():
    pad = (val_dset.get_img_sz() - pred_tiled.shape[-1] )//2
    pred_tiled = np.pad(pred_tiled, ((0,0),(0,0),(pad,pad),(pad,pad)))

# Stitch tiled predictions
pred = stitch_predictions_new(
    pred_tiled,
    test_dset,
)

In [None]:
# TODO discuss visualing inidividual samples vs mmse. do both! 

### Visualize results


### Panel 1 ...