> **Note:** Use the provided YAML config exactly as-is to replicate reported results.

In [None]:
!pip install -q -U kagglehub lpips monai monai-generative torchio wandb

## Imports & Global configs.

In [None]:
import math
import os
import random
import time
from collections import defaultdict
from pathlib import Path
from pprint import pprint
from typing import List, Optional

import kagglehub
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
import torchio as tio
import wandb
import yaml

from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import AutoencoderKL, PatchDiscriminator
from monai.networks.layers import Act
from monai.utils import first, set_determinism
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
set_determinism(42)

def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

config = load_config('/content/config_aekl.yaml')

dataset_source = config["data"]["dataset_source"]
kaggle_preproccesed_dataset  = config["data"]["kaggle_preproccesed_dataset"]
preprocessed_path = None

if dataset_source == "original":
    print(">>>Original mode<<<\n")
    dataset_path = kagglehub.dataset_download(
        "awsaf49/brats20-dataset-training-validation"
    )
    processed_path = config['experiment'].get("processed_path", None) # Where to store processed files in case of store_locally=True
    print("Original BraTS dataset downloaded.")
    print("Path to dataset files:", dataset_path)
    print("Path where processed/resampled .npy files will be stored [None means '<root_dir>/processed']:", processed_path)

elif dataset_source == "preprocessed":
    print(">>>Preprocessed mode<<<\n")
    dataset_path = kagglehub.dataset_download(kaggle_preproccesed_dataset)
    dataset_path = dataset_path + "/content/"
    print(f"{kaggle_preproccesed_dataset} dataset downloaded.")
    print("Path to dataset files:", dataset_path)
else:
    raise ValueError(f"Unknown dataset mode: {dataset_source}")

config['hardware']['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
device = config['hardware']['device']
pprint(config, indent=4, width=80, compact=False)

## Utilities

In [None]:
def _center_crop(volume: torch.Tensor, target_shape: tuple) -> torch.Tensor:
    """
    Center-crops a 3D volume to the target shape in (H, W, D) order.

    Parameters:
        volume (torch.Tensor): Input volume of shape (..., H, W, D).
        target_shape (tuple): Target shape as (tH, tW, tD)

    Returns:
        torch.Tensor: Cropped volume
    """
    h, w, d = volume.shape[-3:]
    tH, tW, tD = target_shape

    h_start = (h - tH) // 2
    w_start = (w - tW) // 2
    d_start = (d - tD) // 2

    return volume[...,
                  h_start:h_start+tH,
                  w_start:w_start+tW,
                  d_start:d_start+tD]

def apply_augmentation(
    multimodal_tensor: torch.Tensor,
    segmentation_tensor: torch.Tensor = None, # Made optional for Autoencoder use
    modalities_list = ['t1', 't1ce', 't2', 'flair'],
    geometric_transforms=None,
    intensity_transforms=None
):
    """
    Apply augmentation pipeline.

    Returns:
        Tuple of (transformed_multimodal, transformed_segmentation, foreground_mask)
    """

    # Create TorchIO Subject
    subject_dict = {}
    for i, mod_name in enumerate(modalities_list):
        subject_dict[mod_name] = tio.ScalarImage(tensor=multimodal_tensor[i:i+1]) # Slicing instead of indexing to preserver dim=0

    # Only add segmentation to subject if it is provided
    if segmentation_tensor is not None:
        subject_dict['seg'] = tio.LabelMap(tensor=segmentation_tensor.unsqueeze(0))

    subject = tio.Subject(**subject_dict)

    # Apply geometric transforms
    if geometric_transforms is not None:
        subject = geometric_transforms(subject)

    # Extract foreground mask after geometric transforms
    # This is the important mask we will use for everything else.
    mask = None
    if intensity_transforms is not None:
        # We can create the mask from any of the modalities, BraTS MRIS are co-registered.
        for mod_name in modalities_list:
            mask = subject[mod_name].data > 0
            break

    # Apply intensity transforms
    if intensity_transforms is not None and mask is not None:
        subject = intensity_transforms(subject)

        # Reapply mask to zero out background
        for mod_name in modalities_list:
            # subject[mod_name].data = subject[mod_name].data * mask  [DeprecationWarning]
            new_data = subject[mod_name].data * mask
            subject[mod_name].set_data(new_data)

    # Extract transformed tensors
    multimodal_list = [subject[mod].data for mod in modalities_list]
    transformed_multimodal = torch.cat(multimodal_list, dim=0)

    # Handle the optional segmentation extraction
    transformed_segmentation = None
    if 'seg' in subject:
        transformed_segmentation = subject['seg'].data.squeeze(0)

    # Return the generated mask along with the tensors
    # If no intensity transforms were applied, the mask will be None.
    if mask is None:
        mask = transformed_multimodal > 0

    return transformed_multimodal, transformed_segmentation, mask

In [None]:
def show_mri_slices(x: torch.Tensor, x_recon: torch.Tensor = None, slice_indices=None, cmap='gray', title=None):
    """
    Visualize axial, coronal, and sagittal slices.
    If x_recon is provided, shows ground truth (top) vs reconstruction (bottom).
    If more than one modality is provided, first one is selected.
    """
    def to_np(t):
        t = t.squeeze()
        return t[0].cpu().numpy() if t.ndim == 4 else t.cpu().numpy()

    gt_np = to_np(x)
    D, H, W = gt_np.shape
    if slice_indices is None:
        slice_indices = (D // 2, H // 2, W // 2)

    rows = [gt_np]
    if x_recon is not None:
        rows.append(to_np(x_recon))

    num_rows = len(rows)
    fig, axes = plt.subplots(num_rows, 3, figsize=(12, 4 * num_rows))

    if num_rows == 1:
        axes = axes[None, :]

    # Add the customized title if provided
    if title:
        fig.suptitle(title, fontsize=16)

    titles = ['Axial', 'Coronal', 'Sagittal']
    row_labels = ['GT', 'Recon'] if num_rows > 1 else ['']

    for r_idx, data in enumerate(rows):
        slices = [data[slice_indices[0], :, :], data[:, slice_indices[1], :], data[:, :, slice_indices[2]]]
        for c_idx, (slc, t_sub) in enumerate(zip(slices, titles)):
            ax = axes[r_idx, c_idx]
            ax.imshow(slc, cmap=cmap)
            if r_idx == 0:
                ax.set_title(t_sub)
            if c_idx == 0:
                ax.set_ylabel(row_labels[r_idx])
            ax.set_xticks([]); ax.set_yticks([])

    plt.tight_layout()
    plt.show()

## Dataset Definition
### Helpful infos

The pipeline supports two data sources controlled by `dataset_source`:

- **`"original"`**: Downloads the original BraTS dataset from Kaggle
- **`"preprocessed"`**: Downloads a preprocessed Kaggle dataset (default: my 150×150×119 BraTS)

The selected dataset becomes the `root_dir` for the `BraTSAutoEncoderDataset` instance.

**Preprocessing & Caching:**
- `processed_root_dir` can be explicitly set or defaults to `<root_dir>/processed`
- If `store_locally=True` **and** `resample_shape` is provided:
  - Data is resampled once and cached as `.npy` files in `processed_root_dir`
  - Subsequent runs load from cache -> faster training
- If `store_locally=False` **and** `resample_shape` is provided:
  - Data is resampled on-the-fly every epoch (slower and [very] inefficient but saves disk space)
- If `resample_shape=None`:
  - Data is loaded as-is from source without resampling

In [None]:
class BraTSAutoEncoderDataset(Dataset):
    """
    BraTS2020 Dataset for 3D multi-modal MRI brain tumor reconstruction.

    Args:
        split (str):
            Dataset split to load. Must be either "train" (369 samples)
            or "test" (125 samples). Augmentations are applied only when split="train".

        root_dir (Optional[str]):
            Root directory containing the downloaded dataset.

        processed_root_dir (Optional[str]):
            Root directory used to store/read cached `.npy` files when store_locally=True.
            Defaults to '<root_dir>/processed' if None.

        modalities_list (List[str]):
            List of MRI modalities to load (e.g., ["t1ce"]). Determines the number
            of input channels.

        resample_shape (Optional[tuple]):
            Target spatial shape (D, H, W) for resampling. If None, volumes are
            loaded at their original resolution.

        store_locally (bool):
            If True and `resample_shape` is specified, resampled volumes are saved
            as `.npy` for faster future loading.

        geometric_transforms:
            TorchIO spatial transforms applied during training before cropping.

        intensity_transforms:
            TorchIO intensity transforms applied during training after cropping.

        output_shape (Optional[tuple]):
            Target spatial shape (D, H, W) for center cropping the final volume.
            If None, the full (resampled) volume is returned.

        norm_type (str):
            Intensity normalization strategy applied channel-wise. Supported values:
                - "z": Z-score normalization computed on the foreground.
                - "minmax": Percentile-based min–max normalization to [0, 1].

    Usage Modes:
        1) Raw:
            Load the volume as-is, from either NIfTI or existing `.npy`, without
            resampling. Only optional cropping or augmentation is applied.

        2) Raw + Resample:
            Load the volume and resample on-the-fly to `resample_shape`. Source
            can be NIfTI or `.npy`. No caching unless `store_locally=True`.

        3) Raw + Resample + Store Locally:
            Load the volume, resample to `resample_shape`, and store it as `.npy`
            for faster subsequent access. Useful for training with fixed shapes
            and high I/O efficiency.
    """

    def __init__(
        self,
        split: str = "train",
        root_dir: str = "../kaggle/input/brats20-dataset-training-validation/",
        processed_root_dir: Optional[str] = None,
        modalities_list: List[str] = ["t1ce"],
        resample_shape: Optional[tuple] = None,
        store_locally: bool = False,
        geometric_transforms=None,
        intensity_transforms=None,
        output_shape: Optional[tuple] = None,
        norm_type: str = "minmax",
        random_seed: int = 42,
    ):

        self.root_dir = Path(root_dir)
        self.split = split
        self.modalities_list = modalities_list
        self.resample_shape = resample_shape
        self.store_locally = store_locally
        self.geometric_transforms = geometric_transforms
        self.intensity_transforms = intensity_transforms
        self.output_shape = output_shape
        self.norm_type = norm_type

        self.apply_augmentations = (
            self.split == "train"
            and (self.geometric_transforms is not None or self.intensity_transforms is not None)
        )

        if split == "train":
            self.total_samples = 369
            self.sample_name = "Training"
        elif split == "test":
            self.total_samples = 125
            self.sample_name = "Validation"
        else:
            raise ValueError("split must be 'train' or 'test'")
        self.data_subdir = f"BraTS2020_{self.sample_name}Data"
        self.data_subsubdir = f"MICCAI_BraTS2020_{self.sample_name}Data"

        if root_dir is None:
            raise ValueError("root_dir must be provided")

        self.processed_root = (Path(processed_root_dir) if processed_root_dir is not None else self.root_dir / "processed")

        self.raw_samples_path = self.root_dir / self.data_subdir / self.data_subsubdir # This is the default file-structure for the released BraTS20 Dataset

        if not self.raw_samples_path.exists():
            raise FileNotFoundError(f"Raw data not found at {self.raw_samples_path}")

        self.cached_samples_path = self.processed_root / self.data_subdir / self.data_subsubdir # We align the cached files to the original file-structure.

        # If store_locally is True, perform preprocessing at instantiation
        if self.store_locally and self.resample_shape is not None:
            voxels_per_sample = np.prod(self.resample_shape)
            bytes_per_sample = voxels_per_sample * len(self.modalities_list) * 4 # Assume fp32, this should be adaptive
            total_gb = (bytes_per_sample * self.total_samples) / (1024**3)
            print(f"[INFO] Estimated storage: {total_gb:.2f} GB")
            self._preprocess_dataset()

        self.use_resampling = self.resample_shape is not None

    def _preprocess_dataset(self):
        """Preprocess and store all samples when store_locally=True."""
        self.cached_samples_path.mkdir(parents=True, exist_ok=True)

        print(f"[INFO] Preprocessing dataset: {self.split}")
        for idx in tqdm(range(self.total_samples), desc=f"Preprocessing {self.split}", unit="sample",):
            sample_idx = f"{idx + 1:03d}"
            sample_filename = f"BraTS20_{self.sample_name}_{sample_idx}"

            if not self._cached_exists(sample_filename): # This is convenient, if you want to preprocess again, delete and restart runtime.
                x = self._load_from_root(sample_filename)
                x = self._resample(x)
                self._save_npy(x, sample_filename)
            else:
                print(f"[INFO] Sample {idx + 1}/{self.total_samples} already cached: {sample_filename}")

    def __len__(self):
        return self.total_samples

    def __getitem__(self, idx):
        sample_idx = f"{idx + 1:03d}"
        sample_filename = f"BraTS20_{self.sample_name}_{sample_idx}"

        # Load from processed storage
        if self.store_locally and self.use_resampling:
            x = self._load_npy(sample_filename)
        else:
            # Load from root_dir
            x = self._load_from_root(sample_filename)

            if self.use_resampling: # On-the-fly resampling: redundant but allows to train with no additional disk space utilized.
                x = self._resample(x)

        foreground_mask = None

        if self.apply_augmentations:
            x, _, foreground_mask = apply_augmentation(
                x,
                modalities_list=self.modalities_list,
                geometric_transforms=self.geometric_transforms,
                intensity_transforms=self.intensity_transforms,
            )
            if foreground_mask is not None:
                foreground_mask = foreground_mask.squeeze(0)

        if self.output_shape is not None:
            x = _center_crop(x, self.output_shape)
            if foreground_mask is not None:
                foreground_mask = _center_crop(foreground_mask, self.output_shape)

        if foreground_mask is None:
            foreground_mask = x[0] > 0

        x = self._normalize(x, foreground_mask)
        return x

    def _cached_exists(self, sample_filename: str) -> bool:
        npy_file = (self.cached_samples_path / sample_filename / f"{sample_filename}_resampled.npy")
        return npy_file.exists()

    def _load_npy(self, sample_filename: str) -> torch.Tensor:
        npy_file = (self.cached_samples_path / sample_filename / f"{sample_filename}_resampled.npy")
        return torch.from_numpy(np.load(npy_file))

    def _save_npy(self, tensor: torch.Tensor, sample_filename: str):
        out_dir = self.cached_samples_path / sample_filename
        out_dir.mkdir(parents=True, exist_ok=True)
        npy_file = out_dir / f"{sample_filename}_resampled.npy"
        np.save(npy_file, tensor.numpy())

    def _load_from_root(self, sample_filename: str) -> torch.Tensor:
        """Load data from root_dir (either NIfTI or existing .npy)."""
        # Check if there's an existing .npy file in root_dir
        npy_file = self.raw_samples_path / sample_filename / f"{sample_filename}_resampled.npy"
        if npy_file.exists():
            return torch.from_numpy(np.load(npy_file))
        else: # Fallback: .nii
            return self._load_nifti(sample_filename)

    def _load_nifti(self, sample_filename: str) -> torch.Tensor:
        sample_path = self.raw_samples_path / sample_filename
        volumes = []
        for mod in self.modalities_list:
            vol_path = sample_path / f"{sample_filename}_{mod}.nii"
            volumes.append(self._load_nifti_volume(vol_path))
        x = np.stack(volumes, axis=0)
        return torch.from_numpy(x)

    def _resample(self, x: torch.Tensor) -> torch.Tensor:
        subject = tio.Subject(image=tio.ScalarImage(tensor=x))
        original_shape = subject.image.shape[-3:]
        original_spacing = subject.image.spacing

        target_shape = self.resample_shape
        new_spacing = tuple(
            o_sp * o_sh / t_sh
            for o_sp, o_sh, t_sh in zip(original_spacing, original_shape, target_shape)
        )

        transform = tio.Compose(
            [tio.ToCanonical(), tio.Resample(new_spacing, image_interpolation="bspline")]
        )

        return transform(subject).image.data

    def _normalize(self, tensor: torch.Tensor, foreground_mask: torch.Tensor) -> torch.Tensor:
        out = torch.empty_like(tensor)

        for c in range(tensor.shape[0]):
            channel = tensor[c]

            if self.norm_type == "z":
                mask = foreground_mask
                if mask.any():
                    mean = channel[mask].mean()
                    std = channel[mask].std()
                    if std > 0:
                        channel = channel.clone()
                        channel[mask] = (channel[mask] - mean) / std
                out[c] = channel

            elif self.norm_type == "minmax":
                subject = tio.Subject(
                    image=tio.ScalarImage(tensor=channel.unsqueeze(0))
                )
                transform = tio.RescaleIntensity(
                    out_min_max=(0, 1),
                    percentiles=(0.0, 99.5),
                )
                out[c] = transform(subject).image.data[0]

            else:
                raise ValueError(f"Unknown normalization type: {self.norm_type}")

        return out

    @staticmethod
    def _load_nifti_volume(path: Path) -> np.ndarray:
        nii = nib.load(str(path))
        return np.asarray(nii.dataobj, dtype=np.float32)

## Split

In [None]:
def get_dataloaders(
    train_batch_size=1,
    test_batch_size=1,
    geometric_transforms=None,
    intensity_transforms=None,
    output_shape=(112, 112, 96),
    modalities_list=['t1ce'],
    seed=42,
    path='./content/',
    processed_path=None,
    num_workers=0,
    norm_type='minmax',
    resample_shape=(128, 128, 128),  # None means load as it is
    store_locally=False  # Flag for caching resampled data, if resample_shape is None, this is not considered.
):
    """
    Create train and test dataloaders for BraTSAutoEncoderDataset.
    """

    g = torch.Generator()
    g.manual_seed(seed)

    train_dataset = BraTSAutoEncoderDataset(
        split='train',
        root_dir=path,
        processed_root_dir=processed_path,
        modalities_list=modalities_list,
        resample_shape=resample_shape,
        store_locally=store_locally,
        geometric_transforms=geometric_transforms,
        intensity_transforms=intensity_transforms,
        output_shape=output_shape,
        norm_type=norm_type
    )

    test_dataset = BraTSAutoEncoderDataset(
        split='test',
        root_dir=path,
        processed_root_dir=processed_path,
        modalities_list=modalities_list,
        resample_shape=resample_shape,
        store_locally=store_locally,
        geometric_transforms=None,  # no augmentation for test
        intensity_transforms=None,
        output_shape=output_shape,
        norm_type=norm_type
    )

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
        generator=g
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

    return train_dataset, test_dataset, train_dataloader, test_dataloader


In [None]:
geometric_transform=None
intensity_transform=None

data_cfg = config["data"]
training_cfg = config["training"]

train_dataset, test_dataset, train_dataloader, test_dataloader = get_dataloaders(
    path=dataset_path,
    geometric_transforms=geometric_transform,
    intensity_transforms=intensity_transform,
    train_batch_size=training_cfg["train_batch_size"],
    test_batch_size=training_cfg["test_batch_size"],
    output_shape=tuple(data_cfg["output_shape"]),
    modalities_list=data_cfg["modalities_list"],
    norm_type=data_cfg["norm_type"],
    processed_path="/content/processed",
    resample_shape=data_cfg["resample_shape"],
    store_locally=data_cfg["store_locally"]
)

In [None]:
from monai.utils import first
batch = first(train_dataloader)
print("Batch shape:", batch.shape)
show_mri_slices(batch)

## Model Architecture

In [None]:
model = AutoencoderKL(
    spatial_dims=config["model"]["spatial_dims"],
    in_channels=config["model"]["autoencoder"]["in_channels"],
    out_channels=config["model"]["autoencoder"]["out_channels"],
    num_channels=config["model"]["autoencoder"]["num_channels"],
    latent_channels=config["model"]["autoencoder"]["latent_channels"],
    num_res_blocks=config["model"]["autoencoder"]["num_res_blocks"],
    norm_num_groups=config["model"]["autoencoder"]["norm_num_groups"],
    attention_levels=config["model"]["autoencoder"]["attention_levels"],
)
model.to(device)

discriminator = PatchDiscriminator(
    spatial_dims=config["model"]["spatial_dims"],
    num_layers_d=config["model"]["discriminator"]["num_layers_d"],
    num_channels=config["model"]["discriminator"]["num_channels"],
    in_channels=config["model"]["discriminator"]["in_channels"],
    out_channels=config["model"]["discriminator"]["out_channels"],
    kernel_size=config["model"]["discriminator"]["kernel_size"],
    padding=config["model"]["discriminator"]["padding"], # Might remove this
    activation=(Act.LEAKYRELU, {"negative_slope": 0.2}),
    norm="BATCH",
    bias=False,
)
discriminator.to(device)

In [None]:
# from torchinfo import summary
# summary(model, input_size=(1, 1, 128, 128, 112))

## Loss Functions

In [None]:
import torch
import torch.nn.functional as F
from generative.losses import PatchAdversarialLoss, PerceptualLoss
loss_cfg = config["loss"]
recon_type = loss_cfg["reconstruction"]["type"]

if recon_type == "l1":
    recon_loss_fn = F.l1_loss
elif recon_type == "l2":
    recon_loss_fn = F.mse_loss
else:
    raise ValueError(f"Unsupported reconstruction loss: {recon_type}")

perceptual_loss_fn = PerceptualLoss(
    spatial_dims=loss_cfg["perceptual"]["spatial_dims"],
    network_type=loss_cfg["perceptual"]["network"],
    fake_3d_ratio=loss_cfg["perceptual"]["fake_3d_ratio"],
).to(device)

adv_loss_fn = PatchAdversarialLoss(criterion=loss_cfg["adversarial"]["type"])
optimizer_g = torch.optim.Adam(model.parameters(), lr=float(config["training"]["lr_g"]))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=float(config["training"]["lr_d"]))

scheduler_g = None
scheduler_d = None

In [None]:
class MetricAccumulator:
    def __init__(self):
        self.sums = defaultdict(float)
        self.counts = defaultdict(int)

    def add(self, metrics: dict):
        for k, v in metrics.items():
            if isinstance(v, list):
                self.sums[k] += sum(v)
                self.counts[k] += len(v)
            else:
                self.sums[k] += v
                self.counts[k] += 1

    def mean(self, key: str):
        return self.sums[key] / self.counts[key]

    def all_means(self):
        return {k: self.mean(k) for k in self.sums}

    def clear(self):
        self.sums.clear()
        self.counts.clear()

## Trainer

In [None]:
class Trainer:
    def __init__(
        self,
        model,
        discriminator,
        optimizer_g,
        optimizer_d,
        scheduler_g,
        scheduler_d,
        recon_loss_fn,
        perceptual_loss_fn,
        adv_loss_fn,
        config,
        run,
        device='cuda',
        modalities_list=None,
        checkpoint_dir="./checkpoints"
    ):
        self.model = model
        self.discriminator = discriminator
        self.optimizer_g = optimizer_g
        self.optimizer_d = optimizer_d
        self.scheduler_g = scheduler_g
        self.scheduler_d = scheduler_d

        self.recon_loss_fn = recon_loss_fn
        self.perceptual_loss_fn = perceptual_loss_fn
        self.adv_loss_fn = adv_loss_fn

        self.config = config
        self.run = run
        self.device = device

        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)

        self.modalities_list = config["data"]["modalities_list"]
        self.kl_weight = config["loss"]["kl_loss"]["weight"]
        self.adv_weight = config["loss"]["adversarial"]["gen_weight"]
        self.perceptual_weight = config["loss"]["perceptual"]["weight"]
        self.grad_accum_steps = config["training"]["accumulation_steps"]
        self.use_wandb = config["experiment"]["wandb_logging"]

        self.best_val_loss = float('inf')
        self.best_model_path = None

    def train_epoch(self, dataloader, epoch, global_step):
            self.model.train()
            self.discriminator.train()

            acc = MetricAccumulator()
            step_times = []
            epoch_start = time.time()

            pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Epoch {epoch}")

            for step, batch in pbar:
                step_start = time.time()
                images = batch.to(self.device)

                reconstruction, z_mu, z_sigma = self.model(images)
                logits_fake = self.discriminator(reconstruction.contiguous())[-1]

                recons_loss = self.recon_loss_fn(reconstruction, images)
                p_loss = self.perceptual_loss_fn(reconstruction, images)
                gen_adv_loss = self.adv_loss_fn(logits_fake, target_is_real=True, for_discriminator=False)

                kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3, 4])
                kl_loss = torch.mean(kl_loss)

                loss_g = recons_loss + (self.kl_weight * kl_loss) + (self.perceptual_weight * p_loss) + (self.adv_weight * gen_adv_loss)

                (loss_g / self.grad_accum_steps).backward()

                logits_fake_d = self.discriminator(reconstruction.detach().contiguous())[-1]
                loss_d_fake = self.adv_loss_fn(logits_fake_d, target_is_real=False, for_discriminator=True)

                logits_real = self.discriminator(images.contiguous())[-1]
                loss_d_real = self.adv_loss_fn(logits_real, target_is_real=True, for_discriminator=True)

                discriminator_loss = (loss_d_fake + loss_d_real) * 0.5
                loss_d = (self.adv_weight * discriminator_loss) / self.grad_accum_steps
                loss_d.backward()

                acc.add({
                    "train/recons_loss":        recons_loss.item(),
                    "train/perceptual_loss":    p_loss.item(),
                    "train/kl_loss":            kl_loss.item(),
                    "train/generator_loss":     gen_adv_loss.item(),
                    "train/discriminator_loss": discriminator_loss.item(),
                    "train/total_gen_loss":     loss_g.item(),
                })

                is_last_batch = (step + 1) == len(dataloader)
                is_acc_step = (step + 1) % self.grad_accum_steps == 0

                if is_acc_step or is_last_batch:
                    self.optimizer_g.step()
                    if self.scheduler_g is not None:
                        self.scheduler_g.step()
                    self.optimizer_g.zero_grad(set_to_none=True)

                    self.optimizer_d.step()
                    if self.scheduler_d is not None:
                        self.scheduler_d.step()
                    self.optimizer_d.zero_grad(set_to_none=True)

                    global_step += 1

                    log_dict = acc.all_means()
                    log_dict.update({
                        "main/lr_g": self.optimizer_g.param_groups[0]["lr"],
                        "main/lr_d": self.optimizer_d.param_groups[0]["lr"],
                        "main/epoch": epoch,
                    })
                    self.log(log_dict, step=global_step, split='train')
                    acc.clear()

                pbar.set_postfix({"recon": f"{recons_loss.item():.3f}", "disc": f"{discriminator_loss.item():.3f}"})
                step_times.append(time.time() - step_start)

            self.log({
                "timing/epoch_train_sec": time.time() - epoch_start,
                "timing/avg_step_sec": sum(step_times) / len(step_times),
            }, step=global_step)

            return global_step

    @torch.no_grad()
    def validate(self, dataloader, epoch, global_step):
        self.model.eval()
        val_start = time.time()

        acc = MetricAccumulator()
        pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Validation Epoch {epoch}")

        for step, batch in pbar:
            images = batch.to(self.device)
            reconstruction, _, _ = self.model(images)

            if step == 0:
                if  self.run:
                    self._log_visuals(images[0], reconstruction[0], step=global_step)
                else:
                    title=f"Validation Epoch {epoch} | Step {global_step}"
                    show_mri_slices(images[0], reconstruction[0], title=title)

            recons_loss = self.recon_loss_fn(reconstruction, images)

            acc.add({"validation/recons_loss": recons_loss.item()})

        log_dict = acc.all_means()
        log_dict["timing/validation_sec"] = time.time() - val_start

        self.log(log_dict, step=global_step, split='val')

        return log_dict["validation/recons_loss"]

    def log(self, metrics_dict, step, split='train'):
        if self.use_wandb and self.run is not None:
            self.run.log(metrics_dict, step=step)
        elif split == 'val' and "validation/recons_loss" in metrics_dict:
            summary = (
                f"\n### Validation [Step {step}] ###\n"
                f"Recon Loss: {metrics_dict['validation/recons_loss']:.4f}\n"
                f"{'#' * 40}"
            )
            tqdm.write(summary)

    def _log_visuals(self, orig_tensor, recon_tensor, step):
        orig = orig_tensor.cpu().numpy()
        recon = recon_tensor.cpu().numpy()

        C, H, W, D = orig.shape
        center_h, center_w, center_d = H // 2, W // 2, D // 2
        slice_caps = ["axial", "coronal", "sagittal"]

        for m, mod in enumerate(self.modalities_list):
            img, img_rec = orig[m], recon[m]

            images_real = [
                img[:, :, center_d],
                img[:, center_w, :],
                img[center_h, :, :]
            ]
            images_recon = [
                img_rec[:, :, center_d],
                img_rec[:, center_w, :],
                img_rec[center_h, :, :]
            ]

            for slice_img, slice_rec, cap in zip(images_real, images_recon, slice_caps):
                self.run.log({
                    f"validation_plot/{mod}/{cap}_real": wandb.Image(slice_img),
                    f"validation_plot/{mod}/{cap}_recon": wandb.Image(slice_rec),
                    }, step=step)

    def save_checkpoint(self, epoch, global_step, val_loss, is_best=False):
        """
        Saves the model state. Keeps 'last.pth' and the single best model.
        """
        state = {
            'epoch': epoch,
            'global_step': global_step,
            'model_state_dict': self.model.state_dict(),
            'discriminator_state_dict': self.discriminator.state_dict(),
            'optimizer_g_state_dict': self.optimizer_g.state_dict(),
            'optimizer_d_state_dict': self.optimizer_d.state_dict(),
            'scheduler_g_state_dict': self.scheduler_g.state_dict() if self.scheduler_g else None,
            'scheduler_d_state_dict': self.scheduler_d.state_dict() if self.scheduler_d else None,
            'config': self.config,
        }

        last_path = self.checkpoint_dir / "last_checkpoint.pth"
        torch.save(state, last_path)

        if is_best:
            # Delete previous best to save space (optional but recommended)
            if self.best_model_path and os.path.exists(self.best_model_path):
                try:
                    os.remove(self.best_model_path)
                except OSError:
                    pass # Ignore if file missing

            # Save new best
            new_best_name = f"best_model_loss_{val_loss:.4f}.pth"
            self.best_model_path = self.checkpoint_dir / new_best_name
            torch.save(state, self.best_model_path)

            print(f"New best model saved: {new_best_name}")
        else:
            print(f"Checkpoint saved to {last_path}")

## Training

When prompted, select option (2): Use an existing W&B account and paste your API key.

In [None]:
run = None
def fmt(val): return f"{val:.0e}" if val < 0.001 else f"{val}"
run_name = (
    f"aekl_"
    f"KL{fmt(config['loss']['kl_loss']['weight'])}_"
    f"Adv{fmt(config['loss']['adversarial']['gen_weight'])}_"
    f"P{fmt(config['loss']['perceptual']['weight'])}_"
    f"bs{config['training']['train_batch_size']}_"
    f"LRg{fmt(config['training']['lr_g'])}_"
    f"LRd{fmt(config['training']['lr_d'])}_"
    f"{time.strftime('%d%b_%H-%M')}"
)
if config['experiment'].get('wandb_logging', False):
    print("WandB logging is enabled. Run is getting started...")
    run = wandb.init(project=config["project"], name=run_name, config=config)
else:
    run = None
    print("WandB logging is disabled. Standard console output will be used.")

trainer = Trainer(
    model=model,
    discriminator=discriminator,
    optimizer_g=optimizer_g,
    optimizer_d=optimizer_d,
    scheduler_g=scheduler_g,
    scheduler_d=scheduler_d,
    recon_loss_fn=recon_loss_fn,
    adv_loss_fn=adv_loss_fn,
    perceptual_loss_fn=perceptual_loss_fn,
    config=config,
    run=run,
    device=device
)

global_step = 0
patience_counter = 0
test_freq = config['training'].get('val_interval', 2)
patience = config['training'].get('early_stopping_patience', 5)

#current_loss = trainer.validate(test_dataloader, epoch=0, global_step=0) # Testing purposes
for epoch in range(1, config['training']['epochs'] + 1):

    global_step = trainer.train_epoch(train_dataloader, epoch, global_step)

    if (epoch + 1) % test_freq == 0:
        current_loss = trainer.validate(test_dataloader, epoch, global_step)

        is_best = current_loss < trainer.best_val_loss

        if is_best:
            trainer.best_val_loss = current_loss
            if trainer.run:
                trainer.run.summary["best_val_loss"] = current_loss
            # Reset counter if we found a better model
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"No improvement for {patience_counter} checks. (Best: {trainer.best_val_loss:.4f})")

        trainer.save_checkpoint(
            epoch=epoch,
            global_step=global_step,
            val_loss=current_loss,
            is_best=is_best
        )

        # Trigger early stopping
        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch} epochs.")
            break

if run:
    wandb.finish()