In [1]:
# BrainSeg3D: 3D U-Net Implementation for Brain Tumor Segmentation
# =================================================================

"""
This project implements a 3D U-Net architecture for volumetric brain segmentation,
with a focus on tumor segmentation using multi-modal MRI data.

Key features:
- 3D convolutions for volumetric data processing
- Skip connections between encoder and decoder paths
- Multi-class segmentation for brain tumor regions
- Visualization tools for 3D medical imaging
"""

# Print existing Python/NumPy versions first
import sys
print(f"Python version: {sys.version}")
!python -c "import numpy; print(f'NumPy version: {numpy.__version__}')"

# Clean pip cache to avoid potential conflicts
!pip cache purge
!pip --version

# Install basic dependencies first (should be quick and reliable)
!pip install nibabel matplotlib tqdm --quiet

# Install PyTorch using the recommended Colab approach
# This uses pre-compiled binaries compatible with Colab's CUDA
print("Installing PyTorch (this may take a minute)...")
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Now install MONAI without version pinning to get latest stable
print("Installing MONAI (this may take a few minutes)...")
!pip install monai

# IMPORTANT: Restart runtime after installations
print("\n--------------------------------------------")
print("IMPORTANT: Please restart the runtime now by clicking Runtime > Restart runtime")
print("Then run the next cells after restart")
print("--------------------------------------------")

# After restarting runtime, run this part:

# Import base libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Print versions for debugging
print(f"NumPy version: {np.__version__}")

# Import PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")

# Import MONAI components - with error handling
try:
    import monai
    from monai.data import Dataset, decollate_batch
    from monai.transforms import (
        Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd,
        RandCropByPosNegLabeld, Orientationd, ToTensord, RandFlipd,
        NormalizeIntensityd
    )
    from monai.networks.nets import UNet
    from monai.networks.layers import Norm
    from monai.losses import DiceLoss
    from monai.inferers import sliding_window_inference
    from monai.visualize import plot_2d_or_3d_image
    from monai.utils import set_determinism
    from monai.apps import download_and_extract

    print(f"MONAI version: {monai.__version__}")
except ImportError as e:
    print(f"Error importing MONAI: {e}")
    print("Please restart the runtime and try again.")

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seed for reproducibility
set_determinism(seed=42)

Python version: 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
NumPy version: 2.0.2
[0mFiles removed: 0
pip 24.1.2 from /usr/local/lib/python3.11/dist-packages/pip (python 3.11)
Installing PyTorch (this may take a minute)...
Looking in indexes: https://download.pytorch.org/whl/cu118
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.6.0%2Bcu118-cp311-cp311-linux_x86_64.whl.metadata (27 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_nvrtc_cu11-11.8.89-py3-none-manylinux1_x86_64.whl (23.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.2/23.2 MB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu11==11.8.89 (from torch)
  Downloading https://download.pytorch.org/whl/cu118/nvidia_cuda_ru

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [2]:
# Dataset Setup and Download
# ==========================

# Create a directory for data
os.makedirs("./brain_data", exist_ok=True)

# Use MONAI's built-in functionality for downloading datasets
try:
    print("Attempting to download BraTS sample data with MONAI...")

    # Import necessary MONAI components for downloading datasets
    from monai.apps.utils import download_and_extract

    # Download a small sample of BraTS data (10 cases)
    resource = "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MSD_Task01_BrainTumour_sub.tar.gz"
    tarfile_name = os.path.join("./brain_data", "MSD_Task01_BrainTumour_sub.tar.gz")

    # Download and extract
    download_and_extract(
        url=resource,
        filepath=tarfile_name,
        output_dir="./brain_data",
        hash_type="md5",
    )

    data_dir = os.path.join("./brain_data", "Task01_BrainTumour")

    # Verify data exists
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"Data directory {data_dir} not found after extraction")

    image_dir = os.path.join(data_dir, "imagesTr")
    label_dir = os.path.join(data_dir, "labelsTr")

    if not os.path.exists(image_dir) or not os.path.exists(label_dir):
        raise FileNotFoundError(f"Image or label directory not found in {data_dir}")

    # Get data files
    training_images = sorted(os.listdir(image_dir))
    training_labels = sorted(os.listdir(label_dir))

    print(f"Dataset download successful!")
    print(f"Dataset path: {data_dir}")
    print(f"Number of images: {len(training_images)}")
    print(f"Number of labels: {len(training_labels)}")

except Exception as e:
    print(f"Error downloading dataset: {e}")
    print("\nFalling back to creating synthetic data for demonstration...")

    # Create synthetic data as a fallback
    import nibabel as nib
    import numpy as np

    data_dir = "./brain_data/synthetic"
    image_dir = os.path.join(data_dir, "imagesTr")
    label_dir = os.path.join(data_dir, "labelsTr")

    os.makedirs(image_dir, exist_ok=True)
    os.makedirs(label_dir, exist_ok=True)

    # Create a few synthetic samples
    num_samples = 4
    for i in range(num_samples):
        # Create synthetic multi-modal MRI (4 channels: T1, T1ce, T2, FLAIR)
        # Use a smaller size (64x64x64) to save memory
        image_data = np.zeros((4, 64, 64, 64), dtype=np.float32)

        # Create different patterns for each modality
        for c in range(4):
            # Create a basic spherical structure
            x, y, z = np.ogrid[:64, :64, :64]
            center = 32
            r = 20
            sphere = (x - center)**2 + (y - center)**2 + (z - center)**2 <= r**2

            # Add some random noise
            noise = np.random.rand(64, 64, 64) * 0.2

            # Combine sphere and noise with different intensities per channel
            image_data[c] = sphere.astype(float) * (0.5 + c*0.1) + noise

        # Create a synthetic segmentation mask
        label_data = np.zeros((64, 64, 64), dtype=np.uint8)

        # Create tumor core (label 1)
        r_core = 10
        core_sphere = (x - center)**2 + (y - center)**2 + (z - center)**2 <= r_core**2
        label_data[core_sphere] = 1

        # Create enhancing tumor (label 2) as a shell around the core
        r_enhancing = 15
        enhancing_sphere = (x - center)**2 + (y - center)**2 + (z - center)**2 <= r_enhancing**2
        label_data[enhancing_sphere & ~core_sphere] = 2

        # Save as NIfTI files
        affine = np.eye(4)  # Identity affine matrix

        # Save image
        image_file = os.path.join(image_dir, f"brain_{i:03d}.nii.gz")
        nib.save(nib.Nifti1Image(image_data, affine), image_file)

        # Save label
        label_file = os.path.join(label_dir, f"brain_{i:03d}.nii.gz")
        nib.save(nib.Nifti1Image(label_data, affine), label_file)

    # Get data files
    training_images = sorted(os.listdir(image_dir))
    training_labels = sorted(os.listdir(label_dir))

    print(f"Created synthetic dataset with {num_samples} samples")
    print(f"Dataset path: {data_dir}")
    print(f"Number of images: {len(training_images)}")
    print(f"Number of labels: {len(training_labels)}")

# Display dataset information
print("\nData Description:")
print("- Multi-modal MRI scans (4 channels: T1, T1ce, T2, FLAIR)")
print("- Segmentation labels: Background (0), Tumor core (1), Enhancing tumor (2)")
print("- 3D volumetric data")

# List a few example files
print("\nExample files:")
for i in range(min(3, len(training_images))):
    print(f"Image {i}: {training_images[i]}")
    print(f"Label {i}: {training_labels[i]}")

Attempting to download BraTS sample data with MONAI...


MSD_Task01_BrainTumour_sub.tar.gz: 0.00B [00:00, ?B/s]

2025-04-06 14:59:19,098 - ERROR - Download failed from https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MSD_Task01_BrainTumour_sub.tar.gz to /tmp/tmpamk0hv2d/MSD_Task01_BrainTumour_sub.tar.gz.
Error downloading dataset: HTTP Error 404: Not Found

Falling back to creating synthetic data for demonstration...





ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [None]:
# Data Exploration and Visualization
# ==================================

import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

# Function to safely load a NIfTI file with error handling
def load_nifti_file(file_path):
    try:
        nii_img = nib.load(file_path)
        return nii_img
    except Exception as e:
        print(f"Error loading file {file_path}: {e}")
        return None

# Get the paths to the first image and label
try:
    image_files = sorted(os.listdir(os.path.join(data_dir, "imagesTr")))
    label_files = sorted(os.listdir(os.path.join(data_dir, "labelsTr")))

    first_image_path = os.path.join(data_dir, "imagesTr", image_files[0])
    first_label_path = os.path.join(data_dir, "labelsTr", label_files[0])

    print(f"Loading image: {first_image_path}")
    print(f"Loading label: {first_label_path}")

    # Load the first image and label using NiBabel
    image_nii = load_nifti_file(first_image_path)
    label_nii = load_nifti_file(first_label_path)

    if image_nii is None or label_nii is None:
        raise ValueError("Failed to load image or label file")

    # Get data from NIfTI objects
    image_data = image_nii.get_fdata()
    label_data = label_nii.get_fdata()

    # Print shape information
    print(f"Image shape: {image_data.shape}")
    print(f"Label shape: {label_data.shape}")

    # For 3D volumes, the channel dimension might be last, check and transpose if needed
    if len(image_data.shape) == 4 and image_data.shape[-1] == 4:
        # If channels are last (e.g., H x W x D x C), transpose to C x H x W x D
        image_data = np.transpose(image_data, (3, 0, 1, 2))
    elif len(image_data.shape) == 3:
        # If it's a single channel image, add channel dimension
        image_data = np.expand_dims(image_data, axis=0)

    print(f"Image shape after preprocessing: {image_data.shape}")
    print(f"Image data type: {image_data.dtype}")
    print(f"Image value range: [{np.min(image_data)}, {np.max(image_data)}]")
    print(f"Label data type: {label_data.dtype}")
    print(f"Unique labels: {np.unique(label_data)}")

    # Visualize a central slice from each modality
    central_slice_idx = image_data.shape[2] // 2

    plt.figure(figsize=(15, 5))

    modalities = ["T1", "T1ce", "T2", "FLAIR"]
    for i in range(min(4, image_data.shape[0])):
        plt.subplot(1, 4, i+1)
        plt.imshow(image_data[i, :, :, central_slice_idx], cmap="gray")
        plt.title(f"Modality: {modalities[i] if i < len(modalities) else f'Channel {i}'}")
        plt.axis("off")

    plt.tight_layout()
    plt.show()

except Exception as e:
    print(f"Error during visualization: {e}")
    print("\nTrying alternative approach...")

    # Fallback approach using synthetic data
    try:
        # Generate a simple 3D volume if we can't load from files
        print("Creating sample data for visualization...")

        # Create a simple 3D volume with 4 channels
        image_data = np.zeros((4, 64, 64, 64), dtype=np.float32)

        # Fill with sample patterns
        for c in range(4):
            x, y, z = np.ogrid[:64, :64, :64]
            center = 32
            radius = 20 - c * 2  # Different radius per channel
            sphere = ((x - center)**2 + (y - center)**2 + (z - center)**2) <= radius**2
            image_data[c] = sphere.astype(float) * (0.8 - c * 0.1) + np.random.rand(64, 64, 64) * 0.2

        # Create sample label
        label_data = np.zeros((64, 64, 64), dtype=np.uint8)
        core_radius = 10
        enhancing_radius = 15

        core_sphere = ((x - center)**2 + (y - center)**2 + (z - center)**2) <= core_radius**2
        enhancing_sphere = ((x - center)**2 + (y - center)**2 + (z - center)**2) <= enhancing_radius**2

        label_data[core_sphere] = 1  # Tumor core
        label_data[enhancing_sphere & ~core_sphere] = 2  # Enhancing tumor

        # Visualize a central slice
        central_slice_idx = 32

        plt.figure(figsize=(15, 5))

        modalities = ["T1", "T1ce", "T2", "FLAIR"]
        for i in range(4):
            plt.subplot(1, 4, i+1)
            plt.imshow(image_data[i, :, :, central_slice_idx], cmap="gray")
            plt.title(f"Modality: {modalities[i]}")
            plt.axis("off")

        plt.tight_layout()
        plt.show()

        print("Sample data visualized successfully")

    except Exception as e:
        print(f"Failed to visualize sample data: {e}")

In [None]:
# Visualization of Segmentation Masks
# ===================================

# Visualize the segmentation masks
plt.figure(figsize=(15, 5))

# Original image (T1ce modality, which typically shows tumor best)
plt.subplot(1, 3, 1)
plt.imshow(image_data[1, :, :, central_slice_idx], cmap="gray")
plt.title("T1ce MRI")
plt.axis("off")

# Segmentation mask
plt.subplot(1, 3, 2)
plt.imshow(label_data[:, :, central_slice_idx], cmap="viridis")
plt.title("Segmentation Mask")
plt.axis("off")

# Overlay segmentation on the image
plt.subplot(1, 3, 3)
plt.imshow(image_data[1, :, :, central_slice_idx], cmap="gray")
plt.imshow(label_data[:, :, central_slice_idx], cmap="hot", alpha=0.3)
plt.title("Overlay")
plt.axis("off")

plt.tight_layout()
plt.show()

# Visualize tumor regions across multiple slices
num_slices = 5
start_slice = central_slice_idx - (num_slices // 2)
plt.figure(figsize=(15, 8))

for i in range(num_slices):
    slice_idx = start_slice + i
    plt.subplot(2, num_slices, i + 1)
    plt.imshow(image_data[1, :, :, slice_idx], cmap="gray")
    plt.title(f"Slice {slice_idx}")
    plt.axis("off")

    plt.subplot(2, num_slices, i + 1 + num_slices)
    plt.imshow(image_data[1, :, :, slice_idx], cmap="gray")
    plt.imshow(label_data[:, :, slice_idx], cmap="hot", alpha=0.3)
    plt.title(f"Overlay {slice_idx}")
    plt.axis("off")

plt.tight_layout()
plt.show()

# Count the number of voxels for each label
label_counts = {int(label): np.sum(label_data == label) for label in np.unique(label_data)}
print("\nVoxel count for each label:")
total_voxels = np.prod(label_data.shape)
for label, count in label_counts.items():
    percentage = (count / total_voxels) * 100
    if label == 0:
        class_name = "Background"
    elif label == 1:
        class_name = "Tumor Core"
    elif label == 2:
        class_name = "Enhancing Tumor"
    else:
        class_name = f"Label {label}"
    print(f"- {class_name}: {count} voxels ({percentage:.2f}%)")

In [None]:
# Data Preprocessing and Transforms
# =================================

# Create training data dictionary of image/label pairs
train_images = sorted(os.listdir(os.path.join(data_dir, "imagesTr")))
train_labels = sorted(os.listdir(os.path.join(data_dir, "labelsTr")))

train_files = [
    {
        "image": os.path.join(data_dir, "imagesTr", img),
        "label": os.path.join(data_dir, "labelsTr", lbl)
    }
    for img, lbl in zip(train_images, train_labels)
]

# Split into training and validation sets (80/20 split)
val_split = 0.2
val_size = int(len(train_files) * val_split)
train_files, val_files = train_files[val_size:], train_files[:val_size]

print(f"Number of training samples: {len(train_files)}")
print(f"Number of validation samples: {len(val_files)}")

# Define preprocessing transforms for training
train_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ScaleIntensityd(keys=["image"]),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    RandCropByPosNegLabeld(
        keys=["image", "label"],
        label_key="label",
        spatial_size=(96, 96, 64),  # Reduced size for Colab memory constraints
        pos=1,
        neg=1,
        num_samples=4,
        image_key="image"
    ),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
    RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
    ToTensord(keys=["image", "label"])
])

# Define preprocessing transforms for validation (no augmentation)
val_transforms = Compose([
    LoadImaged(keys=["image", "label"]),
    EnsureChannelFirstd(keys=["image", "label"]),
    Orientationd(keys=["image", "label"], axcodes="RAS"),
    Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ScaleIntensityd(keys=["image"]),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    ToTensord(keys=["image", "label"])
])

# Create datasets
train_ds = CacheDataset(
    data=train_files,
    transform=train_transforms,
    cache_rate=1.0,
    num_workers=4
)

val_ds = CacheDataset(
    data=val_files,
    transform=val_transforms,
    cache_rate=1.0,
    num_workers=4
)

# Create data loaders
train_loader = DataLoader(
    train_ds,
    batch_size=2,  # Reduced batch size for Colab
    shuffle=True,
    num_workers=2,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available()
)

val_loader = DataLoader(
    val_ds,
    batch_size=1,
    num_workers=2,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available()
)

# Verify the data shape after transforms
check_data = first(train_loader)
image, label = check_data["image"], check_data["label"]
print(f"Image shape after transforms: {image.shape}")
print(f"Label shape after transforms: {label.shape}")

In [None]:
# Model Architecture - Encoder Path
# =================================

import torch.nn.functional as F

class ConvBlock(nn.Module):
    """
    A Convolution block with two 3D convolutions, instance normalization and LeakyReLU activations
    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.InstanceNorm3d(out_channels)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm2 = nn.InstanceNorm3d(out_channels)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.activation(self.norm1(self.conv1(x)))
        x = self.activation(self.norm2(self.conv2(x)))
        return x

class EncoderBlock(nn.Module):
    """
    Encoder block that performs convolution followed by downsampling
    """
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        features = self.conv(x)
        pooled = self.pool(features)
        return pooled, features  # Return both the downsampled output and the features for skip connection

class Encoder(nn.Module):
    """
    Encoder path of the 3D U-Net
    """
    def __init__(self, in_channels=4, depths=[32, 64, 128, 256]):
        super(Encoder, self).__init__()
        self.encoders = nn.ModuleList()

        # Initial convolution block
        self.initial_conv = ConvBlock(in_channels, depths[0])

        # Encoder blocks with increasing feature depth
        for i in range(len(depths)-1):
            self.encoders.append(EncoderBlock(depths[i], depths[i+1]))

    def forward(self, x):
        # Initial features
        features = [self.initial_conv(x)]

        # Encoder path
        out = features[0]
        for encoder in self.encoders:
            out, feature_map = encoder(out)
            features.append(feature_map)

        return out, features

# Test the encoder with a sample input
if __name__ == "__main__":
    # Create a random input tensor (batch_size=2, channels=4, depth=64, height=96, width=96)
    x = torch.randn(2, 4, 64, 96, 96).to(device)

    # Initialize the encoder
    encoder = Encoder(in_channels=4, depths=[32, 64, 128, 256]).to(device)

    # Forward pass
    out, features = encoder(x)

    print("Encoder output shape:", out.shape)
    print("Feature maps shapes:")
    for i, feature in enumerate(features):
        print(f"  Level {i}: {feature.shape}")

In [None]:
# Model Architecture - Decoder Path
# =================================

class DecoderBlock(nn.Module):
    """
    Decoder block that performs upsampling followed by convolution
    """
    def __init__(self, in_channels, skip_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose3d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels // 2 + skip_channels, out_channels)

    def forward(self, x, skip_features):
        # Upsample
        x = self.upconv(x)

        # Handle potential size mismatch for skip connection
        if x.shape[2:] != skip_features.shape[2:]:
            x = F.interpolate(x, size=skip_features.shape[2:], mode='trilinear', align_corners=False)

        # Concatenate with skip features
        x = torch.cat([x, skip_features], dim=1)

        # Apply convolution
        x = self.conv(x)
        return x

class Decoder(nn.Module):
    """
    Decoder path of the 3D U-Net
    """
    def __init__(self, depths=[256, 128, 64, 32]):
        super(Decoder, self).__init__()
        self.decoders = nn.ModuleList()

        # Create decoder blocks with decreasing feature depth
        for i in range(len(depths)-1):
            self.decoders.append(DecoderBlock(depths[i], depths[i+1], depths[i+1]))

    def forward(self, x, encoder_features):
        # Decoder path
        out = x

        # Use encoder features in reverse order (excluding bottleneck)
        skip_features = encoder_features[-2::-1]

        for i, decoder in enumerate(self.decoders):
            out = decoder(out, skip_features[i])

        return out

# Test the decoder with a sample input and features from encoder
if __name__ == "__main__":
    # Create a random input tensor (batch_size=2, channels=256, depth=8, height=12, width=12)
    x = torch.randn(2, 256, 8, 12, 12).to(device)

    # Create sample encoder features
    encoder_features = [
        torch.randn(2, 32, 64, 96, 96).to(device),  # Level 0
        torch.randn(2, 64, 32, 48, 48).to(device),  # Level 1
        torch.randn(2, 128, 16, 24, 24).to(device), # Level 2
        torch.randn(2, 256, 8, 12, 12).to(device)   # Level 3 (bottleneck)
    ]

    # Initialize the decoder
    decoder = Decoder(depths=[256, 128, 64, 32]).to(device)

    # Forward pass
    out = decoder(x, encoder_features)

    print("Decoder output shape:", out.shape)

In [None]:
# Model Architecture - Full 3D U-Net
# ==================================

class UNet3D(nn.Module):
    """
    Complete 3D U-Net model for volumetric segmentation
    """
    def __init__(self, in_channels=4, out_channels=3, feature_channels=[32, 64, 128, 256]):
        super(UNet3D, self).__init__()

        # Encoder and decoder paths
        self.encoder = Encoder(in_channels, feature_channels)
        self.bottleneck = ConvBlock(feature_channels[-1], feature_channels[-1]*2)
        self.decoder = Decoder([feature_channels[-1]*2] + feature_channels[::-1][:-1])

        # Final convolution to produce segmentation map
        self.final_conv = nn.Conv3d(feature_channels[0], out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        enc_out, features = self.encoder(x)

        # Bottleneck
        bottleneck = self.bottleneck(enc_out)

        # Decoder path
        dec_out = self.decoder(bottleneck, features)

        # Final convolution
        logits = self.final_conv(dec_out)

        return logits

    def initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, (nn.Conv3d, nn.ConvTranspose3d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.InstanceNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# Create the model and move to device
model = UNet3D(in_channels=4, out_channels=3).to(device)
model.initialize_weights()

# Print model summary
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

total_params = count_parameters(model)
print(f"Total trainable parameters: {total_params:,}")

# Test with a sample input
if __name__ == "__main__":
    # Create a random input tensor (batch_size=2, channels=4, depth=64, height=96, width=96)
    x = torch.randn(2, 4, 64, 96, 96).to(device)

    # Forward pass
    output = model(x)

    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")

    # Check if output matches expected dimensions
    print("Output channels (should be 3 for background, tumor core, enhancing tumor):", output.shape[1])
    print("Output spatial dimensions match input:", output.shape[2:] == x.shape[2:])

In [None]:
# Loss Function and Evaluation Metrics
# ===================================

from monai.losses import DiceLoss, DiceCELoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric, ConfusionMatrixMetric

# Define the loss function
# DiceCELoss combines Dice loss and Cross-Entropy loss
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)

# Define evaluation metrics
dice_metric = DiceMetric(include_background=False, reduction="mean")
hausdorff_metric = HausdorffDistanceMetric(include_background=False, reduction="mean", percentile=95)
confusion_matrix = ConfusionMatrixMetric(include_background=False,
                                         metric_name=["sensitivity", "specificity", "precision"],
                                         compute_sample=True)

def calculate_metrics(y_pred, y_true):
    """
    Calculate and return multiple evaluation metrics
    """
    # Convert predictions to one-hot format
    y_pred = torch.argmax(y_pred, dim=1, keepdim=True)

    # Compute Dice score
    dice_score = dice_metric(y_pred, y_true)

    # Compute Hausdorff distance
    hausdorff_score = hausdorff_metric(y_pred, y_true)

    # Compute sensitivity, specificity, and precision
    confusion_matrix(y_pred, y_true)
    metrics = confusion_matrix.aggregate()[0]
    sensitivity = metrics[0].item()
    specificity = metrics[1].item()
    precision = metrics[2].item()

    return {
        "dice": dice_score.item(),
        "hausdorff": hausdorff_score.item(),
        "sensitivity": sensitivity,
        "specificity": specificity,
        "precision": precision
    }

# Reset metrics
def reset_metrics():
    dice_metric.reset()
    hausdorff_metric.reset()
    confusion_matrix.reset()

# Test the loss function with random inputs
if __name__ == "__main__":
    # Create random predictions and targets
    y_pred = torch.randn(2, 3, 64, 96, 96).to(device)  # [batch, channels, d, h, w]
    y_true = torch.randint(0, 3, (2, 1, 64, 96, 96)).to(device)  # [batch, channels, d, h, w]

    # Calculate loss
    loss_val = loss_function(y_pred, y_true)
    print(f"Loss value: {loss_val.item()}")

    # Calculate metrics
    metrics = calculate_metrics(y_pred, y_true)
    print("Metrics:")
    for metric_name, metric_value in metrics.items():
        print(f"  {metric_name}: {metric_value}")

In [None]:
# Training Setup
# =============

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

# Define the learning rate scheduler
# Reduce learning rate when validation loss plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

# Number of epochs
max_epochs = 50  # Adjust based on available time
early_stop_patience = 10  # Stop training if validation loss doesn't improve for this many epochs

# Initialize variables to track best model
best_val_loss = float('inf')
best_val_dice = 0.0
best_epoch = 0
epochs_without_improvement = 0

# Save best model
model_save_path = "best_model.pth"

# Create dictionaries to store metrics
train_metrics = {
    'loss': [],
    'dice': []
}

val_metrics = {
    'loss': [],
    'dice': [],
    'hausdorff': [],
    'sensitivity': [],
    'specificity': [],
    'precision': []
}

# Use mixed precision training to speed up and reduce memory usage
scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

# Display training configuration
print("Training Configuration:")
print(f"- Device: {device}")
print(f"- Max Epochs: {max_epochs}")
print(f"- Batch Size: {train_loader.batch_size}")
print(f"- Learning Rate: {optimizer.param_groups[0]['lr']}")
print(f"- Early Stop Patience: {early_stop_patience}")
print(f"- Training Samples: {len(train_ds)}")
print(f"- Validation Samples: {len(val_ds)}")
print(f"- Mixed Precision: {'Enabled' if scaler else 'Disabled'}")

In [None]:
# Training Loop
# ============

def train_epoch(model, loader, optimizer, loss_function, scaler=None):
    """
    Train the model for one epoch
    """
    model.train()
    epoch_loss = 0
    step = 0

    # Use tqdm for progress bar
    progress_bar = tqdm(enumerate(loader), total=len(loader), desc="Training")

    for step, batch in progress_bar:
        # Get data and move to device
        inputs, labels = batch["image"].to(device), batch["label"].to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass with mixed precision
        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(inputs)
                loss = loss_function(outputs, labels)

            # Backward and optimize with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # Standard forward/backward pass
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()

        # Update metrics
        epoch_loss += loss.item()

        # Update progress bar
        progress_bar.set_postfix({'batch_loss': loss.item()})

    # Calculate average loss
    epoch_loss /= (step + 1)

    # Calculate Dice score on training set
    pred = torch.argmax(outputs, dim=1, keepdim=True)
    dice_score = dice_metric(pred, labels)
    dice_metric.reset()

    return epoch_loss, dice_score.item()

# Example usage if running as main script
if __name__ == "__main__":
    print("Training function defined and ready to use in the training loop.")

In [None]:
# Validation Loop
# ==============

@torch.no_grad()
def validate(model, loader, loss_function):
    """
    Validate the model on the validation set
    """
    model.eval()
    val_loss = 0
    step = 0

    # Reset metrics
    reset_metrics()

    # Use tqdm for progress bar
    progress_bar = tqdm(enumerate(loader), total=len(loader), desc="Validation")

    for step, batch in progress_bar:
        # Get data and move to device
        inputs, labels = batch["image"].to(device), batch["label"].to(device)

        # Use sliding window inference for larger volumes
        roi_size = (96, 96, 64)
        sw_batch_size = 4

        outputs = sliding_window_inference(inputs, roi_size, sw_batch_size, model)
        loss = loss_function(outputs, labels)

        # Update loss
        val_loss += loss.item()

        # Calculate metrics
        pred = torch.argmax(outputs, dim=1, keepdim=True)
        dice_metric(pred, labels)
        hausdorff_metric(pred, labels)
        confusion_matrix(pred, labels)

        # Update progress bar
        progress_bar.set_postfix({'batch_loss': loss.item()})

    # Calculate average loss
    val_loss /= (step + 1)

    # Aggregate metrics
    dice_score = dice_metric.aggregate().item()
    hausdorff_score = hausdorff_metric.aggregate().item()

    confusion_values = confusion_matrix.aggregate()[0]
    sensitivity = confusion_values[0].item()
    specificity = confusion_values[1].item()
    precision = confusion_values[2].item()

    # Reset metrics for next validation
    reset_metrics()

    metrics = {
        'loss': val_loss,
        'dice': dice_score,
        'hausdorff': hausdorff_score,
        'sensitivity': sensitivity,
        'specificity': specificity,
        'precision': precision
    }

    return metrics

# Example usage if running as main script
if __name__ == "__main__":
    print("Validation function defined and ready to use in the training loop.")

In [None]:
# Model Training Execution
# =======================

# Let's train the model
def train_model(model, train_loader, val_loader, optimizer, loss_function, max_epochs, scheduler=None, scaler=None):
    """
    Train the model for multiple epochs
    """
    # Initialize tracking variables
    best_val_dice = 0
    best_epoch = 0
    epochs_without_improvement = 0
    model_save_path = "best_model.pth"

    # Track metrics
    train_loss_values = []
    train_dice_values = []
    val_loss_values = []
    val_dice_values = []

    for epoch in range(max_epochs):
        print(f"\nEpoch {epoch+1}/{max_epochs}")

        # Train for one epoch
        train_loss, train_dice = train_epoch(model, train_loader, optimizer, loss_function, scaler)

        # Validate the model
        val_metrics = validate(model, val_loader, loss_function)
        val_loss = val_metrics['loss']
        val_dice = val_metrics['dice']

        # Print metrics
        print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}")
        print(f"Val Hausdorff: {val_metrics['hausdorff']:.4f}")
        print(f"Val Sensitivity: {val_metrics['sensitivity']:.4f}, Val Specificity: {val_metrics['specificity']:.4f}, Val Precision: {val_metrics['precision']:.4f}")

        # Track metrics
        train_loss_values.append(train_loss)
        train_dice_values.append(train_dice)
        val_loss_values.append(val_loss)
        val_dice_values.append(val_dice)

        # Update learning rate scheduler
        if scheduler is not None:
            scheduler.step(val_loss)

        # Save best model
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            best_epoch = epoch
            epochs_without_improvement = 0

            # Save model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_dice': val_dice,
                'val_metrics': val_metrics
            }, model_save_path)

            print(f"Best model saved at epoch {epoch+1} with validation Dice score: {val_dice:.4f}")
        else:
            epochs_without_improvement += 1
            print(f"No improvement for {epochs_without_improvement} epochs (best Dice: {best_val_dice:.4f} at epoch {best_epoch+1})")

        # Early stopping
        if epochs_without_improvement >= early_stop_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    # Plot training curves
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(train_loss_values, label='Train Loss')
    plt.plot(val_loss_values, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curves')

    plt.subplot(1, 2, 2)
    plt.plot(train_dice_values, label='Train Dice')
    plt.plot(val_dice_values, label='Val Dice')
    plt.xlabel('Epoch')
    plt.ylabel('Dice Score')
    plt.legend()
    plt.title('Dice Score Curves')

    plt.tight_layout()
    plt.show()

    return model_save_path, best_epoch, best_val_dice

# Start training with a smaller number of epochs for Colab demo
num_epochs = 5  # Reduced for demonstration
print(f"Starting training for {num_epochs} epochs...")

# Train the model
best_model_path, best_epoch, best_dice = train_model(
    model, train_loader, val_loader, optimizer, loss_function,
    max_epochs=num_epochs, scheduler=scheduler, scaler=scaler
)

print(f"\nTraining completed!")
print(f"Best model saved at {best_model_path}")
print(f"Best validation Dice score: {best_dice:.4f} at epoch {best_epoch+1}")

In [None]:
# Model Inference
# ==============

@torch.no_grad()
def infer(model, image):
    """
    Run inference on a single image
    """
    model.eval()

    # Use sliding window inference for large volumes
    roi_size = (96, 96, 64)
    sw_batch_size = 4

    # Perform inference
    output = sliding_window_inference(image, roi_size, sw_batch_size, model)

    # Get prediction
    pred = torch.argmax(output, dim=1, keepdim=True)

    return pred

# Load the best model
def load_best_model(model, model_path):
    """
    Load the best model from checkpoint
    """
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    best_epoch = checkpoint['epoch']
    best_val_dice = checkpoint['val_dice']

    print(f"Loaded best model from epoch {best_epoch+1} with validation Dice score: {best_val_dice:.4f}")

    return model

# Load the best model
model = load_best_model(model, best_model_path)

# Get a batch from the validation dataset
val_data = first(val_loader)
val_image = val_data["image"].to(device)
val_label = val_data["label"].to(device)

# Run inference
pred = infer(model, val_image)

# Print shapes
print(f"Image shape: {val_image.shape}")
print(f"Label shape: {val_label.shape}")
print(f"Prediction shape: {pred.shape}")

# Calculate metrics
metrics = calculate_metrics(torch.argmax(model(val_image), dim=1, keepdim=True), val_label)
print("\nInference Metrics:")
for metric_name, metric_value in metrics.items():
    print(f"  {metric_name}: {metric_value:.4f}")

In [None]:
# Results Visualization
# ====================

def visualize_results(image, label, pred, slice_idx=None, modality_idx=1):
    """
    Visualize the results of segmentation
    """
    # Get data from tensors
    image_np = image.detach().cpu().numpy()
    label_np = label.detach().cpu().numpy()
    pred_np = pred.detach().cpu().numpy()

    # Get dimensions
    _, C, D, H, W = image_np.shape

    # If slice_idx is not provided, use the middle slice
    if slice_idx is None:
        slice_idx = D // 2

    # Get the selected modality (T1ce is usually the most informative, index 1)
    image_slice = image_np[0, modality_idx, slice_idx, :, :]
    label_slice = label_np[0, 0, slice_idx, :, :]
    pred_slice = pred_np[0, 0, slice_idx, :, :]

    # Create a figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Display the original image
    axes[0].imshow(image_slice, cmap='gray')
    axes[0].set_title(f'Original MRI (Modality {modality_idx})')
    axes[0].axis('off')

    # Display the ground truth segmentation mask
    axes[1].imshow(image_slice, cmap='gray')
    mask = np.ma.masked_where(label_slice == 0, label_slice)
    axes[1].imshow(mask, cmap='hot', alpha=0.7)
    axes[1].set_title('Ground Truth Segmentation')
    axes[1].axis('off')

    # Display the predicted segmentation mask
    axes[2].imshow(image_slice, cmap='gray')
    mask = np.ma.masked_where(pred_slice == 0, pred_slice)
    axes[2].imshow(mask, cmap='hot', alpha=0.7)
    axes[2].set_title('Predicted Segmentation')
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

# Visualize results across multiple slices
def visualize_multiple_slices(image, label, pred, num_slices=5, modality_idx=1):
    """
    Visualize results across multiple slices
    """
    # Get data from tensors
    image_np = image.detach().cpu().numpy()
    label_np = label.detach().cpu().numpy()
    pred_np = pred.detach().cpu().numpy()

    # Get dimensions
    _, C, D, H, W = image_np.shape

    # Calculate slice indices
    slice_indices = np.linspace(D // 4, 3 * D // 4, num_slices, dtype=int)

    # Create a large figure
    fig, axes = plt.subplots(3, num_slices, figsize=(20, 10))

    # For each slice
    for i, slice_idx in enumerate(slice_indices):
        # Get the selected modality slice
        image_slice = image_np[0, modality_idx, slice_idx, :, :]
        label_slice = label_np[0, 0, slice_idx, :, :]
        pred_slice = pred_np[0, 0, slice_idx, :, :]

        # Display the original image
        axes[0, i].imshow(image_slice, cmap='gray')
        if i == 0:
            axes[0, i].set_ylabel('Original MRI')
        axes[0, i].set_title(f'Slice {slice_idx}')
        axes[0, i].axis('off')

        # Display the ground truth segmentation
        axes[1, i].imshow(image_slice, cmap='gray')
        mask = np.ma.masked_where(label_slice == 0, label_slice)
        axes[1, i].imshow(mask, cmap='hot', alpha=0.7)
        if i == 0:
            axes[1, i].set_ylabel('Ground Truth')
        axes[1, i].axis('off')

        # Display the predicted segmentation
        axes[2, i].imshow(image_slice, cmap='gray')
        mask = np.ma.masked_where(pred_slice == 0, pred_slice)
        axes[2, i].imshow(mask, cmap='hot', alpha=0.7)
        if i == 0:
            axes[2, i].set_ylabel('Prediction')
        axes[2, i].axis('off')

    plt.tight_layout()
    plt.show()

# Visualize a single slice
print("Visualization of a single slice:")
visualize_results(val_image, val_label, pred)

# Visualize multiple slices
print("\nVisualization across multiple slices:")
visualize_multiple_slices(val_image, val_label, pred, num_slices=5)

In [None]:
# Evaluation and Conclusion
# =========================

# Calculate Dice score per class
def calculate_class_wise_dice(pred, target, num_classes=3):
    """
    Calculate Dice score for each class separately
    """
    dice_scores = []

    # Convert tensors to numpy arrays
    pred = pred.detach().cpu().numpy().squeeze()
    target = target.detach().cpu().numpy().squeeze()

    for i in range(1, num_classes):  # Skip background class (0)
        # Binary masks for this class
        class_pred = (pred == i).astype(np.float32)
        class_target = (target == i).astype(np.float32)

        # Calculate intersection and union
        intersection = np.sum(class_pred * class_target)
        union = np.sum(class_pred) + np.sum(class_target)

        # Calculate Dice score
        dice = (2.0 * intersection) / (union + 1e-5)
        dice_scores.append(dice)

    return dice_scores

# Calculate class-wise Dice scores
class_wise_dice = calculate_class_wise_dice(pred, val_label)
class_names = ["Tumor Core", "Enhancing Tumor"]

print("\nClass-wise Dice Scores:")
for i, (class_name, dice) in enumerate(zip(class_names, class_wise_dice)):
    print(f"  {class_name}: {dice:.4f}")

# Create a summary of the model and results
print("\n" + "="*50)
print("BrainSeg3D: 3D U-Net for Brain Tumor Segmentation")
print("="*50)

print("\nModel Architecture:")
print(f"- Input Channels: 4 (T1, T1ce, T2, FLAIR MRI modalities)")
print(f"- Output Channels: 3 (Background, Tumor Core, Enhancing Tumor)")
print(f"- Architecture: 3D U-Net with skip connections")
print(f"- Total Parameters: {count_parameters(model):,}")

print("\nTraining Summary:")
print(f"- Training Samples: {len(train_ds)}")
print(f"- Validation Samples: {len(val_ds)}")
print(f"- Best Validation Dice Score: {best_dice:.4f} (Epoch {best_epoch+1})")

print("\nFinal Evaluation Metrics:")
for metric_name, metric_value in metrics.items():
    print(f"  {metric_name}: {metric_value:.4f}")

print("\nClass-wise Dice Scores:")
for i, (class_name, dice) in enumerate(zip(class_names, class_wise_dice)):
    print(f"  {class_name}: {dice:.4f}")

print("\nConclusion:")
print("We have successfully implemented a 3D U-Net model for brain tumor segmentation")
print("using multi-modal MRI data. The model demonstrates the application of deep")
print("learning for medical image segmentation, particularly in neuroimaging.")

print("\nPotential Improvements:")
print("1. Train with more data for better generalization")
print("2. Experiment with different architectures (e.g., Attention U-Net)")
print("3. Use more advanced data augmentation techniques")
print("4. Implement post-processing to refine segmentation boundaries")
print("5. Explore different loss functions (e.g., Focal Dice Loss)")