In [None]:
from pathlib import Path
from typing import Optional, Literal, Union

import numpy as np
import torch 
from torch.utils.data import Dataset, DataLoader

from careamics.config import VAEAlgorithmConfig
from careamics.config.architectures import LVAEModel
from careamics.config.likelihood_model import (
    GaussianLikelihoodConfig,
    NMLikelihoodConfig,
)
from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
from careamics.lightning import VAEModule
from careamics.models.lvae.noise_models import noise_model_factory

## 1. Create `Dataset` and `Dataloader`

### 1.1. Dummy Data

In [None]:
class DummyDataset(Dataset):
    def __init__(self, num_samples, input_shape, target_shape):
        self.num_samples = num_samples
        self.input_shape = input_shape
        self.target_shape = target_shape
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        input_tensor = torch.randn(*self.input_shape)
        target_tensor = torch.randn(*self.target_shape)
        return input_tensor, target_tensor

def dummy_dataloader(request):
    # Parameters for customization
    batch_size = request.param.get("batch_size", 2)
    num_samples = request.param.get("num_samples", 10)
    input_shape = request.param.get("input_shape", (3, 64, 64))  # C_1, X, Y
    target_shape = request.param.get("target_shape", (1, 64, 64))  # C_2, X, Y
    
    dataset = DummyDataset(num_samples, input_shape, target_shape)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

### 1.2. Real Data

## 2. Instantiate the lightning module

In [None]:
def create_dummy_noise_model(
    save_path: Optional[Union[Path, str]] = None,
    n_gaussians: int = 3,
    n_coeffs: int = 3,
) -> Path:
    weights = np.random.rand(3*n_gaussians, n_coeffs)
    nm_dict = {
        "trained_weight": weights,
        "min_signal": np.array([0]),
        "max_signal": np.array([2**16 - 1]),
        "min_sigma": 0.125,
    }
    out_path = Path(save_path) / "dummy_noise_model.npz"
    np.savez(out_path, **nm_dict)
    return out_path

In [None]:
def create_split_lightning_model(
    algorithm: str,
    loss_type: str,
    multiscale_count: int = 1,
    predict_logvar: Optional[Literal["pixelwise"]] = None,
    target_ch: int = 1,
    NM_path: Optional[Path] = None,
) -> VAEModule:
    """Instantiate the muSplit lightining model."""
    lvae_config = LVAEModel(
        architecture="LVAE",
        input_shape=64,
        multiscale_count=multiscale_count,
        z_dims=[128, 128, 128, 128],
        output_channels=target_ch,
        predict_logvar=predict_logvar,
    )

    # gaussian likelihood
    if loss_type in ["musplit", "denoisplit_musplit"]:
        gaussian_lik_config = GaussianLikelihoodConfig(
            predict_logvar=predict_logvar,
            logvar_lowerbound=0.0,
        )
    else:
        gaussian_lik_config = None
    # noise model likelihood
    if loss_type in ["denoisplit", "denoisplit_musplit"]:
        if NM_path is None:
            NM_path = create_dummy_noise_model(Path("./"), 3, 3)
        gmm = GaussianMixtureNMConfig(
            model_type="GaussianMixtureNoiseModel",
            path=NM_path,
        )
        noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * target_ch)
        nm = noise_model_factory(noise_model_config)
        nm_lik_config = NMLikelihoodConfig(noise_model=nm)
    else:
        noise_model_config = None
        nm_lik_config = None

    vae_config = VAEAlgorithmConfig(
        algorithm_type="vae",
        algorithm=algorithm,
        loss=loss_type,
        model=lvae_config,
        gaussian_likelihood_model=gaussian_lik_config,
        noise_model=noise_model_config,
        noise_model_likelihood_model=nm_lik_config,
    )

    return VAEModule(
        algorithm_config=vae_config,
    )

In [None]:
lightning_model = create_split_lightning_model(
    algorithm="musplit",
    loss_type="musplit",
    multiscale_count=1,
    predict_logvar="pixelwise",
    target_ch=2,
    NM_path=None
)

## 3. Define train utils 

## 4. Train the model