In [None]:
# When training this model, Python version 3.8.10 has been used. 

In [None]:
# Importing the libraries: 
# The versions that were used for this model are annotated per library.

import os
import numpy as np              # Version used: 1.23.1
import glob
import monai                    # Version used: 1.3.0
from monai.transforms import * 
import SimpleITK as sitk        # Version used: 2.3.1
import torch                    # Version used: 1.13.1+cu116
from tqdm import tqdm
import wandb                    # Version used: 0.17.0

## Loading the data

In [None]:
# A function 'build_dicts' is made to obtain the filenames from the given dataset.
def build_dicts(data_path='database', mode="training"):

    # Check if the mode is training or testing:
    if mode not in ["training", "testing"]:
        raise ValueError(f"Please choose a mode in ['training', 'testing']. Current mode is {mode}.")

    # Finding the .nii.gz files in the dataset matching the folder and directory names with the variability of patientnumbers and frames. 
    paths_xray = glob.glob(os.path.join(data_path, mode, 'patient*', 'patient*_frame[0-9][0-9].nii.gz'))

    dicts = []
    # Iterate over each file path and extract the ground truth files from the scan file. 
    for scan_file in paths_xray:
        extension_index = scan_file.index(".")
        gt_file = scan_file[:extension_index] + "_gt" + scan_file[extension_index:]

        directory = scan_file.split(os.sep)[:-1]
        cfg_file = os.path.join(*directory, "Info.cfg")

        # Open the info.cfg files to extraxt the disease information for completeness.
        with open(cfg_file, "r") as f:
            line = f.readlines()[2]
            disease = line.split(": ")[1]

        # Make a dictionary with a scan file path, the ground truth file path, and the disease
        if os.path.exists(gt_file):
            dicts.append({"scan_file": scan_file, 
                          "gt_file": gt_file, 
                          "class": disease})
    return dicts

In [None]:
# A function 'split_train_val' is made to make a training and a validation set using Numpy.
def split_train_val(train_list, ratio):
    # Making a random index where 10 indices are chosen that are used for the validation set.
        # These 10 indices are obtained by dividing the entire length of the trainlist (which is 200) by two.
        # This is divided by two, since the set needs to be splitted on patient level and there are 200 scans (a diastoly scan and a systole) of 100 patients.
        # Then a ratio can be defined to choose the size of the validation set.
    val_idx = np.random.choice(len(train_list)//2,
                               int(len(train_list) * ratio)//2, replace=False)
    
    train_dicts = []
    val_dicts = []
    # Iterated over all patients, 10 patients are extracted based on the val_idx that is computed.
    for sample in train_list:
        patient = int(sample['scan_file'][-17:-15])
        
        # The validation dictionairy is computed with 10 patients and 20 files.
        if patient in val_idx:
            val_dicts.append(sample)
            
        # The training dictionairy is computed with 90 patients and 180 files. 
        else:
            train_dicts.append(sample)
            
    return train_dicts, val_dicts

In [None]:
# A function 'ReadFiles' is made to read out scan files from a specific dictionary using SimpleITK.
class ReadFiles(monai.transforms.Transform):
    def __call__(self, sample):
        
        # Read out the scan file without the ground truth
        image = sitk.ReadImage(sample["scan_file"])

        # Read out the scan file with the ground truth
        mask = sitk.ReadImage(sample["gt_file"])

        # Returning a dictionary contain all the relevant information
        return {"img": image,                      # Object of the scan file without the segmenation
                "mask": mask,                      # Object of the mask of the ground truth segmentation
                "img_size": image.GetSize(),       # Size of each scan file
                "img_spacing": image.GetSpacing(), # Spacing of the voxels of each scan file
                "class": sample["class"],          # The disease label of each scan file
                "scan_file": sample["scan_file"]   # Path to the scan file
               }

In [None]:
# A function 'resample_img' is made to set the spacing of the SITK image
def resample_img(itk_image, out_spacing, is_label):
    original_spacing = itk_image.GetSpacing()
    original_size = itk_image.GetSize()

    # Calculate the output size, after the image has been respaced
    out_size = [
        int(np.round(original_size[0] * (original_spacing[0] / out_spacing[0]))),
        int(np.round(original_size[1] * (original_spacing[1] / out_spacing[1]))),
        int(np.round(original_size[2] * (original_spacing[2] / out_spacing[2])))]

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size)
    resample.SetOutputDirection(itk_image.GetDirection())
    resample.SetOutputOrigin(itk_image.GetOrigin())

    if is_label: 
        resample.SetInterpolator(sitk.sitkNearestNeighbor) # Use nearest neighbour for labels
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(itk_image)

In [None]:
# A function 'EqualizeSpacing' is made to apply the equalization of the spacings to the images and the masks.
class EqualizeSpacing(monai.transforms.Transform):
    def __call__(self, sample):
        # The original size and spacings of the images and masks are stored in the sample dictionary.
        # This is done so the image can be transformed back after inference
        sample['org_size'] = sample['img_size']
        sample['org_spacing'] = sample['img_spacing']
        
        # Here all dimensions are equally spaced to 1.25
        image = resample_img(sample['img'], [1.25, 1.25, 1.25], False)
        sample['img'] = image
        sample['img_size'] = image.GetSize()
        sample['img_spacing'] = image.GetSpacing()

        # Here all dimensions are equally spaced to 1.25 similarly as the images above.
        mask = resample_img(sample['mask'], [1.25, 1.25, 1.25], True)
        sample['mask'] = mask
        
        return sample

In [None]:
# A function 'LoadData' is made to convert the SimpleITK objects to NumPy arrays.
class LoadData(monai.transforms.Transform):
    def __call__(self, sample):
        # Converting the SimpleITK objects to NumPy arrays for the images and the masks.
        sample['img'] = sitk.GetArrayFromImage(sample['img'])
        sample['mask'] = sitk.GetArrayFromImage(sample['mask'])
        
        return sample

In [None]:
data_path = 'database'
# Here the initial training set is computed. 
train_dict_list = build_dicts(data_path, mode='training')

In [None]:
# Here the validation set is computed where it includes 10% (ratio=0.1) of the patients and the training set is updated, since the validation set gets extracted. 
train_dict_list, val_dict_list = split_train_val(train_dict_list, ratio=0.1)

In [None]:
# Data augmentation is applied to increase the regularization of the model. 
# Two separate composed transformer functions are made for the training set and the dataset using the MONAI library.

train_transforms = monai.transforms.Compose([
    ReadFiles(),        # Applies the 'ReadFiles' function
    EqualizeSpacing(),  # Applies the 'EqualizeSpacing' function
    LoadData(),         # Applies the 'LoadData' function  
    EnsureChannelFirstd(channel_dim="no_channel", keys=["img", "mask"]),
    ScaleIntensityd(keys=["img"]), # Normalize the intensity
    RandGaussianNoised(keys=['img'], prob=1, mean=0, std=0.075),# Random rotate data between 0.4 and 0.4 radians to represent other orientations of the heart in a realistic range.
    RandRotated(keys=["img", "mask"], prob=1, range_x=(0.4, 0.4), mode=['bilinear', 'nearest']),  # Add random Gaussian noise
    RandSpatialCropd(keys=['img', 'mask'], roi_size=[32, 176, 176], random_size=False).set_random_state(0), # Random crop
])

val_transforms = monai.transforms.Compose([
    ReadFiles(),        # Applies the 'ReadFiles' function
    EqualizeSpacing(),  # Applies the 'EqualizeSpacing' function
    LoadData(),         # Applies the 'LoadData' function  
    EnsureChannelFirstd(channel_dim="no_channel", keys=["img", "mask"]),
    ScaleIntensityd(keys=["img"]), # Normalize the intensity
    RandSpatialCropd(keys=['img', 'mask'], roi_size=[32, 176, 176], random_size=False).set_random_state(0),  # Random crop
])

In [None]:
# Applying the composed transform function to the train and validation data sets.
train_dataset = monai.data.CacheDataset(data=train_dict_list, transform=train_transforms)
val_dataset = monai.data.CacheDataset(data=val_dict_list, transform=val_transforms)

## Defining training and validation loaders

In [None]:
# Defining the training and validation loaders. 
train_loader = monai.data.DataLoader(
    train_dataset,
    batch_size=16,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
)

val_loader = monai.data.DataLoader(
    val_dataset,
    batch_size=16,
    num_workers=0,
    pin_memory=torch.cuda.is_available(),
)

In [None]:
# Select GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Defining the model

In [None]:
# A 3D U-net is created with the MONAI library
model = monai.networks.nets.UNet(
    spatial_dims=3,                   # Two dimensions (axial 2D slice)
    in_channels=1,                    # In channel = scan-file
    out_channels=4,                   # Output channel = background, right endocardium, left endocardium and left myocardium
    channels=(48, 96, 192, 384, 768), # Defined layers
    strides=(2, 2, 2, 2),
    num_res_units=2,
    dropout=0.25                      # Drop-out for regularization of the model
).to(device)

# Defining the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

# Defining the Dice loss function
loss_function = monai.losses.DiceLoss(batch=True, softmax=True, to_onehot_y=True, include_background=False)

## Training Loop

In [None]:
# Choosing amount of epochs
epochs = 150

# Initiate a training progress on weights and biases.
run = wandb.init(
    name='Unet_3d'
)

for epoch in tqdm(range(epochs)):
    model.train()

    train_loss = 0
    steps = 0
    for train_batch in train_loader:
        steps += 1
        optimizer.zero_grad()
        outputs = model(train_batch['img'].float().to(device))
        loss = loss_function(outputs, train_batch['mask'].to(device))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= steps

    model.eval()

    val_loss = 0
    steps = 0
    for val_batch in val_loader:
        steps += 1
        with torch.no_grad():
            outputs = model(val_batch['img'].float().to(device))
            loss = loss_function(outputs, val_batch['mask'].to(device))
        val_loss += loss.item()
    val_loss /= steps
    
    wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss})

run.finish()
torch.save(model.state_dict(), r'trainedUNet3D_noise04en0.1std.pt')