In [None]:
%matplotlib notebook

In [None]:
!pip install nibabel

In [None]:
# Standard libraries included in Python distribution
import os
import re
import random
import pickle

# Installed libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import nibabel as nib
import pandas as pd

### Section 1: Navigate through all images in directory iteratively and display in pandas dataframe the dimensions of the images

**Methodology** :
1. Use *Nibabel* package to convert the Nifti1 to numpy.
2. Obtain full path to file and then the shape of the image.

**Objectives** : 
1. Find out if all the images are of the same dimension.
2. Store information of all image files to be able to traverse easily.

In [None]:
# Define base path where files will be stored.
# This is unpacked from the pickle file created in Step 0.

with open('pickledHomeScratchShared.pickle', "rb") as f:
    baseHomePath,baseScratchPath,baseSharedPath = pickle.load(f)

In [None]:
# Step 1: Traverse through the whole directory and build all paths if 'nifti.img' or 'anon.img' files are found

fileShapeList = list()
for root, dirs, files in os.walk("{}/data".format(baseScratchPath)):
    for file in files:
        if re.match("^.+nifti\.img$|^.+anon\.img$",file):
            fullFilePath = root+'/'+file
            # Step 2: Convert the image to Numpy ndarray object using Nibael and find the shape of the object
            img = nib.load(fullFilePath) # Note: This is a lazy load and does not load image into memory yet
            mriImgID = re.search(r'/([^/]+)$', os.path.dirname(os.path.dirname(fullFilePath))).group(1)
            fileShapeList.append((mriImgID,file,fullFilePath,img.shape)) #Delete later

# Step 3: Display in form of pandas dataframe that contains all sessions from each MRI visit for all patients
imageInfoDf_AllSessions = pd.DataFrame(fileShapeList,columns=['MRI_ID','File name','Full path','Shape of image'])
imageInfoDf_AllSessions

In [None]:
# Total unique MRI IDs seen in list
print("Total unique MRI_ID seen in the DataFrame is : {}".format(len(set(imageInfoDf_AllSessions['MRI_ID'].values))))
# Checking if each of the images in our dataset has the exact same shape. If not, we will need to process them further.
print('Confirming the unique list of shapes of MRI files in the dataset: ',end="")
print(set(imageInfoDf_AllSessions['Shape of image'].values))

In both OASIS-1 and OASIS-2, the practitioners have takes 3-4 sessions of MRIs for each visit. In total we have 508 visits, however have 1894 sessions in total due to multiple sessions taken per visit.  
Since it is computationally expensive to work with all of the 1894 sessions, we will randomly choose **only 1 MRI per visit** using code below.

In [None]:
# Keep only 1 session from each of the MRI (each MRI has 3 sessions)
# 809 rows in total

imageInfoDf = imageInfoDf_AllSessions.groupby('MRI_ID').apply(pd.DataFrame.sample, n=1).reset_index(drop=True)
imageInfoDf

### Section 2 : Load image using Nibabel and display cross sections of image

In [None]:
# Take a random file from imageInfoDf and load using Nibabel
randImgIndex = random.randint(0,len(imageInfoDf)-1)
sampleImg = nib.load(imageInfoDf.iloc[randImgIndex]['Full path'])
print('Sample picked out is : {}'.format(imageInfoDf.iloc[randImgIndex]['Full path']))
# Get Numpy data of image
sampleImgData = sampleImg.get_fdata()

In [None]:
# Review shape of the Numpy ndarray that encodes the image data
sampleImgData.shape

The object above has 4 dimensions. The first three dimensions are those of the dimensions of the *'volume'* of the MRI (i.e. height, width and depth) and the 4th dimension is *time* as there are sometimes samples collected over different epochs while in the same MRI session. In our case, there is only one time recording in each session (i.e. 4th dimension is '1')

In [None]:
# Helper functions 

# Function to show image using Numpy ndarray
def showImg(ndarr):
    return plt.imshow(ndarr, cmap=plt.cm.gray_r, interpolation="nearest")

# Function that can retrieve number of frames for animation to use as well as initial imshow object
def retrieveFrames(planeOfViewing,sample):
    if planeOfViewing == "Sagittal":
        return sample.shape[2]
    elif planeOfViewing == "Transverse":
        return sample.shape[1]
    elif planeOfViewing == "Coronal":
        return sample.shape[0]

In [None]:
# Static image showing data
plt.close();

# Code to display image
showImg(sampleImgData[:,128,:,0]);

### Planes of MRI 
<img src="../data/static/mri_planes_gnu.jpg" width="300" height="800">

Using the dimensions of the 4-D Matrix data, we can traverse through each of the planes of the planes (i.e. 1st three dimensions since 4th dimension of time has only one value).   
Below is an attempt to view the images through an interactive animation.

In [None]:
# Choose plane along which to view the cross sections of images
planeOptions = ["Sagittal","Coronal","Transverse"]
planeOfViewing = planeOptions[2]
print('Plane chosen = {}'.format(planeOfViewing))

# Retrieve the number of frames to iterate over
framesAvailable = retrieveFrames(planeOfViewing,sampleImgData)
print(framesAvailable)

In [None]:
# Retrieve the number of frames to iterate over
framesAvailable = retrieveFrames(planeOfViewing,sampleImgData)

plt.ion() #Interactive mode set to ON
plt.close(); # Close any existing open plot

# Randomly initialize IM object
if planeOfViewing == "Sagittal":
    im = plt.imshow(sampleImgData[:,:,0,0], cmap=plt.cm.gray_r, interpolation="nearest")
elif planeOfViewing == "Transverse":
    im = plt.imshow(sampleImgData[:,0,:,0], cmap=plt.cm.gray_r, interpolation="nearest")
elif planeOfViewing == "Coronal":
    im = plt.imshow(sampleImgData[100,:,:,0], cmap=plt.cm.gray_r, interpolation="nearest") 

def animate(frame):
    if planeOfViewing == "Sagittal":
        im.set_array(sampleImgData[:,:,frame,0]) 
    elif planeOfViewing == "Transverse":
        im.set_array(sampleImgData[:,frame,:,0]) 
    elif planeOfViewing == "Coronal":
        im.set_array(sampleImgData[frame,:,:,0]) 
    return im;


# Uncomment to see animation
"""
anim= FuncAnimation(plt.gcf(), animate, frames=framesAvailable, interval=10, blit=False, repeat=False);

plt.show();
"""


### Section 3 : Stitch all images of a single MRI onto a 2-D plane

In [None]:
# Function to choose frame used by stitch_image
def img_frame(mri_arr,frame,planeOfViewing,switchColor=False):
    if planeOfViewing == "Sagittal":
        returnImg = mri_arr[:,:,frame,0] 
    elif planeOfViewing == "Transverse":
        returnImg = mri_arr[:,frame,:,0]
    elif planeOfViewing == "Coronal":
        returnImg =  mri_arr[frame,:,:,0]
    if switchColor==True:
        maxVal = returnImg.max()
        returnImg = maxVal - returnImg
    return returnImg


# Function to stitch_image
def stitch_image_allimgs(mri_arr,planeOfViewing):
    # takes arrays from get_mri_array function.
    # returns a stitched array.

    n_frames = retrieveFrames(planeOfViewing,mri_arr) # Number of frames in the MRI
    n_rows = 16 
    n_cols = int(n_frames/n_rows)
    frame_dim = (256,256) if n_frames == 128 else (256,128) # Shape of each frame
    stitch_dim = (n_rows,) # rows,cols
    
    complete_stitched_img = np.empty((0, frame_dim[1]*n_rows))
    stitched_img_row = np.empty((frame_dim[0],0))
    x = 0
    for frame in range(n_frames):
        stitched_img_row = np.hstack((stitched_img_row, img_frame(mri_arr,frame,planeOfViewing)))
        x += 1
        if x == n_rows:
            complete_stitched_img = np.vstack((complete_stitched_img, stitched_img_row))
            stitched_img_row = np.empty((frame_dim[0],0))
            x = 0
    return complete_stitched_img

# Function to stitch_image after skipping frames from start and end as specified
def stitch_image_skipframes(mri_arr,planeOfViewing,skipFrames=0,switchColor=False):
    # takes arrays from get_mri_array function.
    # returns a stitched array.

    n_frames = retrieveFrames(planeOfViewing,mri_arr) # Number of frames in the MRI
    n_rows = 16 
    n_cols = int(n_frames/n_rows)
    frame_dim = (256,256) if n_frames == 128 else (256,128) # Shape of each frame
    stitch_dim = (n_rows,) # rows,cols
    
    complete_stitched_img = np.empty((0, frame_dim[1]*n_rows))
    stitched_img_row = np.empty((frame_dim[0],0))
    x = 0
    for frame in range(skipFrames,n_frames-skipFrames):
        stitched_img_row = np.hstack((stitched_img_row, img_frame(mri_arr,frame,planeOfViewing,switchColor)))
        x += 1
        if x == n_rows:
            complete_stitched_img = np.vstack((complete_stitched_img, stitched_img_row))
            stitched_img_row = np.empty((frame_dim[0],0))
            x = 0
    return complete_stitched_img

# Function for displaying stitched images
def plot_stitched_img(stitched_img):
    # takes arrays from get_mri_array function.
    # returns a sample of the image.
    plt.close();
    plt.figure(figsize=(50,30)) 
    plt.imshow(stitched_img, cmap=plt.cm.gray_r, interpolation="nearest") 
    plt.show()

### Section 4 : Pull in CDR labels, transform to binary and then join to `ImageInfoDf`

In [None]:
# Read in files containing labels
oasis_1_datatable = pd.read_csv('{}/milestone_II_project/data/oasis_labelled_data/oasis_1_labelled_data.csv'.format(baseHomePath))
oasis_2_datatable = pd.read_excel('{}/milestone_II_project/data/oasis_labelled_data/oasis_2_labelled_data.xlsx'.format(baseHomePath))

In [None]:
# Normalize content from both datatables and make it into a single style

# Filter and transform to datasets that will be used in the project

oasis1DataSet = oasis_1_datatable.copy() # 436 rows expected from this operation
oasis1DataSet['CDR'] = oasis1DataSet.CDR.fillna(0) # Big assumption that 'NA' == not demented
oasis1DataSet['dem_labels'] = oasis1DataSet['CDR'].map(lambda x: 0 if x==0 else 1)
oasis1DataSet.rename(columns={'ID':'MRI_ID'},inplace=True)
oasis1DataSet = oasis1DataSet[['MRI_ID','dem_labels']]

oasis2DataSet = oasis_2_datatable.copy() # 373 rows will be seen in this DataFrame
oasis2DataSet['dem_labels'] = oasis2DataSet['CDR'].map(lambda x: 0 if x==0 else 1)
oasis2DataSet.rename(columns={'MRI ID':'MRI_ID'},inplace=True)
oasis2DataSet = oasis2DataSet[['MRI_ID','dem_labels']]

# Concatenate the data from the OASIS-1 and OASIS-2 labels to 436+373 = 
oasis_1_2_dataset = pd.concat([oasis1DataSet, oasis2DataSet], axis=0)


# Link imageInfoDf dataframe (i.e. containing image metadata) to oasis_1_2_dataset(i.e. containing labels) 
# to create `oasisMasterDf`

oasisMasterDf = imageInfoDf.merge(right =oasis_1_2_dataset,on='MRI_ID')
display(oasisMasterDf)

### Section 5 : Create arrays and serialize to files using pickle

#### 5.2 - Create numpy `ndarray` of all the labels in the same order as it is in `oasisMasterDf`

#### Also, optionally creating a list of the `MRI_ID`, in case this is necessary for any downstream tasks.

In [None]:
all_labels_allImgs_809 = oasisMasterDf['dem_labels'].to_numpy()
all_mri_id_allImgs_809 = oasisMasterDf['MRI_ID'].to_numpy()

print(all_labels_allImgs_809.shape)
print(all_mri_id_allImgs_809.shape)

#### 5.3 - Serialize and store each of `all_labels` and `all_mri_id`

In [None]:
with open("{}/all_labels_allImgs_809.pickle".format(baseSharedPath), "wb") as f:
    pickle.dump(all_labels_allImgs_809, f)
    
with open("{}/all_mri_id_allImgs_809.pickle".format(baseSharedPath), "wb") as f:
    pickle.dump(all_mri_id_allImgs_809, f)

### Section 6 : Create stitiched `Transverse` arrays using `skipFrames=120` on each end to reduce dimensionality to `256 -120 -120 = 16` frames

#### 6.1 - Create numpy `ndarray` of all the images in the same order as it is in `oasisMasterDf`.
#### Here, we are keeping only the middle 16 slices

In [None]:
%%time

# Convert image path to numpy ndarray as this is the one we will be using.
ImgPathAsNumpy = oasisMasterDf['Full path'].to_numpy()

# Vectorize step 1: Convert the function nib.load to a vectorized function and apply
vec_func_nib = np.vectorize(nib.load)
file_nib_to_np = vec_func_nib(ImgPathAsNumpy)
# Vectorize step 2: Convert the function .get_fdata() class function of the file_nib_to_np objects to a vectorized function and apply
vec_func_mri = np.vectorize(nib.spm2analyze.Spm2AnalyzeImage.get_fdata,otypes=[np.ndarray])
mri_data = vec_func_mri(file_nib_to_np)
# Vectorize step 3: Convert the function stitch_image_allimgs to a vectorized function and apply
vec_func_stitch = np.vectorize(stitch_image_skipframes,otypes=[np.ndarray])
skip_120_stitched_imgs_t_all = vec_func_stitch(mri_data,planeOfViewing = 'Transverse',skipFrames=120)
skip_56_stitched_imgs_s_all = vec_func_stitch(mri_data,planeOfViewing = 'Sagittal',skipFrames=56)
skip_120_stitched_imgs_c_all = vec_func_stitch(mri_data,planeOfViewing = 'Coronal',skipFrames=120)

In [None]:
with open("{}/skip_120_stitched_imgs_t_all_809.pickle".format(baseSharedPath), "wb") as f:
    pickle.dump(skip_120_stitched_imgs_t_all, f)
with open("{}/skip_56_stitched_imgs_s_all_809.pickle".format(baseSharedPath), "wb") as f:
    pickle.dump(skip_56_stitched_imgs_s_all, f)
with open("{}/skip_120_stitched_imgs_c_all_809.pickle".format(baseSharedPath), "wb") as f:
    pickle.dump(skip_120_stitched_imgs_c_all, f)

In [None]:
print(skip_120_stitched_imgs_t_all.shape)
print(skip_56_stitched_imgs_s_all.shape)
print(skip_120_stitched_imgs_c_all.shape)