In [None]:
import afqinsight.nn.tf_models as nn
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from afqinsight.datasets import AFQDataset
from afqinsight.nn.tf_models import cnn_lenet, mlp4, cnn_vgg, lstm1v0, lstm1, lstm2, blstm1, blstm2, lstm_fcn, cnn_resnet
from sklearn.impute import SimpleImputer
import os.path
# Harmonization
from sklearn.model_selection import train_test_split
from neurocombat_sklearn import CombatModel
import pandas as pd
from sklearn.utils import shuffle, resample
from afqinsight.augmentation import jitter, time_warp, scaling
import tempfile

In [None]:
afq_dataset = AFQDataset.from_files(
    fn_nodes="../data/raw/combined_tract_profiles.csv",
    fn_subjects="../data/raw/participants_updated_id.csv",
    dwi_metrics=["dki_fa", "dki_md", "dki_mk"],
    index_col="subject_id",
    target_cols=["age", "dl_qc_score", "scan_site_id"],
    label_encode_cols=["scan_site_id"]
)

In [None]:
afq_dataset.drop_target_na()

In [None]:
print(len(afq_dataset.subjects))
print(afq_dataset.X.shape)
print(afq_dataset.y.shape)

In [None]:
full_dataset = list(afq_dataset.as_tensorflow_dataset().as_numpy_iterator())

In [None]:
X = np.concatenate([xx[0][None] for xx in full_dataset], 0)
y = np.array([yy[1][0] for yy in full_dataset])
qc = np.array([yy[1][1] for yy in full_dataset])
site = np.array([yy[1][2] for yy in full_dataset])

In [None]:
X = X[qc>0]
y = y[qc>0]
site = site[qc>0]

In [None]:
n_epochs = 1000

# EarlyStopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    mode="min",
    patience=100
)

# ReduceLROnPlateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=20,
    verbose=1,
)

In [None]:
def augment_this(X, y, rounds=2): 
    new_X = X[:]
    new_y = y[:]
    for f in range(rounds): 
        aug_X = np.zeros_like(X)
        # Do each channel separately:
        for channel in range(aug_X.shape[-1]):
            this_X = X[..., channel][..., np.newaxis]
            this_X = jitter(this_X, sigma=np.mean(this_X)/25)
            this_X = scaling(this_X, sigma=np.mean(this_X)/25)
            this_X = time_warp(this_X, sigma=np.mean(this_X)/25)
            aug_X[..., channel] = this_X[...,0]
        new_X = np.concatenate([new_X, aug_X])
        new_y = np.concatenate([new_y, y])
    return new_X, new_y 

In [None]:
imputer = SimpleImputer(strategy="median")
def impute(X_data):
    X_data = np.concatenate([imputer.fit_transform(X_data[..., ii])[:, :, None] for ii in range(X_data.shape[-1])], -1)
    return X_data

In [None]:
# Two-in-one test
def cross_site(model_name, name_str, lr, site_1, site_2, site_3, X, y):
    # Split the data by sites
    X_1 = X[site==site_1]
    y_1 = y[site==site_1]
    X_2 = X[site==site_2]
    y_2 = y[site==site_2]
    X_3 = X[site==site_3]
    y_3 = y[site==site_3]
    # Split the data into train and test sets:
    X_train1, X_test, y_train1, y_test = train_test_split(X_1, y_1, test_size=0.2)
    X_train2, _, y_train2, _ = train_test_split(X_2, y_2, test_size=0.2)
    X_train3, _, y_train3, _ = train_test_split(X_3, y_3, test_size=0.2)
    # Imputation
    X_train1 = impute(X_train1)
    X_train2 = impute(X_train2)
    X_train3 = impute(X_train3)
    X_test = impute(X_test)
    
    # Single_site
    # Training on site 1
    model1 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model1.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])
               
    ckpt_filepath1 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt1 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath1,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )

    log1 = tf.keras.callbacks.CSVLogger(filename=(name_str + '1.csv'), append=True)
    callbacks1 = [early_stopping, ckpt1, reduce_lr, log1]
    model1.fit(X_train1, y_train1, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks1)
    model1.load_weights(ckpt_filepath1)
    
    # Training on site 2
    model2 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model2.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])
               
    ckpt_filepath2 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt2 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath2,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )

    log2 = tf.keras.callbacks.CSVLogger(filename=(name_str + '2.csv'), append=True)
    callbacks2 = [early_stopping, ckpt2, reduce_lr, log2]
    model2.fit(X_train2, y_train2, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks2)
    model2.load_weights(ckpt_filepath2)
               
    # Training on site 3
    model3 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model3.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])

    ckpt_filepath3 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt3 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath3,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )
    
    log3 = tf.keras.callbacks.CSVLogger(filename=(name_str + '3.csv'), append=True)
    callbacks3 = [early_stopping, ckpt3, reduce_lr, log3]
    model3.fit(X_train3, y_train3, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks3)
    model3.load_weights(ckpt_filepath3)
    
    # Double cross site
    # Training on site 2 and 3
    sample = y_test.shape[0]//2
    sample1 = resample(X_train2, y_train2, n_samples=sample, replace=False)
    sample2 = resample(X_train3, y_train3, n_samples=sample, replace=False)
    X_train4 = np.concatenate((sample1[0], sample2[0]), axis=0)
    y_train4 = np.concatenate((sample1[1], sample2[1]), axis=0)
               
    X_train4, y_train4 = shuffle(X_train4, y_train4)
    X_train4, y_train4 = augment_this(X_train4, y_train4)
    X_train4, y_train4 = shuffle(X_train4, y_train4)
    
    model4 = model_name(input_shape=(100, 72), n_classes=1, output_activation=None, verbose=True)
    model4.compile(loss='mean_squared_error',
                   optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                   metrics=['mean_squared_error', 
                            tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                            'mean_absolute_error'])
    
    ckpt_filepath4 = tempfile.NamedTemporaryFile().name + '.h5'
    ckpt4 = tf.keras.callbacks.ModelCheckpoint(
    filepath = ckpt_filepath4,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )
    
    log4 = tf.keras.callbacks.CSVLogger(filename=(name_str + '4.csv'), append=True)
    callbacks4 = [early_stopping, ckpt4, reduce_lr, log4]
    model4.fit(X_train4, y_train4, epochs=n_epochs, batch_size=128,
               validation_split=0.2, callbacks=callbacks4)
    model4.load_weights(ckpt_filepath4)
               
    # Testing on site 1
    y_predict1 = model1.predict(X_test)
    y_predict1 = y_predict1.reshape(y_test.shape)
    y_predict2 = model2.predict(X_test)
    y_predict2 = y_predict2.reshape(y_test.shape)
    y_predict3 = model3.predict(X_test)
    y_predict3 = y_predict3.reshape(y_test.shape)
    y_predict4 = model4.predict(X_test)
    y_predict4 = y_predict4.reshape(y_test.shape)
    coef1 = np.corrcoef(y_test, y_predict1)[0,1] ** 2
    coef2 = np.corrcoef(y_test, y_predict2)[0,1] ** 2
    coef3 = np.corrcoef(y_test, y_predict3)[0,1] ** 2
    coef4 = np.corrcoef(y_test, y_predict4)[0,1] ** 2
    eval_1 = model1.evaluate(X_test, y_test)
    eval_2 = model2.evaluate(X_test, y_test)
    eval_3 = model3.evaluate(X_test, y_test)
    eval_4 = model4.evaluate(X_test, y_test)
    
    # Results
    result = {'Model': [name_str]*16,
              'Train_site': [site_1] * 4 + [site_2] * 4 + [site_3] * 4 + [f'{site_2}, {site_3}'] * 4,
              'Test_site': [site_1] * 16,
              'Metric': ['MSE', 'RMSE', 'MAE', 'coef'] * 4,
              'Value': [eval_1[1], eval_1[2], eval_1[3], coef1,
                        eval_2[1], eval_2[2], eval_2[3], coef2,
                        eval_3[1], eval_3[2], eval_3[3], coef3,
                        eval_4[1], eval_4[2], eval_3[3], coef4]}
    df = pd.DataFrame(result)
    return df

In [None]:
df_resnet1 = cross_site(cnn_resnet, 'cnn_resnet', 0.01, 0, 3, 4, X, y)
df_resnet2 = cross_site(cnn_resnet, 'cnn_resnet', 0.01, 3, 0, 4, X, y)
df_resnet3 = cross_site(cnn_resnet, 'cnn_resnet', 0.01, 4, 0, 3, X, y)
df_resnet = (df_resnet1.merge(df_resnet2, how='outer')).merge(df_resnet3, how='outer')