In [1]:
import afqinsight.nn.tf_models as nn
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from afqinsight.datasets import AFQDataset
from afqinsight.nn.tf_models import cnn_lenet, mlp4, cnn_vgg, lstm1v0, lstm1, lstm2, blstm1, blstm2, lstm_fcn, cnn_resnet
from sklearn.impute import SimpleImputer
import os.path
# Harmonization
from sklearn.model_selection import train_test_split
from neurocombat_sklearn import CombatModel
import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
afq_dataset = AFQDataset.from_files(
    fn_nodes="../data/raw/combined_tract_profiles.csv",
    fn_subjects="../data/raw/participants_updated_id.csv",
    dwi_metrics=["dki_fa", "dki_md", "dki_mk"],
    index_col="subject_id",
    target_cols=["age", "dl_qc_score", "scan_site_id"],
    label_encode_cols=["scan_site_id"]
)

In [3]:
afq_dataset.drop_target_na()

In [4]:
print(len(afq_dataset.subjects))
print(afq_dataset.X.shape)
print(afq_dataset.y.shape)

1865
(1865, 7200)
(1865, 3)


In [5]:
full_dataset = list(afq_dataset.as_tensorflow_dataset().as_numpy_iterator())

In [6]:
X = np.concatenate([xx[0][None] for xx in full_dataset], 0)
y = np.array([yy[1][0] for yy in full_dataset])
qc = np.array([yy[1][1] for yy in full_dataset])
site = np.array([yy[1][2] for yy in full_dataset])

In [7]:
X = X[qc>0]
y = y[qc>0]
site = site[qc>0]

In [8]:
n_epochs = 1000

# EarlyStopping
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    mode="min",
    patience=100
)

# ReduceLROnPlateau
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.5,
    patience=20,
    verbose=1,
)

In [9]:
from afqinsight.augmentation import jitter, time_warp, scaling

In [10]:
def augment_this(X, y, rounds=2): 
    new_X = X[:]
    new_y = y[:]
    for f in range(rounds): 
        aug_X = np.zeros_like(X)
        # Do each channel separately:
        for channel in range(aug_X.shape[-1]):
            this_X = X[..., channel][..., np.newaxis]
            this_X = jitter(this_X, sigma=np.mean(this_X)/25)
            this_X = scaling(this_X, sigma=np.mean(this_X)/25)
            this_X = time_warp(this_X, sigma=np.mean(this_X)/25)
            aug_X[..., channel] = this_X[...,0]
        new_X = np.concatenate([new_X, aug_X])
        new_y = np.concatenate([new_y, y])
    return new_X, new_y 

In [11]:
from sklearn.utils import shuffle, resample

In [12]:
from sklearn.metrics import r2_score, median_absolute_error

In [13]:
# Generate evaluation results, training history, number of epochs
def fit_model(model_name, ckpt_filepath, lr, X, y, random_state, augment=True):
    # Split the data into train and test sets:
    X_train, X_test, y_train, y_test, site_train, site_test = train_test_split(X, y, site, test_size=0.2, random_state=random_state)
    imputer = SimpleImputer(strategy="median")
    # Impute train and test separately:
    X_train = np.concatenate([imputer.fit_transform(X_train[..., ii])[:, :, None] for ii in range(X_train.shape[-1])], -1)
    X_test = np.concatenate([imputer.fit_transform(X_test[..., ii])[:, :, None] for ii in range(X_test.shape[-1])], -1)
    # Combat
    X_train = np.concatenate([CombatModel().fit_transform(X_train[..., ii], site_train[:, None], None, None)[:, :, None] for ii in range(X_train.shape[-1])], -1)
    X_test = np.concatenate([CombatModel().fit_transform(X_test[..., ii], site_test[:, None], None, None)[:, :, None] for ii in range(X_test.shape[-1])], -1)

    model = model_name(input_shape=(100, X_train.shape[-1]), n_classes=1, output_activation=None, verbose=True)
    model.compile(loss='mean_squared_error',
                  optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                  metrics=['mean_squared_error',
                           tf.keras.metrics.RootMeanSquaredError(name='rmse'), 
                           'mean_absolute_error'])
    # ModelCheckpoint
    ckpt = tf.keras.callbacks.ModelCheckpoint(
    ckpt_filepath,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="auto",
    )
    callbacks = [early_stopping, ckpt, reduce_lr]
    if augment:
        X_train, y_train = augment_this(X_train, y_train, rounds=6)
        X_train, y_train = shuffle(X_train, y_train)

    history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=128, validation_split=0.2,
                        callbacks=callbacks, verbose=0, use_multiprocessing=True)
    
    model.load_weights(ckpt_filepath)
    eval_model = model.evaluate(X_test, y_test)
    y_pred= model.predict(X_test)
    eval_model.append(median_absolute_error(y_test, y_pred))
    eval_model.append(r2_score(y_test, y_pred))
    return eval_model, history

In [14]:
model_dict = {"cnn_lenet": {"model": cnn_lenet, "lr": 0.001}, 
              "mlp4": {"model": mlp4, "lr": 0.001},
              "cnn_vgg": {"model": cnn_vgg, "lr": 0.001},
              "lstm1v0": {"model": lstm1v0, "lr": 0.01},
              "lstm1": {"model": lstm1, "lr": 0.01},
              "lstm2": {"model": lstm2, "lr": 0.01},
              "blstm1": {"model": blstm1, "lr": 0.01},
              "blstm2": {"model": blstm1, "lr": 0.01},
              "lstm_fcn": {"model": lstm_fcn, "lr": 0.01},
              "cnn_resnet": {"model": cnn_resnet, "lr": 0.01}
             }

### cnn_lenet

In [15]:
results = {}
history = {}

In [16]:
n_runs = 10
augment=False

In [17]:
seeds = np.array([np.abs(np.floor(np.random.randn()*1000)) for ii in range(n_runs)], dtype=int)

In [18]:
import tempfile

In [None]:
for model in model_dict:
    print("##################################################")
    print("model: ", model)
    results[model] = []
    history[model] = {"mean_absolute_error": [],
                      "val_mean_absolute_error": []}
    for ii in range(n_runs): 
        this_eval, this_history = fit_model(model_dict[model]["model"],
                                            tempfile.NamedTemporaryFile().name + ".h5", 
                                            model_dict[model]["lr"], 
                                            X, 
                                            y, 
                                            random_state=seeds[ii],
                                            augment=augment)
        results[model].append(this_eval[1:])
        history[model]["mean_absolute_error"].append(this_history.history['mean_absolute_error'])
        history[model]["val_mean_absolute_error"].append(this_history.history['val_mean_absolute_error'])
        with open(f'results_{model}_no_aug.pickle', 'wb') as file:
            pickle.dump(results[model], file, protocol=pickle.HIGHEST_PROTOCOL)
        with open(f'history_{model}_no_aug.pickle', 'wb') as file:
            pickle.dump(history[model], file, protocol=pickle.HIGHEST_PROTOCOL)

##################################################
model:  cnn_lenet
pooling layers: 4
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 100, 72)]         0         
                                                                 
 conv1d (Conv1D)             (None, 100, 6)            1302      
                                                                 
 max_pooling1d (MaxPooling1D  (None, 50, 6)            0         
 )                                                               
                                                                 
 conv1d_1 (Conv1D)           (None, 50, 16)            304       
                                                                 
 max_pooling1d_1 (MaxPooling  (None, 25, 16)           0         
 1D)                                                             
                                        