### Training code for dataset Bla Bla

Description of the dataset, experiment etc

#### General imports

In [None]:
from typing import Callable, Optional

import wandb
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from careamics.lightning import VAEModule

import configs

from configs.factory import (
    get_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from datasets import create_train_val_datasets
from utils.callbacks import get_callbacks
from utils.io import get_workdir, log_configs

#### Experiments specific imports

In [None]:
from configs.parameters import get_denoisplit_parameters
from configs.data import get_data_configs

### Get configs

Example training code 5 epochs, switch between full training, short training,  fine-tuning 

In [None]:
# TODO refactor, wrap in one function. Only if the function can be generalized
train_data_config, val_data_config, test_data_configs = get_data_configs()
params = get_denoisplit_parameters()
loss_config = get_loss_config(**params)
model_config = get_model_config(**params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **params
)
training_config = get_training_config(**params)
lr_scheduler_config = get_lr_scheduler_config(**params)
optimizer_config = get_optimizer_config(**params)

algo_config = get_algorithm_config(
    algorithm=params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

### Visualize configs

In [None]:
#TODO code, discuss

### Create dataset

In [None]:
train_dset, val_dset, _, data_stats = create_train_val_datasets(
    datapath=data_path,
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=load_data_fn,
)
train_dloader = DataLoader(
    train_dset,
    batch_size=params["batch_size"],
    num_workers=params["num_workers"],
    shuffle=True,
)
val_dloader = DataLoader(
    val_dset,
    batch_size=params["batch_size"],
    num_workers=params["num_workers"],
    shuffle=False,
)

### Train the model

Only 5 epochs for the sake of the example

In [None]:
# init lightning model
lightning_model = VAEModule(algorithm_config=algo_config)

# train the model
custom_callbacks = get_callbacks(logdir)
trainer = Trainer(
    max_epochs=training_config.num_epochs,
    accelerator="gpu",
    enable_progress_bar=True,
    callbacks=custom_callbacks,
    precision=training_config.precision,
    gradient_clip_val=training_config.gradient_clip_val,
    gradient_clip_algorithm=training_config.gradient_clip_algorithm,
)
trainer.fit(
    model=lightning_model,
        train_dataloaders=train_dloader,
        val_dataloaders=val_dloader,
    )

### Training logs

In [None]:
# TODO visualize losses, CSV logger as pandas df, plots, location of checkpoint 