##### Run cell only in GoogleColab

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


##### Installing dependencies (run cell only in GoogleColab)

In [None]:
# Install a specific version of numpy first
!pip install numpy==1.23.5 # You might need to adjust the numpy version based on your MONAI version

# Install monai and torch
!pip install monai
!pip install torch

#### In this Jupyter Notebook we will display the results after training of the model

##### Importing the libraries

In [26]:
import os
from os.path import exists
from glob import glob
import torch
import numpy as np
from torch.amp import autocast, GradScaler
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, TverskyLoss
from monai.metrics import DiceMetric
from monai.data import Dataset, CacheDataset,DataLoader
from monai.utils import set_determinism
from monai.networks.utils import one_hot

from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityRanged,
    RandAffined,
    RandRotated,
    RandGaussianNoised,
    CropForegroundd,
    Orientationd,
    Resized,
    ToTensord,
    Spacingd,
)

##### Setting the path to the working directories

In [3]:
# The input paths for the prepared nifti files
nif_path = ['drive/MyDrive/data_set_group_nif/nif_files_testing/images',
            'drive/MyDrive/data_set_group_nif/nif_files_testing/labels',
            'drive/MyDrive/data_set_group_nif/nif_files_training/images',
            'drive/MyDrive/data_set_group_nif/nif_files_training/labels',]

print(nif_path[0])

drive/MyDrive/data_set_group_nif/nif_files_testing/images


##### Define the function for data preprocessing

In [4]:
def preprocess_data(data_path, batch_size=1, spatial_size=(256, 256, 16)):

    set_determinism(seed=0)

    # Create the dataset
    test_data = sorted(glob(data_path[0] + f'/*'))
    test_labels = sorted(glob(data_path[1] + f'/*'))

    train_data = sorted(glob(data_path[2] + f'/*'))
    train_labels = sorted(glob(data_path[3] + f'/*'))

    train_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_data, train_labels)]
    test_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_data, test_labels)]

    # Transforms for the training with data augmentation
    train_transforms = Compose(# Compose transforms together
        [
            LoadImaged(keys=["image", "label"]), # Load the images
            EnsureChannelFirstd(keys=["image", "label"]), # Ensure the channel is the first dimension of the image
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), # Resample the images
            Orientationd(keys=["image", "label"], axcodes="RAS"), # Change the orientation of the image
            ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),# Change the contrast of the image and gives the image pixels,
                                                                                                            #values between 0 and 1
            CropForegroundd(keys=["image", "label"], source_key="image"), # Crop foreground of the image
            RandAffined(keys=['image', 'label'], prob=0.5, translate_range=10), # Randomly shift the image
            RandRotated(keys=['image', 'label'], prob=0.5, range_x=10.0), # Randomly rotate the image
            RandGaussianNoised(keys='image', prob=0.5), # Add random noise to the image
            Resized(keys=["image", "label"], spatial_size=spatial_size), # Resize the image
            ToTensord(keys=["image", "label"]), # Convert the images to tensors
        ]
    )

    # Transforms for the testing
    test_transforms = Compose(# Compose transforms together
        [
            LoadImaged(keys=["image", "label"]), # Load the images
            EnsureChannelFirstd(keys=["image", "label"]), # Ensure the channel is the first dimension of the image
            Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")), # Resample the images
            Orientationd(keys=["image", "label"], axcodes="RAS"), # Change the orientation of the image
            ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=200, b_min=0.0, b_max=1.0, clip=True),# Change the contrast of the image and gives the image pixels,
                                                                                                            #values between 0 and 1
            CropForegroundd(keys=["image", "label"], source_key="image"), # Crop foreground of the image
            Resized(keys=["image", "label"], spatial_size=spatial_size), # Resize the image
            ToTensord(keys=["image", "label"]), # Convert the images to tensors
        ]
    )

    # Create the datasets
    train_ds = CacheDataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size)

    test_ds = CacheDataset(data=test_files, transform=test_transforms)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    return train_loader, test_loader

##### Preprocess the data

In [5]:
data_in = preprocess_data(nif_path, batch_size=32, spatial_size=(256, 256, 16))

Loading dataset: 100%|██████████| 748/748 [07:02<00:00,  1.77it/s]
Loading dataset: 100%|██████████| 240/240 [02:31<00:00,  1.58it/s]


##### Setting the device for training

In [16]:
# We do the training on the GPU
device = torch.device("cuda:0")
print(device)

cuda:0


##### Initialize the model

In [17]:
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
)

device = torch.device("cpu")
model = model.to(device)

##### Initialize the loss function and the optimizer

In [18]:
loss_function = TverskyLoss(
    to_onehot_y=True,
    softmax=True,
    alpha=0.7,  # penalize false negatives more (missing tumors)
    beta=0.3,
    include_background=True
)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5, amsgrad=True)

##### Define the training loop

In [48]:
def train(
    model,
    data_in,
    num_classes,
    loss_function,
    optimizer,
    max_epochs,
    model_dir,
    test_interval=1,
    device=torch.device('cuda:0')
):
    train_loader, test_loader = data_in

    dice_metric = DiceMetric(include_background=False, reduction="none", get_not_nans=False)
    scaler = GradScaler()

    # Tracking
    best_metric = -1
    best_metric_epoch = -1
    save_loss_train, save_loss_test = [], []
    save_metric_train, save_metric_test = [], []

    for epoch in range(max_epochs):
        print(f"\n--- Epoch {epoch + 1}/{max_epochs} ---")
        model.train()
        train_loss_total = 0.0
        train_dice_sum = torch.zeros(num_classes - 1, device=device)  # skip background
        steps = 0

        for step, batch_data in enumerate(train_loader):
            volumes = batch_data["image"].to(device)
            labels = batch_data["label"].to(device)

            optimizer.zero_grad()
            with autocast(device_type=device.type):
                outputs = model(volumes)
                loss = loss_function(outputs, labels)

            if device.type == 'cuda':
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            with torch.no_grad():
                preds = torch.softmax(outputs, dim=1)
                labels_onehot = one_hot(labels, num_classes=num_classes)
                dice_scores = dice_metric(y_pred=preds, y=labels_onehot)
                #Average dice scores over the batch dimension before accumulation
                dice_scores = dice_scores.mean(dim=0)
                if dice_scores.ndim > 1:
                    dice_scores = dice_scores.squeeze()
                train_dice_sum += dice_scores
                train_loss_total += loss.item()
                steps += 1

            print(f"Step {step+1}/{len(train_loader)} => "
                  f"Loss: {loss.item():.4f} | "
                  f"Liver Dice: {dice_scores[0].item():.4f} | "
                  f"Tumor Dice: {dice_scores[1].item():.4f}")

        epoch_loss = train_loss_total / steps
        epoch_dice = (train_dice_sum / steps).cpu().numpy()
        avg_dice = epoch_dice.mean()

        save_loss_train.append(epoch_loss)
        save_metric_train.append(avg_dice)
        np.save(os.path.join(model_dir, 'train_loss.npy'), save_loss_train)
        np.save(os.path.join(model_dir, 'train_metric.npy'), save_metric_train)

        print(f"✅ Epoch {epoch+1} Train Avg Loss: {epoch_loss:.4f} | "
              f"Liver Dice: {epoch_dice[0]:.4f}, Tumor Dice: {epoch_dice[1]:.4f}, "
              f"Avg Dice: {avg_dice:.4f}")

        # ---------- TESTING ----------
        if (epoch + 1) % test_interval == 0:
            model.eval()
            test_loss_total = 0.0
            test_dice_sum = torch.zeros(num_classes - 1, device=device)
            steps = 0

            with torch.no_grad():
                for step, test_data in enumerate(test_loader):
                    volumes = test_data["image"].to(device)
                    labels = test_data["label"].to(device)

                    with autocast(device_type=device.type):
                        outputs = model(volumes)
                        loss = loss_function(outputs, labels)

                    preds = torch.softmax(outputs, dim=1)
                    labels_onehot = one_hot(labels, num_classes=num_classes)
                    dice_scores = dice_metric(y_pred=preds, y=labels_onehot)
                    # Average dice scores over the batch dimension before accumulation
                    dice_scores = dice_scores.mean(dim=0)
                    if dice_scores.ndim > 1:
                        dice_scores = dice_scores.squeeze()
                    test_dice_sum += dice_scores
                    test_loss_total += loss.item()
                    steps += 1

                    print(f"Step {step+1}/{len(test_loader)} => "
                          f"Loss: {loss.item():.4f} | "
                          f"Liver Dice: {dice_scores[0].item():.4f} | "
                          f"Tumor Dice: {dice_scores[1].item():.4f}")

            epoch_test_loss = test_loss_total / steps
            epoch_test_dice = (test_dice_sum / steps).cpu().numpy()
            avg_test_dice = epoch_test_dice.mean()

            save_loss_test.append(epoch_test_loss)
            save_metric_test.append(avg_test_dice)
            np.save(os.path.join(model_dir, 'test_loss.npy'), save_loss_test)
            np.save(os.path.join(model_dir, 'test_metric.npy'), save_metric_test)

            print(f"🔍 Epoch {epoch+1} Test Avg Loss: {epoch_test_loss:.4f} | "
                  f"Liver Dice: {epoch_test_dice[0]:.4f}, Tumor Dice: {epoch_test_dice[1]:.4f}, "
                  f"Avg Dice: {avg_test_dice:.4f}")

            if avg_test_dice > best_metric:
                best_metric = avg_test_dice
                best_metric_epoch = epoch + 1
                model_path = os.path.join(model_dir, f"best_model_epoch{epoch+1}_dice{best_metric:.4f}.pth")
                torch.save(model.state_dict(), model_path)
                print(f"💾 Best model saved at epoch {epoch+1} with Avg Dice {best_metric:.4f}")

    print(f"\n🏁 Training complete. Best Avg Dice: {best_metric:.4f} at epoch {best_metric_epoch}")


##### Train the model

In [None]:
model_dir = 'drive/MyDrive/trained_models/post_training_Unet'
os.makedirs(model_dir, exist_ok=True)

train(model=model,
      data_in=data_in,
      num_classes=3,
      loss_function=loss_function,
      optimizer=optimizer,
      max_epochs=10,
      model_dir=model_dir,
      test_interval=1,
      device=device
)
