### Training code for dataset Bla Bla

Description of the dataset, experiment etc

#### General imports

In [None]:
import pooch
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from careamics.lightning import VAEModule

from microsplit_reproducibility.configs.factory import (
    get_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from microsplit_reproducibility.datasets import create_train_val_datasets
from microsplit_reproducibility.utils.utils import plot_training_metrics

#### Experiments specific imports

In [None]:
from microsplit_reproducibility.configs.parameters.pavia_p24 import get_denoisplit_parameters
from microsplit_reproducibility.configs.data.pavia_p24 import get_data_configs
from microsplit_reproducibility.datasets.pavia_p24 import get_train_val_data

### Get configs

Example training code 5 epochs, switch between full training, short training,  fine-tuning 

In [None]:
# TODO refactor, wrap in one function. Only if the function can be generalized
train_data_config, val_data_config, test_data_configs = get_data_configs()
training_params = get_denoisplit_parameters()
# TODO nm path with pooch ?

In [None]:
for k in train_data_config:
    print(k)    

In [None]:
training_params

### Create dataset

In [None]:
tmp_local_path = "/localscratch/data/pavia3_sequential_cropped"

In [None]:
DATA = pooch.create(
    # path=pooch.os_cache("microsplit_reproducibility_pavia_p24"), # TODO should be downloaded and stored locally
    path=tmp_local_path,
    base_url="",
    registry={"":""},
)

In [None]:
train_dset, val_dset, _, data_stats = create_train_val_datasets(
    datapath=tmp_local_path,
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=get_train_val_data,
)

# TODO problem is, creating a dataloader requires a config, that's ugly af

### Create dataloaders

In [None]:
train_dloader = DataLoader(
    train_dset,
    batch_size=training_params["batch_size"],
    num_workers=training_params["num_workers"],
    shuffle=True,
)
val_dloader = DataLoader(
    val_dset,
    batch_size=training_params["batch_size"],
    num_workers=training_params["num_workers"],
    shuffle=False,
)

In [None]:
training_params["data_stats"] = data_stats # TODO rethink

loss_config = get_loss_config(**training_params)
model_config = get_model_config(**training_params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **training_params
)
training_config = get_training_config(**training_params)
lr_scheduler_config = get_lr_scheduler_config(**training_params)
optimizer_config = get_optimizer_config(**training_params)

# TODO rename to create
experiment_config = get_algorithm_config(
    algorithm=training_params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

### Visualize configs

In [None]:
#TODO code, discuss

### Train the model

Only 5 epochs for the sake of the example

In [None]:
# init lightning model
lightning_model = VAEModule(algorithm_config=experiment_config)

# train the model
# custom_callbacks = get_callbacks(logdir)


### Training logs

In [None]:
trainer = Trainer(
    max_epochs=training_config.num_epochs,
    accelerator="gpu",
    enable_progress_bar=True,
    # callbacks=custom_callbacks,
    precision=training_config.precision,
    gradient_clip_val=training_config.gradient_clip_val,
    gradient_clip_algorithm=training_config.gradient_clip_algorithm,
    # num_sanity_val_steps=0
)
trainer.fit(
    model=lightning_model,
        train_dataloaders=train_dloader,
        val_dataloaders=val_dloader,
    )

In [None]:
plot_training_metrics("csv_logs")