In [89]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
import numpy as np  

In [109]:
channel_names=["Fp1-T3","T3-O1","Fp1-C3","C3-O1","Fp2-C4","C4-O2","Fp2-T4","T4-O2","T3-C3","C3-Cz","Cz-C4","C4-T4"]
indices =[[r,i] for r,c1 in enumerate(channel_names) for i,c2 in enumerate(channel_names) if (c1.split("-")[0]==c2.split("-")[1] or c1.split("-")[1]==c2.split("-")[1] 
          or c1.split("-")[0]==c2.split("-")[0] or c1.split("-")[1]==c2.split("-")[0])]
adj=np.zeros((12,12))
for i in indices:
    adj[i[0]][i[1]]=1
adj=tf.constant(adj,dtype=tf.float32)

In [107]:
a=tf.constant(tf.random.normal((12,12)),dtype=tf.float32)
c=tf.where(adj==1,a,tf.zeros_like(a))

In [113]:
class MatrixTransformationLayer(layers.Layer):
    def __init__(self, output_dim):
        super(MatrixTransformationLayer, self).__init__()
        self.output_dim = output_dim
    def build(self, input_shape):
        self.W = self.add_weight(name='W',shape=(input_shape[-1], self.output_dim), initializer='random_normal',trainable=True)
    def call(self, inputs):
        return tf.matmul(inputs, self.W)

class AttentionMechanismLayer(layers.Layer):
    def __init__(self,adj):
        super(AttentionMechanismLayer, self).__init__()
        self.adj=adj
        self.LeakyReLU = layers.LeakyReLU(alpha=0.2)
    def build(self, input_shape):
        self.shape0 = input_shape[-2]
        self.shape1 = input_shape[-1]
        self.a = self.add_weight(name='a',shape=(2*input_shape[-1], 1), initializer='random_normal',trainable=True)
    def call(self,input):
        h1=tf.tile(tf.expand_dims(input, axis=1), [1,self.shape0, 1, 1])
        h2=tf.tile(tf.expand_dims(input, axis=2), [1,1, self.shape0, 1])
        result =tf.concat([h1 , h2], axis=-1)
        e=self.LeakyReLU(tf.squeeze(tf.matmul(result, self.a),axis=-1))
        zero_mat=-1e20*tf.zeros_like(e)
        msked_e=tf.where(self.adj==1,e,zero_mat)
        alpha=tf.nn.softmax(msked_e,axis=-1)
        HPrime=tf.matmul(alpha,input)
        return HPrime
        

Input= keras.Input(shape=(12,1024,1)) 
x= layers.Conv2D(4,(1,5),activation='relu',padding='same')(Input)
x= layers.Conv2D(8,(1,5),activation='relu',padding='same')(x)
x= layers.MaxPool2D((1,2))(x)
x= layers.BatchNormalization()(x)
x= layers.Conv2D(16,(1,5),activation='relu',padding='same')(x)
x= layers.Conv2D(32,(1,5),activation='relu',padding='same')(x)
x= layers.MaxPool2D((1,2))(x)
x= layers.BatchNormalization()(x)
x= layers.Conv2D(8,(1,5),activation='relu',padding='same')(x)
x= layers.Conv2D(8,(1,5),activation='relu',padding='same')(x)
x= layers.MaxPool2D((1,2))(x)
x= layers.BatchNormalization()(x)
x= layers.Conv2D(1,(1,5),activation='relu',padding='same')(x)
x= layers.Conv2D(1,(1,5),activation='relu',padding='same')(x)
x= layers.MaxPool2D((1,2))(x)
x= layers.BatchNormalization()(x)
x= layers.Reshape((12,64))(x)
x = MatrixTransformationLayer(32)(x)
x = AttentionMechanismLayer(adj)(x)
x = MatrixTransformationLayer(16)(x)
x = AttentionMechanismLayer(adj)(x)
x = MatrixTransformationLayer(8)(x)
x = AttentionMechanismLayer(adj)(x)
x = layers.GlobalAveragePooling1D()(x)

model = keras.Model(inputs=Input, outputs=x)

In [114]:
model.summary()

Model: "model_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_21 (InputLayer)       [(None, 12, 1024, 1)]     0         
                                                                 
 conv2d_160 (Conv2D)         (None, 12, 1024, 4)       24        
                                                                 
 conv2d_161 (Conv2D)         (None, 12, 1024, 8)       168       
                                                                 
 max_pooling2d_80 (MaxPooli  (None, 12, 512, 8)        0         
 ng2D)                                                           
                                                                 
 batch_normalization_80 (Ba  (None, 12, 512, 8)        32        
 tchNormalization)                                               
                                                                 
 conv2d_162 (Conv2D)         (None, 12, 512, 16)       656 