In [1]:
from __future__ import print_function
import pandas as pd
import os
import numpy as np
import nibabel as nib
import sys
import math
import random
import csv
import nipy
import seaborn as sns
from datetime import datetime
from dateutil import relativedelta
import gc

from scipy import ndimage as nd
import scipy.stats as stats
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import FormatStrFormatter
%matplotlib inline


import tensorflow as tf
from tensorflow.python.framework import ops

# -------------------  start importing keras module ---------------------
from tensorflow import keras
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.layers import Dense, Activation, Flatten, Conv3D, MaxPooling3D, BatchNormalization, Dropout, GlobalAveragePooling3D
from tensorflow.keras.layers import Input, concatenate, multiply, add, Reshape, Lambda
from tensorflow.keras import regularizers
from tensorflow.keras import backend as K

import h5py


2022-02-17 01:45:11.886102: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1


In [2]:
IMAGE_DIR = 'path to VBM image directory'
MASK_DIR = 'path to Mask image directory'
MODEL_DIR = 'path to load and save model/results'

## 1.Functions

### Save model and training history

In [3]:
# Track history
class LossHistory(keras.callbacks.Callback):
    def __init__(self, epochs, modelversion):
        self.ne = epochs
        self.mv = modelversion        
    
    def on_train_begin(self, logs={}):
        self.batch_num = 0
        self.batch_losses = []
        self.epoch_losses = []

        print('Start training ...')
        
        self.stats = ['loss'] #TODO: check
        self.logs = [{} for _ in range(self.ne)]

        self.evolution_file = 'evolution_'+self.mv+'.csv'
        with open(MODEL_DIR+self.evolution_file, "w") as f:
            f.write(';'.join(self.stats + ['val_'+s for s in self.stats]) + "\n")
        
        self.progress_file = 'training_progress_'+self.mv+'.out'
        with open(MODEL_DIR+self.progress_file, "w") as f:
            f.write('Start training ...\n')
            
    def on_batch_end(self, epoch, logs={}):
        self.batch_losses.append(logs.get('loss'))

        
        with open(MODEL_DIR+self.progress_file, "a") as f:
            f.write('  >> batch {} >> loss:{} \r'.format(self.batch_num, self.batch_losses[-1]))
        
        self.batch_num += 1
        
    def on_epoch_end(self, epoch, logs={}):
        self.batch_num = 0
        self.epoch_losses.append(logs.get('loss'))
        
        
#        print('\n    >>> logs:', logs)
        self.logs[epoch] = logs
#        evolution_file = 'evolution_'+self.mv+'.csv'
        loss_fig = 'loss_'+self.mv+'.png'
        
        with open(MODEL_DIR+self.evolution_file, "a") as myfile:
            num_stats = len(self.stats)
            
            plt.figure(figsize=(40, num_stats*15))
            plt.suptitle(loss_fig, fontsize=34, fontweight='bold')

            gs = gridspec.GridSpec(len(self.stats), 2) 

            last_losses = []
            last_val_losses = []
            for idx, stat in enumerate(self.stats):
                losses = [self.logs[e][stat] for e in range(epoch+1)]
                last_losses.append('{}'.format(losses[-1]))
                val_losses = [self.logs[e]['val_'+stat] for e in range(epoch+1)]
                last_val_losses.append('{}'.format(val_losses[-1]))

                plt.subplot(gs[idx,0])
                plt.ylabel(stat, fontsize=34)
                plt.plot(range(0, epoch+1), losses, '-', color = 'b')
                plt.plot(range(0, epoch+1), val_losses, '-', color = 'r')
                plt.tick_params(axis='x', labelsize=30)
                plt.tick_params(axis='y', labelsize=30)
                plt.grid(True)

                recent_n = 10
                recent_losses = losses[-recent_n:]
                recent_val_losses = val_losses[-recent_n:]
                miny_range = 5
                lowery = min([min(losses), recent_losses[-1]-miny_range, min(val_losses), recent_val_losses[-1]-miny_range])
                uppery = max([max(recent_losses), recent_losses[-1]+miny_range, max(recent_val_losses), recent_val_losses[-1]+miny_range])
                plt.subplot(gs[idx,1])
                plt.ylabel(stat, fontsize=34)
                plt.plot(range(0, epoch+1), losses, '-', color = 'b')
                plt.plot(range(0, epoch+1), val_losses, '-', color = 'r')
                plt.ylim(lowery, uppery)
                plt.tick_params(axis='x', labelsize=30)
                plt.tick_params(axis='y', labelsize=30)
                plt.grid(True)
                
            myfile.write(';'.join(last_losses + last_val_losses) + '\n')
            try:                
                plt.savefig(MODEL_DIR+loss_fig)
            except Exception as inst:
                print(type(inst))
                print(inst)
            plt.close()
        

        with open(MODEL_DIR+self.progress_file, "a") as f:
            f.write('epoch {}/{}:\n'.format(epoch, self.ne))
            for idx, stat in enumerate(self.stats):
                f.write('  {} = {}\n  val_{} = {}\n'.format(stat, last_losses[idx], stat, last_val_losses[idx]))

        gc.collect()


In [4]:
#save the best model on validation set
def save_checkpoint(name, model):
    save_model(name, model)
    weights_file = 'model_'+name+'.h5'
    return ModelCheckpoint(MODEL_DIR+weights_file, monitor='val_loss', verbose=0, save_best_only=True, mode='auto')

In [5]:
def save_model(name, model):
    model_file = 'model_'+name+'.json'
    # serialize model to JSON
    with open(MODEL_DIR+model_file, 'w') as json_file:
        json_file.write(model.to_json())
    print('Saved model to '+MODEL_DIR+model_file)

In [6]:
def save_history(name, history, score, sets, distrs):
    history_file = 'history_'+name+'.h5'

    f = h5py.File(MODEL_DIR+history_file, 'w')

    f.create_dataset('batch_losses', data=history.batch_losses)
    f.create_dataset('epoch_losses', data=history.epoch_losses)
    f.create_dataset("score", data=score)

    f.close()

    print('Saved history to '+MODEL_DIR+history_file)

### Image processing

In [7]:
class imgZeropad:

    def __init__(self, img, use_padding=False):
        self.set_crop(img, use_padding)
    
    #set crop locations
    def set_crop(self, img, use_padding=False):
        # argwhere will give you the coordinates of every non-zero point
        true_data = np.argwhere(img)
        # take the smallest points and use them as the top left of your crop
        top_left = true_data.min(axis=0)
        # take the largest points and use them as the bottom right of your crop
        bottom_right = true_data.max(axis=0)
        crop_indeces = [top_left, bottom_right+1]  # plus 1 because slice isn't inclusive

        print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                crop_indeces[0][1], crop_indeces[1][1], 
                                                                crop_indeces[0][2], crop_indeces[1][2]))

        if use_padding == True:
            shape = crop_indeces[1]-crop_indeces[0]
            bottom_net = shape.astype(float)/2/2**3
            top_net = np.ceil(bottom_net)*2*2**3
            padding = (top_net-shape)/2
            print('applying [{},{},{}] padding to image..'.format(padding[0], padding[1], padding[2]))
            padding_l = padding.astype(int)
            padding_r = np.ceil(padding).astype(int)
            crop_indeces[0] -= padding_l
            crop_indeces[1] += padding_r

            print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                    crop_indeces[0][1], crop_indeces[1][1], 
                                                                    crop_indeces[0][2], crop_indeces[1][2]))
        else:
            padding = np.zeros(3)
        self.crop_indeces = crop_indeces
        self.padding = padding
        
        shape = crop_indeces[1]-crop_indeces[0]
        self.img_size = (shape[0], shape[1], shape[2])

    #crop according to crop_indeces
    def zerocrop_img(self, img, augment=False):
        if augment:
            randx = np.random.rand(3)*2-1
            new_crop = self.crop_indeces+(self.padding*randx).astype(int)

            cropped_img = img[new_crop[0][0]:new_crop[1][0],  
                              new_crop[0][1]:new_crop[1][1],
                              new_crop[0][2]:new_crop[1][2]]

            flip_axis = np.random.rand(3)
            if round(flip_axis[0]):
                cropped_img = cropped_img[::-1,:,:]
            if round(flip_axis[1]):
                cropped_img = cropped_img[:,::-1,:]
            if round(flip_axis[2]):
                cropped_img = cropped_img[:,:,::-1]
                
        else:
            cropped_img = img[self.crop_indeces[0][0]:self.crop_indeces[1][0],  
                              self.crop_indeces[0][1]:self.crop_indeces[1][1],
                              self.crop_indeces[0][2]:self.crop_indeces[1][2]]
            
        return cropped_img


In [8]:
#crops the zero-margin of a 3D image (based on mask)
def zerocrop_img(img, set_crop=False, padding=False):
    global crop_indeces
    
    #set crop locations if there are none yet or if requested
    if (crop_indeces is None) or (set_crop):
        # argwhere will give you the coordinates of every non-zero point
        true_data = np.argwhere(img)
        # take the smallest points and use them as the top left of your crop
        top_left = true_data.min(axis=0)
        # take the largest points and use them as the bottom right of your crop
        bottom_right = true_data.max(axis=0)
        crop_indeces = [top_left, bottom_right+1]  # plus 1 because slice isn't inclusive
        
        print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                crop_indeces[0][1], crop_indeces[1][1], 
                                                                crop_indeces[0][2], crop_indeces[1][2]))

        if padding == True:
            shape = crop_indeces[1]-crop_indeces[0]
            bottom_unet = shape.astype(float)/2/2**3
            top_unet = np.ceil(bottom_unet)*2*2**3
            padding = (top_unet-shape)/2
            print('applying [{},{},{}] padding to image..'.format(padding[0], padding[1], padding[2]))
            padding_l = padding.astype(int)
            padding_r = np.ceil(padding).astype(int)
            crop_indeces[0] -= padding_l
            crop_indeces[1] += padding_r

            print('crop set to x[{}:{}], y[{}:{}], z[{}:{}]'.format(crop_indeces[0][0], crop_indeces[1][0], 
                                                                    crop_indeces[0][1], crop_indeces[1][1], 
                                                                    crop_indeces[0][2], crop_indeces[1][2]))
    
    try:
        cropped_img = img[crop_indeces[0][0]:crop_indeces[1][0],  
                          crop_indeces[0][1]:crop_indeces[1][1],
                          crop_indeces[0][2]:crop_indeces[1][2]]
        return cropped_img
    except ValueError:
        print('ERROR: No crop_indeces defined for zerocrop. Returning full image...')
        return img

### CNN image processing by batches

In [9]:
def retrieve_data(patient_index, img_size, img_scale=1.0, mask=None, augment=False, mode=[]):
    """
    Function to retrieve data from a single patient
    
    Inputs:
    - patient_index = list of bigrfullnames identifying scans
    - img_size = size of MRI images
    - img_scale = scale of the MRI scans [default = 1]
    - mask = mask image if necessary [default = None]
    - augment = Boolean if data augmentation should be used [default = False]
    - mode = train, validate or test (used to find appropriate data)
    
    Outputs:
    - img_data = MRI data
    - input2 = sex
    - label = age

    """
    # Retrieve patient info and label(=SNP) of the patient
    if mode == 'train':
        patient_info = train_label_set.loc[patient_index]
    elif mode == 'validate':
        patient_info = validation_label_set.loc[patient_index]
    elif mode == 'test':
        patient_info = test_label_set.loc[patient_index]
    else: # validation set might not use validation flag
        patient_info = validation_label_set.loc[patient_index]
    
    # Get patient label (incident dementia or not)
    label = patient_info.get('age')
    
    # Get second input (sex)
    input2 = patient_info.get('sex')    
    

    # Get image
    patient_filename = patient_index.strip()+'_GM_to_template_GM_mod.nii.gz'
    img = nib.load(IMAGE_DIR+patient_filename)  
    img_data = img.get_data()
    
    # Apply mask to imagedata (if requested)
    if mask is not None:
        img_data = img_data*mask
        img_data = zerocrop_img(img_data)

    # Rescale imagedata (if requested)
    if img_scale < 1.0:
        img_data = resize_img(img_data, img_size)
    
    return np.array(img_data), np.array(int(input2)), label

In [10]:
def generate_batch(patients, img_size, img_scale=1.0, mask=None, augment=False, mode=[]):
    """
    iterate through a batch of patients and get the corresponding data
    
    Input: 
    - patients = list of bigrfullnames identifying scans
    - img_size = size of MRI images
    - img_scale = scale of the MRI scans [default = 1]
    - mask = mask image if necessary [default = None]
    - augment = Boolean if data augmentation should be used [default = False]
    - mode
    
    Outputs:
    - [input data] = sex
    - [label data] = age

    """    
    #get data of each patient
    img_data = []
    label_data = []
    sex = []
    for patient in patients:
        try:
            x, x2, y = retrieve_data(patient, img_size, img_scale, mask, augment, mode)
            img_data.append(x)
            sex.append(x2)
            label_data.append(y)
        except KeyError as e:
            print('\nERROR: No label found for file {}'.format(patient))
        except IOError as e:            
            print('\nERROR: Problem loading file {}. File probably corrupted.'.format(patient))
            

    #convert to correct input format for network
    img_data = np.array(img_data)
    img_data = np.reshape(img_data,(-1, 160, 192, 144, 1))

    sex_data = np.array(sex)
    
    label_data = np.array([label_data])


    return ([img_data, sex_data], [label_data])

In [11]:
def data_generator(patient_list, img_size, batch_size, img_scale=1.0, mask=None, augment=False, mode=[], shuffle=True):
    """
    Provides the inputs and the label to the convolutional network during training
    
    Input:
    - patient_list = list of bigrfullnames identifying scans
    - img_size = size of MRI images
    - batch_size = size of batch used in training
    - img_scale = scale of the MRI scans [default = 1]
    - mask = mask image if necessary [default = None]
    - augment = Boolean if data augmentation should be used [default = False]
    
    Output:
    - Data = continous data output for batches used in training the network

    """
    while 1:
        if shuffle:
            #shuffle list/order of patients
            pl_shuffled = random.sample(patient_list, len(patient_list))
            #divide list of patients into batches
            batch_size = int(batch_size)
            patient_sublist = [pl_shuffled[p:p+batch_size] for p in range(0, len(pl_shuffled), batch_size)]
        else:
            batch_size = int(batch_size)
            patient_sublist = [patient_list[p:p+batch_size] for p in range(0, len(patient_list), batch_size)]
        count = 0
        data = []
        for batch in range(0, len(patient_sublist)):         
            #get the data of a batch samples/patients
            data.append(generate_batch(patient_sublist[batch], img_size, img_scale, mask, augment, mode))
            count = count + len(patient_sublist[batch])
            #yield the data and pop for memory clearing
            yield data.pop()

## 2. Prepare data

In [12]:
# A small test
ages = []
birthdates = ['11/04/1930','15/05/1952','26/01/1943','18/06/1944']

for date in  birthdates:
    birthdate = datetime.strptime(date, '%d/%m/%Y')
    scandate = datetime.strptime('8/12/2020', '%d/%m/%Y')
    ages.append((scandate-birthdate).days/ 365.25)


data_set = ['ergomri_1604_v_1975212_1069','ergomri_783_mri_9973399_563','ergomri_1391_m_4009993_908','ergo5mri_1420_9319504_7837'] 
data = {'bigrfullname':  ['ergomri_1604_v_1975212_1069','ergomri_783_mri_9973399_563','ergomri_1391_m_4009993_908','ergo5mri_1420_9319504_7837'],
        'age': ages,
        'sex': [1,0,0,1]
        }

data_label_set = pd.DataFrame (data, columns = ['bigrfullname','age','sex'])
data_label_set = data_label_set.set_index('bigrfullname')

In [13]:
data_set

['ergomri_1604_v_1975212_1069',
 'ergomri_783_mri_9973399_563',
 'ergomri_1391_m_4009993_908',
 'ergo5mri_1420_9319504_7837']

In [14]:
data_label_set

Unnamed: 0_level_0,age,sex
bigrfullname,Unnamed: 1_level_1,Unnamed: 2_level_1
ergomri_1604_v_1975212_1069,90.661191,1
ergomri_783_mri_9973399_563,68.566735,0
ergomri_1391_m_4009993_908,77.867214,0
ergo5mri_1420_9319504_7837,76.473648,1


In [15]:
#split to train/validation/test set

print('--- Preparing datasets---')

print('Keras backend: '+keras.backend.backend())

#label dataframe of patients (here we just use a informal train/validation/test set for convenience)
train_label_set = data_label_set
validation_label_set = data_label_set
test_label_set = data_label_set

#list of patients
train_set = data_set
validation_set = data_set
test_set = data_set

#print info per set
train_size = len(train_label_set)
print('train samples: {}'.format(train_size))

validation_size = len(validation_label_set)
print('validation samples: {}'.format(validation_size))

test_size = len(test_label_set)
print('test samples: {}'.format(test_size))


--- Preparing datasets---
Keras backend: tensorflow
train samples: 4
validation samples: 4
test samples: 4


## 3. CNN model

In [40]:
##Train model from scratch

# def cnn_model(input_shape):

#     #left input branch ------------------------
#     input1 = Input(input_shape)

#     c1 = Conv3D(32, kernel_size=(5,5,5), strides=(2,2,2), padding='same')(input1)
#     c1 = BatchNormalization()(c1)
#     c1 = Activation('relu')(c1)
    
#     c2 = Conv3D(32, (3,3,3), strides=(1,1,1), padding='same')(c1)
#     c2 = BatchNormalization()(c2)
#     c2 = Activation('relu')(c2)
#     p2 = MaxPooling3D(pool_size=(2, 2, 2))(c2)
    
#     c3 = Conv3D(48, (3,3,3), strides=(1,1,1), padding='same')(p2)
#     c3 = BatchNormalization()(c3)
#     c3 = Activation('relu')(c3)
    
#     c4 = Conv3D(48, (3,3,3), strides=(1,1,1), padding='same')(c3)
#     c4 = BatchNormalization()(c4)
#     c4 = Activation('relu')(c4)
#     p4 = MaxPooling3D(pool_size=(2, 2, 2))(c4)
    
#     c5 = Conv3D(64, (3,3,3), strides=(1,1,1), padding='same')(p4)
#     c5 = BatchNormalization()(c5)
#     c5 = Activation('relu')(c5)
    
#     c6 = Conv3D(64, (3,3,3), strides=(1,1,1), padding='same')(c5)
#     c6 = BatchNormalization()(c6)
#     c6 = Activation('relu')(c6)
#     p6 = MaxPooling3D(pool_size=(2, 2, 2))(c6)

#     c7 = Conv3D(80, (3,3,3), strides=(1,1,1), padding='same')(p6)
#     c7 = BatchNormalization()(c7)
#     c7 = Activation('relu')(c7)
    
#     c8 = Conv3D(80, (3,3,3), strides=(1,1,1), padding='same')(c7)
#     c8 = BatchNormalization()(c8)
#     c8 = Activation('relu')(c8)

#     x1 = GlobalAveragePooling3D()(c8)

#     #right input branch ------------------------
#     input2 = Input((1,))

#     #merging braches into final model ----------
#     y1 = concatenate([x1, input2])   # other modes: multiply, concatenate, dot
    
#     y2 = Dense(32, activation='relu')(y1)
#     y2 = Dropout(0.2)(y2)
    
#     final = Dense(1, activation='linear')(y2)

#     model = Model(inputs=[input1, input2], outputs=final)
    
#     adam_opt = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.0001)
#     model.compile(loss='mean_squared_error',
#                   optimizer=adam_opt,
#                   metrics=['mae', 'mse'])
    
#     return model

#model = cnn_model((160, 192, 144, 1))
#model.summary()

In [16]:
#Use pre-trained model: 'model_age_5h.h5' in models file (recommended)
model = tf.keras.models.load_model(MODEL_DIR+'model_age_5h.h5')
model.summary()

2022-02-17 01:51:43.175700: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /opt/ohpc/pub/easybuild/software/TensorFlow/2.2.0-fosscuda-2019b-Python-3.7.4/lib:/opt/ohpc/pub/easybuild/software/NCCL/2.4.8-gcccuda-2019b/lib:/opt/ohpc/pub/easybuild/software/cuDNN/7.6.4.38-gcccuda-2019b/lib64:/opt/ohpc/pub/easybuild/software/HDF5/1.10.5-gompic-2019b/lib:/opt/ohpc/pub/easybuild/software/Szip/2.1.1-GCCcore-8.3.0/lib:/opt/ohpc/pub/easybuild/software/SciPy-bundle/2019.10-fosscuda-2019b-Python-3.7.4/lib/python3.7/site-packages/numpy/core/lib:/opt/ohpc/pub/easybuild/software/SciPy-bundle/2019.10-fosscuda-2019b-Python-3.7.4/lib:/opt/ohpc/pub/easybuild/software/Python/3.7.4-GCCcore-8.3.0/lib:/opt/ohpc/pub/easybuild/software/libffi/3.2.1-GCCcore-8.3.0/lib64:/opt/ohpc/pub/easybuild/software/libffi/3.2.1-GCCcore-8.3.0/lib:/opt/ohpc/pub/easybui

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 160, 192, 14 0                                            
__________________________________________________________________________________________________
conv3d_1 (Conv3D)               (None, 80, 96, 72, 3 4032        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 80, 96, 72, 3 128         conv3d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 80, 96, 72, 3 0           batch_normalization_1[0][0]      
____________________________________________________________________________________________

##  

## 4. Initialize

In [17]:
print('--- Starting initialization ---')

#choose variables
use_padding = True
crop_indeces = None
augment_train = True
img_scale = 1.0

batch_size = 1
patients_per_epoch = 4 #steps_per_epoch = patients_per_epoch/batch_size
epochs = 4


mask_file = 'Brain_GM_mask_1mm_MNI_kNN_conservative.nii.gz' #None

#setup image_size
if mask_file is not None:
    mask = nib.load(MASK_DIR+mask_file).get_data()
    #when applying a mask, initialize zerocropping
    img_size = np.array(np.array(zerocrop_img(mask, True, padding=use_padding)).shape)
else:
    mask = None
    img_size = np.array(np.array(nib.load(IMAGE_DIR+os.listdir(IMAGE_DIR)[0]).get_data()).shape)

img_size = [int(math.ceil(img_d)) for img_d in img_size*img_scale]
print('data shape:', img_size)




--- Starting initialization ---



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0


crop set to x[26:171], y[28:209], z[19:156]
applying [7.5,5.5,3.5] padding to image..
crop set to x[19:179], y[23:215], z[16:160]
data shape: [160, 192, 144]


## 5. Train model

In [18]:
#train the model and keep track of progress with history
print('--- Starting training of model ---')

modelversion='BrainAge'
history = LossHistory(epochs, modelversion)
checkpoint = save_checkpoint(modelversion, model)
# stoptraining = EarlyStopping(monitor='val_loss', min_delta=0, patience=200, verbose=0, mode='min')

patients_per_epoch = min(patients_per_epoch, train_size)
steps_per_epoch = int(math.ceil(float(patients_per_epoch)/batch_size))
validation_steps = int(math.ceil(float(validation_size)/batch_size))

model.fit_generator(data_generator(list(train_set), img_size, batch_size, img_scale, mask, augment=augment_train, mode='train'),
                    steps_per_epoch=steps_per_epoch,
                    epochs=epochs,
                    validation_data=data_generator(list(validation_set), img_size, batch_size, img_scale, mask),
                    validation_steps=validation_steps,
                    max_queue_size=1,
                    callbacks=[history, checkpoint])#, stoptraining])

print('Succesfully trained the model.')

--- Starting training of model ---
Saved model to /trinity/home/jyu/DeepSurvival/Notebooks/model_BrainAge.json
Instructions for updating:
Please use Model.fit, which supports generators.



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0


Start training ...
Epoch 1/4
Epoch 2/4
Epoch 3/4
Epoch 4/4
Succesfully trained the model.


In [19]:
#prediction on test set
model = tf.keras.models.load_model(MODEL_DIR+'model_'+modelversion+'.h5')
test_predictions = model.predict_generator(data_generator(list(test_set), img_size, batch_size, img_scale, mask=mask, mode='test',shuffle=False),
                                             steps=int(math.ceil(float(len(test_set))/batch_size)),
                                             max_queue_size=3, verbose=1)

Instructions for updating:
Please use Model.predict, which supports generators.



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0




In [20]:
results = {'Patient': test_set,
        'Real age':  ages,
        'Predicted age': test_predictions.tolist()        }

result_df = pd.DataFrame(results, columns = ['Patient','Real age','Predicted age'])

In [21]:
result_df

Unnamed: 0,Patient,Real age,Predicted age
0,ergomri_1604_v_1975212_1069,90.661191,[81.20193481445312]
1,ergomri_783_mri_9973399_563,68.566735,[70.59112548828125]
2,ergomri_1391_m_4009993_908,77.867214,[83.75657653808594]
3,ergo5mri_1420_9319504_7837,76.473648,[79.26387786865234]
