# Brain Tumor Multiclass Segmentation (2D MRI) — Warm-Up Phase

This notebook presents the *warm-up phase* of a deep learning pipeline for multiclass brain tumor segmentation using 2D MRI scans. The task is formulated as a **semantic segmentation** problem, where each pixel is classified into one of several tumor-related classes: background, edema, enhancing tumor, or necrotic core.

We implement a U-Net architecture in **PyTorch**, with an **EfficientNet-B1** encoder pretrained on ImageNet. During this phase, the encoder is kept **frozen**, allowing the decoder and segmentation head to adapt first — a strategy that helps stabilize early training and prevent overfitting when using pretrained weights.

The dataset is a 2D preprocessed variant of the **BraTS** dataset, re-hosted on [Kaggle](https://www.kaggle.com/code/balakrishcodes/brain-mri-2d/input). It comprises axial MRI slices along with their corresponding multiclass segmentation masks.

This notebook is part of a personal research project focused on exploring **medical image segmentation** through the lens of **transfer learning**, **modular model design**, and **systematic experimentation**.

> **Disclaimer:** This project is intended for research and educational purposes only. It is not designed for clinical or diagnostic use.


## 0. Environment Setup (Google Colab)

This notebook is intended to be executed on **Google Colab**. To access the dataset and save outputs such as trained models or logs, we mount **Google Drive**.

> **Note:** If you are running this notebook outside of Colab (e.g., on a local machine or cloud server), ensure that the dataset is available in the expected directory structure and **skip this step**.


In [None]:
# Mount Google Drive to access datasets and save model outputs
drive.mount('/content/drive')

## 1. Install and Import Dependencies

This section sets up the environment by installing and importing all necessary libraries for building a **2D brain tumor segmentation model** using PyTorch and supporting frameworks.

The following libraries are used throughout this project:

- **Albumentations**: A fast and flexible image augmentation library that supports mask-aware transformations, commonly used in computer vision pipelines.
- **Segmentation Models PyTorch (SMP)**: Provides a collection of high-level segmentation architectures (e.g., U-Net, DeepLabV3, FPN), often with pretrained encoders on ImageNet.
- **TorchMetrics**: Offers standardized and modular evaluation metrics (e.g., Intersection over Union, Dice Score) that integrate seamlessly with PyTorch.
- **MONAI**: A domain-specific framework for medical imaging, offering specialized components such as loss functions (e.g., Dice Loss), transforms, and dataset utilities.
- **Scikit-learn**: Useful for classical ML utilities such as train-validation splitting, class balancing, and metric computation.
- **Other utilities**: Libraries such as NumPy, OpenCV, tqdm, `h5py`, and `pickle` are used for I/O operations, preprocessing, and visualization.

> This modular setup ensures that the pipeline remains **scalable**, **reproducible**, and **extensible** across different stages of experimentation and deployment.


### 1.1 Install Required Packages

This step installs the essential Python packages needed to support the brain tumor segmentation pipeline.

- **Albumentations**: A fast and flexible library for image augmentation, including mask-aware transformations suitable for segmentation tasks.
- **Segmentation Models PyTorch (SMP)**: Provides high-level segmentation architectures such as U-Net, DeepLabV3, and FPN, with pretrained encoder support (e.g., ImageNet).
- **TorchMetrics**: Offers reliable and standardized evaluation metrics (e.g., IoU, Dice) compatible with PyTorch workflows.
- **MONAI**: A domain-specific framework for medical imaging, offering specialized components like loss functions, transforms, and dataset utilities.

> **Note:** If you're running this notebook outside of Google Colab, make sure to install these packages manually in your local environment.


In [None]:
!pip install albumentations segmentation-models-pytorch torchmetrics monai

### 1.2 Import Libraries

This section imports all required Python libraries for model development, training, evaluation, and data processing throughout the notebook.

- **PyTorch**: Core deep learning framework used to define the model architecture, loss functions, optimizers, and training logic.
- **Albumentations**: High-performance library for data augmentation — supports image-mask pair transformations crucial for segmentation.
- **TorchMetrics**: Provides standardized evaluation metrics such as IoU (Jaccard Index) for multiclass segmentation tasks.
- **MONAI**: Specialized medical imaging toolkit offering domain-specific loss functions, transforms, and utilities.
- **Segmentation Models PyTorch (SMP)**: Pretrained segmentation architectures (e.g., U-Net, FPN) with ImageNet-initialized encoders.
- **scikit-learn**: Tools for data splitting, computing class weights, and basic evaluation utilities.
- **Automatic Mixed Precision (AMP)**: PyTorch utilities (`autocast`, `GradScaler`) to accelerate training and reduce memory usage.
- **General-purpose utilities**: Libraries such as NumPy, OpenCV, tqdm, h5py, and pickle for data manipulation, visualization, and file I/O.

> **Note:** This environment is designed for use in **Google Colab**. Functions like Google Drive mounting and file uploads assume a Colab context. If running elsewhere, please adjust the setup accordingly.


In [None]:
# -------------------------
# Core PyTorch Framework
# -------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# -------------------------
# Evaluation Metrics & Class Handling
# -------------------------
from torchmetrics.classification import MulticlassJaccardIndex  # IoU metric for multiclass segmentation
from sklearn.utils.class_weight import compute_class_weight      # Compute class weights for imbalance

# -------------------------
# Data Augmentation
# -------------------------
import albumentations as A                         # General image augmentation
from albumentations.pytorch import ToTensorV2      # Convert images to PyTorch tensors

# -------------------------
# Medical Imaging Tools
# -------------------------
from monai.losses import DiceLoss                  # Dice loss for segmentation tasks
import segmentation_models_pytorch as smp          # Predefined segmentation architectures (e.g., U-Net)

# -------------------------
# Training Optimization (AMP)
# -------------------------
from torch.cuda.amp import autocast, GradScaler    # Automatic mixed precision for faster training

# -------------------------
# Data Handling & I/O
# -------------------------
from sklearn.model_selection import train_test_split
import numpy as np
import cv2
import h5py
import pickle
import glob
import os
import copy
import random
from tqdm import tqdm                               # Progress bar for loops

# -------------------------
# Google Colab Utilities
# -------------------------
from google.colab import drive                      # Mount Google Drive for I/O
from google.colab import files                      # File upload/download support


## 2. Load and Prepare Dataset

In this section, we prepare the brain tumor segmentation dataset for model training and validation.

The dataset is stored in `.h5` (HDF5) format — a compact and structured format suitable for storing multi-dimensional medical imaging data.  
We use standard Python utilities (`glob`, `os`) to recursively search for all `.h5` files within the specified root directory.

After collecting all available files, we partition the dataset into training and validation subsets using `train_test_split` from `scikit-learn`.

> **Note:** No image or label preprocessing is applied at this stage — only file discovery and path-based dataset splitting are performed.


### 2.1 Define Dataset Path

We define the root directory that contains the dataset files.  
Each sample is stored as an `.h5` file — a common format for storing multi-dimensional medical imaging data, such as MRI slices paired with segmentation masks.

> **Note:** Update this path according to your environment (e.g., Google Drive, local machine, or cloud storage).


In [None]:
dataset_path = "/content/drive/MyDrive/your/dataset/path"

### 2.2 Load `.h5` Files from Dataset Directory

We use the `glob` module to recursively search for all `.h5` files within the dataset directory and its subdirectories.

> This assumes each `.h5` file stores a single data sample, including both an input MRI slice and its corresponding segmentation mask.


In [None]:
all_files = glob.glob(os.path.join(dataset_path, "**", "*.h5"), recursive=True)

### 2.3 Split Dataset into Training and Validation Sets

We split the list of `.h5` files into training and validation sets using `train_test_split` from **scikit-learn**.

- An 80/20 split is used: 80% for training, 20% for validation.
- A fixed `random_state` ensures reproducibility.

> Splitting is performed at the file level, with each `.h5` file treated as an independent sample.


In [None]:
train_path, val_path = train_test_split(all_files, test_size=0.2, random_state=42)

## 3. Custom Utilities and Dataset Preparation

This section defines the core components of the 2D brain tumor segmentation pipeline.  
The implementation is modular to support reusability and easier maintenance:

- **Indexing Utilities**: Functions for organizing and indexing dataset content efficiently.
- **Custom Dataset Class**: Loads `.h5` files containing MRI slices and segmentation masks using a hybrid patch extraction strategy.
- **Data Augmentation**: Albumentations-based pipelines applied to both images and masks.
- **Preprocessing Helpers**: Includes mask post-processing and class weight computation to address class imbalance.
- **Loss Functions**: Custom loss implementations (e.g., Dice + CrossEntropy) tailored for multiclass medical segmentation.

> **Note:** All components are modular and extensible for use in other medical imaging applications.


### 3.1 Build Index from `.h5` Files

We build an index from a list of `.h5` files, where each entry corresponds to a valid (non-empty) image–mask pair.

This step improves data loading efficiency by:

- Reducing I/O overhead during training
- Skipping background-only slices
- Enabling optional caching for reuse

> **Note:** The core logic is described in the function’s docstring.  
This step is recommended for large `.h5`-based datasets to optimize runtime performance.


In [None]:
def build_index(file_path, save_path=None, filter_empty=True, force_rebuild=False, verbose=True):
    """
    Builds an index of valid (non-empty) slices from a list of `.h5` files.

    Parameters:
        file_path (List[str]): List of `.h5` file paths to process.
        save_path (str, optional): If provided, saves the resulting index as a pickle file.
        filter_empty (bool): If True, slices with only background are skipped.
        force_rebuild (bool): If True, forces rebuilding the index even if a saved one exists.
        verbose (bool): If True, prints progress, skipped files, and error logs.

    Returns:
        List[Tuple[str, int]]: A list of (file_path, slice_index) tuples for valid slices.
    """


  # Load from existing index if available and rebuild not forced
  if save_path and os.path.exists(save_path) and not force_rebuild:
    if verbose:
      print(f"[INFO]: Index file found... Loading from {save_path}")
    with open(save_path, 'rb') as f:
      return pickle.load(f)

  index_list = []
  skipped_empty = 0
  error_files = 0

  if verbose:
    print(f"[INFO]: Building index of {len(file_path)}")

  for path in tqdm(file_path, desc="Indexing .h5 files"):
    try:
      with h5py.File(path, 'r') as f:
        image = f['x'] # shape: (slices, H, W)
        mask = f['y'] # shape: (slices, H, W)
        for i in range(image.shape[0]):
          mask_slice = mask[i]
          if filter_empty and mask_slice.sum() == 0:
            skipped_empty +=1
            continue
          index_list.append((path, i))

    except Exception as e:
      error_files += 1
      if verbose:
        print(f"[Warning]: Error file skipped {path} -> {e}")

  if verbose:
    print(f"[INFO]: Total slice valid {len(index_list)}")
    if filter_empty:
      print(f"[INFO]: Total empty slices skipped {skipped_empty}")
    if error_files:
      print(f"[INFO]: Total file errors skipped {error_files}")

  # Save index to pickle file
  if save_path:
    try:
      os.makedirs(os.path.dirname(save_path), exist_ok=True)
      with open(save_path, 'wb') as f:
        pickle.dump(index_list, f)
      if verbose:
        print(f"[INFO]: The index is saved in {save_path} ({len(index_list)} slice)")
    except Exception as e:
      print(f"[Error]: Failed to save index -> {e}")

  return index_list

**Example usage**:
We call `build_index()` on both the training and validation paths, and save the index as `.pkl` files.


In [None]:
save_idx_train_path = "/content/drive/MyDrive/train_index_dataset.pkl"
save_idx_val_path = "/content/drive/MyDrive/val_index_dataset.pkl"

train_index = build_index(train_path, save_idx_train_path, force_rebuild=False)
val_index = build_index(val_path, save_idx_val_path, force_rebuild=False)

### 3.2 Dataset Class: BrainTumorHybrid

This custom dataset class is designed to handle `.h5`-based brain tumor segmentation data with hybrid sampling, caching, remapping, normalization, and augmentation. It supports two sampling modes: full slice and random cropped patch, useful for training segmentation models efficiently.

- Supports per-epoch subsampling via `max_sample_per_epoch`
- Uses patch-based hybrid sampling with control via `hybrid_prob`
- Normalizes grayscale image channels and duplicates to 3 channels
- Remaps label values from [0, 50, 100, 150] → [0, 1, 2, 3]
- Caches image/mask slices for fast loading

> **Note:** Caching is automatically enabled when a valid cache_path is provided.
If the cache file exists, it will be loaded to reduce repeated disk I/O.
Otherwise, the dataset will preload all slices into memory and save the cache for future use.
This mechanism can greatly improve performance during training by minimizing disk latency, especially for large datasets.




In [None]:
class BrainTumorHybrid(Dataset):
    """
    Custom PyTorch Dataset for 2D brain tumor slices with hybrid sampling and caching.

    This dataset supports:
    - On-disk caching of preloaded data.
    - Hybrid sampling between full-slice and random patch extraction.
    - Optional Albumentations transforms.
    - Label remapping for tumor segmentation tasks.

    Args:
        index (List[Tuple[str, int]]): List of (file_path, slice_index) pairs.
        cache_path (str, optional): Path to save/load the preloaded cache.
        max_sample_per_epoch (int, optional): Maximum number of samples per epoch.
        transform (callable, optional): Albumentations transform pipeline.
        patch_size (int): Size of square patches to sample.
        hybrid_prob (float): Probability of using the full slice instead of patching.
    """


  # Save the current cache dictionary to disk
  def save_cache(self, cache_path):
    print(f"[INFO]: Saved cache to {cache_path}")
    with open(cache_path, 'wb') as f:
      pickle.dump(self.cache, f)

  # Load cache from disk if exists
  def load_cache(self, cache_path):
    if os.path.exists(cache_path):
      print(f"[INFO]: Loaded cache from {cache_path}")
      with open(cache_path, 'rb') as f:
        self.cache = pickle.load(f)
      print(f"[INFO]: Total cache loaded: {len(self.cache)} slices")
      return True
    return False

  # Remap original mask labels to consecutive class indices
  def remap_labels(self, mask):
    remap = {0:0, 50:1, 100:2, 150:3}
    return np.vectorize(remap.get)(mask)

  # Load & preprocess function
  def load_and_preprocess(self, path, i):
    with h5py.File(path, 'r') as f:
      image = f['x'][i]
      mask = f['y'][i]
    return image, mask

  # Preload cache
  def preload_cache(self):
    print(f"[INFO]: Preloading semua data ke cache...")
    for path, i in tqdm(self.full_index, desc="Caching Data"):
      key = (path, i)
      if key not in self.cache:
        image, mask = self.load_and_preprocess(path, i)
        self.cache[key] = (image, mask)
    print(f"[INFO]: Total cache: {len(self.cache)} slices")

  # Update index list for each epoch (with optional sampling)
  def update_epoch_index(self):
    if self.max_sample_per_epoch is None:
      self.index = self.full_index
    else:
      self.index = random.sample(self.full_index, min(self.max_sample_per_epoch, len(self.full_index)))

  # Perform random patch extraction (with minimum tumor content)
  def random_crop(self, image, mask):
    H, W = image.shape[:2]
    ps = self.patch_size
    if H < ps or W < ps:
      return image, mask

    for _ in range(10):
      x = np.random.randint(0, W - ps)
      y = np.random.randint(0, H - ps)
      img_patch = image[y:y+ps, x:x+ps]
      mask_patch = mask[y:y+ps, x:x+ps]

      if np.sum(mask_patch > 0) > 10:
        img_patch = cv2.resize(img_patch, (240, 240), interpolation=cv2.INTER_LINEAR)
        mask_patch = cv2.resize(mask_patch, (240, 240), interpolation=cv2.INTER_NEAREST)
        return img_patch, mask_patch

    img_patch = image[0:ps, 0:ps]
    mask_patch = mask[0:ps, 0:ps]
    img_patch = cv2.resize(img_patch, (240, 240), interpolation=cv2.INTER_LINEAR)
    mask_patch = cv2.resize(mask_patch, (240, 240), interpolation=cv2.INTER_NEAREST)
    return img_patch, mask_patch

  # Init Datset
  def __init__(self, index, cache_path=None, max_sample_per_epoch=None, transform=None, patch_size=160, hybrid_prob=0.3):
    self.full_index = index
    self.cache_path = cache_path
    self.cache = dict()
    self.max_sample_per_epoch = max_sample_per_epoch
    self.transform = transform
    self.patch_size = patch_size # Patch size used in random cropping
    self.hybrid_prob = hybrid_prob # Probability of using full slice (vs patch)

    if cache_path:
      if not self.load_cache(cache_path):
        self.preload_cache()
        self.save_cache(cache_path)

    self.update_epoch_index()

  # Return dataset length (after sampling)
  def __len__(self):
    return len(self.index)

  # Retrieve a single sample from the dataset
  def __getitem__(self, idx):
    path, i = self.index[idx]
    key = (path, i)

    # Retrieve from cache if available, otherwise load from file
    if key in self.cache:
      image, mask = self.cache[key]
    else:
      image, mask = self.load_and_preprocess(path, i)

    # Remap mask
    mask = self.remap_labels(mask).astype(np.int64)

    # Resize early to ensure consistent shape before augmentation
    image = cv2.resize(image, (240, 240), interpolation=cv2.INTER_LINEAR)
    mask = cv2.resize(mask, (240, 240), interpolation=cv2.INTER_NEAREST)

    # Normalize image using z-score normalization
    image = image.astype(np.float32)
    image = (image - image.mean()) / (image.std() + 1e-5)

    # Expand grayscale image to 3-channel to match pretrained model expectations
    image = np.expand_dims(image, axis=-1)
    image = np.repeat(image, 3, axis=-1)

    # Hybrid sampling: randomly choose between full slice or patch
    if random.random() < self.hybrid_prob:
      img_patch, mask_patch = image, mask
    else:
      img_patch, mask_patch = self.random_crop(image, mask)

    # Apply augmentation if provided
    if self.transform:
      augmentation = self.transform(image=img_patch, mask=mask_patch)
      img_patch = augmentation['image']
      mask_patch = augmentation['mask']

      mask_patch = mask_patch.squeeze(0) if mask_patch.ndim == 2 else mask_patch
      mask_patch = mask_patch.long()

    return img_patch, mask_patch

### 3.3 Data Augmentation Pipelines

This section defines the data augmentation strategies applied during training and validation.  
Augmentation is crucial in medical image segmentation to increase model robustness by simulating real-world variability.

We use **Albumentations** to construct flexible and composable pipelines:

- **`train_T`** — Includes both geometric (e.g., flips, rotations) and intensity-based augmentations (e.g., brightness, contrast) to promote generalization and prevent overfitting.
- **`val_T`** — Applies only resizing and normalization to ensure consistency during validation.

> **Note:** These augmentations are applied on-the-fly within the dataset class, ensuring efficient and randomized augmentation for each training epoch.


In [None]:
# Transform for training data (with augmentation)
train_T = A.Compose([
    A.Resize(240, 240),                       # Resize input to 240x240
    A.GridDistortion(p=0.2),                  # Apply grid distortion
    A.RandomBrightnessContrast(p=0.3),        # Adjust brightness and contrast
    A.GaussNoise(p=0.2),                      # Add Gaussian noise
    A.HorizontalFlip(p=0.3),                  # Random horizontal flip
    A.VerticalFlip(p=0.3),                    # Random vertical flip
    A.ToTensorV2()                            # Convert to PyTorch tensor
])

# Transform for validation data (no augmentation)
val_T = A.Compose([
    A.Resize(240, 240),
    A.ToTensorV2()
])

### 3.4 Data Preparation and Loading

We instantiate the custom `BrainTumorHybrid` dataset for both training and validation phases, configuring caching, transformation, and sampling strategies.

- **Training Set**  
  Uses *hybrid sampling* (a mix of full slices and random patches) to increase sample diversity and focus on lesion regions.  
  Augmentations are applied on-the-fly using the `train_T` pipeline.

- **Validation Set**  
  Uses *only full slices* by setting `hybrid_prob=1.0`, ensuring that patch sampling is disabled for fair evaluation.  
  Augmentations are minimal (`val_T`) to maintain consistency.

We also define **PyTorch DataLoaders** with appropriate batch sizes and multithreading for efficient data streaming during training.

> **Note:** Dataset caching is automatically handled — if a cache file exists, it will be loaded; otherwise, it will be generated from scratch and saved for future runs.


In [None]:
# Cache path
cache_train_dataset_path = "/content/drive/MyDrive/cache_train_dataset.pkl"
cache_val_dataset_path = "/content/drive/MyDrive/cache_val_dataset.pkl"

# Instantiate datasets
train_dataset = BrainTumorHybrid(
    train_index,
    transform=train_T,
    cache_path=cache_train_dataset_path
)

val_dataset = BrainTumorHybrid(
    val_index,
    transform=val_T,
    hybrid_prob=1.0,
    cache_path=cache_val_dataset_path
)

# DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

### 3.5 Class Weights Computation

To address the issue of **class imbalance** in the segmentation masks, we compute class weights based on the **pixel distribution** in the training dataset.  
These weights are especially useful when using loss functions such as `CrossEntropyLoss` that support per-class weighting via the `weight` argument.

- **`collect_mask_from_index`**  
  Extracts flattened segmentation masks from the full-slice training dataset.  
  This step bypasses any patch sampling or data augmentations to reflect the true class distribution.

- **`calculate_class_weight`**  
  Computes class weights using `sklearn.utils.class_weight.compute_class_weight` and saves the result as a PyTorch tensor in `.pth` format for reuse.

> **Note:** Accurate class weights can improve training stability and performance on underrepresented classes (e.g., necrotic core).


In [None]:
# Collect mask form index
def collect_mask_from_index(index):
    """
    Collects masks directly from the HDF5 files using the provided index.

    This bypasses augmentations and patch sampling done in the Dataset class.
    The mask labels are remapped to class indices (0-3) and flattened for use
    in class weight computation.

    Args:
        index (List[Tuple[str, int]]): List of (file_path, slice_index) pairs.

    Returns:
        List[int]: Flattened list of remapped mask values.
    """


  all_labels = []

  # Loop through each indexed file and slice
  for path, i in tqdm(index, desc="Collecting mask"):
    with h5py.File(path, 'r') as f:
      mask = f['y'][i]

      # Remap original pixel values to class indices
      remap = {0: 0, 50: 1, 100: 2, 150: 3}
      mask = np.vectorize(remap.get)(mask)

      # Flatten and collect all label values
      all_labels.extend(mask.flatten())
  return all_labels

# Compute class weight
def calculate_class_weight(index, num_classes=4, save_path=None, load_if_available=True):
      """
    Calculates class weights from the ground-truth masks for use in loss balancing.

    This function computes balanced class weights based on the frequency of each class
    label found in the dataset. It supports saving and loading from a cache file.

    Args:
        index (List[Tuple[str, int]]): List of (file_path, slice_index) to extract masks from.
        num_classes (int): Number of classes to consider for weight computation.
        save_path (str, optional): Path to save the computed weights.
        load_if_available (bool): If True, loads precomputed weights if available.

    Returns:
        torch.Tensor: Class weights as a float tensor of shape (num_classes,).
    """


  # Load from cache if available
  if load_if_available and save_path and os.path.exists(save_path):
     print(f"[INFO]: Index found... Loading from {save_path}")
     return torch.load(save_path)

  # Collect flattened ground-truth labels from the dataset
  labels = collect_mask_from_index(index)

  # Compute class weights using scikit-learn
  weight = compute_class_weight('balanced', classes=np.arange(num_classes), y=labels)
  weight_tensor = torch.tensor(weight, dtype=torch.float32)

  # Save to cache if needed
  if save_path:
    torch.save(weight_tensor, save_path)
    print(f"[INFO]: Class weight file saved to {save_path}")

  return weight_tensor


In [None]:
save_path_weight_calculate = "/content/drive/MyDrive/weight_calculate.pth"

# call function calculate_class_weight
class_weight = calculate_class_weight(
    train_index,
    num_classes=4,
    save_path=save_path_weight_calculate,
    load_if_available=True)

### 3.6 Combined Loss Function (CrossEntropy + Dice)

To effectively train the segmentation model on imbalanced medical data, we define a custom loss function that combines **CrossEntropyLoss** (pixel-wise classification) and **DiceLoss** (region-based overlap measure).  
This hybrid approach balances *local accuracy* (via CrossEntropy) and *global structure preservation* (via Dice), which is critical for robust tumor segmentation.

> **Note:** DiceLoss helps the model better capture smaller or less frequent tumor subregions, which are often underrepresented in the dataset.

The `ComboLoss` class internally handles reshaping and formatting, making it compatible with both **MONAI** and native **PyTorch** segmentation pipelines.


In [None]:
# Combined Loss: CrossEntropyLoss + DiceLoss
class ComboLoss(nn.Module):
    """
    Combines CrossEntropyLoss and DiceLoss for multi-class segmentation.

    Args:
        ce_weight (Tensor or None): Optional tensor of class weights for CrossEntropyLoss.
        dice_weight (float): Scaling factor for DiceLoss. Default is 1.0.
        ce_scale (float): Scaling factor for CrossEntropyLoss. Default is 1.0.
    """

  def __init__(self, ce_weight=None, dice_weight=1.0, ce_scale=1.0): # ce=CrossEntropy
    super().__init__()
    self.ce = nn.CrossEntropyLoss(weight=ce_weight) # CE Defined
    self.dice = DiceLoss(to_onehot_y=True, softmax=True) # DL Defined
    self.dice_weight = dice_weight
    self.ce_scale = ce_scale

  def forward(self, preds, targets):
        """
        Computes the combined loss.

        Args:
            preds (Tensor): Predicted logits of shape (B, C, H, W).
            targets (Tensor): Ground truth mask of shape (B, H, W) or (B, 1, H, W).

        Returns:
            Tensor: Combined loss value.
        """


    # Prepare targets for CrossEntropyLoss
    targets_ce = targets.squeeze(1) if targets.ndim == 4 else targets

    # Prepare targets for DiceLoss
    targets_dice = targets.unsqueeze(1) if targets.ndim == 3 else targets

    # Compute individual loss components
    loss_ce = self.ce(preds, targets_ce)
    loss_dice = self.dice(preds, targets_dice)

    # Return combined weighted loss
    return self.ce_scale * loss_ce + self.dice_weight * loss_dice

## 4. Model Initialization & Training Configuration

This section defines the segmentation model architecture and sets up the full training configuration pipeline.  
The model leverages a pre-trained encoder for efficiency, and all components are organized for clarity and reproducibility.

> **Note:** `EfficientNet-B1` is selected as the encoder due to its strong balance between accuracy and computational efficiency. It is initialized with ImageNet weights.

**Key components:**
- **Model**: U-Net-style architecture with an `EfficientNet-B1` encoder backbone.
- **Loss Function**: Combination of `CrossEntropyLoss` and `DiceLoss`, designed to handle class imbalance and spatial overlap effectively.
- **Optimizer**: Adam with weight decay for regularization.
- **Learning Rate Scheduler**: `ReduceLROnPlateau` to adaptively reduce the learning rate based on validation loss.
- **Early Stopping**: Monitors validation loss to prevent overfitting.
- **Metric**: Multiclass Jaccard Index (IoU) for evaluating segmentation quality across all classes.
- **AMP (Automatic Mixed Precision)**: Enabled via `GradScaler` for faster training and reduced memory usage on compatible GPUs.

All configurations are defined prior to training to ensure clarity, reproducibility, and compatibility with different hardware environments.


### 4.1 Device Setup and Model Initialization

We begin by defining the computation device (GPU or CPU), and then initialize the segmentation model.  
The model follows a U-Net architecture with an `EfficientNet-B1` encoder, pre-trained on ImageNet for better feature extraction and faster convergence.

To accelerate training and reduce overfitting during early epochs, all encoder layers are initially frozen.

> **Note:** Freezing the encoder enables the decoder to learn effectively at the start. Fine-tuning can be performed later by unfreezing the encoder layers once the decoder has stabilized.


In [None]:
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model from smp as Encoder
model = smp.Unet(
    encoder_name='efficientnet-b1',
    encoder_weights='imagenet',
    in_channels=3,
    classes=4
).to(device)

# Freeze encoder layers
for param in model.encoder.parameters():
  param.requires_grad = False

### 4.2 Early Stopping

To avoid overfitting and reduce unnecessary computation, we define a custom **EarlyStopping** utility.  
This mechanism monitors the validation metric (e.g., loss or accuracy) and halts training if no improvement is observed over a defined patience period.

It supports both `'min'` and `'max'` monitoring modes, depending on whether the target metric is expected to decrease (e.g., loss) or increase (e.g., accuracy or IoU).

> **Note:** Early stopping is particularly useful when training high-capacity models on limited datasets, ensuring training stops at the optimal point before performance degrades.


In [None]:
# Custom early stopping class to monitor validation performance
# Stops training if no improvement is observed over 'patience' epochs
class EarlyStopping:
    def __init__(self, monitor='val_loss', mode='min', patience=3, delta=0.0, verbose=True):
         """
        Args:
            monitor (str): Metric to monitor ('val_loss' or 'val_acc')
            mode (str): 'min' → lower is better, 'max' → higher is better
            patience (int): # of epochs with no improvement before stopping
            delta (float): Minimum change to qualify as improvement
            verbose (bool): Print status each epoch if True
        """
        self.monitor = monitor
        self.mode = mode
        self.patience = patience
        self.delta = delta
        self.verbose = verbose

        self.best_score = None
        self.counter = 0
        self.early_stop = False

        # Set comparison function and initial best value
        if self.mode == 'min':
            self.monitor_op = lambda current, best: current < best - self.delta
            self.best_score = np.inf
        elif self.mode == 'max':
            self.monitor_op = lambda current, best: current > best + self.delta
            self.best_score = -np.inf
        else:
            raise ValueError("mode must be 'min' or 'max'")

    def __call__(self, current_score):
        # Initialize best score
        if self.best_score is None:
            self.best_score = current_score
            if self.verbose:
                print(f"[EarlyStopping] Initial best {self.monitor}: {self.best_score:.4f}")
        # Check for improvement
        elif self.monitor_op(current_score, self.best_score):
            self.best_score = current_score
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Improved {self.monitor}: {self.best_score:.4f}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement in {self.monitor} for {self.counter}/{self.patience} epochs.")
            # Stop training if performance has not improved for 'patience' epochs
            if self.counter >= self.patience:
                if self.verbose:
                    print(f"[EarlyStopping] Stopping training. Best {self.monitor}: {self.best_score:.4f}")
                self.early_stop = True


#### 4.2.1 EarlyStopping Configuration

After defining the `EarlyStopping` class, we instantiate it by specifying:

- `monitor`: The validation metric to observe (`val_loss`)
- `mode`: Whether to minimize or maximize the metric (`'min'` for loss)
- `patience`: The number of consecutive epochs to wait without improvement

This configuration ensures that training is terminated when the model stops improving, thereby avoiding overfitting and reducing computational cost.


In [None]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=3,
    verbose=True
)

### 4.3 Loss Function

We define a composite loss function by combining **CrossEntropyLoss** and **DiceLoss**, which is particularly suitable for multi-class medical image segmentation tasks characterized by class imbalance and irregular anatomical structures.

**Key characteristics:**
- **CrossEntropyLoss** handles pixel-wise classification and supports class weights to account for imbalanced label distributions.
- **DiceLoss** emphasizes spatial overlap between prediction and ground truth, making it effective for capturing small or sparse tumor regions.

The custom `ComboLoss` implementation integrates both components and automatically moves the loss function to the appropriate computation device (`CPU` or `GPU`) for efficient training.

> **Note:** This hybrid loss function is designed to balance region-based similarity and classification accuracy, providing more stable and meaningful learning signals across all tumor classes.


In [None]:
criterion = ComboLoss(ce_weight=class_weight).to(device)

### 4.4 Optimizer Configuration

We configure the optimizer using **Adam**, known for its adaptive learning rate and fast convergence in deep learning tasks.  
A base learning rate of `1e-3` is selected, which works well as a starting point for most segmentation models.

> **Note:** Weight decay can be applied through `Adam` to introduce L2 regularization, helping reduce overfitting in deeper networks.


In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)

### 4.5 Learning Rate Scheduling

To enable dynamic learning rate adjustment during training, we configure **ReduceLROnPlateau**, a scheduler that monitors the validation loss and reduces the learning rate when no significant improvement is observed.

**Key characteristics:**
- **Mode**: `'min'` — activates when the validation loss stops decreasing.
- **Factor**: `0.1` — reduces the learning rate by a factor of 10.
- **Patience**: `2` — waits for two stagnant epochs before reducing.
- **Verbose**: Enabled to log learning rate changes.

> **Note:** This scheduler is particularly suitable for medical segmentation tasks, where validation loss may fluctuate due to class imbalance or noisy annotations.  
> Adaptive scheduling allows the optimizer to take larger steps when improving and smaller steps when plateauing, enhancing convergence stability and helping avoid local minima.


In [None]:
scheduler = ReduceLROnPlateau(
    early_optimaizer,
    mode='min',
    factor=0.1,
    patience=2
)

### 4.6 Evaluation Metric

For segmentation performance evaluation, we use the **Multiclass Jaccard Index** (also known as Intersection over Union, IoU). This metric is particularly effective for multi-class problems, as it measures the overlap between predicted and ground truth masks across all classes.

The Jaccard Index is computed as:

$$
IoU = \frac{TP}{TP + FP + FN}
$$

Where:
- **TP**: True Positives
- **FP**: False Positives
- **FN**: False Negatives

This metric provides a more interpretable performance indicator than accuracy in segmentation tasks, especially when class imbalance is present.

> **Note:** The Multiclass Jaccard Index is particularly sensitive to class imbalance and is widely used in medical image segmentation benchmarks.



In [None]:
matric = MulticlassJaccardIndex(num_classes=4).to(device)

### 4.7 Mixed Precision Training Setup

To accelerate training and reduce GPU memory usage, we enable **Automatic Mixed Precision (AMP)** using PyTorch’s `GradScaler`. AMP performs selected operations in half precision while maintaining stability via dynamic gradient scaling, making it well-suited for large models and high-resolution data.

> **Note:** While AMP can significantly improve performance, it may introduce instability in certain architectures or loss functions. Careful monitoring during training is recommended.


In [None]:
scaler = GradScaler()

## 5. Model Training

This section implements the training loop for our segmentation model, incorporating **Automatic Mixed Precision (AMP)** for computational efficiency and GPU memory optimization. The training is guided by both loss and Intersection over Union (IoU) metrics on training and validation sets, with **early stopping** and **learning rate scheduling** to enhance convergence and avoid overfitting.

**Key components:**

- **Automatic Mixed Precision (AMP):** Accelerates training by using float16 where appropriate, while preserving numerical stability via `GradScaler`.
- **IoU Metric:** Evaluates segmentation performance more robustly than accuracy.
- **Early Stopping:** Terminates training when validation loss no longer improves after a patience threshold.
- **ReduceLROnPlateau Scheduler:** Dynamically lowers the learning rate when performance stagnates.

> **Note:** Maintain the correct order of operations (`forward → backward → scaler.step(optimizer) → zero_grad()`) when using AMP. Incorrect sequencing may lead to unstable gradients or halted learning.


### 5.1 Training Loop

We define the full training loop across a fixed number of epochs. Each epoch is composed of two distinct phases: **training** and **validation**.

- During the **training phase**, the model performs a forward and backward pass using **Automatic Mixed Precision (AMP)** with dynamic gradient scaling (`GradScaler`).
- In the **validation phase**, the model is evaluated without gradient updates to track performance metrics.

After each validation:
- The learning rate scheduler (`ReduceLROnPlateau`) is updated based on validation loss.
- **Early stopping** checks whether the validation loss has improved.
- If the current model achieves a higher validation IoU, its weights are stored in memory and saved to disk.

> **Note:** Always ensure that gradient scaling, optimizer stepping, and zeroing are correctly ordered. When using AMP, an incorrect sequence may silently result in no learning progress.


In [None]:
best_model_wts = copy.deepcopy(model.state_dict())
best_val_iou = 0.0

num_epochs = 200

for epoch in range(num_epochs):
  print("-" * 50)
  print(f"Epoch {epoch+1}/{num_epochs}")
  print("-" * 50)

  # Optional: dynamically update training set (e.g. patch-based sampling)
  # If applicable, reinitialize DataLoader per epoch to reflect updated sampling
  # -------------------------------------------------------------------------------------------------------------
  # train_dataset.update_epoch_index()
  # train_loader = DataLoader(train_dataset, batch_size=x, shuffle=True, num_workers=4, pin_memory=True)
  # -------------------------------------------------------------------------------------------------------------

  # --- Train Phases ---
  model.train()
  train_loss, train_iou = 0.0, 0.0

  for image, mask in tqdm(train_loader, desc="Training"):
    image, mask = image.to(device), mask.to(device)

    # Forward pass with mixed precision (autocast increases speed and reduces memory usage)
    with autocast():
      outputs = model(image)
      loss = criterion(outputs, mask)

    # Clear gradients to prevent accumulation in the next iteration
    optimizer.zero_grad()

    # Backward pass with scaled gradients (helps prevent underflow in float16 training)
    scaler.scale(loss).backward()

    # Update model parameters using scaled gradients
    scaler.step(optimizer)

    # Update the scaler for dynamic loss scaling
    scaler.update()

    # Track total loss and IoU for monitoring training progress
    train_loss += loss.item() * image.size(0)
    preds = torch.argmax(outputs, dim=1)
    train_iou += matric(preds, mask).item() * image.size(0)

  avg_train_loss = train_loss / len(train_loader.dataset)
  avg_train_iou = train_iou / len(train_loader.dataset)

  # --- Val Phases ---
  model.eval()
  val_loss, val_iou = 0.0, 0.0

  with torch.no_grad():
    for image, mask in tqdm(val_loader, desc="Validation"):
      image, mask = image.to(device), mask.to(device)

      outputs = model(image)
      loss = criterion(outputs, mask)

      val_loss += loss.item() * image.size(0)
      preds = torch.argmax(outputs, dim=1)
      val_iou += matric(preds, mask).item() * image.size(0)

    avg_val_loss = val_loss / len(val_loader.dataset)
    avg_val_iou = val_iou / len(val_loader.dataset)

    print(f"Train Loss: {avg_train_loss:.4f} | Train IoU: {avg_train_iou:.4f}")
    print(f"Val Loss: {avg_val_loss:.4f}     | Val IoU: {avg_val_iou:.4f}")

    # Adjust learning rate based on validation loss
    scheduler.step(avg_val_loss)

    # Check early stopping condition
    early_stopping(avg_val_loss)

    # Save best model based on validation IoU
    if avg_val_iou > best_val_iou:
      best_val_iou = avg_val_iou
      best_model_wts = copy.deepcopy(model.state_dict())
      torch.save(model.state_dict(), "best_warmup_segmentation_model.pth")
      print(f"[INFO]: New Model Saved!")

    if early_stopping.early_stop:
      print("[INFO]: Early Stop Trigred!")
      break

model.load_state_dict(best_model_wts)
print("Best Model Earlytrain Loaded!")



### 5.2 Saving the Best Model

The model weights corresponding to the **highest validation IoU** are stored in memory (`best_model_wts`) during training. This ensures that we retain the most performant version of the model regardless of when training ends (either by early stopping or epoch limit).

At the end of training:
- The model is reloaded with these best weights.
- The weights are then saved to disk for later evaluation or deployment.

> **Note:** Reloading the best weights before saving ensures that the final model reflects the best validation performance achieved during training.


In [None]:
# Save the best model to disk for future use (e.g., inference or further evaluation)
torch.save(model.state_dict(), "/content/drive/MyDrive/braTS_earlyTrain_segModel_01.pth")

# Conclusion & Reflection

This notebook represents a focused exploration into 2D brain tumor segmentation using deep learning — developed as a personal learning project and portfolio showcase.

Through this process, I designed and implemented a modular pipeline that includes dataset preprocessing, class imbalance handling, mixed-precision training (AMP), metric evaluation with IoU, and best-model saving strategies. Every stage was crafted with reproducibility and clarity in mind, helping me deepen my understanding of practical model development in the context of medical image segmentation.

This work not only enhanced my technical skills in PyTorch and segmentation workflows, but also trained my ability to think structurally — planning each section of the notebook to build toward a cohesive and explainable pipeline. Additionally, it highlighted key challenges in medical AI, such as class imbalance and efficient resource usage.

>💡 As part of an ongoing journey, this notebook lays the foundation for future work involving 3D volumetric segmentation, explainable AI (XAI), and deployment through interactive tools — gradually advancing toward more robust and human-centric AI systems in medical imaging.
