In [None]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np 
from scipy import signal 
import pandas as pd
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa
from keras import regularizers

In [None]:
x_train = np.load('Datasets/GAT/traindata_new.npy',mmap_mode='r')
y_train = np.load('Datasets/GAT/trainlabels_new.npy',mmap_mode='r')
x_test  = np.load('Datasets/GAT/testdata_new.npy',mmap_mode='r')
y_test  = np.load('Datasets/GAT/testlabels_new.npy',mmap_mode='r')

In [None]:
mean=x_train.mean()
std=x_train.std()

x_train=(x_train-mean)/std
x_test=(x_test-mean)/std

x_train=np.expand_dims(x_train,axis=-1)
x_test=np.expand_dims(x_test,axis=-1)

np.random.seed(42)
train_indices = np.arange(x_train.shape[0])
np.random.shuffle(train_indices)
x_train = x_train[train_indices]
y_train = y_train[train_indices]

In [None]:
model=keras.models.load_model('Saved_models/GAT_paper_model/cp_0175.ckpt')

In [None]:
model.evaluate(x_test,y_test)

In [None]:
## for 18 channels 
adj=tf.constant(
    [  [1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 1.]],
    dtype=tf.float32)

In [None]:
physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(physical_devices[1], 'GPU')

In [None]:

channel_names=["Fp1-T3","T3-O1","Fp1-C3","C3-O1","Fp2-C4","C4-O2","Fp2-T4","T4-O2","T3-C3","C3-Cz","Cz-C4","C4-T4"]
indices =[[r,i] for r,c1 in enumerate(channel_names) for i,c2 in enumerate(channel_names) if (c1.split("-")[0]==c2.split("-")[1] or c1.split("-")[1]==c2.split("-")[1] 
          or c1.split("-")[0]==c2.split("-")[0] or c1.split("-")[1]==c2.split("-")[0])]
adj=np.zeros((12,12))
# adj[0,6],adj[6,0]=1,1
# adj[1,7],adj[7,1]=1,1
for i in indices:
    adj[i[0]][i[1]]=1
adj=tf.constant(adj,dtype=tf.float32)

class GATLayer(layers.Layer):

    def __init__(self,output_dim):
        super(GATLayer, self).__init__()
        self.output_dim = output_dim
        self.Leakyrelu = layers.LeakyReLU(alpha=0.2)
    
    def build(self, input_shape):
        self.W = self.add_weight(name='W',shape=(input_shape[-1], self.output_dim), initializer='random_normal',trainable=True)
        self.a = self.add_weight(name='a',shape=(2*self.output_dim, 1), initializer='random_normal',trainable=True)
    
    def call(self,input,adj):
        H= tf.matmul(input, self.W)
        h1=tf.tile(tf.expand_dims(H, axis=1), [1,12,1,1])
        h2=tf.tile(tf.expand_dims(H, axis=2), [1,1,12,1])
        result =tf.concat([h1 , h2], axis=-1)
        e=self.Leakyrelu(tf.squeeze(tf.matmul(result, self.a),axis=-1))
        zero_mat= -1e20*tf.ones_like(e)
        msked_e=tf.where(adj==1.0,e,zero_mat)
        alpha=tf.nn.softmax(msked_e,axis=-1)
        HPrime=tf.matmul(alpha,H)
        return tf.nn.elu(HPrime)

# class AttentionLayer(layers.Layer):
#     def __init__(self, output_dim):
#         super(AttentionLayer, self).__init__()
#         self.output_dim = output_dim
    
#     def build(self, input_shape):
#         self.WQ = self.add_weight(name='WQ',shape=(input_shape[-1], self.output_dim), initializer='random_normal',trainable=True)
#         self.WK = self.add_weight(name='WK',shape=(input_shape[-1], self.output_dim), initializer='random_normal',trainable=True)
#         self.WV = self.add_weight(name='WV',shape=(input_shape[-1], self.output_dim), initializer='random_normal',trainable=True)

#     def call(self, input,adj):
#         Q = tf.matmul(input, self.WQ)
#         K = tf.matmul(input, self.WK)
#         V = tf.matmul(input, self.WV)
#         e = tf.matmul(Q, K, transpose_b=True)
#         e = e / tf.math.sqrt(tf.cast(self.output_dim, tf.float32))
#         zero_mat= -1e20*tf.ones_like(e)
#         msked_e=tf.where(adj==1.0,e,zero_mat)
#         alpha = tf.nn.softmax(msked_e, axis=-1)
#         H = tf.matmul(alpha, V)
#         return H

def create_model():
    Input= keras.Input(shape=(12,384,1))
    regularizer_dense=regularizers.l2(0.0001)

    x= layers.Conv2D(32,(1,5),activation='relu',padding='same')(Input)
    y= layers.Conv2D(32,(1,7),activation='relu',padding='same')(Input)
    x= layers.add([x,y])
    x= layers.AveragePooling2D((1,2))(x)
    x= layers.BatchNormalization()(x)
    x= layers.SpatialDropout2D(0.2)(x)

    x= layers.Conv2D(64,(1,5),activation='relu',padding='same')(x)
    y- layers.Conv2D(64,(1,7),activation='relu',padding='same')(x)
    x= layers.add([x,y])
    x= layers.AveragePooling2D((1,2))(x)
    x= layers.BatchNormalization()(x)
    x= layers.SpatialDropout2D(0.2)(x)

    x= layers.Conv2D(8,(1,5),activation='relu',padding='same')(x)
    y= layers.Conv2D(8,(1,7),activation='relu',padding='same')(x)
    x= layers.add([x,y])
    x= layers.AveragePooling2D((1,2))(x)
    x= layers.BatchNormalization()(x)
    x= layers.SpatialDropout2D(0.2)(x)

    x= layers.Conv2D(1,(1,5),activation='relu',padding='same')(x)
    y= layers.Conv2D(1,(1,7),activation='relu',padding='same')(x)
    x= layers.add([x,y])
    x= layers.AveragePooling2D((1,2))(x)
    x= layers.Reshape((12,24))(x)

    x= GATLayer(37)(x,adj)
    x= GATLayer(32)(x,adj)
    x= GATLayer(16)(x,adj)
    
    x= layers.GlobalAveragePooling1D()(x)
    x= layers.Dropout(0.2)(x)
    x= layers.Dense(32,activation='relu',kernel_regularizer=regularizer_dense)(x)
    x= layers.Dropout(0.2)(x)
    x= layers.Dense(16,activation='relu',kernel_regularizer=regularizer_dense)(x)
    x= layers.Dropout(0.2)(x)
    x= layers.Dense(1,activation='sigmoid',kernel_regularizer=regularizer_dense)(x)

    model = keras.Model(inputs=Input, outputs=x)

    optimizer=keras.optimizers.Adam(learning_rate=0.002)
    loss=keras.losses.BinaryFocalCrossentropy(from_logits=False,gamma=2,alpha=0.4,apply_class_balancing=True)
    kappa=tfa.metrics.CohenKappa(num_classes=2)
    fp=keras.metrics.FalsePositives()
    tn=keras.metrics.TrueNegatives()
    precision = keras.metrics.Precision(name='precision')
    recall = keras.metrics.Recall(name='recall')
    AUROC = keras.metrics.AUC(curve='ROC', name = 'AUROC')
    AUPRC = keras.metrics.AUC(curve='PR', name = 'AUPRC')
    model.compile(optimizer=optimizer,loss=loss,metrics=['accuracy', AUROC, AUPRC,fp,tn, precision, recall,kappa])   
    return model
model=create_model()

In [None]:
checkpoint_path = "GAT_new/cp_{epoch:02d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path) 
cp_callback=keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,save_weights_only=False,verbose=0,save_best_only=True,monitor='val_AUROC',mode='max')  
history=model.fit(x_train,y_train,epochs=100,batch_size=512,verbose=1,validation_data=(x_test,y_test),callbacks=[cp_callback])

In [None]:
with open("Attention.jason", 'w') as f:
    pd.DataFrame(history.history).to_json(f)

In [None]:
with open("./History/history_paper_cv_0.jason", 'r') as f:
    history1=pd.read_json(f)
with open("./History/history_paper_cv_1.jason", 'r') as f:
    history2=pd.read_json(f)
with open("./History/history_paper_cv_2.jason", 'r') as f:
    history3=pd.read_json(f)
with open("./History/history_paper_cv_3.jason", 'r') as f:
    history4=pd.read_json(f)
with open("./History/history_paper_cv_4.jason", 'r') as f:
    history5=pd.read_json(f)
with open("./History/history_paper_cv_5.jason", 'r') as f:
    history6=pd.read_json(f)
with open("./History/history_paper_cv_6.jason", 'r') as f:
    history7=pd.read_json(f)
with open("./History/history_paper_cv_7.jason", 'r') as f:
    history8=pd.read_json(f)
with open("./History/history_paper_cv_8.jason", 'r') as f:
    history9=pd.read_json(f)
with open("./History/history_paper_cv_9.jason", 'r') as f:
    history10=pd.read_json(f)

In [None]:
metrics=['accuracy','val_accuracy','loss','val_loss','val_AUROC','val_AUPRC','val_recall_9','val_cohen_kappa']

epochs = range(1, 51)
fig,ax=plt.subplots(4,2,figsize=(20,20))
for r in range(8):
    ax[r//2][r%2].plot(history5[metrics[r]],color='r',label='inter')
    ax[r//2][r%2].plot(history10[metrics[r]],color='b',label='old')
    ax[r//2][r%2].set_title(metrics[r])
    ax[r//2][r%2].set_xlabel('Epochs')
    ax[r//2][r%2].legend()
    ax[r//2][r%2].grid()
fig.tight_layout()
plt.show()

In [None]:
with open("Attention.jason", 'r') as f:
    history1=pd.read_json(f)

with open("History/history_cv_0.jason", 'r') as f:
    history2=pd.read_json(f)

plt.plot(history1['accuracy'],color='r',label=' Accuracy with dot product attention')
plt.plot(history2['accuracy'],color='b',label=' Accuracy with GAT')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
# plt.ylim(0,1)
plt.grid()
plt.show()