Imports and Setup


In [None]:
# imports.py
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from monai.config import print_config
from monai.data import CacheDataset
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged,
    CropForegroundd, RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandZoomd, EnsureTyped
)
from monai.networks.nets import VNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference

# Ensure reproducibility
torch.manual_seed(0)
np.random.seed(0)

# Set up device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


 Data Preparation


In [None]:
# data_preparation.py
from imports import *

def prepare_data(data_dir):
    image_files = glob.glob(os.path.join(data_dir, "images/*.nii"))
    label_files = glob.glob(os.path.join(data_dir, "labels/*.nii"))

    # Ensure the list is sorted if needed
    image_files.sort()
    label_files.sort()

    # Create a list of dictionaries for dataset
    data_list = [{"image": img, "label": lbl} for img, lbl in zip(image_files, label_files)]

    # Split dataset into training and validation sets
    split_index = int(len(data_list) * 0.8)
    train_files = data_list[:split_index]
    val_files = data_list[split_index:]

    return train_files, val_files

def get_transforms():
    train_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            ScaleIntensityRanged(keys=["image"], a_min=-100, a_max=400, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            RandCropByPosNegLabeld(
                keys=["image", "label"],
                label_key="label",
                spatial_size=(128, 128, 64),
                pos=1,
                neg=1,
                num_samples=4,
                image_key="image",
                image_threshold=0,
            ),
            RandFlipd(keys=["image", "label"], spatial_axis=[0], prob=0.10),
            RandFlipd(keys=["image", "label"], spatial_axis=[1], prob=0.10),
            RandFlipd(keys=["image", "label"], spatial_axis=[2], prob=0.10),
            RandRotate90d(keys=["image", "label"], prob=0.10, max_k=3),
            RandZoomd(keys=["image", "label"], min_zoom=0.9, max_zoom=1.1, prob=0.10),
            EnsureTyped(keys=["image", "label"]),
        ]
    )

    val_transforms = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            EnsureChannelFirstd(keys=["image", "label"]),
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
            Orientationd(keys=["image", "label"], axcodes="RAS"),
            ScaleIntensityRanged(keys=["image"], a_min=-100, a_max=400, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=["image", "label"], source_key="image"),
            EnsureTyped(keys=["image", "label"]),
        ]
    )

    return train_transforms, val_transforms

def load_data(train_files, val_files, train_transforms, val_transforms):
    train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.1)
    val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=0.1)

    train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)

    return train_loader, val_loader


Model Setup


In [None]:
# model_setup.py
from imports import *

def initialize_model():
    model = VNet(spatial_dims=3, in_channels=1, out_channels=5).to(device)  # 4 organs + 1 background
    return model

def get_loss_function_and_optimizer(model):
    loss_function = DiceLoss(include_background=False, to_onehot_y=True, softmax=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    return loss_function, optimizer


Training and Validation


In [None]:
# training_validation.py
from imports import *
from data_preparation import prepare_data, get_transforms, load_data
from model_setup import initialize_model, get_loss_function_and_optimizer

def train_and_validate(model, train_loader, val_loader, loss_function, optimizer, model_dir, num_epochs=100, val_interval=2):
    # Initialize Dice metrics for each organ
    dice_metric_liver = DiceMetric(include_background=False, reduction="mean")
    dice_metric_right_kidney = DiceMetric(include_background=False, reduction="mean")
    dice_metric_left_kidney = DiceMetric(include_background=False, reduction="mean")
    dice_metric_spleen = DiceMetric(include_background=False, reduction="mean")

    # Define variables to track best performance
    best_metric_liver = -1
    best_metric_right_kidney = -1
    best_metric_left_kidney = -1
    best_metric_spleen = -1

    # Lists to store metrics over epochs
    liver_metrics = []
    right_kidney_metrics = []
    left_kidney_metrics = []
    spleen_metrics = []

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        # Print training results
        epoch_loss /= step
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

        # Validation
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_steps = 0
                dice_liver = 0
                dice_right_kidney = 0
                dice_left_kidney = 0
                dice_spleen = 0
                for batch_data in val_loader:
                    val_steps += 1
                    inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
                    outputs = sliding_window_inference(inputs, (128, 128, 64), 4, model)

                    # Compute metrics for each organ
                    outputs = torch.softmax(outputs, dim=1)
                    dice_metric_liver(y_pred=outputs[:, 1], y=labels[:, 1])
                    dice_metric_right_kidney(y_pred=outputs[:, 2], y=labels[:, 2])
                    dice_metric_left_kidney(y_pred=outputs[:, 3], y=labels[:, 3])
                    dice_metric_spleen(y_pred=outputs[:, 4], y=labels[:, 4])

                # Calculate average Dice score for each organ
                dice_liver = dice_metric_liver.aggregate().item()
                dice_right_kidney = dice_metric_right_kidney.aggregate().item()
                dice_left_kidney = dice_metric_left_kidney.aggregate().item()
                dice_spleen = dice_metric_spleen.aggregate().item()

                dice_metric_liver.reset()
                dice_metric_right_kidney.reset()
                dice_metric_left_kidney.reset()
                dice_metric_spleen.reset()

                print(f"Validation Dice Liver: {dice_liver:.4f}")
                print(f"Validation Dice Right Kidney: {dice_right_kidney:.4f}")
                print(f"Validation Dice Left Kidney: {dice_left_kidney:.4f}")
                print(f"Validation Dice Spleen: {dice_spleen:.4f}")

                # Save best model checkpoints
                if dice_liver > best_metric_liver:
                    best_metric_liver = dice_liver
                    torch.save(model.state_dict(), os.path.join(model_dir, "best_metric_liver_model.pth"))

                if dice_right_kidney > best_metric_right_kidney:
                    best_metric_right_kidney = dice_right_kidney
                    torch.save(model.state_dict(), os.path.join(model_dir, "best_metric_right_kidney_model.pth"))

                if dice_left_kidney > best_metric_left_kidney:
                    best_metric_left_kidney = dice_left_kidney
                    torch.save(model.state_dict(), os.path.join(model_dir, "best_metric_left_kidney_model.pth"))

                if dice_spleen > best_metric_spleen:
                    best_metric_spleen = dice_spleen
                    torch.save(model.state_dict(), os.path.join(model_dir, "best_metric_spleen_model.pth"))

                # Append metrics
                liver_metrics.append(dice_liver)
                right_kidney_metrics.append(dice_right_kidney)
                left_kidney_metrics.append(dice_left_kidney)
                spleen_metrics.append(dice_spleen)

    print("Training complete.")

    # Plotting Dice scores over epochs for each organ
    plt.figure(figsize=(12, 8))
    plt.plot(range(1, len(liver_metrics) + 1), liver_metrics, label='Liver')
    plt.plot(range(1, len(right_kidney_metrics) + 1), right_kidney_metrics, label='Right Kidney')
    plt.plot(range(1, len(left_kidney_metrics) + 1), left_kidney_metrics, label='Left Kidney')
    plt.plot(range(1, len(spleen_metrics) + 1), spleen_metrics, label='Spleen')
    plt.xlabel('Epochs')
    plt.ylabel('Dice Score')
    plt.title('Dice Score over Epochs for each Organ')
    plt.legend()
    plt.show()
