In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.classification as M
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
from model_architectures import FNNEegSignals

from sklearn.model_selection import train_test_split, StratifiedKFold
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gc
import time
import os
import json
import copy
import subprocess
from IPython import display
from scipy.signal import savgol_filter
display.set_matplotlib_formats('svg')

#### Experiment Configuration

In [None]:
# Change this to swap between CPU/GPU
device = torch.device('cuda:0')
# device = torch.device('cpu')

# Fetch experiment configuration from disk
with open('exp_config.json', 'r') as file:
    json_data = file.read()
    exp_conf = json.loads(json_data)

#### Read the data

In [None]:
dir_data = f'{exp_conf["data_path"]}/centralized-{exp_conf["dataset"]}'

# Training Dataset
x_train = pd.read_csv(f'{dir_data}/eeg_x_train.csv')
y_train = pd.read_csv(f'{dir_data}/eeg_y_train.csv')

# Validation Dataset
x_val = pd.read_csv(f'{dir_data}/eeg_x_val.csv')
y_val = pd.read_csv(f'{dir_data}/eeg_y_val.csv')

# Test Dataset
x_test = pd.read_csv(f'{dir_data}/eeg_x_test.csv')
y_test = pd.read_csv(f'{dir_data}/eeg_y_test.csv')

#### Data Pre-Processing

In [None]:
## Normalization (Z-Score)
# Training Set
x_train_mean = x_train.mean()
x_train_sd = x_train.std() 
x_train_norm = (x_train-x_train_mean)/x_train_sd
x_train= x_train_norm

# Validation Set
x_val_mean = x_val.mean()
x_val_sd = x_val.std() 
x_val_norm = (x_val-x_val_mean)/x_val_sd
x_val= x_val_norm

# Testing Set
x_test_mean = x_test.mean()
x_test_sd = x_test.std() 
x_test_norm = (x_test-x_test_mean)/x_test_sd
x_test= x_test_norm

#### Training/Evaluation Functions

In [None]:
def get_vram_usage():
    try:
        cmd = ['nvidia-smi', '--query-gpu=memory.used', '--format=csv,noheader,nounits']
        result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
        if result.returncode == 0:
            vram_used = int(result.stdout.strip())
            return vram_used
        else:
            print("Error:", result.stderr, flush='True')
    except Exception as e:
        print("An error occurred:", e, flush='True')
    return None

def train(model, x_train, y_train, x_val, y_val, exp_conf, exp_name, filepath):
    # Get the value to scale the weights based on class imbalance
    n_pos = torch.tensor(y_train[y_train==1].shape[0]).float()
    n_neg = torch.tensor(y_train[y_train==0].shape[0]).float()
    pos_weight = n_neg / n_pos

    # Loss Function
    loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    # Optimizer
    optimizer = getattr(torch.optim, exp_conf['optimizer'])
    optimizer = optimizer(model.parameters(), lr=exp_conf['learning_rate'])

    # Metrics Accumulators
    train_losses = []
    train_accuracies = []
    train_sensitivities = []
    train_specificities = []
    val_losses = []
    val_accuracies = []
    val_sensitivities = []
    val_specificities = []
    params_best_model = None
    epoch_best_model = None

    # Define evaluation metrics settings
    accuracy_metric = M.BinaryAccuracy(threshold=exp_conf['decision_boundary']).to(device)
    recall_metric = M.BinaryRecall(threshold=exp_conf['decision_boundary']).to(device)
    specificity_metric = M.BinarySpecificity(threshold=exp_conf['decision_boundary']).to(device)

    # For each training epoch
    for epoch_i in range(exp_conf['n_epochs']):
        ## Training
        model.train()
        
        # Update on progress
        print(f'Running {exp_name}, epoch {epoch_i} of {exp_conf["n_epochs"]-1}')
        display.clear_output(wait=True)
        
        # Forward Pass
        y_hat = model(x_train)

        # Compute the Loss
        train_loss = loss_function(y_hat, y_train)
        train_losses.append(train_loss)

        # Backpropagation
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        ## Compute epoch training metrics
        y_pred = torch.sigmoid(y_hat)
        train_accuracy = accuracy_metric(y_pred, y_train)
        train_accuracies.append(train_accuracy)
        train_sensitivity = recall_metric(y_pred, y_train)
        train_sensitivities.append(train_sensitivity)
        train_specificity = specificity_metric(y_pred, y_train)
        train_specificities.append(train_specificity)

        ## Compute epoch evaluation metrics
        val_results = evaluate(model, x_val, y_val, exp_conf)

        # Update best model
        prev_losses = torch.as_tensor(val_losses).to(device)
        prev_losses_bigger = prev_losses > val_results['loss']
        if torch.all(prev_losses_bigger==True):
            # print("Look at me, I'm the model now")
            epoch_best_model = epoch_i
            params_best_model = copy.deepcopy(model.state_dict())

        val_losses.append(val_results['loss'])
        val_accuracies.append(val_results['accuracy'])
        val_sensitivities.append(val_results['sensitivity'])
        val_specificities.append(val_results['specificity'])


    # Save the best model found to disk
    if (filepath!=None):
        torch.save(params_best_model, filepath)

    # Collect all objects in CPU
    result = {
        # Training
        'train_losses': torch.as_tensor(train_losses).cpu()
        , 'train_accuracies': torch.as_tensor(train_accuracies).cpu()
        , 'train_sensitivities': torch.as_tensor(train_sensitivities).cpu()
        , 'train_specificities': torch.as_tensor(train_specificities).cpu()
        # Validation
        , 'val_losses': torch.as_tensor(val_losses).cpu()
        , 'val_accuracies': torch.as_tensor(val_accuracies).cpu()
        , 'val_sensitivities': torch.as_tensor(val_sensitivities).cpu()
        , 'val_specificities': torch.as_tensor(val_specificities).cpu()
        , 'epoch_best_model': epoch_best_model
    }

    return result

def evaluate(model, x_val, y_val, exp_conf):
    model.eval()
    # Get the value to scale the weights based on class imbalance
    n_pos = torch.tensor(y_val[y_val==1].shape[0]).float()
    n_neg = torch.tensor(y_val[y_val==0].shape[0]).float()
    pos_weight = n_neg / n_pos

    # Loss Function
    loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    with torch.no_grad():
        y_hat = model(x_val)
        loss = loss_function(y_hat, y_val)
        y_pred = torch.sigmoid(y_hat)

        # Define evaluation metrics settings
        accuracy_metric = M.BinaryAccuracy(threshold=exp_conf['decision_boundary']).to(device)
        recall_metric = M.BinaryRecall(threshold=exp_conf['decision_boundary']).to(device)
        specificity_metric = M.BinarySpecificity(threshold=exp_conf['decision_boundary']).to(device)

        # Compute metrics
        accuracy = accuracy_metric(y_pred, y_val)
        sensitivity = recall_metric(y_pred, y_val)
        specificity = specificity_metric(y_pred, y_val)

        results = {'loss': loss, 'accuracy': accuracy, 'sensitivity': sensitivity, 'specificity': specificity}

    return results

def test(model, x_val, y_val, exp_conf):
    model.eval()
    # Get the value to scale the weights based on class imbalance
    n_pos = torch.tensor(y_val[y_val==1].shape[0]).float()
    n_neg = torch.tensor(y_val[y_val==0].shape[0]).float()
    pos_weight = n_neg / n_pos
    
    # Loss Function
    loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    with torch.no_grad():
        y_hat = model(x_val)
        loss = loss_function(y_hat, y_val)
        y_pred = torch.sigmoid(y_hat)

        # Define test metrics settings
        accuracy_metric = M.BinaryAccuracy(threshold=exp_conf['decision_boundary']).to(device)
        recall_metric = M.BinaryRecall(threshold=exp_conf['decision_boundary']).to(device)
        specificity_metric = M.BinarySpecificity(threshold=exp_conf['decision_boundary']).to(device)
        confusion_matrix_metric = M.BinaryConfusionMatrix(threshold=exp_conf['decision_boundary']).to(device)
        roc_metric = M.BinaryROC(thresholds=100).to(device)
        auroc_metric = M.BinaryAUROC().to(device)

        # Compute metrics
        accuracy = accuracy_metric(y_pred, y_val)
        sensitivity = recall_metric(y_pred, y_val)
        specificity = specificity_metric(y_pred, y_val)
        confusion_matrix = confusion_matrix_metric(y_pred, y_val)
        auroc = auroc_metric(y_pred, y_val)
        roc_fpr, roc_tpr, roc_thresholds = roc_metric(y_pred, y_val.long())

        results = {
            'loss': loss, 'accuracy': accuracy, 'sensitivity': sensitivity, 'specificity': specificity
            , 'confusion_matrix': confusion_matrix, 'roc_fpr': roc_fpr.cpu(), 'roc_tpr': roc_tpr.cpu()
            , 'roc_thresholds': roc_thresholds.cpu(), 'auroc': auroc
        }

    return results

def compile_results(epochs_results, inference_results, exp_conf, elapsed_time, filepath):
    exp_conf['experiment_name'] = f"{exp_conf['experiment_name']}_{exp_conf['dataset']}"
    epoch_best_model = epochs_results['epoch_best_model']
    epochs_results = {key: epochs_results[key] for key in epochs_results.keys() if key not in 'epoch_best_model'}

    ## Training Results (Wide Format)
    results_wide_df = pd.DataFrame.from_dict(epochs_results, orient='columns')
    results_wide_df['epoch'] = [epoch for epoch in range(exp_conf['n_epochs'])]
    results_wide_df['exp_name'] = exp_conf['experiment_name']

    ## Training Results (Long Format)
    results_long_df = pd.melt(frame=results_wide_df, id_vars=['epoch', 'exp_name'], value_vars=epochs_results.keys(), var_name='metric', value_name='value')
    results_long_df = results_long_df.sort_values('epoch')

    ## Training Summary 
    roc_dict = {key: inference_results[key].tolist() for key in ['roc_fpr', 'roc_tpr', 'roc_thresholds']}
    exp_conf_simpl = {key: value for key, value in exp_conf.items() if key not in ['data_path', 'models_path', 'logs_path']}

    summary_exp = {
        'exp_name': [exp_conf['experiment_name']]
        , 'exp_type': [exp_conf['experiment_type']]
        , 'exp_resource': [exp_conf['resource']]
        , 'exp_device': [exp_conf['device']]
        , 'exp_configuration': [json.dumps(exp_conf_simpl)]
        , 'elapsed_time': [elapsed_time]
        , 'epoch_best_model': epoch_best_model
        , 'loss': [inference_results['loss'].item()]
        , 'accuracy': [inference_results['accuracy'].item()]
        , 'sensitivity': [inference_results['sensitivity'].item()]
        , 'specificity': [inference_results['specificity'].item()]
        , 'auroc': [inference_results['auroc'].item()]
        , 'tp': [inference_results['confusion_matrix'].flatten()[0].item()]
        , 'fn': [inference_results['confusion_matrix'].flatten()[1].item()]
        , 'fp': [inference_results['confusion_matrix'].flatten()[2].item()]
        , 'tn': [inference_results['confusion_matrix'].flatten()[3].item()]
        , 'roc_thresholds': [json.dumps(roc_dict)]
    }

    exp_summary_df = pd.DataFrame.from_dict(summary_exp, orient='columns')

    # Persist experiments to disk (If path is provided)
    if(filepath!=None):
        filename_train_wide = f'{filepath}/results_test_wide.csv'
        results_wide_df.to_csv(filename_train_wide, index=False, mode='a', header=not os.path.exists(filename_train_wide))
        filename_train_long = f'{filepath}/results_test_long.csv'
        results_long_df.to_csv(filename_train_long, index=False, mode='a', header=not os.path.exists(filename_train_long))
        filename_summary = f'{filepath}/exp_summary_test.csv'
        exp_summary_df.to_csv(filename_summary, index=False, mode='a', header=not os.path.exists(filename_summary))
    
    return {
        'results': {'wide': results_wide_df, 'long': results_long_df}
        , 'summary_exp': exp_summary_df
    }

#### Fit the Model

In [None]:
# Load model and data
fnn = FNNEegSignals(exp_conf['n_layers'], exp_conf['n_units'], exp_conf['perc_dropout'])
fnn.to(device)

# Load the data sets into PyTorch Tensors
# Training Dataset
x_train = torch.tensor(x_train.values, device=device).float()
y_train = torch.tensor(y_train.values, device=device).float()
# Validation Dataset
x_val = torch.tensor(x_val.values, device=device).float()
y_val = torch.tensor(y_val.values, device=device).float()
# Testing Dataset
x_test = torch.tensor(x_test.values, device=device).float()
y_test = torch.tensor(y_test.values, device=device).float()

# Train the model and store the best model
filepath_model = f'{exp_conf["models_path"]}/{exp_conf["experiment_name"]}.pt'


start_time = time.time()
epoch_results = train(fnn, x_train, y_train, x_val, y_val, exp_conf, f'{exp_conf["experiment_name"]}', filepath_model)
end_time = time.time()
elapsed_time = end_time - start_time
epoch_results = {key.replace('val', 'test') if 'val' in key else key: np.array(epoch_results[key]) for key in epoch_results.keys()}

# Restore best model for inference
model = FNNEegSignals(exp_conf['n_layers'], exp_conf['n_units'], exp_conf['perc_dropout'])
model.load_state_dict(torch.load(filepath_model))
model.eval()
model.to(device)

# Final Model Evaluation
inference_results = test(model, x_test, y_test, exp_conf)

# Free GPU memory after each run
del fnn, model, x_train, y_train, x_val, y_val, x_test, y_test
torch.cuda.empty_cache() 
gc.collect()

In [None]:
comp_test_results = compile_results(epoch_results, inference_results, exp_conf, elapsed_time, exp_conf['logs_path'])