In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import glob
import nibabel as nib
import os
import time

import pandas as pd
import numpy as np
import cv2
from skimage.transform import resize

from mricode.utils import log_textfile, createPath, data_generator
from mricode.utils import copy_colab
from mricode.utils import return_iter
from mricode.utils import return_csv
from mricode.config import config

from mricode.models.DenseNet_NoDict_cross import MyDenseNet

import tensorflow as tf
from tensorflow.keras.layers import Conv3D
from tensorflow import nn
from tensorflow.python.ops import nn_ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.utils import conv_utils

tf.__version__

















'2.0.0'

In [3]:
tf.test.is_gpu_available()

True

In [4]:
path_output = './output/'
path_tfrecords = '/data2/res64/down/'
path_csv = '/data2/csv/'
filename_res = {'train': 'intell_residual_train.csv', 'val': 'intell_residual_valid.csv', 'test': 'intell_residual_test.csv'}
filename_final = filename_res
sample_size = 'site16_allimages'
batch_size = 8
onlyt1 = False
Model = MyDenseNet
versionkey = 'down64' #down256, cropped128, cropped64, down64
modelname = 'new2_densenet_cross_allimages_' + versionkey



In [5]:
createPath(path_output + modelname)

In [6]:
train_df, val_df, test_df, norm_dict = return_csv(path_csv, filename_final, False)

In [7]:
train_iter = config[versionkey]['iter_train']
val_iter = config[versionkey]['iter_val']
test_iter = config[versionkey]['iter_test']
t1_mean = config[versionkey]['norm']['t1'][0]
t1_std= config[versionkey]['norm']['t1'][1]
t2_mean=config[versionkey]['norm']['t2'][0]
t2_std=config[versionkey]['norm']['t2'][1]
ad_mean=config[versionkey]['norm']['ad'][0]
ad_std=config[versionkey]['norm']['ad'][1]
fa_mean=config[versionkey]['norm']['fa'][0]
fa_std=config[versionkey]['norm']['fa'][1]
md_mean=config[versionkey]['norm']['md'][0]
md_std=config[versionkey]['norm']['md'][1]
rd_mean=config[versionkey]['norm']['rd'][0]
rd_std=config[versionkey]['norm']['rd'][1]

In [8]:
norm_dict

{'BMI': {'mean': 18.681548127052135, 'std': 4.193043131845343},
 'age': {'mean': 119.00325844623563, 'std': 7.479129774017182},
 'height': {'mean': 55.25173666322929, 'std': 3.152756181679028},
 'nihtbx_cardsort_uncorrected': {'mean': 0.01902727147573316,
  'std': 0.92710655806542},
 'nihtbx_cryst_uncorrected': {'mean': 0.007018628014748754,
  'std': 0.7845373584638602},
 'nihtbx_flanker_uncorrected': {'mean': 0.02188780794049048,
  'std': 0.9070917080607726},
 'nihtbx_fluidcomp_uncorrected': {'mean': 0.020178243913565427,
  'std': 0.86606123778624},
 'nihtbx_list_uncorrected': {'mean': 0.00625176016120734,
  'std': 0.8898550695735616},
 'nihtbx_pattern_uncorrected': {'mean': 0.020721412569885096,
  'std': 0.9486618556882954},
 'nihtbx_picture_uncorrected': {'mean': 0.0005782175223803825,
  'std': 0.9577989703304521},
 'nihtbx_picvocab_uncorrected': {'mean': 0.0068509109986280275,
  'std': 0.8038465951211212},
 'nihtbx_reading_uncorrected': {'mean': 0.0027847502208883574,
  'std': 0.85

In [9]:
cat_cols = {'female': 2, 'race.ethnicity': 5, 'high.educ_group': 4, 'income_group': 8, 'married': 6}
num_cols = [x for x in list(val_df.columns) if '_norm' in x]

In [10]:
def calc_loss_acc(out_loss, out_acc, y_true, y_pred, cat_cols, num_cols, norm_dict):
  for col in num_cols:
    tmp_col = col
    tmp_std = norm_dict[tmp_col.replace('_norm','')]['std']
    tmp_y_true = tf.cast(y_true[col], tf.float32).numpy()
    tmp_y_pred = np.squeeze(y_pred[col].numpy())
    if not(tmp_col in out_loss):
      out_loss[tmp_col] = np.sum(np.square(tmp_y_true-tmp_y_pred))
    else:
      out_loss[tmp_col] += np.sum(np.square(tmp_y_true-tmp_y_pred))
    if not(tmp_col in out_acc):
      out_acc[tmp_col] = np.sum(np.square((tmp_y_true-tmp_y_pred)*tmp_std))
    else:
      out_acc[tmp_col] += np.sum(np.square((tmp_y_true-tmp_y_pred)*tmp_std))
  for col in list(cat_cols.keys()):
    tmp_col = col
    if not(tmp_col in out_loss):
      out_loss[tmp_col] = tf.keras.losses.SparseCategoricalCrossentropy()(tf.squeeze(y_true[col]), tf.squeeze(y_pred[col])).numpy()
    else:
      out_loss[tmp_col] += tf.keras.losses.SparseCategoricalCrossentropy()(tf.squeeze(y_true[col]), tf.squeeze(y_pred[col])).numpy()
    if not(tmp_col in out_acc):
      out_acc[tmp_col] = tf.reduce_sum(tf.dtypes.cast((y_true[col] == tf.argmax(y_pred[col], axis=-1)), tf.float32)).numpy()
    else:
      out_acc[tmp_col] += tf.reduce_sum(tf.dtypes.cast((y_true[col] == tf.argmax(y_pred[col], axis=-1)), tf.float32)).numpy()    
  return(out_loss, out_acc)

def format_output(out_loss, out_acc, n, cols, print_bl=False):
  loss = 0
  acc = 0
  output = []
  for col in cols:
    output.append([col, out_loss[col]/n, out_acc[col]/n])
    loss += out_loss[col]/n
    acc += out_acc[col]/n
  df = pd.DataFrame(output)
  df.columns = ['name', 'loss', 'acc']
  if print_bl:
    print(df)
  return(loss, acc, df)

@tf.function
def train_step(X, y, model, optimizer, cat_cols, num_cols):
  with tf.GradientTape() as tape:
    predictions = model(X)
    i = 0
    loss = tf.keras.losses.MSE(tf.cast(y[num_cols[i]], tf.float32), tf.squeeze(predictions[num_cols[i]]))
    for i in range(1,len(num_cols)):
      loss += tf.keras.losses.MSE(tf.cast(y[num_cols[i]], tf.float32), tf.squeeze(predictions[num_cols[i]]))
    for col in list(cat_cols.keys()):
      loss += tf.keras.losses.SparseCategoricalCrossentropy()(tf.squeeze(y[col]), tf.squeeze(predictions[col]))
  gradients = tape.gradient(loss, model.trainable_variables)
  mean_std = [x.name for x in model.non_trainable_variables if ('batch_norm') in x.name and ('mean' in x.name or 'variance' in x.name)]
  with tf.control_dependencies(mean_std):
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  return(y, predictions, loss)

@tf.function
def test_step(X, y, model):
  predictions = model(X)
  return(y, predictions)

def epoch(data_iter, df, model, optimizer, cat_cols, num_cols, norm_dict):
  out_loss = {}
  out_acc = {}
  n = 0.
  n_batch = 0.
  total_time_dataload = 0.
  total_time_model = 0.
  start_time = time.time()
  for batch in data_iter:
    total_time_dataload += time.time() - start_time
    start_time = time.time()
    t1 = (tf.cast(batch['t1'], tf.float32)-t1_mean)/t1_std
    t2 = (batch['t2']-t2_mean)/t2_std
    if False:
        ad = batch['ad']
        ad = tf.where(tf.math.is_nan(ad), tf.zeros_like(ad), ad)
        ad = (ad-ad_mean)/ad_std
        fa = batch['fa']
        fa = tf.where(tf.math.is_nan(fa), tf.zeros_like(fa), fa)
        fa = (fa-fa_mean)/fa_std
        md = batch['md']
        md = tf.where(tf.math.is_nan(md), tf.zeros_like(md), md)
        md = (md-md_mean)/md_std
        rd = batch['rd']
        rd = tf.where(tf.math.is_nan(rd), tf.zeros_like(rd), rd)
        rd = (rd-rd_mean)/rd_std
    subjectid = decoder(batch['subjectid'])
    y = get_labels(df, subjectid, list(cat_cols.keys())+num_cols)
    #X = tf.concat([t1], axis=4)
    X = tf.concat([t1, t2], axis=4)
    if optimizer != None:
      y_true, y_pred, loss = train_step(X, y, model, optimizer, cat_cols, num_cols)
    else:
      y_true, y_pred = test_step(X, y, model)
    out_loss, out_acc = calc_loss_acc(out_loss, out_acc, y_true, y_pred, cat_cols, num_cols, norm_dict)
    n += X.shape[0]
    n_batch += 1
    if (n_batch % 10) == 0:
      log_textfile(path_output + modelname + '/log' + '.log', str(n_batch))
    total_time_model += time.time() - start_time
    start_time = time.time()
  return (out_loss, out_acc, n, total_time_model, total_time_dataload)

def get_labels(df, subjectid, cols = ['nihtbx_fluidcomp_uncorrected_norm']):
  subjects_df = pd.DataFrame(subjectid)
  result_df = pd.merge(subjects_df, df, left_on=0, right_on='subjectkey', how='left')
  output = {}
  for col in cols:
    output[col] = np.asarray(result_df[col].values)
  return output

def best_val(df_best, df_val, df_test, e):
  df_best = pd.merge(df_best, df_val, how='left', left_on='name', right_on='name')
  df_best = pd.merge(df_best, df_test, how='left', left_on='name', right_on='name')
  df_best.loc[df_best['best_loss_val']>=df_best['cur_loss_val'], 'best_loss_epochs'] = e
  df_best.loc[(df_best['best_acc_val']<=df_best['cur_acc_val'])&(df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'best_acc_epochs'] = e
  df_best.loc[df_best['best_loss_val']>=df_best['cur_loss_val'], 'best_loss_test'] = df_best.loc[df_best['best_loss_val']>=df_best['cur_loss_val'], 'cur_loss_test']
  df_best.loc[df_best['best_loss_val']>=df_best['cur_loss_val'], 'best_loss_val'] = df_best.loc[df_best['best_loss_val']>=df_best['cur_loss_val'], 'cur_loss_val']
  df_best.loc[(df_best['best_acc_val']<=df_best['cur_acc_val'])&(df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'best_acc_test'] = df_best.loc[(df_best['best_acc_val']<=df_best['cur_acc_val'])&(df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'cur_acc_test']
  df_best.loc[(df_best['best_acc_val']<=df_best['cur_acc_val'])&(df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'best_acc_val'] = df_best.loc[(df_best['best_acc_val']<=df_best['cur_acc_val'])&(df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'cur_acc_val']
  df_best.loc[(df_best['best_acc_val']>=df_best['cur_acc_val'])&(~df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'best_acc_test'] = df_best.loc[(df_best['best_acc_val']>=df_best['cur_acc_val'])&(~df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'cur_acc_test']
  df_best.loc[(df_best['best_acc_val']>=df_best['cur_acc_val'])&(~df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'best_acc_val'] = df_best.loc[(df_best['best_acc_val']>=df_best['cur_acc_val'])&(~df_best['name'].isin(['female', 'race.ethnicity', 'high.educ_group', 'income_group', 'married'])), 'cur_acc_val']
  df_best = df_best.drop(['cur_loss_val', 'cur_acc_val', 'cur_loss_test', 'cur_acc_test'], axis=1)
  return(df_best)

In [11]:
decoder = np.vectorize(lambda x: x.decode('UTF-8'))
template = 'Epoch {0}, Loss: {1:.3f}, Accuracy: {2:.3f}, Val Loss: {3:.3f}, Val Accuracy: {4:.3f}, Time Model: {5:.3f}, Time Data: {6:.3f}'
for col in [0]:
  log_textfile(path_output + modelname + '/log' + '.log', cat_cols)
  log_textfile(path_output + modelname + '/log' + '.log', num_cols)
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
  optimizer = tf.keras.optimizers.Adam(lr = 0.001)
  model = Model(cat_cols, num_cols)
  df_best = None
  for e in range(20):
    log_textfile(path_output + modelname + '/log' + '.log', 'Epochs: ' + str(e))
    loss = tf.Variable(0.)
    acc = tf.Variable(0.) 
    val_loss = tf.Variable(0.)
    val_acc = tf.Variable(0.)
    test_loss = tf.Variable(0.)
    test_acc = tf.Variable(0.)
    tf.keras.backend.set_learning_phase(True)
    train_out_loss, train_out_acc, n, time_model, time_data = epoch(train_iter, train_df, model, optimizer, cat_cols, num_cols, norm_dict)
    tf.keras.backend.set_learning_phase(False)
    val_out_loss, val_out_acc, n, _, _ = epoch(val_iter, val_df, model, None, cat_cols, num_cols, norm_dict)
    test_out_loss, test_out_acc, n, _, _ = epoch(test_iter, test_df, model, None, cat_cols, num_cols, norm_dict)
    loss, acc, _ = format_output(train_out_loss, train_out_acc, n, list(cat_cols.keys())+num_cols)
    val_loss, val_acc, df_val = format_output(val_out_loss, val_out_acc, n, list(cat_cols.keys())+num_cols, print_bl=False)
    test_loss, test_acc, df_test = format_output(test_out_loss, test_out_acc, n, list(cat_cols.keys())+num_cols, print_bl=False)
    df_val.columns = ['name', 'cur_loss_val', 'cur_acc_val']
    df_test.columns = ['name', 'cur_loss_test', 'cur_acc_test']
    if e == 0:
      df_best = pd.merge(df_test, df_val, how='left', left_on='name', right_on='name')
      df_best['best_acc_epochs'] = 0
      df_best['best_loss_epochs'] = 0
      df_best.columns = ['name', 'best_loss_test', 'best_acc_test', 'best_loss_val', 'best_acc_val', 'best_acc_epochs', 'best_loss_epochs']
    df_best = best_val(df_best, df_val, df_test, e)
    print(df_best[['name', 'best_loss_test', 'best_acc_test']])
    print(df_best[['name', 'best_loss_val', 'best_acc_val']])
    log_textfile(path_output + modelname + '/log' + '.log', template.format(e, loss, acc, val_loss, val_acc, time_model, time_data))
    if e in [10, 15]:
      optimizer.lr = optimizer.lr/3
      log_textfile(path_output + modelname + '/log' + '.log', 'Learning rate: ' + str(optimizer.lr))
    df_best.to_csv(path_output +  modelname + '/df_best' + str(e) + '.csv')
    df_best.to_csv(path_output +  modelname + '/df_best' + '.csv')
    #model.save_weights(path_output + modelname + '/checkpoints/' + str(e) + '/')

{'income_group': 8, 'race.ethnicity': 5, 'female': 2, 'high.educ_group': 4, 'married': 6}
['BMI_norm', 'age_norm', 'vol_norm', 'weight_norm', 'height_norm', 'nihtbx_fluidcomp_uncorrected_norm', 'nihtbx_cryst_uncorrected_norm', 'nihtbx_pattern_uncorrected_norm', 'nihtbx_picture_uncorrected_norm', 'nihtbx_list_uncorrected_norm', 'nihtbx_flanker_uncorrected_norm', 'nihtbx_picvocab_uncorrected_norm', 'nihtbx_cardsort_uncorrected_norm', 'nihtbx_totalcomp_uncorrected_norm', 'nihtbx_reading_uncorrected_norm']


Epochs: 0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.211743       0.535629
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.534027
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm       14.111100     248.095802
6                            age_norm        2.852541     159.563751
7                            vol_norm        0.586413       0.588408
8                         weight_norm        8.212801    4467.010809
9                         height_norm        1.529852      15.206543
10  nihtbx_fluidcomp_uncorrected_norm        1.312983       0.984819
11      nihtbx_cryst_uncorrected_norm        1.056363       0.650190
12    nihtbx_pattern_uncorrected_norm        1.013499       0.912108
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.211743       0.535629
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.534027
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        5.083577      89.377471
6                            age_norm        2.852541     159.563751
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.284370     154.671012
9                         height_norm        0.903699       8.982648
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.056363       0.650190
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.534027
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.357179       6.279790
6                            age_norm        2.852541     159.563751
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.284370     154.671012
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.056363       0.650190
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.534027
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.293343       5.157434
6                            age_norm        2.852541     159.563751
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.056363       0.650190
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.534027
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.293343       5.157434
6                            age_norm        2.852541     159.563751
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.056363       0.650190
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.535629
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.293343       5.157434
6                            age_norm        2.852541     159.563751
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.056363       0.650190
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.535629
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.293343       5.157434
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.550040
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.293343       5.157434
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.550040
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.707766
5                            BMI_norm        0.293343       5.157434
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.397458       0.290633
2                      race.ethnicity        0.337834       0.620496
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.726181
5                            BMI_norm        0.293343       5.157434
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.392899       0.296237
2                      race.ethnicity        0.337834       0.620496
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.726181
5                            BMI_norm        0.257790       4.532355
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.496049       0.497737
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        1.001303       0.901133
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.392899       0.296237
2                      race.ethnicity        0.337834       0.620496
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.726181
5                            BMI_norm        0.257790       4.532355
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.267672       0.268583
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        0.986437       0.887753
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


230.0


240.0


250.0


260.0


270.0


280.0


290.0


300.0


310.0


320.0


330.0


340.0


350.0


360.0


370.0


380.0


390.0


400.0


410.0


420.0


430.0


440.0


450.0


460.0


470.0


480.0


490.0


500.0


510.0


520.0


530.0


540.0


550.0


560.0


570.0


580.0


590.0


600.0


610.0


620.0


630.0


640.0


650.0


660.0


670.0


680.0


690.0


700.0


710.0


720.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


                                 name  best_loss_test  best_acc_test
0                              female        0.083594       0.690152
1                        income_group        0.392899       0.296237
2                      race.ethnicity        0.337834       0.620496
3                     high.educ_group        0.079121       0.874299
4                             married        0.135511       0.726181
5                            BMI_norm        0.257790       4.532355
6                            age_norm        0.974896      54.533095
7                            vol_norm        0.267672       0.268583
8                         weight_norm        0.189348     102.988040
9                         height_norm        0.671672       6.676336
10  nihtbx_fluidcomp_uncorrected_norm        1.021275       0.766020
11      nihtbx_cryst_uncorrected_norm        1.053927       0.648691
12    nihtbx_pattern_uncorrected_norm        0.986437       0.887753
13    nihtbx_picture_uncorrected_n

10.0


20.0


30.0


40.0


50.0


60.0


70.0


80.0


90.0


100.0


110.0


120.0


130.0


140.0


150.0


160.0


170.0


180.0


190.0


200.0


210.0


220.0


In [None]:
error

In [None]:
test_loss, test_acc, df_test = format_output(test_out_loss, test_out_acc, n, list(cat_cols.keys())+num_cols, print_bl=False)

In [None]:
df_test.to_csv('final_output_all.csv')

In [None]:
inputs = tf.keras.Input(shape=(64,64,64,2), name='inputlayer123')
a = model(inputs)['female']
mm = tf.keras.models.Model(inputs=inputs, outputs=a)

In [None]:
from tf_explain.core.smoothgrad import SmoothGrad
import pickle

explainer = SmoothGrad()
output_grid = {}
output_n = {}
for i in range(2):
  output_grid[i] = np.zeros((64,64,64))
  output_n[i] = 0
counter = 0
for batch in test_iter:
  counter+=1
  print(counter)
  t1 = (tf.cast(batch['t1'], tf.float32)-t1_mean)/t1_std
  t2 = (batch['t2']-t2_mean)/t2_std
  X = tf.concat([t1, t2], axis=4)
  subjectid = decoder(batch['subjectid'])
  y = get_labels(test_df, subjectid, list(cat_cols.keys())+num_cols)
  y_list = list(y['female'])
  for i in range(X.shape[0]):
    X_i = X[i]
    X_i = tf.expand_dims(X_i, axis=0)
    y_i = y_list[i]
    grid = explainer.explain((X_i, _), mm, y_i, 20, 1.)
    output_grid[y_i] += grid
    output_n[y_i] += 1

In [None]:
pickle.dump([output_grid, output_n], open( "smoothgrad_female_all.p", "wb" ) )

In [None]:
#output_grid, output_n = pickle.load(open( "smoothgrad_female.p", "rb" ))

In [None]:
def apply_grey_patch(image, top_left_x, top_left_y, top_left_z, patch_size):
    """
    Replace a part of the image with a grey patch.
    Args:
        image (numpy.ndarray): Input image
        top_left_x (int): Top Left X position of the applied box
        top_left_y (int): Top Left Y position of the applied box
        patch_size (int): Size of patch to apply
    Returns:
        numpy.ndarray: Patched image
    """
    patched_image = np.array(image, copy=True)
    patched_image[
        top_left_x : top_left_x + patch_size, top_left_y : top_left_y + patch_size, top_left_z : top_left_z + patch_size, :
    ] = 0

    return patched_image

In [None]:
import math

In [None]:
def get_sensgrid(image, mm, class_index, patch_size):
  sensitivity_map = np.zeros((
    math.ceil(image.shape[0] / patch_size),
    math.ceil(image.shape[1] / patch_size),
    math.ceil(image.shape[2] / patch_size)
  ))
  for index_z, top_left_z in enumerate(range(0, image.shape[2], patch_size)):
    patches = [
              apply_grey_patch(image, top_left_x, top_left_y, top_left_z, patch_size)
              for index_x, top_left_x in enumerate(range(0, image.shape[0], patch_size))
              for index_y, top_left_y in enumerate(range(0, image.shape[1], patch_size))
              ]
    coordinates = [
                (index_y, index_x)
                for index_x, _ in enumerate(range(0, image.shape[0], patch_size))
                for index_y, _ in enumerate(range(0, image.shape[1], patch_size))
                ]
    predictions = mm.predict(np.array(patches), batch_size=1)
    target_class_predictions = [prediction[class_index] for prediction in predictions]
    for (index_y, index_x), confidence in zip(coordinates, target_class_predictions):
      sensitivity_map[index_y, index_x, index_z] = 1 - confidence
  sm = resize(sensitivity_map, (64,64,64))
  heatmap = (sm - np.min(sm)) / (sm.max() - sm.min())
  return(heatmap)

In [None]:
output_grid = {}
output_n = {}
for i in range(2):
  output_grid[i] = np.zeros((64,64,64))
  output_n[i] = 0

counter = 0
for batch in test_iter:
  counter+=1
  print(counter)
  t1 = (tf.cast(batch['t1'], tf.float32)-t1_mean)/t1_std
  t2 = (batch['t2']-t2_mean)/t2_std
  X = tf.concat([t1, t2], axis=4)
  subjectid = decoder(batch['subjectid'])
  y = get_labels(test_df, subjectid, list(cat_cols.keys())+num_cols)
  y_list = list(y['female'])
  for i in range(X.shape[0]):
    print(i)
    X_i = X[i]
    y_i = y_list[i]
    grid = get_sensgrid(X_i, mm, y_i, 4)
    output_grid[y_i] += grid
    output_n[y_i] += 1
  if counter==6:
    break

In [None]:
pickle.dump([output_grid, output_n], open( "heatmap_female_all.p", "wb" ) )

In [None]:
error

In [None]:
batch = next(iter(train_iter))

In [None]:
t1 = (tf.cast(batch['t1'], tf.float32)-t1_mean)/t1_std
t2 = (batch['t2']-t2_mean)/t2_std
ad = batch['ad']
ad = tf.where(tf.math.is_nan(ad), tf.zeros_like(ad), ad)
ad = (ad-ad_mean)/ad_std
fa = batch['fa']
fa = tf.where(tf.math.is_nan(fa), tf.zeros_like(fa), fa)
fa = (fa-fa_mean)/fa_std
md = batch['md']
md = tf.where(tf.math.is_nan(md), tf.zeros_like(md), md)
md = (md-md_mean)/md_std
rd = batch['rd']
rd = tf.where(tf.math.is_nan(rd), tf.zeros_like(rd), rd)
rd = (rd-rd_mean)/rd_std
#subjectid = decoder(batch['subjectid'])
#y = get_labels(df, subjectid, list(cat_cols.keys())+num_cols)
#X = tf.concat([t1, t2, ad, fa, md, rd], axis=4)
X = tf.concat([t1, t2], axis=4)    

In [None]:
tf.keras.backend.set_learning_phase(True)
model(X)['female']

In [None]:
tf.keras.backend.set_learning_phase(False)
model(X)['female']

In [None]:
mean_std = [x.name for x in model.non_trainable_variables if ('batch_norm') in x.name and ('mean' in x.name or 'variance' in x.name)]

In [None]:
model = Model(cat_cols, num_cols)

In [None]:
model.non_trainable_variables 