### Training code for dataset Bla Bla

Training is done in a supervised way. For every input patch, we have the two corresponding target patches using which we train our VSE (Variational Splitting Encoder decoder Network) with KL-divergence loss and a per-channel likelihood loss, following denoiSplit[ref]. In the likelihood computation, Noise models are used. Besides the primary input patch, we also feed LC inputs, originally introduced in uSplit[ref], to the network to make available information about larger spatial context. 

### Important ! 

This step can be skipped! Only run this notebook if you want to train the microsplit model from scratch or finetune. Pretrained model checkpoint is available 

#### General imports

In [1]:
import pooch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from careamics.lightning import VAEModule

from microsplit_reproducibility.configs.factory import (
    create_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from microsplit_reproducibility.utils.callbacks import get_callbacks
from microsplit_reproducibility.utils.io import load_checkpoint
from microsplit_reproducibility.datasets import create_train_val_datasets
from microsplit_reproducibility.utils.utils import plot_training_metrics

#### Experiments specific imports

In [2]:
from microsplit_reproducibility.configs.parameters.HT_LIF24 import get_microsplit_parameters
from microsplit_reproducibility.configs.data.HT_LIF24 import get_data_configs
from microsplit_reproducibility.datasets.HT_LIF24 import get_train_val_data

### Get data and experiment parameters

Example training code 5 epochs, switch between full training, short training,  fine-tuning 

In [5]:
from typing import Literal

from careamics.lvae_training.dataset import DatasetConfig, DataSplitType, DataType


class NikolaDataConfig(DatasetConfig):
    dset_type: Literal[
        "high", "mid", "low", "verylow", "2ms", "3ms", "5ms", "20ms", "500ms"
    ]
    # TODO: add description
    
    channel_idx_list: list[Literal[1, 2, 3, 17]]
    # TODO: add description


In [13]:
train_data_config, val_data_config, test_data_configs = get_data_configs(dset_type="20ms", channel_idx_list=[1, 2, 3])
experiment_params = get_microsplit_parameters(dset_type="20ms", channel_idx_list=[1, 2, 3])

In [14]:
for k in train_data_config:
    print(k)    

('data_type', <DataType.NicolaData: 1>)
('depth3D', 1)
('datasplit_type', <DataSplitType.Train: 1>)
('num_channels', 3)
('ch1_fname', None)
('ch2_fname', None)
('ch_input_fname', None)
('input_is_sum', False)
('input_idx', 2)
('target_idx_list', [0, 1])
('start_alpha', None)
('end_alpha', None)
('image_size', (64, 64))
('grid_size', 32)
('empty_patch_replacement_enabled', False)
('empty_patch_replacement_channel_idx', None)
('empty_patch_replacement_probab', None)
('empty_patch_max_val_threshold', None)
('uncorrelated_channels', False)
('uncorrelated_channel_probab', 0.5)
('poisson_noise_factor', -1.0)
('synthetic_gaussian_scale', 100.0)
('input_has_dependant_noise', True)
('enable_gaussian_noise', False)
('allow_generation', False)
('training_validtarget_fraction', None)
('deterministic_grid', None)
('enable_rotation_aug', False)
('max_val', None)
('overlapping_padding_kwargs', {'mode': 'reflect'})
('print_vars', False)
('normalized_input', True)
('use_one_mu_std', True)
('train_aug_r

In [15]:
experiment_params

{'algorithm': 'denoisplit',
 'loss_type': 'denoisplit_musplit',
 'img_size': (64, 64),
 'target_channels': 3,
 'multiscale_count': 3,
 'predict_logvar': 'pixelwise',
 'nm_paths': ['/group/jug/ashesh/training/noise_model/2406/13/nm_ht_lif24_ch1_20ms.npz',
  '/group/jug/ashesh/training/noise_model/2406/14/nm_ht_lif24_ch2_20ms.npz',
  '/group/jug/ashesh/training/noise_model/2406/15/nm_ht_lif24_ch3_20ms.npz'],
 'kl_type': 'kl_restricted',
 'batch_size': 32,
 'lr': 0.001,
 'lr_scheduler_patience': 30,
 'earlystop_patience': 200,
 'num_epochs': 400,
 'num_workers': 0,
 'mmse_count': 10,
 'grid_size': 32}

### Download the data

In [16]:
DATA = pooch.create(
    path="./data",
    base_url="https://download.fht.org/jug/ht_lif24",
    registry={"ht_lif24.zip": None},
)

NOISE_MODELS = pooch.create(
    path="./noise_models",
    base_url="https://download.fht.org/jug/ht_lif24",
    registry={"nm_ht_lif24_ch1_20ms.npz": None,
              "nm_ht_lif24_ch2_20ms.npz": None,
              "nm_ht_lif24_ch3_20ms.npz": None},
)

MODEL_CHECKPOINTS = pooch.create(
    path="./checkpoints",
    base_url="https://download.fht.org/jug/ht_lif24",
    registry={"best.ckpt": None,
              "last.ckpt": None},
)

In [17]:
for i, f in enumerate(NOISE_MODELS.registry):
    NOISE_MODELS.fetch(f"nm_ht_lif24_ch{i+1}_20ms.npz")

DATA.fetch("ht_lif24.zip", processor=pooch.Unzip())

for f in MODEL_CHECKPOINTS.registry:
    MODEL_CHECKPOINTS.fetch(f"{f}")


Downloading file 'ht_lif24.zip' from 'https://download.fht.org/jug/ht_lif24/ht_lif24.zip' to '/home/igor.zubarev/projects/microSplit-reproducibility/examples/2D/HT_LIF24/data'.


KeyboardInterrupt: 

### Create dataset

In [18]:
train_dset, val_dset, _, data_stats = create_train_val_datasets(
    datapath=DATA.path / "ht_lif24.zip.unzip/ht_lif24",
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=get_train_val_data,
)

Loading from data/ht_lif24.zip.unzip/ht_lif24/Set1/uSplit_20ms.nd2
ND2 dimensions: {'P': 20, 'C': 19, 'Y': 1608, 'X': 1608}; RGB: False; datatype: uint16; legacy: False
Loading from data/ht_lif24.zip.unzip/ht_lif24/Set2/uSplit_20ms.nd2
ND2 dimensions: {'P': 11, 'C': 19, 'Y': 1608, 'X': 1608}; RGB: False; datatype: uint16; legacy: False
Loading from data/ht_lif24.zip.unzip/ht_lif24/Set3/uSplit_20ms.nd2
ND2 dimensions: {'P': 20, 'C': 19, 'Y': 1608, 'X': 1608}; RGB: False; datatype: uint16; legacy: False
Loading from data/ht_lif24.zip.unzip/ht_lif24/Set4/uSplit_20ms.nd2
ND2 dimensions: {'P': 20, 'C': 19, 'Y': 1608, 'X': 1608}; RGB: False; datatype: uint16; legacy: False
Loading from data/ht_lif24.zip.unzip/ht_lif24/Set5/uSplit_20ms.nd2
ND2 dimensions: {'P': 21, 'C': 19, 'Y': 1608, 'X': 1608}; RGB: False; datatype: uint16; legacy: False
Loading from data/ht_lif24.zip.unzip/ht_lif24/Set6/uSplit_20ms.nd2
ND2 dimensions: {'P': 20, 'C': 19, 'Y': 1608, 'X': 1608}; RGB: False; datatype: uint16; 

### Create dataloaders

In [None]:
train_dloader = DataLoader(
    train_dset,
    batch_size=experiment_params["batch_size"],
    num_workers=experiment_params["num_workers"],
    shuffle=True,
)
val_dloader = DataLoader(
    val_dset,
    batch_size=experiment_params["batch_size"],
    num_workers=experiment_params["num_workers"],
    shuffle=False,
)

### Get experiment configs

In [None]:
experiment_params["data_stats"] = data_stats # TODO rethink

loss_config = get_loss_config(**experiment_params)
model_config = get_model_config(**experiment_params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **experiment_params
)
training_config = get_training_config(**experiment_params)
lr_scheduler_config = get_lr_scheduler_config(**experiment_params)
optimizer_config = get_optimizer_config(**experiment_params)

experiment_config = create_algorithm_config(
    algorithm=experiment_params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

In [None]:
experiment_config


### Initialize the model

In [None]:
model = VAEModule(algorithm_config=experiment_config)

### Load checkpoint (optional)

It's possible to load a checkpoint to continue training

In [None]:
ckpt = load_checkpoint("checkpoints", best=False)
model = VAEModule.load_from_checkpoint(ckpt, algorithm_config=experiment_config) # TODO fix seek error

### Visualize input data


### Train the model

Only 5 epochs for the sake of the example

In [None]:
trainer = Trainer(
    max_epochs=training_config.num_epochs,
    accelerator="gpu",
    enable_progress_bar=True,
    callbacks=get_callbacks("."),
    precision=training_config.precision,
    gradient_clip_val=training_config.gradient_clip_val,
    gradient_clip_algorithm=training_config.gradient_clip_algorithm,
)
trainer.fit(
    model=model,
        train_dataloaders=train_dloader,
        val_dataloaders=val_dloader,
    )

### Training logs

In [None]:
plot_training_metrics(f"csv_logs/{experiment_params['experiment_name']}/version_0/")

In [None]:
# TODO grid of predictions of last epoch, input/ k channels