In [None]:
import os
import socket
from pathlib import Path
from typing import Optional, Literal, Union

import numpy as np
import torch 
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

from careamics.config import VAEAlgorithmConfig
from careamics.config.architectures import LVAEModel
from careamics.config.callback_model import CheckpointModel, EarlyStoppingModel
from careamics.config.likelihood_model import (
    GaussianLikelihoodConfig,
    NMLikelihoodConfig,
)
from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
from careamics.lightning import VAEModule
from careamics.models.lvae.noise_models import noise_model_factory

Set some parameters for the current training simulation

In [None]:
img_size: int = 64
"""Spatial size of the input image."""
target_channels: int = 2
"""Number of channels in the target image."""
multiscale_count: int = 5
"""The number of LC inputs plus one (the actual input)."""
predict_logvar: Optional[Literal["pixelwise"]] = "pixelwise"
"""Whether to compute also the log-variance as LVAE output."""
loss_type: Optional[Literal["musplit", "denoisplit", "denoisplit_musplit"]] = "musplit"
"""The type of reconstruction loss (i.e., likelihood) to use."""

## 1. Create `Dataset` and `Dataloader`

### 1.1. Dummy Data

In [None]:
class DummyDataset(Dataset):
    def __init__(
        self, 
        img_size: int = 64, 
        target_ch: int = 1,
        multiscale_count: int = 1,
    ):
        self.num_samples = 100
        self.img_size = img_size
        self.target_ch = target_ch
        self.multiscale_count = multiscale_count
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx: int):
        input_ = torch.randn(self.multiscale_count, self.img_size, self.img_size)
        target = torch.randn(self.target_ch, self.img_size, self.img_size)
        return input_, target

def dummy_dataloader(
    batch_size: int = 1,
    img_size: int = 64,
    target_ch: int = 1,
    multiscale_count: int = 1,
):
    dataset = DummyDataset(
        img_size=img_size,
        target_ch=target_ch,
        multiscale_count=multiscale_count,
    )
    return DataLoader(dataset, batch_size=batch_size, num_workers=3, shuffle=False)

In [None]:
dloader = dummy_dataloader(
    img_size=img_size,
    target_ch=target_channels,
    multiscale_count=multiscale_count,
)
input_, target = next(iter(dloader))
input_.shape, target.shape

### 1.2. Real Data

## 2. Instantiate the lightning module

In [None]:
def create_dummy_noise_model(
    save_path: Optional[Union[Path, str]] = None,
    n_gaussians: int = 3,
    n_coeffs: int = 3,
) -> Path:
    weights = np.random.rand(3*n_gaussians, n_coeffs)
    nm_dict = {
        "trained_weight": weights,
        "min_signal": np.array([0]),
        "max_signal": np.array([2**16 - 1]),
        "min_sigma": 0.125,
    }
    out_path = Path(save_path) / "dummy_noise_model.npz"
    np.savez(out_path, **nm_dict)
    return out_path

In [None]:
def create_split_lightning_model(
    algorithm: str,
    loss_type: str,
    multiscale_count: int = 1,
    predict_logvar: Optional[Literal["pixelwise"]] = None,
    target_ch: int = 1,
    NM_path: Optional[Path] = None,
) -> VAEModule:
    """Instantiate the muSplit lightining model."""
    lvae_config = LVAEModel(
        architecture="LVAE",
        input_shape=64,
        multiscale_count=multiscale_count,
        z_dims=[128, 128, 128, 128],
        output_channels=target_ch,
        predict_logvar=predict_logvar,
    )

    # gaussian likelihood
    if loss_type in ["musplit", "denoisplit_musplit"]:
        gaussian_lik_config = GaussianLikelihoodConfig(
            predict_logvar=predict_logvar,
            logvar_lowerbound=0.0,
        )
    else:
        gaussian_lik_config = None
    # noise model likelihood
    if loss_type in ["denoisplit", "denoisplit_musplit"]:
        if NM_path is None:
            NM_path = create_dummy_noise_model(Path("./"), 3, 3)
        gmm = GaussianMixtureNMConfig(
            model_type="GaussianMixtureNoiseModel",
            path=NM_path,
        )
        noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * target_ch)
        nm = noise_model_factory(noise_model_config)
        nm_lik_config = NMLikelihoodConfig(noise_model=nm)
    else:
        noise_model_config = None
        nm_lik_config = None

    vae_config = VAEAlgorithmConfig(
        algorithm_type="vae",
        algorithm=algorithm,
        loss=loss_type,
        model=lvae_config,
        gaussian_likelihood_model=gaussian_lik_config,
        noise_model=noise_model_config,
        noise_model_likelihood_model=nm_lik_config,
    )

    return VAEModule(
        algorithm_config=vae_config,
    )

In [None]:
algo = "musplit" if loss_type == "musplit" else "denoisplit"
lightning_model = create_split_lightning_model(
    algorithm=algo,
    loss_type=loss_type,
    multiscale_count=multiscale_count,
    predict_logvar=predict_logvar,
    target_ch=target_channels,
    NM_path=None
)

## 3. Set utils for training

In [None]:
from datetime import datetime

from careamics.lvae_training.train_utils import get_new_model_version

def get_new_model_version(model_dir: Union[Path, str]) -> int:
    """Create a unique version ID for a new model run."""
    versions = []
    for version_dir in os.listdir(model_dir):
        try:
            versions.append(int(version_dir))
        except:
            print(
                f"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed"
            )
            exit()
    if len(versions) == 0:
        return "0"
    return f"{max(versions) + 1}"

def get_workdir(
    root_dir: str,
    model_name: str,
) -> tuple[Path, Path]:
    """Get the workdir for the current model.
    
    It has the following structure: "root_dir/YYMM/model_name/version"
    """
    rel_path = datetime.now().strftime("%y%m")
    cur_workdir = os.path.join(root_dir, rel_path)
    Path(cur_workdir).mkdir(exist_ok=True)

    rel_path = os.path.join(rel_path, model_name)
    cur_workdir = os.path.join(root_dir, rel_path)
    Path(cur_workdir).mkdir(exist_ok=True)

    rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))
    cur_workdir = os.path.join(root_dir, rel_path)
    try:
        Path(cur_workdir).mkdir(exist_ok=False)
    except FileExistsError:
        print(
            f"Workdir {cur_workdir} already exists."
        )
    return cur_workdir, rel_path

In [None]:
ROOT_DIR = "/group/jug/federico/careamics_training/refac_v2/"
workdir, exp_tag = get_workdir(ROOT_DIR, "dummy_debugging")
print(f"Current workdir: {workdir}")

In [None]:
# Define the logger
custom_logger = WandbLogger(
    name=os.path.join(socket.gethostname(), exp_tag),
    save_dir=workdir,
    project="careamics_debugging_LVAE",
)

In [None]:
# Define callbacks (e.g., ModelCheckpoint, EarlyStopping, etc.)
early_stopping_config = EarlyStoppingModel(
    monitor="val_loss",
    min_delta=1e-6,
    patience=10,
    mode="min",
    verbose=True,
)
checkpoint_config = CheckpointModel(
    monitor="val_loss",
    save_top_k=2,
    mode="min",
)
custom_callbacks = [
    EarlyStopping(**early_stopping_config.model_dump()), 
    ModelCheckpoint(**checkpoint_config.model_dump()),
    LearningRateMonitor(logging_interval="epoch")
]

In [None]:
# Save AlgorithmConfig
with open(os.path.join(workdir, "algorithm_config.json"), "w") as f:
    f.write(lightning_model.algorithm_config.model_dump_json())

custom_logger.experiment.config.update(
    lightning_model.algorithm_config.model_dump()    
)

## 4. Train the model

In [None]:
trainer = Trainer(
    max_epochs=10,
    accelerator="cpu",
    enable_progress_bar=True,
    logger=custom_logger,
    callbacks=custom_callbacks,
)

In [None]:
trainer.fit(
    model=lightning_model,
    train_dataloaders=dloader,
    val_dataloaders=dloader,
)