Libraries/module Imports

In [1]:
#   Custom imports for testing
from classification.cnn import CNNClassifier, GraphDataset, DataLoader
from utils.import_data import import_all_data, import_panda_csv
from utils.constants import BAND_NAMES
from utils.plot_fig import (
    plot_avg_roc_curve,
    plot_accuracies,
    plot_losses
)

%load_ext autoreload
%autoreload 2

In [2]:
#   Python imports
import numpy as np

In [3]:
#   Plotting imports
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
sns.set_theme(style="whitegrid")

In [4]:
import warnings

# Suppress specific warning
warnings.filterwarnings("ignore", category=RuntimeWarning, module="scipy.stats")

Import data

In [5]:
labels = [c.strip() for c in import_panda_csv("../data/mTBI/sources_TBI_MEGM001.csv").columns]
graphs, targets = import_all_data("../graphs/multiplex/MI")
graphs = [graph.to_numpy() for graph in graphs]

[2KData loaded successfully!


In [6]:
total_labels = ["{}_{}".format(band, label) for band in BAND_NAMES for label in labels]

Nomralise the data

In [7]:
for g in graphs:
    norm = np.linalg.norm(g)
    g = g/norm

Convolutional Neural Network

We use the architecture with the desired layers. After grid-searching the optimal hyperparameters, we check the accuracy of the model.

In [8]:
cnn = CNNClassifier(multilayer=True)

Grid Search for optimal hyperparameters based on cross-validated accuracy

In [9]:
cnn.hyperparameter_tuning(graphs, targets, multilayer=True, max_evals=40)

 92%|█████████▎| 37/40 [20:21:32<1:25:52, 1717.38s/trial, best loss: -0.9166666666666666] 

Cross-Validation

5-fold cross-validation repeated 300 times to obtain average accuracy and ROC AUC

In [None]:
# cnn = CNNClassifier(learning_rate=0.0011, batch_size=3, gamma=0.95)
#   Accuracies
accs = []
stds = []

#   ROC AUCs
all_mean_fprs = []
all_mean_tprs = []
all_std_tprs = []
all_aucs = []

hist = None
best_acc = 0.0

for _ in range(5):
    acc_params, roc_params, history = cnn.cross_validate(graphs, targets, 5, verbose=False, multilayer=True)
    accs.append(acc_params[0])
    stds.append(acc_params[1])
    
    all_mean_fprs.append(roc_params[0])
    all_mean_tprs.append(roc_params[1])
    all_std_tprs.append(roc_params[2])
    all_aucs.append(roc_params[3])
    
    if (acc_params[0] > best_acc):
        hist = history
        best_acc = acc_params[0]
    
print("Accuracy: {:.4f} ± {:.4f}".format(np.mean(accs), np.mean(stds)))

Plot Results

In [None]:
mean_mean_fpr = np.mean(all_mean_fprs, axis=0)
mean_mean_tpr = np.mean(all_mean_tprs, axis=0)
mean_std_tpr = np.mean(all_std_tprs, axis=0)
mean_auc = np.mean(all_aucs)

plot_avg_roc_curve(mean_mean_fpr, mean_mean_tpr, mean_std_tpr, mean_auc)

In [None]:
dataset = GraphDataset(graphs, targets)
dataset_loader = DataLoader(dataset, batch_size=3, shuffle=False)
y, preds = cnn.predict(dataset_loader)
print(y)
print(preds)

In [None]:
# cnn = CNNClassifier(learning_rate=0.00019233, batch_size=4, gamma=0.840144, weight_decay=0.0001)
plot_losses(hist)
plot_accuracies(hist)