# MM-GAN Training: Missing Modality Imputation on IXI Dataset

This notebook implements **Phase 2** of the 2-phase MM-GAN training pipeline for brain MRI missing modality imputation.

**System Overview:**
- **Phase 1** (local): Data download, preprocessing, and slice extraction from IXI brain MRI volumes.
- **Phase 2** (this notebook, Kaggle GPU): Train MM-GAN on the pre-extracted slices with GPU acceleration.

**Key Features:**
- Supports **resumable training** across 2-hour Kaggle session chunks â€” saves checkpoints before timeout and auto-resumes on re-run.
- Trains a UNet Generator + PatchGAN Discriminator on 3 modalities: **T1, T2, PD**.
- Uses **implicit conditioning** and **curriculum learning** for stable training.

**Run this notebook twice:**
1. `EXPERIMENT_NAME = "baseline"` with baseline preprocessed data.
2. `EXPERIMENT_NAME = "optimized"` with N4 bias-corrected (optimized) data.

Then compare the results to evaluate the impact of preprocessing quality on imputation performance.

In [None]:
# Install dependencies (Kaggle already has PyTorch, torchvision, numpy, etc.)
!pip install -q scikit-image tensorboard tqdm

import os, sys, time, random, json, glob
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## Configuration

Change `EXPERIMENT_NAME` and `DATA_PATH` for each run:
- **Session 1**: Starts fresh training from epoch 0.
- **Subsequent sessions**: Automatically resumes from the latest checkpoint.

| Run | `EXPERIMENT_NAME` | `DATA_PATH` |
|-----|-------------------|--------------|
| 1   | `"baseline"`      | Path to baseline slices dataset |
| 2   | `"optimized"`     | Path to N4 bias-corrected slices dataset |

In [None]:
# ============================================================
# CONFIGURATION - Change these for each run
# ============================================================
EXPERIMENT_NAME = "baseline"  # "baseline" or "optimized"
DATA_PATH = "/kaggle/input/ixi-slices-baseline"  # Update with your dataset name

# Training hyperparameters (MUST be identical for both experiments)
N_EPOCHS = 60
BATCH_SIZE = 8
LR_G = 2e-4
LR_D = 2e-4
LAMBDA_PIXEL = 0.9
BETA1 = 0.5
BETA2 = 0.999
SEED = 42
IMG_SIZE = 256
IN_CHANNELS = 3
IMPUTE_TYPE = "zeros"
USE_IC = True           # Implicit conditioning
USE_CURRICULUM = True   # Curriculum learning

# Session management
SAVE_INTERVAL = 5       # Save checkpoint every N epochs
VAL_INTERVAL = 2        # Validate every N epochs
MAX_SESSION_TIME = 6600  # 1h50m safety margin (Kaggle limit is 2h = 7200s)

# Paths
CHECKPOINT_DIR = f"/kaggle/working/checkpoints/{EXPERIMENT_NAME}"
LOG_DIR = f"/kaggle/working/logs/{EXPERIMENT_NAME}"
RESULTS_DIR = f"/kaggle/working/results/{EXPERIMENT_NAME}"

for d in [CHECKPOINT_DIR, LOG_DIR, RESULTS_DIR]:
    os.makedirs(d, exist_ok=True)

print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Data path: {DATA_PATH}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")

## MM-GAN Architecture

**Generator**: UNet with 8 encoder blocks and 7 decoder blocks with skip connections. Uses InstanceNorm and ReLU final activation (data is [0,1] normalized).

**Discriminator**: PatchGAN with 4 downsampling blocks. Outputs per-patch real/fake predictions.

Adapted for 3 modalities (T1, T2, PD) from the original 4-modality BRATS setup.

In [None]:
# ============================================================
# MM-GAN Model Architecture
# Adapted for IXI Dataset (3 modalities: T1, T2, PD)
# Based on: https://github.com/trane293/mm-gan
# ============================================================


def set_seed(seed=42):
    """Set all random seeds for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def weights_init_normal(m):
    """Initialize weights using normal distribution."""
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


# ---- U-Net Generator ----

class UNetDown(nn.Module):
    """Encoder block: Conv -> [InstanceNorm] -> LeakyReLU -> [Dropout]"""

    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UNetUp(nn.Module):
    """Decoder block: ConvTranspose -> InstanceNorm -> ReLU -> [Dropout] + skip"""

    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        if dropout:
            layers.append(nn.Dropout(dropout))
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip_input):
        x = self.model(x)
        x = torch.cat((x, skip_input), 1)
        return x


class GeneratorUNet(nn.Module):
    """
    UNet Generator with skip connections.
    For IXI: in_channels=3 (T1, T2, PD), out_channels=3
    Input size: (B, 3, 256, 256) -> Output: (B, 3, 256, 256)
    """

    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        # Encoder
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.2)
        self.down5 = UNetDown(512, 512, dropout=0.2)
        self.down6 = UNetDown(512, 512, dropout=0.2)
        self.down7 = UNetDown(512, 512, dropout=0.2)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.2)

        # Decoder
        self.up1 = UNetUp(512, 512, dropout=0.2)
        self.up2 = UNetUp(1024, 512, dropout=0.2)
        self.up3 = UNetUp(1024, 512, dropout=0.2)
        self.up4 = UNetUp(1024, 512, dropout=0.2)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        # Final layer with ReLU (output in [0, inf), data is [0,1] normalized)
        self.final = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(128, out_channels, 4, padding=1),
            nn.ReLU(),
        )

    def forward(self, x):
        # Encoder path
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)  # Bottleneck

        # Decoder path with skip connections
        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)

        return self.final(u7)


# ---- PatchGAN Discriminator ----

class Discriminator(nn.Module):
    """
    PatchGAN Discriminator.
    For IXI: in_channels=3 (receives concat of real/fake + condition = 6 channels)
    Output: (B, out_channels, H/16, W/16) patch predictions
    """

    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Downsampling block: Conv -> [InstanceNorm] -> LeakyReLU"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # Input: (in_channels * 2) because we concat img_A and img_B
        self.model = nn.Sequential(
            *discriminator_block(in_channels * 2, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, out_channels, 4, padding=1, bias=False),
        )

    def forward(self, img_A, img_B):
        """
        Args:
            img_A: Generated/real image (B, C, H, W)
            img_B: Condition image (B, C, H, W)
        Returns:
            Patch predictions (B, out_channels, H/16, W/16)
        """
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)


# ---- Missing Modality Logic ----

# All possible missing modality scenarios for 3 modalities (T1, T2, PD)
# 0 = missing, 1 = available
# Sorted by difficulty: fewer available = harder (first)
ALL_SCENARIOS_3MOD = [
    [1, 0, 0],  # Only T1 available       (hardest: 2 missing)
    [0, 1, 0],  # Only T2 available
    [0, 0, 1],  # Only PD available
    [1, 1, 0],  # T1+T2 available          (medium: 1 missing)
    [1, 0, 1],  # T1+PD available
    [0, 1, 1],  # T2+PD available          (easiest: 1 missing)
]

# Modality names (indexed)
MODALITY_NAMES = ["T1", "T2", "PD"]


def get_curriculum_scenarios(epoch, total_epochs):
    """
    Curriculum learning strategy for 3-modality setup.
    Starts with easy scenarios (1 missing) and gradually adds harder ones.

    Returns: (low_idx, high_idx) range into ALL_SCENARIOS_3MOD
    """
    progress = epoch / max(total_epochs, 1)

    if progress <= 0.3:
        # First 30%: easy scenarios only (1 missing modality)
        return 3, 6
    elif progress <= 0.7:
        # 30-70%: all scenarios
        return 0, 6
    else:
        # 70%+: all scenarios (full difficulty)
        return 0, 6


def impute_missing(x_real, scenario, impute_type="zeros"):
    """
    Replace missing modality channels with imputation values.

    Args:
        x_real: (B, C, H, W) tensor of real images
        scenario: list of 0/1 indicating missing/available
        impute_type: 'zeros', 'noise', or 'average'

    Returns:
        x_imputed: tensor with missing channels replaced
    """
    x_imputed = x_real.clone()
    B, C, H, W = x_imputed.shape

    if impute_type == "average":
        avail_idx = [i for i, s in enumerate(scenario) if s == 1]
        if avail_idx:
            avg = torch.mean(x_real[:, avail_idx, ...], dim=1)
        else:
            avg = torch.zeros(B, H, W, device=x_real.device)

    for idx, available in enumerate(scenario):
        if available == 0:
            if impute_type == "zeros":
                x_imputed[:, idx, ...] = 0.0
            elif impute_type == "noise":
                x_imputed[:, idx, ...] = torch.randn(B, H, W, device=x_real.device)
            elif impute_type == "average":
                x_imputed[:, idx, ...] = avg

    return x_imputed


def impute_reals_into_fake(x_input, fake_x, scenario):
    """
    Implicit conditioning: copy real (available) modalities back into
    the generator output so loss is only on synthesized channels.

    Args:
        x_input: (B, C, H, W) original input with real available channels
        fake_x: (B, C, H, W) generator output
        scenario: list of 0/1 indicating missing/available
    Returns:
        fake_x with available channels replaced by real values
    """
    result = fake_x.clone()
    for idx, available in enumerate(scenario):
        if available == 1:
            result[:, idx, ...] = x_input[:, idx, ...].clone()
    return result


def compute_missing_loss(fake_x, real_x, scenario, loss_fn):
    """
    Compute loss ONLY on missing modality channels (for implicit conditioning).

    Args:
        fake_x: Generator output (B, C, H, W)
        real_x: Ground truth (B, C, H, W)
        scenario: list of 0/1
        loss_fn: loss function (e.g., nn.L1Loss())

    Returns:
        Loss value averaged over missing channels
    """
    losses = []
    for idx, available in enumerate(scenario):
        if available == 0:
            losses.append(loss_fn(fake_x[:, idx, ...], real_x[:, idx, ...]))

    if losses:
        return sum(losses) / len(losses)
    return torch.tensor(0.0, device=fake_x.device)


print("Model architecture loaded.")
print(f"  Scenarios: {len(ALL_SCENARIOS_3MOD)}")
print(f"  Modalities: {MODALITY_NAMES}")

## IXI Slice Dataset

Loads pre-extracted `.npy` slices with shape `(3, H, W)` = `[T1, T2, PD]`.

Data is already normalized to `[0, 1]` by the Phase 1 extraction pipeline. Expects directory structure:
```
DATA_PATH/
  train/
    subject001_slice050.npy
    ...
  val/
    ...
  test/
    ...
```

In [None]:
# ============================================================
# IXI Dataset Loader
# ============================================================


class IXISliceDataset(Dataset):
    """
    Dataset for loading pre-extracted IXI axial slices.
    Each .npy file is shape (3, H, W) with channels [T1, T2, PD].
    Data is already normalized to [0, 1] by the extraction pipeline.
    """

    def __init__(self, data_dir, split="train", target_size=(256, 256), augment=False):
        """
        Args:
            data_dir: Root directory containing split subdirs (train/val/test)
            split: 'train', 'val', or 'test'
            target_size: Resize slices to this size (H, W)
            augment: Apply data augmentation (horizontal flip)
        """
        self.data_dir = Path(data_dir) / split
        self.target_size = target_size
        self.augment = augment

        # Collect all .npy files
        if not self.data_dir.exists():
            raise FileNotFoundError(f"Split directory not found: {self.data_dir}")

        self.file_list = sorted(list(self.data_dir.glob("*.npy")))

        if len(self.file_list) == 0:
            raise RuntimeError(f"No .npy files found in {self.data_dir}")

        print(f"[IXISliceDataset] {split}: {len(self.file_list)} slices from {self.data_dir}")

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        """
        Returns:
            dict with:
                image: (3, H, W) float32 tensor, channels = [T1, T2, PD]
                filename: str, name of the .npy file (for tracking)
        """
        filepath = self.file_list[idx]
        data = np.load(filepath).astype(np.float32)  # (3, H, W)

        # Convert to tensor
        image = torch.from_numpy(data)

        # Resize if needed
        if self.target_size is not None:
            h, w = image.shape[1], image.shape[2]
            if (h, w) != self.target_size:
                image = TF.resize(
                    image, list(self.target_size),
                    interpolation=TF.InterpolationMode.BILINEAR,
                    antialias=True,
                )

        # Augmentation
        if self.augment and torch.rand(1).item() > 0.5:
            image = TF.hflip(image)

        return {
            "image": image,
            "filename": filepath.stem,
        }


def create_dataloaders(
    data_dir,
    batch_size=8,
    target_size=(256, 256),
    num_workers=2,
    augment_train=True,
):
    """
    Create train, val, test DataLoaders.

    Args:
        data_dir: Root directory with train/val/test subdirs
        batch_size: Batch size
        target_size: Resize to (H, W)
        num_workers: DataLoader workers
        augment_train: Apply augmentation to training set

    Returns:
        dict of {'train': DataLoader, 'val': DataLoader, 'test': DataLoader}
    """
    loaders = {}

    for split in ["train", "val", "test"]:
        split_dir = Path(data_dir) / split
        if not split_dir.exists():
            print(f"[WARN] Split directory missing: {split_dir}")
            continue

        augment = augment_train if split == "train" else False
        shuffle = split == "train"

        dataset = IXISliceDataset(
            data_dir=data_dir,
            split=split,
            target_size=target_size,
            augment=augment,
        )

        loaders[split] = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True,
            drop_last=(split == "train"),
        )

    return loaders


print("Dataset classes loaded.")

## Metrics (PSNR, SSIM) and Checkpoint Management

- **Metrics**: NumPy-based (for final evaluation) and PyTorch-based (for training loop) implementations of PSNR and SSIM.
- **Checkpoints**: Saves full training state (model, optimizer, scheduler, epoch, metrics, RNG states) for seamless resume across Kaggle sessions.

In [None]:
# ============================================================
# Metrics: PSNR and SSIM
# ============================================================


# --- NumPy-based (for final evaluation) ---

def psnr_numpy(pred, gt, data_range=1.0):
    """
    Compute PSNR between two images.
    PSNR = 10 * log10(MAX^2 / MSE)
    """
    mse = np.mean((pred.astype(np.float64) - gt.astype(np.float64)) ** 2)
    if mse < 1e-10:
        return 100.0
    return 10.0 * np.log10(data_range**2 / mse)


def ssim_numpy(pred, gt, data_range=1.0):
    """Compute SSIM between two 2D images using scikit-image."""
    from skimage.metrics import structural_similarity
    return structural_similarity(pred, gt, data_range=data_range)


def compute_metrics_batch(pred_batch, gt_batch, data_range=1.0):
    """
    Compute PSNR and SSIM for a batch of images.

    Args:
        pred_batch: (B, C, H, W) numpy array
        gt_batch: (B, C, H, W) numpy array
        data_range: max pixel value

    Returns:
        dict with lists of per-sample PSNR and SSIM values
    """
    B, C, H, W = pred_batch.shape
    psnr_vals = []
    ssim_vals = []

    for b in range(B):
        p_list = []
        s_list = []
        for c in range(C):
            p_list.append(psnr_numpy(pred_batch[b, c], gt_batch[b, c], data_range))
            s_list.append(ssim_numpy(pred_batch[b, c], gt_batch[b, c], data_range))
        psnr_vals.append(np.mean(p_list))
        ssim_vals.append(np.mean(s_list))

    return {"psnr": psnr_vals, "ssim": ssim_vals}


# --- PyTorch-based (for use during training) ---

def psnr_torch(pred, gt, data_range=1.0):
    """
    Compute PSNR using PyTorch tensors.
    Returns scalar PSNR value (averaged over batch).
    """
    mse = torch.mean((pred.float() - gt.float()) ** 2)
    if mse.item() < 1e-10:
        return torch.tensor(100.0)
    return 10.0 * torch.log10(torch.tensor(data_range**2) / mse)


def ssim_torch(pred, gt, window_size=11, data_range=1.0):
    """
    Simple SSIM implementation in PyTorch using a Gaussian window.

    Args:
        pred: (B, 1, H, W) tensor
        gt: (B, 1, H, W) tensor
        window_size: Gaussian window size
        data_range: max pixel value

    Returns:
        Scalar SSIM value
    """
    C1 = (0.01 * data_range) ** 2
    C2 = (0.03 * data_range) ** 2

    # Create Gaussian window
    sigma = 1.5
    gauss = torch.Tensor(
        [np.exp(-(x - window_size // 2) ** 2 / (2 * sigma**2)) for x in range(window_size)]
    )
    gauss = gauss / gauss.sum()

    _1D_window = gauss.unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)

    window = _2D_window.expand(1, 1, window_size, window_size).contiguous()
    window = window.to(pred.device).type(pred.dtype)

    pad = window_size // 2

    mu1 = torch.nn.functional.conv2d(pred, window, padding=pad, groups=1)
    mu2 = torch.nn.functional.conv2d(gt, window, padding=pad, groups=1)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = torch.nn.functional.conv2d(pred * pred, window, padding=pad, groups=1) - mu1_sq
    sigma2_sq = torch.nn.functional.conv2d(gt * gt, window, padding=pad, groups=1) - mu2_sq
    sigma12 = torch.nn.functional.conv2d(pred * gt, window, padding=pad, groups=1) - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
        (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
    )

    return ssim_map.mean()


# ============================================================
# Checkpoint Management
# ============================================================

from datetime import datetime


def save_checkpoint(
    state,
    checkpoint_dir,
    epoch,
    is_best=False,
    prefix="mmgan",
    max_keep=3,
):
    """
    Save training checkpoint with metadata.

    Saves to {prefix}_epoch_{epoch:04d}.pth, keeps only max_keep recent
    checkpoints, and also saves best separately.
    """
    ckpt_dir = Path(checkpoint_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    # Save epoch checkpoint
    filename = ckpt_dir / f"{prefix}_epoch_{epoch:04d}.pth"
    torch.save(state, filename)
    print(f"  [CKPT] Saved: {filename}")

    # Save latest pointer
    latest_path = ckpt_dir / f"{prefix}_latest.pth"
    torch.save(state, latest_path)

    # Save best model
    if is_best:
        best_path = ckpt_dir / f"{prefix}_best.pth"
        torch.save(state, best_path)
        print(f"  [CKPT] New best model saved!")

    # Save metadata
    meta = {
        "epoch": epoch,
        "timestamp": datetime.now().isoformat(),
        "is_best": is_best,
        "best_psnr": state.get("best_psnr", 0.0),
        "best_ssim": state.get("best_ssim", 0.0),
    }
    meta_path = ckpt_dir / f"{prefix}_meta.json"
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)

    # Cleanup old checkpoints (keep latest N)
    all_ckpts = sorted(glob.glob(str(ckpt_dir / f"{prefix}_epoch_*.pth")))
    if len(all_ckpts) > max_keep:
        for old_ckpt in all_ckpts[:-max_keep]:
            os.remove(old_ckpt)


def load_checkpoint(checkpoint_dir, prefix="mmgan", which="latest"):
    """
    Load a checkpoint.

    Args:
        checkpoint_dir: Directory containing checkpoints
        prefix: Filename prefix
        which: 'latest', 'best', or epoch number (int)

    Returns:
        state dict or None if not found
    """
    ckpt_dir = Path(checkpoint_dir)

    if isinstance(which, int):
        path = ckpt_dir / f"{prefix}_epoch_{which:04d}.pth"
    elif which == "best":
        path = ckpt_dir / f"{prefix}_best.pth"
    else:
        path = ckpt_dir / f"{prefix}_latest.pth"

    if not path.exists():
        print(f"  [CKPT] No checkpoint found at: {path}")
        return None

    print(f"  [CKPT] Loading: {path}")
    state = torch.load(path, map_location="cpu", weights_only=False)
    return state


def resume_training(
    generator, discriminator,
    optimizer_G, optimizer_D,
    checkpoint_dir, prefix="mmgan",
    scheduler_G=None, scheduler_D=None,
):
    """
    Resume training from latest checkpoint.

    Returns:
        start_epoch: Epoch to resume from (0 if no checkpoint)
        best_psnr: Best PSNR so far
        best_ssim: Best SSIM so far
        history: Training history dict
    """
    state = load_checkpoint(checkpoint_dir, prefix, which="latest")

    if state is None:
        return 0, 0.0, 0.0, {"train_loss_G": [], "train_loss_D": [], "val_psnr": [], "val_ssim": []}

    # Restore model states
    generator.load_state_dict(state["generator_state_dict"])
    discriminator.load_state_dict(state["discriminator_state_dict"])
    optimizer_G.load_state_dict(state["optimizer_G_state_dict"])
    optimizer_D.load_state_dict(state["optimizer_D_state_dict"])

    if scheduler_G is not None and "scheduler_G_state_dict" in state:
        scheduler_G.load_state_dict(state["scheduler_G_state_dict"])
    if scheduler_D is not None and "scheduler_D_state_dict" in state:
        scheduler_D.load_state_dict(state["scheduler_D_state_dict"])

    # Restore RNG states for reproducibility
    if "rng_state" in state:
        torch.set_rng_state(state["rng_state"])
    if "numpy_rng_state" in state:
        np.random.set_state(state["numpy_rng_state"])

    start_epoch = state["epoch"] + 1
    best_psnr = state.get("best_psnr", 0.0)
    best_ssim = state.get("best_ssim", 0.0)

    history = state.get("history", {
        "train_loss_G": [], "train_loss_D": [],
        "val_psnr": [], "val_ssim": [],
    })

    print(f"  [CKPT] Resuming from epoch {start_epoch}")
    print(f"  [CKPT] Best PSNR: {best_psnr:.4f}, Best SSIM: {best_ssim:.4f}")

    return start_epoch, best_psnr, best_ssim, history


def build_checkpoint_state(
    epoch, generator, discriminator,
    optimizer_G, optimizer_D,
    best_psnr, best_ssim, history,
    scheduler_G=None, scheduler_D=None,
):
    """Build a checkpoint state dict."""
    state = {
        "epoch": epoch,
        "generator_state_dict": generator.state_dict(),
        "discriminator_state_dict": discriminator.state_dict(),
        "optimizer_G_state_dict": optimizer_G.state_dict(),
        "optimizer_D_state_dict": optimizer_D.state_dict(),
        "best_psnr": best_psnr,
        "best_ssim": best_ssim,
        "history": history,
        "rng_state": torch.get_rng_state(),
        "numpy_rng_state": np.random.get_state(),
    }

    if scheduler_G is not None:
        state["scheduler_G_state_dict"] = scheduler_G.state_dict()
    if scheduler_D is not None:
        state["scheduler_D_state_dict"] = scheduler_D.state_dict()

    return state


print("Metrics and checkpoint utilities loaded.")

## Training Loop

Training loop with:
- **Curriculum Learning**: Gradually introduces harder missing-modality scenarios.
- **Implicit Conditioning**: Copies real available modalities into generator output.
- **Session Time Management**: Monitors elapsed time and saves checkpoint before Kaggle's 2-hour limit.

In [None]:
SESSION_START = time.time()


def time_remaining():
    """Check remaining session time."""
    elapsed = time.time() - SESSION_START
    return MAX_SESSION_TIME - elapsed


def validate(generator, val_loader, device, scenarios=None):
    """Run validation and compute PSNR/SSIM metrics."""
    generator.eval()
    if scenarios is None:
        scenarios = ALL_SCENARIOS_3MOD
    all_psnr, all_ssim = [], []
    with torch.no_grad():
        for batch in val_loader:
            x_real = batch["image"].to(device)
            for scenario in scenarios:
                x_input = impute_missing(x_real, scenario, impute_type="zeros")
                fake_x = generator(x_input)
                fake_x = impute_reals_into_fake(x_real, fake_x, scenario)
                for idx, available in enumerate(scenario):
                    if available == 0:
                        pred = fake_x[:, idx:idx+1, ...]
                        gt = x_real[:, idx:idx+1, ...]
                        all_psnr.append(psnr_torch(pred, gt).item())
                        all_ssim.append(ssim_torch(pred, gt).item())
    generator.train()
    return np.mean(all_psnr) if all_psnr else 0.0, np.mean(all_ssim) if all_ssim else 0.0


def train_one_epoch(generator, discriminator, optimizer_G, optimizer_D,
                    train_loader, device, criterion_GAN, criterion_pixel,
                    lambda_pixel, epoch, n_epochs, impute_type, use_ic, use_curriculum,
                    writer=None, global_step=0):
    """Train for one epoch."""
    generator.train()
    discriminator.train()
    epoch_loss_G, epoch_loss_D, n_batches = 0.0, 0.0, 0

    if use_curriculum:
        low, high = get_curriculum_scenarios(epoch, n_epochs)
        available_scenarios = ALL_SCENARIOS_3MOD[low:high]
    else:
        available_scenarios = ALL_SCENARIOS_3MOD

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs}", leave=False):
        # Check time
        if time_remaining() < 300:  # Less than 5 min left
            print(f"WARNING: Less than 5 min remaining, stopping epoch early")
            break

        x_real = batch["image"].to(device)
        scenario = random.choice(available_scenarios)
        x_input = impute_missing(x_real, scenario, impute_type=impute_type)

        # Train Generator
        optimizer_G.zero_grad()
        fake_x = generator(x_input)
        if use_ic:
            fake_x_ic = impute_reals_into_fake(x_real, fake_x, scenario)
        else:
            fake_x_ic = fake_x
        pred_fake = discriminator(fake_x_ic, x_real)
        valid = torch.ones_like(pred_fake, device=device)
        loss_GAN = criterion_GAN(pred_fake, valid)
        if use_ic:
            loss_pixel = compute_missing_loss(fake_x_ic, x_real, scenario, criterion_pixel)
        else:
            loss_pixel = criterion_pixel(fake_x, x_real)
        loss_G = (1 - lambda_pixel) * loss_GAN + lambda_pixel * loss_pixel
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        pred_real = discriminator(x_real, x_real)
        valid = torch.ones_like(pred_real, device=device)
        loss_real = criterion_GAN(pred_real, valid)
        pred_fake = discriminator(fake_x_ic.detach(), x_real)
        fake_label = torch.zeros_like(pred_fake, device=device)
        loss_fake = criterion_GAN(pred_fake, fake_label)
        loss_D = 0.5 * (loss_real + loss_fake)
        loss_D.backward()
        optimizer_D.step()

        epoch_loss_G += loss_G.item()
        epoch_loss_D += loss_D.item()
        n_batches += 1
        global_step += 1

        if writer and n_batches % 50 == 0:
            writer.add_scalar("train/loss_G", loss_G.item(), global_step)
            writer.add_scalar("train/loss_D", loss_D.item(), global_step)

    return epoch_loss_G / max(n_batches, 1), epoch_loss_D / max(n_batches, 1), global_step


print("Training functions loaded.")
print(f"Session time remaining: {time_remaining():.0f}s")

## Execute Training Loop

Auto-resumes from checkpoint if available. If the session time limit approaches, training stops and saves a checkpoint. Re-run this notebook to continue from where it left off.

In [None]:
# Set seed
set_seed(SEED)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Data
print("Loading data...")
loaders = create_dataloaders(
    data_dir=DATA_PATH,
    batch_size=BATCH_SIZE,
    target_size=(IMG_SIZE, IMG_SIZE),
    num_workers=2,
    augment_train=True,
)
train_loader = loaders["train"]
val_loader = loaders.get("val", None)
print(f"Train batches: {len(train_loader)}")
if val_loader:
    print(f"Val batches: {len(val_loader)}")

# Models
print("Building models...")
generator = GeneratorUNet(in_channels=IN_CHANNELS, out_channels=IN_CHANNELS).to(device)
discriminator = Discriminator(in_channels=IN_CHANNELS, out_channels=IN_CHANNELS).to(device)
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=LR_G, betas=(BETA1, BETA2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LR_D, betas=(BETA1, BETA2))
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=20, gamma=0.5)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=20, gamma=0.5)

# Loss
criterion_GAN = nn.MSELoss().to(device)
criterion_pixel = nn.L1Loss().to(device)

# Resume?
start_epoch = 0
best_psnr, best_ssim = 0.0, 0.0
history = {"train_loss_G": [], "train_loss_D": [], "val_psnr": [], "val_ssim": []}

existing_ckpt = load_checkpoint(CHECKPOINT_DIR)
if existing_ckpt is not None:
    print("Resuming from checkpoint...")
    start_epoch, best_psnr, best_ssim, history = resume_training(
        generator, discriminator, optimizer_G, optimizer_D,
        CHECKPOINT_DIR, scheduler_G=scheduler_G, scheduler_D=scheduler_D,
    )
    print(f"Resumed from epoch {start_epoch}, best PSNR: {best_psnr:.2f}")
else:
    print("Starting fresh training")

# TensorBoard
writer = SummaryWriter(LOG_DIR)

# Training loop
print("=" * 60)
print(f"Experiment: {EXPERIMENT_NAME}")
print(f"Epochs: {start_epoch} -> {N_EPOCHS}")
print(f"Batch size: {BATCH_SIZE}, Lambda: {LAMBDA_PIXEL}")
print(f"IC: {USE_IC}, Curriculum: {USE_CURRICULUM}")
print("=" * 60)

global_step = start_epoch * len(train_loader)
training_start = time.time()

for epoch in range(start_epoch, N_EPOCHS):
    # Check session time
    if time_remaining() < 600:  # Less than 10 min
        print(f"\nSession time limit approaching ({time_remaining():.0f}s remaining). Saving and stopping.")
        state = build_checkpoint_state(
            epoch, generator, discriminator, optimizer_G, optimizer_D,
            best_psnr, best_ssim, history, scheduler_G, scheduler_D,
        )
        save_checkpoint(state, CHECKPOINT_DIR, epoch, is_best=False)
        break

    epoch_start = time.time()

    avg_loss_G, avg_loss_D, global_step = train_one_epoch(
        generator, discriminator, optimizer_G, optimizer_D,
        train_loader, device, criterion_GAN, criterion_pixel,
        LAMBDA_PIXEL, epoch, N_EPOCHS, IMPUTE_TYPE, USE_IC, USE_CURRICULUM,
        writer=writer, global_step=global_step,
    )

    scheduler_G.step()
    scheduler_D.step()

    history["train_loss_G"].append(avg_loss_G)
    history["train_loss_D"].append(avg_loss_D)

    epoch_time = time.time() - epoch_start
    print(f"Epoch [{epoch+1}/{N_EPOCHS}] Loss_G: {avg_loss_G:.4f} | Loss_D: {avg_loss_D:.4f} | Time: {epoch_time:.1f}s")

    # Validation
    if val_loader and (epoch + 1) % VAL_INTERVAL == 0:
        avg_psnr, avg_ssim = validate(generator, val_loader, device)
        history["val_psnr"].append(avg_psnr)
        history["val_ssim"].append(avg_ssim)
        writer.add_scalar("val/psnr", avg_psnr, epoch)
        writer.add_scalar("val/ssim", avg_ssim, epoch)
        print(f"  Val PSNR: {avg_psnr:.4f} | Val SSIM: {avg_ssim:.4f}")
        is_best = avg_psnr > best_psnr
        if is_best:
            best_psnr = avg_psnr
            best_ssim = avg_ssim
    else:
        is_best = False

    # Save checkpoint
    if (epoch + 1) % SAVE_INTERVAL == 0 or is_best or (epoch + 1) == N_EPOCHS:
        state = build_checkpoint_state(
            epoch, generator, discriminator, optimizer_G, optimizer_D,
            best_psnr, best_ssim, history, scheduler_G, scheduler_D,
        )
        save_checkpoint(state, CHECKPOINT_DIR, epoch, is_best=is_best)

total_time = time.time() - training_start
writer.close()

print("\n" + "=" * 60)
print("SESSION COMPLETE")
print(f"  Time this session: {total_time:.1f}s ({total_time/3600:.2f}h)")
print(f"  Completed epochs: {epoch+1}/{N_EPOCHS}")
print(f"  Best PSNR: {best_psnr:.4f}")
print(f"  Best SSIM: {best_ssim:.4f}")
print(f"  Checkpoints: {CHECKPOINT_DIR}")
if epoch + 1 < N_EPOCHS:
    print(f"\n  >>> Re-run this notebook to continue training from epoch {epoch+1} <<<")
print("=" * 60)

## Evaluation

Run this section **after all training epochs are complete**. Loads the best model checkpoint and evaluates on the test set across all 6 missing modality scenarios.

Produces a per-scenario metrics table (PSNR, SSIM) and saves results to `metrics.json`.

In [None]:
# Only run this after training is complete!
print("Loading best model for evaluation...")

# Load best checkpoint
generator_eval = GeneratorUNet(in_channels=IN_CHANNELS, out_channels=IN_CHANNELS).to(device)
best_ckpt = load_checkpoint(CHECKPOINT_DIR, which="best")
if best_ckpt is None:
    best_ckpt = load_checkpoint(CHECKPOINT_DIR, which="latest")
generator_eval.load_state_dict(best_ckpt["generator_state_dict"])
print(f"Loaded checkpoint from epoch {best_ckpt.get('epoch', '?')}")

# Load test data
test_loader = loaders.get("test", None)
if test_loader is None:
    print("No test split found, using validation set")
    test_loader = val_loader

# Evaluate
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

generator_eval.eval()
results = {}
with torch.no_grad():
    for scenario in ALL_SCENARIOS_3MOD:
        scenario_str = "".join(str(s) for s in scenario)
        missing_mods = [MODALITY_NAMES[i] for i, s in enumerate(scenario) if s == 0]
        avail_mods = [MODALITY_NAMES[i] for i, s in enumerate(scenario) if s == 1]
        psnr_list, ssim_list = [], []
        for batch in tqdm(test_loader, desc=f"Eval {scenario_str}", leave=False):
            x_real = batch["image"].to(device)
            x_input = impute_missing(x_real, scenario, impute_type="zeros")
            fake_x = generator_eval(x_input)
            fake_x = impute_reals_into_fake(x_real, fake_x, scenario)
            fake_np = fake_x.cpu().numpy()
            real_np = x_real.cpu().numpy()
            for b in range(x_real.size(0)):
                for idx, avail in enumerate(scenario):
                    if avail == 0:
                        psnr_list.append(psnr_numpy(fake_np[b, idx], real_np[b, idx]))
                        ssim_list.append(ssim_numpy(fake_np[b, idx], real_np[b, idx]))
        results[scenario_str] = {
            "missing": missing_mods, "available": avail_mods,
            "psnr_mean": float(np.mean(psnr_list)), "psnr_std": float(np.std(psnr_list)),
            "ssim_mean": float(np.mean(ssim_list)), "ssim_std": float(np.std(ssim_list)),
            "n_samples": len(psnr_list),
        }

# Overall
all_p = [v["psnr_mean"] for v in results.values()]
all_s = [v["ssim_mean"] for v in results.values()]
overall = {"psnr_mean": float(np.mean(all_p)), "ssim_mean": float(np.mean(all_s))}

# Print table
print(f"\n{'='*70}")
print(f"Results: {EXPERIMENT_NAME}")
print(f"{'='*70}")
print(f"{'Scenario':<12} {'Missing':<15} {'Available':<15} {'PSNR':>10} {'SSIM':>10}")
print(f"{'-'*70}")
for key, val in results.items():
    print(f"{key:<12} {','.join(val['missing']):<15} {','.join(val['available']):<15} "
          f"{val['psnr_mean']:>7.2f}+-{val['psnr_std']:.2f} "
          f"{val['ssim_mean']:>7.4f}+-{val['ssim_std']:.4f}")
print(f"{'-'*70}")
print(f"{'OVERALL':<42} {overall['psnr_mean']:>7.2f}       {overall['ssim_mean']:>7.4f}")
print(f"{'='*70}")

# Save
metrics_path = os.path.join(RESULTS_DIR, "metrics.json")
with open(metrics_path, "w") as f:
    json.dump({"per_scenario": results, "overall": overall}, f, indent=2)
print(f"\nMetrics saved to: {metrics_path}")

## Download Artifacts

Package checkpoints and results into a zip file for download. Use this to:
- Transfer checkpoints to the next session (if training is not complete).
- Download final results for comparison between baseline and optimized experiments.

In [None]:
# Package artifacts for download
import shutil
artifact_name = f"mmgan_{EXPERIMENT_NAME}_artifacts"
shutil.make_archive(
    f"/kaggle/working/{artifact_name}",
    "zip",
    "/kaggle/working",
    "."
)
print(f"Artifacts packaged: /kaggle/working/{artifact_name}.zip")
print("Download this file from the Output tab before session ends!")