##### Run cell only in GoogleColab

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


##### Installing dependencies (run cell only in GoogleColab)

In [None]:
# Install monai and torch
!pip install monai
!pip install torch

#### In this Jupyter Notebook we will display the results after training of the model

##### Importing the libraries

In [3]:
import os
from os.path import exists
from glob import glob
import torch
import numpy as np
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, TverskyLoss, DiceFocalLoss
from monai.data import Dataset, CacheDataset, DataLoader
from monai.utils import set_determinism

from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    ScaleIntensityRanged,
    RandAffined,
    RandFlipd,
    RandGaussianNoised,
    CropForegroundd,
    Orientationd,
    Resized,
    ToTensord,
    Spacingd,
    EnsureTyped,
)

from monai.data.image_reader import NibabelReader

##### Setting the path to the working directories

In [4]:
# The input paths for the prepared nifti files
nif_path = ['drive/MyDrive/data_set_group_nif/nif_files_testing/images',
            'drive/MyDrive/data_set_group_nif/nif_files_testing/labels',
            'drive/MyDrive/data_set_group_nif/nif_files_training/images',
            'drive/MyDrive/data_set_group_nif/nif_files_training/labels',]

print(nif_path[0])

drive/MyDrive/data_set_group_nif/nif_files_testing/images


##### Define the function for data preprocessing

In [5]:
def preprocess_data(data_path, batch_size=8, spatial_size=(256, 256, 16), pixdim=(1.5, 1.5, 2.0)):

    set_determinism(seed=0)

    # Create the dataset
    test_data = sorted(glob(data_path[0] + f'/*'))
    test_labels = sorted(glob(data_path[1] + f'/*'))

    train_data = sorted(glob(data_path[2] + f'/*'))
    train_labels = sorted(glob(data_path[3] + f'/*'))

    train_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_data, train_labels)]
    test_files = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_data, test_labels)]

    # Transforms for the training with data augmentation
    train_transforms = Compose([
        LoadImaged(keys=["image", "label"]),  # Load the images
        EnsureChannelFirstd(keys=["image", "label"]),  # Ensure the channel is the first dimension of the image
        Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),  # Resample the images
        Orientationd(keys=["image", "label"], axcodes="RAS"),  # Change the orientation of the image
        ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        # Change the contrast of the image and gives the image pixels,
        # values between 0 and 1
        CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),

        RandAffined(
            keys=["image", "label"],
            prob=0.7,
            translate_range=(10, 10, 5),
            rotate_range=(0, 0, np.pi / 15),
            scale_range=(0.1, 0.1, 0.1),
            mode=("bilinear", "nearest")
        ),
        RandGaussianNoised(keys="image", prob=0.5),
        RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),

        Resized(keys=["image", "label"], spatial_size=spatial_size),  # Resize the image
        EnsureTyped(keys=["image", "label"]),
        ToTensord(keys=["image", "label"]),  # Convert the images to tensors
    ])

    # Transforms for the testing
    test_transforms = Compose(# Compose transforms together
        [
            LoadImaged(keys=["image", "label"]),  # Load the images
            EnsureChannelFirstd(keys=["image", "label"]),  # Ensure the channel is the first dimension of the image
            Spacingd(keys=["image", "label"], pixdim=pixdim, mode=("bilinear", "nearest")),
            # Resample the images
            Orientationd(keys=["image", "label"], axcodes="RAS"),  # Change the orientation of the image
            ScaleIntensityRanged(keys=["image"], a_min=-200, a_max=250, b_min=0.0, b_max=1.0, clip=True),
            # Change the contrast of the image and gives the image pixels,
            # values between 0 and 1
            CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),  # Crop foreground of the image
            Resized(keys=["image", "label"], spatial_size=spatial_size),  # Resize the image
            EnsureTyped(keys=["image", "label"]),
            ToTensord(keys=["image", "label"]),  # Convert the images to tensors
        ]
    )

    # Create the datasets
    train_ds = CacheDataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds, batch_size=batch_size)

    test_ds = CacheDataset(data=test_files, transform=test_transforms)
    test_loader = DataLoader(test_ds, batch_size=batch_size)

    return train_loader, test_loader

##### Preprocess the data

In [11]:
# Save the metadata of the entire training set
data_in = preprocess_data(
    nif_path,
    batch_size=2,  # start conservative
    spatial_size=(128, 128, 32),
    pixdim=(0.7871384, 0.7871384, 1.2131842)
)

Loading dataset: 100%|██████████| 748/748 [15:08<00:00,  1.21s/it]
Loading dataset: 100%|██████████| 240/240 [06:52<00:00,  1.72s/it]


##### Setting the device for training

In [6]:
# We do the training on the GPU
device = torch.device("cuda:0")
print(device)

cuda:0


##### Initialize the model

In [13]:
model = UNet(
  spatial_dims=3,
  in_channels=1,
  out_channels=2,
  channels=(16, 32, 64, 128),
  strides=(2, 2, 2),
  num_res_units=2,
  norm=Norm.BATCH,
)

device = torch.device("cpu")
model = model.to(device)

##### Initialize the loss function and the optimizer

In [8]:
loss_function = DiceFocalLoss(to_onehot_y=True, softmax=True, lambda_focal=0.5)
optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5, amsgrad=True)

In [9]:
def dice_metric(y_pred, y):
    dice_loss = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
    dice_coeff = 1 - dice_loss(y_pred, y).item()
    return dice_coeff

##### Define the training loop

In [10]:
# Function for training the model
def train(model, data_in, loss_function, optimizer, max_epochs, model_dir, test_interval=1,
          device=torch.device('cuda:0')):
    best_metric = -1
    best_metric_epoch = -1
    save_loss_train = []
    save_loss_test = []
    save_metric_train = []
    save_metric_test = []
    train_loader, test_loader = data_in

    for epoch in range(max_epochs):
        model.train()
        train_epoch_loss = 0
        train_step = 0
        train_epoch_metric = 0

        for batch_data in train_loader:
            train_step += 1
            volumes = batch_data["image"]
            labels = batch_data["label"]
            labels = labels != 0
            volumes = volumes.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(volumes)

            train_loss = loss_function(outputs, labels)

            train_loss.backward()
            optimizer.step()

            train_epoch_loss += train_loss.item()
            train_metric = dice_metric(outputs, labels)
            train_epoch_metric += train_metric

            print(
                f"{epoch + 1}/{max_epochs} and {train_step}/{len(train_loader)} => train_loss: {train_loss.item():.4f} and train_metric: {train_metric:.4f}")

        print('Saving training data after epoch: ' + str(epoch + 1))
        train_epoch_loss /= train_step
        print(f"epoch {epoch + 1} average training loss: {train_epoch_loss:.4f}")
        save_loss_train.append(train_epoch_loss)
        np.save(os.path.join(model_dir, 'train_loss.npy'), save_loss_train)

        train_epoch_metric /= train_step
        print(f"epoch {epoch + 1} average training metric: {train_epoch_metric:.4f}")
        save_metric_train.append(train_epoch_metric)
        np.save(os.path.join(model_dir, 'train_metric.npy'), save_metric_train)

        if (epoch + 1) % test_interval == 0:
            model.eval()
            with torch.no_grad():
                test_epoch_loss = 0
                test_metric = 0
                test_step = 0
                test_epoch_metric = 0

                for test_data in test_loader:
                    test_step += 1
                    volumes = test_data["image"]
                    labels = test_data["label"]
                    labels = labels != 0
                    volumes = volumes.to(device)
                    labels = labels.to(device)

                    outputs = model(volumes)

                    test_loss = loss_function(outputs, labels)

                    test_epoch_loss += test_loss.item()
                    test_metric = dice_metric(outputs, labels)
                    test_epoch_metric += test_metric

                    print(
                        f"{epoch + 1}/{max_epochs} and {test_step}/{len(test_loader)} => test_loss: {test_loss.item():.4f} and test_metric: {test_metric:.4f}")

                print('Saving testing data after epoch: ' + str(epoch + 1))
                test_epoch_loss /= test_step
                print(f"epoch {epoch + 1} average testing loss: {test_epoch_loss:.4f}")
                save_loss_test.append(test_epoch_loss)
                np.save(os.path.join(model_dir, 'test_loss.npy'), save_loss_test)

                test_epoch_metric /= test_step
                print(f"epoch {epoch + 1} average testing metric: {test_epoch_metric:.4f}")
                save_metric_test.append(test_epoch_metric)
                np.save(os.path.join(model_dir, 'test_metric.npy'), save_metric_test)

                if test_epoch_metric > best_metric:
                    best_metric = test_epoch_metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), os.path.join(model_dir, "best_metric_model.pth"))

                print(f"current epoch: {epoch + 1} current test Dice coefficient: {test_epoch_metric:.4f}"
                      f"\nbest metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

    print(f"train completed => best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

##### Train the model

In [None]:
model_dir = 'drive/MyDrive/trained_models/post_training_UNet_128_128_32'
os.makedirs(model_dir, exist_ok=True)

train(model=model,
      data_in=data_in,
      loss_function=loss_function,
      optimizer=optimizer,
      max_epochs=100,
      model_dir=model_dir,
      test_interval=4,
      device=device
)


1/100 and 1/374 => train_loss: 0.7062 and train_metric: 0.5059
1/100 and 2/374 => train_loss: 0.7384 and train_metric: 0.4475
1/100 and 3/374 => train_loss: 0.7131 and train_metric: 0.4967
1/100 and 4/374 => train_loss: 0.6887 and train_metric: 0.5411
1/100 and 5/374 => train_loss: 0.6811 and train_metric: 0.5541
1/100 and 6/374 => train_loss: 0.6609 and train_metric: 0.5888
1/100 and 7/374 => train_loss: 0.6926 and train_metric: 0.5327
1/100 and 8/374 => train_loss: 0.7378 and train_metric: 0.4488
1/100 and 9/374 => train_loss: 0.7058 and train_metric: 0.5099
1/100 and 10/374 => train_loss: 0.6802 and train_metric: 0.5546
1/100 and 11/374 => train_loss: 0.6542 and train_metric: 0.6029
1/100 and 12/374 => train_loss: 0.6763 and train_metric: 0.5613
1/100 and 13/374 => train_loss: 0.7424 and train_metric: 0.4401
1/100 and 14/374 => train_loss: 0.7315 and train_metric: 0.4616
1/100 and 15/374 => train_loss: 0.7044 and train_metric: 0.5124
1/100 and 16/374 => train_loss: 0.6752 and train_