# Example: PCA with Brain MRI

## Install SimpleITK

In [None]:
! pip install SimpleITK

## Downloading the data

In [None]:
! wget https://www.dropbox.com/s/xdkavdkqljhco2r/brainage-data.zip
! unzip brainage-data.zip

## Setting data directory

In [None]:
# data directory
data_dir = 'data/brain_age/'

In [None]:
# Read the meta data using pandas
import pandas as pd

meta_data_all = pd.read_csv(data_dir + 'meta/meta_data_all.csv')
meta_data_all.head() # show the first five data entries

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import seaborn as sns

meta_data = meta_data_all

sns.catplot(x="gender_text", data=meta_data, kind="count")
plt.title('Gender distribution')
plt.xlabel('Gender')
plt.show()

sns.displot(meta_data['age'], bins=[10,20,30,40,50,60,70,80,90], kde=True)
plt.title('Age distribution')
plt.xlabel('Age')
plt.show()

plt.scatter(range(len(meta_data['age'])),meta_data['age'], marker='.')
plt.grid()
plt.xlabel('Subject')
plt.ylabel('Age')
plt.show()

## Set up a simple medical image viewer and import SimpleITK

In [None]:
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt

from ipywidgets import interact, fixed
from IPython.display import display

# Calculate parameters low and high from window and level
def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, x=None, y=None, z=None, window=None, level=None, colormap='gray', crosshair=False):
    # Convert SimpleITK image to NumPy array
    img_array = sitk.GetArrayFromImage(img)
    
    # Get image dimensions in millimetres
    size = img.GetSize()
    spacing = img.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]
    
    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)
    
    if window is None:
        window = np.max(img_array) - np.min(img_array)
    
    if level is None:
        level = window / 2 + np.min(img_array)
    
    low,high = wl_to_lh(window,level)

    # Display the orthogonal slices
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10, 4))

    pos1 = ax1.imshow(img_array[z,:,:], cmap=colormap, clim=(low, high), extent=(0, width, height, 0))
    ax2.imshow(img_array[:,y,:], origin='lower', cmap=colormap, clim=(low, high), extent=(0, width,  0, depth))
    ax3.imshow(img_array[:,:,x], origin='lower', cmap=colormap, clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    if crosshair:
        ax1.axhline(y * spacing[1], lw=1)
        ax1.axvline(x * spacing[0], lw=1)
        ax2.axhline(z * spacing[2], lw=1)
        ax2.axvline(x * spacing[0], lw=1)
        ax3.axhline(z * spacing[2], lw=1)
        ax3.axvline(y * spacing[1], lw=1)

    fig.colorbar(pos1, ax=ax1)
    plt.show()
    
def interactive_view(img):
    size = img.GetSize() 
    img_array = sitk.GetArrayFromImage(img)
    interact(display_image,img=fixed(img),
             x=(0, size[0] - 1),
             y=(0, size[1] - 1),
             z=(0, size[2] - 1),
             window=(0,np.max(img_array) - np.min(img_array)),
             level=(np.min(img_array),np.max(img_array)));

## Preprocess imaging data

Let's check out the imaging data that is available for each subject. This cell also shows how to retrieve data given a particular subject ID from the meta data.

In [None]:
# Subject with index 0
ID = meta_data['subject_id'][0]
age = meta_data['age'][0]

# Image
image_filename = data_dir + 'images/sub-' + ID + '_T1w_unbiased.nii.gz'
img = sitk.ReadImage(image_filename)

# Mask
mask_filename = data_dir + 'masks/sub-' + ID + '_T1w_brain_mask.nii.gz'
msk = sitk.ReadImage(mask_filename)

print('Imaging data of subject ' + ID + ' with age ' + str(age))

print('\nMR Image')
display_image(img, window=400, level=200)

print('Brain mask')
display_image(msk)

In [None]:
def zero_mean_unit_var(image, mask):
    """Normalizes an image to zero mean and unit variance."""

    img_array = sitk.GetArrayFromImage(image)
    img_array = img_array.astype(np.float32)

    msk_array = sitk.GetArrayFromImage(mask)

    mean = np.mean(img_array[msk_array>0])
    std = np.std(img_array[msk_array>0])

    if std > 0:
        img_array = (img_array - mean) / std
        img_array[msk_array==0] = 0

    image_normalised = sitk.GetImageFromArray(img_array)
    image_normalised.CopyInformation(image)

    return image_normalised

def resample_image(image, out_spacing=(1.0, 1.0, 1.0), out_size=None, is_label=False, pad_value=0):
    """Resamples an image to given element spacing and output size."""

    original_spacing = np.array(image.GetSpacing())
    original_size = np.array(image.GetSize())

    if out_size is None:
        out_size = np.round(np.array(original_size * original_spacing / np.array(out_spacing))).astype(int)
    else:
        out_size = np.array(out_size)

    original_direction = np.array(image.GetDirection()).reshape(len(original_spacing),-1)
    original_center = (np.array(original_size, dtype=float) - 1.0) / 2.0 * original_spacing
    out_center = (np.array(out_size, dtype=float) - 1.0) / 2.0 * np.array(out_spacing)

    original_center = np.matmul(original_direction, original_center)
    out_center = np.matmul(original_direction, out_center)
    out_origin = np.array(image.GetOrigin()) + (original_center - out_center)

    resample = sitk.ResampleImageFilter()
    resample.SetOutputSpacing(out_spacing)
    resample.SetSize(out_size.tolist())
    resample.SetOutputDirection(image.GetDirection())
    resample.SetOutputOrigin(out_origin.tolist())
    resample.SetTransform(sitk.Transform())
    resample.SetDefaultPixelValue(pad_value)

    if is_label:
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resample.SetInterpolator(sitk.sitkBSpline)

    return resample.Execute(image)

In [None]:
meta_data = pd.read_csv(data_dir + 'meta/meta_data_regression_train.csv')
ids = list(meta_data['subject_id'])
files_img = [data_dir + 'images/sub-' + f + '_T1w_unbiased.nii.gz' for f in ids]
files_msk = [data_dir + 'masks/sub-' + f + '_T1w_brain_mask.nii.gz' for f in ids]

In [None]:
def preproc(img):
    img_resampled = resample_image(img, out_spacing=[4, 4, 4])
    return img_resampled
img_normalized = zero_mean_unit_var(img, msk)
img_preproc = preproc(img_normalized)
display_image(img_preproc)

In [None]:
import os
from tqdm import tqdm

preproc_dir = './output/preproc/'

if not os.path.exists(preproc_dir):
    os.makedirs(preproc_dir)
    
img_data = []
img_size = []

for idx, _ in enumerate(tqdm(range(len(files_img)), desc='Preprocessing')):
    
    preproc_filename = files_img[idx].replace(data_dir + 'images/', preproc_dir)
    
    if not os.path.exists(preproc_filename):
    
        im = sitk.ReadImage(files_img[idx])
        ma = sitk.ReadImage(files_msk[idx])
        im = zero_mean_unit_var(im, ma)
        im_preproc = preproc(im)
        sitk.WriteImage(im_preproc, preproc_filename)
    
    else:        
        im_preproc = sitk.ReadImage(preproc_filename)
    
    img_array = sitk.GetArrayFromImage(im_preproc)
    img_size = img_array.shape
    img_data.append(img_array.flatten())
    

img_data = np.array(img_data, dtype=float)

## Principal Component Analysis

In [None]:
X = img_data.T
m, n = X.shape
print('Image Size:\t' + str(img_size))
print('Dimension:\t' + str(m))
print('Samples:\t' + str(n))

In [None]:
import numpy as np
from sklearn import decomposition

pca = decomposition.PCA(whiten=False)

X_prime = pca.fit_transform(X.T).T
mu = pca.mean_
U = pca.components_.T
D = pca.singular_values_**2 / (n - 1)
exp_var = pca.explained_variance_ratio_

print('Size before PCA: ' + str(X.shape))
print('Size after PCA: ' + str(X_prime.shape))

print('Size of U: ' + str(U.shape))
print('Size of D: ' + str(D.shape))

In [None]:
fig, ax = plt.subplots()
ax.plot(np.cumsum(exp_var))
ax.set_xlabel('Mode')
ax.set_ylabel('Retained Variance')
plt.show()    

mode1 = U[:,0] * np.sqrt(D[0])*1;
mode2 = U[:,1] * np.sqrt(D[1])*1;
mode3 = U[:,2] * np.sqrt(D[2])*1;
display_image(sitk.GetImageFromArray(mu.reshape(img_size)), colormap='gray')
display_image(sitk.GetImageFromArray(mode1.reshape(img_size)), colormap='gray')
display_image(sitk.GetImageFromArray(mode2.reshape(img_size)), colormap='gray')
display_image(sitk.GetImageFromArray(mode3.reshape(img_size)), colormap='gray')

In [None]:
modes = 100
orig = X[:,0]
recon = mu + U[:,0:modes].dot(X_prime[0:modes,0])

display_image(sitk.GetImageFromArray(orig.reshape(img_size)), colormap='gray')
display_image(sitk.GetImageFromArray(recon.reshape(img_size)), colormap='gray')
display_image(sitk.GetImageFromArray((orig-recon).reshape(img_size)), colormap='gray')

In [None]:
from ipywidgets import interact, fixed

def plot_brain(mean_shape,modes,s1,s2,s3):
    image = mu + U[:,0] * s1 + U[:,1] * s2 + U[:,2] * s3
    display_image(sitk.GetImageFromArray(image.reshape(img_size)), colormap='gray')

def interactive_pca(mu,U,D):
    interact(plot_brain,mean_shape=fixed(mu),modes=fixed(U),
             **{'s%d' % (i+1): (-np.sqrt(D[i]) * 6, np.sqrt(D[i]) * 6, np.sqrt(D[i])) for i in range(3)});

interactive_pca(mu,U,D)