In [1]:
import afqinsight.nn.tf_models as nn
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from afqinsight.datasets import AFQDataset
from afqinsight.nn.tf_models import cnn_lenet, mlp4, cnn_vgg, lstm1v0, lstm1, lstm2, blstm1, blstm2, lstm_fcn, cnn_resnet
from sklearn.impute import SimpleImputer
import os.path
# Harmonization
from sklearn.model_selection import train_test_split
from neurocombat_sklearn import CombatModel
import pickle

ModuleNotFoundError: No module named 'afqinsight'

In [None]:
afq_dataset = AFQDataset.from_files(
    fn_nodes="../data/raw/combined_tract_profiles.csv",
    fn_subjects="../data/raw/participants_updated_id.csv",
    dwi_metrics=["dki_fa", "dki_md", "dki_mk"],
    index_col="subject_id",
    target_cols=["age", "dl_qc_score", "scan_site_id"],
    label_encode_cols=["scan_site_id"]
)

In [None]:
afq_dataset.drop_target_na()

In [None]:
print(len(afq_dataset.subjects))
print(afq_dataset.X.shape)
print(afq_dataset.y.shape)

In [None]:
full_dataset = list(afq_dataset.as_tensorflow_dataset().as_numpy_iterator())

In [None]:
X = np.concatenate([xx[0][None] for xx in full_dataset], 0)
y = np.array([yy[1][0] for yy in full_dataset])
qc = np.array([yy[1][1] for yy in full_dataset])
site = np.array([yy[1][2] for yy in full_dataset])

In [None]:
X = X[qc>0]
y = y[qc>0]
site = site[qc>0]

In [None]:
# Split the data into train and test sets:
X_train, X_test, y_train, y_test, site_train, site_test = train_test_split(X, y, site, test_size=0.2, random_state=42)

In [None]:
imputer = SimpleImputer(strategy="median")

In [None]:
# Impute train and test separately:
X_train = np.concatenate([imputer.fit_transform(X_train[..., ii])[:, :, None] for ii in range(X_train.shape[-1])], -1)
X_test = np.concatenate([imputer.fit_transform(X_test[..., ii])[:, :, None] for ii in range(X_test.shape[-1])], -1)

In [None]:
# Combat
X_train = np.concatenate([CombatModel().fit_transform(X_train[..., ii], site_train[:, None], None, None)[:, :, None] for ii in range(X_train.shape[-1])], -1)
X_test = np.concatenate([CombatModel().fit_transform(X_test[..., ii], site_test[:, None], None, None)[:, :, None] for ii in range(X_test.shape[-1])], -1)

In [None]:
n_epochs = 1000

# EarlyStopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    mode="min",
    patience=100
)

# ReduceLROnPlateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=20,
    verbose=1,
)

In [None]:
from afqinsight.augmentation import jitter, time_warp, scaling

In [None]:
def augment_this(X, y, rounds=2): 
    new_X = X[:]
    new_y = y[:]
    for f in range(rounds): 
        aug_X = np.zeros_like(X)
        # Do each channel separately:
        for channel in range(aug_X.shape[-1]):
            this_X = X[..., channel][..., np.newaxis]
            this_X = jitter(this_X, sigma=np.mean(this_X)/25)
            this_X = scaling(this_X, sigma=np.mean(this_X)/25)
            this_X = time_warp(this_X, sigma=np.mean(this_X)/25)
            aug_X[..., channel] = this_X[...,0]
        new_X = np.concatenate([new_X, aug_X])
        new_y = np.concatenate([new_y, y])
    return new_X, new_y 

In [None]:
from sklearn.utils import shuffle, resample

In [None]:
# Generate evaluation results, training history, number of epochs
def model_history(model_name, ckpt_filepath, lr, X_train, y_train):
    model = model_name(input_shape=(100, 24), n_classes=1, output_activation=None, verbose=True)
    model.compile(loss='mean_squared_error',
                  optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                  metrics=['mean_squared_error', tf.keras.metrics.RootMeanSquaredError(name='rmse'), 'mean_absolute_error'])
    # ModelCheckpoint
    ckpt = tf.keras.callbacks.ModelCheckpoint(
    ckpt_filepath,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )

    # CSVLogger
    log = tf.keras.callbacks.CSVLogger(filename= str(model_name) + '.csv', append=True)
    callbacks = [early_stopping, ckpt, reduce_lr, log]
    # Augment
    X_train, y_train = augment_this(X_train, y_train, rounds=6)
    X_train, y_train = shuffle(X_train, y_train)

    history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=128, validation_split=0.2,
                        callbacks=callbacks)
    model.load_weights(ckpt_filepath)
    eval_model = model.evaluate(X_test, y_test)
    count_epochs = history.epoch[-1]+1
    return eval_model, history, count_epochs

In [None]:
# Visualization of mean_squared_error, root_mean_squared_error, and mean_absolute_error
def vis_results(model_name, history, epoch):
    fig, ax = plt.subplots(1, 3, figsize=[20,5])
    fig.suptitle(model_name + ' epoch = ' + str(epoch), fontsize=15)
    ax[0].plot(history.history['loss'][10:])
    ax[0].plot(history.history['val_loss'][10:])
    ax[0].set_ylabel('loss')
    ax[0].set_xlabel('epoch')
    ax[0].set_title('Mean Squared Error')
    ax[0].legend(['train', 'val'], loc='upper right')
    ax[1].plot(history.history['rmse'][10:])
    ax[1].plot(history.history['val_rmse'][10:])
    ax[1].set_ylabel('roor_mean_squared_error')
    ax[1].set_xlabel('epoch')
    ax[1].set_title('Root Mean Squared Error')
    ax[1].legend(['train', 'val'], loc='upper right')
    ax[2].plot(history.history['mean_absolute_error'][10:])
    ax[2].plot(history.history['val_mean_absolute_error'][10:])
    ax[2].set_ylabel('mean_absolute_error')
    ax[2].set_xlabel('epoch')
    ax[2].set_title('Mean Absolute Error')
    ax[2].legend(['train', 'val'], loc='upper right')
    plt.show()

In [None]:
X_train.shape

In [None]:
model_dict = {"cnn_lenet": {"model": cnn_lenet, "lr": 0.001}, 
              "mlp4": {"model": mlp4, "lr": 0.001},
              "cnn_vgg": {"model": cnn_vgg, "lr": 0.001},
              "lstm1v0": {"model": lstm1v0, "lr": 0.01},
              "lstm1": {"model": lstm1, "lr": 0.01},
              "lstm2": {"model": lstm2, "lr": 0.01},
              "blstm1": {"model": blstm1, "lr": 0.01},
              "blstm2": {"model": blstm1, "lr": 0.01},
              "lstm_fcn": {"model": lstm_fcn, "lr": 0.01},
              "cnn_resnet": {"model": cnn_resnet, "lr": 0.01}
             }

In [None]:
for model in model_dict: 
    print(model_dict[model]["lr"])

### cnn_lenet

In [None]:
results = {}
history = {}

In [None]:
import tempfile

In [None]:
for model in model_dict:
    print("##################################################")
    print("model: ", model)
    results[model] = []
    history[model] = {"mean_absolute_error": [],
                      "val_mean_absolute_error": []}
    for ii in range(10): 
        this_eval, this_history, this_epochs = model_history(model_dict[model]["model"],
                                                             tempfile.NamedTemporaryFile().name + ".h5", 
                                                             model_dict[model]["lr"], 
                                                             X_train, 
                                                             y_train)
        results[model].append(this_eval[1:])
        history[model]["mean_absolute_error"].append(this_history.history['mean_absolute_error'])
        history[model]["val_mean_absolute_error"].append(this_history.history['val_mean_absolute_error'])
        with open(f'results_{model}_all_metrics.pickle', 'wb') as file:
            pickle.dump(results[model], file, protocol=pickle.HIGHEST_PROTOCOL)
        with open(f'history_{model}_all_metrics.pickle', 'wb') as file:
            pickle.dump(history[model], file, protocol=pickle.HIGHEST_PROTOCOL)