> **Note:** Use the provided YAML config exactly as-is to replicate reported results.

In [None]:
!pip install -q -U kagglehub torchio wandb

## Imports & Global configs.

In [None]:
import os
import random
from pathlib import Path
from typing import List, Literal

import kagglehub
import nibabel as nib
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchio as tio
import wandb
import yaml

from collections import defaultdict
from pprint import pprint
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from typing import Optional

awsaf49_brats20_dataset_training_validation_path = kagglehub.dataset_download('awsaf49/brats20-dataset-training-validation')
print('Data source import complete.')
print("Path to dataset files:", awsaf49_brats20_dataset_training_validation_path)

def load_config(config_path):
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

config = load_config('./config_unet.yaml')
config['hardware']['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
pprint(config, indent=4, width=80, compact=False)

In [None]:
def _center_crop(volume: torch.Tensor, target_shape: tuple) -> torch.Tensor:
    """
    Center-crops a 3D volume to the target shape in (H, W, D) order.

    Parameters:
        volume (torch.Tensor): Input volume of shape (..., H, W, D) or similar
        target_shape (tuple): Target shape as (tH, tW, tD)

    Returns:
        torch.Tensor: Cropped volume
    """
    h, w, d = volume.shape[-3:]
    tH, tW, tD = target_shape

    h_start = (h - tH) // 2
    w_start = (w - tW) // 2
    d_start = (d - tD) // 2

    return volume[...,
                  h_start:h_start+tH,
                  w_start:w_start+tW,
                  d_start:d_start+tD]

def apply_augmentation(
    multimodal_tensor: torch.Tensor,
    segmentation_tensor: torch.Tensor = None, # Made optional for Autoencoder use
    modalities_list = ['t1', 't1ce', 't2', 'flair'],
    geometric_transforms=None,
    intensity_transforms=None
):
    """
    Apply augmentation pipeline.

    Returns:
        Tuple of (transformed_multimodal, transformed_segmentation, foreground_mask)
    """

    # Create TorchIO Subject
    subject_dict = {}
    for i, mod_name in enumerate(modalities_list):
        subject_dict[mod_name] = tio.ScalarImage(tensor=multimodal_tensor[i:i+1]) # Slicing instead of indexing to preserver dim=0

    # Only add segmentation to subject if it is provided
    if segmentation_tensor is not None:
        subject_dict['seg'] = tio.LabelMap(tensor=segmentation_tensor.unsqueeze(0))

    subject = tio.Subject(**subject_dict)

    # Apply geometric transforms
    if geometric_transforms is not None:
        subject = geometric_transforms(subject)

    # Extract foreground mask after geometric transforms
    # This is the important mask we will use for everything else.
    mask = None
    if intensity_transforms is not None:
        # We can create the mask from any of the modalities, BraTS MRIS are co-registered.
        for mod_name in modalities_list:
            mask = subject[mod_name].data > 0
            break

    # Apply intensity transforms
    if intensity_transforms is not None and mask is not None:
        subject = intensity_transforms(subject)

        # Reapply mask to zero out background
        for mod_name in modalities_list:
            # subject[mod_name].data = subject[mod_name].data * mask  [DeprecationWarning]
            new_data = subject[mod_name].data * mask
            subject[mod_name].set_data(new_data)

    # Extract transformed tensors
    multimodal_list = [subject[mod].data for mod in modalities_list]
    transformed_multimodal = torch.cat(multimodal_list, dim=0)

    # Handle the optional segmentation extraction
    transformed_segmentation = None
    if 'seg' in subject:
        transformed_segmentation = subject['seg'].data.squeeze(0)

    # Return the generated mask along with the tensors
    # If no intensity transforms were applied, the mask will be None.
    if mask is None:
        mask = transformed_multimodal > 0

    return transformed_multimodal, transformed_segmentation, mask

## Dataset Definition
### Helpful infos

The pipeline always downloads the **original BraTS2020 dataset** from Kaggle as the primary data source.

**Segmentation loading modes (`seg_mode`):**

The `seg_mode` parameter controls how segmentation masks are loaded, offering a trade-off between speed and storage:

- **`"raw"`** (slowest, minimal storage):
  - Loads compressed NIfTI segmentation masks (`.nii`) directly from the original dataset
  - No preprocessing directory required

- **`"preprocessed"`** (fastest, requires storage):
  - Loads pre-generated uncompressed `.npy` segmentation masks from `preprocessed_root_dir`
  - If preprocessed files are missing or incomplete, automatically triggers preprocessing
  - Preprocessing: converts `.nii` to `.npy`, applies label remapping, and optionally generates `assign_label5`

- **`"force_generate"`** (for updates):
  - Regenerates `.npy` masks, even if they already exist in `preprocessed_root_dir`
  - Use this when preprocessing logic changes (e.g., toggling `assign_label5`)
  - Training-wise, identical to `preprocessed`

**Preprocessing Directory:**
- `preprocessed_root_dir` can be explicitly set or defaults to `<root_dir>/processed`
- Stores only segmentation masks as `.npy` files (MRI images are always loaded from original `.nii` to save storage)

In [None]:
class BraTSDataset(Dataset):
    """
    BraTS2020 Dataset for 3D multi-modal MRI brain tumor segmentation.
    {Verbose docstring for clarity.}

    Args:
        split (str):
            Dataset split to load. Must be either "train" or "test".
            Augmentations are applied only when split="train".

        train_ratio (float):
            Fraction of the dataset used for training. The remainder is used
            for testing after shuffling with a fixed random seed.

        root_dir (str):
            Root directory containing the BraTS2020 dataset. Required.

        preprocessed_root_dir (Optional[str]):
            Root directory used to store and/or load preprocessed segmentation
            masks saved as .npy files. If None, defaults to '<root_dir>/processed'.

        modalities_list (List[str]):
            List of MRI modalities to load for each subject (e.g., ["t1",
            "t1ce", "t2", "flair"]). The order defines the channel order in
            the output tensor.

        seg_mode (Literal["raw", "preprocessed", "force_generate"]):
            Controls how segmentation masks are loaded:
                1) "raw": Load original compressed NIfTI segmentation masks [slowest].
                2) "preprocessed": Load pre-generated .npy segmentation masks [fastest].
                  If the preprocessing directory is missing or empty, all
                  masks are generated and locally stored.
                3) "force_generate": Always regenerate .npy segmentation masks
                  before loading (even if preprocessed_root_dir is not empty).
                  [Required if preprocessing logic is updated.]

        assign_label5 (bool):
            If True, during preprocessing assigns a separate label to non-tumor
            brain tissue. It is referred to as "label 5" to follow BraTS
            conventions (labels {0,1,2,4}), but after remapping (4 >> 3) it is
            stored internally as label 4.
            [assign_label5 in "raw" mode is not supported.]

        geometric_transforms:
            TorchIO spatial transforms applied jointly to image and
            segmentation volumes during training. Applied before cropping.

        intensity_transforms:
            TorchIO intensity transforms applied to image volumes during
            training. Applied after cropping.

        output_shape (Tuple[int, int, int]):
            Target spatial shape (H, W, D) for center cropping of images,
            segmentations, and foreground masks.

        norm_type (str):
            Intensity normalization strategy. Supported values are:
                - "z": Z-normalization within the foreground mask.
                - "minmax": Min-max normalization.

        random_seed (int):
            Random seed used for reproducible shuffling and train/test
            splitting.
    """

    def __init__(
        self,
        split: str = "train",
        train_ratio: float = 0.8,
        root_dir: str = "../kaggle/input/brats20-dataset-training-validation/",
        preprocessed_root_dir: Optional[str] = None,
        modalities_list: List[str] = ["t1", "t1ce", "t2", "flair"],
        seg_mode: Literal["raw", "preprocessed", "force_generate"] = "raw",
        assign_label5: bool = True,
        geometric_transforms=None,
        intensity_transforms=None,
        output_shape: tuple = (128, 128, 128),
        norm_type: str = "z",
        random_seed: int = 42,
    ):
        self.split = split
        self.modalities_list = modalities_list
        self.seg_mode = seg_mode
        self.assign_label5 = assign_label5
        self.geometric_transforms = geometric_transforms
        self.intensity_transforms = intensity_transforms
        self.output_shape = output_shape
        self.norm_type = norm_type

        if root_dir is None:
            raise ValueError("root_dir must be provided")

        self.root_dir = Path(root_dir)

        self.preprocessed_root = (
            Path(preprocessed_root_dir)
            if preprocessed_root_dir is not None
            else self.root_dir / "processed"
        )

        self.samples_path = ( # Default path
            self.root_dir
            / "BraTS2020_TrainingData"
            / "MICCAI_BraTS2020_TrainingData"
        )

        if not self.samples_path.exists():
            raise FileNotFoundError(f"Raw data not found at {self.samples_path}")

        self.preprocessed_path = ( # Preprocessed path, this is either passed or built from the default path
            self.preprocessed_root
            / "BraTS2020_TrainingData"
            / "MICCAI_BraTS2020_TrainingData"
        )

        total_samples = 369
        all_indices = list(range(1, total_samples + 1))
        random.seed(random_seed)
        random.shuffle(all_indices)

        split_idx = int(train_ratio * total_samples)
        if split == "train":
            self.sample_indices = all_indices[:split_idx]
        elif split == "test":
            self.sample_indices = all_indices[split_idx:]
        else:
            raise ValueError("split must be 'train' or 'test'")

        if self.seg_mode == "force_generate":
            print("Force generating preprocessed data...")
            self._preprocess_dataset()
        elif self.seg_mode == "preprocessed":
            # Check if all samples are preprocessed (count directories)
            if not self.preprocessed_path.exists():
                num_preprocessed = 0
            else:
                num_preprocessed = len([d for d in self.preprocessed_path.iterdir() if d.is_dir()])

            if num_preprocessed < total_samples:
                print(f"Preprocessed files incomplete ({num_preprocessed}/{total_samples}). Generating now...")
                self._preprocess_dataset()
            else:
                print(f"All preprocessed data verified ({total_samples} samples).")

        self.use_npy = (self.seg_mode in ["preprocessed", "force_generate"]) # Preprocessing works with .npy

    def _preprocess_dataset(self):
        """
        Converts NIfTI segmentations to uncompressed .npy for higher speed data loading.
        Remaps labels (4 -> 3) and optionally generates non-tumor tissue [label 5 (label 4 after remapping)].
        """
        self.preprocessed_path.mkdir(exist_ok=True, parents=True)

        for idx in tqdm(self.sample_indices, desc="Preprocessing"):
            sid = f"{idx:03d}"
            name = f"BraTS20_Training_{sid}"
            raw_path = self.samples_path / name
            out_path = self.preprocessed_path / name
            out_path.mkdir(exist_ok=True, parents=True)

            if idx == 355:  # Handle known naming inconsistency in dataset [Colab doesn't allow to rename it.]
                seg_path = raw_path / "W39_1998.09.19_Segm.nii"
            else:
                seg_path = raw_path / f"{name}_seg.nii"

            seg = self._load_nifti_volume(seg_path)
            seg[seg == 4] = 3  # Remap 4 >> 3

            if self.assign_label5:
                mri = self._load_nifti_volume(raw_path / f"{name}_{self.modalities_list[0]}.nii")
                label4_mask = (seg == 0) & (mri != 0)
                seg[label4_mask] = 4

            np.save(out_path / f"{name}_seg.npy", seg)

    def __len__(self):
        return len(self.sample_indices)

    def __getitem__(self, idx):
        idx_val = self.sample_indices[idx]
        sample_id = f"{idx_val:03d}"
        sample_name = f"BraTS20_Training_{sample_id}"

        raw_sample_path = self.samples_path / sample_name
        prep_sample_path = self.preprocessed_path / sample_name

        img = self._load_raw_mri(raw_sample_path, sample_name) # MRI is always loaded as .nii
        # Applying the same pre-processing logic here would mean to save uncompressed
        # 4-modalities 3D MRIs, fp32 precision, this blows memory.

        # Unified logic: Load .npy if we are in 'preprocessed' or 'force_generate' mode
        if self.use_npy:
            seg_path = prep_sample_path / f"{sample_name}_seg.npy"
            seg = torch.from_numpy(np.load(seg_path))
        else:
            # Fallback: Load original .nii segmentation
            # Note: this is slower and inefficient. But requires much less storage.
            if idx_val == 355:  # Handle known naming inconsistency in dataset [Colab doesn't allow to rename it.]
                seg_path = raw_sample_path / "W39_1998.09.19_Segm.nii"
            else:
                seg_path = raw_sample_path / f"{sample_name}_seg.nii"
            seg_np = self._load_nifti_volume(seg_path)
            seg_np[seg_np == 4] = 3

            if self.assign_label5:
                mri_np = img[0].numpy()
                label5_mask = (seg_np == 0) & (mri_np != 0)
                seg_np[label5_mask] = 4
            seg = torch.from_numpy(seg_np)
            # To support assign_label5 in "raw" mode we could either
            # 1) Perform it on the fly (as above^^^) -> Redundant,at every epoch we would be performing same operations on the same samples.
            # 2) Pre-process it and overwrite the original .nii files -> Colab doesn't allow to modify those write-only files.
            # 1 is implemented but acknowledge it's suboptimal.

        foreground_mask = None
        if self.split == "train" and (self.geometric_transforms or self.intensity_transforms):
            img, seg, foreground_mask = apply_augmentation(
                img, seg, self.modalities_list, self.geometric_transforms, self.intensity_transforms
            )
            foreground_mask = foreground_mask.squeeze(0)

        if self.output_shape is not None:
            img = _center_crop(img, self.output_shape)
            seg = _center_crop(seg, self.output_shape)
            if foreground_mask is not None:
                foreground_mask = _center_crop(foreground_mask, self.output_shape)

        if foreground_mask is None:
            foreground_mask = img[0] > 0

        img = self._normalize(img, foreground_mask)
        return img, seg

    def _load_raw_mri(self, sample_path: Path, sample_name: str) -> torch.Tensor:
        volumes = []
        for mod in self.modalities_list:
            path = sample_path / f"{sample_name}_{mod}.nii"
            volumes.append(self._load_nifti_volume(path))
        return torch.from_numpy(np.stack(volumes, axis=0))

    def _normalize(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        if self.norm_type == "z":
            for c in range(tensor.shape[0]):
                if mask.any():
                    mean = tensor[c][mask].mean()
                    std = tensor[c][mask].std()
                    if std > 0:
                        tensor[c][mask] = (tensor[c][mask] - mean) / std
        elif self.norm_type == "minmax":
            subject = tio.Subject(image=tio.ScalarImage(tensor=tensor))
            transform = tio.RescaleIntensity(
                out_min_max=(0, 1),
                percentiles=(0.0, 99.5),
            )
            tensor = transform(subject).image.data
        else:
            raise ValueError(f"Unknown norm_type: {self.norm_type}")
        return tensor

    @staticmethod
    def _load_nifti_volume(path: Path) -> np.ndarray:
        nii = nib.load(str(path))
        return np.asarray(nii.dataobj, dtype=np.float32)

## Split

In [None]:
def get_dataloaders(
    train_ratio: float = 0.8,
    train_batch_size: int = 1,
    test_batch_size: int = 1,
    geometric_transforms=None,
    intensity_transforms=None,
    output_shape: tuple = (128, 128, 128),
    norm_type: str = "z",
    seed: int = 42,
    root_dir: str = "./",
    preprocessed_root_dir: str | None = None,
    seg_mode: Literal["raw", "preprocessed", "force_generate"] = "raw",
    assign_label5: bool = False,
    num_workers: int = 0,
):
    """
    Create train and test dataloaders for BraTS20Dataset.

    Returns:
        train_dataloader, test_dataloader
    """

    train_dataset = BraTSDataset(
        split="train",
        train_ratio=train_ratio,
        root_dir=root_dir,
        preprocessed_root_dir=preprocessed_root_dir,
        seg_mode=seg_mode,
        assign_label5=assign_label5,
        geometric_transforms=geometric_transforms,
        intensity_transforms=intensity_transforms,
        output_shape=output_shape,
        norm_type=norm_type,
        random_seed=seed,
    )

    test_dataset = BraTSDataset(
        split="test",
        train_ratio=train_ratio,
        root_dir=root_dir,
        preprocessed_root_dir=preprocessed_root_dir,
        seg_mode=seg_mode,
        assign_label5=assign_label5,
        geometric_transforms=None,
        intensity_transforms=None,
        output_shape=output_shape,
        norm_type=norm_type,
        random_seed=seed,
    )

    print(f"\nTrain dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True,
        drop_last=False,
        num_workers=num_workers,
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=test_batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers,
    )

    return train_dataloader, test_dataloader

In [None]:
geometric_transform = tio.Compose([
    tio.RandomFlip(axes=('LR',), flip_probability=0.5),
    tio.RandomAffine(
    scales=(0.6, 1.4),
    degrees=(15, 15, 15),
    translation=(5, 5, 5),
    isotropic=False,
    p=0.35
)
])

intensity_transform = tio.Compose([
    tio.RandomGamma(log_gamma=(-0.3, 0.3), p=0.3),
])

data_cfg = config['data']
train_cfg = config['training']

train_dataloader, test_dataloader = get_dataloaders(
    train_ratio=data_cfg["train_ratio"],
    train_batch_size=train_cfg["train_batch_size"],
    test_batch_size=train_cfg["test_batch_size"],
    output_shape=tuple(data_cfg["output_shape"]),  # YAML list > tuple
    norm_type=data_cfg["norm_type"],
    root_dir=awsaf49_brats20_dataset_training_validation_path,
    preprocessed_root_dir=data_cfg["preprocessed_dir"],
    seg_mode=data_cfg["seg_mode"],
    assign_label5=data_cfg["assign_label5"],
    geometric_transforms=geometric_transform,
    intensity_transforms=intensity_transform,
)

## Model Architecture


In [None]:
class SEBlock3d(nn.Module):
    def __init__(self, channels, reduction=8, min_bottleneck=4):
        super().__init__()
        bottleneck = max(channels // reduction, min_bottleneck)

        self.squeeze = nn.AdaptiveAvgPool3d(1) # (B, C, 1, 1, 1)

        self.excitation = nn.Sequential(
            nn.Linear(channels, bottleneck, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(bottleneck, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _, _ = x.shape
        y = self.squeeze(x).view(b, c) # view squeezes out the 1,1,1
        y = self.excitation(y).view(b, c, 1, 1, 1) #  view puts back the 1,1,1 to align to x.shape
        return x * y

class nConvBlock3d(nn.Module):
    """
    Residual 3D convolutional block with N stacked convolutions.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        kernel_size (int): Size of the 3D convolution kernel. Default is 3.
        n_convs (int): Number of convolutional layers in the block.
        use_se (bool): Whether to apply a Squeeze-and-Excitation block.
        mid_channels (int, optional): Number of channels used for intermediate
            convolutions. Defaults to out_channels. This allows a, optional,
            gradual transition from in_channels to out_channels.

    Output:
        Tensor: Output tensor of shape (B, out_channels, H, W, D).
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, n_convs=2, use_se=False, mid_channels=None):
        super().__init__()
        self.n_convs = n_convs

        # If mid_channels not provided, use out_channels
        if mid_channels is None:
            mid_channels = out_channels

        layers = []
        for i in range(n_convs):
            in_ch = in_channels if i == 0 else mid_channels
            out_ch = out_channels if i == n_convs - 1 else mid_channels
            layers.append(nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding='same'))
            layers.append(nn.GroupNorm(num_groups=8, num_channels=out_ch))
            layers.append(nn.LeakyReLU(0.01, inplace=True))

        self.conv_sequence = nn.Sequential(*layers)
        self.se = SEBlock3d(out_channels) if use_se else nn.Identity()

    def forward(self, x, custom_skip_tensor=None):
        if custom_skip_tensor is not None:
            skip_connection = custom_skip_tensor
        else:
            skip_connection = x
        x = self.conv_sequence(x)
        x = self.se(x)

        return x + skip_connection

class Down3d(nn.Module):
    def __init__(self, in_channels, out_channels, downsample_mode=None):
        super(Down3d, self).__init__()

        if downsample_mode:
            if downsample_mode == 'conv': # Learnable down-sampling
                self.downsample = nn.Sequential(
                    nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                    nn.GroupNorm(num_groups=8, num_channels=out_channels),
                    nn.LeakyReLU(0.01, inplace=True)
                )
            elif downsample_mode == 'pool':
                # Downsample with pooling + 1x1 conv.
                self.downsample = nn.Sequential(
                    nn.MaxPool3d(kernel_size=2, stride=2),
                    nn.Conv3d(in_channels, out_channels, kernel_size=1),
                    nn.GroupNorm(num_groups=8, num_channels=out_channels),
                    nn.LeakyReLU(0.01, inplace=True)
                )
            else:
                raise ValueError(f"Unknown downsample_mode: {downsample_mode}")
        else: # No spatial downsampling, used in the initial encoder for the chann expansion.
            self.downsample = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1),
                nn.GroupNorm(num_groups=8, num_channels=out_channels),
                nn.LeakyReLU(0.01, inplace=True)
            )

    def forward(self, x):
        return self.downsample(x)


class Up3d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, stride=2):
        super().__init__()
        self.upconv = nn.ConvTranspose3d(in_channels, out_channels,
                                         kernel_size=kernel_size, stride=stride)

    def forward(self, x, skip):
        x_upsampled = self.upconv(x)

        if x_upsampled.shape[2:] != skip.shape[2:]:
            _, _, H, W, D = skip.shape
            _, _, H_chunk, W_chunk, D_chunk = x_upsampled.shape
            H_start = (H - H_chunk) // 2
            W_start = (W - W_chunk) // 2
            D_start = (D - D_chunk) // 2
            skip = skip[:, :, H_start:H_start+H_chunk, W_start:W_start+W_chunk, D_start:D_start+D_chunk]

        x_concatenated = torch.cat([x_upsampled, skip], dim=1)

        return x_concatenated, x_upsampled


class EncoderLayer3d(nn.Module):
    def __init__(self, in_channels, out_channels, downsample_mode='conv', n_convs=2, use_skip_se=False, mid_channels=None):
        super().__init__()

        self.downsample = Down3d(in_channels, out_channels, downsample_mode=downsample_mode)
        self.n_convblock = nConvBlock3d(out_channels, out_channels, use_se=False, n_convs=n_convs, mid_channels=mid_channels) # use_se=False on the Encoder's path.
        self.skip_se = SEBlock3d(out_channels) if use_skip_se else nn.Identity()

    def forward(self, x):
        x_down = self.downsample(x)          # pre-SE features for downward path
        x_down = self.n_convblock(x_down)
        skip = self.skip_se(x_down)          # skip connection with SE if enabled
        return x_down, skip

class DecoderLayer3d(nn.Module):
    def __init__(self, in_channels, out_channels, n_convs=2, use_se=False, mid_channels=None):
        super().__init__()
        self.n_convs=n_convs
        self.up = Up3d(in_channels, out_channels, kernel_size=2, stride=2) # Spatial upsampling and channel reduction.
        self.n_convblock = nConvBlock3d(2*out_channels, out_channels, n_convs=n_convs, use_se=use_se, mid_channels=mid_channels)
        # ^^^ 2*out_channels as input because we perform a concatenation with the skip first.

    def forward(self, x, skip):
        x_concatenated, x_upsampled = self.up(x, skip)
        return self.n_convblock(x_concatenated, custom_skip_tensor=x_upsampled)


class cvNet2(nn.Module): # cv -> Colle Vincenzo (^_~)
    """
    3D U-Net + SE + encoder/decoder input-output residuals.
    """
    def __init__(self, in_channels=4, out_channels=4,
                 features=[32, 64, 128, 256, 320],
                 n_convs=[1, 2, 3, 3, 3],
                 downsample_mode='conv',
                 use_skip_se=True,
                 use_decoder_se=True):
        super().__init__()

        skip_se_config = [use_skip_se and f >= 0 for f in features]  # >= n: tinkering, idea was to apply SE only for deeper layers.
        # decoder_se_config = [use_decoder_se for f in features[:-1]]

        if len(n_convs) < len(features):
            n_convs = n_convs + [n_convs[-1]] * (len(features) - len(n_convs))
        assert len(features) == len(n_convs)

        encoder_n_convs = n_convs[1:] # exclude initial encoder layer
        decoder_n_convs = encoder_n_convs[::-1][:-1] # mirror, remove last for final
        # Example of ^^^
        # n_convs = [1, 2, 3, 3, 3]
        # encoder_n_convs = [2, 3, 3, 3] -> There is one more because it handles the bottleneck as well.
        # decoder_n_convs = [3, 3, 2]

        self.initial_encoder_layer = EncoderLayer3d(in_channels, features[0],
                                        downsample_mode=None,
                                        use_skip_se=skip_se_config[0],
                                        n_convs=n_convs[0]) # No spatial downsampling

        self.encoder = nn.ModuleList([
            EncoderLayer3d(features[i], features[i+1],
                    downsample_mode=downsample_mode,
                    use_skip_se=skip_se_config[i+1], # i+1 since intial encoder layer is not defined here.
                    n_convs=n_convs[i+1])
            for i in range(len(features)-1)
        ])

        self.decoder = nn.ModuleList([
            DecoderLayer3d(features[i], features[i-1],
                    use_se=use_decoder_se,
                    n_convs=decoder_n_convs[idx])
            for idx, i in enumerate(reversed(range(2, len(features))))
        ])

        self.final_decoder_layer = DecoderLayer3d(features[1], features[0],
                                    use_se=use_decoder_se,
                                    n_convs=n_convs[0]) # Split from the rest for symmetry.

        self.output = nn.Conv3d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        x, skip = self.initial_encoder_layer(x)
        skip_connections = [skip]

        for enc_layer in self.encoder:
            x, skip = enc_layer(x)
            skip_connections.append(skip)

        skip_connections = skip_connections[:-1][::-1]

        for dec_layer, skip in zip(self.decoder, skip_connections):
            x = dec_layer(x, skip)

        x = self.final_decoder_layer(x, skip_connections[-1])
        return self.output(x)

## Loss Functions & Metrics

In [None]:
def dice_score(
    pred: torch.Tensor,
    target: torch.Tensor,
    eps: float = 1e-7,
):
    """
    Computes per-sample Dice score for binary masks.

    Args:
        pred (Tensor): Prediction mask of shape (B, ...)
        target (Tensor): Target mask of shape (B, ...)
        eps (float): Numerical stability term

    Returns:
        Tensor: Dice score per sample, shape (B,)
    """
    assert pred.shape == target.shape

    bs = pred.shape[0]
    pred_flat = pred.reshape(bs, -1)
    target_flat = target.reshape(bs, -1)

    intersection = (pred_flat * target_flat).sum(dim=1)
    union = pred_flat.sum(dim=1) + target_flat.sum(dim=1)

    return (2.0 * intersection + eps) / (union + eps)

def dice_loss(pred, target):
    """
    Soft Dice loss for binary segmentation.
    """
    return 1.0 - dice_score(pred, target).mean()

def multiclass_soft_dice_loss(logits, target, ignore_idxs=None):
    """
    Soft Dice loss averaged over classes.
    """
    if ignore_idxs is None:
        ignore_idxs = []

    probs = F.softmax(logits, dim=1)
    num_classes = probs.shape[1]

    dice_losses = []

    for c in range(num_classes):
        if c in ignore_idxs:
            continue

        target_c = (target == c).float()
        dice_c = dice_score(probs[:, c], target_c)
        dice_losses.append(1.0 - dice_c.mean())

    if not dice_losses:  # Edge case, for ignore_idx == every idx, never gonna happen but still
        return torch.tensor(0.0, device=logits.device)

    return sum(dice_losses) / len(dice_losses)

def multiclass_soft_dice_ce_loss(
    logits: torch.Tensor,
    target: torch.Tensor,
    ignore_idxs_dice: list[int] | None = None,
    ignore_idxs_ce: list[int] | None = None,
    alpha: float = 1.0,
    beta: float = 0.0,
    ce_class_weights: torch.Tensor | None = None
):
    """
    Computes a combination of multi-class soft Dice loss and Cross-Entropy loss.

    Args:
        logits (Tensor): shape (B, C, H, W)
        target (Tensor): shape (B, H, W)
        ignore_idxs_dice (list[int] | None): classes to ignore for Dice loss
        ignore_idxs_ce (list[int] | None): classes to ignore for CE loss
        alpha (float): weight for Dice loss
        beta (float): weight for CE loss
        ce_class_weights (Tensor | None): optional class weights for CE loss

    Returns:
        Tuple[Tensor, Tensor, Tensor]: total_loss, dice_loss, ce_loss
    """
    if ignore_idxs_dice is None:
        ignore_idxs_dice = []
    if ignore_idxs_ce is None:
        ignore_idxs_ce = []

    dice = multiclass_soft_dice_loss(logits, target, ignore_idxs_dice)

    if beta > 0.0:
        n_classes = logits.shape[1]
        # Default to unweighted if none provided
        if ce_class_weights is None:
            ce_class_weights = torch.ones(n_classes, device=logits.device)

        # Zero out weights for ignored CE indices (ignore_idx parameter for CrossEntropyLoss by PyTorch doesn't support lists for ignore_idx parameter)
        for idx in ignore_idxs_ce:
            if 0 <= idx < n_classes:
                ce_class_weights[idx] = 0.0

        ce = nn.CrossEntropyLoss(weight=ce_class_weights)
        ce_loss = ce(logits, target)
    else:
        ce_loss = torch.tensor(0.0, device=logits.device)

    total = alpha * dice + beta * ce_loss
    return total, dice, ce_loss

In [None]:
class ComboLoss(nn.Module):
    """
    Wrapper for multiclass_soft_dice_ce_loss to decouple Trainer from hyperparameters.

    The Trainer remains 'blind' to specific parameters like alpha and beta, allowing it
    to support any loss function following a standard interface. This prevents
    hardcoding specific loss logic inside the training loop.
    """
    def __init__(self, dice_weight=1.0, ce_weight=0.0, ignore_idxs_dice=None, ignore_idxs_ce=None, ce_class_weights=None):
        super().__init__()
        self.alpha = dice_weight
        self.beta = ce_weight
        self.ignore_idxs_dice = ignore_idxs_dice or []
        self.ignore_idxs_ce = ignore_idxs_ce or []
        if ce_class_weights is not None:
            self.ce_class_weights = torch.tensor(ce_class_weights, dtype=torch.float32)
        else:
            self.ce_class_weights = None

        if self.alpha > 0 and self.beta == 0:
            self.loss_type = "Dice Loss"
        elif self.alpha == 0 and self.beta > 0:
            self.loss_type = "Cross Entropy Loss"
        else:
            self.loss_type = "Combo Loss (Dice + CE)"

        print(f"Loss configured: {self.loss_type} | Alpha: {self.alpha}, Beta: {self.beta}")

    def forward(self, logits, target):
        ce_class_weights = self.ce_class_weights.to(logits.device) if self.ce_class_weights is not None else None
        return multiclass_soft_dice_ce_loss(
            logits,
            target,
            ignore_idxs_dice=self.ignore_idxs_dice,
            ignore_idxs_ce=self.ignore_idxs_ce,
            alpha=self.alpha,
            beta=self.beta,
            ce_class_weights=ce_class_weights
        )

In [None]:
def per_class_metrics(input, target, split='train'):
    """
    Computes per-sample Dice, Jaccard, entropy and mean class probability
    for each class, skipping ignored classes and samples without the class.
    """
    C = input.shape[1] ## This is 'dangerous'
    eps = 1e-7
    hard_pred = input.argmax(dim=1)
    per_class_metrics = {}

    probs = F.softmax(input, dim=1)
    entropy_map = -torch.sum(probs * torch.log(probs + eps), dim=1) # Computes Shannon entropy per-voxel
    correct_class_probs = torch.gather(probs, 1, target.unsqueeze(1)).squeeze(1) # Model’s predicted probability for the true class at each voxel
    #  torch.gather > take values from input along dimension dim at indices specified by index
    # entropy_map and correct_class_probs complement each other very well in understanding the model's behavior.
    for c in range(C):
        bs = input.shape[0]
        sample_mask = (target == c).view(bs, -1).sum(dim=1) > 0 # Flattens to (B, ...)
        if sample_mask.sum() == 0: # If sample_mask is all False, the current class is absent in every sample of the batch -> Nothing to eval.
            continue

        target_c_samples = target[sample_mask] # Filter to valid samples only (where the class exists)
        pred_c_samples = hard_pred[sample_mask]

        target_mask = (target_c_samples == c)
        pred_mask = (pred_c_samples == c)

        # Dice per sample
        dice = dice_score(pred_mask, target_mask, eps)
        # Compute Jaccard from Dice
        jaccard = dice / (2.0 - dice + eps)

        per_class_metrics[f'labels_{split}/class_{c}_dice'] = dice.tolist()
        per_class_metrics[f'labels_{split}/class_{c}_jaccard'] = jaccard.tolist()

        # entropy & mean probability
        entropy_c = entropy_map[sample_mask][target_mask]
        prob_c = correct_class_probs[sample_mask][target_mask]

        mean_entropy = entropy_c.mean()
        mean_prob = prob_c.mean()

        per_class_metrics[f'entropies_{split}/class_{c}'] = mean_entropy.item()
        per_class_metrics[f'avg_prob_{split}/class_{c}'] = mean_prob.item()

    return per_class_metrics

def bra_ts_region_metrics(input, target, split='train'):
    """
    Computes per-sample Dice and Jaccard for BraTS WT, TC, ET.
    """
    pred = input.argmax(dim=1)
    eps = 1e-7

    # Define regions as binary masks
    regions = {
        'wt': ((pred == 1) | (pred == 2) | (pred == 3), (target == 1) | (target == 2) | (target == 3)),
        'tc': ((pred == 1) | (pred == 3), (target == 1) | (target == 3)),
        'et': (pred == 3, target == 3)
    }

    metrics = {}
    for name, (p, t) in regions.items():
        bs = t.shape[0]
        sample_mask = t.view(bs, -1).sum(dim=1) > 0 # Flattens to (B, ...)
        if sample_mask.sum() == 0: # If sample_mask is all False, the current region is absent in every sample of the batch -> Nothing to eval.
            continue

        p_masked = p[sample_mask] # Filter to valid samples only (where the region exists)
        t_masked = t[sample_mask]

        # Dice per sample
        dice = dice_score(p_masked, t_masked, eps)
        # Jaccard from Dice
        jaccard = dice / (2.0 - dice + eps)

        metrics[f'brats_{split}/{name}_dice'] = dice.tolist()
        metrics[f'brats_{split}/{name}_jaccard'] = jaccard.tolist()

    return metrics

In [None]:
class MetricAccumulator:
    def __init__(self):
        self.sums = defaultdict(float)
        self.counts = defaultdict(int)

    def add(self, metrics: dict):
        for k, v in metrics.items():
            if isinstance(v, list):
                self.sums[k] += sum(v)
                self.counts[k] += len(v)
            else:
                self.sums[k] += v
                self.counts[k] += 1

    def mean(self, key: str):
        return self.sums[key] / self.counts[key]

    def all_means(self):
        return {k: self.mean(k) for k in self.sums}

    def clear(self):
        self.sums.clear()
        self.counts.clear()

## Trainer

In [None]:
class Trainer:
    def __init__(self, model, optimizer, scheduler, criterion, config, run,
                 metrics_fn, brats_fn, device):
        """
        Refactored Trainer with Optimized Interval Logging.
        """
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.config = config
        self.run = run
        self.device = device
        self.acc_steps = self.config['training'].get('accumulation_steps', 1)
        self.use_wandb = config['experiment'].get('wandb_logging', False)

        self.metrics_fn = metrics_fn
        self.brats_fn = brats_fn

        self.best_dice_composite = 0.0
        self.best_model_path = None

    def train_epoch(self, dataloader, epoch, global_step):
        self.model.train()
        acc = MetricAccumulator()

        step_times = []
        epoch_start = time.time()

        pbar = tqdm(dataloader, desc=f"Epoch {epoch}/{self.config['training']['epochs']}")

        for i, (batch, labels) in enumerate(pbar):
            step_start = time.time()

            batch = batch.to(self.device)
            labels = labels.to(self.device, dtype=torch.long)

            outputs = self.model(batch)
            loss, dice_comp, ce_comp = self.criterion(outputs, labels) # Loss is a single scalar

            is_last_batch = (i + 1) == len(dataloader)
            is_acc_step = (i + 1) % self.acc_steps == 0

            (loss / self.acc_steps).backward()

            with torch.no_grad():
                pc_m = self.metrics_fn(outputs.detach(), labels, split='train')
                br_m = self.brats_fn(outputs.detach(), labels, split='train')

                acc.add({
                    "main/train_loss": [loss.item()],
                    "loss_components/train_dice_loss": [dice_comp.item()],
                    "loss_components/train_ce_loss": [ce_comp.item()]
                })

                # Accumulate metrics
                acc.add(pc_m)
                acc.add(br_m)

            if is_acc_step or is_last_batch: # Either step when it's time or when there are no more sample to process for this epoch.
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.scheduler.step()

                global_step += 1

                log_dict = acc.all_means()
                log_dict["loss_components/train_ce_dice_ratio"] = (
                    log_dict["loss_components/train_ce_loss"] /
                    (log_dict["loss_components/train_dice_loss"] + 1e-8)
                )
                log_dict["main/lr"] = self.optimizer.param_groups[0]["lr"]
                log_dict["main/epoch"] = epoch
                self.log(log_dict, step=global_step, split="train")

                acc.clear()

                if not self.use_wandb:
                    pbar.set_postfix({
                        "Loss": f"{log_dict['main/train_loss']:.3f}",
                        "WT": f"{log_dict.get('brats_train/wt_dice', 0.0):.3f}",
                        "TC": f"{log_dict.get('brats_train/tc_dice', 0.0):.3f}",
                        "ET": f"{log_dict.get('brats_train/et_dice', 0.0):.3f}",
                    })

            step_times.append(time.time() - step_start)

        self.log({
            "timing/epoch_train_sec": time.time() - epoch_start,
            "timing/avg_step_sec": sum(step_times) / len(step_times),
        }, step=global_step, split="train")

        return global_step

    def validate(self, dataloader, global_step, split='test'):
        self.model.eval()
        test_start = time.time()

        acc = MetricAccumulator()
        with torch.no_grad():
            for tbatch, tlabels in tqdm(dataloader, leave=False, desc=f"Validating {split}"):
                tbatch = tbatch.to(self.device)
                tlabels = tlabels.to(self.device, dtype=torch.long)

                tout = self.model(tbatch)
                tloss, tdice, tce = self.criterion(tout, tlabels)

                # Accumulate losses
                acc.add({
                    f"main/{split}_loss": [tloss.item()],
                    f"loss_components/{split}_dice_loss": [tdice.item()],
                    f"loss_components/{split}_ce_loss": [tce.item()]
                })

                # Accumulate metrics
                pc_m = self.metrics_fn(tout, tlabels, split=split)
                br_m = self.brats_fn(tout, tlabels, split=split)
                acc.add(pc_m)
                acc.add(br_m)

        log_dict = acc.all_means()

        # Composite Dice
        comp_dice = (
            log_dict.get(f'brats_{split}/wt_dice', 0.0) +
            log_dict.get(f'brats_{split}/tc_dice', 0.0) +
            log_dict.get(f'brats_{split}/et_dice', 0.0)
        ) / 3

        log_dict[f"{split}/composite_dice"] = comp_dice
        log_dict[f"timing/{split}_sec"] = time.time() - test_start

        self.log(log_dict, step=global_step, split=split)
        acc.clear()
        return comp_dice

    def log(self, metrics_dict, step, split='train'):
        """Unified logging: WandB gets everything; Console only gets Validation."""
        if self.use_wandb and self.run is not None:
            self.run.log(metrics_dict, step=step)

        # When WandB is off, only print validation to keep terminal clean
        elif split == 'test' and "test/composite_dice" in metrics_dict:
            print(
                f"\n### Validation [Step {step}] ###\n"
                f"Composite Dice: {metrics_dict['test/composite_dice']:.4f} | "
                f"WT Dice: {metrics_dict.get('brats_test/wt_dice', 0.0):.4f} | "
                f"TC Dice: {metrics_dict.get('brats_test/tc_dice', 0.0):.4f} | "
                f"ET Dice: {metrics_dict.get('brats_test/et_dice', 0.0):.4f}\n"
                f"{'#' * 45}"
            )

    def save_checkpoint(self, epoch, global_step, composite_dice, is_best=False):
        """
        Saves a 'last_checkpoint' and keeps only the single best model.
        """
        checkpoint = {
            'epoch': epoch,
            'global_step': global_step,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
        }

        last_filename = "./last_checkpoint.pth"
        torch.save(checkpoint, last_filename)

        if is_best:
            # Delete the previous best file if it exists
            if self.best_model_path and os.path.exists(self.best_model_path):
                try:
                    os.remove(self.best_model_path)
                except OSError as e:
                    print(f"Error deleting old best: {e}")

            new_best_name = f"best_model_dice_{composite_dice:.4f}.pth"
            torch.save(checkpoint, new_best_name)

            self.best_model_path = new_best_name
            print(f"New best model saved: {new_best_name}")
        else:
            print(f"Progress saved to {last_filename}")

## Training

When prompted, select option (2): Use an existing W&B account and paste your API key.

In [None]:
model = cvNet2(**config['model']).to(config['hardware']['device'])
criterion = ComboLoss(dice_weight=config['loss']['dice_weight'], ce_weight=config['loss']['ce_weight'],
                      ignore_idxs_dice=config['loss']['ignore_idxs_dice'], ignore_idxs_ce=config['loss']['ignore_idxs_ce'],
                      ce_class_weights=config['loss']['ce_class_weights'])

optimizer = optim.AdamW(
    model.parameters(),
    lr=config['training']['lr'],
    weight_decay=config['training']['weight_decay']
)

total_steps = (len(train_dataloader) // config['training']['accumulation_steps']) * config['training']['epochs']
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=5e-5)

run_name = (
    f"{type(model).__name__}_{type(criterion).__name__}_"
    f"α{config['loss']['dice_weight']}_β{config['loss']['ce_weight']}_"
    f"{type(optimizer).__name__}_{type(scheduler).__name__}_"
    f"num_classes{config['model']['out_channels']}_"
    f"bs{config['training']['train_batch_size']}_"
    f"lr{config['training']['lr']}_{time.strftime('%d-%m-%Y_%H.%M')}"
)
if config['experiment'].get('wandb_logging', False):
    run = wandb.init(project=config['project'], name=run_name, config=config)
else:
    run = None
    print("WandB logging is disabled. Standard console output will be used.")

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    config=config,
    run=run,
    metrics_fn=per_class_metrics,
    brats_fn=bra_ts_region_metrics,
    device=config['hardware']['device']
)

global_step = 0
test_freq = config['training'].get('test_freq', 2)
patience = config['training'].get('early_stopping_patience', 5)
patience_counter = 0

#current_dice = trainer.validate(test_dataloader, global_step, split='test')
for epoch in range(1, config['training']['epochs'] + 1):
    global_step = trainer.train_epoch(train_dataloader, epoch, global_step)

    if epoch % test_freq == 0:
        # Get the score from validation
        current_dice = trainer.validate(test_dataloader, global_step, split='test')
        is_best = current_dice > trainer.best_dice_composite

        if is_best:
            trainer.best_dice_composite = current_dice
            if trainer.run is not None:
                trainer.run.summary["best_val_dice"] = current_dice
            patience_counter = 0
        else:
            patience_counter += 1
            print(f"No improvement for {patience_counter} checks. (Best: {trainer.best_dice_composite:.4f})")

        trainer.save_checkpoint(
            epoch=epoch,
            global_step=global_step,
            composite_dice=current_dice,
            is_best=is_best
        )

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch} epochs.")
            break

wandb.finish()