In [None]:
# train the cnn model
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics import RocCurveDisplay

import tensorflow as tf
from tensorflow.keras import layers, models

# adj_filename = 'reduced_X_correlation_epilepsy_corpus_60s.npy'
adj_filename = 'reduced_X_normed_mi_epilepsy_corpus_60s.npy'

adj_60s = np.load(adj_filename) 
y_60s = np.load('reduced_y_epilepsy_corpus_60s.npy', allow_pickle=True)

y = [0 if 'no_epilepsy'==v else 1 for v in y_60s]
X = adj_60s.reshape(len(y), 19,19)

# shuffling the training data
arr = list(np.arange(len(y)))
np.random.seed(1)
np.random.shuffle(arr)

y_shuf = np.array([y[i] for i in arr]).reshape(-1,1)
X_shuf = np.array([X[i] for i in arr])

# splitting the data
Xall, Xts, yall, yts = train_test_split(X_shuf, y_shuf, test_size=0.2,
                                      random_state=1)

Xtr, Xval, ytr, yval = train_test_split(Xall, yall, test_size=0.2,
                                      random_state=1)

# fix the seed for reproducibility
tf.random.set_seed(42)

# define the convolutional neural network 
model = models.Sequential()
model.add(layers.Conv2D(19, (3, 3), activation='relu',
                                    input_shape=(19, 19, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(19, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dense(19, activation='relu'))
model.add(layers.Dense(4, activation='tanh')) 
model.add(layers.Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', 
              loss='binary_crossentropy',
              metrics=['accuracy']
              )

# fit the model
history = model.fit(Xtr, ytr, epochs=80, batch_size=48,
                    validation_data=(Xval, yval))

test_loss, test_acc = model.evaluate(Xts,  yts, verbose=2)

# plot training history
metric = list(history.history.keys())[1]
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 8))
ax.plot(history.history[metric], lw=4, ls='-', c='b')
ax.plot(history.history["val_" + metric],
        lw=4, ls='--', c='r')
ax.set_ylabel('Accuracy', fontsize=18)
ax.set_xlabel("Epoch", fontsize=18)
ax.legend(["Train", "Validation"], loc=4, fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=18)
ax.tick_params(axis='both', which='minor', labelsize=18)
plt.tight_layout()
plt.show()

# plot ROC curves for validation and test data
y_pred_va = model.predict(Xval)
y_pred_ts = model.predict(Xts)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))
RocCurveDisplay.from_predictions(
   yval, y_pred_va, color='b', ax=ax, name='Validation')
RocCurveDisplay.from_predictions(
   yts, y_pred_ts,color='r', ax=ax, name='Test' )
plt.tight_layout()
plt.show()

# plot confusion matrix for test data
y_pred = np.array(model.predict(Xts) > 0.5)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(5, 5))
ConfusionMatrixDisplay.from_predictions(yts, y_pred, 
                                        ax=ax,
                                        colorbar=False,
                                        )
plt.tight_layout()
plt.show()
