In [None]:
import numpy as np

#Neural Imaging
import nibabel as nib
import SimpleITK as sitk
import torchio as tio

from tqdm import tqdm
import torch

from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
import albumentations as A

import segmentation_models_pytorch as smp
from torch import nn
from torch.nn import functional as F
import torchmetrics

from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

import gdown
import shutil

import pandas as pd
import glob
import matplotlib.pyplot as plt

In [None]:
train_on_gpu = torch.cuda.is_available()
if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')
device = torch.device("cuda:0" if train_on_gpu else "cpu")
print(device)
experiment_name = 'prova'

Loading the Data (original version)

These are the original images that I downloaded from the website of the Decathlon Challenge ("http://medicaldecathlon.com"). 
In this section I upload the files, creating a list with the names of all the files and applied N4 Bias Field Correction to them.

In [None]:
#data_path_tr="/content/drive/MyDrive/Task09_Spleen/imagesTr"
#data_path_ts="/content/drive/MyDrive/Task09_Spleen/imagesTs"
#labels_tr = "/content/drive/MyDrive/Task09_Spleen/labelsTr"

In [None]:
# List of path of all the training files

#tr_names = glob.glob(data_path_tr + '/*')
#ts_names = glob.glob(data_path_ts + '/*')
#tr_labels = glob.glob(labels_tr + '/*')

#Total Number of Training samples
#len(tr_names)

In [None]:
#Get more informations from the Nifti format
#lab = nib.load(tr_names + '/spleen_63.nii.gz').get_fdata()

#Labels are already One Hot encoded into (0,1) format
#np.unique(lab)

Applying N4 Bias Field Correction

In [None]:
from ipywidgets import interact
def explore_3D_array(arr: np.ndarray, cmap: str = 'gray'):

  def fn(SLICE):
    plt.figure(figsize=(7,7))
    plt.imshow(arr[SLICE, :, :], cmap=cmap)

  interact(fn, SLICE=(0, arr.shape[0]-1))

#Function to cancel out the bias
def bias_cancellation(imm_path: str):
    """
    Correct the given image reducing the N4 Bias .

    Args:
        imm_path (str): string containing image path

    Returns:
        corrected_image_full_resolution (Image): 3D Image of all the slices after bias correction
    """
    raw_img_sitk = sitk.ReadImage(imm_path, sitk.sitkFloat32)
    raw_img_sitk_arr = sitk.GetArrayFromImage(raw_img_sitk)

    #Creating the mask rescaling the intensity
    transformed = sitk.RescaleIntensity(raw_img_sitk, 0, 255)
    transformed = sitk.LiThreshold(transformed,0,1)
    head_mask = transformed

    #Applying the correction on a low resolution image for computational reasons
    shrinkFactor = 4
    inputImage = raw_img_sitk

    inputImage = sitk.Shrink( raw_img_sitk, [ shrinkFactor ] * inputImage.GetDimension() )
    maskImage = sitk.Shrink( head_mask, [ shrinkFactor ] * inputImage.GetDimension() )

    bias_corrector = sitk.N4BiasFieldCorrectionImageFilter()

    corrected = bias_corrector.Execute(inputImage, maskImage)

    #Returning to the previous resolution
    log_bias_field = bias_corrector.GetLogBiasFieldAsImage(raw_img_sitk)
    corrected_image_full_resolution = raw_img_sitk / sitk.Exp( log_bias_field )

    return corrected_image_full_resolution

In [None]:
#This is used to create a new folder containing all the images saved after having applied bias correction
#First I apply the function created before and then I will save all these images

#for i in tr_names:
#    corrected_image_full_resolution = bias_cancellation(i)
#    new_name = i.replace('imagesTr', 'imagesTr_nobias')
#    sitk.WriteImage(corrected_image_full_resolution, new_name)

#for i in ts_names:
#    corrected_image_full_resolution = bias_cancellation(i)
#    new_name = i.replace('imagesTs', 'imagesTs_nobias')
#    sitk.WriteImage(corrected_image_full_resolution, new_name)

Downloading the images after Bias Correction

In [None]:
#gdown for the data after bias correction
path = 'dati_nobias.zip'
gdown.download(id='1HMUZhx1tT5jYT5_YPpw_vJ8HBoYLYSg-', output= path, quiet = False)

In [None]:
#Unpacking the zip file
shutil.unpack_archive("dati_nobias.zip")

In [None]:
#Creating the paths with the ones that have no bias
tr_nobias= "imagesTr_nobias"
labels_tr = "labelsTr"

# List of path of all the training files
tr_names_no = glob.glob(tr_nobias + '/*')
tr_labels = glob.glob(labels_tr + '/*')

tr_names_no = np.sort([path.replace("\\", "/") for path in tr_names_no])
tr_labels = np.sort([path.replace("\\", "/") for path in tr_labels])

print(len(tr_names_no))

Visualizing some examples

In [None]:
img_1 = nib.load(tr_nobias + "/spleen_2.nii.gz").get_fdata()
lab_1 = nib.load(labels_tr + "/spleen_2.nii.gz").get_fdata()

In [None]:
#Selecting one slice to be shown
z = 70
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(12, 6))

ax1.imshow(img_1[:,:,z], cmap = "gray")
ax2.imshow(lab_1[:,:,z], cmap = "rainbow")
ax3.imshow(img_1[:,:,z], cmap = "gray")
ax3.imshow(lab_1[:,:,z], cmap = "rainbow", alpha = 0.5)


ax1.set_title("MRI")
ax2.set_title("MASK")
ax3.set_title("MRI w/ MASK")

ax1.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()

In [None]:
#Showing all the different slices in sequence
fig, axes = plt.subplots(6, int(img_1.shape[2]/6), figsize=(16, 9))

for i, ax in enumerate(axes.flat):
    ax.imshow(img_1[:,:,i], cmap="gray")
    ax.imshow(lab_1[:,:,i], cmap="rainbow", alpha = 0.5)
    ax.set_axis_off()

plt.show()

In [None]:
#Searching for the smallest number of slides
y = []
for i in range(len(tr_names_no)):
  image = nib.load(tr_names_no[i]).get_fdata()
  y.append(image.shape[2])
print(min(y))
print(max(y))
#Minimum is 31 slides
#Maximum is 168 slides

## 2-D Implementation

Saving a .npy copy of all the files from the .nii.gz formato for computational reasons.
In this section I will transform all the .nii.gz in numpy array, one array per slice per patient.
Loading for the training the nii.gz was computationally expensive, this is a faster way of approaching the problem

In [None]:
""" #Labels
for i in tqdm(range(len(tr_labels))):
  image = nib.load(tr_labels[i]).get_fdata()
  z = image.shape[2]

  for j in range(z):
    np.save("C:/Users/user/Desktop/Artificial Intelligence/Piccoli/Labels_npy/" +tr_labels[i].split(".")[0].split("/")[-1] + "_" + str(j) + str(".npy"), image[:,:,j]) """

In [None]:
""" #Images
for i in tqdm(range(len(tr_names_no))):
  image = nib.load(tr_names_no[i]).get_fdata()
  z = image.shape[2]

  for j in range(z):
    np.save("C:/Users/user/Desktop/Artificial Intelligence/Piccoli/Images_npy/" +tr_names_no[i].split(".")[0].split("/")[-1] + "_" + str(j) + str(".npy"), image[:,:,j]) """

In this section I will provide the gdown with the numpy file already created

In [None]:
#gdown for the data in npy format
path = 'dati_numpy.zip'
gdown.download(id='1TPxra-VksERcClu4jIO2SKD7i2B5Ct7P', output= path, quiet = False)

In [None]:
#Unpacking the zip file
shutil.unpack_archive("dati_numpy.zip")

To obtain a 2D implementation first I need to create a single DataFrame that contains all the images

In [None]:
#Creating a Dataframe with the path of all the nii.gz files
lista_dati = pd.DataFrame(columns = ["Path", "Slice", "Label Path"])
for i in tqdm(range(len(tr_names_no))):
    image = nib.load(tr_names_no[i]).get_fdata()
    z = image.shape[2]

    df_slices = pd.DataFrame({
        "Path": [tr_names_no[i]] * z,
        "Slice": np.arange(0, z),
        "Label Path": [tr_labels[i]] * z
    })

    # Append the slice DataFrame to the main dataset
    lista_dati = pd.concat([lista_dati, df_slices], ignore_index=True)

In [None]:
#Refreshing the paths with the ones that have no bias
immagini_npy = "Images_npy"
labels_npy = "Labels_npy"

# List of path of all the training files
names_npy_im = glob.glob(immagini_npy + '/*')
names_npy_lb = glob.glob(labels_npy + '/*')

names_npy_im = np.sort([path.replace("\\", "/") for path in names_npy_im])
names_npy_lb = np.sort([path.replace("\\", "/") for path in names_npy_lb])

#Creating a Dataframe with the path of all the npy files sorted
lista_dati_2 = pd.DataFrame(columns = ["Path", "Label Path"])
lista_dati_2['Path'] = names_npy_im
lista_dati_2['Label Path'] = names_npy_lb

In [None]:
# create a Pytorch Dataset class
class Dataset(torch.utils.data.Dataset):
  def __init__(self, df, type = "npy", usage = 'train'):
      """
      Initializes a custom PyTorch dataset.

      Args:
          df (pandas.DataFrame): DataFrame containing image paths, labels path, slice
          usage (str, optional): Usage mode ('train', 'test'). Defaults to 'train'.
          type (str, optional): Input files type ("npy", "nii.gz"). Defaults to "npy".
      """

      self.df = df
      self.usage = usage
      self.type = type

      #Preprocessing:
      #Normalization using ZNormalization
      #Clamp - Using Hounsfield scale to improve visualization of the spleen
      #RescaleIntensity
      #Resampling to 1 mm isotropic for faster computation
      HOUNSFIELD_AIR = -1000
      HOUNSFIELD_BONE = 1400
      self.transform = tio.Compose([
          tio.ZNormalization(masking_method=tio.ZNormalization.mean),
          tio.Clamp(out_min=HOUNSFIELD_AIR, out_max=HOUNSFIELD_BONE),
          tio.RescaleIntensity(out_min_max=(0, 1)),
          tio.Resample(1),
      ])

      #Data Augmentation with standard MRI artifacts like:
      #Bias Field
      #Blurring
      #Noise
      #Spike
      self.augment = tio.Compose([
          tio.RandomBiasField(p=0.3),
          tio.OneOf({
              tio.RandomBlur(): 0.2,
              tio.RandomNoise(mean = 0, std = 0.1) : 0.6,
              tio.RandomSpike(): 0.2
          })
      ])

      #Data Augmentation with random rotation
      self.torch_augment = A.Compose([
          A.augmentations.geometric.rotate.RandomRotate90(p = 0.5),
      ])

      #Resizing to a common shape of 224x224
      self.resizing = transforms.Compose([
      transforms.ToTensor(),
      transforms.Resize((224, 224), antialias=False),
    ])

  def __len__(self):
      return len(self.df)

  def __getitem__(self, index):
      # get filename and label
      if self.type == "nii":
        path = self.df.iloc[index]['Path']
        sl = self.df.iloc[index]['Slice']
        label_path = self.df.iloc[index]['Label Path']

        # Load image
        image = nib.load(path).get_fdata()[:,:,sl]

        # Load label
        label = nib.load(label_path).get_fdata()[:,:,sl]
      
      if self.type == "npy":
        path = self.df.iloc[index]['Path']
        label_path = self.df.iloc[index]['Label Path']

        # Load image
        image = np.load(path)

        # Load label
        label = np.load(label_path)

      image = np.expand_dims(image, [0,3])
      image = self.transform(image)

      #Data augmentation only during train
      if self.usage == 'train':
        
        #Data Augmentation with MRI Artifacts is applied only to the image, leaving the mask untouched
        image = self.augment(image)
        image = np.squeeze(image)

        #Creating a dictionary that contains both image and mask
        sample = {'image': image, 'mask': label}
        #Data augmentation with random rotation to both image and mask
        sample = self.torch_augment(image = sample['image'], mask = sample['mask'])

        image = sample['image']
        label = sample['mask']

        #Resizing both mask and image
        image = np.array(image)
        image = self.resizing(image)
        image = torch.squeeze(image)

        label = np.array(label)
        label = self.resizing(label)
        label = torch.squeeze(label)

      #When the Dataset is in test mode then don't apply data augmentation but only resizing
      if self.usage == 'test':
       
        image = np.squeeze(image)
        image = np.array(image)
        image = self.resizing(image)
        image = torch.squeeze(image)

        label = np.array(label)
        label = self.resizing(label)
        label = torch.squeeze(label)
      
      return image, label

In [None]:
ds = Dataset(lista_dati_2, usage = "train")

#Visualizing an example of MRI image with data augmentation
featurs, label = ds.__getitem__(58)

fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(12, 6))

ax1.imshow(featurs, cmap = "gray")
ax2.imshow(label, cmap = "rainbow")
ax3.imshow(featurs, cmap = "gray")
ax3.imshow(label, cmap = "rainbow", alpha = 0.5)

ax1.set_title("MRI")
ax2.set_title("MASK")
ax3.set_title("MRI w/ MASK")

ax1.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()

In [None]:
#Creating a Validation Set and a Test Set
path_train, path_val = train_test_split(lista_dati_2, test_size = 0.3, random_state = 42)
path_test, path_val = train_test_split(path_val, test_size = 0.5, random_state = 42)

batch_size = 64
# Create Dataset
train_ds = Dataset(path_train)
test_ds = Dataset(path_test, usage = 'test')
val_ds = Dataset(path_val, usage = 'test')


# Create Dataloader using different batch size for train and validation
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle = True, drop_last=True)
valid_loader = DataLoader(val_ds, batch_size=16, shuffle = True, drop_last=True)

In [None]:
import gc

#Emptying cuda cache
gc.collect()
torch.cuda.empty_cache()

In [None]:
NUM_EPOCHS = 20

#Two different pretrained models:
model = smp.DeepLabV3Plus(
  encoder_name="resnet34", 
  encoder_weights= "imagenet", 
  in_channels=1,  #model input channels 
  classes=2, # model output channels 
)

#model = smp.Unet(
 #   encoder_name="resnet34",        # choose encoder
 #   encoder_weights="imagenet",     # use `imagenet` pre-trained weights 
  #  in_channels=1,                  # model input channels 
   # classes=2                       # model output channels (number of classes in your dataset)
#)

#Loading the model on the device
model.to(device)

#Choosing the optimizer, the criterion and the scheduler
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scheduler = CosineAnnealingLR(optimizer,
                              T_max = len(train_loader)*5, # Maximum number of iterations.
                             eta_min = 1e-5) # Minimum learning rate.

#Initializing the variables to empty array
train_losses = []
val_ious = []
val_losses = []

for idx_epoch in range(NUM_EPOCHS):
  print('train')
  running_loss = 0.
  #Training
  model.train()
  for image, mask in tqdm(train_loader):
    # move to device
    image = image.to(device)
    mask = mask.to(device).squeeze()
    # reset gradients
    optimizer.zero_grad()
    # forward pass
    image = torch.unsqueeze(image, 1)
    pred_mask = model(image.float())
    # compute loss
    loss = criterion(pred_mask, mask.long())
    # backward pass
    loss.backward()
    # update weights
    optimizer.step()
    scheduler.step()
    # update running loss
    running_loss += loss.item()
  train_losses.append(running_loss / len(train_loader))
  print(f"{idx_epoch:02d}: {train_losses[-1]:.6f}")
  print('validate')
  ious, nan_i = 0., 0
  running_loss = 0.
  #Validation
  model.eval()
  with torch.no_grad():
    for image, mask in valid_loader:
      # move to device
      image = image.to(device)
      image = torch.unsqueeze(image, dim = 1)

      mask = mask.to(device).long()
      mask = torch.unsqueeze(mask, dim = 1)

      pred_mask = model(image.float())

      #select the probabilities for the foreground class and create a mask accordingly
      pred_msk_binary = torch.where(F.softmax(pred_mask, dim=1)[:, 1] >= 0.5, 1, 0)
      image = torch.squeeze(image)
      mask = torch.squeeze(mask)
      #Using the IoU Metric during the validation
      metric = torchmetrics.JaccardIndex(task = "binary").to(device)
      iou = metric(pred_msk_binary, mask)
      #If IoU is nan then increasing by 1 the total number of nan found
      if torch.isnan(iou):
        nan_i += 1
      ious = torch.nansum(torch.tensor([ious,iou]))
      loss = criterion(pred_mask, mask)
      running_loss += loss.item()
  val_ious.append(ious.cpu()/(len(valid_loader)-len(nan_i)))
  val_losses.append(running_loss / len(valid_loader))
  print(f"iou: {val_ious[-1]:.4f}")

  if np.argmax(val_ious) == len(val_ious)-1:
    print('new best model')
    torch.save(model.state_dict(), 'deeplab.pth')


Testing the code using the Test Dataset and two different metrics: Dice + IoU

In [None]:
#Loading the pretrained model using gdown

def load_model_2D(path = None, which = None):
#Download the two different models depending on the choice
    if path is None:
        print("We are using Unet")
        id = '1B5pIMRFd0FywUdexZ5Ah6eoiI_9CfHag'
        path = 'model_unet2d.pth'
        gdown.download(id=id, output=path, quiet=False)
    
    if which == "unet": 
        model = smp.Unet(
        encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
        encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
        in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
        classes=2
        )

        gdown.download(id='1B5pIMRFd0FywUdexZ5Ah6eoiI_9CfHag', output=path, quiet=True)

    if which == "deeplab":
        model = smp.DeepLabV3Plus(
        encoder_name="resnet34", 
        encoder_weights= "imagenet", 
        in_channels=1,  #model input channels 
        classes=2, # model output channels 
        )

        gdown.download(id='1XtWzG3hdKUHDGQ1OnoGlPPG3LJBIDAbx', output=path, quiet=True)
        
    state = torch.load(path, map_location=torch.device("cpu"))
    model.load_state_dict(state)
    model.eval()
    return model

model = load_model_2D("unet_prova.pth", "unet")
model = model.to(device)

In [None]:
#Test Dataloader using a batchsize of 1

dl_test = DataLoader(test_ds, shuffle=False, batch_size=1, num_workers=0)
i = 0

for image, label in dl_test:
  image = image.to(device)
  image = torch.unsqueeze(image, dim = 1)

  mask = label.to(device).long()
  mask = torch.unsqueeze(mask, dim = 1)

  pred_mask = model(image.float())

  predicted = torch.where(F.softmax(pred_mask, dim=1)[:, 1] >= 0.5, 1, 0)
  image = torch.squeeze(image)
  mask = torch.squeeze(mask)
  #Plotting the first ten images to show how the model worked in a qualitatively way
  fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(12, 6))

  ax1.imshow(image.cpu(), cmap = "gray")
  ax2.imshow(image.cpu(), cmap = "gray")
  ax2.imshow(mask.cpu(), cmap = "rainbow", alpha = 0.5)
  ax3.imshow(image.cpu(), cmap = "gray")
  ax3.imshow(predicted.cpu()[0], cmap = "rainbow", alpha = 0.5)

  ax1.set_title("MRI")
  ax2.set_title("TRUE MASK")
  ax3.set_title("PREDICTED MASK")

  ax1.set_axis_off()
  ax2.set_axis_off()
  ax3.set_axis_off()

  if i == 10:
    break
  
  i += 1

In [None]:
from torchmetrics.classification import Dice

ious, dices, nan_i =0, 0, 0
val_ious, val_dices = [],[]

for image, label in dl_test:
  image = image.to(device)
  image = torch.unsqueeze(image, dim = 1)

  mask = label.to(device).long()
  mask = torch.unsqueeze(mask, dim = 1)

  pred_mask = model(image.float())

  predicted = torch.where(F.softmax(pred_mask, dim=1)[:, 1] >= 0.5, 1, 0)
  image = torch.squeeze(image)
  mask = torch.squeeze(mask)
  predicted = torch.squeeze(predicted)
  
  #Measuring the IoU metric using the Jaccard Index
  metric = torchmetrics.JaccardIndex(task = "binary").to(device)
  iou = metric(predicted, mask)
  if torch.isnan(iou):
    nan_i += 1
  ious = torch.nansum(torch.tensor([ious,iou]))
  metric_bis = torchmetrics.Dice(num_classes=2, average='macro').to(device)
  dice = metric_bis(predicted, mask)
  dices = torch.sum(torch.tensor([dices,dice]))

#Measuring the average iou and dice on all the dataset
val_ious.append(ious.cpu()/(len(dl_test)-nan_i))
val_dices.append(dices.cpu()/len(dl_test))
print(f"Test IoU: {val_ious[-1]:.4f}")
print(f"Test Dice: {val_dices[-1]:.4f}")

## 3-D Implementation


First I will create a dataframe for the 3D implementation with the image path, the path of the corresponding label and the index of the first slice in which I have a mask different from zero and the number of slices in which the mask is not all zeros

In [None]:
slices_3d = pd.DataFrame(columns = ["Path", "Label Path", "Start Index", "Num_slices"])

for i in tqdm(range(len(tr_names_no))):

  image = nib.load(tr_names_no[i]).get_fdata()
  mask = nib.load(tr_labels[i]).get_fdata()

  mask_start_idx = 0
  while mask[:,:,mask_start_idx].max() == 0:
    mask_start_idx += 1
  mask_end_idx = mask_start_idx
  while mask[:,:,mask_end_idx].max() != 0:
    mask_end_idx += 1

  num_spleen_slices = mask_end_idx - mask_start_idx

  df_slices = pd.DataFrame({
      "Path": [tr_names_no[i]],
      "Label Path": [tr_labels[i]],
      "Start Index": mask_start_idx,
      "Num_slices": num_spleen_slices
  })

  slices_3d = pd.concat([slices_3d, df_slices], ignore_index=True)

In [None]:
def stack_images_in_blocks(df, idx, block_size=32):
    """Stacks images and labels from a DataFrame into blocks of given size.

    Args:
        df (pandas.DataFrame): DataFrame containing image paths and label paths.
        idx (int or list): Index or list of indices specifying which images to stack.
        block_size (int, optional): Size of each block. Defaults to 32.

    Returns:
        tuple: A tuple containing two NumPy arrays:
            stacked_images (np.ndarray): Array of stacked images with shape (num_blocks, height, width, block_size).
            stacked_labels (np.ndarray): Array of stacked labels with the same shape as stacked_images."""

    img = nib.load(df["Path"][idx])
    lab = nib.load(df["Label Path"][idx])
        
    data = img.get_fdata()
    label = lab.get_fdata()

    height, width, num_images = data.shape

    # Check if block size is a multiple of variable dimension
    remainder = num_images % block_size
    if remainder != 0:
        print(f"Warning: Discarding {remainder} columns from each image as they don't fit complete blocks.")

    num_blocks = num_images // block_size

    # Initialize empty array to store stacked images
    stacked_images = np.zeros((num_blocks, height, width, block_size))
    stacked_labels = np.zeros((num_blocks, height, width, block_size))

    # Loop over images and stack them block-wise
    for block_index in range(num_blocks):
        start_index = block_index * block_size
        end_index = min(start_index + block_size, num_images)
        stacked_images[block_index, :, :, :end_index - start_index] = data[:, :, start_index:end_index]
        stacked_labels[block_index, :, :, :end_index - start_index] = label[:, :, start_index:end_index]
        
    return stacked_images, stacked_labels

Creating block of stacked images

In [None]:
for index, _ in slices_3d.iterrows():
    stacked_images, stacked_labels = stack_images_in_blocks(slices_3d, index)
    for j in range(stacked_images.shape[0]):
        np.save(slices_3d["Path"][index].replace("imagesTr_nobias", "images_npy_3d").replace(".nii.gz", "_" + str(j) + ".npy"), stacked_images[j,:,:,:])
        np.save(slices_3d["Label Path"][index].replace("labelsTr", "labels_npy_3d").replace(".nii.gz", "_" + str(j) + ".npy"), stacked_labels[j,:,:,:])

For simplicity I will upload the files already in .npy format ready to be downloaded here

In [None]:
path = "dati_numpy_3d.zip"
gdown(id = "17my4ppZsjlAxlzTEVNUHpgcqjL6ffREI", output = path, quiet = False)

In [None]:
#Unpacking the zip file
shutil.unpack_archive("dati_numpy_3d.zip")

In [None]:
immagini_npy_3d = "images_npy_3d"
labels_npy_3d = "labels_npy_3d"

# List of path of all the training files
names_npy_im_3d = glob.glob(immagini_npy_3d + '/*')
names_npy_lb_3d = glob.glob(labels_npy_3d + '/*')

names_npy_im_3d = np.sort([path.replace("\\", "/") for path in names_npy_im_3d])
names_npy_lb_3d = np.sort([path.replace("\\", "/") for path in names_npy_lb_3d])

#Creating a Dataframe with the path of all the npy files sorted
lista_dati_3 = pd.DataFrame(columns = ["Path", "Label Path"])
lista_dati_3['Path'] = names_npy_im_3d
lista_dati_3['Label Path'] = names_npy_lb_3d

In [None]:
# create a Pytorch Dataset class
class Dataset_3D(torch.utils.data.Dataset):
  def __init__(self, df, usage = 'train'):
      """
      Initializes a custom PyTorch dataset.

      Args:
          df (pandas.DataFrame): DataFrame containing image paths, labels path, slice
          usage (str, optional): Usage mode ('train', 'val', 'test'). Defaults to 'train'.
      """

      self.df = df
      self.usage = usage

      #Preprocessing:
      #Clamp - Using Hounsfield scale to improve visualization of the spleen
      #RescaleIntensity
      #Resampling to 1 mm isotropic for faster computation
      HOUNSFIELD_AIR = -1000
      HOUNSFIELD_BONE = 1400
      self.transform = tio.Compose([
          tio.ToCanonical(),
          tio.ZNormalization(masking_method=tio.ZNormalization.mean),
          tio.Clamp(out_min=HOUNSFIELD_AIR, out_max=HOUNSFIELD_BONE),
          tio.RescaleIntensity(out_min_max=(0, 1)),
          tio.Resample(1)
      ])

      self.augment = tio.Compose([
          tio.RandomBiasField(p=0.3),
          tio.OneOf({
              tio.RandomBlur(): 0.2,
              tio.RandomNoise(mean = 0, std = 0.05) : 0.6,
              tio.RandomSpike(): 0.2
          })
      ])

      self.resizing = transforms.Compose([
      transforms.ToTensor(),
      transforms.Resize((224, 224), antialias=False),
    ])


  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
      # get filename and label
      path = self.df.iloc[index]['Path']
      l_path = self.df.iloc[index]['Label Path']

      # Load image
      try:
        im = np.load(path)
        ms = np.load(l_path)
        #For the files with dimension smaller than 32 (the minimum is 31) I add a slice of all zeros
        if im.shape[2] == 31:
           im = np.concatenate((im, np.zeros((512, 512, 1))), axis=-1)
           ms = np.concatenate((ms, np.zeros((512, 512, 1))), axis=-1)
        #Apply resizing
        im = self.resizing(im)
        ms = self.resizing(ms)

        im = torch.unsqueeze(im, dim = 0)
        im = self.transform(im)
      except:
        print(path, l_path, index)
        return np.zeros((2, 224, 224))

      #Data augmentation only during train
      #Will use artificial artifacts to which MRI is often susceptible
      if self.usage == 'train':
        image = self.augment(im)
        ms = torch.unsqueeze(ms, dim = 0)

        image = np.array(image)
        ms = np.array(ms)

        image = np.squeeze(image)
        ms = np.squeeze(ms)
      if self.usage == 'test':
        image = np.squeeze(im)


      return image, ms

In [None]:
ds = Dataset_3D(lista_dati_3, usage = "train")

featurs, label = ds.__getitem__(16)


fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(12, 6))

ax1.imshow(featurs[15,:,:], cmap = "gray")
ax2.imshow(label[15,:,:], cmap = "rainbow")
ax3.imshow(featurs[15,:,:], cmap = "gray")
ax3.imshow(label[15,:,:], cmap = "rainbow", alpha = 0.5)

ax1.set_title("MRI")
ax2.set_title("MASK")
ax3.set_title("MRI w/ MASK")

ax1.set_axis_off()
ax2.set_axis_off()
ax3.set_axis_off()

In [None]:
#Creating a Validation Set and a Test Set
path_train, path_val = train_test_split(lista_dati_3, test_size = 0.3, random_state = 42)
path_test, path_val = train_test_split(path_val, test_size = 0.5, random_state = 42)

batch_size = 6
# Create Dataset
train_ds = Dataset_3D(path_train)
test_ds = Dataset_3D(path_test, usage = 'test')
val_ds = Dataset_3D(path_val, usage = 'test')


# Create Dataloader
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle = True, drop_last=True)
valid_loader = DataLoader(val_ds, batch_size=1, shuffle = True, drop_last=True)

Training the Model

In [None]:
import monai
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss

#Using the UNet pretrained model from the Monai library
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)
#Defining loss, optimizer and metric
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

#Loading the pretrained model
pretrained_model = monai.bundle.load(
    name="spleen_ct_segmentation", bundle_dir="./"
)

model.load_state_dict(pretrained_model)

In [None]:
train_losses = []
val_ious = []
val_losses = []

NUM_EPOCHS = 100
for idx_epoch in range(NUM_EPOCHS):
  #Training
  print('train')
  running_loss = 0.
  model.train()
  for image, mask in tqdm(train_loader):
    # move to device
    image = image.to(device)
    mask = mask.to(device).squeeze()
    # reset gradients
    optimizer.zero_grad()
    # forward pass
    image = torch.unsqueeze(image, 1)
    pred_mask = model(image.float())
    # compute loss

    mask = torch.unsqueeze(mask, dim = 1)
    loss = loss_function(pred_mask, mask.long())
    # backward pass
    loss.backward()
    # update weights
    optimizer.step()
    # update running loss
    running_loss += loss.item()
  train_losses.append(running_loss / len(train_loader))
  print(f"{idx_epoch:02d}: {train_losses[-1]:.6f}")
  print('validate')
  ious = 0.
  nan_i = 0
  running_loss = 0.
  model.eval()
  with torch.no_grad():
    for image, mask in valid_loader:
      # move to device
      image = image.to(device)
      image = torch.unsqueeze(image, dim = 1)

      mask = mask.to(device).long()
      mask = torch.unsqueeze(mask, dim = 1)

      pred_mask = model(image.float())

      #select the probabilities for the foreground class and create a mask accordingly
      pred_msk_binary = torch.where(F.softmax(pred_mask, dim=1)[:, 1] >= 0.5, 1, 0)

      image = torch.squeeze(image)
      mask = torch.squeeze(mask, dim = 0)

      metric = torchmetrics.JaccardIndex(task = "binary").to(device)
      iou = metric(pred_msk_binary, mask)
      if torch.isnan(iou):
        nan_i += 1
      ious = torch.nansum(torch.tensor([ious,iou]))
      loss = criterion(pred_mask, mask)
      running_loss += loss.item()
      
  val_ious.append(ious.cpu()/(len(valid_loader)- nan_i))
  val_losses.append(running_loss / len(valid_loader))
  print(f"iou: {val_ious[-1]:.4f}")

  if np.argmax(val_ious) == len(val_ious)-1:
    print('new best model')
    torch.save(model.state_dict(), 'unet3d_dice.pth')


To load the pretrained 3D Model

In [None]:
def load_3dmodel(path = None):
    # If a copy of the model is present on the pc, load it from the path
    # Download it otherwise
    if path is None:
        id = '1oh1_ngh9vvC-iO_-1d4v5lsX22ioW4cI'
        path = '3dmodel_best.pth'
        gdown.download(id=id, output=path, quiet=False)
    
    model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
    ).to(device)

    state = torch.load(path, map_location=torch.device("cpu"))
    model.load_state_dict(state)
    model.eval()
    return model

In [None]:
model = load_3dmodel("prova_3d.pth")

Testing the Model using both IoU and Dice Metric

In [None]:
#DataLoader
dl_test = DataLoader(test_ds, shuffle=True, batch_size=1, num_workers=0)
i, nan_i = 0, 0
ious, dices, val_ious_test, val_dices = [], [], [], []

for image, mask in dl_test:
  image = image.to(device)
  image = torch.unsqueeze(image, dim = 1)

  mask = mask.to(device).long()
  mask = torch.unsqueeze(mask, dim = 1)

  pred_mask = model(image.float())

  predicted = torch.where(F.softmax(pred_mask, dim=1)[:, 1] >= 0.5, 1, 0)
  image = torch.squeeze(image)
  mask = torch.squeeze(mask)
  predicted = torch.squeeze(predicted)

  metric = torchmetrics.JaccardIndex(task = "binary").to(device)
  
  iou = metric(predicted, mask)
  if torch.isnan(iou):
    nan_i += 1
  ious = torch.nansum(torch.tensor([ious,iou]))
  metric_bis = torchmetrics.Dice(num_classes=2, average='macro').to(device)
  dice = metric_bis(predicted, mask)
  dices = torch.sum(torch.tensor([dices,dice]))

val_ious_test.append(ious.cpu()/(len(dl_test)-nan_i))
val_dices.append(dices.cpu()/len(dl_test))
print(f"Test IoU: {val_ious_test[-1]:.4f}")
print(f"Test Dice: {val_dices[-1]:.4f}")


In [None]:
i = 0
for image, label in dl_test:
  image = image.to(device)
  image = torch.unsqueeze(image, dim = 1)

  mask = label.to(device).long()
  mask = torch.unsqueeze(mask, dim = 1)

  pred_mask = model(image.float())

  predicted = torch.where(F.softmax(pred_mask, dim=1)[:, 1] >= 0.5, 1, 0)
  image = torch.squeeze(image)
  mask = torch.squeeze(mask)
  predicted = torch.squeeze(predicted)

  fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize=(12, 6))

  ax1.imshow(image.cpu()[25], cmap = "gray")
  ax2.imshow(image.cpu()[25], cmap = "gray")
  ax2.imshow(mask.cpu()[25], cmap = "rainbow", alpha = 0.5)
  ax3.imshow(image.cpu()[25], cmap = "gray")
  ax3.imshow(predicted.cpu()[25], cmap = "rainbow", alpha = 0.5)

  ax1.set_title("MRI")
  ax2.set_title("TRUE MASK")
  ax3.set_title("PREDICTED MASK")

  ax1.set_axis_off()
  ax2.set_axis_off()
  ax3.set_axis_off()
  
  if i == 10:
    break
  
  i += 1