# Dependences & Utility

## Install dependencies

In [None]:
!pip install -q -U "monai[nibabel, tqdm]"
!pip install -q -U wandb

In [None]:
import os
import time

import numpy as np
from tqdm.auto import tqdm
import wandb

from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.losses import DiceLoss, FocalLoss
from monai.config import print_config
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)
from monai.utils import set_determinism

import torch

import gc
from collections import deque

from google.colab import drive
from pathlib import Path

print_config()

MONAI version: 1.4.0
Numpy version: 1.26.4
Pytorch version: 2.5.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.0
scipy version: 1.13.1
Pillow version: 11.1.0
Tensorboard version: 2.17.1
gdown version: 5.2.0
TorchVision version: 0.20.1+cu121
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.5
pandas version: 2.2.2
einops version: 0.8.0
transformers version: 4.47.1
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/l

I train the model on google colab, so i need to mount my google drive and initialize wandb to permanently store all my checkpoints and metrics

In [None]:
drive.mount('/content/drive')
root_dir = Path("/content/drive/MyDrive/FSDS")

wandb.login()
wandb.init(project="glioma-brain-tumor-segmentation")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


[34m[1mwandb[0m: Currently logged in as: [33mduongmaixa1207[0m ([33mduongmaixa1207-university-of-south-florida[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


## Label Preprocessing

originally, the BraTS dataset provides us with MRI scans for multiple patients, each patients with 4 different sequences corresponding to 4 channels. They essentially are 4 different features of the same scene. Our BraTS make us interested in 3 sub-regions, namely ET, TC and WT, but the provided segmentation masks do not necessarily equal to the 3 regions. Rather we are given 4 labels corresponding to sub-regions. Interestingly, the class ET is label 2 only while the TC is label 2 and label 3 (NET) while WT is label 1,2,3 combined. Hence, we need a function to transform the multi-channel labels (label 1,2,3) to our BraTS classes of interests

We can also show the Decalthon data folder structure

In [None]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    Reference: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb

    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 2 and label 3 to construct TC
            result.append(torch.logical_or(d[key] == 2, d[key] == 3))
            # merge labels 1, 2 and 3 to construct WT
            result.append(
                torch.logical_or(
                    torch.logical_or(d[key] == 2, d[key] == 3), d[key] == 1
                )
            )
            # label 2 is ET
            result.append(d[key] == 2)
            d[key] = torch.stack(result, axis=0).float()
        return d

## Configuration

need a config file for reproducibility

some interesting info:
Each sequence is 3D (spatial x slices), and they have spatial dimension of 224 x 224

However, most of the scans are background, hence we would like to focus on a smaller region with more condensed information

we set epoch=50 due to limited computational resources

In [None]:
config = wandb.config
config.seed = 29
config.roi_size = [128, 128, 144]
config.batch_size = 2
config.num_workers = 4
config.max_train_images_visualized = 5
config.max_val_images_visualized = 5
config.dice_loss_smoothen_numerator = 0
config.dice_loss_smoothen_denominator = 1e-5
config.dice_loss_squared_prediction = True
config.dice_loss_target_onehot = False
config.dice_loss_apply_sigmoid = True
config.initial_learning_rate = 1e-4
config.weight_decay = 1e-5
config.max_train_epochs = 50
config.validation_intervals = 5
config.dataset_dir = root_dir / "glioma_dataset"
config.checkpoint_dir = root_dir / "checkpoints"
config.inference_roi_size = (128, 128, 64)
config.max_prediction_images_visualized = 5

we use monai set_determinism to further enhance reproducibility. Here, showcase where MONAI make deterministic

In [None]:
set_determinism(seed=config.seed)

# Create directories
os.makedirs(config.dataset_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)

# Download, Preprocess & Visualize Interactive Dataset

## Image Preprocessing

***Image Preprocessing***


* **Activations**: Applies activation functions to the model output (like
sigmoid, softmax, etc.)
*   **AsDiscrete**: Converts continuous values to discrete values, often used in segmentation tasks to convert probability maps to binary or multi-class masks.
Compose: A utility to chain multiple transforms together
*   **LoadImaged**: Loads medical images from files using specified readers (NIfTI)

* **MapTransform**: A base class for transforms that process dictionary data

* **NormalizeIntensityd**: Normalizes the intensity of input images, using mean and standard deviation

* **Orientationd**: Ensures medical images have a consistent orientation (important for 3D medical data)

* **EnsureTyped**: Ensures the input data has a specified data type

* **EnsureChannelFirstd**: Ensures the input data follows a "channel-first" format (important for deep learning frameworks)



***Data Augmentation***

* **RandFlipd**: Randomly flips the image along specified axes for data augmentation

* **RandScaleIntensityd**: Randomly scales the intensity of input images for data augmentation

* **RandShiftIntensityd**: Randomly shifts the intensity of input images for data augmentation

* **RandSpatialCropd**: Randomly crops the spatial dimensions of images for data augmentation

* **Spacingd**: Resamples images to have a specified voxel spacing

The 'd' suffix on most of these transforms indicates they operate on dictionary inputs rather than direct tensor inputs, which is a MONAI convention for handling metadata alongside image data.

In [None]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCropd(
            keys=["image", "label"], roi_size=config.roi_size, random_size=False
        ),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
        RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
        RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
    ]
)
val_transform = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys="image"),
        EnsureTyped(keys=["image", "label"]),
        ConvertToMultiChannelBasedOnBratsClassesd(keys="label"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
    ]
)

## Generate Interactive Data Visualizations in WandB

In [None]:
def log_data_samples_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    split: str,
    data_idx: int,
    table: wandb.Table,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False, desc=f"Processing {split} data index {data_idx}") as progress_bar:
        for slice_idx in range(num_slices):
            ground_truth_wandb_images = [
                wandb.Image(
                    sample_image[channel_idx, :, :, slice_idx],
                    masks={
                        "ground-truth/Tumor-Core": {
                            "mask_data": sample_label[0, :, :, slice_idx],
                            "class_labels": {0: "background", 1: "Tumor Core"},
                        },
                        "ground-truth/Whole-Tumor": {
                            "mask_data": sample_label[1, :, :, slice_idx] * 2,
                            "class_labels": {0: "background", 2: "Whole Tumor"},
                        },
                        "ground-truth/Enhancing-Tumor": {
                            "mask_data": sample_label[2, :, :, slice_idx] * 3,
                            "class_labels": {0: "background", 3: "Enhancing Tumor"},
                        },
                    },
                )
                for channel_idx in range(num_channels)
            ]
            table.add_data(split, data_idx, slice_idx, *ground_truth_wandb_images)
            progress_bar.update(1)
    return table

def generate_visualizations(dataset, split, table, max_samples):
    progress_bar = tqdm(
        enumerate(dataset[:max_samples]),
        total=max_samples,
        desc=f"Generating {split.capitalize()} Dataset Visualizations:",
    )
    for data_idx, sample in progress_bar:
        sample_image = sample["image"].detach().cpu().numpy()
        sample_label = sample["label"].detach().cpu().numpy()
        table = log_data_samples_into_tables(
            sample_image, sample_label, split=split, data_idx=data_idx, table=table
        )
    return table


## Download Data and Visualize it on WandB

In [None]:
start = time.time()
print("Training data extraction in progress...\n")

# for visualization purpose, we do not use the train_transform for train data
# since we do not need the data augmentation part yet
train_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)

print("\nValidation data extraction in progress...")
val_dataset = DecathlonDataset(
    root_dir=config.dataset_dir,
    task="Task01_BrainTumour",
    transform=val_transform,
    section="validation",
    download=False,
    cache_rate=0.0,
    num_workers=4,
)

print("Data extraction takes {} seconds".format(time.time()-start))

Training data extraction in progress...

2025-01-14 08:52:24,959 - INFO - Verified 'Task01_BrainTumour.tar', md5: 240a19d752f0d9e9101544901065d872.
2025-01-14 08:52:24,961 - INFO - File exists: /content/drive/MyDrive/FSDS/glioma_dataset/Task01_BrainTumour.tar, skipped downloading.
2025-01-14 08:52:24,963 - INFO - Non-empty folder exists in /content/drive/MyDrive/FSDS/glioma_dataset/Task01_BrainTumour, skipped extracting.

Validation data extraction in progress...
Data extraction takes 20.23122763633728 seconds


In [None]:

# Initialize the W&B table
columns = [
    "Split",
    "Data Index",
    "Slice Index",
    "Image-Channel-0",
    "Image-Channel-1",
    "Image-Channel-2",
    "Image-Channel-3",
]
table = wandb.Table(columns=columns)

# Generate visualizations for train_dataset
max_train_samples = (
    min(config.max_train_images_visualized, len(train_dataset))
    if config.max_train_images_visualized > 0
    else len(train_dataset)
)
table = generate_visualizations(train_dataset, "train", table, max_train_samples)

# Generate visualizations for val_dataset
max_val_samples = (
    min(config.max_val_images_visualized, len(val_dataset))
    if config.max_val_images_visualized > 0
    else len(val_dataset)
)
table = generate_visualizations(val_dataset, "val", table, max_val_samples)

# Log the table to your dashboard
wandb.log({"Glioma-Segmentation-Data": table})

Generating Train Dataset Visualizations::   0%|          | 0/5 [00:00<?, ?it/s]

Processing train data index 0:   0%|          | 0/155 [00:00<?, ?it/s]

Processing train data index 1:   0%|          | 0/155 [00:00<?, ?it/s]

Processing train data index 2:   0%|          | 0/155 [00:00<?, ?it/s]

Processing train data index 3:   0%|          | 0/155 [00:00<?, ?it/s]

Processing train data index 4:   0%|          | 0/155 [00:00<?, ?it/s]

Generating Val Dataset Visualizations::   0%|          | 0/5 [00:00<?, ?it/s]

Processing val data index 0:   0%|          | 0/155 [00:00<?, ?it/s]

Processing val data index 1:   0%|          | 0/155 [00:00<?, ?it/s]

Processing val data index 2:   0%|          | 0/155 [00:00<?, ?it/s]

Processing val data index 3:   0%|          | 0/155 [00:00<?, ?it/s]

Processing val data index 4:   0%|          | 0/155 [00:00<?, ?it/s]

# Load Data & Train Model

***Image Preprocessing***


* **Activations**: Applies activation functions to the model output (like
sigmoid, softmax, etc.)
*   **AsDiscrete**: Converts continuous values to discrete values, often used in segmentation tasks to convert probability maps to binary or multi-class masks.
Compose: A utility to chain multiple transforms together
*   **LoadImaged**: Loads medical images from files using specified readers (NIfTI)

* **MapTransform**: A base class for transforms that process dictionary data

* **NormalizeIntensityd**: Normalizes the intensity of input images, using mean and standard deviation

* **Orientationd**: Ensures medical images have a consistent orientation (important for 3D medical data)

* **EnsureTyped**: Ensures the input data has a specified data type

* **EnsureChannelFirstd**: Ensures the input data follows a "channel-first" format (important for deep learning frameworks)



***Data Augmentation***

* **RandFlipd**: Randomly flips the image along specified axes for data augmentation

* **RandScaleIntensityd**: Randomly scales the intensity of input images for data augmentation

* **RandShiftIntensityd**: Randomly shifts the intensity of input images for data augmentation

* **RandSpatialCropd**: Randomly crops the spatial dimensions of images for data augmentation

* **Spacingd**: Resamples images to have a specified voxel spacing

The 'd' suffix on most of these transforms indicates they operate on dictionary inputs rather than direct tensor inputs, which is a MONAI convention for handling metadata alongside image data.

## Dataloader

In [None]:
# apply train_transforms to the training dataset
train_dataset.transform = train_transform

# create the train_loader
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)

# create the val_loader
val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    shuffle=False,
    num_workers=config.num_workers,
)

## Model Initialization

This code instantiates a SegResNet neural network for 3D medical image segmentation with 4 input channels and 3 output classes. The model uses an asymmetric architecture with more complexity in the downsampling path. Training is configured with Adam optimizer and cosine annealing learning rate scheduling. The entire pipeline runs on GPU for faster computation.

In [None]:
device = torch.device("cuda:0")

# create model
model = SegResNet(
    blocks_down=[1, 2, 2, 4],
    blocks_up=[1, 1, 1],
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2,
).to(device)

# create optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    config.initial_learning_rate,
    weight_decay=config.weight_decay,
)

# create learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=config.max_train_epochs
)

## Combined Loss function

This loss function combines Dice Loss and Focal Loss for medical image segmentation. Dice Loss promotes structural overlap between predictions and ground truth, while Focal Loss focuses on hard-to-classify examples. The losses are weighted equally by default (0.5 each) but can be adjusted. This combination addresses class imbalance and boundary delineation challenges, improving performance on both large and small anatomical structures.

In [None]:
class CombinedLoss(torch.nn.Module):
    def __init__(self, dice_weight=0.5, focal_weight=0.5):
        super().__init__()
        self.dice_loss = DiceLoss(
            smooth_nr=config.dice_loss_smoothen_numerator,
            smooth_dr=config.dice_loss_smoothen_denominator,
            squared_pred=config.dice_loss_squared_prediction,
            to_onehot_y=config.dice_loss_target_onehot,
            sigmoid=config.dice_loss_apply_sigmoid,
        )
        self.focal_loss = FocalLoss()
        self.dice_weight = dice_weight
        self.focal_weight = focal_weight

    def forward(self, y_pred, y_true):
        dice = self.dice_loss(y_pred, y_true)
        focal = self.focal_loss(y_pred, y_true)
        return self.dice_weight * dice + self.focal_weight * focal

# Loss function (weighted 0.5 each)
loss_function = CombinedLoss(dice_weight=0.5, focal_weight=0.5)


## Inference & Evaluation Metric/Utility

This code configures a medical segmentation training pipeline with dual Dice metrics for evaluation, and post-processing to convert predictions to binary masks. Performance is optimized through mixed-precision training and CUDNN benchmarking for faster GPU execution.

In [None]:
# evaluation metric
dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

# post-processing
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# use automatic mixed-precision to accelerate training
scaler = torch.amp.GradScaler('cuda')
torch.backends.cudnn.benchmark = True

This function performs efficient 3D medical image inference using sliding window technique. It processes large volumes by dividing them into overlapping patches of size (240, 240, 160), runs predictions on each patch, and seamlessly combines the results. The implementation leverages CUDA mixed-precision for faster execution and manages memory constraints by controlling batch size and 50% window overlap.

In [None]:
def inference(model, input):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    with torch.amp.autocast('cuda'):
        return _compute(input)

Log metric into WandB

In [None]:
wandb.define_metric("epoch/epoch_step")
wandb.define_metric("epoch/*", step_metric="epoch/epoch_step")
wandb.define_metric("batch/batch_step")
wandb.define_metric("batch/*", step_metric="batch/batch_step")
wandb.define_metric("validation/validation_step")
wandb.define_metric("validation/*", step_metric="validation/validation_step")

<wandb.sdk.wandb_metric.Metric at 0x792d2d168af0>

## Model Training

This code implements a robust training loop for a medical image segmentation model with brain tumor subtypes. It handles batch training with mixed-precision, periodic validation, and comprehensive metrics tracking (overall Dice score plus separate metrics for tumor core, whole tumor, and enhanced tumor). The implementation includes memory management through explicit cleanup (for Google Colab integration), error handling with try-except blocks, and integration with Weights & Biases for experiment tracking. Progress bars provide visual feedback during both epoch and batch-level training steps.

In [None]:
# Initialize step counters and metrics storage
batch_step = 0
validation_step = 0
metric_window_size = 25
metric_values = deque(maxlen=metric_window_size)
metric_values_tumor_core = deque(maxlen=metric_window_size)
metric_values_whole_tumor = deque(maxlen=metric_window_size)
metric_values_enhanced_tumor = deque(maxlen=metric_window_size)

# Add periodic memory cleanup
def cleanup_memory():
    gc.collect()
    torch.cuda.empty_cache()

start = time.time()
try:
    epoch_progress_bar = tqdm(range(config.max_train_epochs), desc="Training:")

    for epoch in epoch_progress_bar:
        cleanup_memory()

        # Reset epoch counters
        epoch_samples = 0
        epoch_loss = 0
        model.train()

        total_batch_steps = len(train_dataset) // train_loader.batch_size
        batch_progress_bar = tqdm(train_loader, total=total_batch_steps, leave=False)

        # Training Step
        for batch_data in batch_progress_bar:
            try:
                batch_size = batch_data["image"].size(0)
                inputs, labels = (
                    batch_data["image"].to(device, non_blocking=True),
                    batch_data["label"].to(device, non_blocking=True),
                )

                optimizer.zero_grad()
                with torch.amp.autocast('cuda'):
                    outputs = model(inputs)
                    loss = loss_function(outputs, labels)

                del outputs  # Explicitly free memory

                scaler.scale(loss).backward()

                scaler.step(optimizer)
                scaler.update()

                epoch_loss += loss.item() * batch_size
                epoch_samples += batch_size

                # Log training metrics
                try:
                    wandb.log({
                        "batch/batch_step": batch_step,
                        "batch/train_loss": loss.item(),
                        "batch/samples_processed": epoch_samples,
                    })
                except Exception as e:
                    print(f"Warning: Failed to log batch metrics: {e}")

                batch_step += 1
                batch_progress_bar.set_description(f"train_loss: {loss.item():.4f}")

                if batch_step % 50 == 0:
                    cleanup_memory()

            except Exception as e:
                print(f"Error in training batch: {e}")
                continue

        # End of epoch processing
        lr_scheduler.step()
        epoch_loss = epoch_loss / epoch_samples if epoch_samples > 0 else 0

        try:
            wandb.log({
                "epoch/epoch_step": epoch,
                "epoch/mean_train_loss": epoch_loss,
                "epoch/learning_rate": lr_scheduler.get_last_lr()[0],
            })
        except Exception as e:
            print(f"Warning: Failed to log epoch metrics: {e}")

        epoch_progress_bar.set_description(f"Training: train_loss: {epoch_loss:.4f}")

        if (epoch + 1) % config.validation_intervals == 0:
            cleanup_memory()

            model.eval()
            val_batch_loss = 0
            # total_val_samples = 0
            val_batch_count = 0  # Add batch counter

            with torch.no_grad():
                val_loader_with_progress = tqdm(val_loader, desc="Validating:", leave=False)

                for val_data in val_loader_with_progress:
                    try:
                        val_batch_count += 1

                        val_inputs, val_labels = (
                            val_data["image"].to(device, non_blocking=True),
                            val_data["label"].to(device, non_blocking=True),
                        )
                        batch_size = val_inputs.size(0)

                        val_outputs = inference(model, val_inputs)

                        processed_outputs = []
                        for i in decollate_batch(val_outputs):
                          processed_outputs.append(post_trans(i))
                          del i

                        # The error is happening because val_outputs[0] doesn't match the shape
                        # We need to ensure the shapes match before calculating loss
                        # Add a batch dimension to match val_labels shape
                        val_output_tensor = torch.stack(processed_outputs) # This adds the batch dimension

                        # Calculate validation loss
                        # val_loss = loss_function(val_output_tensor, val_labels)
                        # val_batch_loss += val_loss.item() * batch_size
                        # total_val_samples += batch_size

                        # For dice metric, we can keep the original format
                        dice_metric(y_pred=processed_outputs, y=val_labels)
                        dice_metric_batch(y_pred=processed_outputs, y=val_labels)

                        # # Log intermediate validation metrics
                        # if total_val_samples  % 10 == 0:
                        #     try:
                        #         wandb.log({
                        #             "validation/batch_loss": (val_batch_loss / total_val_samples if total_val_samples > 0 else 0),
                        #             "validation/batch_number": total_val_samples
                        #         })
                        #     except Exception as e:
                        #         print(f"Warning: Failed to log validation batch metrics: {e}")

                        del val_inputs, val_labels
                        del val_outputs, processed_outputs, val_output_tensor
                        cleanup_memory()

                    except Exception as e:
                        print(f"Error in validation batch: {e}")
                        print(f"val_outputs shape: {[v.shape for v in val_outputs]}")
                        print(f"val_labels shape: {val_labels.shape}")
                        continue

                # Aggregate validation metrics
                metric_values.append(dice_metric.aggregate().item())
                metric_batch = dice_metric_batch.aggregate()
                metric_values_tumor_core.append(metric_batch[0].item())
                metric_values_whole_tumor.append(metric_batch[1].item())
                metric_values_enhanced_tumor.append(metric_batch[2].item())
                dice_metric.reset()
                dice_metric_batch.reset()

                # Save checkpoint
                try:
                    checkpoint_path = os.path.join(config.checkpoint_dir, "model.pth")
                    torch.save(model.state_dict(), checkpoint_path)

                    artifact = wandb.Artifact(
                        name="model-checkpoint",
                        type="model",
                        description=f"Model checkpoint for epoch {epoch}"
                    )
                    artifact.add_file(local_path=checkpoint_path)
                    wandb.log_artifact(artifact, aliases=[f"epoch_{epoch}", "latest"])
                except Exception as e:
                    print(f"Warning: Failed to save checkpoint: {e}")

                # Log final validation metrics
                try:
                    wandb.log({
                        "validation/validation_step": validation_step,
                        "validation/mean_dice": metric_values[-1],
                        "validation/mean_dice_tumor_core": metric_values_tumor_core[-1],
                        "validation/mean_dice_whole_tumor": metric_values_whole_tumor[-1],
                        "validation/mean_dice_enhanced_tumor": metric_values_enhanced_tumor[-1],
                    })
                except Exception as e:
                    print(f"Warning: Failed to log final validation metrics: {e}")

                validation_step += 1

except Exception as e:
    print(f"Critical error in training loop: {e}")

finally:
    # Cleanup
    try:
        batch_progress_bar.close()
        epoch_progress_bar.close()
        wandb.finish()
    except Exception as e:
        print(f"Error during cleanup: {e}")

    print(f"Entire training process took {time.time()-start:.2f} seconds")

Training::   0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

  0%|          | 0/194 [00:00<?, ?it/s]

Validating::   0%|          | 0/48 [00:00<?, ?it/s]

0,1
batch/batch_step,▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███
batch/samples_processed,▃▄▄▅▃▇█▂▂▃▄▆▆▂▆▃▆▄▇▇▅▂▇▁▂▄█▁▇▃▆▂▇▂▄▅▅▇▅▁
batch/train_loss,█▇▄▄▄▂▅▄▃▅▃▂▄▅▄▃▂▂▂▂▄▆▁▃▅▂▂▁▄▂▂▂▄▄▄▃▅▁▄▁
epoch/epoch_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
epoch/learning_rate,███████▇▇▇▇▇▇▆▆▆▆▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
epoch/mean_train_loss,█▇▆▆▅▄▄▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
validation/mean_dice,▁▃▇▇▇▇████
validation/mean_dice_enhanced_tumor,▁▃▇▇▇▇████
validation/mean_dice_tumor_core,▁▃▇▇▆█████
validation/mean_dice_whole_tumor,▁▄▆▇▆▇████

0,1
batch/batch_step,9699.0
batch/samples_processed,388.0
batch/train_loss,0.11506
epoch/epoch_step,49.0
epoch/learning_rate,0.0
epoch/mean_train_loss,0.14705
validation/mean_dice,0.73862
validation/mean_dice_enhanced_tumor,0.53888
validation/mean_dice_tumor_core,0.78385
validation/mean_dice_whole_tumor,0.89314


Entire training process took 10234.81 seconds


It is almost done! only 1 minute left

# Model Evaluation & Prediction Logging

This function visualizes brain tumor segmentation results by creating a Weights & Biases table with side-by-side comparisons of model predictions and ground truth. For each slice of a 3D brain scan, it generates overlay visualizations for three tumor subregions (tumor core, whole tumor, and enhancing tumor) across all imaging channels (likely T1, T2, FLAIR, etc.). The function uses different color mappings for predictions versus ground truth masks to facilitate comparison, and includes a progress bar to track processing of potentially hundreds of image slices.

## Prediction Logging

In [None]:
def log_predictions_into_tables(
    sample_image: np.array,
    sample_label: np.array,
    predicted_label: np.array,
    split: str = None,
    data_idx: int = None,
    table: wandb.Table = None,
):
    num_channels, _, _, num_slices = sample_image.shape
    with tqdm(total=num_slices, leave=False) as progress_bar:
        for slice_idx in range(num_slices):
            wandb_images = []
            for channel_idx in range(num_channels):
                wandb_images += [
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Tumor-Core": {
                                "mask_data": sample_label[0, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Tumor Core"},
                            },
                            "prediction/Tumor-Core": {
                                "mask_data": predicted_label[0, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Tumor Core"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Whole-Tumor": {
                                "mask_data": sample_label[1, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Whole Tumor"},
                            },
                            "prediction/Whole-Tumor": {
                                "mask_data": predicted_label[1, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Whole Tumor"},
                            },
                        },
                    ),
                    wandb.Image(
                        sample_image[channel_idx, :, :, slice_idx],
                        masks={
                            "ground-truth/Enhancing-Tumor": {
                                "mask_data": sample_label[2, :, :, slice_idx],
                                "class_labels": {0: "background", 1: "Enhancing Tumor"},
                            },
                            "prediction/Enhancing-Tumor": {
                                "mask_data": predicted_label[2, :, :, slice_idx] * 2,
                                "class_labels": {0: "background", 2: "Enhancing Tumor"},
                            },
                        },
                    ),
                ]
            table.add_data(split, data_idx, slice_idx, *wandb_images)
            progress_bar.update(1)
    return table

## Evaluation

This code creates a visual validation pipeline for a brain tumor segmentation model using Weights & Biases. It initializes a W&B session, loads a pre-trained model from a versioned artifact, and creates a structured table for visualization. The pipeline processes a configurable number of validation samples, generating side-by-side comparisons of ground truth and predictions for three tumor subregions (tumor core, whole tumor, and enhancing tumor) across all four MRI modalities (channels). Results are logged to W&B as an interactive table, allowing for detailed qualitative evaluation of the model's performance on 3D brain scans.

In [None]:
wandb.init(project="glioma-brain-tumor-segmentation")

model_artifact = wandb.use_artifact(
    'duongmaixa1207-university-of-south-florida/glioma-brain-tumor-segmentation/model-checkpoint:v9',
    type="model",
)

model_artifact_dir = model_artifact.download()
model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))
model.eval()

# create the prediction table
prediction_table = wandb.Table(
    columns=[
        "Split",
        "Data Index",
        "Slice Index",
        "Image-Channel-0/Tumor-Core",
        "Image-Channel-1/Tumor-Core",
        "Image-Channel-2/Tumor-Core",
        "Image-Channel-3/Tumor-Core",
        "Image-Channel-0/Whole-Tumor",
        "Image-Channel-1/Whole-Tumor",
        "Image-Channel-2/Whole-Tumor",
        "Image-Channel-3/Whole-Tumor",
        "Image-Channel-0/Enhancing-Tumor",
        "Image-Channel-1/Enhancing-Tumor",
        "Image-Channel-2/Enhancing-Tumor",
        "Image-Channel-3/Enhancing-Tumor",
    ]
)

# Perform inference and visualization
with torch.no_grad():
    config.max_prediction_images_visualized
    max_samples = (
        min(config.max_prediction_images_visualized, len(val_dataset))
        if config.max_prediction_images_visualized > 0
        else len(val_dataset)
    )
    progress_bar = tqdm(
        enumerate(val_dataset[:max_samples]),
        total=max_samples,
        desc="Generating Predictions:",
    )
    for data_idx, sample in progress_bar:
        val_input = sample["image"].unsqueeze(0).to(device)
        val_output = inference(model, val_input)
        val_output = post_trans(val_output[0])
        prediction_table = log_predictions_into_tables(
            sample_image=sample["image"].cpu().numpy(),
            sample_label=sample["label"].cpu().numpy(),
            predicted_label=val_output.cpu().numpy(),
            data_idx=data_idx,
            split="validation",
            table=prediction_table,
        )

    wandb.log({"Predictions/Tumor-Segmentation-Data": prediction_table})


# End the experiment
wandb.finish()


[34m[1mwandb[0m:   1 of 1 files downloaded.  
  model.load_state_dict(torch.load(os.path.join(model_artifact_dir, "model.pth")))


Generating Predictions::   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

  0%|          | 0/155 [00:00<?, ?it/s]

#Todo later

Change model to ONNX Runtime for interoperability + deployment, and unexpected speed gain when coupled with ONNX runtime rather than default pytorch runtime. After that, consider writing a report on WandB

Report Reference["https://wandb.ai/geekyrakshit/brain-tumor-segmentation/reports/Brain-Tumor-Segmentation-using-MONAI-and-WandB---Vmlldzo0MjUzODIw"]

To accelerate training speed, use slide_window_inference, dont cache, perform periodic cleanup

Now to perform inference speed, we can do the highlighted text

# Reference

MONAI project: https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/brats_segmentation_3d.ipynb