<a href="https://colab.research.google.com/github/Riky2014/NAPDE/blob/main/cervelli.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install MONAI

In [None]:
%%capture
!pip install -U "monai-weekly[fire, nibabel, yaml, tqdm, einops]"

# Import

In [None]:
import os
import time
import torch
import tempfile
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

import monai
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.config import print_config
from monai.utils import set_determinism
from monai.networks.nets import SegResNet
from monai.inferers import sliding_window_inference
from monai.data import DataLoader, decollate_batch

import gc
gc.collect()
torch.cuda.empty_cache()

from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
    LoadImaged,
    NormalizeIntensityd,
    Orientationd,
    RandFlipd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    RandSpatialCropd,
    Spacingd,
    EnsureTyped,
    EnsureChannelFirstd,
)

# Print configuration
This instruction is not mandatory: it just shows the details regarding the current work configuration.

In [None]:
print_config()

# Set the directory
Here we set the working directory.

In [None]:
directory_path = '/kaggle/working/cervelli'
os.makedirs(directory_path, exist_ok = True)
os.environ["MONAI_DATA_DIRECTORY"] = directory_path
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory

# Auxiliary functions
First of all, it is necessary to execute all the cells belonging to this section.

In [None]:
def create_dataset(n_train, n_test):
  os.environ["MONAI_DATA_DIRECTORY"] = "/kaggle/input/ciaooo/Immagini_cervelli_giuste" #this directory need to be changed according to the location of the dataset
  directory = os.environ.get("MONAI_DATA_DIRECTORY")
  root_dir = tempfile.mkdtemp() if directory is None else directory

  set_determinism(seed = 0)

  images = sorted(glob(os.path.join(root_dir, "image*.nii")))
  labels = sorted(glob(os.path.join(root_dir, "label*.nii")))

  train_files = [{"image": image, "label": label} for image,
                 label in zip(images[:n_train], labels[:n_train])]
  val_files = [{"image": image, "label": label} for image,
               label in zip(images[-n_test:], labels[-n_test:])]

  return train_files, val_files

In [None]:
def transform(train_files, val_files):
  train_transform = Compose([
      LoadImaged(keys=["image", "label"]),
      EnsureChannelFirstd(keys=["image","label"]),
      EnsureTyped(keys=["image", "label"]),
      Orientationd(keys=["image", "label"], axcodes="RAS"),
      Spacingd(
          keys=["image", "label"],
          pixdim=(1.0, 1.0, 1.0),
          mode=("bilinear", "nearest"),
      ),
      RandSpatialCropd(keys=["image", "label"], roi_size=[240, 240, 160], random_size=False),
      RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=0),
      RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=1),
      RandFlipd(keys=["image", "label"], prob=0.5, spatial_axis=2),
      NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
      RandScaleIntensityd(keys="image", factors=0.1, prob=1.0),
      RandShiftIntensityd(keys="image", offsets=0.1, prob=1.0),
  ])

  val_transform = Compose([
      LoadImaged(keys=["image", "label"]),
      EnsureChannelFirstd(keys=["image","label"]),
      EnsureTyped(keys=["image", "label"]),
      Orientationd(keys=["image", "label"], axcodes="RAS"),
      Spacingd(
          keys=["image", "label"],
          pixdim=(1.0, 1.0, 1.0),
          mode=("bilinear", "nearest"),
      ),
      NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
  ])

  train_ds = monai.data.Dataset(data=train_files, transform=train_transform)
  train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=2)

  val_ds = monai.data.Dataset(data=val_files, transform=val_transform)
  val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=2)

  return train_loader, val_loader, val_ds

In [None]:
def inference(input, model, VAL_AMP):
    def _compute(input):
        return sliding_window_inference(
            inputs=input,
            roi_size=(240, 240, 160),
            sw_batch_size=1,
            predictor=model,
            overlap=0.5,
        )

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

In [None]:
def model_and_train(n_train, n_test, train_loader, val_loader, max_epochs, val_ds, filters = 16):
  print(f"Start training")
  print(f"Number of training images = {n_train}")
  print(f"Number of testing images = {n_test}")
  print()
    
  val_interval = 1
  VAL_AMP = True

  device = torch.device("cuda:0")
  model = SegResNet(
    blocks_down=[1, 2, 2, 4], #default: [1, 2, 2, 4]
    blocks_up=[1, 1, 1], #default: [1, 1, 1]
    init_filters=filters,
    in_channels=1,
    out_channels=1,
    dropout_prob=0.2,
  ).to(device)

  loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True,
                           to_onehot_y=False, sigmoid=True)
  optimizer = torch.optim.Adam(model.parameters(), 1e-4, weight_decay=1e-5)
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

  dice_metric = DiceMetric(include_background=False, reduction="mean")        # False
  dice_metric_train = DiceMetric(include_background=False, reduction="mean")  # False

  post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

  scaler = torch.cuda.amp.GradScaler()
  torch.backends.cudnn.benchmark = True

  epoch_loss_values = []
  metric_values = []
  metric_values_train = []

  total_start = time.time()
    
  for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        inputs, labels = (
            batch_data["image"].to(device),
            batch_data["label"].to(device)
        )
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(inputs)  #inference
            loss = loss_function(outputs, labels)

        outputs = [post_trans(i) for i in decollate_batch(outputs)]  # added by the guy

        dice_metric_train(y_pred=outputs, y=labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        epoch_loss += loss.item()
        

    metric_train = dice_metric_train.aggregate().item()
    metric_values_train.append(metric_train)
    dice_metric_train.reset()

    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"Loss: {epoch_loss:.4f} \nTrain dice: {metric_train:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device)
                )
                val_outputs = inference(val_inputs, model, VAL_AMP)  #model
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]

                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            dice_metric.reset()
            print(f"Test dice: {metric:.4f}")

    print(f"Time: {(time.time() - epoch_start):.4f}")
    
  total_time = time.time() - total_start
  print(f"Train completed, total time: {total_time}.")
  print()
  print(f"Train metric = {metric_values_train[-1]}, Test metric = {metric_values[-1]}")
  print()
  print()

  return metric_values_train, metric_values, epoch_loss_values, model, device, VAL_AMP

In [None]:
def plot_patient(plot_index, z_coordinate, val_ds):
  val_data_example = val_ds[plot_index]
  plt.figure("patient", (12, 6))

  #plot of the input image
  print(f"image shape: {val_data_example['image'].shape}")
  plt.subplot(1, 2, 1)
  plt.title(f"patient {plot_index}, input image")
  plt.imshow(val_data_example["image"][0,:, :,z_coordinate].detach().cpu(), cmap="gray")

  # also visualize the label corresponding to this image
  print(f"label shape: {val_data_example['label'].shape}")
  plt.subplot(1, 2, 2)
  plt.title(f"patient {plot_index}, label")
  plt.imshow(val_data_example["label"][0,:, :,z_coordinate].detach().cpu())
    
  plt.show()
    
  return

In [None]:
def plot_metrics(epoch_loss_values, metric_values, metric_values_train, val_interval, n_train, max_epochs):
  output_dir = '/kaggle/working/output_plots'
  os.makedirs(output_dir, exist_ok=True)
  plot_path = os.path.join(output_dir,
                           f"loss and metrics, n_train = {n_train}, epochs = {max_epochs}.png")

  plt.figure("train", (18, 6))

  # Loss function
  plt.subplot(1, 3, 1)
  plt.title("Loss")
  x = [i + 1 for i in range(len(epoch_loss_values))]
  y = epoch_loss_values
  plt.xlabel("epoch")
  plt.plot(x, y, color="red")

  # Test metric
  plt.subplot(1, 3, 2)
  plt.title("Test metric")
  x = [val_interval * (i + 1) for i in range(len(metric_values))]
  y = metric_values
  plt.xlabel("epoch")
  plt.plot(x, y, color="green")

  # Train metric
  plt.subplot(1, 3, 3)
  plt.title("Train metric")
  x = [val_interval * (i + 1) for i in range(len(metric_values_train))]
  y = metric_values_train
  plt.xlabel("epoch")
  plt.plot(x, y, color="blue")

  plt.savefig(plot_path)
  plt.show()
  
  return

In [None]:
def plot_test_slices(model, val_ds, VAL_AMP, patient_indices, z_axis_values):
  post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
      
  with torch.no_grad():
    for k in range(len(patient_indices)):
      # choose "i", index of the testing image to be plotted,
      # and "j", value of the z-axis corresponding to the
      # x-y section to be plotted
      i = patient_indices[k]
      j = z_axis_values[k]

      # select one image to evaluate and visualize the model output
      val_input = val_ds[i]["image"].unsqueeze(0).to(device)
      roi_size = (128, 128, 64)  #(240, 240, 160)
      sw_batch_size = 4
      val_output = inference(val_input, model, VAL_AMP)
      val_output = post_trans(val_output[0])
        
      plt.figure("current patient", (18, 6))
      
      #plot of the input image
      plt.subplot(1, 3, 1)
      plt.title(f"input image, patient {i}, z = {j}")
      plt.imshow(val_ds[i]["image"][0, :, :, j].detach().cpu(), cmap="gray")
        
      #plot of the label
      plt.subplot(1, 3, 2)
      plt.title(f"label, patient {i}, z = {j}")
      plt.imshow(val_ds[i]["label"][0, :, :, j].detach().cpu())
        
      #plot of the output
      plt.subplot(1, 3, 3)
      plt.title(f"output, patient {i}, z = {j}")
      plt.imshow(val_output[0, :, :, j].detach().cpu())
        
      plt.show()
  return

In [None]:
def plot_test_3d(VAL_AMP, device, model, val_ds, n_train, max_epochs):
  post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
  number_val_patients = len(val_ds)
    
  output_dir = '/kaggle/working/test_3d_plots'
  os.makedirs(output_dir, exist_ok=True)

  with torch.no_grad():
    for i in range(number_val_patients):
      val_input = val_ds[i]["image"].unsqueeze(0).to(device)
      roi_size = (128, 128, 64)  #(240, 240, 160)
      sw_batch_size = 4
      val_output = inference(val_input, model, VAL_AMP)
      val_output = post_trans(val_output[0])
    
      # create figure with three subplots for display
      fig = plt.figure(figsize=(18, 6))
  
      # 3d plot of the input image
      ax = fig.add_subplot(131, projection = '3d')
      z, x, y = val_ds[i]["image"][0].astype(np.uint8).nonzero()
      ax.scatter(x, y, z)
      ax.set_xlim([0, 200])
      ax.set_ylim([0, 200])
      ax.set_zlim([0, 200])
      ax.set_title(f"input image, patient {i}")
        
      # 3d plot of the label
      ax = fig.add_subplot(132, projection = '3d')
      z, x, y = val_ds[i]["label"][0].astype(np.uint8).nonzero()
      ax.scatter(x, y, z)
      ax.set_xlim([0, 200])
      ax.set_ylim([0, 200])
      ax.set_zlim([0, 200])
      ax.set_title(f"label, patient {i}")

      # 3d plot of the output
      ax = fig.add_subplot(133, projection = '3d')
      z, x, y = val_output[0].astype(np.uint8).nonzero()
      ax.scatter(x, y, z)
      ax.set_xlim([0, 200])
      ax.set_ylim([0, 200])
      ax.set_zlim([0, 200])
      ax.set_title(f"output, patient {i}")
        
      # save the plots
      plot_path = os.path.join(output_dir,
                               f"test patient {i}, n_train = {n_train}, max_epochs = {max_epochs}.png")
      plt.savefig(plot_path)
      plt.show()
      plt.close(fig)
        
  return

# Creation of the dataset
Here, the dataset is created, and one horizontal slice of the image regarding one of the training set patients is plotted, together with the corresponding label.
Please specify:
- number of training images ("n_train": int)
- index of the patient's data to be plotted ("plot_index": int), in the set [0,...,n_train-1]
- vertical (z-axis) coordinate corresponding to the slice ("z_coordinate": int).

In [None]:
n_train = 8
n_test = 10 - n_train
plot_index = 1
z_coordinate = 70

train_files, val_files = create_dataset(n_train, n_test)
train_loader, val_loader, val_ds = transform(train_files, val_files)

plot_patient(plot_index, z_coordinate, val_ds)

# Training of the U-Net
Here, the U-Net is trained.
Please specify the integer value of "max_epochs", i.e. the number of epochs of the training phase. During each epoch, the model is trained once on all the images in the train set.

In [None]:
max_epochs = 5

metric_values_train, metric_values, epoch_loss_values, model, device, VAL_AMP = model_and_train(n_train,
                                                                                                n_test,
                                                                                                train_loader,
                                                                                                val_loader,
                                                                                                max_epochs,
                                                                                                val_ds)

# Plot of the model's metrics
Here we plot the trends regarding the loss function and the Dice metric computed on train and test sets.
The plots are automatically saved in the directory: /kaggle/working/output_plots.

In [None]:
plot_metrics(epoch_loss_values, metric_values, metric_values_train, 1, n_train, max_epochs)

# Plot of the slices of some testing images
Here, we choose to show some slices (corresponding to the z-coordinate specified by the elements of the vector "z_axis_values") for the patients in the test set. For each patient, we show the  input image, the corresponding label, and the output generated by the U-Net. Please note that the vector "patients_indices" must contain only integers in the interval [0, n_test-1], so changing the number of patients in the training set (and, consequently, in the test set) may require to modify the vector. Also note that the two vectors "patients_indices" and "z_axis_values" must have the same length.

In [None]:
patients_indices = (0, 0, 1, 1)
z_axis_values = (82, 90, 30, 60)

plot_test_slices(model, val_ds, VAL_AMP, patients_indices, z_axis_values)

# 3D plot of all the testing images
We now show the 3D images in the test set, together with their corresponding labels, and the reconstructions generated by the U-Net. The plots are automatically saved in the directory: /kaggle/working/test_3d_plots.
Running this cell requires some minutes.

In [None]:
plot_test_3d(VAL_AMP, device, model, val_ds, n_train, max_epochs)