Libraries/module Imports

In [1]:
#   Custom imports for testing
from classification.cnn import CNNClassifier, GraphDataset, DataLoader
from utils.import_data import import_all_data, import_panda_csv
from utils.constants import BAND_NAMES
from utils.plot_fig import (
    plot_avg_roc_curve,
    plot_accuracies,
    plot_losses
)

%load_ext autoreload
%autoreload 2

In [2]:
#   Python imports
import numpy as np

In [3]:
#   Plotting imports
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline
sns.set_theme(style="whitegrid")

In [4]:
import warnings

# Suppress specific warning
warnings.filterwarnings("ignore", category=RuntimeWarning, module="scipy.stats")

Import data

In [5]:
labels = [c.strip() for c in import_panda_csv("../data/mTBI/sources_TBI_MEGM001.csv").columns]
graphs, targets = import_all_data("../graphs/full_multilayer/IPLV")
graphs = [graph.to_numpy() for graph in graphs]

[2KData loaded successfully!


In [6]:
total_labels = ["{}_{}".format(band, label) for band in BAND_NAMES for label in labels]

Nomralise the data

In [7]:
for g in graphs:
    norm = np.linalg.norm(g)
    g = g/norm

Convolutional Neural Network

We use the architecture with the desired layers. After grid-searching the optimal hyperparameters, we check the accuracy of the model.

In [None]:
cnn = CNNClassifier(multilayer=True)

Grid Search for optimal hyperparameters based on cross-validated accuracy

In [None]:
lr_search = np.linspace(0.001, 0.0013, 4)
bs_search = [3, 4, 5, 6]
wd_search = [0.0]
gamma_search = [0.8, 0.9, 0.95]


cnn.hyperparameter_tuning(graphs, targets, lr_search, bs_search, wd_search, gamma_search, multilayer=True)

Cross-Validation

5-fold cross-validation repeated 300 times to obtain average accuracy and ROC AUC

In [None]:
# cnn = CNNClassifier(learning_rate=0.0011, batch_size=3, gamma=0.95)
#   Accuracies
accs = []
stds = []

#   ROC AUCs
all_mean_fprs = []
all_mean_tprs = []
all_std_tprs = []
all_aucs = []

hist = None
best_acc = 0.0

for _ in range(5):
    acc_params, roc_params, history = cnn.cross_validate(graphs, targets, 5, verbose=False, multilayer=True)
    accs.append(acc_params[0])
    stds.append(acc_params[1])
    
    all_mean_fprs.append(roc_params[0])
    all_mean_tprs.append(roc_params[1])
    all_std_tprs.append(roc_params[2])
    all_aucs.append(roc_params[3])
    
    if (acc_params[0] > best_acc):
        hist = history
        best_acc = acc_params[0]
    
print("Accuracy: {:.4f} ± {:.4f}".format(np.mean(accs), np.mean(stds)))

Plot Results

In [None]:
mean_mean_fpr = np.mean(all_mean_fprs, axis=0)
mean_mean_tpr = np.mean(all_mean_tprs, axis=0)
mean_std_tpr = np.mean(all_std_tprs, axis=0)
mean_auc = np.mean(all_aucs)

plot_avg_roc_curve(mean_mean_fpr, mean_mean_tpr, mean_std_tpr, mean_auc)

In [34]:
dataset = GraphDataset(graphs, targets)
dataset_loader = DataLoader(dataset, batch_size=3, shuffle=False)
y, preds = cnn.predict(dataset_loader)
print(y)
print(preds)

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 1 0 0 0 1 1 1 1 1 1 0 0 1 0 1 0 0 0 0 0 0 0 0 1 0 1 1 1 1 0 1 1
 1 0 1 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


In [35]:
cnn = CNNClassifier(learning_rate=0.00015, batch_size=3, gamma=0.8, weight_decay=0.0)
cnn.classify(graphs, targets, verbose=True, multilayer=True)

Epoch [1]:	Training loss: 0.7659	Training Accuracy: 0.4722	Validation loss: 0.7090	Validation accuracy: 0.3333
Epoch [2]:	Training loss: 0.7009	Training Accuracy: 0.3056	Validation loss: 0.6922	Validation accuracy: 0.5833
Epoch [3]:	Training loss: 0.6732	Training Accuracy: 0.6111	Validation loss: 0.6599	Validation accuracy: 0.6667
Epoch [4]:	Training loss: 0.6653	Training Accuracy: 0.5556	Validation loss: 0.6523	Validation accuracy: 0.6667
Epoch [5]:	Training loss: 0.6566	Training Accuracy: 0.5556	Validation loss: 0.6433	Validation accuracy: 0.6667
Epoch [6]:	Training loss: 0.6405	Training Accuracy: 0.5556	Validation loss: 0.6662	Validation accuracy: 0.6667
Epoch [7]:	Training loss: 0.6310	Training Accuracy: 0.9444	Validation loss: 0.6610	Validation accuracy: 0.6667
Epoch [8]:	Training loss: 0.5951	Training Accuracy: 0.6111	Validation loss: 0.6354	Validation accuracy: 0.6667
Epoch [9]:	Training loss: 0.5745	Training Accuracy: 0.6944	Validation loss: 0.6675	Validation accuracy: 0.9167
E

(array([0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]))

In [None]:
plot_losses(hist)
plot_accuracies(hist)