In [1]:
import tensorflow as tf
import numpy as np

from tensorflow.keras import datasets, layers, models
from tensorflow import keras
from scipy import io
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
#import input features
label = []
data_pre = io.loadmat("/content/drive/MyDrive/JGAT_code/Social_dataset.mat")['key']
data_fc = data_pre.transpose([0, 2, 1])
data = data_pre[:, :, :, np.newaxis]

label = np.tile(np.array([0, 0, 0, 0, 1, 1, 1 ,1]), 48)
print(label[0:16])

print(f"shape of labels: {label.shape}")
print(f"shape of pre_inputs: {data.shape}")
print('-------------------------------------------------')

#import adjacency matrix
file_adj = '/content/drive/MyDrive/JGAT_code/48-accumulated_SC.mat'
adj = io.loadmat(file_adj)['key']

# Compute a general A
adj[adj<40] = 0
adj[adj>=40] = 1 # edges: 5316, 1500

class GraphInfo:
    def __init__(self, edges, adj, num_nodes: int):
        self.edges = edges
        self.adj = adj
        self.num_nodes = num_nodes   

# Form joint kernel graph
jkernel_size = 3
A = np.tile(adj, [1, jkernel_size])
edge = tf.where(A == 1)

graph = GraphInfo(
    edge,
    adj,
    num_nodes=adj.shape[0],
)

print(f"number of nodes: {graph.num_nodes}, number of edges: {len(graph.edges)}")

[0 0 0 0 1 1 1 1 0 0 0 0 1 1 1 1]
shape of labels: (384,)
shape of pre_inputs: (384, 32, 200, 1)
-------------------------------------------------
number of nodes: 200, number of edges: 5316


In [4]:
# Compute Functional connectivity(FC)
FCs = []
FCs_dense = []

for i in range(data_fc.shape[0]):
    nodes_temp = []
    for j in range(data_fc.shape[1]):
      nodes_temp.append(data_fc[i][j, :])

    FC = np.corrcoef(nodes_temp)
    FCs.append(FC)

FCs = np.stack(FCs)
FCs_joint = np.tile(FCs, [1, 1, jkernel_size])
print(FCs_joint.shape)
for i in range(data_fc.shape[0]):
  temp = tf.gather_nd(FCs_joint[i], edge)
  FCs_dense.append(temp)

FCs_dense = np.stack(FCs_dense)
print(FCs_dense.shape)

(384, 200, 600)
(384, 5316)


In [5]:
# Compute joint spatio-temporal inputs [None, seq_len, N, K]
data_new = data.squeeze(-1)
data_new = data_new.transpose([0, 2, 1]) #[192, 200, 78]
zeropad = np.zeros([data.shape[0], 200, 1])
data_new = np.concatenate([zeropad, zeropad, data_new], axis=2)
# print(data_new.shape)
data_st = []

for i in range(data.shape[0]):
  new_data_s = []
  for j in range(data.shape[1]):
    new_data_s.append(data_new[i][:, j:j+jkernel_size])

  temp = np.concatenate(new_data_s, axis=1)
  data_st.append(temp)

data_st = np.stack(data_st)
print(data_st.shape)
data_st = data_st.reshape([data.shape[0], 200, data.shape[1], jkernel_size])
data_st = data_st.transpose([0, 2, 1, 3])

print(f"shape of inputs: {data_st.shape}")

(384, 200, 96)
shape of inputs: (384, 32, 200, 3)


In [6]:
#JGAT layer
class GraphAttention(layers.Layer):
    def __init__(
        self,
        in_feat,
        out_feat,
        N,
        K,
        graph_info,
        kernel_initializer1=tf.keras.initializers.GlorotUniform(seed=0),
        kernel_initializer2=tf.keras.initializers.GlorotUniform(seed=1),
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.N = N
        self.K = K
        self.graph_info = graph_info
        self.kernel_initializer1 = keras.initializers.get(kernel_initializer1)
        self.kernel_initializer2 = keras.initializers.get(kernel_initializer2)

    def build(self, input_shape):

        self.kernel = self.add_weight(
            shape=(self.in_feat, self.out_feat),
            trainable=True,
            initializer=self.kernel_initializer1,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.out_feat * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer2,
            name="kernel_attention",
        )
        self.built = True

    def compute_nodes_representation(self, features: tf.Tensor):
        return tf.matmul(features, self.kernel) 

    def call(self, inputs, FC_inputs):
        FC = FC_inputs[:, :, tf.newaxis, tf.newaxis]
        # Linear transformation of node features
        node_feat_transformed = tf.matmul(inputs, self.kernel) # (N*K, None, seq_len, out_feat)
        
        # Compute attention scores
        node_feat_expanded_head = tf.gather(node_feat_transformed[self.N*(K-1):self.N*K, :, :, :], self.graph_info.edges[:, 0]) # base on current nodes
        node_feat_expanded_tail = tf.gather(node_feat_transformed, self.graph_info.edges[:, 1]) * FC # apply FC on neighbors

        node_feat_expanded = tf.concat([node_feat_expanded_head, node_feat_expanded_tail], axis=-1)

        attention_scores = tf.nn.leaky_relu(
            tf.matmul(node_feat_expanded, self.kernel_attention)
        )
        
        attention_scores = tf.math.exp(tf.squeeze(attention_scores, -1))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=self.graph_info.edges[:, 0],
            num_segments=self.N, # (N, None, seq_len)
        )

        attention_scores_sum = tf.repeat(attention_scores_sum, tf.math.bincount(tf.cast(self.graph_info.edges[:, 0], tf.int32)), axis=0)
        attention_scores_norm = attention_scores / attention_scores_sum

        # Compute Frame scores(beta)
        frame_score = tf.squeeze(tf.matmul(node_feat_expanded, self.kernel_attention), -1)
        frame_score = tf.math.unsorted_segment_mean(
            data=frame_score,
            segment_ids=self.graph_info.edges[:, 0],
            num_segments=self.N, # [N, None, seq_len]
        )
        frame_score = tf.math.reduce_mean(frame_score, axis=0) #(None, seq_len)
        
        frame_score_mean = tf.math.reduce_mean(frame_score, axis=1) #(None)
        frame_score_std = tf.math.reduce_std(frame_score, axis=1) #(None)
        frame_score_norm = tf.math.sigmoid((frame_score - frame_score_mean[:, tf.newaxis])
                             / frame_score_std[:, tf.newaxis]) #(None, seq_len)

        # Gather features of neighbors, apply attention scores and frame scores
        aggregated_message_j = tf.math.unsorted_segment_sum(     
            data=node_feat_expanded_tail * attention_scores_norm[:, :, :, tf.newaxis],
            segment_ids=self.graph_info.edges[:, 0],
            num_segments=self.N
        )
        aggregated_message_j = aggregated_message_j * frame_score_norm[tf.newaxis, :, :, tf.newaxis] #(N, None, seq_len, out_feat)

        # Compute node representation
        nodes_representation = self.compute_nodes_representation(inputs[self.N*(K-1):self.N*K, :, :, :]) #(N, None, seq_len, out_feat)

        return nodes_representation, aggregated_message_j

In [7]:
# JGAT model

class JGAT(layers.Layer):
  def __init__(
      self,
      in_feat,
      out_feat,
      seq_len: int,
      K: int,
      graph_info: GraphInfo,
      num_classes: int,
  ):
      super().__init__()
      self.JGAT_layers = GraphAttention(in_feat, out_feat, graph_info.num_nodes, K, graph_info)
      self.dense_map = layers.Dense(K)
      self.flatten = layers.Flatten()
      self.dense1 = layers.Dense(graph_info.num_nodes, activation="relu")
      self.dense2 = layers.Dense(num_classes, activation="softmax")
      self.dropout_layer = layers.Dropout(rate=0.2)
      
  def get_config(self):
    config = super().get_config()
    return config

  def call(self, inputs, FC_inputs):
    inputs = tf.transpose(inputs, [2, 0, 1, 3]) # (N, None, seq_len, K)
    FC_inputs = tf.transpose(FC_inputs, [1, 0]) # (num_edges, None)
    inputs_transform = tf.concat([inputs[:, :, :, 0:1], inputs[:, :, :, 1:2], inputs[:, :, :, 2:3]], axis=0)
    outputs_attlayer = self.JGAT_layers(inputs_transform, FC_inputs)

    total_embedding = tf.concat([outputs_attlayer[0],  outputs_attlayer[1]], axis=-1) # (N, None, seq_len, out_feat*2)
    total_embedding = tf.transpose(total_embedding, [1, 0, 2, 3]) # (None, N, seq_len, out_feat*2)

    outputs_condense = self.dense_map(total_embedding) # (None, N, seq_len, K)

    out = self.flatten(outputs_condense)
    out = self.dropout_layer(out)
    out = self.dense1(out)
    out = self.dropout_layer(out)
    out = self.dense2(out)

    return out

In [11]:
acc_per_fold = []
loss_per_fold = []
acc_per_fold_train = []
loss_per_fold_train = []

In [12]:
in_feat = 1
out_feat = 10
input_sequence_length = data_st.shape[1]
batch_size = 48
epochs = 10
K = jkernel_size
num_classes = 2

fold_var = 1
n_split = 8

def get_model_name(k):
    return 'model_'+str(k)+'.h5'


for train_index,test_index in KFold(n_split, shuffle=False).split(data_st, FCs_dense, label):
  tf.keras.backend.clear_session()

  x_train,x_test=data_st[train_index],data_st[test_index]
  x_train_FC,x_test_FC=FCs_dense[train_index],FCs_dense[test_index]
  y_train,y_test=label[train_index],label[test_index]

  jgat = JGAT(
      in_feat,
      out_feat,
      input_sequence_length,
      K,
      graph,
      num_classes,
  )
  inputs = layers.Input((input_sequence_length, graph.num_nodes, K))
  FC_inputs = layers.Input(FCs_dense.shape[1])
  outputs = jgat(inputs, FC_inputs)

  model = keras.models.Model([inputs, FC_inputs], outputs)

  model.compile(optimizer = keras.optimizers.Adam(learning_rate=0.001),
          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
          metrics=['acc'])

  checkpoint_filepath = '/content/'
  model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath+get_model_name(fold_var),
    monitor='val_acc',
    verbose=1,
    mode='max',
    save_best_only=True)
  
  callbacks_list=[model_checkpoint_callback]
  history = model.fit([x_train, x_train_FC], y_train, batch_size=batch_size, epochs=epochs, callbacks=callbacks_list, validation_data=([x_test, x_test_FC], y_test))

  model.load_weights("/content/model_"+str(fold_var)+".h5")

  scores_train = model.evaluate([x_train, x_train_FC], y_train, batch_size=batch_size, verbose=0)
  acc_per_fold_train.append(scores_train[1] * 100)
  loss_per_fold_train.append(scores_train[0])

  scores = model.evaluate([x_test, x_test_FC], y_test, batch_size=batch_size, verbose=0)
  print(f'Score for fold {fold_var}: {model.metrics_names[0]} of {scores[0]:.6f}; {model.metrics_names[1]} of {scores[1]*100:.4f}%')
  acc_per_fold.append(scores[1] * 100)
  loss_per_fold.append(scores[0])

  fold_var += 1

print(f"Total samples: {len(acc_per_fold)}") 

Epoch 1/10
Epoch 1: val_acc improved from -inf to 0.79167, saving model to /content/model_1.h5
Epoch 2/10
Epoch 2: val_acc improved from 0.79167 to 0.91667, saving model to /content/model_1.h5
Epoch 3/10
Epoch 3: val_acc improved from 0.91667 to 0.95833, saving model to /content/model_1.h5
Epoch 4/10
Epoch 4: val_acc did not improve from 0.95833
Epoch 5/10
Epoch 5: val_acc improved from 0.95833 to 0.97917, saving model to /content/model_1.h5
Epoch 6/10
Epoch 6: val_acc did not improve from 0.97917
Epoch 7/10
Epoch 7: val_acc did not improve from 0.97917
Epoch 8/10
Epoch 8: val_acc did not improve from 0.97917
Epoch 9/10
Epoch 9: val_acc did not improve from 0.97917
Epoch 10/10
Epoch 10: val_acc did not improve from 0.97917
Score for fold 1: loss of 0.078689; acc of 97.9167%
Epoch 1/10
Epoch 1: val_acc improved from -inf to 0.77083, saving model to /content/model_2.h5
Epoch 2/10
Epoch 2: val_acc improved from 0.77083 to 0.87500, saving model to /content/model_2.h5
Epoch 3/10
Epoch 3: va

In [None]:
print('--------------------------------------------------------------------------------------------')
print('Score per fold')
for i in range(0, len(acc_per_fold)):
  print('--------------------------------------------------------------------------------------------')
  print(f'> Fold {i+1} - Train_Loss: {loss_per_fold_train[i]:.4f} - Train_Accuracy: {acc_per_fold_train[i]:.4f}% -\
    Loss: {loss_per_fold[i]:.4f} - Accuracy: {acc_per_fold[i]:.4f}%')
print('--------------------------------------------------------------------------------------------')
print('Average scores for all folds:')
print(f'> Train_Accuracy: {np.mean(acc_per_fold_train):.4f} (+- {np.std(acc_per_fold_train, ddof=1)/np.sqrt(len(acc_per_fold)):.4f})')
print(f'> Train_Loss: {np.mean(loss_per_fold_train):.4f}')
print()
print(f'> Accuracy: {np.mean(acc_per_fold):.4f} (+- {np.std(acc_per_fold, ddof=1)/np.sqrt(len(acc_per_fold)):.4f})')
print(f'> Loss: {np.mean(loss_per_fold):.4f}')
print('--------------------------------------------------------------------------------------------')