# 3D U-Net for Brain Tumor Segmentation

This notebook implements a complete training pipeline for brain tumor segmentation using the BraTS 2021 dataset with a 3D U-Net architecture.

## Contents
1. Imports and Setup
2. Data Loading
3. Data Preprocessing
4. Dataset and DataLoader
5. Loss Functions
6. 3D U-Net Model Architecture
7. Training Loop
8. Inference and Visualization


---
## 1. Imports and Setup


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import nibabel as nib
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")

---
## 2. Data Loading

Load the BraTS 2021 dataset and create a DataFrame with file paths.


In [None]:
DATA_ROOT = "/kaggle/temp/brats_extracted"

def get_brats_file_paths(root_dir):
    """Index all patient folders and return a DataFrame with file paths."""
    data_list = []
    
    if not os.path.exists(root_dir):
        print("Error: Directory not found.")
        return pd.DataFrame()
    
    # Handle nested directory structure
    search_dirs = [root_dir]
    for item in os.listdir(root_dir):
        item_path = os.path.join(root_dir, item)
        if os.path.isdir(item_path):
            search_dirs.append(item_path)
    
    for search_dir in search_dirs:
        for patient in os.listdir(search_dir):
            patient_path = os.path.join(search_dir, patient)
            if not os.path.isdir(patient_path):
                continue
            
            # Find modality files
            files = os.listdir(patient_path)
            paths = {'id': patient}
            
            for modality in ['flair', 't1', 't1ce', 't2', 'seg']:
                for f in files:
                    if modality == 't1' and 't1ce' in f:
                        continue
                    if f.endswith('.nii.gz') and modality in f.lower():
                        paths[modality] = os.path.join(patient_path, f)
                        break
            
            if len(paths) == 6:
                data_list.append(paths)
    
    print(f"Found {len(data_list)} patients")
    return pd.DataFrame(data_list)

df = get_brats_file_paths(DATA_ROOT)
print(df.head())

---
## 3. Data Preprocessing

Functions for volume normalization, cropping, and patch extraction.


In [None]:
def normalize_volume(volume):
    """
    Robust Z-Score Normalization.
    Clips outliers and normalizes only non-zero (brain) region.
    """
    mask = volume > 0
    if np.sum(mask) == 0:
        return volume

    pixels = volume[mask]
    p_low, p_high = np.percentile(pixels, 0.5), np.percentile(pixels, 99.5)
    volume = np.clip(volume, p_low, p_high)
    
    pixels = volume[mask]
    mean, std = pixels.mean(), pixels.std()
    volume = (volume - mean) / (std + 1e-8)
    volume[~mask] = 0
    
    return volume


def crop_to_bbox(image_stack, label):
    """Crop to the bounding box of non-zero signal."""
    mask = np.sum(image_stack, axis=0) > 0
    coords = np.argwhere(mask)
    
    if len(coords) == 0:
        return image_stack, label

    x_min, y_min, z_min = coords.min(axis=0)
    x_max, y_max, z_max = coords.max(axis=0) + 1

    return (image_stack[:, x_min:x_max, y_min:y_max, z_min:z_max], 
            label[x_min:x_max, y_min:y_max, z_min:z_max])


def get_random_patch(image, label, patch_size=(128, 128, 128)):
    """Extract a random patch with foreground sampling."""
    c, h, w, d = image.shape
    ph, pw, pd = patch_size

    # Pad if smaller than patch
    pad_h, pad_w, pad_d = max(ph-h, 0), max(pw-w, 0), max(pd-d, 0)
    if pad_h or pad_w or pad_d:
        image = np.pad(image, ((0, 0), (0, pad_h), (0, pad_w), (0, pad_d)), mode='constant')
        label = np.pad(label, ((0, pad_h), (0, pad_w), (0, pad_d)), mode='constant')
        h, w, d = image.shape[1:]

    # Foreground sampling (33% chance)
    if np.random.rand() < 0.33:
        fg_coords = np.argwhere(label > 0)
        if len(fg_coords) > 0:
            center = fg_coords[np.random.randint(len(fg_coords))]
            x = np.clip(center[0] - ph // 2, 0, h - ph)
            y = np.clip(center[1] - pw // 2, 0, w - pw)
            z = np.clip(center[2] - pd // 2, 0, d - pd)
            return image[:, x:x+ph, y:y+pw, z:z+pd], label[x:x+ph, y:y+pw, z:z+pd]
    
    # Random sampling
    x = np.random.randint(0, h - ph + 1)
    y = np.random.randint(0, w - pw + 1)
    z = np.random.randint(0, d - pd + 1)
    
    return image[:, x:x+ph, y:y+pw, z:z+pd], label[x:x+ph, y:y+pw, z:z+pd]

---
## 4. Dataset and DataLoader


In [None]:
class BratsDataset(Dataset):
    """PyTorch Dataset for BraTS 2021 data."""
    def __init__(self, df, phase="train", augment=False):
        self.df = df
        self.phase = phase
        self.augment = augment
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load and normalize each modality
        t1 = normalize_volume(nib.load(row['t1']).get_fdata().astype(np.float32))
        t1ce = normalize_volume(nib.load(row['t1ce']).get_fdata().astype(np.float32))
        t2 = normalize_volume(nib.load(row['t2']).get_fdata().astype(np.float32))
        flair = normalize_volume(nib.load(row['flair']).get_fdata().astype(np.float32))
        seg = nib.load(row['seg']).get_fdata().astype(np.int64)
        
        # Remap labels: [0, 1, 2, 4] -> [0, 1, 2, 3]
        seg[seg == 4] = 3
        
        # Stack modalities
        image = np.stack([t1, t1ce, t2, flair], axis=0)
        
        # Crop and extract patch
        image, seg = crop_to_bbox(image, seg)
        image, seg = get_random_patch(image, seg, patch_size=(128, 128, 128))
        
        # Augmentation
        if self.augment and self.phase == "train":
            if np.random.rand() > 0.5:
                image = np.flip(image, axis=1).copy()
                seg = np.flip(seg, axis=0).copy()
        
        return torch.from_numpy(image), torch.from_numpy(seg)

In [None]:
# Split data
train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)
print(f"Training: {len(train_df)}, Validation: {len(val_df)}")

# Create datasets and loaders
BATCH_SIZE = 2
train_ds = BratsDataset(train_df, phase="train", augment=True)
val_ds = BratsDataset(val_df, phase="val", augment=False)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

---
## 5. Loss Functions

Dice Loss and Combined Dice-CrossEntropy Loss for segmentation.


In [None]:
class DiceLoss(nn.Module):
    """Multi-class Dice Loss for segmentation."""
    def __init__(self, smooth=1e-5):
        super().__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        probs = F.softmax(inputs, dim=1)
        targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[1])
        targets_one_hot = targets_one_hot.permute(0, 4, 1, 2, 3).float()

        dims = (0, 2, 3, 4)
        intersection = torch.sum(probs * targets_one_hot, dims)
        cardinality = torch.sum(probs + targets_one_hot, dims)
        
        dice_score = (2. * intersection + self.smooth) / (cardinality + self.smooth)
        return 1. - dice_score.mean()


class DiceCELoss(nn.Module):
    """Combined Dice and Cross-Entropy Loss."""
    def __init__(self, weight_ce=1.0, weight_dice=1.0):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = DiceLoss()
        self.weight_ce = weight_ce
        self.weight_dice = weight_dice

    def forward(self, inputs, targets):
        loss_ce = self.ce(inputs, targets.long())
        loss_dice = self.dice(inputs, targets)
        return self.weight_ce * loss_ce + self.weight_dice * loss_dice

---
## 6. 3D U-Net Model Architecture


In [None]:
class DoubleConv(nn.Module):
    """Double convolution block: (Conv3D -> InstanceNorm -> LeakyReLU) x 2"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm3d(out_channels, affine=True),
            nn.LeakyReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.InstanceNorm3d(out_channels, affine=True),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downsampling with MaxPool followed by DoubleConv."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool3d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upsampling followed by DoubleConv with skip connection."""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class UNet3D(nn.Module):
    """3D U-Net for volumetric segmentation."""
    def __init__(self, in_channels=4, out_channels=4, features=32):
        super().__init__()
        
        # Encoder
        self.inc = DoubleConv(in_channels, features)
        self.down1 = Down(features, features * 2)
        self.down2 = Down(features * 2, features * 4)
        self.down3 = Down(features * 4, features * 8)
        self.down4 = Down(features * 8, features * 16)
        
        # Decoder
        self.up1 = Up(features * 16, features * 8)
        self.up2 = Up(features * 8, features * 4)
        self.up3 = Up(features * 4, features * 2)
        self.up4 = Up(features * 2, features)
        
        # Output
        self.outc = nn.Conv3d(features, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        return self.outc(x)

---
## 7. Training Loop


In [None]:
def compute_brats_dice(predictions, masks):
    """Compute Dice Score for BraTS regions: WT, TC, ET."""
    preds_argmax = torch.argmax(predictions, dim=1)
    
    dice_scores = {'WT': 0.0, 'TC': 0.0, 'ET': 0.0}
    batch_size = preds_argmax.shape[0]
    
    for i in range(batch_size):
        pred = preds_argmax[i]
        mask = masks[i]
        
        # Whole Tumor (WT): Labels 1, 2, 3
        pred_wt = (pred > 0).float()
        mask_wt = (mask > 0).float()
        
        # Tumor Core (TC): Labels 1, 3
        pred_tc = ((pred == 1) | (pred == 3)).float()
        mask_tc = ((mask == 1) | (mask == 3)).float()
        
        # Enhancing Tumor (ET): Label 3
        pred_et = (pred == 3).float()
        mask_et = (mask == 3).float()
        
        def dice(p, m, smooth=1e-5):
            intersection = (p * m).sum()
            return (2. * intersection + smooth) / (p.sum() + m.sum() + smooth)
        
        dice_scores['WT'] += dice(pred_wt, mask_wt).item()
        dice_scores['TC'] += dice(pred_tc, mask_tc).item()
        dice_scores['ET'] += dice(pred_et, mask_et).item()
    
    return {k: v / batch_size for k, v in dice_scores.items()}

In [None]:
# Configuration
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5
EPOCHS = 20
SAVE_PATH = "unet3d_best.pth"

# Initialize
model = UNet3D(in_channels=4, out_channels=4).to(DEVICE)
criterion = DiceCELoss(weight_ce=1.0, weight_dice=1.0)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE == "cuda"))

best_val_loss = float('inf')

for epoch in range(EPOCHS):
    # Training
    model.train()
    train_loss = 0.0
    
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for images, masks in progress:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
            outputs = model(images)
            loss = criterion(outputs, masks)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        train_loss += loss.item()
        progress.set_postfix({"loss": loss.item()})
    
    avg_train_loss = train_loss / len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0.0
    dice_wt, dice_tc, dice_et = 0.0, 0.0, 0.0
    
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            outputs = model(images)
            val_loss += criterion(outputs, masks).item()
            
            dice = compute_brats_dice(outputs, masks)
            dice_wt += dice['WT']
            dice_tc += dice['TC']
            dice_et += dice['ET']
    
    avg_val_loss = val_loss / len(val_loader)
    dice_wt /= len(val_loader)
    dice_tc /= len(val_loader)
    dice_et /= len(val_loader)
    
    scheduler.step(avg_val_loss)
    
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
    print(f"         Dice WT={dice_wt:.4f}, TC={dice_tc:.4f}, ET={dice_et:.4f}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"         Model saved!")

print("Training complete!")

---
## 8. Inference and Visualization


In [None]:
def visualize_prediction(model, loader, device=DEVICE):
    """Visualize model predictions alongside ground truth."""
    model.eval()
    
    images, masks = next(iter(loader))
    images = images.to(device)
    
    with torch.no_grad():
        outputs = model(images)
        predictions = torch.argmax(torch.softmax(outputs, dim=1), dim=1)
    
    images = images.cpu().numpy()
    masks = masks.numpy()
    predictions = predictions.cpu().numpy()
    
    sample_idx = 0
    slice_idx = images.shape[4] // 2
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    modalities = ['T1', 'T1ce', 'T2', 'FLAIR']
    for i, name in enumerate(modalities):
        axes[0, i].imshow(np.rot90(images[sample_idx, i, :, :, slice_idx]), cmap='gray')
        axes[0, i].set_title(name)
        axes[0, i].axis('off')
    
    axes[1, 0].imshow(np.rot90(masks[sample_idx, :, :, slice_idx]), cmap='jet', vmin=0, vmax=3)
    axes[1, 0].set_title('Ground Truth')
    axes[1, 0].axis('off')
    
    axes[1, 1].imshow(np.rot90(predictions[sample_idx, :, :, slice_idx]), cmap='jet', vmin=0, vmax=3)
    axes[1, 1].set_title('Prediction')
    axes[1, 1].axis('off')
    
    axes[1, 2].axis('off')
    axes[1, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

# Visualize
visualize_prediction(model, val_loader)