## CNN application in 3D medical imaging

This notebook contains the codes for applying CNN to medical imaging, e.g. predict a patient's age from brain MRI. It compares the direct CNN approach to the two stage (segmentation + supervised learning) approach.

In [None]:
! pip install SimpleITK==1.2.4 

! wget https://www.doc.ic.ac.uk/~bglocker/teaching/notebooks/brainage-data.zip
! unzip brainage-data.zip

In [None]:
# Mount Google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# data directory
data_dir = 'data/brain_age/'

In [None]:
# Read the meta data using pandas
import pandas as pd

meta_data_all = pd.read_csv(data_dir + 'meta/meta_data_all.csv')
meta_data_all.head() # show the first five data entries

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns

meta_data = meta_data_all

sns.catplot(x="gender_text", data=meta_data, kind="count")
plt.title('Gender distribution')
plt.xlabel('Gender')
plt.show()

sns.distplot(meta_data['age'], bins=[10,20,30,40,50,60,70,80,90])
plt.title('Age distribution')
plt.xlabel('Age')
plt.show()

plt.scatter(range(len(meta_data['age'])),meta_data['age'], marker='.')
plt.grid()
plt.xlabel('Subject')
plt.ylabel('Age')
plt.show()

In [None]:
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed
from IPython.display import display

In [None]:
# Calculate parameters low and high from window and level
def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, x=None, y=None, z=None, window=None, level=None, colormap='gray', crosshair=False):
    # Convert SimpleITK image to NumPy array
    img_array = sitk.GetArrayFromImage(img)
    
    # Get image dimensions in millimetres
    size = img.GetSize()
    spacing = img.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]
    
    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)
    
    if window is None:
        window = np.max(img_array) - np.min(img_array)
    
    if level is None:
        level = window / 2 + np.min(img_array)
    
    low,high = wl_to_lh(window,level)

    # Display the orthogonal slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

    ax1.imshow(img_array[z,:,:], cmap=colormap, clim=(low, high), extent=(0, width, height, 0))
    ax2.imshow(img_array[:,y,:], origin='lower', cmap=colormap, clim=(low, high), extent=(0, width,  0, depth))
    ax3.imshow(img_array[:,:,x], origin='lower', cmap=colormap, clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    if crosshair:
        ax1.axhline(y * spacing[1], lw=1)
        ax1.axvline(x * spacing[0], lw=1)
        ax2.axhline(z * spacing[2], lw=1)
        ax2.axvline(x * spacing[0], lw=1)
        ax3.axhline(z * spacing[2], lw=1)
        ax3.axvline(y * spacing[1], lw=1)

    plt.show()
    
def interactive_view(img):
    size = img.GetSize() 
    img_array = sitk.GetArrayFromImage(img)
    interact(display_image,img=fixed(img),
             x=(0, size[0] - 1),
             y=(0, size[1] - 1),
             z=(0, size[2] - 1),
             window=(0,np.max(img_array) - np.min(img_array)),
             level=(np.min(img_array),np.max(img_array)));

In [None]:
# Subject with index 0
ID = meta_data['subject_id'][0]
age = meta_data['age'][0]

# Image
image_filename = data_dir + 'images/sub-' + ID + '_T1w_unbiased.nii.gz'
img = sitk.ReadImage(image_filename)
print('Image size and spance')
print(img.GetSize())
print(img.GetSpacing())

# Mask
mask_filename = data_dir + 'masks/sub-' + ID + '_T1w_brain_mask.nii.gz'
msk = sitk.ReadImage(mask_filename)

print('Imaging data of subject ' + ID + ' with age ' + str(age))

print('\nMR Image')
display_image(img, window=400, level=200)

print('Brain mask')
display_image(msk)

In [None]:
# Setting up for brain tissue segmentation
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [None]:
def zero_mean_unit_var(image, mask):
    """Normalizes an image to zero mean and unit variance."""

    img_array = sitk.GetArrayFromImage(image)
    img_array = img_array.astype(np.float32)

    msk_array = sitk.GetArrayFromImage(mask)

    mean = np.mean(img_array[msk_array>0])
    std = np.std(img_array[msk_array>0])

    if std > 0:
        img_array = (img_array - mean) / std
        img_array[msk_array==0] = 0

    image_normalised = sitk.GetImageFromArray(img_array)
    image_normalised.CopyInformation(image)

    return image_normalised


def resample_image(image, out_spacing=(1.0, 1.0, 1.0), out_size=None, is_label=False, pad_value=0):
    """Resamples an image to given element spacing and output size."""

    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(image)


class ImageSegmentationDataset(Dataset):
    """Dataset for image segmentation."""

    def __init__(self, file_list_img, file_list_seg, file_list_msk, img_spacing, img_size):
        self.samples = []
        self.img_names = []
        self.seg_names = []
        for idx, _ in enumerate(tqdm(range(len(file_list_img)), desc='Loading Data')):
            img_path = file_list_img[idx]
            seg_path = file_list_seg[idx]
            msk_path = file_list_msk[idx]

            img = sitk.ReadImage(img_path, sitk.sitkFloat32)

            seg = sitk.ReadImage(seg_path, sitk.sitkInt64)

            msk = sitk.ReadImage(msk_path, sitk.sitkUInt8)

            #pre=processing
            img = zero_mean_unit_var(img, msk)
            img = resample_image(img, img_spacing, img_size, is_label=False)
            seg = resample_image(seg, img_spacing, img_size, is_label=True)
            msk = resample_image(msk, img_spacing, img_size, is_label=True)

            sample = {'img': img, 'seg': seg, 'msk': msk}

            self.samples.append(sample)
            self.img_names.append(os.path.basename(img_path))
            self.seg_names.append(os.path.basename(seg_path))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        sample = self.samples[item]

        image = torch.from_numpy(sitk.GetArrayFromImage(sample['img'])).unsqueeze(0)
        seg = torch.from_numpy(sitk.GetArrayFromImage(sample['seg'])).unsqueeze(0)
        msk = torch.from_numpy(sitk.GetArrayFromImage(sample['msk'])).unsqueeze(0)

        return {'img': image, 'seg': seg, 'msk': msk}

    def get_sample(self, item):
        return self.samples[item]

    def get_img_name(self, item):
        return self.img_names[item]

    def get_seg_name(self, item):
        return self.seg_names[item]


In [None]:
# Check GPU
cuda_dev = '0' #GPU device 0 (can be changed if multiple GPUs are available)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:" + cuda_dev if use_cuda else "cpu")

print('Device: ' + str(device))
if use_cuda:
    print('GPU: ' + str(torch.cuda.get_device_name(int(cuda_dev))))        

In [None]:
rnd_seed = 42 #fixed random seed

img_size = [96, 96, 96]
img_spacing = [2, 2, 2]

num_epochs = 20
learning_rate = 0.001
batch_size = 2
val_interval = 10

num_classes = 4

out_dir = './output'

# Create output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [None]:
# Data preprocessing
meta_data_seg_train = pd.read_csv(data_dir + 'meta/meta_data_segmentation_train.csv')
ids_seg_train = list(meta_data_seg_train['subject_id'])
files_seg_img_train = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids_seg_train]
files_seg_seg_train = [data_dir + 'segs_refs/sub-' + f + '_T1w_seg.nii.gz' for f in ids_seg_train]
files_seg_msk_train = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids_seg_train]

meta_data_seg_val = pd.read_csv(data_dir + 'meta/meta_data_segmentation_val.csv')
ids_seg_val = list(meta_data_seg_val['subject_id'])
files_seg_img_val = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids_seg_val]
files_seg_seg_val = [data_dir + 'segs_refs/sub-' + f + '_T1w_seg.nii.gz' for f in ids_seg_val]
files_seg_msk_val = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids_seg_val]

In [None]:
# Into DataLoaders

# LOAD TRAINING DATA
dataset_train = ImageSegmentationDataset(files_seg_img_train, files_seg_seg_train, files_seg_msk_train, img_spacing, img_size)
# FOR QUICK DEBUGGING, USE THE VALIDATION DATA FOR TRAINING
#dataset_train = ImageSegmentationDataset(files_seg_img_val, files_seg_seg_val, files_seg_msk_val, img_spacing, img_size)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

# LOAD VALIDATION DATA
dataset_val = ImageSegmentationDataset(files_seg_img_val, files_seg_seg_val, files_seg_msk_val, img_spacing, img_size)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False)

In [None]:
# Visualise training samples
sample = dataset_train.get_sample(0)
img_name = dataset_train.get_img_name(0)
seg_name = dataset_train.get_seg_name(0)
print('sample size and spacing')
print(sample['img'].GetSize())
print(sample['img'].GetSpacing())
print()

a = sitk.LabelToRGB(sample['seg'])
a = sitk.GetArrayFromImage(a)
print(a[32][32][32])

print('Image: ' + img_name)
display_image(sample['img'], window=5, level=0)
print('Segmentation')
display_image(sitk.LabelToRGB(sample['seg']))
print('Mask')
display_image(sample['msk'])

In [None]:
# Build model

# Input x: [batch_size, channel, 64, 64, 64]

# pool of square window of size=2, stride=2
# m = nn.MaxPool3d(2, stride=2)
# input = torch.randn(20, 16, 50, 44, 31)
# output = m(input)
# print(output.shape)

# With square kernels and equal stride
# m = nn.ConvTranspose3d(16, 33, 2, stride=2)
# m = nn.ConvTranspose3d(16, 33, 3, stride=1)
# input = torch.randn(20, 16, 1, 2, 4)
# output = m(input)
# print(output.shape)

class conv(nn.Module):
    def __init__(self, input, output):
        super(conv, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv3d(input, output, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(output),
            nn.Dropout(0.2),
            nn.LeakyReLU(inplace=True),

            nn.Conv3d(output, output, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(output),
            nn.Dropout(0.3),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.layer(x)

class down_samp(nn.Module):
    def __init__(self, c):
        super(down_samp, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv3d(c, c, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm3d(c),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.layer(x)

class up_samp(nn.Module):
    def __init__(self, c):
        super(up_samp, self).__init__()
        self.layer = nn.Sequential(
            nn.ConvTranspose3d(c, c // 2, kernel_size=2, stride=2, bias=False),
            nn.BatchNorm3d(c // 2),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.layer(x)

class SimpleNet3D(nn.Module):

    def __init__(self, num_classes):
        super(SimpleNet3D, self).__init__()
        # self.conv1 = nn.Conv3d(1, 4, kernel_size=3, padding=1)
        # self.conv2 = nn.Conv3d(4, num_classes, kernel_size=3, padding=1)



        self.down_conv_1 = conv(1, 64)
        self.down_samp_1 = down_samp(64)
        self.down_conv_2 = conv(64, 128)
        self.down_samp_2 = down_samp(128)
        self.down_conv_3 = conv(128, 256)
        self.down_samp_3 = down_samp(256)
        self.down_conv_4 = conv(256, 512)
        self.down_samp_4 = down_samp(512)

        self.bottom = conv(512, 1024)

        self.up_samp_1 = up_samp(1024)
        self.up_conv_1 = conv(1024, 512)
        self.up_samp_2 = up_samp(512)
        self.up_conv_2 = conv(512, 256)
        self.up_samp_3 = up_samp(256)
        self.up_conv_3 = conv(256, 128)
        self.up_samp_4 = up_samp(128)
        self.up_conv_4 = conv(128, 64)

        self.out = nn.Conv3d(64, num_classes, kernel_size=1, stride=1, bias=False)

        

    def forward(self, x):
        d1c = self.down_conv_1(x)
        d1s = self.down_samp_1(d1c)
        d2c = self.down_conv_2(d1s)
        d2s = self.down_samp_2(d2c)
        d3c = self.down_conv_3(d2s)
        d3s = self.down_samp_3(d3c)
        d4c = self.down_conv_4(d3s)
        d4s = self.down_samp_4(d4c)

        x_bottom = self.bottom(d4s)

        u1 = self.up_samp_1(x_bottom)
        u1 = torch.cat((d4c, u1), dim=1)
        u1 = self.up_conv_1(u1)
        u2 = self.up_samp_2(u1)
        u2 = torch.cat((d3c, u2), dim=1)
        u2 = self.up_conv_2(u2)
        u3 = self.up_samp_3(u2)
        u3 = torch.cat((d2c, u3), dim=1)
        u3 = self.up_conv_3(u3)
        u4 = self.up_samp_4(u3)
        u4 = torch.cat((d1c, u4), dim=1)
        u4 = self.up_conv_4(u4)

        x = self.out(u4)
        
        
        return x # cross-entropy loss expects raw logits and applies softmax

# torch.cuda.empty_cache()
# test = SimpleNet3D(num_classes=num_classes).to('cpu')
# for batch_idx, batch_samples in enumerate(dataloader_train):
#   x = batch_samples['img'].to('cpu')
#   break
# with torch.no_grad():
#   p = test(x)
# del p
# torch.cuda.empty_cache()

In [None]:
# Training 

model_dir = os.path.join(out_dir, 'model')
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

torch.manual_seed(rnd_seed) #fix random seed

model = SimpleNet3D(num_classes=num_classes).to(device)
model.train()
    
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

loss_train_log = []
loss_val_log = []
epoch_val_log = []
    
print('START TRAINING...')
for epoch in range(1, num_epochs + 1):

    # Training
    for batch_idx, batch_samples in enumerate(dataloader_train):
        img, seg = batch_samples['img'].to(device), batch_samples['seg'].to(device)
        optimizer.zero_grad()
        prd = model(img)
        loss = F.cross_entropy(prd, seg.squeeze(1))
        loss.backward()
        optimizer.step()

    loss_train_log.append(loss.item())

    print('+ TRAINING \tEpoch: {} \tLoss: {:.6f}'.format(epoch, loss.item()))
    
    # Validation
    if epoch == 1 or epoch % val_interval == 0:
        loss_val = 0
        sum_pts = 0
        with torch.no_grad():
            for data_sample in dataloader_val:
                img, seg = data_sample['img'].to(device), data_sample['seg'].to(device)
                prd = model(img)
                loss_val += F.cross_entropy(prd, seg.squeeze(1), reduction='sum').item()
                sum_pts += np.prod(img_size)
                
        prd = torch.argmax(prd, dim=1)
        prediction = sitk.GetImageFromArray(prd.cpu().squeeze().numpy().astype(np.uint8))
        

        loss_val /= sum_pts

        loss_val_log.append(loss_val)
        epoch_val_log.append(epoch)

        print('--------------------------------------------------')
        print('+ VALIDATE \tEpoch: {} \tLoss: {:.6f}'.format(epoch, loss_val))
        display_image(sitk.LabelToRGB(prediction))
        print('--------------------------------------------------')

torch.save(model.state_dict(), os.path.join(model_dir, 'model.pt'))

print('\nFinished TRAINING.')

plt.plot(range(1, num_epochs + 1), loss_train_log, c='r', label='train')
plt.plot(epoch_val_log, loss_val_log, c='b', label='val')
plt.legend(loc='upper right')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

In [None]:
# Load and preprocess test data
meta_data_reg_train = pd.read_csv(data_dir + 'meta/meta_data_regression_train.csv')
ids_seg_test = list(meta_data_reg_train['subject_id'])
files_seg_img_test = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids_seg_test]
files_seg_seg_test = [data_dir + 'segs_refs/sub-' + f + '_T1w_seg.nii.gz' for f in ids_seg_test]
files_seg_msk_test = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids_seg_test]

dataset_test = ImageSegmentationDataset(files_seg_img_test, files_seg_seg_test, files_seg_msk_test, img_spacing, img_size)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False)

In [None]:
# Visualise testing samples
sample = dataset_test.get_sample(0)
img_name = dataset_test.get_img_name(0)
seg_name = dataset_test.get_seg_name(0)
print('Image: ' + img_name)
display_image(sample['img'], window=5, level=0)
print('Segmentation')
display_image(sitk.LabelToRGB(sample['seg']))

print('Mask')
display_image(sample['msk'])

In [None]:
# Testing 
pred_dir = os.path.join(out_dir, 'pred')
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)
from sklearn import metrics
import pandas as pd
model = SimpleNet3D(num_classes=num_classes)
model_dir = 'drive/MyDrive/Colab Notebooks/ML for Img/'
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt')))
model.to(device)
model.eval()
    
print('START TESTING...')


precision0 = np.zeros(500)
recall0 = np.zeros(500)
precision1 = np.zeros(500)
recall1 = np.zeros(500)
precision2 = np.zeros(500)
recall2 = np.zeros(500)
precision3 = np.zeros(500)
recall3 = np.zeros(500)
acc = np.zeros(500)
f10 = np.zeros(500)
f11 = np.zeros(500)
f12 = np.zeros(500)
f13 = np.zeros(500)


loss_test = 0
sum_pts = 0
idx_test = 0
with torch.no_grad():
    for data_sample in dataloader_test:
        img, seg = data_sample['img'].to(device), data_sample['seg'].to(device)
        prd = model(img)
        loss_test += F.cross_entropy(prd, seg.squeeze(1), reduction='sum').item()
        sum_pts += np.prod(img_size)
        
        prd = torch.argmax(prd, dim=1)

        # get pred and truth seg
        t = seg.view(-1, 1).cpu().numpy()
        p = prd.view(-1, 1).cpu().numpy()

        acc[idx_test] = metrics.accuracy_score(t, p)

        # find label 0 in prediction, print(np.sum(p == 0))
        macro_precision = metrics.precision_score(t, p, average=None)
        macro_recall = metrics.recall_score(t, p, average=None)
        precision0[idx_test] = macro_precision[0]
        precision1[idx_test] = macro_precision[1]
        precision2[idx_test] = macro_precision[2]
        precision3[idx_test] = macro_precision[3]
        recall0[idx_test] = macro_recall[0]
        recall1[idx_test] = macro_recall[1]
        recall2[idx_test] = macro_recall[2]
        recall3[idx_test] = macro_recall[3]

        f10[idx_test] = 2 * macro_precision[0] * macro_recall[0] / (macro_precision[0] + macro_recall[0])
        f11[idx_test] = 2 * macro_precision[1] * macro_recall[1] / (macro_precision[1] + macro_recall[1])
        f12[idx_test] = 2 * macro_precision[2] * macro_recall[2] / (macro_precision[2] + macro_recall[2])
        f13[idx_test] = 2 * macro_precision[3] * macro_recall[3] / (macro_precision[3] + macro_recall[3])

        sample = dataset_test.get_sample(idx_test)
        name = dataset_test.get_seg_name(idx_test)
        prediction = sitk.GetImageFromArray(prd.cpu().squeeze().numpy().astype(np.uint8))
        prediction.CopyInformation(sample['seg'])
        sitk.WriteImage(prediction, os.path.join(pred_dir, name))
        
        idx_test += 1
        if idx_test % 50 == 0: print('\t tested ' + str(idx_test))

# show acc
fig1, ax1 = plt.subplots()
ax1.set_title('Acc Plot')
ax1.boxplot(acc, sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.show()
# show precision
data = np.concatenate((precision0.reshape((500, 1)), precision1.reshape((500, 1)), precision2.reshape((500, 1)), precision3.reshape((500, 1))), axis=1)
data = pd.DataFrame(data, columns=['Background','CSF','GM','WM'])
data.boxplot(sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.title('Precision Plot')
plt.show()
# show recall
data = np.concatenate((recall0.reshape((500, 1)), recall1.reshape((500, 1)), recall2.reshape((500, 1)), recall3.reshape((500, 1))), axis=1)
data = pd.DataFrame(data, columns=['Background','CSF','GM','WM'])
data.boxplot(sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.title('Recall Plot')
plt.show()
# show Dice/F1
data = np.concatenate((f10.reshape((500, 1)), f11.reshape((500, 1)), f12.reshape((500, 1)), f13.reshape((500, 1))), axis=1)
data = pd.DataFrame(data, columns=['Background','CSF','GM','WM'])
data.boxplot(sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.title('Dice/F1 Plot')
plt.show()


# macro_precision = np.mean(precision0.mean() + precision1.mean() + precision2.mean() + precision3.mean()) / 500
# macro_recall = np.mean(recall0.mean() + recall1.mean() + recall2.mean() + recall3.mean()) / 500
# print('+ TESTING \tmacro precision: {:.6f}'.format(macro_precision))
# print('+ TESTING \tmacro recall: {:.6f}'.format(macro_recall))
# macro_f1 = 2 * macro_precision * macro_recall / (macro_precision + macro_recall)
# print('+ TESTING \tDice/F1 score: {:.6f}'.format(macro_f1))
        
loss_test /= sum_pts

print('+ TESTING \tLoss: {:.6f}'.format(loss_test))

# Show last testing sample as an example
print('\n\nReference segmentation')
display_image(sitk.LabelToRGB(sample['seg']))
print('Predicted segmentation')
display_image(sitk.LabelToRGB(prediction))

print('\nFinished TESTING.')

In [None]:
# Calculating absolute tissue volume

import os

# USE THIS TO RUN THE CALCULATIONS ON YOUR SEGMENTATONS
seg_dir = './output/pred/'

# USE THIS TO RUN THE CALCULATIONS ON OUR REFERENCE SEGMENTATIONS
# seg_dir = data_dir + 'segs_refs/'

meta_data_reg_train = pd.read_csv(data_dir + 'meta/meta_data_regression_train.csv')

ids_reg_train = list(meta_data_reg_train['subject_id'])
files_reg_seg_train = [seg_dir + 'sub-' + f + '_T1w_seg.nii.gz' for f in ids_reg_train]

# THIS MATRIX WILL STORE THE VOLUMES PER TISSUE CLASS
vols = np.zeros((3,len(files_reg_seg_train)))

for idx, _ in enumerate(tqdm(range(len(files_reg_seg_train)), desc='Calculating Features')):
    
    seg_filename = files_reg_seg_train[idx]
    
    if os.path.exists(seg_filename):
        seg = sitk.ReadImage(seg_filename)
        
        ########################################
        # ADD YOUR CODE HERE
        ########################################
        seg_array = sitk.GetArrayFromImage(seg)
        for i in range(3):
            vols[i,idx] = np.count_nonzero(seg_array == i+1)
            
print(vols)

In [None]:
# Plotting
plt.scatter(vols[0,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols[1,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols[2,:],meta_data_reg_train['age'], marker='.')
plt.grid()
plt.title('Unnormalised')
plt.xlabel('Volume')
plt.ylabel('Age')
plt.legend(('CSF','GM','WM'))
plt.show()

In [None]:
# Calculating relative tissue volume

vols_normalised = np.zeros((3,len(files_reg_seg_train)))

########################################
# ADD YOUR CODE HERE
########################################

total = np.sum(vols, axis = 0)
vols_normalised = vols / total
print(vols_normalised.shape)


In [None]:
# Plotting
plt.scatter(vols_normalised[0,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols_normalised[1,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols_normalised[2,:],meta_data_reg_train['age'], marker='.')
plt.grid()
plt.title('Normalised')
plt.xlabel('Volume')
plt.ylabel('Age')
plt.legend(('CSF','GM','WM'))
plt.show()

In [None]:
# Prepare data for age regression
X = vols_normalised.T
y = meta_data_reg_train['age'].values.reshape(-1,1)

print(X.shape)
print(y.shape)

In [None]:
# Run regression

from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.metrics import mean_absolute_error

# normal y
min = 10
max = 100
y_normal = (y - min) / (max - min)
y_normal = y_normal.reshape(y_normal.shape[0])

best_nn = None
min_MAE = 1024
learning_rate = [0.0001, 0.0005, 0.001, 0.002, 0.005]
alpha = [0.00001, 0.0001, 0.001]
score = np.zeros((3, 5))
i, j = 0, 0
for a in alpha:
  j = 0
  for lr in learning_rate:
    print('LR = '+str(lr)+',\t alpha = '+str(a))
    MAE = 0

    kfold = KFold(n_splits = 2, shuffle = False, random_state = None)
    for idx_train, idx_validation in kfold.split(X):
      MLP = MLPRegressor(
        hidden_layer_sizes=(1024, 256, 64, 16),
        activation='relu',
        solver='adam',
        alpha=a, # L2 parameter
        batch_size=25,
        learning_rate='invscaling',
        learning_rate_init=lr,
        max_iter=200,
        shuffle=True
      )

      X_train, X_vali = X[idx_train], X[idx_validation]
      y_train, y_vali = y_normal[idx_train], y_normal[idx_validation]
      

      MLP.fit(X_train, y_train)
      prediction = MLP.predict(X_vali)
      MAE += mean_absolute_error(y_vali, prediction)

    score[i][j] += MAE / 2.

    if score[i][j] < min_MAE:
      min_MAE = score[i][j]
      best_nn = MLP

    j += 1

  i += 1

print('min MAE = ', min_MAE)
idx = np.array([score == np.min(score)]).reshape((3, 5))
for i in range(len(alpha)):
  for j in range(len(learning_rate)):
    if idx[i][j] == True:
      print('alpha, lr, MAE = ', alpha[i], learning_rate[j], score[i][j], 'Min')
    else:
      print('alpha, lr, MAE = ', alpha[i], learning_rate[j], score[i][j])

In [None]:
# Evaluation metrics

from sklearn.metrics import r2_score
predicted = best_nn.predict(X)
predicted = predicted * (max - min) + min

print('MAE: {0}'.format(mean_absolute_error(y,predicted)))
print('R2: {0}'.format(r2_score(y,predicted)))

fig, ax = plt.subplots()
ax.scatter(y, predicted, marker='.')
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.show()

In [None]:
# Other regression attempts

from sklearn.model_selection import cross_validate
from sklearn.linear_model import LinearRegression, Ridge
from sklearn import svm, tree
from sklearn.ensemble import RandomForestRegressor, AdaBoostRegressor, GradientBoostingRegressor

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score


lin_reg = LinearRegression()
svm_reg = svm.SVR()
tree_reg = tree.DecisionTreeRegressor()
rf_reg = RandomForestRegressor(max_depth = 4, random_state = rnd_seed)
ada_reg = AdaBoostRegressor(n_estimators = 50, learning_rate = 0.3)
grad_reg = GradientBoostingRegressor()
ridge_reg = Ridge(alpha = 0.01)


# Ridge Regression
scoring = ["neg_mean_absolute_error", "r2"]
scores = cross_validate(ridge_reg, X, y, scoring = scoring, cv = 2)
keys = sorted(scores.keys())
print(scores[keys[2]])
print(scores[keys[3]])

ridge_reg.fit(X, y)
predicted = ridge_reg.predict(X)

# Ridge regression plot

print('MAE: {0}'.format(mean_absolute_error(y,predicted)))
print('R2: {0}'.format(r2_score(y,predicted)))

fig, ax = plt.subplots()
ax.scatter(y, predicted, marker='.')
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.show()
fig.savefig("ridge_reg_plot")

In [None]:
# Support vector machine

scoring = ["neg_mean_absolute_error", "r2"]
scores = cross_validate(svm_reg, X, y.reshape(y.shape[0],), scoring = scoring, cv = 2)
keys = sorted(scores.keys())
print(scores[keys[2]])
print(scores[keys[3]])

svm_reg.fit(X, y)
predicted = svm_reg.predict(X)

# Support vector machine plot

print('MAE: {0}'.format(mean_absolute_error(y,predicted)))
print('R2: {0}'.format(r2_score(y,predicted)))

fig, ax = plt.subplots()
ax.scatter(y, predicted, marker='.')
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.show()
fig.savefig("ridge_reg_plot")

In [None]:
# Random forest regression

scoring = ["neg_mean_absolute_error", "r2"]
scores = cross_validate(rf_reg, X, y.reshape(y.shape[0],), scoring = scoring, cv = 2)
keys = sorted(scores.keys())
print(scores[keys[2]])
print(scores[keys[3]])

rf_reg.fit(X, y)
predicted = rf_reg.predict(X)

# Random forest regression plot

print('MAE: {0}'.format(mean_absolute_error(y,predicted)))
print('R2: {0}'.format(r2_score(y,predicted)))

fig, ax = plt.subplots()
ax.scatter(y, predicted, marker='.')
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.show()
fig.savefig("randomforest_reg_plot")

In [None]:
# Adaboost regression

scoring = ["neg_mean_absolute_error", "r2"]
scores = cross_validate(ada_reg, X, y.reshape(y.shape[0],), scoring = scoring, cv = 2)
keys = sorted(scores.keys())
print(scores[keys[2]])
print(scores[keys[3]])

ada_reg.fit(X, y)
predicted = ada_reg.predict(X)

# Adaboost regression plot

print('MAE: {0}'.format(mean_absolute_error(y,predicted)))
print('R2: {0}'.format(r2_score(y,predicted)))

fig, ax = plt.subplots()
ax.scatter(y, predicted, marker='.')
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.show()
fig.savefig("ada_reg_plot")

In [None]:
# Test on held-out data
! wget https://www.doc.ic.ac.uk/~bglocker/teaching/notebooks/brainage-testdata.zip
! unzip brainage-testdata.zip

# Loading data
meta_data_reg_test = pd.read_csv(data_dir + 'meta/meta_data_regression_test.csv')
ids_seg_test = list(meta_data_reg_test['subject_id'])
files_seg_img_test = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids_seg_test]
files_seg_seg_test = [data_dir + 'segs_refs/sub-' + f + '_T1w_seg.nii.gz' for f in ids_seg_test]
files_seg_msk_test = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids_seg_test]

dataset_test = ImageSegmentationDataset(files_seg_img_test, files_seg_seg_test, files_seg_msk_test, img_spacing, img_size)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False)


In [None]:

# Run test
pred_dir = os.path.join(out_dir, 'final_test')
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)
from sklearn import metrics
import pandas as pd
model = SimpleNet3D(num_classes=num_classes)
model_dir = 'drive/MyDrive/Colab Notebooks/ML for Img/'
model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt')))
model.to(device)
model.eval()
    
print('START TESTING...')

num = 100
precision0 = np.zeros(num)
recall0 = np.zeros(num)
precision1 = np.zeros(num)
recall1 = np.zeros(num)
precision2 = np.zeros(num)
recall2 = np.zeros(num)
precision3 = np.zeros(num)
recall3 = np.zeros(num)
acc = np.zeros(num)
f10 = np.zeros(num)
f11 = np.zeros(num)
f12 = np.zeros(num)
f13 = np.zeros(num)


loss_test = 0
sum_pts = 0
idx_test = 0
with torch.no_grad():
    for data_sample in dataloader_test:
        img, seg = data_sample['img'].to(device), data_sample['seg'].to(device)
        prd = model(img)
        loss_test += F.cross_entropy(prd, seg.squeeze(1), reduction='sum').item()
        sum_pts += np.prod(img_size)
        
        prd = torch.argmax(prd, dim=1)

        # get pred and truth seg
        t = seg.view(-1, 1).cpu().numpy()
        p = prd.view(-1, 1).cpu().numpy()

        acc[idx_test] = metrics.accuracy_score(t, p)

        # find label 0 in prediction, print(np.sum(p == 0))
        macro_precision = metrics.precision_score(t, p, average=None)
        macro_recall = metrics.recall_score(t, p, average=None)
        precision0[idx_test] = macro_precision[0]
        precision1[idx_test] = macro_precision[1]
        precision2[idx_test] = macro_precision[2]
        precision3[idx_test] = macro_precision[3]
        recall0[idx_test] = macro_recall[0]
        recall1[idx_test] = macro_recall[1]
        recall2[idx_test] = macro_recall[2]
        recall3[idx_test] = macro_recall[3]

        f10[idx_test] = 2 * macro_precision[0] * macro_recall[0] / (macro_precision[0] + macro_recall[0])
        f11[idx_test] = 2 * macro_precision[1] * macro_recall[1] / (macro_precision[1] + macro_recall[1])
        f12[idx_test] = 2 * macro_precision[2] * macro_recall[2] / (macro_precision[2] + macro_recall[2])
        f13[idx_test] = 2 * macro_precision[3] * macro_recall[3] / (macro_precision[3] + macro_recall[3])

        sample = dataset_test.get_sample(idx_test)
        name = dataset_test.get_seg_name(idx_test)
        prediction = sitk.GetImageFromArray(prd.cpu().squeeze().numpy().astype(np.uint8))
        prediction.CopyInformation(sample['seg'])
        sitk.WriteImage(prediction, os.path.join(pred_dir, name))
        
        idx_test += 1
        if idx_test % 10 == 0: print('\t tested ' + str(idx_test))

# show acc
fig1, ax1 = plt.subplots()
ax1.set_title('Acc Plot')
ax1.boxplot(acc, sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.show()
# show precision
data = np.concatenate((precision0.reshape((num, 1)), precision1.reshape((num, 1)), precision2.reshape((num, 1)), precision3.reshape((num, 1))), axis=1)
data = pd.DataFrame(data, columns=['Background','CSF','GM','WM'])
data.boxplot(sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.title('Precision Plot')
plt.show()
# show recall
data = np.concatenate((recall0.reshape((num, 1)), recall1.reshape((num, 1)), recall2.reshape((num, 1)), recall3.reshape((num, 1))), axis=1)
data = pd.DataFrame(data, columns=['Background','CSF','GM','WM'])
data.boxplot(sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.title('Recall Plot')
plt.show()
# show Dice/F1
data = np.concatenate((f10.reshape((num, 1)), f11.reshape((num, 1)), f12.reshape((num, 1)), f13.reshape((num, 1))), axis=1)
data = pd.DataFrame(data, columns=['Background','CSF','GM','WM'])
data.boxplot(sym = 'o', vert = True, whis=1.5, patch_artist = True, meanline = False, showmeans = True, showbox = True, showfliers = True)
plt.title('Dice/F1 Plot')
plt.show()

        
loss_test /= sum_pts

print('+ TESTING \tLoss: {:.6f}'.format(loss_test))

# Show last testing sample as an example
print('\n\nReference segmentation')
display_image(sitk.LabelToRGB(sample['seg']))
print('Predicted segmentation')
display_image(sitk.LabelToRGB(prediction))

print('\nFinished TESTING.')

In [None]:
# Calculating relative tissue volume

import os

# USE THIS TO RUN THE CALCULATIONS ON SEGMENTATONS
seg_dir = './output/final_test/'

# USE THIS TO RUN THE CALCULATIONS ON REFERENCE SEGMENTATIONS
# seg_dir = data_dir + 'segs_refs/'

meta_data_reg_train = pd.read_csv(data_dir + 'meta/meta_data_regression_test.csv')

ids_reg_train = list(meta_data_reg_train['subject_id'])
files_reg_seg_train = [seg_dir + 'sub-' + f + '_T1w_seg.nii.gz' for f in ids_reg_train]

# THIS MATRIX WILL STORE THE VOLUMES PER TISSUE CLASS
vols = np.zeros((3,len(files_reg_seg_train)))

for idx, _ in enumerate(tqdm(range(len(files_reg_seg_train)), desc='Calculating Features')):
    
    seg_filename = files_reg_seg_train[idx]
    
    if os.path.exists(seg_filename):
        seg = sitk.ReadImage(seg_filename)
        
        seg = sitk.GetArrayFromImage(seg)
        total = seg.shape[0]*seg.shape[1]*seg.shape[2]

        # relative tissue volumes for GM, WM and CSF
        num_csf = np.sum(seg == 1) / total
        num_gm = np.sum(seg == 2) / total
        num_wm = np.sum(seg == 3) / total

        vols[0][idx] = num_csf
        vols[1][idx] = num_gm
        vols[2][idx] = num_wm

plt.scatter(vols[0,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols[1,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols[2,:],meta_data_reg_train['age'], marker='.')
plt.grid()
plt.title('Unnormalised')
plt.xlabel('Volume')
plt.ylabel('Age')
plt.legend(('CSF','GM','WM'))
plt.show()


In [None]:
# Calculating absolute tissue volume

vols_normalised = np.zeros((3,len(files_reg_seg_train)))
vols_normalised = np.log(vols)

mean = np.mean(vols_normalised)
std = np.std(vols_normalised)
vols_normalised = (vols_normalised - mean) / std

# to 0-1
min = np.min(vols_normalised)
max = np.max(vols_normalised)
vols_normalised = (vols_normalised - min) / (max - min)

plt.scatter(vols_normalised[0,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols_normalised[1,:],meta_data_reg_train['age'], marker='.')
plt.scatter(vols_normalised[2,:],meta_data_reg_train['age'], marker='.')
plt.grid()
plt.title('Normalised')
plt.xlabel('Volume')
plt.ylabel('Age')
plt.legend(('CSF','GM','WM'))
plt.show()

In [None]:
X = vols_normalised.T
y = meta_data_reg_train['age'].values.reshape(-1,1)

print(X.shape)
print(y.shape)


from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score

min = 10
max = 100
predicted = best_nn.predict(X)
predicted = predicted * (max - min) + min

print('MAE: {0}'.format(mean_absolute_error(y,predicted)))
print('R2: {0}'.format(r2_score(y,predicted)))

fig, ax = plt.subplots()
ax.scatter(y, predicted, marker='.')
ax.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.show()

## Use CNN directly to predict age from 3D scan

In [None]:
# Setting up

import os
import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm

# data directory
data_dir = 'data/brain_age/'

# Read the meta data using pandas
import pandas as pd
meta_data_all = pd.read_csv(data_dir + 'meta/meta_data_all.csv')
meta_data_train = pd.read_csv(data_dir + 'meta/meta_data_regression_train.csv')
meta_data_all.head() # show the first five data entries

%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns

meta_data = meta_data_all

In [None]:
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed
from IPython.display import display

# Calculate parameters low and high from window and level
def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, x=None, y=None, z=None, window=None, level=None, colormap='gray', crosshair=False):
    # Convert SimpleITK image to NumPy array
    img_array = sitk.GetArrayFromImage(img)
    
    # Get image dimensions in millimetres
    size = img.GetSize()
    spacing = img.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]
    
    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)
    
    if window is None:
        window = np.max(img_array) - np.min(img_array)
    
    if level is None:
        level = window / 2 + np.min(img_array)
    
    low,high = wl_to_lh(window,level)

    # Display the orthogonal slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

    ax1.imshow(img_array[z,:,:], cmap=colormap, clim=(low, high), extent=(0, width, height, 0))
    ax2.imshow(img_array[:,y,:], origin='lower', cmap=colormap, clim=(low, high), extent=(0, width,  0, depth))
    ax3.imshow(img_array[:,:,x], origin='lower', cmap=colormap, clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    if crosshair:
        ax1.axhline(y * spacing[1], lw=1)
        ax1.axvline(x * spacing[0], lw=1)
        ax2.axhline(z * spacing[2], lw=1)
        ax2.axvline(x * spacing[0], lw=1)
        ax3.axhline(z * spacing[2], lw=1)
        ax3.axvline(y * spacing[1], lw=1)

    plt.show()
    
def interactive_view(img):
    size = img.GetSize() 
    img_array = sitk.GetArrayFromImage(img)
    interact(display_image,img=fixed(img),
             x=(0, size[0] - 1),
             y=(0, size[1] - 1),
             z=(0, size[2] - 1),
             window=(0,np.max(img_array) - np.min(img_array)),
             level=(np.min(img_array),np.max(img_array)));

In [None]:
def zero_mean_unit_var(image, mask):
    """Normalizes an image to zero mean and unit variance."""

    img_array = sitk.GetArrayFromImage(image)
    img_array = img_array.astype(np.float32)

    msk_array = sitk.GetArrayFromImage(mask)

    mean = np.mean(img_array[msk_array>0])
    std = np.std(img_array[msk_array>0])

    if std > 0:
        img_array = (img_array - mean) / std
        img_array[msk_array==0] = 0

    image_normalised = sitk.GetImageFromArray(img_array)
    image_normalised.CopyInformation(image)

    return image_normalised


def resample_image(image, out_spacing=(1.0, 1.0, 1.0), out_size=None, is_label=False, pad_value=0):
    """Resamples an image to given element spacing and output size."""

    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(image)


class ImageRegressionDataset(Dataset):
    """Dataset for image regression."""

    def __init__(self, file_list_img,file_list_msk, label_list_age,img_spacing, img_size):
        self.samples = []
        self.img_names = []
        self.labels = []
        for idx, _ in enumerate(tqdm(range(len(file_list_img)), desc='Loading Data')):
            img_path = file_list_img[idx]
#             seg_path = file_list_seg[idx]
            msk_path = file_list_msk[idx]

            img = sitk.ReadImage(img_path, sitk.sitkFloat32)

#             seg = sitk.ReadImage(seg_path, sitk.sitkInt64)

            msk = sitk.ReadImage(msk_path, sitk.sitkUInt8)

            #pre=processing
            img = zero_mean_unit_var(img, msk)
            img = resample_image(img, img_spacing, img_size, is_label=False)

            

            self.samples.append(img)
            self.img_names.append(os.path.basename(img_path))
            self.labels.append(label_list_age[idx])

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, item):
        sample = self.samples[item]

        image = torch.from_numpy(sitk.GetArrayFromImage(sample)).unsqueeze(0)

        return image,self.labels[item]

    def get_sample(self, item):
        return self.samples[item]

    def get_img_name(self, item):
        return self.img_names[item]

In [None]:
# Check GPU
cuda_dev = '0' #GPU device 0 (can be changed if multiple GPUs are available)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:" + cuda_dev if use_cuda else "cpu")

print('Device: ' + str(device))
if use_cuda:
    print('GPU: ' + str(torch.cuda.get_device_name(int(cuda_dev))))  

In [None]:
rnd_seed = 42 #fixed random seed

img_size = [96, 96, 96]
img_spacing = [2, 2, 2]

num_epochs = 50
learning_rate = 0.001
batch_size = 25
val_interval = 10

out_dir = './output'

# Create output directory
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [None]:
meta_data_regression_train = pd.read_csv(data_dir + 'meta/meta_data_regression_train.csv')
ids_regression_train = list(meta_data_regression_train['subject_id'])
ages_regression_train = list(meta_data_regression_train['age'])
files_regression_img_train = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids_regression_train]
files_regression_msk_train = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids_regression_train]

meta_data_regression_test = pd.read_csv(data_dir + 'meta/meta_data_regression_test.csv')
ids_regression_test = list(meta_data_regression_test['subject_id'])
ages_regression_test = list(meta_data_regression_test['age'])
files_regression_img_test = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids_regression_test]
files_regression_msk_test = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids_regression_test]

In [None]:
# Load training data
dataset_train = ImageRegressionDataset(files_regression_img_train, files_regression_msk_train,ages_regression_train, img_spacing, img_size)

train_proportion = 0.5
train_examples_fold_0 = round(len(dataset_train)*train_proportion)
train_examples_fold_1 = len(dataset_train) - train_examples_fold_0

dataset_train_fold_0, dataset_train_fold_1 = random_split(dataset_train,
                                           (train_examples_fold_0,
                                            train_examples_fold_1))

dataloader_train_fold_0 = torch.utils.data.DataLoader(dataset_train_fold_0, batch_size=batch_size, shuffle=True)
dataloader_train_fold_1 = torch.utils.data.DataLoader(dataset_train_fold_1, batch_size=batch_size, shuffle=True)
dataloaders_train = [dataloader_train_fold_0,dataloader_train_fold_1]
dataloader_train_all = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)

In [None]:
# Load testing data
dataset_test = ImageRegressionDataset(files_regression_img_test, files_regression_msk_test,ages_regression_test, img_spacing, img_size)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size)

In [None]:
# Build 3D CNN

class CNN3D(nn.Module):
    def __init__(self):
        super(CNN3D, self).__init__()

        self.DCNN = nn.Sequential(
            # input shape (batch_size, 1, 96, 96, 96)
            nn.Conv3d(1, 6, kernel_size=5, bias=True, padding=0),
            nn.LeakyReLU(0.2),
            nn.MaxPool3d(2, stride = 2),


            nn.Conv3d(6, 16, kernel_size=5, bias=True, padding=0),
            nn.LeakyReLU(0.2),
            nn.MaxPool3d(2, stride = 2),


            nn.Conv3d(16, 120, kernel_size=5, bias=True, padding=0),
            nn.LeakyReLU(0.2),
         

            nn.Conv3d(120, 84, kernel_size=3, stride = 2, bias=True, padding=0),
            nn.LeakyReLU(0.2),
            nn.BatchNorm3d(84),

            nn.Conv3d(84, 32, kernel_size=3, bias=True, padding=0),
            nn.LeakyReLU(0.2),
            nn.BatchNorm3d(32),

            nn.Conv3d(32, 8, kernel_size=3, bias=True, padding=0),
            nn.LeakyReLU(0.2),
            nn.BatchNorm3d(8),

            nn.Conv3d(8, 1, kernel_size=4, bias=True, padding=0),
        )

        
    def forward(self, x, label=None):

        out = self.DCNN(x)
        
        return out


In [None]:
def loss_function(out, label):
#     loss_func = nn.MSELoss()
    loss_func = nn.L1Loss()
    loss = loss_func(out, label)
    return loss

In [None]:
# Two fold cross validation
from torch.optim.lr_scheduler import MultiStepLR
fig, axs = plt.subplots(2, 2, figsize=(10, 7))

num_epochs = 50
learning_rate = 0.0008
val_interval = 5
milestones = [8, 18, 35]
gamma = 0.5


for fold in range(2):
    dataloader_train = dataloaders_train[fold]
    dataloader_val = dataloaders_train[1-fold]
    train_losses = []
    val_losses = []
    
    model = CNN3D().to(device)
    print(model)
    
    beta1 = 0.9
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, 0.999))
    scheduler = MultiStepLR(optimizer, milestones, gamma = gamma)
    
    for epoch in range(num_epochs):

        train_loss_record = 0
        val_loss_record = 0
        train_count = 0
        val_count = 0

        for data,label in dataloader_train:
            optimizer.zero_grad()
            output = model(data.to(device)).reshape(-1,)
            train_loss = loss_function(output,label.to(torch.float).to(device))
            train_loss.backward()
            optimizer.step()

            train_loss_record += train_loss.item()
            train_count += 1
        
        train_loss_record /= train_count
        train_losses.append(train_loss_record)
        
        
        label_all = []
        output_all = []
        
        with torch.no_grad():
            for data,label in dataloader_val:
                output = model(data.to(device)).reshape(-1,)
                val_loss = loss_function(output,label.to(torch.float).to(device))
                val_loss_record += val_loss.item()
                val_count += 1
                if epoch == num_epochs-1:
                    label_all += list(label.numpy())
                    output_all += list(output.to('cpu').numpy())
            val_loss_record /= val_count
            val_losses.append(val_loss_record)
        print(f"EPOCH: {epoch+1} --- the train MAE is {train_loss_record:.4f}, the val MAE is {val_loss_record:.4f}")

        scheduler.step()
        if epoch == num_epochs-1:
            label_all = np.array(label_all)
            output_all = np.array(output_all)
            axs[1][fold].scatter(label_all, output_all, marker='.')
            axs[1][fold].plot([label_all.min(), label_all.max()], [label_all.min(), label_all.max()], 'k--', lw=2)
            axs[1][fold].set_xlabel('Real Age')
            axs[1][fold].set_ylabel('Predicted Age')
    
    axs[0][fold].plot(train_losses)
    axs[0][fold].plot(val_losses)
    axs[0][fold].legend(['train','val'])
    
    
fig.savefig('two_fold.png')
fig.show()
#     plt.figure()
#     plt.plot(train_losses)
#     plt.plot(val_losses)
#     plt.legend(['train','val'])
#     plt.savefig('fold'+str(fold)+'.png')
#     plt.show()
    

In [None]:
train_losses = []
test_losses = []
model = CNN3D().to(device)
print(model)

beta1 = 0.5
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(beta1, 0.999))
for epoch in range(num_epochs):
    train_loss_sum = 0
    for data,label in dataloader_train_all:
        model.zero_grad()
        output = model(data.to(device)).reshape(-1,)
        train_loss = loss_function(output,label.to(torch.float).to(device))
        train_loss.backward()
        optimizer.step()
        train_loss_sum += train_loss.item()
    train_loss = train_loss_sum/len(dataloader_train_all)
    train_losses.append(train_loss)
    with torch.no_grad():
        test_loss_sum = 0
        for data,label in dataloader_test:
            output = model(data.to(device)).reshape(-1)
            test_loss = loss_function(output,label.to(torch.float).to(device))
            test_loss_sum += test_loss.item()
        test_loss = test_loss_sum/len(dataloader_test)
        test_losses.append(test_loss)
    print(f"EPOCH: {epoch+1} --- the train MAE is {train_losses[-1]:.4f}, the val MAE is {test_losses[-1]:.4f}")


In [None]:
plt.figure()
plt.plot(train_losses)
plt.plot(test_losses)
plt.legend(['train','test'])
plt.savefig('test.png')
plt.show()

In [None]:
real_age = []
output_all = []
with torch.no_grad():
    for data,label in dataloader_test:
        real_age += list(label.numpy())
        output = model(data.to(device)).reshape(-1).to('cpu')
        output_all += list(output.numpy())
real_age = np.array(real_age)
output_all = np.array(output_all)

# print('MAE: {0}'.format(mean_absolute_error(real_age,output)))
# print('R2: {0}'.format(r2_score(real_age,output)))

fig, ax = plt.subplots()
ax.scatter(real_age, output_all, marker='.')
ax.plot([real_age.min(), real_age.max()], [real_age.min(), real_age.max()], 'k--', lw=2)
ax.set_xlabel('Real Age')
ax.set_ylabel('Predicted Age')
plt.savefig('scatter_part_b.png')
plt.show()