In [1]:
from __future__ import print_function
import pandas as pd
import os
import numpy as np
import nibabel as nib
import sys
import math
import random
import csv
import nipy
import seaborn as sns
from datetime import datetime
from dateutil import relativedelta

from scipy import ndimage as nd
import scipy.stats as stats
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
%matplotlib inline


import tensorflow as tf
from tensorflow.python.framework import ops

# -------------------  start importing keras module ---------------------
from tensorflow import keras
from tensorflow.keras.models import Sequential, model_from_json
from tensorflow.keras.layers import Dense, Activation, Flatten, Masking# Dropout
from tensorflow.keras.layers import Conv3D, MaxPooling3D
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import glorot_uniform
from tensorflow.keras.utils import CustomObjectScope
import h5py


2022-01-27 16:08:41.604604: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1


In [2]:
IMAGE_DIR = '/trinity/home/jyu/DeepVoxels/data/images/'
MASK_DIR = '/trinity/home/jyu/DeepVoxels/data/standards/'
WANG_DIR = '/trinity/home/jyu/DeepSurvival/Notebooks/'

In [3]:
class imgZeropad:
    """
    
    
    """
    def __init__(self, img, use_padding=False):
        self.set_crop(img, use_padding)
    
    #set crop locations
    def set_crop(self, img, use_padding=False):
        # argwhere will give you the coordinates of every non-zero point
        true_data = np.argwhere(img)
        # take the smallest points and use them as the top left of your crop
        top_left = true_data.min(axis=0)
        # take the largest points and use them as the bottom right of your crop
        bottom_right = true_data.max(axis=0)
        crop_indeces = [top_left, bottom_right+1]  # plus 1 because slice isn't inclusive

        print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                crop_indeces[0][1], crop_indeces[1][1], 
                                                                crop_indeces[0][2], crop_indeces[1][2]))

        if use_padding == True:
            shape = crop_indeces[1]-crop_indeces[0]
            bottom_net = shape.astype(float)/2/2**3
            top_net = np.ceil(bottom_net)*2*2**3
            padding = (top_net-shape)/2
            print('applying [{},{},{}] padding to image..'.format(padding[0], padding[1], padding[2]))
            padding_l = padding.astype(int)
            padding_r = np.ceil(padding).astype(int)
            crop_indeces[0] -= padding_l
            crop_indeces[1] += padding_r

            print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                    crop_indeces[0][1], crop_indeces[1][1], 
                                                                    crop_indeces[0][2], crop_indeces[1][2]))
        else:
            padding = np.zeros(3)
        self.crop_indeces = crop_indeces
        self.padding = padding
        
        shape = crop_indeces[1]-crop_indeces[0]
        self.img_size = (shape[0], shape[1], shape[2])

    #crop according to crop_indeces
    def zerocrop_img(self, img, augment=False):
        if augment:
            randx = np.random.rand(3)*2-1
            new_crop = self.crop_indeces+(self.padding*randx).astype(int)

            cropped_img = img[new_crop[0][0]:new_crop[1][0],  
                              new_crop[0][1]:new_crop[1][1],
                              new_crop[0][2]:new_crop[1][2]]

            flip_axis = np.random.rand(3)
            if round(flip_axis[0]):
                cropped_img = cropped_img[::-1,:,:]
            if round(flip_axis[1]):
                cropped_img = cropped_img[:,::-1,:]
            if round(flip_axis[2]):
                cropped_img = cropped_img[:,:,::-1]
                
        else:
            cropped_img = img[self.crop_indeces[0][0]:self.crop_indeces[1][0],  
                              self.crop_indeces[0][1]:self.crop_indeces[1][1],
                              self.crop_indeces[0][2]:self.crop_indeces[1][2]]
            
        return cropped_img


In [4]:
def retrieve_data(patient_index, img_size, img_scale=1.0, mask=None, augment=False, mode=[]):
    """
    Function to retrieve data from a single patient
    
    Inputs:
    - patient_index = list of bigrfullnames identifying scans
    - img_size = size of MRI images
    - img_scale = scale of the MRI scans [default = 1]
    - mask = mask image if necessary [default = None]
    - augment = Boolean if data augmentation should be used [default = False]
    - mode = train, validate or test (used to find appropriate data)
    
    Outputs:
    - img_data = MRI data
    - input2 = sex of patient
    - label = dementia_label (event=1, no event=0)
    - time = event time 

    """
    # Retrieve patient info and label(=SNP) of the patient
    if mode == 'train':
        patient_info = train_label_set.loc[patient_index]
    elif mode == 'validate':
        patient_info = validation_label_set.loc[patient_index]
    elif mode == 'test':
        patient_info = test_label_set.loc[patient_index]
    
    # Get patient label (incident dementia or not)
    label = patient_info.get('dementia')
    
    # Get second input (sex)
    input2 = patient_info.get('sex')    
    
    # Get event time
    time = patient_info.get('event_time')

    # Get image
    patient_filename = patient_index.strip()+'_GM_to_template_GM_mod.nii.gz'
    img = nib.load(IMAGE_DIR+patient_filename)
                 
        
    img_data = img.get_data()
    # Apply mask to imagedata (if requested)
    if mask is not None:
#        img_data = img_data*mask+(mask-1.0)
        img_data = img_data*mask
        img_data = zerocrop_img(img_data)

    
    # Rescale imagedata (if requested)
    if img_scale < 1.0:
        img_data = resize_img(img_data, img_size)
    
    return np.array(img_data), np.array(int(input2)), label, time 

In [5]:
def generate_batch(patients, img_size, img_scale=1.0, mask=None, augment=False, mode=[]):
    """
    iterate through a batch of patients and get the corresponding data
    
    Input: 
    - patients = list of bigrfullnames identifying scans
    - img_size = size of MRI images
    - img_scale = scale of the MRI scans [default = 1]
    - mask = mask image if necessary [default = None]
    - augment = Boolean if data augmentation should be used [default = False]
    - mode
    
    Outputs:
    - [input data] = covariates
    - [label data] = label (incident dementia or not) and riskset of patient

    """    
    #get data of each patient
    img_data = []
    label_data = []
    time = []
    sex = []
    for patient in patients:
        try:
            x, x2, y, t = retrieve_data(patient, img_size, img_scale, mask, augment, mode)
            img_data.append(x)
            sex.append(x2)
            label_data.append(y)
            time.append(t)
        except KeyError as e:
            print('\nERROR: No label found for file {}'.format(patient))
        except IOError as e:            
            print('\nERROR: Problem loading file {}. File probably corrupted.'.format(patient))
            
    # Make riskset
    label_riskset = generate_riskset(np.array(time))
    #convert to correct input format for network
    img_data = np.array(img_data)
    img_data = np.reshape(img_data,(-1, 160, 192, 144, 1))

    sex_data = np.array(sex)
    
    label_data = np.array([label_data])
    label_riskset = np.array(label_riskset)
    
    label_data = label_data.transpose()
    label_data_out = np.hstack((label_data,label_riskset))

    return ([img_data, sex_data], [label_data_out])

In [6]:
def data_generator(patient_list, img_size, batch_size, img_scale=1.0, mask=None, augment=False, mode=[], shuffle=True):
    """
    Provides the inputs and the label to the convolutional network during training
    
    Input:
    - patient_list = list of bigrfullnames identifying scans
    - img_size = size of MRI images
    - batch_size = size of batch used in training
    - img_scale = scale of the MRI scans [default = 1]
    - mask = mask image if necessary [default = None]
    - augment = Boolean if data augmentation should be used [default = False]
    
    Output:
    - Data = continous data output for batches used in training the network

    """
    while 1:
        if shuffle:
            #shuffle list/order of patients
            pl_shuffled = random.sample(patient_list, len(patient_list))
            #divide list of patients into batches
            batch_size = int(batch_size)
            patient_sublist = [pl_shuffled[p:p+batch_size] for p in range(0, len(pl_shuffled), batch_size)]
        else:
            batch_size = int(batch_size)
            patient_sublist = [patient_list[p:p+batch_size] for p in range(0, len(patient_list), batch_size)]
        count = 0
        data = []
        for batch in range(0, len(patient_sublist)):         
            #get the data of a batch samples/patients
            data.append(generate_batch(patient_sublist[batch], img_size, img_scale, mask, augment, mode))
            count = count + len(patient_sublist[batch])
            #yield the data and pop for memory clearing
            yield data.pop()

In [7]:
#crops the zero-margin of a 3D image (based on mask)
def zerocrop_img(img, set_crop=False, padding=False):
    global crop_indeces
    
    #set crop locations if there are none yet or if requested
    if (crop_indeces is None) or (set_crop):
        # argwhere will give you the coordinates of every non-zero point
        true_data = np.argwhere(img)
        # take the smallest points and use them as the top left of your crop
        top_left = true_data.min(axis=0)
        # take the largest points and use them as the bottom right of your crop
        bottom_right = true_data.max(axis=0)
        crop_indeces = [top_left, bottom_right+1]  # plus 1 because slice isn't inclusive
        
        print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                crop_indeces[0][1], crop_indeces[1][1], 
                                                                crop_indeces[0][2], crop_indeces[1][2]))

        if padding == True:
            shape = crop_indeces[1]-crop_indeces[0]
            bottom_unet = shape.astype(float)/2/2**3
            top_unet = np.ceil(bottom_unet)*2*2**3
            padding = (top_unet-shape)/2
            print('applying [{},{},{}] padding to image..'.format(padding[0], padding[1], padding[2]))
            padding_l = padding.astype(int)
            padding_r = np.ceil(padding).astype(int)
            crop_indeces[0] -= padding_l
            crop_indeces[1] += padding_r

            print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                    crop_indeces[0][1], crop_indeces[1][1], 
                                                                    crop_indeces[0][2], crop_indeces[1][2]))
    
    try:
        cropped_img = img[crop_indeces[0][0]:crop_indeces[1][0],  
                          crop_indeces[0][1]:crop_indeces[1][1],
                          crop_indeces[0][2]:crop_indeces[1][2]]
        return cropped_img
    except ValueError:
        print('ERROR: No crop_indeces defined for zerocrop. Returning full image...')
        return img

In [8]:
def generate_riskset(event_times):
    """
    Generates the riskset for every individual. Riskset is the set of individuals that have a 
    longer event time and are thus at risk of experiencing the event : Tj>=Ti
    
    Input:
    - label_data = dataframe with file name, event times and other labels that do not get used
    
    Output:
    - riskset = square matrix in which row i is the riskset of individual i compared to all 
    individuals j. Entry is true if Tj>=Ti, so individual j is 'at risk'.
    """

    o = np.argsort(-event_times, kind="mergesort")
    n_samples = len(event_times)
    risk_set = np.zeros((n_samples, n_samples), dtype=np.bool_)
    for i_org, i_sort in enumerate(o):
        ti = event_times[i_sort]
        k = i_org
        while k < n_samples and ti == event_times[o[k]]:
            k += 1
        risk_set[i_sort, o[:k]] = True
    return risk_set

In [15]:
# A small test
ages = []
birthdates = ['11/04/1930','15/05/1952','26/01/1943','18/06/1944']

for date in  birthdates:
    birthdate = datetime.strptime(date, '%d/%m/%Y')
    scandate = datetime.strptime('8/12/2020', '%d/%m/%Y')
    ages.append((scandate-birthdate).days/ 365.25)


test_set = ['ergomri_1604_v_1975212_1069','ergomri_783_mri_9973399_563','ergomri_1391_m_4009993_908','ergo5mri_1420_9319504_7837'] 
data = {'bigrfullname':  ['ergomri_1604_v_1975212_1069','ergomri_783_mri_9973399_563','ergomri_1391_m_4009993_908','ergo5mri_1420_9319504_7837'],
        'age': ages,
        'sex': [1,0,0,1]
        }

test_label_set = pd.DataFrame (data, columns = ['bigrfullname','age','sex','dementia','event_time'])
test_label_set = test_label_set.set_index('bigrfullname')

In [16]:
test_label_set

Unnamed: 0_level_0,age,sex,dementia,event_time
bigrfullname,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ergomri_1604_v_1975212_1069,90.661191,1,,
ergomri_783_mri_9973399_563,68.566735,0,,
ergomri_1391_m_4009993_908,77.867214,0,,
ergo5mri_1420_9319504_7837,76.473648,1,,


In [17]:
#initialize
mask_file = 'Brain_GM_mask_1mm_MNI_kNN_conservative.nii.gz' #None
#init_preprocessing(prepdata_name)
use_padding = True
crop_indeces = None
img_scale = 1

#initialize
#init_preprocessing(prepdata_name)

#setup image_size
if mask_file is not None:
    mask = nib.load(MASK_DIR+mask_file).get_data()
    #when applying a mask, initialize zerocropping
    img_size = np.array(np.array(zerocrop_img(mask, True, padding=use_padding)).shape)
else:
    mask = None
    img_size = np.array(np.array(nib.load(IMAGE_DIR+os.listdir(IMAGE_DIR)[0]).get_data()).shape)

img_size = [int(math.ceil(img_d)) for img_d in img_size*img_scale]
print('data shape:', img_size)


* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  del sys.path[0]


crop set to x[26:171], y[28:209], z[19:156]
applying [7.5,5.5,3.5] padding to image..
crop set to x[19:179], y[23:215], z[16:160]
data shape: [160, 192, 144]


In [18]:
model = tf.keras.models.load_model(WANG_DIR+'model_age_5h.h5')

In [24]:
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 192, 14 0                                            
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, 80, 96, 72, 3 4032        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 80, 96, 72, 3 128         conv3d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 80, 96, 72, 3 0           batch_normalization_1[0][0]      
____________________________________________________________________________________________

In [19]:
batch_size=1

In [20]:
test_predictions = model.predict_generator(data_generator(list(test_set), img_size, batch_size, img_scale, mask=mask, mode='test',shuffle=False),
                                             steps=int(math.ceil(float(len(test_set))/batch_size)),
                                             max_queue_size=3, verbose=1)


* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0




In [21]:
results = {'Patient': test_set,
        'Real age':  ages,
        'Predicted age': test_predictions.tolist()        }

result_df = pd.DataFrame(results, columns = ['Patient','Real age','Predicted age'])

In [22]:
result_df

Unnamed: 0,Patient,Real age,Predicted age
0,ergomri_1604_v_1975212_1069,90.661191,[78.38008117675781]
1,ergomri_783_mri_9973399_563,68.566735,[68.54951477050781]
2,ergomri_1391_m_4009993_908,77.867214,[81.1343765258789]
3,ergo5mri_1420_9319504_7837,76.473648,[76.70735168457031]
