In [1]:
import afqinsight.nn.tf_models as nn
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from afqinsight.datasets import AFQDataset
from afqinsight.nn.tf_models import cnn_lenet, mlp4, cnn_vgg, lstm1v0, lstm1, lstm2, blstm1, blstm2, lstm_fcn, cnn_resnet
from sklearn.impute import SimpleImputer
import os.path
# Harmonization
from sklearn.model_selection import train_test_split
from neurocombat_sklearn import CombatModel
import pandas as pd
from sklearn.utils import shuffle, resample
from afqinsight.augmentation import jitter, time_warp, scaling
import tempfile

In [2]:
afq_dataset = AFQDataset.from_files(
    fn_nodes="../data/raw/combined_tract_profiles.csv",
    fn_subjects="../data/raw/participants_updated_id.csv",
    dwi_metrics=["dki_fa", "dki_md", "dki_mk"],
    index_col="subject_id",
    target_cols=["age", "dl_qc_score", "scan_site_id"],
    label_encode_cols=["scan_site_id"]
)

In [3]:
afq_dataset.drop_target_na()

In [4]:
print(len(afq_dataset.subjects))
print(afq_dataset.X.shape)
print(afq_dataset.y.shape)

1865
(1865, 7200)
(1865, 3)


In [5]:
full_dataset = list(afq_dataset.as_tensorflow_dataset().as_numpy_iterator())

2022-05-17 22:34:23.755187: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [6]:
X = np.concatenate([xx[0][None] for xx in full_dataset], 0)
y = np.array([yy[1][0] for yy in full_dataset])
qc = np.array([yy[1][1] for yy in full_dataset])
site = np.array([yy[1][2] for yy in full_dataset])

In [7]:
X = X[qc>0]
y = y[qc>0]
site = site[qc>0]

In [8]:
n_epochs = 1000

# EarlyStopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    mode="min",
    patience=100
)

# ReduceLROnPlateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=20,
    verbose=1,
)

In [9]:
def augment_this(X, y, rounds=2): 
    new_X = X[:]
    new_y = y[:]
    for f in range(rounds): 
        aug_X = np.zeros_like(X)
        # Do each channel separately:
        for channel in range(aug_X.shape[-1]):
            this_X = X[..., channel][..., np.newaxis]
            this_X = jitter(this_X, sigma=np.mean(this_X)/25)
            this_X = scaling(this_X, sigma=np.mean(this_X)/25)
            this_X = time_warp(this_X, sigma=np.mean(this_X)/25)
            aug_X[..., channel] = this_X[...,0]
        new_X = np.concatenate([new_X, aug_X])
        new_y = np.concatenate([new_y, y])
    return new_X, new_y 

In [10]:
imputer = SimpleImputer(strategy="median")
def impute(X_data):
    X_data = np.concatenate([imputer.fit_transform(X_data[..., ii])[:, :, None] for ii in range(X_data.shape[-1])], -1)
    return X_data

In [11]:
# Two-in-one test
def cross_site(model_name, name_str, lr, site_1, site_2, site_3, X, y):
    # Split the data by sites
    X_1 = X[site==site_1]
    y_1 = y[site==site_1]
    X_2 = X[site==site_2]
    y_2 = y[site==site_2]
    X_3 = X[site==site_3]
    y_3 = y[site==site_3]
    # Split the data into train and test sets:
    X_train1, X_test, y_train1, y_test = train_test_split(X_1, y_1, test_size=0.2)
    X_train2, _, y_train2, _ = train_test_split(X_2, y_2, test_size=0.2)
    X_train3, _, y_train3, _ = train_test_split(X_3, y_3, test_size=0.2)
    # Imputation
    X_train1 = impute(X_train1)
    X_train2 = impute(X_train2)
    X_train3 = impute(X_train3)
    X_test = impute(X_test)
    
    # Single_site
    # Training on site 1
    model1 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model1.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])
               
    ckpt_filepath1 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt1 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath1,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )

    log1 = tf.keras.callbacks.CSVLogger(filename=(name_str + '1.csv'), append=True)
    callbacks1 = [early_stopping, ckpt1, reduce_lr, log1]
    model1.fit(X_train1, y_train1, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks1)
    model1.load_weights(ckpt_filepath1)
    
    # Training on site 2
    model2 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model2.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])
               
    ckpt_filepath2 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt2 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath2,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )

    log2 = tf.keras.callbacks.CSVLogger(filename=(name_str + '2.csv'), append=True)
    callbacks2 = [early_stopping, ckpt2, reduce_lr, log2]
    model2.fit(X_train2, y_train2, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks2)
    model2.load_weights(ckpt_filepath2)
               
    # Training on site 3
    model3 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model3.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])

    ckpt_filepath3 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt3 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath3,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )
    
    log3 = tf.keras.callbacks.CSVLogger(filename=(name_str + '3.csv'), append=True)
    callbacks3 = [early_stopping, ckpt3, reduce_lr, log3]
    model3.fit(X_train3, y_train3, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks3)
    model3.load_weights(ckpt_filepath3)
    
    # Double cross site
    # Training on site 2 and 3
    sample = y_test.shape[0]//2
    sample1 = resample(X_train2, y_train2, n_samples=sample, replace=False)
    sample2 = resample(X_train3, y_train3, n_samples=sample, replace=False)
    X_train4 = np.concatenate((sample1[0], sample2[0]), axis=0)
    y_train4 = np.concatenate((sample1[1], sample2[1]), axis=0)
               
    X_train4, y_train4 = shuffle(X_train4, y_train4)
    X_train4, y_train4 = augment_this(X_train4, y_train4)
    X_train4, y_train4 = shuffle(X_train4, y_train4)
    
    model4 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model4.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])
    
    ckpt_filepath4 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt4 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath4,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )
    
    log4 = tf.keras.callbacks.CSVLogger(filename=(name_str + '4.csv'), append=True)
    callbacks4 = [early_stopping, ckpt4, reduce_lr, log4]
    model4.fit(X_train4, y_train4, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks4)
    model4.load_weights(ckpt_filepath4)
               
    # Testing on site 1
    y_predict1 = model1.predict(X_test)
    y_predict1 = y_predict1.reshape(y_test.shape)
    y_predict2 = model2.predict(X_test)
    y_predict2 = y_predict2.reshape(y_test.shape)
    y_predict3 = model3.predict(X_test)
    y_predict3 = y_predict3.reshape(y_test.shape)
    y_predict4 = model4.predict(X_test)
    y_predict4 = y_predict4.reshape(y_test.shape)
    coef1 = np.corrcoef(y_test, y_predict1)[0,1] ** 2
    coef2 = np.corrcoef(y_test, y_predict2)[0,1] ** 2
    coef3 = np.corrcoef(y_test, y_predict3)[0,1] ** 2
    coef4 = np.corrcoef(y_test, y_predict4)[0,1] ** 2
    eval_1 = model1.evaluate(X_test, y_test)
    eval_2 = model2.evaluate(X_test, y_test)
    eval_3 = model3.evaluate(X_test, y_test)
    eval_4 = model4.evaluate(X_test, y_test)
    
    # Results
    result = {'Model': [name_str]*16,
              'Train_site': [site_1] * 4 + [site_2] * 4 + [site_3] * 4 + [f'{site_2}, {site_3}'] * 4,
              'Test_site': [site_1] * 16,
              'Metric': ['MSE', 'RMSE', 'MAE', 'coef'] * 4,
              'Value': [eval_1[1], eval_1[2], eval_1[3], coef1,
                        eval_2[1], eval_2[2], eval_2[3], coef2,
                        eval_3[1], eval_3[2], eval_3[3], coef3,
                        eval_4[1], eval_4[2], eval_3[3], coef4]}
    df = pd.DataFrame(result)
    return df

In [12]:
df_resnet1 = cross_site(cnn_resnet, 'cnn_resnet', 0.01, 0, 3, 4, X, y)
df_resnet2 = cross_site(cnn_resnet, 'cnn_resnet', 0.01, 3, 0, 4, X, y)
df_resnet3 = cross_site(cnn_resnet, 'cnn_resnet', 0.01, 4, 0, 3, X, y)
df_resnet = (df_resnet1.merge(df_resnet2, how='outer')).merge(df_resnet3, how='outer')

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 100, 72)]    0           []                               
                                                                                                  
 conv1d (Conv1D)                (None, 100, 64)      36928       ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 100, 64)     256         ['conv1d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 activation (Activation)        (None, 100, 64)      0           ['batch_normalization[0][0]']

Epoch 00001: val_loss improved from inf to 5030200.00000, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 2/1000
Epoch 00002: val_loss did not improve from 5030200.00000
Epoch 3/1000
Epoch 00003: val_loss did not improve from 5030200.00000
Epoch 4/1000
Epoch 00004: val_loss did not improve from 5030200.00000
Epoch 5/1000
Epoch 00005: val_loss did not improve from 5030200.00000
Epoch 6/1000
Epoch 00006: val_loss did not improve from 5030200.00000
Epoch 7/1000
Epoch 00007: val_loss did not improve from 5030200.00000
Epoch 8/1000
Epoch 00008: val_loss did not improve from 5030200.00000
Epoch 9/1000
Epoch 00009: val_loss did not improve from 5030200.00000
Epoch 10/1000
Epoch 00010: val_loss improved from 5030200.00000 to 2685270.25000, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 11/1000
Epoch 00011: val_loss improved from 2685270.25000 to 889697.56250, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr000

Epoch 17/1000
Epoch 00017: val_loss improved from 25619.22656 to 12702.78418, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 18/1000
Epoch 00018: val_loss improved from 12702.78418 to 2001.38525, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 19/1000
Epoch 00019: val_loss did not improve from 2001.38525
Epoch 20/1000
Epoch 00020: val_loss did not improve from 2001.38525
Epoch 21/1000
Epoch 00021: val_loss improved from 2001.38525 to 1497.15491, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 22/1000
Epoch 00022: val_loss improved from 1497.15491 to 661.95288, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 23/1000
Epoch 00023: val_loss improved from 661.95288 to 111.14220, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 24/1000
Epoch 00024: val_loss improved from 111.14220 to 42.62692, saving mode

Epoch 00033: val_loss did not improve from 10.64977
Epoch 34/1000
Epoch 00034: val_loss did not improve from 10.64977
Epoch 35/1000
Epoch 00035: val_loss did not improve from 10.64977
Epoch 36/1000
Epoch 00036: val_loss did not improve from 10.64977
Epoch 37/1000
Epoch 00037: val_loss did not improve from 10.64977
Epoch 38/1000
Epoch 00038: val_loss did not improve from 10.64977
Epoch 39/1000
Epoch 00039: val_loss did not improve from 10.64977
Epoch 40/1000
Epoch 00040: val_loss did not improve from 10.64977
Epoch 41/1000
Epoch 00041: val_loss did not improve from 10.64977
Epoch 42/1000
Epoch 00042: val_loss did not improve from 10.64977
Epoch 43/1000
Epoch 00043: val_loss did not improve from 10.64977
Epoch 44/1000
Epoch 00044: val_loss did not improve from 10.64977
Epoch 45/1000
Epoch 00045: val_loss improved from 10.64977 to 10.20870, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 46/1000
Epoch 00046: val_loss did not improve from 10.20870
Epoc

Epoch 51/1000
Epoch 00051: val_loss did not improve from 10.20870
Epoch 52/1000
Epoch 00052: val_loss did not improve from 10.20870
Epoch 53/1000
Epoch 00053: val_loss improved from 10.20870 to 9.52952, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch 54/1000
Epoch 00054: val_loss did not improve from 9.52952
Epoch 55/1000
Epoch 00055: val_loss did not improve from 9.52952
Epoch 56/1000
Epoch 00056: val_loss did not improve from 9.52952
Epoch 57/1000
Epoch 00057: val_loss did not improve from 9.52952
Epoch 58/1000
Epoch 00058: val_loss did not improve from 9.52952
Epoch 59/1000
Epoch 00059: val_loss did not improve from 9.52952
Epoch 60/1000
Epoch 00060: val_loss did not improve from 9.52952
Epoch 61/1000
Epoch 00061: val_loss did not improve from 9.52952
Epoch 62/1000
Epoch 00062: val_loss did not improve from 9.52952
Epoch 63/1000
Epoch 00063: val_loss did not improve from 9.52952
Epoch 64/1000
Epoch 00064: val_loss improved from 9.52952 to 8.0737

Epoch 69/1000
Epoch 00069: val_loss did not improve from 8.07373
Epoch 70/1000
Epoch 00070: val_loss did not improve from 8.07373
Epoch 71/1000
Epoch 00071: val_loss did not improve from 8.07373
Epoch 72/1000
Epoch 00072: val_loss did not improve from 8.07373
Epoch 73/1000
Epoch 00073: val_loss did not improve from 8.07373
Epoch 74/1000
Epoch 00074: val_loss did not improve from 8.07373
Epoch 75/1000
Epoch 00075: val_loss did not improve from 8.07373
Epoch 76/1000
Epoch 00076: val_loss did not improve from 8.07373
Epoch 77/1000
Epoch 00077: val_loss did not improve from 8.07373
Epoch 78/1000
Epoch 00078: val_loss did not improve from 8.07373
Epoch 79/1000
Epoch 00079: val_loss did not improve from 8.07373
Epoch 80/1000
Epoch 00080: val_loss did not improve from 8.07373
Epoch 81/1000
Epoch 00081: val_loss did not improve from 8.07373
Epoch 82/1000
Epoch 00082: val_loss improved from 8.07373 to 7.84815, saving model to /var/folders/c_/8lvtjqcn13jcbbdq_sx7f7sr0000gn/T/tmpwn9zw07d.h5
Epoch

Epoch 87/1000
Epoch 00087: val_loss did not improve from 7.57853
Epoch 88/1000

KeyboardInterrupt: 