### Demo notebook for COMICAL 

#### Load libraries

In [None]:
import os, argparse, time, json
from src.train_eval import train_eval
from src.utils import plot_training_curves, plot_roc_curve, plot_precision_recall_curve, select_gpu
from tabulate import tabulate

import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from src.train import train
from src.test import test_model
from src.dataset_template import dataset

#### Set paths

In [None]:
import os
fname_out_root = 'demo_run'
path_data = 'data'
top_n_perc = '0.5'
paths = {
    'path_mod_a' : os.path.join(os.getcwd(),path_data,'snp-encodings-from-vcf.csv'),
    'path_mod_b' : os.path.join(os.getcwd(),path_data,'T1_struct_brainMRI_IDPs.csv'),
    'path_pairs' : os.path.join(os.getcwd(),path_data,'pairs.csv'),
    'path_mod_b_map' : os.path.join(os.getcwd(),path_data,'T1mri.csv'),
    'path_res' : os.path.join(os.getcwd(),'results'),
    'tensorboard_log': os.path.join(os.getcwd(),'results',fname_out_root,'tensorboard_logs'),
    'wd' : os.getcwd(),
    'path_target_labels' : os.path.join(os.getcwd(),path_data,'neuroDx.csv'),
    'path_covariates' : os.path.join(os.getcwd(),path_data,'neuroDx_geneticPCs.csv'),
    'path_mod_a2group_map' : os.path.join(os.getcwd(),path_data,'SNPs_and_disease_mapping_with_pvalues.csv'),
    'path_mod_b2group_map' : os.path.join(os.getcwd(),path_data,'IDPs_and_disease_mapping.csv'),
    'path_saved_pairs' : os.path.join(path_data,'pairs_top_n_'+top_n_perc+'.pickle'),
    'path_data' : os.path.join(os.getcwd(),path_data),
}


In [None]:
default_values = {
    'save_embeddings': '0',
    'plot_embeddings': '0',
    'top_n_perc': '0.5',
    'resume_from_batch': '0',
    'ckpt_name': 'None',
    'downstream_pred_task_flag': '0',
    'out_flag': 'pairs',
    'target': 'PD',
    'index_col': 'eid',
    'feat_a_index_col': 'SNPs',
    'feat_b_index_col': 'IDPs',
    'feat_a_target_col': 'Disease',
    'feat_b_target_col': 'Disease',
    'coveriate_names': 'Age, Sex',
    'count_bins': '64'
}


In [None]:
args = {
    "batch_size": '32768',
    'save_embeddings': '0',
    'plot_embeddings': '0',
    'top_n_perc': '0.5',
    'resume_from_batch': '0',
    'ckpt_name': 'None',
    'downstream_pred_task_flag': '0',
    'out_flag': 'pairs',
    'target': 'PD',
    'index_col': 'eid',
    'feat_a_index_col': 'SNPs',
    'feat_b_index_col': 'IDPs',
    'feat_a_target_col': 'Disease',
    'feat_b_target_col': 'Disease',
    'coveriate_names': 'Age, Sex',
    'count_bins': '64'
    'ckpt_name': args.ckpt_name,
    'covariates_names': list(args.coveriate_names),
    'dim_feedforward': int(args.dim_feedforward),
    'target': args.target,
    'd_model': int(args.d_model),
    'dropout': float(args.dropout),
    'epochs': int(args.epochs),
    'feat_a_index_col':args.feat_a_index_col,
    'feat_b_index_col':args.feat_b_index_col,
    'feat_a_target_col':args.feat_a_target_col,
    'feat_b_target_col':args.feat_b_target_col,
    'fname_root_out': args.fname_out_root,
    'gpus_per_trial': args.gpus_per_trial,
    'index_col': args.index_col if type(args.index_col) == str else int(args.index_col),
    'learning_rate': float(args.learning_rate),
    'nhead': int(args.nhead),
    'num_layers': int(args.num_layers),
    'out_flag': args.out_flag,
    'pairs_exist': pairs_exist,
    'plot_embeddings': bool(int(args.plot_embeddings)),
    'rnd_st': int(args.random_seed),
    'resume_from_batch': bool(int(args.resume_from_batch)),
    'save_embeddings': bool(int(args.save_embeddings)),
    'downstream_pred_task_flag': bool(int(args.downstream_pred_task_flag)),
    'test_size': float(int(args.test_size) / 100),
    'top_n_perc': float(args.top_n_perc),
    'tune_flag': bool(int(args.tune_flag)),
    'units': int(args.units),
    'val_size': float(int(args.val_size) / 100),
}

#### Load and create dataset

In [None]:
# Load data and create dataset object
data = dataset(paths, args)

# Get data splits and confirm there is no overlap between them.
train_idx, val_idx, test_idx = data.get_data_splits()

#### Set configuration

In [None]:
config = {
    'train_index' : train_idx,
    'val_index' : val_idx,
    'test_index' : test_idx,
    "random_seed": '42',
    "val_size": '20',
    "test_size": '10',
    "gpu_nums": '7',
    "tune_flag": '0',
    "gpus_per_trial": '1',
    "batch_size": '32768',
    "lr": '0.01',
    "epochs": '2',
    "num_layers": '2',
    "d_model": '64',
    "nhead": '4',
    "dim_feedforward": '32',
    "dropout": '0.0',
    "units": '16',
    'tune': False,
    'tensorboard_log_path':paths['tensorboard_log'],
    'num_snps':10000,
    'num_idps':139,
    'idp_tok_dims':64,
    'save_embeddings':False,
    'plot_embeddings':False,
    'results_path':os.path.join(os.getcwd(),'results',fname_out_root),
    'resume_from_batch':args['resume_from_batch'],
    'ckpt_name':args['ckpt_name'],
    'subject_based_pred_flag':args['downstream_pred_task_flag'],
    'out_flag':args['out_flag'],
    'target':args['target'],
}

In [None]:
run_args = {
    'batch_size': int(args.batch_size),
    'ckpt_name': args.ckpt_name,
    'count_bins': int(args.count_bins),
    'covariates_names': list(args.coveriate_names),
    'dim_feedforward': int(args.dim_feedforward),
    'target': args.target,
    'd_model': int(args.d_model),
    'dropout': float(args.dropout),
    'epochs': int(args.epochs),
    'feat_a_index_col':args.feat_a_index_col,
    'feat_b_index_col':args.feat_b_index_col,
    'feat_a_target_col':args.feat_a_target_col,
    'feat_b_target_col':args.feat_b_target_col,
    'fname_root_out': fname_out_root,
    'gpus_per_trial': args.gpus_per_trial,
    'index_col': args.index_col if type(args.index_col) == str else int(args.index_col),
    'learning_rate': float(args.learning_rate),
    'nhead': int(args.nhead),
    'num_layers': int(args.num_layers),
    'out_flag': args.out_flag,
    'pairs_exist': pairs_exist,
    'plot_embeddings': bool(int(args.plot_embeddings)),
    'rnd_st': int(args.random_seed),
    'resume_from_batch': bool(int(args.resume_from_batch)),
    'save_embeddings': bool(int(args.save_embeddings)),
    'downstream_pred_task_flag': bool(int(args.downstream_pred_task_flag)),
    'test_size': float(int(args.test_size) / 100),
    'top_n_perc': float(args.top_n_perc),
    'tune_flag': bool(int(args.tune_flag)),
    'units': int(args.units),
    'val_size': float(int(args.val_size) / 100),
}

#### Run training

In [None]:
### Train model ###
train_losses, val_losses = train(config, data=data, checkpoint_dir = paths['checkpoint_name'])


#### Run test

In [None]:
## Evaluate model ##
# Select checkpoint with the lowest loss on validation set
best_epoch = np.argmin(val_losses)
best_checkpoint_path = os.path.join(paths['checkpoint_name'], f'checkpoint_epoch_{best_epoch}')

# Evaluate on test set - using best checkpointed model
loss_test, acc_test, _ = test_model(config, args ,data, test_idx, best_checkpoint_path)

# Return dictionary with results, data info
results_dict = {
    'metrics':{
        'loss_test':loss_test,
        'acc_test':acc_test,
    },
    'data':{
        'train_losses': train_losses,
        'val_losses':val_losses,
    },
    'hyperparams':{
        'lr':config["lr"],
        'batch_size':config["batch_size"],
        'units':config["units"],
        'd_model':config["d_model"],
        'nhead':config["nhead"],
        'dim_feedforward':config["dim_feedforward"],
        'dropout':config["dropout"],
        'layer_norm_eps':config["layer_norm_eps"],
        'activation':config["activation"],
        'checkpoint_path':best_checkpoint_path
    }
}

### Print results

In [None]:
 print(f'Saving results dictionary in {os.path.join(os.getcwd(),"results",fname_out_root,"result_dict.json")}')
    with open(os.path.join(os.getcwd(),'results',fname_out_root,'result_dict.json'), "w") as outfile:
        json.dump(results_dict, outfile)
    
    print(f'Test set loss {results_dict["metrics"]["loss_test"]}')
    print(f'Test set top-1 accuracy {results_dict["metrics"]["acc_test"]}')

    # Plot losses and result curves
    plot_training_curves(results_dict['data']['train_losses'], results_dict['data']['val_losses'], os.path.join(os.getcwd(),'results',fname_out_root,'training_curves.pdf'))
    if args.out_flag == 'clf':
        plot_roc_curve(results_dict['data']['test_preds'], results_dict['data']['test_labels'], os.path.join(os.getcwd(),'results',fname_out_root,'roc_curve.pdf'))
        plot_precision_recall_curve(results_dict['data']['test_preds'], results_dict['data']['test_labels'], os.path.join(os.getcwd(),'results',fname_out_root,'precision_recall_curve.pdf'))

    # Print hyperparameter configuration and results metrics
    print("Hyperparameter configuration and results metrics:")
    table_data = []
    # Add hyperparameter configuration to table data
    for key, value in results_dict['hyperparams'].items():
        table_data.append([key, value])

    # Add results metrics to table data
    for key, value in results_dict['metrics'].items():
        table_data.append([key, value])

    table = tabulate(table_data, headers=["Parameter", "Value"], tablefmt="grid")

    # Save table to txt file and print out
    with open(os.path.join(os.getcwd(),'results',fname_out_root,'results_and_config_out.txt'), "w") as outfile:
        outfile.write(table)
    print(table)