In [None]:
# train gcn model
import numpy as np
import pandas as pd
import scipy.sparse as sp
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.metrics import categorical_accuracy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from spektral.data import Dataset, DisjointLoader, Graph
from spektral.layers import GCSConv, GCNConv, GeneralConv
from spektral.layers import GlobalAvgPool, GlobalMaxPool

# load data
# adj_filename = 'reduced_X_correlation_epilepsy_corpus_60s.npy'
adj_filename = 'reduced_X_normed_mi_epilepsy_corpus_60s.npy'

reduced_index_df = pd.read_csv("reduced_epilepsy_corpus_window_index_60s.csv")
chose_ids = reduced_index_df[['patient_id', 'numeric_label']].drop_duplicates()

train_ids = chose_ids['patient_id'].sample(frac=0.6, random_state=3).tolist() 
val_ids = (chose_ids[-chose_ids['patient_id'].isin(train_ids)]['patient_id'].
            sample(frac=0.5, random_state=4).tolist())
test_ids = (chose_ids[-chose_ids['patient_id'].
            isin(train_ids+val_ids)]['patient_id'].tolist())

class LoadGraphs(Dataset):
    def __init__(self, patient_ids, **kwargs):
        self.patient_ids = patient_ids
        super().__init__(**kwargs)

    def read(self): 
        reduced_index_df = pd.read_csv("reduced_epilepsy_corpus_window_index_60s.csv")
        idx = (reduced_index_df[reduced_index_df['patient_id'].
               isin(self.patient_ids)].index.values)
        adj_60s = np.load(adj_filename) 
        adj_60s = adj_60s[idx]
        y_60s = np.load('reduced_y_epilepsy_corpus_60s.npy', allow_pickle=True)
        y_60s = y_60s[idx]
        eeg_60s = np.load('reduced_X_windows_epilepsy_corpus_60s.npy', allow_pickle=True)
        eeg_60s = eeg_60s[idx]
        
        y = [0 if 'no_epilepsy'==v else 1 for v in y_60s]
        
        # for categorical cross-entropy
        y_col = np.zeros((len(y),2),)
        for i,x in enumerate(y):
            y_col[i,x] = 1
        
        A = adj_60s.reshape(len(y), 19,19)

        data_gnn = []
        for x,a,label in zip(eeg_60s, A, y_col):
            proportional_thresholding(matrix_a=a, percentile=50)
            data_gnn.append(Graph(x=x, a=sp.csr_matrix(a), y=label))
    
        del adj_60s, A, y, y_60s, y_col, eeg_60s  
        return data_gnn


# reading preprocessed data
data_tr = LoadGraphs(patient_ids=train_ids)
data_va = LoadGraphs(patient_ids=train_ids)
data_te = LoadGraphs(patient_ids=test_ids)

# shuffle graph datasets
np.random.seed(1)
idx_tr = np.random.permutation(len(data_tr))
data_tr = data_tr[idx_tr]
idx_va = np.random.permutation(len(data_va))
data_va = data_va[idx_va]
idx_te = np.random.permutation(len(data_te))
data_te = data_te[idx_te]

# config
learning_rate = 1e-4  # learning rate
epochs = 80  # number of training epochs
es_patience = 10  # patience for early stopping
batch_size = 32  # batch size

# data loaders
loader_tr = DisjointLoader(data_tr, batch_size=batch_size, epochs=epochs)
loader_va = DisjointLoader(data_va, batch_size=batch_size)
loader_te = DisjointLoader(data_te, batch_size=batch_size)

# build model
network_prarams={'channels': 32,
            'activation': 'relu', 
            'aggregate': 'mean',
            'dropout': 0.25,
            'kernel_regularizer':'l2',
            'bias_regularizer':'l2',
            'activity_regularizer':'l2'
            }
class Net(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = GeneralConv(**network_prarams)
        self.conv2 = GeneralConv(**network_prarams)
        self.conv3 = GeneralConv(**network_prarams) 
        self.global_pool = GlobalMaxPool()
        self.dense1 = Dense(19, activation="relu")
        self.dense2 = Dense(4, activation="relu")
        self.dense3 = Dense(2, activation="softmax")

    def call(self, inputs):
        x, a, i = inputs
        x = self.conv1([x, a])
        x = self.conv2([x, a])
        x = self.conv3([x, a])
        output = self.global_pool([x, i])
        output = self.dense1(output)
        output = self.dense2(output)
        output = self.dense3(output)

        return output

model = Net()
optimizer = Adam(learning_rate=learning_rate)
loss_fn = CategoricalCrossentropy()

# fit the model
@tf.function(input_signature=loader_tr.tf_signature(),
             experimental_relax_shapes=True)

def train_step(inputs, target):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(target, predictions) + sum(model.losses)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    acc = tf.reduce_mean(categorical_accuracy(target, predictions))
    return loss, acc

def evaluate(loader):
    output = []
    step = 0
    while step < loader.steps_per_epoch:
        step += 1
        inputs, target = loader.__next__()
        pred = model(inputs, training=False)
        outs = (
            loss_fn(target, pred),
            tf.reduce_mean(categorical_accuracy(target, pred)),
            len(target),  
        )
        output.append(outs)
        if step == loader.steps_per_epoch:
            output = np.array(output)
            return np.average(output[:, :-1], 0, weights=output[:, -1])

epoch = step = 0
best_val_loss = np.inf
best_weights = None
patience = es_patience
results = []
tr_results = []
val_results = []


tf.random.set_seed(42)
for batch in loader_tr:
    step += 1
    loss, acc = train_step(*batch)
    results.append((np.array(loss), np.array(acc)))
    if step == loader_tr.steps_per_epoch:
        step = 0
        epoch += 1

        # compute validation loss and accuracy
        val_loss, val_acc = evaluate(loader_va)
        val_results.append((val_loss, val_acc))
        print(
            "Ep. {} - Loss: {:.3f} - Acc: {:.3f} - Val loss: {:.3f} - Val acc: {:.3f}".format(
                epoch, *np.mean(results, 0), val_loss, val_acc
            )
        )
        tr_results.append(np.mean(results, 0))

        # check if loss improved for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = es_patience
            print("New best val_loss {:.3f}".format(val_loss))
            best_weights = model.get_weights()
        else:
            patience -= 1
            if patience == 0:
                print("Early stopping (best val_loss: {})".format(best_val_loss))
                break
        results = []

# evaluate model
model.set_weights(best_weights)  # load best model
test_loss, test_acc = evaluate(loader_te)
print("Done. Test loss: {:.4f}. Test acc: {:.2f}".format(test_loss, test_acc))

# plot evaluation results
val_results = np.array(val_results)
tr_results = np.array(tr_results)
  
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(10, 8))
ax.plot(tr_results[:,1], lw=4, ls='-', c='b')
ax.plot(val_results[:,1], lw=4, ls='--', c='r')
ax.set_ylabel('Accuracy', fontsize=18)
ax.set_xlabel("Epoch", fontsize=18)
ax.legend(["Train", "Validation"], loc=4, fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=18)
ax.tick_params(axis='both', which='minor', labelsize=18)
plt.tight_layout()
plt.show()
