In [6]:
from google.colab import drive

drive.mount("/content/gdrive")

Mounted at /content/gdrive


# 3D-UNet reconstruction model

In [2]:
# Import libraries
import tarfile
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt
from matplotlib import colors
import imageio.v2 as imageio

import nibabel as nib
from tqdm import tqdm
import glob
import shutil
import keras


In [None]:
# # Download the dataset
# !wget https://www.dropbox.com/s/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz

# # Unzip the '.tar.gz' file to the current directory
# datafile = tarfile.open('Task01_BrainTumour_2D.tar.gz')
# datafile.extractall()
# datafile.close()

In [None]:
nii_file = '/Users/zifengwang/Desktop/Imperial_Computing/Indiv_Research/Code/Zhuang_dataset/labelled_meshes/ct_train_1010_label.nii.gz'
nii_image = nib.load(nii_file)
x, y, z = nii_image.shape

# Extract the image data as a NumPy array
data = nii_image.get_fdata()

# plt.imshow(img.get_fdata()[:, :, z//2], cmap='gray')

lv_value = 500
rv_value = 600
la_value = 420

lv_mask = (data == lv_value)
lv_center_of_mass = np.mean(np.nonzero(lv_mask), axis=1)

rv_mask = (data == rv_value)
rv_center_of_mass = np.mean(np.nonzero(rv_mask), axis=1)

la_mask = (data == la_value)
la_center_of_mass = np.mean(np.nonzero(la_mask), axis=1)


# Extract the 2-chamber and 4-chamber views
# Set  all non-2chamber and non-4-chamber views pixels as 0(background value)
data[0:int(lv_center_of_mass[0]), 0:int(lv_center_of_mass[1]), :] = 0
data[(int(lv_center_of_mass[0]) + 1) : x, 0: int(lv_center_of_mass[1]), :] = 0
data[0:int(lv_center_of_mass[0]), (int(lv_center_of_mass[1]) + 1) : y, :] = 0
data[(int(lv_center_of_mass[0]) + 1) : x, (int(lv_center_of_mass[1]) + 1) : y, : ] = 0


# Save the sliced long axis views as a NIfTI file
sliced_img = nib.Nifti1Image(data, img.affine)
output_file = 'demo/ct_train_1010_4ch.nii.gz'
nib.save(sliced_img, output_file)

# vector1 = rv_center_of_mass - lv_center_of_mass
# vector2 = lv_center_of_mass - la_center_of_mass
# normal_vector = np.cross(vector1, vector2)

# normalized_normal_vector = normal_vector / np.linalg.norm(normal_vector)
# D = -np.dot(normalized_normal_vector, lv_center_of_mass)

# x_range = np.arange(0, 256, 1)
# y_range = np.arange(0, 256, 1)
# z_range = np.arange(0, 256, 1)
# X, Y, Z = np.meshgrid(x_range, y_range, z_range)

# Z_plane = (-normalized_normal_vector[0] * X - normalized_normal_vector[1] * Y - D) / normalized_normal_vector[2]

# # Create a new NIfTI image using Z_plane as data
# new_nii = nib.Nifti1Image(Z_plane, affine=None)

# # Set the header information (optional)
# new_nii.header['pixdim'] = nii_image.header['pixdim']
# new_nii.header['qform_code'] = nii_image.header['qform_code']
# new_nii.header['sform_code'] = nii_image.header['sform_code']
# new_nii.header['quatern_b'] = nii_image.header['quatern_b']
# new_nii.header['quatern_c'] = nii_image.header['quatern_c']
# new_nii.header['quatern_d'] = nii_image.header['quatern_d']
# new_nii.header['qoffset_x'] = nii_image.header['qoffset_x']
# new_nii.header['qoffset_y'] = nii_image.header['qoffset_y']
# new_nii.header['qoffset_z'] = nii_image.header['qoffset_z']
# new_nii.header['srow_x'] = nii_image.header['srow_x']
# new_nii.header['srow_y'] = nii_image.header['srow_y']
# new_nii.header['srow_z'] = nii_image.header['srow_z']

# # Save the NIfTI image to a new file
# output_file = 'demo/ct_train_1010_4ch.nii.gz'
# nib.save(new_nii, output_file)

# Z_plane = Z_plane.squeeze()
# plt.imshow(Z_plane, cmap='gray', origin='lower')
# plt.xlabel('X')
# plt.ylabel('Y')
# plt.title('Slice Plane')
# plt.colorbar()
# plt.show()

# lowest_z_index = np.argmin(np.nonzero(lv_mask), axis=None)
# lv_apex = np.unravel_index(lowest_z_index, lv_mask.shape)
# # print("Center of mass:", center_of_mass)
# print("Apex:", lv_apex)
# lv_mask.shape


In [None]:
input_dir = 'Zhuang_dataset/labelled_meshes'
output_dir = 'sliced_long_axis_views'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir, "*.nii.gz"))

for nii_file in nii_files:
    # Load the NIfTI file
    img = nib.load(nii_file)
    x, y, z = img.shape
    affine = img.affine

    data = img.get_fdata()

    # Extract the 2-chamber and 4-chamber views
    four_chamber_view = data[ :, y//2, :]
    two_chamber_view = data[x//2, :, :]

    # Update the affine matrix
    new_affine = np.copy(affine)
    new_affine[0, 0] //= img.header.get_zooms()[0]
    new_affine[1, 1] //= img.header.get_zooms()[1]
    new_affine[2, 2] //= img.header.get_zooms()[2]

    # Create the output filenames
    file_name = os.path.basename(nii_file)
    two_chamber_filename = os.path.join(output_dir, "2ch_" + file_name)
    four_chamber_filename = os.path.join(output_dir, "4ch_" + file_name)

    # Save the 2-chamber view as a NIfTI file
    two_chamber_img = nib.Nifti1Image(two_chamber_view, new_affine)
    nib.save(two_chamber_img, two_chamber_filename)

    # Save the 4-chamber view as a NIfTI file
    four_chamber_img = nib.Nifti1Image(four_chamber_view, new_affine)
    nib.save(four_chamber_img, four_chamber_filename)



In [None]:
input_dir_zhuang = 'Zhuang_dataset/labelled_meshes'
input_dir_bai = 'Bai_dataset/pilot_project/data/14*'
output_dir = '3D_long_axis_views_bai'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "segmentation_*.nii.gz"))

for nii_file in nii_files:
    # Load the NIfTI file
    img = nib.load(nii_file)
    x, y, z, c = img.shape

    data = img.get_fdata()

    lv_value = 1
    rv_value = 4

    lv_mask = (data == lv_value)
    lv_centre = np.mean(np.nonzero(lv_mask), axis=1)

    # Extract the 2-chamber and 4-chamber views
    # Set all non-2chamber and non-4-chamber views pixels as 0(background value)
    data[0:int(lv_centre[0]), 0:int(lv_centre[1]), :] = 0
    data[(int(lv_centre[0]) + 1) : x, 0: int(lv_centre[1]), :] = 0
    data[0:int(lv_centre[0]), (int(lv_centre[1]) + 1) : y, :] = 0
    data[(int(lv_centre[0]) + 1) : x, (int(lv_centre[1]) + 1) : y, : ] = 0

    #remove all cardiac structures other than left ventricle and left ventricle myocardium
    lv = [1,2]
    data = np.where(np.isin(data, lv), data, 0)

    # Create the output filenames
    sample_name = os.path.basename(os.path.dirname(nii_file))
    sample_phase = os.path.basename(nii_file)
    sliced_filename = os.path.join(output_dir, sample_name + "_" + sample_phase)

    # Save the sliced long axis views as a NIfTI file
    sliced_img = nib.Nifti1Image(data, img.affine)
    nib.save(sliced_img, sliced_filename)


# Move the nii.gz files from separate directories into the same directory

In [None]:
# import shutil
# #directory of 3D segmentation label files
# directory_path = 'Bai_dataset/pilot_project/data/14*'
# output_directory = '3D_label_views_bai'

# if not os.path.exists(output_directory):
#     os.makedirs(output_directory)

# # Iterate over the files in the directory
# for root, dirs, files in os.walk(directory_path):
#     for file_name in files:
#         if file_name.endswith('.nii.gz') and file_name.startswith('segmentation'):
#             img = nib.load(os.path.join(directory_path, file_name))

#             data = img.get_fdata()
#             lv = [1,2]
#             data = np.where(np.isin(data, lv), data, 0)

#             file_path = os.path.join(root, file_name)
#             new_file_name = os.path.basename(root) + '_' + file_name
#             new_file_path = os.path.join(output_directory, new_file_name)
#             # shutil.copyfile(file_path, new_file_path)

#             # Save the views as a NIfTI file
#             processed_img = nib.Nifti1Image(data, img.affine)
#             nib.save(processed_img, new_file_path)


In [None]:
input_dir_bai = 'Bai_dataset/pilot_project/data/14*'
output_dir = '3D_label_views_bai'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "segmentation_*.nii.gz"))

for nii_file in nii_files:
    # Load the NIfTI file
    img = nib.load(nii_file)
    x, y, z, c = img.shape

    data = img.get_fdata()

    #remove all cardiac structures other than left ventricle and left ventricle myocardium
    lv = [1,2]
    data = np.where(np.isin(data, lv), data, 0)

    # Create the output filenames
    sample_name = os.path.basename(os.path.dirname(nii_file))
    sample_phase = os.path.basename(nii_file)
    sliced_filename = os.path.join(output_dir, sample_name + "_" + sample_phase)

    # Save the sliced long axis views as a NIfTI file
    sliced_img = nib.Nifti1Image(data, img.affine)
    nib.save(sliced_img, sliced_filename)


# Apply padding to original samples to unify sample shapes

In [None]:
import math

input_dir_bai = '3D_label_views_bai'
output_dir = 'pad_3D_label_views_bai'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "*.nii.gz"))

max_x = max_y = max_z = 0

#identify the largest shape in the training samples
for nii_file in nii_files:
    # Load the NIfTI file
    img = nib.load(nii_file)
    x1, y1, z1, c1 = img.shape
    if x1 > max_x:
        max_x = x1
    if y1 > max_y:
        max_y = y1
    if z1 > max_z:
        max_z = z1

for nii_file in nii_files:
    img = nib.load(nii_file)
    x, y, z, c = img.shape
    data = img.get_fdata()

    pad_x = int((max_x - x)/2)
    pad_y = int((max_y - y)/2)
    pad_z_left = (max_z - z)//2
    pad_z_right = math.ceil((max_z - z) / 2)

    padded_data = np.pad(data, pad_width=((pad_x,pad_x), (pad_y, pad_y), (pad_z_left, pad_z_right), (0,0)))

#     print(padded_data.shape)

    # Create the output filenames
    sample_name = os.path.basename(nii_file)
    padded_filename = os.path.join(output_dir, "pad_" + sample_name)

    # Save the sliced long axis views as a NIfTI file
    padded_img = nib.Nifti1Image(padded_data, img.affine)
    nib.save(padded_img, padded_filename)

(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320, 60, 1)
(320, 320,

# Apply padding to long-axis views

In [None]:
import math

input_dir_bai = '3D_long_axis_views_bai'
output_dir = 'pad_3D_long_axis_views_bai'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "*.nii.gz"))

max_x = max_y = max_z = 0

#identify the largest shape in the training samples
for nii_file in nii_files:
    # Load the NIfTI file
    img = nib.load(nii_file)
    x1, y1, z1, c1 = img.shape
    if x1 > max_x:
        max_x = x1
    if y1 > max_y:
        max_y = y1
    if z1 > max_z:
        max_z = z1

#apply padding to the images and save the padded images to output_dir
for nii_file in nii_files:
    img = nib.load(nii_file)
    x, y, z, c = img.shape
    data = img.get_fdata()

    pad_x = int((max_x - x)/2)
    pad_y = int((max_y - y)/2)
    pad_z_left = (max_z - z)//2
    pad_z_right = math.ceil((max_z - z) / 2)

    padded_data = np.pad(data, pad_width=((pad_x,pad_x), (pad_y, pad_y), (pad_z_left, pad_z_right), (0,0)))

    # Create the output filenames
    sample_name = os.path.basename(nii_file)
    padded_filename = os.path.join(output_dir, "pad_" + sample_name)

    # Save the sliced long axis views as a NIfTI file
    padded_img = nib.Nifti1Image(padded_data, img.affine)
    nib.save(padded_img, padded_filename)

# Apply cropping to original images

In [None]:
#function that takes a numpy array and output a cropped array
def crop_center_of_mass(array, centre_label = 1):
    # Find the center of mass of the structure with label 1
    indices = np.argwhere(array == centre_label)
    center_of_mass = np.mean(indices, axis=0)

    # Calculate the crop boundaries
    x_start = int(center_of_mass[0] - 64)
    x_end = x_start + 128
    y_start = int(center_of_mass[1] - 64)
    y_end = y_start + 128

    # Crop the array around the center of mass
    cropped_array = array[x_start:x_end, y_start:y_end, :, :]

    return cropped_array


In [None]:
import math
input_dir_bai = '3D_long_axis_views_bai'
output_dir = 'crop_3D_long_axis_views_bai'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "*.nii.gz"))

#apply cropping to the images and save the cropped images to output_dir
for nii_file in nii_files:
    img = nib.load(nii_file)
    x, y, z, c = img.shape
    data = img.get_fdata()

    data = crop_center_of_mass(data, centre_label = 1)

    pad_z_left = (64 - z)//2
    pad_z_right = math.ceil((64 - z) / 2)
    data = np.pad(data, pad_width=((0,0), (0,0), (pad_z_left, pad_z_right), (0,0)))

    # Create the output filenames
    sample_name = os.path.basename(nii_file)
    cropped_filename = os.path.join(output_dir, "crop_" + sample_name)

    # Save the sliced long axis views as a NIfTI file
    cropped_img = nib.Nifti1Image(data, img.affine)
    nib.save(cropped_img, cropped_filename)


input_dir_bai = '3D_label_views_bai'
output_dir = 'crop_3D_label_views_bai'

os.makedirs(output_dir, exist_ok=True)

nii_files = glob.glob(os.path.join(input_dir_bai, "*.nii.gz"))

#apply cropping to the images and save the cropped images to output_dir
for nii_file in nii_files:
    img = nib.load(nii_file)
    x, y, z, c = img.shape
    data = img.get_fdata()

    data = crop_center_of_mass(data, centre_label = 1)

    pad_z_left = (64 - z)//2
    pad_z_right = math.ceil((64 - z) / 2)
    data = np.pad(data, pad_width=((0,0), (0,0), (pad_z_left, pad_z_right), (0,0)))

    # Create the output filenames
    sample_name = os.path.basename(nii_file)
    cropped_filename = os.path.join(output_dir, "crop_" + sample_name)

    # Save the sliced long axis views as a NIfTI file
    cropped_img = nib.Nifti1Image(data, img.affine)
    nib.save(cropped_img, cropped_filename)


In [None]:
output_dir = 'crop_3D_label_views_bai'
nii_files = glob.glob(os.path.join(output_dir, "*.nii.gz"))
for nii_file in nii_files:
    img = nib.load(nii_file)
    data = img.get_fdata()
    print(data.shape)

(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128, 60, 1)
(128, 128,

# Separate data into training set and test set

In [None]:
import os
import random
import shutil

# Set the paths to the original image and label folders
image_folder = 'crop_3D_long_axis_views_bai'
label_folder = 'crop_3D_label_views_bai'

# Set the paths to the train and test folders
train_image_folder = 'crop_preprocessed_bai/train_2D'
train_label_folder = 'crop_preprocessed_bai/train_3D'
test_image_folder = 'crop_preprocessed_bai/test_2D'
test_label_folder = 'crop_preprocessed_bai/test_3D'

# Create the train and test folders if they don't exist
os.makedirs(train_image_folder, exist_ok=True)
os.makedirs(train_label_folder, exist_ok=True)
os.makedirs(test_image_folder, exist_ok=True)
os.makedirs(test_label_folder, exist_ok=True)

# Get the list of image files in the image folder
image_files = [file for file in os.listdir(image_folder) if file.endswith('.nii.gz')]

# Shuffle the image files randomly
random.shuffle(image_files)

# Calculate the number of files for training and testing
train_ratio = 0.8
num_train = int(len(image_files) * train_ratio)
num_test = len(image_files) - num_train

# Split the image files into train and test sets
train_images = image_files[:num_train]
test_images = image_files[num_train:]

# Move the images and labels to the respective train and test folders
for image in train_images:
    src_image_path = os.path.join(image_folder, image)
    dst_image_path = os.path.join(train_image_folder, image)
    shutil.move(src_image_path, dst_image_path)

    # Move the corresponding label file
    label_file = os.path.basename(image)
    src_label_path = os.path.join(label_folder, label_file)
    dst_label_path = os.path.join(train_label_folder, label_file)
    shutil.move(src_label_path, dst_label_path)

for image in test_images:
    src_image_path = os.path.join(image_folder, image)
    dst_image_path = os.path.join(test_image_folder, image)
    shutil.move(src_image_path, dst_image_path)

    # Move the corresponding label file
    label_file = os.path.basename(image)
    src_label_path = os.path.join(label_folder, label_file)
    dst_label_path = os.path.join(test_label_folder, label_file)
    shutil.move(src_label_path, dst_label_path)


## 2. Implement a dataset class.

It can read the imaging dataset and get items, pairs of images and label maps, as training batches.

In [None]:
class CardiacImageSet(keras.utils.Sequence):
    """ Cardiac image set """
    def __init__(self, image_path, label_path='', deploy=False):
        self.image_path = image_path
        self.deploy = deploy
        self.images = []
        self.labels = []

        image_names = [file for file in os.listdir(image_path) if file.endswith('.nii.gz')]
        for image_name in image_names:
            # Read the image
            image = nib.load(os.path.join(image_path, image_name))
            image = image.get_fdata()
            #transpose image dimension from XYZC to CXYZ
            image = np.transpose(image, (3, 0, 1, 2))
            self.images += [image]

            # Read the label map
            if not self.deploy:
                label_name = os.path.join(label_path, image_name)
                label = nib.load(label_name)
                label = label.get_fdata()
                label = np.transpose(label, (3, 0, 1, 2))
                self.labels += [label]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Get an image and perform intensity normalisation
        # Dimension: XYZ
        # image = normalise_intensity(self.images[idx])
        image = self.images[idx]

        # Get its label map
        # Dimension: XYZ
        label = self.labels[idx]
        return image, label

    def get_random_batch(self, batch_size):
        # Get a batch of paired images and label maps
        # Dimension of images: NCXYZ
        # Dimension of labels: NXYZ
        images, labels = [], []

        ### Insert your code ###
        for i in range(batch_size):
            #randomly retrieve an image and label map
            random_idx = random.randint(0,self.__len__() - 1)
            random_image, random_label = self.__getitem__(random_idx)
            images += [random_image]
            labels += [random_label]

        #Turn the list into np array
        images = np.array(images)
        labels = np.array(labels)
        ### End of your code ###
        return images, labels

    def get_batch(self, batch_size, iteration_num):
      images, labels = [], []
      batch_num = self.__len__()//batch_size
      image_idx = ((iteration_num % batch_num) - 1) * batch_size
      for i in range(batch_size):
        image, label = self.__getitem__(image_idx + i)
        images += [image]
        labels += [label]

      images = np.array(images)
      labels = np.array(labels)
      return images, labels

# train_set = CardiacImageSet('/content/gdrive/My Drive/crop_preprocessed_bai/train_2D', '/content/gdrive/My Drive/crop_preprocessed_bai/train_3D')
# test_set = CardiacImageSet('/content/gdrive/My Drive/crop_preprocessed_bai/test_2D', '/content/gdrive/My Drive/crop_preprocessed_bai/test_3D')

# image_1, label_1 = train_set.__getitem__(88)
# print(image_1.shape)
# print(label_1.shape)

# test_images, test_labels = train_set.get_random_batch(4)
# print(test_images.shape)
# print(test_labels.shape)
# print(test_images.dtype)


## Construct a U-net architecture.


In [4]:
class UNet3d(nn.Module):
    def contracting_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(out_channels),
        )
        return block

    def expansive_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.ConvTranspose3d(in_channels=mid_channel, out_channels=out_channels, kernel_size=3, stride=2,
                                     padding=1, output_padding=1)
        )
        return block

    def final_block(self, in_channels, mid_channel, out_channels, kernel_size=3):
        block = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=in_channels, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=mid_channel, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(mid_channel),
            torch.nn.Conv3d(kernel_size=kernel_size, in_channels=mid_channel, out_channels=out_channels, padding=1),
            torch.nn.Sigmoid()
        )
        return block

    def __init__(self, in_channel, out_channel):
        super(UNet3d, self).__init__()
        # Encode
        self.conv_encode1 = self.contracting_block(in_channel, 16, 32)
        self.conv_maxpool1 = torch.nn.MaxPool3d(kernel_size=2)
        self.conv_encode2 = self.contracting_block(32, 32, 64)
        self.conv_maxpool2 = torch.nn.MaxPool3d(kernel_size=2)
        self.conv_encode3 = self.contracting_block(64, 64, 128)
        self.conv_maxpool3 = torch.nn.MaxPool3d(kernel_size=2)
        # Bottleneck
        self.bottleneck = torch.nn.Sequential(
            torch.nn.Conv3d(kernel_size=3, in_channels=128, out_channels=128, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(128),
            torch.nn.Conv3d(kernel_size=3, in_channels=128, out_channels=256, padding=1),
            torch.nn.LeakyReLU(0.1),
            torch.nn.BatchNorm3d(256),
            torch.nn.ConvTranspose3d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1,
                                     output_padding=1)
        )
        # Decode
        self.conv_decode3 = self.expansive_block(128+256, 128, 128)
        self.conv_decode2 = self.expansive_block(64+128, 64, 64)
        self.final_layer = self.final_block(32+64, 32, out_channel)

    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            c = (bypass.size()[2] - upsampled.size()[2]) // 2
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        # print("unsampled shape:", upsampled.shape)
        # print("bypass shape:", bypass.shape)

        return torch.cat((upsampled, bypass), 1)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_maxpool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_maxpool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_maxpool3(encode_block3)
        # Bottleneck
        bottleneck1 = self.bottleneck(encode_pool3)
        # Decode
        decode_block3 = self.crop_and_concat(bottleneck1, encode_block3, crop=False)
        cat_layer2 = self.conv_decode3(decode_block3)
        decode_block2 = self.crop_and_concat(cat_layer2, encode_block2, crop=False)
        cat_layer1 = self.conv_decode2(decode_block2)
        decode_block1 = self.crop_and_concat(cat_layer1, encode_block1, crop=False)
        final_layer = self.final_layer(decode_block1)
        return final_layer


## 4. Train the segmentation model.

In [5]:
# CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: {0}'.format(device))

# Build the model
num_class = 3
model = UNet3d(in_channel=1, out_channel=num_class)
model = model.to(device)
params = list(model.parameters())

model_dir = '/content/gdrive/My Drive/saved_model'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Optimizer
optimizer = optim.Adam(params, lr=1e-3)

# Segmentation loss
criterion = nn.CrossEntropyLoss()

train_image_folder = '/content/gdrive/My Drive/crop_preprocessed_bai/train_2D'
train_label_folder = '/content/gdrive/My Drive/crop_preprocessed_bai/train_3D'
test_image_folder = '/content/gdrive/My Drive/crop_preprocessed_bai/test_2D'
test_label_folder = '/content/gdrive/My Drive/crop_preprocessed_bai/test_3D'

# Datasets
train_set = CardiacImageSet(train_image_folder, train_label_folder)
test_set = CardiacImageSet(test_image_folder, test_label_folder)

# Train the model

num_iter = 500
train_batch_size = 4
eval_batch_size = 4
start = time.time()
running_loss = 0
#number of batches in an epoch
num_batches = train_set.__len__()/train_batch_size
for it in range(1, 1 + num_iter):
    # Set the modules in training mode, which will have effects on certain modules, e.g. dropout or batchnorm.
    start_iter = time.time()
    model.train()

    # Get a batch of images and labels
    images, labels = train_set.get_batch(train_batch_size, it)
    images, labels = torch.from_numpy(images), torch.from_numpy(labels)
    # image.to() convert the array from system RAM to GPU RAM
    images, labels = images.to(device, dtype=torch.float32), labels.to(device, dtype=torch.long)
    #remove the channel dimension in the labels array
    labels = labels.squeeze(axis = 1)
    # print("Images shape:", images.shape)
    logits = model(images)

    # Perform optimisation and print out the training loss
    # print('logits shape:', logits.shape)
    # print('label shape:', labels.shape)

    loss = criterion(logits, labels)
    running_loss += loss

    if it % num_batches == 0:
        epoch_loss = running_loss/num_batches
        running_loss = 0
        print ("training loss for epoch {}:".format(it/num_batches))
        print(epoch_loss.item())

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    ###   ###

    # Evaluate
    if it % num_batches == 0:
        model.eval()
        # Disabling gradient calculation during reference to reduce memory consumption
        with torch.no_grad():
            # Evaluate on a batch of test images and print out the test loss
            ### Insert your code ###
            test_images, test_labels = test_set.get_random_batch(eval_batch_size)
            test_images, test_labels = torch.from_numpy(test_images), torch.from_numpy(test_labels)
            test_images, test_labels = test_images.to(device, dtype=torch.float32), test_labels.to(device, dtype=torch.long)
            test_labels = test_labels.squeeze(axis = 1)
            test_logits = model(test_images)
            test_loss = criterion(test_logits, test_labels)
            print ("test loss for iteration {}:".format(it/num_batches))
            print(test_loss.item())
            ### End of your code ###

    # Save the model
    if it % 100 == 0:
        torch.save(model.state_dict(), os.path.join(model_dir, 'model_{0}.pt'.format(it)))
print('Training took {:.3f}s in total.'.format(time.time() - start))

Device: cuda
training loss for epoch 1.0:
0.6309710741043091
test loss for iteration 1.0:
0.5744005441665649
training loss for epoch 2.0:
0.5774275064468384
test loss for iteration 2.0:
0.5718958377838135
training loss for epoch 3.0:
0.573948085308075
test loss for iteration 3.0:
0.5708145499229431
training loss for epoch 4.0:
0.5708648562431335
test loss for iteration 4.0:
0.8764143586158752
training loss for epoch 5.0:
0.5693016052246094
test loss for iteration 5.0:
0.5715335607528687
training loss for epoch 6.0:
0.5685228109359741
test loss for iteration 6.0:
1.1640300750732422
Training took 1464.462s in total.


# Visualise model reconstruction output

In [18]:
device = torch.device("cpu")
def model_load():

    unet = UNet3d(in_channel=1, out_channel=3)
    unet.to(device, dtype=torch.float)

    model_list = ['/content/gdrive/My Drive/saved_model/model_500.pt']
    model_i = model_list[0]
    checkpoint = torch.load(model_i)
    unet.load_state_dict(checkpoint)
    return unet

def mr_lax_inference(unet, slice_img):

    img = nib.load(slice_img)
    affine = img.affine
    data = img.get_fdata()

    data = torch.from_numpy(data)

    print(data.size())

    data = np.transpose(data, (3, 0, 1, 2))
    data = data.to(device, dtype=torch.float32)
    data = np.expand_dims(data, axis=(0,1))
    unet.eval()
    output = unet(data)
    pred = output.detach().cpu().numpy()

    return pred

sliced_image = '/content/gdrive/My Drive/crop_preprocessed_bai/test_2D/crop_14AB01345_segmentation_ES.nii.gz'
unet = model_load()
pred = mr_lax_inference(unet, sliced_image)
pred_save_dir = '/content/gdrive/My Drive/pred_output/pred.nii.gz'
nib.save(pred, pred_save_dir)

torch.Size([128, 128, 64, 1])


TypeError: ignored

*italicised text*## 5. Deploy the trained model to a random set of 4 test images and visualise the automated segmentation.

You can show the images as a 4 x 3 panel. Each row shows one example, with the 3 columns being the test image, automated segmentation and ground truth segmentation.

In [None]:
### Insert your code ###
image_test_1 = imageio.imread('Task01_BrainTumour_2D/test_images/BRATS_004_z62.png')
image_test_2 = imageio.imread('Task01_BrainTumour_2D/test_images/BRATS_016_z62.png')
image_test_3 = imageio.imread('Task01_BrainTumour_2D/test_images/BRATS_058_z93.png')
image_test_4 = imageio.imread('Task01_BrainTumour_2D/test_images/BRATS_115_z62.png')

label_test_1 = imageio.imread('Task01_BrainTumour_2D/test_labels/BRATS_004_z62.png')
label_test_2 = imageio.imread('Task01_BrainTumour_2D/test_labels/BRATS_016_z62.png')
label_test_3 = imageio.imread('Task01_BrainTumour_2D/test_labels/BRATS_058_z93.png')
label_test_4 = imageio.imread('Task01_BrainTumour_2D/test_labels/BRATS_115_z62.png')

image_test_1_3D, image_test_2_3D, image_test_3_3D, image_test_4_3D = np.expand_dims(image_test_1, axis=(0,1)), np.expand_dims(image_test_2, axis=(0,1)), np.expand_dims(image_test_3, axis=(0,1)), np.expand_dims(image_test_4, axis=(0,1))
image_test_1_tor, image_test_2_tor, image_test_3_tor, image_test_4_tor = torch.from_numpy(image_test_1_3D), torch.from_numpy(image_test_2_3D), torch.from_numpy(image_test_3_3D), torch.from_numpy(image_test_4_3D)
image_test_1_tor, image_test_2_tor, image_test_3_tor, image_test_4_tor = image_test_1_tor.to(device, dtype=torch.float32), image_test_2_tor.to(device, dtype=torch.float32), image_test_3_tor.to(device, dtype=torch.float32), image_test_4_tor.to(device, dtype=torch.float32)

pred_test_1 = np.argmax((model(image_test_1_tor).detach().reshape((4,120,120))), 0)
pred_test_2 = np.argmax((model(image_test_2_tor).detach().reshape((4,120,120))), 0)
pred_test_3 = np.argmax((model(image_test_3_tor).detach().reshape((4,120,120))), 0)
pred_test_4 = np.argmax((model(image_test_4_tor).detach().reshape((4,120,120))), 0)

fig = plt.figure(figsize=(50, 50))

fig.add_subplot(4, 3, 1)
plt.imshow(image_test_1, cmap='gray')
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 2)
plt.imshow(pred_test_1, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 3)
plt.imshow(label_test_1, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)


fig.add_subplot(4, 3, 4)
plt.imshow(image_test_2, cmap='gray')
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 5)
plt.imshow(pred_test_2, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 6)
plt.imshow(label_test_2, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)


fig.add_subplot(4, 3, 7)
plt.imshow(image_test_3, cmap='gray')
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 8)
plt.imshow(pred_test_3, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 9)
plt.imshow(label_test_3, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)


fig.add_subplot(4, 3, 10)
plt.imshow(image_test_4, cmap='gray')
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 11)
plt.imshow(pred_test_4, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)

fig.add_subplot(4, 3, 12)
plt.imshow(label_test_4, cmap = colors.ListedColormap(['black', 'green', 'blue', 'red']))
plt.axis('off')
plt.gcf().set_size_inches(10, 10)


### End of your code ###