# MedNeXt Training and Data Preparation
This notebook prepares BraTS data and sets up MedNeXt training workflows.

In [None]:
import tarfile
import os
import glob
from tqdm import tqdm

input_dir = '/kaggle/input/brats-2021-task1'
extract_dir = '/kaggle/working/brats_data'

os.makedirs(extract_dir, exist_ok=True)

tar_files = glob.glob(os.path.join(input_dir, '*.tar'))

print(f"Found {len(tar_files)} tar files. Extracting to {extract_dir}...")

for tar_path in tqdm(tar_files):
    try:
        with tarfile.open(tar_path) as tar:
            tar.extractall(path=extract_dir)
            print(f"Extracted: {os.path.basename(tar_path)}")
    except Exception as e:
        print(f"Error extracting {tar_path}: {e}")

print("Extraction complete.")

Found 3 tar files. Extracting to /kaggle/working/brats_data...


  0%|          | 0/3 [00:00<?, ?it/s]

✅ Extracted: BraTS2021_00495.tar


  tar.extractall(path=extract_dir)
100%|██████████| 3/3 [01:20<00:00, 26.85s/it]

✅ Extracted: BraTS2021_Training_Data.tar
✅ Extracted: BraTS2021_00621.tar

Extraction Complete!





In [15]:
# List the contents of the extraction folder
contents = os.listdir(extract_dir)
print(f"Root contents: {contents[:5]}") # Print first 5 items

# Check if there is a subfolder (common in tar files)
if 'BraTS2021_Training_Data' in contents:
    final_data_path = os.path.join(extract_dir, 'BraTS2021_Training_Data')
else:
    final_data_path = extract_dir

print("-" * 30)
print(f"YOUR DATA PATH IS: {final_data_path}")
print("-" * 30)

# Verify we can see subject folders (e.g., BraTS2021_00001)
subjects = [x for x in os.listdir(final_data_path) if os.path.isdir(os.path.join(final_data_path, x))]
print(f"Found {len(subjects)} subject folders.")
print(f"Sample subjects: {subjects[:3]}")

Root contents: ['BraTS2021_01524', 'BraTS2021_01509', 'BraTS2021_01139', 'BraTS2021_01001', 'BraTS2021_00376']
------------------------------
YOUR DATA PATH IS: /kaggle/working/brats_data
------------------------------
Found 1251 subject folders.
Sample subjects: ['BraTS2021_01524', 'BraTS2021_01509', 'BraTS2021_01139']


In [16]:
!git clone https://github.com/MIC-DKFZ/MedNeXt.git mednext
!pip install -e ./mednext

fatal: destination path 'mednext' already exists and is not an empty directory.
Obtaining file:///kaggle/working/mednext
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting argparse (from unittest2->batchgenerators>=0.23->mednextv1==1.7.0)
  Using cached argparse-1.4.0-py2.py3-none-any.whl.metadata (2.8 kB)
Using cached argparse-1.4.0-py2.py3-none-any.whl (23 kB)
Installing collected packages: argparse, mednextv1
  Attempting uninstall: mednextv1
    Found existing installation: mednextv1 1.7.0
    Uninstalling mednextv1-1.7.0:
      Successfully uninstalled mednextv1-1.7.0
  Running setup.py develop for mednextv1
Successfully installed argparse-1.4.0 mednextv1-1.7.0


In [17]:
!pip install nibabel SimpleITK tqdm torchio



In [None]:
import sys
sys.path.insert(0, 'mednext')

import os
import glob
import numpy as np
import nibabel as nib
from tqdm import tqdm
import random
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast

# TorchIO for advanced 3D medical image augmentations
import torchio as tio

# Import MedNeXt
from nnunet_mednext import create_mednext_v1, MedNeXt

print(f"PyTorch Version: {torch.__version__}")
print(f"TorchIO Version: {tio.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch Version: 2.8.0+cu126
TorchIO Version: 0.21.2
CUDA Available: True
GPU: Tesla P100-PCIE-16GB
GPU Memory: 17.06 GB


In [20]:
class BraTS2021Dataset(Dataset):
    """
    BraTS2021 Dataset for brain tumor segmentation.
    
    Each sample contains:
    - 4 MRI modalities: T1, T1ce, T2, FLAIR
    - Segmentation mask with labels: 0 (background), 1 (NCR/NET), 2 (ED), 4 (ET)
    
    Uses patch-based sampling for memory efficiency.
    """
    
    def __init__(
        self, 
        data_dir: str, 
        subject_ids: List[str],
        patch_size: Tuple[int, int, int] = (128, 128, 128),
        samples_per_volume: int = 2,
        is_training: bool = True,
        augment: bool = True
    ):
        self.data_dir = data_dir
        self.subject_ids = subject_ids
        self.patch_size = patch_size
        self.samples_per_volume = samples_per_volume
        self.is_training = is_training
        self.augment = augment and is_training
        
        # Modality suffixes in BraTS2021
        self.modalities = ['t1', 't1ce', 't2', 'flair']
        
    def __len__(self):
        return len(self.subject_ids) * self.samples_per_volume
    
    def _load_nifti(self, filepath: str) -> np.ndarray:
        """Load a NIfTI file and return the data array."""
        img = nib.load(filepath)
        return img.get_fdata().astype(np.float32)
    
    def _normalize(self, data: np.ndarray) -> np.ndarray:
        """Z-score normalization per volume (non-zero voxels only)."""
        mask = data > 0
        if mask.sum() > 0:
            mean = data[mask].mean()
            std = data[mask].std()
            if std > 0:
                data = (data - mean) / std
                data[~mask] = 0
        return data
    
    def _convert_labels(self, seg: np.ndarray) -> np.ndarray:
        """
        Convert BraTS labels to consecutive integers.
        BraTS: 0 (background), 1 (NCR/NET), 2 (ED), 4 (ET)
        Output: 0 (background), 1 (NCR/NET), 2 (ED), 3 (ET)
        """
        new_seg = np.zeros_like(seg)
        new_seg[seg == 0] = 0
        new_seg[seg == 1] = 1
        new_seg[seg == 2] = 2
        new_seg[seg == 4] = 3
        return new_seg
    
    def _get_random_patch(
        self, 
        volume: np.ndarray, 
        seg: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Extract a random patch from the volume."""
        c, d, h, w = volume.shape
        pd, ph, pw = self.patch_size
        
        # Ensure we don't go out of bounds
        d_start = random.randint(0, max(0, d - pd))
        h_start = random.randint(0, max(0, h - ph))
        w_start = random.randint(0, max(0, w - pw))
        
        # Extract patch
        volume_patch = volume[:, d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]
        seg_patch = seg[d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]
        
        # Pad if necessary
        if volume_patch.shape[1:] != self.patch_size:
            pad_d = pd - volume_patch.shape[1]
            pad_h = ph - volume_patch.shape[2]
            pad_w = pw - volume_patch.shape[3]
            volume_patch = np.pad(volume_patch, ((0, 0), (0, pad_d), (0, pad_h), (0, pad_w)))
            seg_patch = np.pad(seg_patch, ((0, pad_d), (0, pad_h), (0, pad_w)))
        
        return volume_patch, seg_patch
    
    def _get_tumor_centered_patch(
        self, 
        volume: np.ndarray, 
        seg: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Extract a patch centered on tumor region for better sampling."""
        c, d, h, w = volume.shape
        pd, ph, pw = self.patch_size
        
        # Find tumor voxels
        tumor_mask = seg > 0
        if tumor_mask.sum() > 0:
            tumor_coords = np.where(tumor_mask)
            # Random tumor voxel as center
            idx = random.randint(0, len(tumor_coords[0]) - 1)
            center_d = tumor_coords[0][idx]
            center_h = tumor_coords[1][idx]
            center_w = tumor_coords[2][idx]
            
            # Calculate start positions
            d_start = max(0, min(center_d - pd // 2, d - pd))
            h_start = max(0, min(center_h - ph // 2, h - ph))
            w_start = max(0, min(center_w - pw // 2, w - pw))
        else:
            # Fallback to random patch
            return self._get_random_patch(volume, seg)
        
        # Extract patch
        volume_patch = volume[:, d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]
        seg_patch = seg[d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]
        
        # Pad if necessary
        if volume_patch.shape[1:] != self.patch_size:
            pad_d = pd - volume_patch.shape[1]
            pad_h = ph - volume_patch.shape[2]
            pad_w = pw - volume_patch.shape[3]
            volume_patch = np.pad(volume_patch, ((0, 0), (0, pad_d), (0, pad_h), (0, pad_w)))
            seg_patch = np.pad(seg_patch, ((0, pad_d), (0, pad_h), (0, pad_w)))
        
        return volume_patch, seg_patch
    
    def _augment(self, volume: np.ndarray, seg: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        """Apply simple augmentations."""
        # Random flips
        if random.random() > 0.5:
            volume = np.flip(volume, axis=1).copy()
            seg = np.flip(seg, axis=0).copy()
        if random.random() > 0.5:
            volume = np.flip(volume, axis=2).copy()
            seg = np.flip(seg, axis=1).copy()
        if random.random() > 0.5:
            volume = np.flip(volume, axis=3).copy()
            seg = np.flip(seg, axis=2).copy()
        
        # Random intensity scaling (for each modality independently)
        if random.random() > 0.5:
            scale = random.uniform(0.9, 1.1)
            volume = volume * scale
        
        return volume, seg
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # Determine which subject and which sample from that subject
        subject_idx = idx // self.samples_per_volume
        subject_id = self.subject_ids[subject_idx]
        subject_dir = os.path.join(self.data_dir, subject_id)
        
        # Load all modalities
        modality_data = []
        for mod in self.modalities:
            filepath = os.path.join(subject_dir, f"{subject_id}_{mod}.nii.gz")
            data = self._load_nifti(filepath)
            data = self._normalize(data)
            modality_data.append(data)
        
        # Stack modalities: (4, D, H, W)
        volume = np.stack(modality_data, axis=0)
        
        # Load segmentation
        seg_path = os.path.join(subject_dir, f"{subject_id}_seg.nii.gz")
        seg = self._load_nifti(seg_path)
        seg = self._convert_labels(seg)
        
        # Extract patch (alternate between random and tumor-centered)
        if self.is_training:
            if random.random() > 0.5:
                volume_patch, seg_patch = self._get_tumor_centered_patch(volume, seg)
            else:
                volume_patch, seg_patch = self._get_random_patch(volume, seg)
        else:
            # For validation, use center patch
            volume_patch, seg_patch = self._get_tumor_centered_patch(volume, seg)
        
        # Apply augmentations
        if self.augment:
            volume_patch, seg_patch = self._augment(volume_patch, seg_patch)
        
        # Convert to tensors
        volume_tensor = torch.from_numpy(volume_patch.copy()).float()
        seg_tensor = torch.from_numpy(seg_patch.copy()).long()
        
        return volume_tensor, seg_tensor

print("BraTS2021Dataset class defined successfully!")

BraTS2021Dataset class defined successfully!


In [37]:
# ============== Configuration ==============

class Config:
    # Data paths
    DATA_DIR = final_data_path
    CHECKPOINT_DIR = "checkpoints"
    
    # Model configuration
    MODEL_SIZE = 'B'          # 'S' (Small), 'B' (Base), 'M' (Medium), 'L' (Large)
    KERNEL_SIZE = 3           # 3 or 5
    IN_CHANNELS = 4           # 4 MRI modalities (T1, T1ce, T2, FLAIR)
    NUM_CLASSES = 4           # Background + 3 tumor regions
    DEEP_SUPERVISION = True   # Use deep supervision for training
    ACCUMULATION_STEPS = 1    # Gradient accumulation steps
    
    # Training configuration
    BATCH_SIZE = 1            # Due to 3D volumes, usually batch size 1-2
    NUM_EPOCHS = 20
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    
    # Patch-based training (for memory efficiency)
    PATCH_SIZE = (128, 128, 128)  # Size of patches to extract from volumes
    SAMPLES_PER_VOLUME = 2        # Number of patches per volume per epoch
    
    # Data split
    TRAIN_RATIO = 0.8
    VAL_RATIO = 0.2
    
    # Mixed precision training
    USE_AMP = True
    
    # Random seed for reproducibility
    SEED = 42
    
    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create checkpoint directory
os.makedirs(Config.CHECKPOINT_DIR, exist_ok=True)

# Set random seeds
torch.manual_seed(Config.SEED)
np.random.seed(Config.SEED)
random.seed(Config.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(Config.SEED)

print(f"Using device: {Config.DEVICE}")
print(f"Model size: {Config.MODEL_SIZE}, Kernel size: {Config.KERNEL_SIZE}x{Config.KERNEL_SIZE}x{Config.KERNEL_SIZE}")

Using device: cuda
Model size: B, Kernel size: 3x3x3


In [21]:
# Get all subject IDs from the data directory
all_subjects = sorted([
    d for d in os.listdir(Config.DATA_DIR) 
    if os.path.isdir(os.path.join(Config.DATA_DIR, d)) and d.startswith('BraTS2021_')
])

print(f"Total subjects found: {len(all_subjects)}")

# Split into train and validation
random.shuffle(all_subjects)
split_idx = int(len(all_subjects) * Config.TRAIN_RATIO)
train_subjects = all_subjects[:split_idx]
val_subjects = all_subjects[split_idx:]

print(f"Training subjects: {len(train_subjects)}")
print(f"Validation subjects: {len(val_subjects)}")
print(f"\nFirst few training subjects: {train_subjects[:5]}")
print(f"First few validation subjects: {val_subjects[:5]}")

Total subjects found: 1251
Training subjects: 1000
Validation subjects: 251

First few training subjects: ['BraTS2021_01417', 'BraTS2021_01632', 'BraTS2021_01664', 'BraTS2021_01305', 'BraTS2021_01339']
First few validation subjects: ['BraTS2021_00570', 'BraTS2021_00003', 'BraTS2021_00555', 'BraTS2021_00238', 'BraTS2021_00094']


In [22]:
# Create datasets and dataloaders
train_dataset = BraTS2021Dataset(
    data_dir=Config.DATA_DIR,
    subject_ids=train_subjects,
    patch_size=Config.PATCH_SIZE,
    samples_per_volume=Config.SAMPLES_PER_VOLUME,
    is_training=True,
    augment=True
)

val_dataset = BraTS2021Dataset(
    data_dir=Config.DATA_DIR,
    subject_ids=val_subjects,
    patch_size=Config.PATCH_SIZE,
    samples_per_volume=1,  # One sample per volume for validation
    is_training=False,
    augment=False
)

train_loader = DataLoader(
    train_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print(f"Training samples per epoch: {len(train_dataset)}")
print(f"Validation samples per epoch: {len(val_dataset)}")
print(f"Training batches per epoch: {len(train_loader)}")
print(f"Validation batches per epoch: {len(val_loader)}")

Training samples per epoch: 2000
Validation samples per epoch: 251
Training batches per epoch: 2000
Validation batches per epoch: 251


In [24]:
# Verify a sample batch
sample_volume, sample_seg = next(iter(train_loader))
print(f"Volume shape: {sample_volume.shape}")  # Expected: (B, 4, 128, 128, 128)
print(f"Segmentation shape: {sample_seg.shape}")  # Expected: (B, 128, 128, 128)
print(f"Volume dtype: {sample_volume.dtype}")
print(f"Segmentation dtype: {sample_seg.dtype}")
print(f"Unique segmentation values: {torch.unique(sample_seg).tolist()}")

Volume shape: torch.Size([1, 4, 128, 128, 128])
Segmentation shape: torch.Size([1, 128, 128, 128])
Volume dtype: torch.float32
Segmentation dtype: torch.int64
Unique segmentation values: [0, 1, 2, 3]


In [25]:
class DiceLoss(nn.Module):
    """
    Soft Dice Loss for multi-class segmentation.
    """
    def __init__(self, smooth: float = 1e-5, include_background: bool = False):
        super().__init__()
        self.smooth = smooth
        self.include_background = include_background
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, D, H, W) - softmax probabilities
            targets: (B, D, H, W) - integer class labels
        """
        num_classes = predictions.shape[1]
        
        # Convert targets to one-hot
        targets_one_hot = F.one_hot(targets, num_classes)  # (B, D, H, W, C)
        targets_one_hot = targets_one_hot.permute(0, 4, 1, 2, 3).float()  # (B, C, D, H, W)
        
        # Skip background class if specified
        start_idx = 0 if self.include_background else 1
        
        dice_scores = []
        for c in range(start_idx, num_classes):
            pred_c = predictions[:, c]
            target_c = targets_one_hot[:, c]
            
            intersection = (pred_c * target_c).sum()
            union = pred_c.sum() + target_c.sum()
            
            dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
            dice_scores.append(dice)
        
        # Average Dice across classes
        mean_dice = torch.stack(dice_scores).mean()
        return 1.0 - mean_dice


class CombinedLoss(nn.Module):
    """
    Combined Dice Loss + Cross-Entropy Loss for segmentation.
    """
    def __init__(
        self, 
        dice_weight: float = 1.0, 
        ce_weight: float = 1.0,
        include_background: bool = False
    ):
        super().__init__()
        self.dice_weight = dice_weight
        self.ce_weight = ce_weight
        self.dice_loss = DiceLoss(include_background=include_background)
        self.ce_loss = nn.CrossEntropyLoss()
    
    def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        """
        Args:
            predictions: (B, C, D, H, W) - logits
            targets: (B, D, H, W) - integer class labels
        """
        # Softmax for Dice loss
        probs = F.softmax(predictions, dim=1)
        dice_loss = self.dice_loss(probs, targets)
        
        # Cross-entropy loss
        ce_loss = self.ce_loss(predictions, targets)
        
        return self.dice_weight * dice_loss + self.ce_weight * ce_loss


class DeepSupervisionLoss(nn.Module):
    """
    Loss function for deep supervision with multiple output scales.
    """
    def __init__(self, base_loss: nn.Module, weights: List[float] = None):
        super().__init__()
        self.base_loss = base_loss
        self.weights = weights
    
    def forward(
        self, 
        predictions: List[torch.Tensor], 
        targets: torch.Tensor
    ) -> torch.Tensor:
        """
        Args:
            predictions: List of (B, C, D, H, W) tensors at different scales
            targets: (B, D, H, W) - integer class labels at full resolution
        """
        if not isinstance(predictions, (list, tuple)):
            # Single output (no deep supervision)
            return self.base_loss(predictions, targets)
        
        n_outputs = len(predictions)
        if self.weights is None:
            # Default: exponentially decreasing weights
            self.weights = [1.0 / (2 ** i) for i in range(n_outputs)]
            total = sum(self.weights)
            self.weights = [w / total for w in self.weights]
        
        total_loss = 0.0
        for i, pred in enumerate(predictions):
            # Resize target to match prediction size
            if pred.shape[2:] != targets.shape[1:]:
                target_resized = F.interpolate(
                    targets.unsqueeze(1).float(), 
                    size=pred.shape[2:], 
                    mode='nearest'
                ).squeeze(1).long()
            else:
                target_resized = targets
            
            total_loss += self.weights[i] * self.base_loss(pred, target_resized)
        
        return total_loss


print("Loss functions defined successfully!")

Loss functions defined successfully!


In [26]:
def compute_dice_score(
    predictions: torch.Tensor, 
    targets: torch.Tensor, 
    num_classes: int = 4,
    include_background: bool = False
) -> Dict[str, float]:
    """
    Compute Dice score for each class.
    
    Args:
        predictions: (B, C, D, H, W) - softmax probabilities or logits
        targets: (B, D, H, W) - integer class labels
    
    Returns:
        Dictionary with Dice scores for each class
    """
    if predictions.shape[1] != num_classes:
        raise ValueError(f"Expected {num_classes} classes, got {predictions.shape[1]}")
    
    # Get predictions
    pred_classes = predictions.argmax(dim=1)  # (B, D, H, W)
    
    dice_scores = {}
    class_names = ['Background', 'NCR/NET', 'Edema', 'ET']
    
    start_idx = 0 if include_background else 1
    
    for c in range(start_idx, num_classes):
        pred_c = (pred_classes == c).float()
        target_c = (targets == c).float()
        
        intersection = (pred_c * target_c).sum()
        union = pred_c.sum() + target_c.sum()
        
        if union > 0:
            dice = (2.0 * intersection) / union
        else:
            dice = torch.tensor(1.0) if intersection == 0 else torch.tensor(0.0)
        
        dice_scores[class_names[c]] = dice.item()
    
    # Compute mean Dice (excluding background)
    tumor_dices = [dice_scores[name] for name in class_names[1:]]
    dice_scores['Mean'] = np.mean(tumor_dices)
    
    # BraTS-specific regions
    # Whole Tumor (WT): all tumor classes (1, 2, 3)
    pred_wt = (pred_classes > 0).float()
    target_wt = (targets > 0).float()
    wt_inter = (pred_wt * target_wt).sum()
    wt_union = pred_wt.sum() + target_wt.sum()
    dice_scores['WT'] = (2.0 * wt_inter / (wt_union + 1e-8)).item()
    
    # Tumor Core (TC): NCR/NET + ET (classes 1 and 3)
    pred_tc = ((pred_classes == 1) | (pred_classes == 3)).float()
    target_tc = ((targets == 1) | (targets == 3)).float()
    tc_inter = (pred_tc * target_tc).sum()
    tc_union = pred_tc.sum() + target_tc.sum()
    dice_scores['TC'] = (2.0 * tc_inter / (tc_union + 1e-8)).item()
    
    # Enhancing Tumor (ET): class 3
    pred_et = (pred_classes == 3).float()
    target_et = (targets == 3).float()
    et_inter = (pred_et * target_et).sum()
    et_union = pred_et.sum() + target_et.sum()
    dice_scores['ET'] = (2.0 * et_inter / (et_union + 1e-8)).item()
    
    return dice_scores

print("Evaluation metrics defined successfully!")

Evaluation metrics defined successfully!


In [27]:
# Create MedNeXt model
model = create_mednext_v1(
    num_input_channels=Config.IN_CHANNELS,
    num_classes=Config.NUM_CLASSES,
    model_id=Config.MODEL_SIZE,
    kernel_size=Config.KERNEL_SIZE,
    deep_supervision=Config.DEEP_SUPERVISION
)

# Move model to GPU
model = model.to(Config.DEVICE)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model: MedNeXt-{Config.MODEL_SIZE}")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Deep supervision: {Config.DEEP_SUPERVISION}")
print(f"Device: {Config.DEVICE}")

Model: MedNeXt-B
Total parameters: 10,530,325
Trainable parameters: 10,530,325
Deep supervision: True
Device: cuda


In [28]:
# Verify model with a forward pass
with torch.no_grad():
    test_input = torch.randn(1, 4, 128, 128, 128).to(Config.DEVICE)
    test_output = model(test_input)
    
    if isinstance(test_output, (list, tuple)):
        print(f"Deep supervision outputs: {len(test_output)}")
        for i, out in enumerate(test_output):
            print(f"  Output {i}: {out.shape}")
    else:
        print(f"Output shape: {test_output.shape}")

del test_input, test_output
torch.cuda.empty_cache()

Deep supervision outputs: 5
  Output 0: torch.Size([1, 4, 128, 128, 128])
  Output 1: torch.Size([1, 4, 64, 64, 64])
  Output 2: torch.Size([1, 4, 32, 32, 32])
  Output 3: torch.Size([1, 4, 16, 16, 16])
  Output 4: torch.Size([1, 4, 8, 8, 8])


In [29]:
# Loss function
base_loss = CombinedLoss(dice_weight=1.0, ce_weight=1.0, include_background=False)

if Config.DEEP_SUPERVISION:
    criterion = DeepSupervisionLoss(base_loss)
else:
    criterion = base_loss

# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=Config.LEARNING_RATE,
    weight_decay=Config.WEIGHT_DECAY
)

# Learning rate scheduler (Cosine Annealing with Warm Restarts)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,  # Restart every 10 epochs
    T_mult=2,  # Double the restart period each time
    eta_min=1e-6
)

# Mixed precision scaler
scaler = GradScaler(enabled=Config.USE_AMP)

print("Training setup complete!")
print(f"Optimizer: AdamW (lr={Config.LEARNING_RATE}, weight_decay={Config.WEIGHT_DECAY})")
print(f"Scheduler: CosineAnnealingWarmRestarts")
print(f"Mixed Precision: {Config.USE_AMP}")

Training setup complete!
Optimizer: AdamW (lr=0.0001, weight_decay=1e-05)
Scheduler: CosineAnnealingWarmRestarts
Mixed Precision: True


  scaler = GradScaler(enabled=Config.USE_AMP)


In [30]:
def train_one_epoch(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    scaler: GradScaler,
    device: torch.device,
    epoch: int
) -> Dict[str, float]:
    """Train for one epoch."""
    model.train()
    
    total_loss = 0.0
    all_dice_scores = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
    
    for batch_idx, (volumes, targets) in enumerate(pbar):
        volumes = volumes.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        with autocast(enabled=Config.USE_AMP):
            outputs = model(volumes)
            loss = criterion(outputs, targets)
        
        # Backward pass
        scaler.scale(loss).backward()
        
        # Gradient clipping
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        
        # Compute Dice scores (use first output if deep supervision)
        with torch.no_grad():
            if isinstance(outputs, (list, tuple)):
                pred = outputs[0]
            else:
                pred = outputs
            dice = compute_dice_score(pred, targets, num_classes=Config.NUM_CLASSES)
            all_dice_scores.append(dice)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'dice_mean': f"{dice['Mean']:.4f}"
        })
    
    # Compute average metrics
    avg_loss = total_loss / len(train_loader)
    avg_dice = {
        key: np.mean([d[key] for d in all_dice_scores]) 
        for key in all_dice_scores[0].keys()
    }
    
    return {'loss': avg_loss, **avg_dice}


@torch.no_grad()
def validate(
    model: nn.Module,
    val_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
    epoch: int
) -> Dict[str, float]:
    """Validate the model."""
    model.eval()
    
    total_loss = 0.0
    all_dice_scores = []
    
    pbar = tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]")
    
    for volumes, targets in pbar:
        volumes = volumes.to(device)
        targets = targets.to(device)
        
        with autocast(enabled=Config.USE_AMP):
            outputs = model(volumes)
            loss = criterion(outputs, targets)
        
        total_loss += loss.item()
        
        # Compute Dice scores
        if isinstance(outputs, (list, tuple)):
            pred = outputs[0]
        else:
            pred = outputs
        dice = compute_dice_score(pred, targets, num_classes=Config.NUM_CLASSES)
        all_dice_scores.append(dice)
        
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'dice_mean': f"{dice['Mean']:.4f}"
        })
    
    # Compute average metrics
    avg_loss = total_loss / len(val_loader)
    avg_dice = {
        key: np.mean([d[key] for d in all_dice_scores]) 
        for key in all_dice_scores[0].keys()
    }
    
    return {'loss': avg_loss, **avg_dice}


def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    scaler: GradScaler,
    epoch: int,
    metrics: Dict[str, float],
    filepath: str
):
    """Save a training checkpoint."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'metrics': metrics,
        'config': {
            'model_size': Config.MODEL_SIZE,
            'kernel_size': Config.KERNEL_SIZE,
            'in_channels': Config.IN_CHANNELS,
            'num_classes': Config.NUM_CLASSES,
            'deep_supervision': Config.DEEP_SUPERVISION,
            'patch_size': Config.PATCH_SIZE
        }
    }
    torch.save(checkpoint, filepath)
    print(f"Checkpoint saved: {filepath}")


def load_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler,
    scaler: GradScaler,
    filepath: str
) -> int:
    """Load a training checkpoint."""
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    print(f"Checkpoint loaded: {filepath}")
    print(f"Resuming from epoch {checkpoint['epoch'] + 1}")
    return checkpoint['epoch']


print("Training functions defined successfully!")

Training functions defined successfully!


In [31]:
# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'train_dice_mean': [],
    'val_dice_mean': [],
    'train_dice_wt': [],
    'val_dice_wt': [],
    'train_dice_tc': [],
    'val_dice_tc': [],
    'train_dice_et': [],
    'val_dice_et': [],
    'lr': []
}

best_val_dice = 0.0
patience = 20
patience_counter = 0
start_epoch = 0

# Optional: Resume from checkpoint
resume_training = False

if resume_training and os.path.exists(checkpoint_path):
    start_epoch = load_checkpoint(model, optimizer, scheduler, scaler, checkpoint_path) + 1

In [None]:
# Main training loop
print("=" * 60)
print("Starting Training")
print("=" * 60)
print(f"Improvements enabled:")
print(f"  ✓ Region-based loss (WT, TC, ET)")
print(f"  ✓ Gradient accumulation (effective batch size: {Config.BATCH_SIZE * Config.ACCUMULATION_STEPS})")
print(f"  ✓ TorchIO augmentations")
print(f"  ✓ Kernel size: {Config.KERNEL_SIZE}x{Config.KERNEL_SIZE}x{Config.KERNEL_SIZE}")
print("=" * 60)

for epoch in range(start_epoch, Config.NUM_EPOCHS):
    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']
    print(f"\nEpoch {epoch + 1}/{Config.NUM_EPOCHS} | LR: {current_lr:.2e}")
    print("-" * 40)
    
    # Train with gradient accumulation
    train_metrics = train_one_epoch(
        model, train_loader, criterion, optimizer, scaler, Config.DEVICE, epoch
    )
    
    # Validate
    val_metrics = validate(
        model, val_loader, criterion, Config.DEVICE, epoch
    )
    
    # Update scheduler
    scheduler.step()
    
    # Log metrics
    history['train_loss'].append(train_metrics['loss'])
    history['val_loss'].append(val_metrics['loss'])
    history['train_dice_mean'].append(train_metrics['Mean'])
    history['val_dice_mean'].append(val_metrics['Mean'])
    history['train_dice_wt'].append(train_metrics['WT'])
    history['val_dice_wt'].append(val_metrics['WT'])
    history['train_dice_tc'].append(train_metrics['TC'])
    history['val_dice_tc'].append(val_metrics['TC'])
    history['train_dice_et'].append(train_metrics['ET'])
    history['val_dice_et'].append(val_metrics['ET'])
    history['lr'].append(current_lr)
    
    # Print epoch summary
    print(f"\nEpoch {epoch + 1} Summary:")
    print(f"  Train Loss: {train_metrics['loss']:.4f} | Val Loss: {val_metrics['loss']:.4f}")
    print(f"  Train Dice (Mean): {train_metrics['Mean']:.4f} | Val Dice (Mean): {val_metrics['Mean']:.4f}")
    print(f"  Val Dice - WT: {val_metrics['WT']:.4f}, TC: {val_metrics['TC']:.4f}, ET: {val_metrics['ET']:.4f}")
    
    # Save latest checkpoint
    save_checkpoint(
        model, optimizer, scheduler, scaler, epoch, val_metrics,
        os.path.join(Config.CHECKPOINT_DIR, "latest_checkpoint.pt")
    )
    
    # Save best model
    if val_metrics['Mean'] > best_val_dice:
        best_val_dice = val_metrics['Mean']
        patience_counter = 0
        save_checkpoint(
            model, optimizer, scheduler, scaler, epoch, val_metrics,
            os.path.join(Config.CHECKPOINT_DIR, "best_model.pt")
        )
        print(f"  *** New best model! Val Dice Mean: {best_val_dice:.4f} ***")
    else:
        patience_counter += 1
        print(f"  No improvement for {patience_counter}/{patience} epochs")
    
    # Early stopping
    if patience_counter >= patience:
        print(f"\nEarly stopping triggered after {epoch + 1} epochs!")
        break
    
    # Save periodic checkpoints
    if (epoch + 1) % 10 == 0:
        save_checkpoint(
            model, optimizer, scheduler, scaler, epoch, val_metrics,
            os.path.join(Config.CHECKPOINT_DIR, f"checkpoint_epoch_{epoch+1}.pt")
        )

print("\n" + "=" * 60)
print("Training Complete!")
print(f"Best Validation Dice Mean: {best_val_dice:.4f}")
print("=" * 60)


# ============== Save Best Model for Deployment ==============
best_model_path = os.path.join(Config.CHECKPOINT_DIR, "best_model.pt")
if os.path.exists(best_model_path):
    # Load the best model
    best_checkpoint = torch.load(best_model_path)
    model.load_state_dict(best_checkpoint['model_state_dict'])
    
    # Save model weights only (for easy deployment/inference elsewhere)
    deployment_path = os.path.join(Config.CHECKPOINT_DIR, "mednext_brats2021_weights.pt")
    torch.save({
        'model_state_dict': model.state_dict(),
        'metrics': best_checkpoint['metrics'],
        'epoch': best_checkpoint['epoch'],
        'config': {
            'model_size': Config.MODEL_SIZE,
            'kernel_size': Config.KERNEL_SIZE,
            'in_channels': Config.IN_CHANNELS,
            'num_classes': Config.NUM_CLASSES,
            'deep_supervision': Config.DEEP_SUPERVISION,
            'patch_size': Config.PATCH_SIZE
        }
    }, deployment_path)
    
    print(f"\n✓ Best model saved for deployment: {deployment_path}")
    print(f"  Best epoch: {best_checkpoint['epoch'] + 1}")
    print(f"  Val Dice Mean: {best_checkpoint['metrics']['Mean']:.4f}")
else:
    print("\nWarning: No best model checkpoint found.")

Starting Training
Improvements enabled:
  ✓ Region-based loss (WT, TC, ET)
  ✓ Gradient accumulation (effective batch size: 1)
  ✓ TorchIO augmentations
  ✓ Kernel size: 3x3x3

Epoch 1/20 | LR: 1.00e-04
----------------------------------------


  with autocast(enabled=Config.USE_AMP):
Epoch 1 [Train]:  13%|█▎        | 256/2000 [03:31<23:32,  1.23it/s, loss=0.4500, dice_mean=0.5893]

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss curves
ax1 = axes[0, 0]
ax1.plot(history['train_loss'], label='Train Loss', color='blue')
ax1.plot(history['val_loss'], label='Val Loss', color='red')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Mean Dice curves
ax2 = axes[0, 1]
ax2.plot(history['train_dice_mean'], label='Train Dice', color='blue')
ax2.plot(history['val_dice_mean'], label='Val Dice', color='red')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Dice Score')
ax2.set_title('Mean Dice Score')
ax2.legend()
ax2.grid(True)

# BraTS region Dice curves (Validation)
ax3 = axes[1, 0]
ax3.plot(history['val_dice_wt'], label='Whole Tumor (WT)', color='green')
ax3.plot(history['val_dice_tc'], label='Tumor Core (TC)', color='orange')
ax3.plot(history['val_dice_et'], label='Enhancing Tumor (ET)', color='purple')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Dice Score')
ax3.set_title('Validation Dice by Region')
ax3.legend()
ax3.grid(True)

# Learning rate
ax4 = axes[1, 1]
ax4.plot(history['lr'], label='Learning Rate', color='green')
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Learning Rate')
ax4.set_title('Learning Rate Schedule')
ax4.set_yscale('log')
ax4.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(Config.CHECKPOINT_DIR, 'training_curves.png'), dpi=150)
plt.show()

print(f"\nTraining curves saved to {Config.CHECKPOINT_DIR}/training_curves.png")

In [None]:
# ============== INFERENCE CONFIGURATION (STANDALONE) ==============
# This section is completely independent and can run without the training code above

import os
import sys
import numpy as np
import pandas as pd
import nibabel as nib
from tqdm import tqdm
from typing import List, Tuple, Dict
import random
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Add MedNeXt to path (adjust if needed)
sys.path.insert(0, 'mednext')
from nnunet_mednext import create_mednext_v1, MedNeXt

# ============== CHANGE THESE PATHS ==============
MODEL_PATH = "/kaggle/working/checkpoints/best_model.pt"  # Path to your trained model
DATASET_PATH = "/kaggle/input/instant-odc-ai-hackathon/test"  # Path to dataset for inference
OUTPUT_CSV_PATH = "submission.csv"  # Path to save RLE results

# Number of samples to VISUALIZE (visualization only, not inference)
NUM_SAMPLES_TO_VISUALIZE = 5

# ============== MODEL CONFIGURATION ==============
class InferenceConfig:
    # Model configuration (must match training)
    MODEL_SIZE = 'B'          # 'S' (Small), 'B' (Base), 'M' (Medium), 'L' (Large)
    KERNEL_SIZE = 3           # 3 or 5
    IN_CHANNELS = 4           # 4 MRI modalities (T1, T1ce, T2, FLAIR)
    NUM_CLASSES = 4           # Background + 3 tumor regions
    
    # Inference settings
    PATCH_SIZE = (128, 128, 128)
    USE_AMP = True
    
    # Device
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Model path: {MODEL_PATH}")
print(f"Dataset path: {DATASET_PATH}")
print(f"Output CSV: {OUTPUT_CSV_PATH}")
print(f"Device: {InferenceConfig.DEVICE}")
print(f"CUDA Available: {torch.cuda.is_available()}")

In [None]:
# ============== RLE ENCODING FUNCTIONS ==============

def rle_encode(mask):
    """
    Run-Length Encoding for binary segmentation masks.
    Produces format: "start1 length1 start2 length2 ..."
    
    Args:
        mask: Binary segmentation mask (numpy array), 1 for foreground
    
    Returns:
        RLE encoded string in format "start length start length ..."
    """
    # Flatten in column-major (Fortran) order - standard for RLE competitions
    pixels = mask.flatten(order='F')
    
    # Pad with zeros to detect transitions at boundaries
    pixels = np.concatenate([[0], pixels, [0]])
    
    # Find transition points
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1  # 1-indexed
    
    # Extract start positions and end positions
    runs_starts = runs[0::2]  # Start of each run
    runs_ends = runs[1::2]    # End of each run
    
    # Calculate lengths
    runs_lengths = runs_ends - runs_starts
    
    # Build output string
    rle_pairs = []
    for start, length in zip(runs_starts, runs_lengths):
        rle_pairs.extend([start, length])
    
    return ' '.join(str(x) for x in rle_pairs)


def rle_encode_tumor(mask):
    """
    RLE encode all tumor regions (non-background) as a single binary mask.
    
    Args:
        mask: Multi-class segmentation mask (numpy array)
              0 = background, 1/2/3 = tumor classes
    
    Returns:
        RLE encoded string for all tumor regions combined
    """
    # Combine all tumor classes into single binary mask
    tumor_mask = (mask > 0).astype(np.uint8)
    return rle_encode(tumor_mask)


print("✅ RLE encoding functions defined!")

In [None]:
# ============== LOAD MODEL FOR INFERENCE ==============

def load_model_for_inference(model_path, config):
    """
    Load the trained MedNeXt model for inference.
    
    Args:
        model_path: Path to the saved model checkpoint
        config: InferenceConfig object with model settings
    
    Returns:
        Loaded model in evaluation mode
    """
    # Load checkpoint first to check if it was trained with deep supervision
    # Using weights_only=False for PyTorch 2.6+ compatibility
    checkpoint = torch.load(model_path, map_location=config.DEVICE, weights_only=False)
    
    # Check if model was trained with deep supervision by looking for deep supervision keys
    if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
        state_dict = checkpoint['model_state_dict']
    elif isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
    elif isinstance(checkpoint, dict):
        state_dict = checkpoint
    else:
        state_dict = checkpoint
    
    # Check if deep supervision was used during training
    has_deep_supervision = any('out_1' in key or 'out_2' in key or 'out_3' in key or 'out_4' in key for key in state_dict.keys())
    
    print(f"Checkpoint trained with deep supervision: {has_deep_supervision}")
    
    # Create model with same configuration as training
    # Match the deep_supervision setting from training
    model = create_mednext_v1(
        num_input_channels=config.IN_CHANNELS,
        num_classes=config.NUM_CLASSES,
        model_id=config.MODEL_SIZE,
        kernel_size=config.KERNEL_SIZE,
        deep_supervision=has_deep_supervision  # Match training setting
    )
    
    # Load the state dict
    model.load_state_dict(state_dict)
    
    if isinstance(checkpoint, dict):
        if 'epoch' in checkpoint:
            print(f"Loaded model from epoch {checkpoint.get('epoch', 'N/A')}")
        if 'best_val_dice' in checkpoint:
            print(f"Best validation Dice: {checkpoint.get('best_val_dice'):.4f}")
        if 'metrics' in checkpoint:
            print(f"Checkpoint metrics: {checkpoint['metrics']}")
    
    model = model.to(config.DEVICE)
    model.eval()
    
    print(f"Model loaded successfully from: {model_path}")
    return model


# Load the model
inference_model = load_model_for_inference(MODEL_PATH, InferenceConfig)
print(f"\nModel is on device: {next(inference_model.parameters()).device}")

In [None]:
# ============== INFERENCE DATASET ==============

class InferenceDataset(Dataset):
    """
    Dataset for inference - loads full volumes without patch extraction.
    Completely standalone - no dependencies on training code.
    Auto-detects file naming patterns for flexibility.
    """
    
    def __init__(self, data_dir: str, subject_ids: List[str] = None):
        self.data_dir = data_dir
        
        # Possible modality suffixes (will auto-detect)
        self.modality_patterns = [
            ['t1', 't1ce', 't2', 'flair'],           # lowercase with underscore
            ['T1', 'T1ce', 'T2', 'FLAIR'],           # uppercase
            ['T1', 'T1CE', 'T2', 'FLAIR'],           # uppercase CE
            ['t1', 't1Gd', 't2', 'flair'],           # alternative Gd naming
        ]
        
        # Get subject IDs if not provided
        if subject_ids is None:
            # Try to find subject folders with different patterns
            if os.path.exists(data_dir):
                all_items = os.listdir(data_dir)
                self.subject_ids = sorted([
                    d for d in all_items 
                    if os.path.isdir(os.path.join(data_dir, d))
                ])
            else:
                self.subject_ids = []
        else:
            self.subject_ids = subject_ids
        
        print(f"Found {len(self.subject_ids)} subjects in dataset")
        if len(self.subject_ids) > 0:
            print(f"Sample subjects: {self.subject_ids[:3]}")
            # Show files in first subject folder for debugging
            first_subject_dir = os.path.join(data_dir, self.subject_ids[0])
            if os.path.exists(first_subject_dir):
                files = os.listdir(first_subject_dir)
                print(f"Files in first subject folder: {files}")
        
        # Detect the modality pattern from the first subject
        self.modalities = self._detect_modality_pattern()
        print(f"Detected modalities: {self.modalities}")
    
    def _detect_modality_pattern(self):
        """Auto-detect the modality file naming pattern."""
        if len(self.subject_ids) == 0:
            return ['t1', 't1ce', 't2', 'flair']  # default
        
        subject_id = self.subject_ids[0]
        subject_dir = os.path.join(self.data_dir, subject_id)
        
        if not os.path.exists(subject_dir):
            return ['t1', 't1ce', 't2', 'flair']  # default
        
        files = os.listdir(subject_dir)
        
        # Try each pattern
        for pattern in self.modality_patterns:
            # Check with underscore separator
            found_all = True
            for mod in pattern:
                # Try different separators and extensions
                possible_names = [
                    f"{subject_id}_{mod}.nii.gz",
                    f"{subject_id}-{mod}.nii.gz",
                    f"{subject_id}_{mod}.nii",
                    f"{mod}.nii.gz",
                    f"{mod}.nii",
                ]
                if not any(name in files for name in possible_names):
                    found_all = False
                    break
            
            if found_all:
                return pattern
        
        # If no pattern matched, try to detect from actual files
        detected = []
        modality_keywords = {
            't1ce': ['t1ce', 't1gd', 't1Gd', 'T1CE', 'T1Gd', 'T1ce'],
            't1': ['t1', 'T1'],
            't2': ['t2', 'T2'],
            'flair': ['flair', 'FLAIR', 'Flair'],
        }
        
        for mod_key, keywords in modality_keywords.items():
            for f in files:
                f_lower = f.lower()
                if any(kw.lower() in f_lower for kw in keywords):
                    # Extract the actual modality name from file
                    for kw in keywords:
                        if kw in f:
                            detected.append(kw)
                            break
                    break
        
        if len(detected) == 4:
            return detected
        
        return ['t1', 't1ce', 't2', 'flair']  # default fallback
    
    def __len__(self):
        return len(self.subject_ids)
    
    def _find_modality_file(self, subject_dir: str, subject_id: str, modality: str) -> str:
        """Find the file for a given modality, trying different naming conventions."""
        files = os.listdir(subject_dir)
        
        # Try different naming patterns
        patterns = [
            f"{subject_id}_{modality}.nii.gz",
            f"{subject_id}-{modality}.nii.gz",
            f"{subject_id}_{modality}.nii",
            f"{modality}.nii.gz",
            f"{modality}.nii",
        ]
        
        for pattern in patterns:
            if pattern in files:
                return os.path.join(subject_dir, pattern)
        
        # Try case-insensitive search
        modality_lower = modality.lower()
        for f in files:
            if modality_lower in f.lower() and ('nii' in f.lower()):
                # Make sure it's the right modality (not t1 matching t1ce)
                if modality_lower == 't1':
                    if 't1ce' not in f.lower() and 't1gd' not in f.lower():
                        return os.path.join(subject_dir, f)
                else:
                    return os.path.join(subject_dir, f)
        
        raise FileNotFoundError(f"Could not find {modality} file in {subject_dir}. Available files: {files}")
    
    def _load_nifti(self, filepath: str) -> np.ndarray:
        """Load a NIfTI file and return the data array."""
        img = nib.load(filepath)
        return img.get_fdata().astype(np.float32)
    
    def _normalize(self, data: np.ndarray) -> np.ndarray:
        """Z-score normalization per volume (non-zero voxels only)."""
        mask = data > 0
        if mask.sum() > 0:
            mean = data[mask].mean()
            std = data[mask].std()
            if std > 0:
                data = (data - mean) / std
                data[~mask] = 0
        return data
    
    def __getitem__(self, idx: int):
        subject_id = self.subject_ids[idx]
        subject_dir = os.path.join(self.data_dir, subject_id)
        
        # Load all modalities
        modality_data = []
        for mod in self.modalities:
            filepath = self._find_modality_file(subject_dir, subject_id, mod)
            data = self._load_nifti(filepath)
            data = self._normalize(data)
            modality_data.append(data)
        
        # Stack modalities: (4, D, H, W)
        volume = np.stack(modality_data, axis=0)
        
        # Load ground truth segmentation if available
        seg = None
        seg_patterns = [
            f"{subject_id}_seg.nii.gz",
            f"{subject_id}-seg.nii.gz",
            f"seg.nii.gz",
            f"{subject_id}_mask.nii.gz",
        ]
        
        files = os.listdir(subject_dir)
        for pattern in seg_patterns:
            if pattern in files:
                seg_path = os.path.join(subject_dir, pattern)
                seg = self._load_nifti(seg_path)
                # Convert labels: 0, 1, 2, 4 -> 0, 1, 2, 3
                new_seg = np.zeros_like(seg)
                new_seg[seg == 0] = 0
                new_seg[seg == 1] = 1
                new_seg[seg == 2] = 2
                new_seg[seg == 4] = 3
                seg = new_seg
                break
        
        volume_tensor = torch.from_numpy(volume.copy()).float()
        
        return {
            'subject_id': subject_id,
            'volume': volume_tensor,
            'ground_truth': seg,
            'original_shape': volume.shape[1:]
        }


# Create inference dataset
inference_dataset = InferenceDataset(DATASET_PATH)
print(f"\nInference dataset created with {len(inference_dataset)} subjects")

In [None]:
# ============== SLIDING WINDOW INFERENCE ==============

def sliding_window_inference(model, volume, config, patch_size=(128, 128, 128), overlap=0.5):
    """
    Perform sliding window inference for large 3D volumes.
    
    Args:
        model: Trained model
        volume: Input volume tensor (C, D, H, W)
        config: InferenceConfig object
        patch_size: Size of patches for inference
        overlap: Overlap ratio between patches
    
    Returns:
        Predicted segmentation mask
    """
    model.eval()
    device = config.DEVICE
    
    C, D, H, W = volume.shape
    pd, ph, pw = patch_size
    
    # Calculate stride
    stride_d = int(pd * (1 - overlap))
    stride_h = int(ph * (1 - overlap))
    stride_w = int(pw * (1 - overlap))
    
    # Pad volume if necessary
    pad_d = max(0, pd - D)
    pad_h = max(0, ph - H)
    pad_w = max(0, pw - W)
    
    if pad_d > 0 or pad_h > 0 or pad_w > 0:
        volume = F.pad(volume, (0, pad_w, 0, pad_h, 0, pad_d))
        D, H, W = volume.shape[1:]
    
    # Initialize output and count tensors
    output = torch.zeros((config.NUM_CLASSES, D, H, W), device=device)
    count = torch.zeros((D, H, W), device=device)
    
    # Generate patch positions
    d_positions = list(range(0, max(1, D - pd + 1), stride_d))
    h_positions = list(range(0, max(1, H - ph + 1), stride_h))
    w_positions = list(range(0, max(1, W - pw + 1), stride_w))
    
    # Ensure we cover the entire volume
    if D > pd and D - pd not in d_positions:
        d_positions.append(D - pd)
    if H > ph and H - ph not in h_positions:
        h_positions.append(H - ph)
    if W > pw and W - pw not in w_positions:
        w_positions.append(W - pw)
    
    with torch.no_grad():
        for d_start in d_positions:
            for h_start in h_positions:
                for w_start in w_positions:
                    # Extract patch
                    patch = volume[:, d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw]
                    patch = patch.unsqueeze(0).to(device)  # Add batch dimension
                    
                    # Forward pass
                    with torch.cuda.amp.autocast(enabled=config.USE_AMP):
                        pred = model(patch)
                    
                    # Handle deep supervision output (returns list of tensors)
                    if isinstance(pred, (list, tuple)):
                        pred = pred[0]  # Use the first (full resolution) output
                    
                    pred = F.softmax(pred, dim=1).squeeze(0)  # Remove batch dimension
                    
                    # Accumulate predictions
                    output[:, d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw] += pred
                    count[d_start:d_start+pd, h_start:h_start+ph, w_start:w_start+pw] += 1
    
    # Average predictions
    output = output / count.unsqueeze(0).clamp(min=1)
    
    # Remove padding
    original_d = D - pad_d if pad_d > 0 else D
    original_h = H - pad_h if pad_h > 0 else H
    original_w = W - pad_w if pad_w > 0 else W
    output = output[:, :original_d, :original_h, :original_w]
    
    # Get final prediction
    prediction = torch.argmax(output, dim=0).cpu().numpy()
    
    return prediction


print("Sliding window inference function defined!")

In [None]:
# ============== RUN INFERENCE ON ALL SAMPLES ==============

def run_inference_for_submission(model, dataset, config):
    """
    Run inference on ALL samples and generate RLE encodings.
    Stores only what's needed to minimize memory usage.
    
    Args:
        model: Trained model
        dataset: Inference dataset
        config: InferenceConfig object
    
    Returns:
        List of dictionaries with subject_id and rle_encoding
    """
    results = []
    num_samples = len(dataset)
    
    print(f"Running inference on {num_samples} samples...")
    
    for idx in tqdm(range(num_samples)):
        sample = dataset[idx]
        subject_id = sample['subject_id']
        volume = sample['volume']
        
        # Run sliding window inference
        prediction = sliding_window_inference(
            model, 
            volume, 
            config,
            patch_size=config.PATCH_SIZE,
            overlap=0.5
        )
        
        # RLE encode the tumor mask (all non-background classes combined)
        rle_encoding = rle_encode_tumor(prediction)
        
        # Store result
        result = {
            'subject_id': subject_id,
            'rle_encoding': rle_encoding,
        }
        
        # Store full data only for visualization samples
        if idx < NUM_SAMPLES_TO_VISUALIZE:
            result['prediction'] = prediction
            result['volume'] = volume.numpy()
            result['ground_truth'] = sample['ground_truth']
        
        results.append(result)
        
        # Clear GPU cache periodically
        if idx % 10 == 0:
            torch.cuda.empty_cache()
        
        # Print progress
        print(f"  {subject_id}: unique labels = {np.unique(prediction)}")
    
    return results


# ============== RUN INFERENCE ==============
print("=" * 60)
print("RUNNING INFERENCE ON ENTIRE DATASET")
print("=" * 60)

inference_results = run_inference_for_submission(
    inference_model, 
    inference_dataset, 
    InferenceConfig
)

print(f"\n{'=' * 60}")
print(f"✅ Inference complete! Processed {len(inference_results)} samples")
print(f"{'=' * 60}")

In [None]:
# ============== SAVE RLE RESULTS TO CSV ==============

def save_submission(results, output_path="submission.csv"):
    """
    Save RLE encoded results to CSV file.
    Format: Id<TAB>Expected
    
    Args:
        results: List of inference results with 'subject_id' and 'rle_encoding'
        output_path: Path to save the CSV file
    """
    rows = []
    
    for result in results:
        rows.append({
            'Id': result['subject_id'],
            'Expected': result['rle_encoding']
        })
    
    df = pd.DataFrame(rows)
    df.to_csv(output_path, index=False, sep='\t')
    
    print(f"✅ Submission saved to: {output_path}")
    print(f"   Total samples: {len(df)}")
    
    return df


# ============== SAVE SUBMISSION ==============
print("=" * 60)
print("SAVING SUBMISSION FILE")
print("=" * 60)

submission_df = save_submission(inference_results, output_path=OUTPUT_CSV_PATH)

print(f"\nSubmission file: {OUTPUT_CSV_PATH}")
print(f"Shape: {submission_df.shape}")

# Show sample output
print("\n" + "=" * 60)
print("SAMPLE OUTPUT (first 3 rows, truncated RLE):")
print("=" * 60)
for idx, row in submission_df.head(3).iterrows():
    rle_preview = row['Expected'][:100] + '...' if len(row['Expected']) > 100 else row['Expected']
    print(f"{row['Id']}\t{rle_preview}")

In [None]:
# ============== VISUALIZATION FUNCTIONS ==============

def visualize_segmentation_slices(result, slice_indices=None, figsize=(20, 10)):
    """
    Visualize segmentation results for a single subject.
    Shows FLAIR image, ground truth, and prediction side by side.
    
    Args:
        result: Dictionary containing prediction, ground truth, and volume
        slice_indices: List of slice indices to visualize (axial slices)
        figsize: Figure size
    """
    subject_id = result['subject_id']
    volume = result['volume']  # (4, D, H, W)
    prediction = result['prediction']  # (D, H, W)
    ground_truth = result['ground_truth']  # (D, H, W) or None
    
    # Use FLAIR modality for visualization (index 3)
    flair = volume[3]  # (D, H, W)
    
    # Auto-select slices if not provided (pick slices with tumor)
    if slice_indices is None:
        if ground_truth is not None:
            tumor_slices = np.where(ground_truth.sum(axis=(1, 2)) > 0)[0]
        else:
            tumor_slices = np.where(prediction.sum(axis=(1, 2)) > 0)[0]
        
        if len(tumor_slices) > 0:
            # Pick 5 evenly spaced slices from tumor region
            indices = np.linspace(0, len(tumor_slices) - 1, min(5, len(tumor_slices)), dtype=int)
            slice_indices = tumor_slices[indices]
        else:
            # Fallback to middle slices
            slice_indices = np.linspace(flair.shape[0] // 4, 3 * flair.shape[0] // 4, 5, dtype=int)
    
    num_slices = len(slice_indices)
    has_gt = ground_truth is not None
    num_rows = 3 if has_gt else 2
    
    fig, axes = plt.subplots(num_rows, num_slices, figsize=figsize)
    
    # Color map for segmentation
    # 0: Background (black), 1: NCR/NET (red), 2: ED (green), 3: ET (yellow)
    colors = np.array([
        [0, 0, 0],        # Background - black
        [255, 0, 0],      # NCR/NET - red
        [0, 255, 0],      # ED - green
        [255, 255, 0]     # ET - yellow
    ]) / 255.0
    
    def apply_colormap(mask):
        """Apply custom colormap to segmentation mask."""
        colored = np.zeros((*mask.shape, 3))
        for i in range(4):
            colored[mask == i] = colors[i]
        return colored
    
    for col, slice_idx in enumerate(slice_indices):
        # FLAIR image
        axes[0, col].imshow(flair[slice_idx].T, cmap='gray', origin='lower')
        axes[0, col].set_title(f'FLAIR (Slice {slice_idx})')
        axes[0, col].axis('off')
        
        # Ground truth
        if has_gt:
            gt_colored = apply_colormap(ground_truth[slice_idx])
            axes[1, col].imshow(flair[slice_idx].T, cmap='gray', origin='lower', alpha=0.7)
            axes[1, col].imshow(gt_colored.transpose(1, 0, 2), origin='lower', alpha=0.5)
            axes[1, col].set_title('Ground Truth')
            axes[1, col].axis('off')
        
        # Prediction
        pred_row = 2 if has_gt else 1
        pred_colored = apply_colormap(prediction[slice_idx])
        axes[pred_row, col].imshow(flair[slice_idx].T, cmap='gray', origin='lower', alpha=0.7)
        axes[pred_row, col].imshow(pred_colored.transpose(1, 0, 2), origin='lower', alpha=0.5)
        axes[pred_row, col].set_title('Prediction')
        axes[pred_row, col].axis('off')
    
    # Add legend
    legend_elements = [
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='red', markersize=10, label='NCR/NET'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='green', markersize=10, label='ED'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='yellow', markersize=10, label='ET')
    ]
    fig.legend(handles=legend_elements, loc='upper right', fontsize=10)
    
    plt.suptitle(f'Segmentation Results: {subject_id}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


print("Visualization functions defined!")

In [None]:
# ============== VISUALIZE FIRST N SAMPLES (CONFIGURABLE) ==============

# Only visualize the first NUM_SAMPLES_TO_VISUALIZE samples
# Full data is stored for these samples, RLE is stored for ALL samples

samples_to_visualize = min(NUM_SAMPLES_TO_VISUALIZE, len(inference_results))

print("=" * 60)
print(f"SEGMENTATION VISUALIZATION ({samples_to_visualize} of {len(inference_results)} samples)")
print("=" * 60)
print("\nLegend:")
print("  🔴 NCR/NET (Necrotic and Non-Enhancing Tumor)")
print("  🟢 ED (Peritumoral Edema)")
print("  🟡 ET (GD-Enhancing Tumor)")
print("=" * 60)
print(f"\nNote: RLE encodings generated for ALL {len(inference_results)} samples")
print(f"      Visualization shown for first {samples_to_visualize} samples only")
print("=" * 60)

for i in range(samples_to_visualize):
    result = inference_results[i]
    if result['prediction'] is not None and result['volume'] is not None:
        print(f"\n[{i+1}/{samples_to_visualize}] Visualizing: {result['subject_id']}")
        visualize_segmentation_slices(result, figsize=(18, 12))
    else:
        print(f"\n[{i+1}] {result['subject_id']}: Full data not stored for this sample")

In [None]:
# ============== 3D VISUALIZATION (OPTIONAL) ==============

def visualize_3d_segmentation(result, figsize=(16, 6)):
    """
    Create a 3D-like visualization showing orthogonal views (axial, sagittal, coronal).
    
    Args:
        result: Dictionary containing prediction, ground truth, and volume
        figsize: Figure size
    """
    subject_id = result['subject_id']
    volume = result['volume']
    prediction = result['prediction']
    ground_truth = result['ground_truth']
    
    # Check if data is available
    if volume is None or prediction is None:
        print(f"Skipping {subject_id}: Full data not stored for this sample")
        return
    
    flair = volume[3]  # Use FLAIR for background
    
    # Find center of tumor for slice selection
    if ground_truth is not None:
        tumor_mask = ground_truth > 0
    else:
        tumor_mask = prediction > 0
    
    if tumor_mask.sum() > 0:
        coords = np.where(tumor_mask)
        center_d = int(np.mean(coords[0]))
        center_h = int(np.mean(coords[1]))
        center_w = int(np.mean(coords[2]))
    else:
        center_d, center_h, center_w = [s // 2 for s in flair.shape]
    
    # Color map
    colors = np.array([
        [0, 0, 0],
        [255, 0, 0],
        [0, 255, 0],
        [255, 255, 0]
    ]) / 255.0
    
    def apply_colormap(mask):
        colored = np.zeros((*mask.shape, 3))
        for i in range(4):
            colored[mask == i] = colors[i]
        return colored
    
    fig, axes = plt.subplots(2, 3, figsize=figsize)
    
    views = [
        ('Axial', flair[center_d], prediction[center_d], ground_truth[center_d] if ground_truth is not None else None),
        ('Sagittal', flair[:, :, center_w], prediction[:, :, center_w], ground_truth[:, :, center_w] if ground_truth is not None else None),
        ('Coronal', flair[:, center_h, :], prediction[:, center_h, :], ground_truth[:, center_h, :] if ground_truth is not None else None)
    ]
    
    for col, (view_name, img, pred, gt) in enumerate(views):
        # Top row: Ground Truth (or FLAIR if no GT)
        axes[0, col].imshow(img.T, cmap='gray', origin='lower', alpha=0.7)
        if gt is not None:
            gt_colored = apply_colormap(gt)
            axes[0, col].imshow(gt_colored.transpose(1, 0, 2), origin='lower', alpha=0.5)
            axes[0, col].set_title(f'{view_name} - Ground Truth')
        else:
            axes[0, col].set_title(f'{view_name} - FLAIR')
        axes[0, col].axis('off')
        
        # Bottom row: Prediction
        axes[1, col].imshow(img.T, cmap='gray', origin='lower', alpha=0.7)
        pred_colored = apply_colormap(pred)
        axes[1, col].imshow(pred_colored.transpose(1, 0, 2), origin='lower', alpha=0.5)
        axes[1, col].set_title(f'{view_name} - Prediction')
        axes[1, col].axis('off')
    
    plt.suptitle(f'Multi-View Segmentation: {subject_id}', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()


# Visualize first sample in 3D views (only if data available)
if len(inference_results) > 0:
    first_result = inference_results[0]
    if first_result['volume'] is not None and first_result['prediction'] is not None:
        print("\n3D Multi-View Visualization (First Sample):")
        visualize_3d_segmentation(first_result)
    else:
        print("First sample does not have full data stored for visualization")

In [None]:
# ============== SUMMARY STATISTICS ==============

def compute_dice_score_inference(pred, gt, num_classes=4):
    """Compute Dice score for each class."""
    dice_scores = {}
    class_names = ['Background', 'NCR/NET', 'ED', 'ET']
    
    for c in range(num_classes):
        pred_c = (pred == c).astype(float)
        gt_c = (gt == c).astype(float)
        
        intersection = (pred_c * gt_c).sum()
        union = pred_c.sum() + gt_c.sum()
        
        if union > 0:
            dice = 2 * intersection / union
        else:
            dice = 1.0 if pred_c.sum() == 0 else 0.0
        
        dice_scores[class_names[c]] = dice
    
    return dice_scores


# Compute metrics for all results
print("\n" + "=" * 60)
print("INFERENCE SUMMARY")
print("=" * 60)

metrics_list = []
for result in inference_results:
    subject_id = result['subject_id']
    prediction = result['prediction']
    ground_truth = result['ground_truth']
    rle_encodings = result['rle_encodings']
    
    print(f"\n📊 {subject_id}")
    print(f"   Prediction shape: {prediction.shape}")
    print(f"   Unique classes predicted: {np.unique(prediction)}")
    
    # RLE encoding info
    class_names = ['Background', 'NCR/NET', 'ED', 'ET']
    for class_idx, rle in rle_encodings.items():
        rle_preview = rle[:50] + "..." if len(rle) > 50 else rle
        print(f"   RLE [{class_names[class_idx]}]: {rle_preview}")
    
    # Compute Dice if ground truth available
    if ground_truth is not None:
        dice_scores = compute_dice_score_inference(prediction, ground_truth)
        print(f"   Dice Scores:")
        for name, score in dice_scores.items():
            if name != 'Background':
                print(f"      - {name}: {score:.4f}")
        metrics_list.append({
            'subject_id': subject_id,
            **dice_scores
        })

# Summary statistics
if metrics_list:
    metrics_df = pd.DataFrame(metrics_list)
    print("\n" + "=" * 60)
    print("OVERALL PERFORMANCE")
    print("=" * 60)
    for col in ['NCR/NET', 'ED', 'ET']:
        mean_dice = metrics_df[col].mean()
        std_dice = metrics_df[col].std()
        print(f"   {col}: {mean_dice:.4f} ± {std_dice:.4f}")
    
    # Mean Dice (excluding background)
    mean_overall = metrics_df[['NCR/NET', 'ED', 'ET']].mean(axis=1).mean()
    print(f"\n   Overall Mean Dice: {mean_overall:.4f}")

print("\n✅ Inference pipeline completed successfully!")