# BratsMamba: 3D Brain Tumor Segmentation with State Space Model



## Introduction

Our main goal is to capture both fine-grained tumor boundaries (Necrosis) and global contextual features (Edema) efficiently in MRI tumor images. This is a heavily researched and critical area vital for bio-medical engineering field. This problem is in the domain of computer vision, specifically semantic segmentation tasks. 

Semantic segmentation is a critical task in computer vision, requiring the precise classification of every pixel in an image. Traditional Convolutional Neural Networks (CNNs) like U-Net excel at capturing local features but often struggle with long-range dependencies due to their limited receptive fields. Transformers addressed this with self-attention, but at the cost of quadratic computational complexity($O(N^2)$), making them heavy for high-resolution tasks.

**Why Mamba?** We address this trade-off using **Mamba-SSM (State Space Models)**. Mamba offers linear complexity ($O(N)$) with respect to sequence length, allowing us to model global context (long-range dependencies) without the massive memory overhead of Transformers.

**The Solution: BratsMamba** 
This notebook implements BratsMamba, a hybrid architecture that combines the hierarchical structure of a U-Net with Mamba blocks. This allows us to capture:
1. **Local Texture Details**: Via convolutional stems and decoder blocks.
3. **Global Semantic Context**: Via Mamba encoders that scan the image as a sequence, understanding the "whole picture" efficiently.

### Key Technical Features:
* **Architecture:** Dual-Path Conv Stem + Mamba Encoder/Decoder + U-Net Skip Connections.
* **Data Pipeline:** Lazy loading from internal disk (`/tmp`) to handle large datasets without RAM explosion.
* **Robustness:** Implements `SpatialPadd` and `DivisiblePadd` to handle variable MRI volume sizes preventing shape mismatches.
* **Evaluation:** Clinical metrics (Dice & HD95) calculated on Whole Tumor (WT), Tumor Core (TC), and Enhancing Tumor (ET).

## Imports

In [None]:
# =============================================================================
# general imports
# =============================================================================

import sys
import os
import time
import json
import tarfile
import subprocess
import warnings
import random
import shutil
import glob as gb
from tqdm import tqdm

# suppress cluttered warnings
warnings.filterwarnings("ignore")

# -----------------------------------------------------------------------------
# mamba-ssm, monai, nibabel and einops
# -----------------------------------------------------------------------------
print("‚öôÔ∏è Checking and Installing Dependencies...")
start_install = time.time()

# Helper to install if missing
def install_package(package_name, pip_name=None):
    if pip_name is None: pip_name = package_name
    try:
        __import__(package_name)
    except ImportError:
        print(f"   Installing {pip_name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", pip_name, "--quiet"])

# 1. Medical / 3D Imaging Libraries
install_package("nibabel")
install_package("monai")
install_package("einops") # needed for tensor rearranging in Mamba/Transformers

# 2. Mamba-SSM & Causal Conv1d
# We prioritize pre-built wheels to save time.
try:
    import mamba_ssm
    print("   ‚úÖ Mamba-SSM already installed.")
except ImportError:
    print("   ‚ö†Ô∏è Mamba-SSM not found. Installing specific versions for Kaggle T4...")
    try:
        # Install causal-conv1d first
        subprocess.check_call([sys.executable, "-m", "pip", "install", "causal-conv1d>=1.2.0"])
        # Install mamba-ssm
        subprocess.check_call([sys.executable, "-m", "pip", "install", "mamba-ssm"])
        print("   ‚úÖ Mamba-SSM installed successfully.")
    except Exception as e:
        print(f"   ‚ùå Error installing Mamba: {e}")
        print("   -> Ensure you are using GPU T4 and Internet is ON.")

print(f"‚úÖ Dependencies ready in {time.time() - start_install:.1f}s")

# -----------------------------------------------------------------------------
# 1.2 other libraries
# -----------------------------------------------------------------------------
print("üìÇ Importing Libraries...")

# > standard data science
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from PIL import Image

# > PyTorch & DL
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler  # mixed Precision
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau

# > Medical / Specialized Imaging (MONAI & Nibabel)
import nibabel as nib

# MONAI & Medical Imaging Imports
from monai.utils import set_determinism
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.data import DataLoader, Dataset, decollate_batch
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, 
    ScaleIntensityd, RandCropByPosNegLabeld, RandFlipd, 
    RandShiftIntensityd, SpatialPadd, DivisiblePadd, AsDiscrete
)

# > Math & Tensor Manipulation
from einops import rearrange, repeat

# > Scikit-Learn (Metrics & Splitting)
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, jaccard_score, f1_score

# -----------------------------------------------------------------------------
# 1.3 CONFIGURATION & SEEDING
# -----------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =============================================================================
# ‚öôÔ∏è CONFIGURATION
# =============================================================================
class Config:
    SEED = 42
    # Data extraction path (Using /tmp for speed and to avoid /working limits)
    EXTRACT_PATH = "/tmp/brats2021_data" 
    TAR_PATH = "/kaggle/input/brats2021-task1-converting-to-processed-dataset/BraTS2021_1000_Samples.tar"
    
    # Dataset Limits (Set N_SAMPLES to None for full dataset)
    N_SAMPLES = 1400       
    VAL_SPLIT = 0.2
    
    # Training Hyperparameters
    IMG_SIZE = (128, 128, 128)
    BATCH_SIZE_TRAIN = 4   # Effective batch size = 8 (on 2 GPUs)
    BATCH_SIZE_VAL = 1     # MUST be 1 to handle variable image sizes
    NUM_WORKERS = 8        # High worker count for efficient lazy loading
    NUM_EPOCHS = 50
    LEARNING_RATE = 3e-4
    MAX_RUNTIME = 11.5 * 3600 # Safety buffer for Kaggle timeout
    
    # Paths
    ARTIFACTS_DIR = "/kaggle/working/artifacts"
    CHECKPOINT_DIR = os.path.join(ARTIFACTS_DIR, "checkpoints")
    RESULTS_DIR = os.path.join(ARTIFACTS_DIR, "results")
    BEST_MODEL_DIR = os.path.join(ARTIFACTS_DIR, "best_model")

# Create Directories
for d in [Config.EXTRACT_PATH, Config.CHECKPOINT_DIR, Config.RESULTS_DIR, Config.BEST_MODEL_DIR]:
    os.makedirs(d, exist_ok=True)

# Set Reproducibility
set_determinism(seed=Config.SEED)
print("‚úÖ Configuration Complete. Artifacts will be saved to:", Config.ARTIFACTS_DIR)

print(f"\nüöÄ System Ready.")
print(f"   PyTorch: {torch.__version__}")
print(f"   Device:  {DEVICE}")
if torch.cuda.is_available():
    print(f"   GPU:     {torch.cuda.get_device_name(0)}")
    print(f"   Memory:  {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

‚öôÔ∏è Checking and Installing Dependencies...
   Installing monai...
   ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ 2.7/2.7 MB 28.4 MB/s eta 0:00:00
   ‚ö†Ô∏è Mamba-SSM not found. Installing specific versions for Kaggle T4...
Collecting causal-conv1d>=1.2.0
  Downloading causal_conv1d-1.5.3.post1.tar.gz (24 kB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: causal-conv1d
  Building wheel for causal-conv1d (pyproject.toml): started
  Building wheel for causal-conv1d (pyproject.toml): finished with status 'done'
  Created wheel for causal-conv1d: filename=causal_conv1d-1.5.3.post1-cp312-cp312-l

## DataSet - BraTS (Brain Tumor Segmentation)

**Why this Dataset?**

We utilize the **BraTS Challenge Dataset**, the global benchmark for medical 3D segmentation. This dataset is uniquely suited for **BratMamba** because:

1. **3D Volumetric Data**: Unlike 2D datasets, BraTS provides full 3D MRI volumes ($240 \times 240 \times 155$). This justifies the use of Mamba-SSM, which excels at modeling the extremely long sequences created by flattening 3D volumes ($N = H \times W \times D$), a task where Transformers typically run out of memory.
2. **Multi-Modal Complexity**: Each patient has 4 modalities (T1, T1c, T2, FLAIR). Our "Dual CNN Stem" is designed specifically to fuse these heterogeneous signals locally before global processing.
3. **Class Imbalance**: The tumor sub-regions (Necrosis, Edema, Enhancing) vary wildly in size, requiring robust loss functions (Dice/Focal) rather than simple accuracy.
  

*Segmentation Classes & Labels*

We follow the standard BraTS protocol:
* Label 0: Background
* Label 1 (NCR): Necrotic Tumor Core (Hypointense on T1-Gd)
* Label 2 (ED): Peritumoral Edema (Hyperintense on FLAIR)
* Label 4 (ET): Enhancing Tumor (Hyperintense on T1-Gd)


*Dataset Paper Citation*:

[1]U. Baid et al., ‚ÄúThe RSNA-ASNR-MICCAI BraTS 2021 Benchmark on Brain Tumor Segmentation and Radiogenomic Classification,‚Äù arXiv:2107.02314 [cs], Sep. 2021, Available: https://arxiv.org/abs/2107.02314

*Evaluation Metrics*

To ensure fair comparison with SOTA, we track:
1. **Dice Similarity Coefficient (DSC)**: Measures overlap accuracy.
2. **Hausdorff Distance (HD95)**: Measures the worst-case boundary error (critical for surgical planning).

## Importing Dataset & Preparation for training


We already process the BraTS 2021 Task01 dataset on `brats2021-task1-converting-to-processed-dataset` notebook and imported all the processed data to this notebook to save disk space.

In [None]:
# =============================================================================
# üì¶ DATA UNPACKING & DISCOVERY
# =============================================================================
def unpack_dataset(tar_path, extract_path, limit=None):
    """Extracts dataset from TAR to fast local storage."""
    print(f"üì¶ Source: {tar_path}")
    
    # Check existing files to avoid re-extracting
    existing = gb.glob(os.path.join(extract_path, "**", "*_x.npy"), recursive=True)
    if limit and len(existing) >= limit:
        print(f"‚úÖ Data already unpacked ({len(existing)} samples). Skipping...")
        return

    print("‚è≥ Unpacking... (This utilizes internal disk IO)")
    with tarfile.open(tar_path, "r") as tar:
        members = tar.getmembers()
        count = 0
        for member in tqdm(members, desc="Extracting"):
            if limit and count >= limit: break
            tar.extract(member, path=extract_path)
            count += 1
    print(f"‚úÖ Extracted {count} files.")

def get_file_lists(data_dir):
    """Robustly pairs Input volumes (*_x.npy) with Label volumes (*_y.npy)."""
    print(f"üîç Scanning: {data_dir}")
    input_files = sorted(gb.glob(os.path.join(data_dir, "**", "*_x.npy"), recursive=True))
    label_files = sorted(gb.glob(os.path.join(data_dir, "**", "*_y.npy"), recursive=True))
    
    if not input_files:
        raise ValueError("‚ùå No .npy files found! Check dataset path.")
        
    data_dicts = [{"image": i, "label": l} for i, l in zip(input_files, label_files)]
    return data_dicts[:Config.N_SAMPLES]

# Execute Unpack
unpack_dataset(Config.TAR_PATH, Config.EXTRACT_PATH, limit=Config.N_SAMPLES)
all_files = get_file_lists(Config.EXTRACT_PATH)

# Split Data
val_count = int(len(all_files) * Config.VAL_SPLIT)
val_count = max(1, val_count) # Ensure at least 1 val sample
train_files, val_files = all_files[val_count:], all_files[:val_count]

print(f"üìä Dataset Split: Train={len(train_files)} | Val={len(val_files)}")

# =============================================================================
# üîÑ TRANSFORMS & LOADERS
# =============================================================================
# Training: Aggressive Augmentation + Fixed Cropping
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image"], channel_dim=0),
    EnsureTyped(keys=["image", "label"]),
    ScaleIntensityd(keys=["image"]),
    # Safety: Pad small images to avoid crash during crop
    SpatialPadd(keys=["image", "label"], spatial_size=Config.IMG_SIZE, method='symmetric'),
    RandCropByPosNegLabeld(
        keys=["image", "label"], label_key="label",
        spatial_size=Config.IMG_SIZE, pos=1, neg=1, num_samples=1,
        image_key="image", image_threshold=0,
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    RandShiftIntensityd(keys=["image"], offsets=0.1, prob=0.5),
])

# Validation: Full Volume (No Crop) + Divisible Padding
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image"], channel_dim=0),
    EnsureTyped(keys=["image", "label"]),
    ScaleIntensityd(keys=["image"]),
    # Safety 1: Pad to minimum size
    SpatialPadd(keys=["image", "label"], spatial_size=Config.IMG_SIZE, method='symmetric'),
    # Safety 2: Pad to be divisible by 16 (Required for U-Net/Mamba downsampling)
    DivisiblePadd(keys=["image", "label"], k=16)
])

print("‚è≥ Initializing Loaders (Lazy Loading)...")
train_ds = Dataset(data=train_files, transform=train_transforms)
val_ds = Dataset(data=val_files, transform=val_transforms)

train_loader = DataLoader(
    train_ds, batch_size=Config.BATCH_SIZE_TRAIN, shuffle=True, 
    num_workers=Config.NUM_WORKERS, pin_memory=True
)
# Note: Batch Size 1 is critical for Val to handle variable volume sizes
val_loader = DataLoader(
    val_ds, batch_size=Config.BATCH_SIZE_VAL, shuffle=False, 
    num_workers=Config.NUM_WORKERS, pin_memory=True
)
print("‚úÖ Loaders Ready.")

## Visualization

In [None]:
# =============================================================================
# üëÅÔ∏è DATA SANITY CHECK
# =============================================================================
def visualize_batch(loader, save_path=None):
    print("‚è≥ Fetching batch for visualization...")
    try:
        batch = next(iter(loader))
    except Exception as e:
        print(f"‚ùå Error fetching batch: {e}"); return

    images, masks = batch["image"], batch["label"]
    print(f"   Batch Shape: {images.shape}")
    
    # Search for a sample with tumor content
    sample_idx = 0
    for i in range(len(images)):
        if masks[i].sum() > 0:
            sample_idx = i; break
    
    img_t = images[sample_idx]
    msk_t = masks[sample_idx]
    if msk_t.ndim == 4: msk_t = msk_t[0] # Handle channel dim
    
    # Find Axial Slice with Max Tumor Area
    tumor_counts = torch.sum(msk_t > 0, dim=(1, 2))
    slice_idx = torch.argmax(tumor_counts).item()
    if tumor_counts.max() == 0: slice_idx = img_t.shape[1] // 2
    
    print(f"‚úÖ Visualizing Sample {sample_idx}, Slice {slice_idx}")

    # Plot
    slice_img = img_t[:, slice_idx, :, :].cpu().numpy()
    slice_msk = msk_t[slice_idx, :, :].cpu().numpy()
    
    fig, ax = plt.subplots(1, 5, figsize=(20, 5))
    modes = ["T1", "T1ce", "T2", "FLAIR", "Mask"]
    for i in range(4):
        ax[i].imshow(slice_img[i], cmap="gray")
        ax[i].set_title(modes[i]); ax[i].axis("off")
    
    ax[4].imshow(slice_msk, cmap="jet", vmin=0, vmax=3)
    ax[4].set_title("Ground Truth"); ax[4].axis("off")
    
    if save_path: plt.savefig(save_path)
    plt.show()

visualize_batch(train_loader, save_path=os.path.join(Config.ARTIFACTS_DIR, "sanity_check.png"))

## Model Architecture: BratMamba

                           ~Fusing Local Texture with Global Context~

Our architecture, **BratMamba**, addresses the limitations of purely Convolutional networks (like nnU-Net) and purely Transformer-based networks (like Swin UNETR) by leveraging the **Linear Complexity ($O(N)$)** of State Space Models (Mamba).

*4.1 Key Design Decisions & Citations*

1. **The Dual-Stage CNN Stem (Our Innovation)**
      * **The Component**: Instead of a single $7 \times 7$ patch embedding (like in Swin UNETR), we split the input into two parallel paths: one with a small kernel ($3 \times 3$) and one with a large kernel ($7 \times 7$).
          * **Stream A ($3 \times 3$ Kernel)**: Captures high-frequency details (sharp edges of the Necrotic Core).
          * **Stream B ($7 \times 7$ Kernel)**: Captures low-frequency context (large Edema regions).
          * **Fusion**: These are concatenated to give the Mamba blocks a "rich" feature set that contains both texture and context. 
      * **The Reason**:
          * **Local Texture**: MRI Brain tumor boundaries are defined by subtle texture changes (Necrosis vs. Edema). Small kernels capture these high-frequency edges.
          * **Receptive Field**: Large kernels capture the "neighborhood" context immediately.
          * **Paper Reference**: *Swin UNETR* (Hatamizadeh et al., 2022) highlights the importance of patch merging, but notes that Transformers often lose local spatial details early on. Our Dual-Stem preserves this before the Mamba layers take over.
          
2. **The Mamba Encoder (The Engine)**
      * **The Component**: Stacked Mamba Blocks that flatten the 3D volume into a 1D sequence.
      * **The Reason**:
          * **The Problem**: Standard Self-Attention (Transformers) scales quadratically ($N^2$). For a 3D volume of $128^3$, the sequence length is ~2 million. A Transformer would run out of memory immediately.
          * **The Solution**: Mamba scales linearly ($N$). It allows us to scan the entire 3D brain volume as a single sequence, understanding that "a pixel at the top left" is related to "a pixel at the bottom right" (Global Context).
            * **Global Scanning**: The Mamba block scans the image left-to-right, right-to-left, and top-to-bottom, effectively "seeing" the whole brain at once to decide if a pixel is tumor or noise. 
          * **Paper Reference**: SegMamba (Xing et al., 2024) demonstrates that Mamba outperforms Transformers on 3D medical data by reducing memory usage by 60% while improving Dice scores on the BraTS dataset.

3. **Deep Supervision & Skip Connections**
    * **The Component**: Direct connections between the Encoder and Decoder at matching resolutions.
    * **The Reason**: As the network goes deeper to understand "shapes" (Tumor vs Brain), it loses spatial resolution. Skip connections inject the high-resolution texture details from the Stem directly into the Decoder, ensuring the final mask has sharp edges.
    * **Paper Reference**: nnU-Net (Isensee et al., 2020) proves that robust encoder-decoder connections are often more important than the choice of optimizer or activation function.

In [None]:
# =============================================================================
# üß† MODEL ARCHITECTURE: BRATMAMBA
# =============================================================================
try:
    from mamba_ssm import Mamba
except ImportError:
    print("‚ö†Ô∏è WARNING: mamba_ssm not found. Using Mock layer (Install for real training!)")
    class Mamba(nn.Module):
        def __init__(self, d_model, d_state, d_conv, expand): super().__init__()
        def forward(self, x): return x

class DualConvStem(nn.Module):
    """Hybrid Stem: Captures fine details (3x3) and coarse context (7x7)."""
    def __init__(self, in_chans, out_chans):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv3d(in_chans, out_chans // 2, kernel_size=3, padding=1, stride=2),
            nn.InstanceNorm3d(out_chans // 2), nn.GELU()
        )
        self.branch2 = nn.Sequential(
            nn.Conv3d(in_chans, out_chans // 2, kernel_size=7, padding=3, stride=2),
            nn.InstanceNorm3d(out_chans // 2), nn.GELU()
        )
        self.fusion = nn.Conv3d(out_chans, out_chans, kernel_size=1)

    def forward(self, x):
        return self.fusion(torch.cat([self.branch1(x), self.branch2(x)], dim=1))

class MambaLayer(nn.Module):
    """Volumetric Mamba Block: Flattens 3D -> Sequence -> Mamba -> 3D."""
    def __init__(self, dim, d_state=16, d_conv=4, expand=2):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.mamba = Mamba(d_model=dim, d_state=d_state, d_conv=d_conv, expand=expand)
    
    def forward(self, x):
        B, C, D, H, W = x.shape
        x_flat = x.flatten(2).transpose(1, 2)
        x_mamba = self.mamba(self.norm(x_flat))
        out = x_flat + x_mamba
        return out.transpose(1, 2).view(B, C, D, H, W)

class UpBlock(nn.Module):
    """Standard U-Net Decoder Block."""
    def __init__(self, in_chans, out_chans):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_chans, out_chans, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv3d(out_chans * 2, out_chans, kernel_size=3, padding=1),
            nn.InstanceNorm3d(out_chans), nn.GELU()
        )

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape != skip.shape: # Handle padding mismatches
            x = nn.functional.interpolate(x, size=skip.shape[2:], mode='trilinear')
        return self.conv(torch.cat([x, skip], dim=1))

class BratMamba(nn.Module):
    def __init__(self, in_chans=4, num_classes=4, embed_dim=48):
        super().__init__()
        self.stem = DualConvStem(in_chans, embed_dim)
        self.layer1 = MambaLayer(embed_dim)
        self.down1 = nn.Conv3d(embed_dim, embed_dim*2, kernel_size=3, stride=2, padding=1)
        self.layer2 = MambaLayer(embed_dim*2)
        self.down2 = nn.Conv3d(embed_dim*2, embed_dim*4, kernel_size=3, stride=2, padding=1)
        self.bottleneck = MambaLayer(embed_dim*4)
        
        self.up1 = UpBlock(embed_dim*4, embed_dim*2)
        self.up2 = UpBlock(embed_dim*2, embed_dim)
        self.final_up = nn.ConvTranspose3d(embed_dim, embed_dim, kernel_size=2, stride=2)
        self.out_head = nn.Conv3d(embed_dim, num_classes, kernel_size=1)

    def forward(self, x):
        x1 = self.layer1(self.stem(x))
        x2 = self.layer2(self.down1(x1))
        x3 = self.bottleneck(self.down2(x2))
        d1 = self.up1(x3, x2)
        d2 = self.up2(d1, x1)
        return self.out_head(self.final_up(d2))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BratMamba().to(DEVICE)
print(f"‚úÖ Model Initialized on {DEVICE}")

## Training Loop

In [None]:
# =============================================================================
# üöÄ TRAINING LOOP
# =============================================================================
GPU_COUNT = torch.cuda.device_count()
if GPU_COUNT > 1:
    print(f"‚ö° Using {GPU_COUNT} GPUs (DataParallel)")
    model = nn.DataParallel(model)

criterion = DiceCELoss(to_onehot_y=True, softmax=True, include_background=False)
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.NUM_EPOCHS)
scaler = GradScaler('cuda') 
dice_metric = DiceMetric(include_background=False, reduction="mean")

# Resume Logic
history = {'train_loss': [], 'val_loss': [], 'val_dice': [], 'epochs': []}
start_epoch, best_dice = 0, 0.0
last_ckpt = os.path.join(Config.CHECKPOINT_DIR, "last.pth")
history_path = os.path.join(Config.RESULTS_DIR, "log.json")

if os.path.exists(last_ckpt):
    print("üîÑ Resuming from checkpoint...")
    ckpt = torch.load(last_ckpt)
    
    # Load State Dict handling DataParallel wrapper
    sd = ckpt['model_state_dict']
    if isinstance(model, nn.DataParallel): model.module.load_state_dict(sd)
    else: 
        # Fix if resuming non-parallel on parallel or vice-versa
        new_sd = {k.replace('module.', ''): v for k, v in sd.items()}
        model.load_state_dict(new_sd, strict=False)
        
    optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    scaler.load_state_dict(ckpt['scaler_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    best_dice = ckpt['best_dice']
    if os.path.exists(history_path):
        with open(history_path, 'r') as f: history = json.load(f)

print(f"üöÄ Training starting at Epoch {start_epoch+1}...")
START_TIME = time.time()

for epoch in range(start_epoch, Config.NUM_EPOCHS):
    if time.time() - START_TIME > Config.MAX_RUNTIME:
        print("üõë Time limit reached."); break
        
    model.train()
    ep_loss = 0
    
    # Train Step
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Ep {epoch+1}")
    for i, batch in pbar:
        img, lbl = batch["image"].to(DEVICE), batch["label"].to(DEVICE)
        
        optimizer.zero_grad()
        with autocast('cuda'):
            pred = model(img)
            loss = criterion(pred, lbl)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        ep_loss += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    # Val Step
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Val"):
            img, lbl = batch["image"].to(DEVICE), batch["label"].to(DEVICE)
            with autocast('cuda'):
                pred = model(img)
                val_loss += criterion(pred, lbl).item()
            
            p = [AsDiscrete(argmax=True, to_onehot=4)(i) for i in pred]
            t = [AsDiscrete(to_onehot=4)(i) for i in lbl]
            dice_metric(y_pred=p, y=t)
            
    # Stats
    stats = {
        'train_loss': ep_loss / len(train_loader),
        'val_loss': val_loss / len(val_loader),
        'val_dice': dice_metric.aggregate().item()
    }
    dice_metric.reset()
    
    # Update History & Log
    history['epochs'].append(epoch+1)
    for k, v in stats.items(): history[k].append(v)
    with open(history_path, 'w') as f: json.dump(history, f)
    
    print(f"   Stats: Train={stats['train_loss']:.4f} | Val={stats['val_loss']:.4f} | Dice={stats['val_dice']:.4f}")
    
    # Save State
    state = model.module.state_dict() if isinstance(model, nn.DataParallel) else model.state_dict()
    ckpt = {
        'epoch': epoch, 'model_state_dict': state,
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(), 'best_dice': best_dice
    }
    torch.save(ckpt, last_ckpt)
    
    if stats['val_dice'] > best_dice:
        print(f"   ‚≠ê New Best! {best_dice:.4f} -> {stats['val_dice']:.4f}")
        best_dice = stats['val_dice']
        torch.save(state, os.path.join(Config.BEST_MODEL_DIR, "best_model.pth"))
        
    scheduler.step()

# Plot
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1); plt.plot(history['train_loss'], label='Train'); plt.plot(history['val_loss'], label='Val'); plt.legend()
plt.subplot(1, 2, 2); plt.plot(history['val_dice'], label='Dice', color='green'); plt.legend()
plt.savefig(os.path.join(Config.RESULTS_DIR, "training_curves.png"))
plt.show()

## Evaluation 

In [None]:
# =============================================================================
# üèÜ EVALUATION & METRICS
# =============================================================================
def get_brats_regions(tensor_oh):
    """Converts One-Hot (BG, NCR, ED, ET) -> (WT, TC, ET)"""
    wt = torch.sum(tensor_oh[:, 1:4, ...], dim=1, keepdim=True) > 0
    tc = (tensor_oh[:, 1:2, ...] + tensor_oh[:, 3:4, ...]) > 0
    et = tensor_oh[:, 3:4, ...] > 0
    return torch.cat([wt, tc, et], dim=1).float()

def evaluate_model(model, loader):
    model.eval()
    dice_metric = DiceMetric(include_background=True, reduction="mean_batch")
    hd95_metric = HausdorffDistanceMetric(include_background=True, percentile=95, reduction="mean_batch")
    post_pred = AsDiscrete(argmax=True, to_onehot=4)
    post_label = AsDiscrete(to_onehot=4)
    
    print("üîç Starting Final Evaluation...")
    with torch.no_grad():
        for i, batch in tqdm(enumerate(loader), total=len(loader)):
            img, lbl = batch["image"].to(DEVICE), batch["label"].to(DEVICE)
            pred = model(img)
            
            # Post-process
            pred_oh = torch.stack([post_pred(x) for x in pred])
            lbl_oh = torch.stack([post_label(x) for x in lbl])
            
            # Convert to BraTS Regions
            pred_reg = get_brats_regions(pred_oh)
            lbl_reg = get_brats_regions(lbl_oh)
            
            dice_metric(y_pred=pred_reg, y=lbl_reg)
            hd95_metric(y_pred=pred_reg, y=lbl_reg)
            
            # Visualize First Batch
            if i == 0:
                visualize_prediction(img[0], lbl[0], pred[0], 
                                     os.path.join(Config.RESULTS_DIR, "best_pred.png"))

    # Aggregate
    dice = dice_metric.aggregate().cpu().numpy()
    hd95 = hd95_metric.aggregate().cpu().numpy()
    
    df = pd.DataFrame({
        "Region": ["Whole Tumor (WT)", "Tumor Core (TC)", "Enhancing Tumor (ET)"],
        "Dice (DSC) ‚Üë": dice,
        "HD95 (mm) ‚Üì": hd95,
        "Mean Dice": [dice.mean()] * 3
    })
    
    csv_path = os.path.join(Config.RESULTS_DIR, "final_metrics.csv")
    df.to_csv(csv_path, index=False)
    print("\nüèÜ FINAL SCORES:"); print(df.to_string(index=False, float_format="%.4f"))

def visualize_prediction(img, lbl, pred, save_path):
    """Helper to visualize input vs prediction."""
    vol = lbl.sum(dim=(0,1,2)); idx = torch.argmax(vol).item()
    if vol.max() == 0: idx = img.shape[2] // 2
    
    im = img[2, :, :, idx].cpu().numpy()
    gt = torch.argmax(lbl, dim=0)[:, :, idx].cpu().numpy()
    pr = torch.argmax(pred, dim=0)[:, :, idx].cpu().numpy()
    
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    ax[0].imshow(im, cmap='gray'); ax[0].set_title("Input (T2)")
    ax[1].imshow(gt, cmap='jet', vmin=0, vmax=3); ax[1].set_title("GT")
    ax[2].imshow(pr, cmap='jet', vmin=0, vmax=3); ax[2].set_title("Pred")
    plt.savefig(save_path); plt.close()

# Load Best & Run
best_path = os.path.join(Config.BEST_MODEL_DIR, "best_model.pth")
if os.path.exists(best_path):
    print(f"üìÇ Loading Best Model: {best_path}")
    sd = torch.load(best_path)
    if isinstance(model, nn.DataParallel): model.module.load_state_dict(sd)
    else: model.load_state_dict(sd)
    evaluate_model(model, val_loader)
else:
    print("‚ö†Ô∏è No best model found.")