In [1]:
from load_data import *
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
import numpy as np
from datasets import *
from models import *
from train_test_utils import *
from tabulate import tabulate
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
run_with_seed = 1964 #random.randint(0, 4294967295)
torch.manual_seed(run_with_seed)
np.random.seed(run_with_seed)
torch.cuda.manual_seed(run_with_seed)

# Load data
path = 'full_per_visit_data_2021-03-26_processed.csv'

# Path to save the processed version of the dataset, if needed
w_path = ''

input_feats = load_longitudinal_tabular_data(input_path=path, write_path=w_path, quick=True, write_csv=False)

# Because this variable should not be used
input_feats = input_feats.drop(columns=['aces_total'])

# We are predicting age and construct in a multi-task setting
age = False

# seq2seq dictates if we predict all time-steps or only the last one
seq2seq = False

# Options are 'negative_valence' and 'positive_valence'
construct = 'negative_valence'

In [4]:
# Using the 'no construct at all' label separate the subjects to diseased and control so that
# we can create stratified splits

labels = {}
labels_control_diseased = {}
ages = {}
for key in list(np.unique(input_feats.loc[:, 'subject'])):
    subj_v = input_feats[input_feats['subject'] == key]
    labels[key] = subj_v.loc[:, construct]
    ages[key] = subj_v.loc[:, 'visit_age']
    
split_stratified_labels = []
for key in list(np.unique(input_feats.loc[:, 'subject'])):
    split_stratified_labels.append(labels[key].max())

In [5]:
# Since if we are predicting the age as a task we don't use it as an input feature
if age:
    input_feats = input_feats.drop(columns=['visit_age'])
    
# We are doing 5-fold cross validation se we are creating dictionaries to save the data for each fold
partition = {}
folds = {}
subject_folds = {}
mfb_folds = {}
folds_of_the_labels = {}
folds_of_the_ages = {}
counter = 0

# We are creating the object scaler to normalize our input data from -1 to 1
scaler = MinMaxScaler(feature_range=(-1,1))

# We are performing stratified CV since we hace class imbalance
kf = StratifiedKFold(n_splits=5, shuffle=False)  # False for reproducible folds
X = np.array(list(np.unique(input_feats.loc[:, 'subject'])))
y = np.array(split_stratified_labels)
kf.get_n_splits(X, y)

5

In [6]:
# We want to separate the data to train and test sets for each fold, normalize them and also 
# exclude the 'fake control' subjects from the test set. 
# 'Fake controls' are subjects that have 'True' in any depression construct in at least one of the visits. 
# The four constructs are: 'positive_valence', 'negative_valence', 'arousal', 'cognitive' 

for train_index, test_index in kf.split(X, y):
    labels_fold = labels.copy()
    ages_fold = ages.copy()
    subject_ages = {}
    subject_post = {}

    train_subj = input_feats.loc[input_feats['subject'].isin(list(X[train_index]))]
    test_subj = input_feats.loc[input_feats['subject'].isin(list(X[test_index]))]
    # In this version of the dataset the 21 first columns contain features that should not be used as input 
    # for the prediction.
    # These features are the personal information, like sex, race etc., the labels that we are predicting and other
    # potential prediction variables.
    X_train = train_subj.iloc[:, 21:]
    X_train = scaler.fit_transform(X_train)

    X_train = pd.DataFrame(data=X_train, columns=train_subj.columns[21:])
    X_train = X_train.set_index(train_subj.index)
    X_train.insert(0, 'subject', train_subj.loc[:, 'subject'], True)

    X_test = test_subj.iloc[:, 21:]
    X_test = scaler.transform(X_test)
    
    X_test = pd.DataFrame(data=X_test, columns=test_subj.columns[21:])
    X_test = X_test.set_index(test_subj.index)
    X_test.insert(0, 'subject', test_subj.loc[:, 'subject'], True)

    partition['test'] = list()
    partition['train'] = list()
    for subject in input_feats.subject.unique():
        if subject in list(X[train_index]):
            subj_visits = X_train[X_train['subject'] == subject]
            subject_ages[subject] = subj_visits
            partition['train'].append(subject)
            
        elif subject in list(X[test_index]):
            subj_visits = X_test[X_test['subject'] == subject]
            subject_ages[subject] = subj_visits
            partition['test'].append(subject)

    folds[counter] = partition.copy()

    # Subject-specific dataset with all the visits and post-processed for training
    for key in list(partition['train'] + partition['test']):
        df = subject_ages[key]
        df = df.iloc[:, 1:]
        subject_post[key] = df

    subject_folds[counter] = subject_post.copy()
    folds_of_the_labels[counter] = labels_fold.copy()
    folds_of_the_ages[counter] = ages_fold.copy()

    # Since we have class imbalance we are using weights in the binary cross entropy during training
    # These weights are calculated separately for each fold since we only take the training 
    # set into account to calculate them.
    y_train = train_subj.loc[:, construct]
    number_neg_samples = np.sum(y_train.values == False)
    num_pos_samples = np.sum(y_train.values == True)
    mfb = number_neg_samples / num_pos_samples
    mfb_folds[counter] = mfb.copy()

    counter += 1

In [7]:
# Parameters
params = {'shuffle': True,
          'num_workers': 0,
          'batch_size': 1} # One batch contains all the visits of each subject

results = {}
avg_results_dict = {}

In [8]:
# Train models for each fold

for fold in folds.keys():

    # Get the data per fold
    partition = folds[fold]
    subject_post = subject_folds[fold]
    pos_weight = mfb_folds[fold]
    labels_f = folds_of_the_labels[fold]
    ages_f = folds_of_the_ages[fold]

    # Dataset generators
    training_set = Dataset(partition['train'], subject_post, labels_f, ages_f, age)
    training_generator = torch.utils.data.DataLoader(training_set, **params)

    validation_set = Dataset(partition['test'], subject_post, labels_f, ages_f, age)
    validation_generator = torch.utils.data.DataLoader(validation_set, **params)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # The parameters I trained with, after hyperparameter tuning
    epoch = 0
    max_epochs = 30

    feature_dim = 128
    input_dim = next(iter(training_generator))[0].shape[2]
    output_dim = 1
    n_layers = 1
    hidden_dim = 64

    
    if age:
        model = AgeGRUNet(feature_dim=feature_dim, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
                          n_layers=n_layers, seq2seq=seq2seq, device=device)
    else:
        model = GRUNet(feature_dim=feature_dim, input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim,
                          n_layers=n_layers, seq2seq=seq2seq, device=device)

    # Loss chosen for binary classification task
    score_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight).float().to(device))

    if age:
        # Loss chosen for age prediction
        age_criterion = nn.MSELoss()
        criterion = {}
        criterion['score'] = score_criterion
        criterion['age'] = age_criterion
    else:
        criterion = score_criterion

    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)

    model.to(device)

    model.train()
    
    if age:
        model_trained, h = train_gru_age(model=model, criterion=criterion, optimizer=optimizer, max_epochs=max_epochs,
                                         train_loader=training_generator, val_loader=validation_generator,
                                         device=device, seq2seq=seq2seq, params=params)
    else:
        model_trained, h = train_gru(model=model, criterion=criterion, optimizer=optimizer, max_epochs=max_epochs,
                                         train_loader=training_generator, val_loader=validation_generator,
                                         device=device, seq2seq=seq2seq, params=params)
    
    # Path to save the models
    if age:
        output_path = f'{construct}_fold_{fold}_tabular_no_aces.ckpt'
    else:
        output_path = f'{construct}_fold_{fold}_tabular_no_aces_no_age.ckpt'
    
    print(f'Model training complete for fold {fold} complete.')
    # Uncomment to save models
    #torch.save(model_trained.state_dict(), output_path)
    #print(f'Model Saved at: {output_path}')
    
    if seq2seq:
        if age:
            results[f'split{fold}'] = evaluate_all_timesteps_age_per_subject(model=model_trained, val_loader=validation_generator,
                                                                 hidden=h, device=device)
        else:
            results[f'split{fold}'] = evaluate_all_timesteps_per_subject(model=model_trained, val_loader=validation_generator, hidden=h, device=device)
    else:
        results[f'split{fold}'] = evaluate_last_timestep(model=model_trained, val_loader=validation_generator, device=device)


Model training complete for fold 0 complete.
Model training complete for fold 1 complete.
Model training complete for fold 2 complete.
Model training complete for fold 3 complete.
Model training complete for fold 4 complete.


In [9]:
# Average the results over the 5 folds and print the metrics
avg_acc = 0.0
avg_bacc = 0.0
avg_f1 = 0.0
subj_acc = 0.0
subj_macro_acc = 0.0

for key in results.keys():
    subj_macro_acc += results[key]['subject_macro_accuracy']
    subj_acc += results[key]['subject_accuracy']
    avg_acc += results[key]['accuracy']
    avg_bacc += results[key]['balanced_accuracy']
    avg_f1 += results[key]['f1-score']

avg_results_dict['subject_accuracy'] = subj_acc / len(folds.keys())
avg_results_dict['subject_macro_accuracy'] = subj_macro_acc / len(folds.keys())
avg_results_dict['accuracy'] = avg_acc / len(folds.keys())
avg_results_dict['macro_accuracy'] = avg_bacc / len(folds.keys())
avg_results_dict['f1-score'] = avg_f1 / len(folds.keys())
print(f'Average results for {construct}:')
print(avg_results_dict)

Average results for negative_valence:
{'subject_accuracy': 0.8743225806451613, 'subject_macro_accuracy': 0.7455922744294735, 'accuracy': 0.8743225806451613, 'macro_accuracy': 0.7455922744294735, 'f1-score': 0.6481935022869874}
