In [None]:
import time
import tensorflow as tf
import numpy as np
import sys
import os
import csv
from sklearn.metrics import confusion_matrix, f1_score, balanced_accuracy_score, roc_curve, auc
import scipy.io as sio
import scipy.sparse as sp
import pickle as pkl
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold
from tensorflow.keras import layers, initializers, optimizers, callbacks
from tensorflow.keras import Model
from tensorflow.keras.utils import to_categorical
import ABCD_Parser as Reader


In [9]:
seed = 123
np.random.seed(seed)
tf.random.set_seed(seed)

In [29]:
flags = {
    'node_num': 110,          # Number of Graph nodes
    'output_dim': 1,           # Number of output dimensions
    'learning_rate': 0.0001,   # Initial learning rate for model
    'learning_rate_mask': 0.01,# Learning rate for mask optimization
    'batch_num': 10,           # Batch size (original 'batch_num' = batch size)
    'epochs': 1000,            # Epochs for model training
    'epochs_mask': 400,        # Epochs for mask optimization
    'attn_heads': 5,           # Number of attention heads
    'hidden1_gat': 24,         # GAT hidden layer 1 units
    'output_gat': 3,           # GAT output layer units
    'dropout': 0.0,            # Dropout rate (1 - keep prob)
    'in_drop': 0.0,            # Input dropout rate
    'weight_decay': 5e-4,      # L2 weight decay
    'early_stopping': 15,      # Early stopping tolerance
    'fold': 4                  # Target fold to train
}

In [10]:
def accuracy(preds, labels):
    """Accuracy metric (matches TF1 logic)."""
    correct_prediction = tf.equal(tf.round(preds), labels)
    return tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Subclassing using Keras 

In [30]:
import tensorflow as tf
from tensorflow import keras
class gat_layer(tf.keras.layers.Layer):
  def __init__(self,input_dim, F_, attn_heads=1, attn_heads_reduction="concat",
               activation = tf.nn.relu, use_bias = True, dropout_rate = 0.0, in_drop = 0.0, name=''):

    super().__init__(name = f'gat_layer{name}')

    self.input_dim = input_dim
    self.F_ = F_ #Output feature per node
    self.attn_heads = attn_heads
    self.attn_heads_reduction = attn_heads_reduction
    self.activation = activation
    self.use_bias = use_bias
    self.dropout_rate = dropout_rate
    self.in_drop = in_drop

  def build(self, input_shape): #input_dim is the features
    glorot_init = initializers.GlorotUniform(seed=seed)
    zero_init = initializers.Zeros()

    self.weights_list = [] # input_dim x F_
    self.attn_self_weights = [] # F_ x 1
    self.attn_neigh_weight = [] # F_ x 1

    for head in range(self.attn_heads):
      #Linear transformation

      w = self.add_weight(
        shape =(self.input_dim, self.F_),
        initializer= glorot_init,
        dtype = tf.float32,
        name = f'weights_{head}',
        trainable = True
      )

      #Weight for self and neighbors

      attn_self = self.add_weight(
        shape = (self.F_, 1),
        initializer = glorot_init,
        dtype = tf.float32,
        name = f'attn_self_{head}',
        trainable = True
      )
      attn_neigh = self.add_weight(
        shape = (self.F_, 1),
        initializer = glorot_init,
        dtype = tf.float32,
        name = f'attn_neigh_{head}',
        trainable = True
      )

      self.weights_list.append(w)
      self.attn_self_weights.append(attn_self)
      self.attn_neigh_weight.append(attn_neigh)

    if self.use_bias:
      self.bias = self.add_weight(
        shape = (self.F_,),
        initializer = zero_init,
        name = 'bias',
        trainable = True
      )
    else:
      self.bias = None

    super().build(input_shape) # mark layer as built

  def call(self, inputs, trainning = None):
    """Foward pass"""
    X, A = inputs # X: (batch size, node_num, input_dim(feature)), A: (batch_size, node_num, node_num)

    if self.in_drop > 0.0 and trainning:
      X = tf.nn.dropout(X, rate = self.in_drop)

    outputs = [] # store output of each attention head
    dense_masks = [] # store attention score masks

    for head in range(self.attn_heads):
      # Linear transformation
      kernel = self.weights_list[head]
      features = tf.matmul(X, kernel) # (batch_size, node_num, F_) F_ is a hyperpara -> choose later

      # Compute attention scores
      attn_self_kernel = self.attn_self_weights[head]
      attn_neigh_kernel = self.attn_neigh_weight[head]
      attn_for_self = tf.matmul(features, attn_self_kernel) # (batch_size, node_num, 1)
      attn_for_neigh = tf.matmul(features, attn_neigh_kernel) # (batch_size, node_num, 1)

      # Compute attention head (a(Wh_i, Wh_j))
      dense = attn_for_self + tf.transpose(attn_for_neigh, [0,2,1]) # batch_size, node_num, node_num
      #print("plus:", dense.shape)
      dense = tf.nn.leaky_relu(dense, alpha = 0.2) # non-linearity

      # Mask non-edges using adj matrix
      zero_vec = -9e15 * tf.ones_like(dense)
      dense = tf.where(A > 0.0, dense, zero_vec)
      dense_masks.append(dense)

      # Softmax to get attn coef
      dense = tf.nn.softmax(dense)

      # Drop out on attention coef and features
      if trainning and self.dropout_rate > 0.0:
        dropout_attn = tf.nn.dropout(dense, rate = self.dropout_rate)
        dropout_feat = tf.nn.dropout(features, rate = self.dropout_rate)
      else:
        dropout_attn = dense
        dropout_feat = features

      # Aggregate neighbor features
      node_features = tf.matmul(dropout_attn, dropout_feat) # batch_size, node_num, F_

      # add bias if need

      if self.use_bias:
        node_features += self.bias

      # storing head output
      if self.attn_heads_reduction == 'concat':
        outputs.append(self.activation(node_features))
      else:
        outputs.append(node_features)
    # Aggregate attention heads
    if self.attn_heads_reduction == 'concat':
      output = tf.concat(outputs, axis = -1)
    else:
      output = tf.add_n(outputs) / self.attn_heads # avg heads
      output = self.activation(output)

    return output, dense_masks

Testing

In [25]:

# 1. Create dummy data
batch_size = 2
num_nodes = 4
input_dim = 5  # Features per node

# Random node features
X = tf.random.normal(shape=(batch_size, num_nodes, input_dim))

# Random adjacency matrix (binary)
A = tf.cast(tf.random.uniform(shape=(batch_size, num_nodes, num_nodes)) > 0.5, tf.float32)

print("Input features shape:", X.shape)  # (2, 4, 5)
print("Adjacency matrix shape:", A.shape)  # (2, 4, 4)

# 2. Create a GAT layer
gat = gat_layer(
    input_dim=input_dim,
    F_=8,                # Output features per node per head
    attn_heads=2,        # Number of attention heads
    attn_heads_reduction='concat',  # or 'average'
    activation=tf.nn.leaky_relu,
    use_bias=True,
    dropout_rate=0.0,
    in_drop=0.0,
    name='test'
)

# 3. Forward pass
output, attn_scores = gat([X, A], training=False)

# 4. Check output shapes
print("\nGAT output shape:", output.shape)
# With concat: (batch_size, num_nodes, F_ * attn_heads) = (2, 4, 16)
# With average: (batch_size, num_nodes, F_) = (2, 4, 8)

print("Attention scores length (number of heads):", len(attn_scores))
print("Attention scores shape for first head:", attn_scores[0].shape)  # (2, 4, 4)


# Take first sample in batch, first attention head
attn_head0 = attn_scores[0][0].numpy()
A_sample = A[0].numpy()

print("\nAdjacency matrix (first sample):")
print(A_sample)

print("\nAttention scores (first sample, first head):")
print(attn_head0)

# Check that non-edges are masked to -9e15
print("\nNon-edges are masked:", np.all(attn_head0[A_sample == 0] == -9e15))

Input features shape: (2, 4, 5)
Adjacency matrix shape: (2, 4, 4)

GAT output shape: (2, 4, 16)
Attention scores length (number of heads): 2
Attention scores shape for first head: (2, 4, 4)

Adjacency matrix (first sample):
[[0. 1. 0. 0.]
 [1. 0. 1. 1.]
 [0. 0. 1. 0.]
 [1. 0. 1. 0.]]

Attention scores (first sample, first head):
[[-8.9999998e+15 -3.7271079e-01 -8.9999998e+15 -8.9999998e+15]
 [-3.7271079e-01 -8.9999998e+15  1.9667579e+00  2.0883536e+00]
 [-8.9999998e+15 -8.9999998e+15  4.1073794e+00 -8.9999998e+15]
 [ 3.9866316e-01 -8.9999998e+15  4.2289753e+00 -8.9999998e+15]]

Non-edges are masked: True


In [None]:
inputs = [tf.keras.Input(shape=(num_nodes, input_dim)),
          tf.keras.Input(shape=(num_nodes, num_nodes))]
x, _ = gat(inputs)
model = tf.keras.Model(inputs=inputs, outputs=x)
model.summary()

In [27]:
class fc_layer(tf.keras.layers.Layer): # Fully dense layer in paper
  def __init__(self, input_dim, output_dim, dropout=0.0, act = tf.nn.relu, bias = False, name = ''):
    super().__init__(name = f'fc_layer{name}')

    # Hyperparameters
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.dropout = dropout
    self.act = act
    self.bias = bias

  def build(self, input_shape):

    glorot_init = initializers.GlorotUniform(seed = seed)
    zero_init = initializers.Zeros()

    #Linear weigh
    self.weights = self.add_weight(
      shape = (self.input_dim, self.output_dim),
      initializer = glorot_init,
      dtype = tf.float32,
      name = 'weights',
      trainable = True
    )


    if self.bias:
      self.bias_term = self.add_weight(
        shape = (self.output_dim,),
        initializer = zero_init,
        dtype = tf.float32,
        name = 'bias',
        trainable = True

      )
    else:
      self.bias_term = None

    super().build(input_shape)

  def call(self, inputs, trainning = None): #foward pass

    x = inputs

    if self.dropout > 0.0 and trainning:
      x = tf.nn.dropout(x, rate = self.dropout)

    # Linear transformation
    output = tf.matmul(x, self.weights)

    if self.bias:
      output += self.bias_term

    return self.act(output)

In [None]:
class GAT_Model(Model):
  def __init__(self, input_dim, flags):
    super().__init__(name = 'gat_mil')

    self.flags = flags
    self.input_dim = input_dim #input feature dimension (from data)

    self._build_layers() #this get called after defining because python read all class function before calling init

    #Init mask M using add_weight
    self.M = self.add_weight(
      shape=(self.config)
    )








