### Training code for dataset HT_P23B 2D

Training is done in a supervised way. For every input patch, we have the two corresponding target patches using which we train our VSE (Variational Splitting Encoder decoder Network) with KL-divergence loss and a per-channel likelihood loss, following denoiSplit[ref]. In the likelihood computation, Noise models are used. Besides the primary input patch, we also feed LC inputs, originally introduced in uSplit[ref], to the network to make available information about larger spatial context. 

### Important ! 

This step can be skipped! Only run this notebook if you want to train the microsplit model from scratch or finetune. Pretrained model checkpoint is available 
By default this notebook runs example training for 5 epochs. 

#### General imports

In [1]:
import pooch
from pathlib import Path
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from careamics.lightning import VAEModule

from microsplit_reproducibility.configs.factory import (
    create_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from microsplit_reproducibility.utils.callbacks import get_callbacks
from microsplit_reproducibility.utils.io import load_checkpoint, load_checkpoint_path
from microsplit_reproducibility.datasets import create_train_val_datasets
from microsplit_reproducibility.utils.utils import (
    plot_training_metrics,
    plot_input_patches,
    plot_training_outputs,
)

#### Experiments specific imports

In [2]:
from microsplit_reproducibility.configs.parameters.HT_P23B_2D import (
    get_microsplit_parameters,
)
from microsplit_reproducibility.configs.data.HT_P23B_2D import get_data_configs
from microsplit_reproducibility.datasets.HT_P23B import get_train_val_data

### Download the data

In [3]:
DATA = pooch.create(
    path="./data",
    base_url="https://download.fht.org/jug/microsplit/ht_p23b",
    registry={"ht_p23b.zip": None},
)

NOISE_MODELS = pooch.create(
    path="./noise_models",
    base_url="https://download.fht.org/jug/microsplit/ht_p23b/2d",
    registry={
        "nm_ht_p23b_2d_ch1.npz": None,
        "nm_ht_p23b_2d_ch2.npz": None,
    },
)

MODEL_CHECKPOINTS = pooch.create(
    path="./checkpoints",
    base_url="https://download.fht.org/jug/microsplit/ht_p23b_2d",
    registry={"best.ckpt": None, "last.ckpt": None},
)

In [4]:
for i, f in enumerate(NOISE_MODELS.registry):
    NOISE_MODELS.fetch(f"nm_ht_p23b_2d_ch{i+1}.npz")

DATA.fetch("ht_p23b.zip", processor=pooch.Unzip())

for f in MODEL_CHECKPOINTS.registry:
    MODEL_CHECKPOINTS.fetch(f"{f}")

Downloading file 'best.ckpt' from 'https://download.fht.org/jug/microsplit/ht_p23b_2d/best.ckpt' to '/home/igor.zubarev/projects/microSplit-reproducibility/examples/2D/HT_P23B/checkpoints'.


HTTPError: 404 Client Error: Not Found for url: https://download.fht.org/jug/microsplit/ht_p23b_2d/best.ckpt

### Get data and experiment parameters

In [8]:
train_data_config, val_data_config, test_data_configs = get_data_configs()
experiment_params = get_microsplit_parameters()

Downloading nm_ht_p23b_2d_channel_1.npz
Downloading nm_ht_p23b_2d_channel_2.npz


In [9]:
for k in train_data_config:
    print(k)

('data_type', <DataType.ExpMicroscopyV2: 6>)
('depth3D', 1)
('datasplit_type', <DataSplitType.Train: 1>)
('num_channels', 2)
('ch1_fname', None)
('ch2_fname', None)
('ch_input_fname', None)
('input_is_sum', False)
('input_idx', None)
('target_idx_list', None)
('start_alpha', None)
('end_alpha', None)
('image_size', (64, 64))
('grid_size', 32)
('empty_patch_replacement_enabled', False)
('empty_patch_replacement_channel_idx', None)
('empty_patch_replacement_probab', None)
('empty_patch_max_val_threshold', None)
('uncorrelated_channels', False)
('uncorrelated_channel_probab', 0.5)
('poisson_noise_factor', -1.0)
('synthetic_gaussian_scale', 228.0)
('input_has_dependant_noise', True)
('enable_gaussian_noise', False)
('allow_generation', False)
('training_validtarget_fraction', None)
('deterministic_grid', None)
('enable_rotation_aug', False)
('max_val', None)
('overlapping_padding_kwargs', {'mode': 'reflect'})
('print_vars', False)
('normalized_input', True)
('use_one_mu_std', True)
('train

In [10]:
experiment_params

{'algorithm': 'denoisplit',
 'loss_type': 'denoisplit_musplit',
 'img_size': (64, 64),
 'target_channels': 2,
 'multiscale_count': 1,
 'predict_logvar': 'pixelwise',
 'nm_paths': ['noise_models/nm_ht_p23b_2d_channel_1.npz',
  'noise_models/nm_ht_p23b_2d_channel_2.npz'],
 'kl_type': 'kl_restricted',
 'batch_size': 32,
 'lr': 0.001,
 'lr_scheduler_patience': 30,
 'earlystop_patience': 200,
 'num_epochs': 400,
 'num_workers': 0,
 'mmse_count': 10,
 'grid_size': 32}

### Create dataset

In [12]:
train_dset, val_dset, _, data_stats = create_train_val_datasets(
    datapath=DATA.path / "ht_p23b.zip.unzip/datafiles",
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=get_train_val_data,
)


Padding is not used with this alignement style

Padding is not used with this alignement style
MultiFileDset avg height: 1584, avg width: 1584, count: 58

Padding is not used with this alignement style
MultiFileDset avg height: 1584, avg width: 1584, count: 12

Padding is not used with this alignement style
MultiFileDset avg height: 1584, avg width: 1584, count: 12


### Create dataloaders

In [None]:
train_dloader = DataLoader(
    train_dset,
    batch_size=experiment_params["batch_size"],
    num_workers=experiment_params["num_workers"],
    shuffle=True,
)
val_dloader = DataLoader(
    val_dset,
    batch_size=experiment_params["batch_size"],
    num_workers=experiment_params["num_workers"],
    shuffle=False,
)

### Get experiment configs

In [None]:
experiment_params["data_stats"] = data_stats  # TODO rethink

loss_config = get_loss_config(**experiment_params)
model_config = get_model_config(**experiment_params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **experiment_params
)
training_config = get_training_config(**experiment_params)
lr_scheduler_config = get_lr_scheduler_config(**experiment_params)
optimizer_config = get_optimizer_config(**experiment_params)

experiment_config = create_algorithm_config(
    algorithm=experiment_params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

### Initialize the model

In [None]:
model = VAEModule(algorithm_config=experiment_config)

### Load checkpoint (optional)

It's possible to load a checkpoint to continue training

In [None]:
ckpt_path = load_checkpoint_path("checkpoints", best=True)
model = VAEModule.load_from_checkpoint(
    ckpt_path, algorithm_config=experiment_config
)

### Visualize input data


In [None]:
plot_input_patches(dataset=train_dset, num_channels=3, num_samples=3)

### Train the model

Only 5 epochs for the sake of the example

In [None]:
trainer = Trainer(
    max_epochs=training_config.num_epochs,
    accelerator="gpu",
    enable_progress_bar=True,
    callbacks=get_callbacks("/checkpoints"),
    precision=training_config.precision,
    gradient_clip_val=training_config.gradient_clip_val,
    gradient_clip_algorithm=training_config.gradient_clip_algorithm,
)
trainer.fit(
    model=model,
    train_dataloaders=train_dloader,
    val_dataloaders=val_dloader,
)

### Training logs

In [None]:
plot_training_metrics(list(Path("csv_logs").rglob("metrics.csv"))[0])

In [None]:
plot_training_outputs(val_dset, trainer.model, num_channels=3)