Youtube Link: https://www.youtube.com/watch?v=8owQBFAHw7E&ab_channel=TensorFlow

In [None]:
!pip install numpy
!pip install tensorflow
!pip install spektral

import numpy as np
import tensorflow as tf
import spektral

In [None]:
dataset = spektral.datasets.citation.Citation(name='cora', dtype='float32')

In [None]:
adj = dataset[0]['a']
node_features = dataset[0]['x']
edge_features = dataset[0]['e']
labels = dataset[0]['y']

train_mask = dataset.mask_tr
test_mask = dataset.mask_te
val_mask = dataset.mask_va

In [None]:
adj = adj + np.eye(adj.shape[0]) # add identity matrix (self connection)

adj = adj.astype('float32')
node_featuers = node_features.astype('float32')

print(adj.shape)
print(node_features.shape)
print(labels.shape)

print(np.sum(train_mask))
print(np.sum(val_mask))
print(np.sum(test_mask))

In [None]:
def masked_softmax_cross_entropy(logits, labels, mask):
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32) # mask loss
    mask /= tf.reduce_mean(mask) # average the value so can take product with loss
    loss *= mask
    return tf.reduce_mean(loss)

def masked_accuracy(logits, labels, mask):
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy_all = tf.cast(correct_prediction, tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    accuracy_all *= mask
    return tf.reduce_mean(accuracy_all) # accuracy over nodes we care about

In [None]:
def gnn(fts, adj, transform, activation):
    seq_fts = transform(fts) # point wise transformation, W
    ret_fts = tf.matmul(adj, seq_fts) # matrix mult of adjancecy and W
    return activation(ret_fts) # apply activation function

In [None]:
def train_cora(fts, adj, gnn_fn, units, epochs, lr):
    lyr_1 = tf.keras.layers.Dense(units)
    lyr_2 = tf.keras.layers.Dense(7) # for number of classes; classification

    def cora_gnn(fts, adj):
        hidden = gnn_fn(fts, adj, lyr_1, tf.nn.relu) # first pass with transform
        logits = gnn_fn(hidden, adj, lyr_2, tf.identity) # identity to not transform
        return logits # return this as nn predictions

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    best_accuracy = 0.0
    for ep in range(epochs + 1):
        with tf.GradientTape() as t: # to record gradients
            logits = cora_gnn(fts, adj) # compute predictions
            loss = masked_softmax_cross_entropy(logits, labels, train_mask) # calc loss of training mask
        
        variables = t.watched_variables() # get variables gradient tape was watching (specify variables to update)
        grads = t.gradient(loss, variables) 
        optimizer.apply_gradients(zip(grads, variables)) # apply gradients via optimizer

        logits = cora_gnn(fts, adj) # take logits of fts and adj
        val_accuracy = masked_accuracy(logits, labels, val_mask)
        test_accuracy = masked_accuracy(logits, labels, test_mask)

        if val_accuracy > best_accuracy: 
            best_accuracy = val_accuracy
            # should save best model; but we just print 
            print('Epoch', ep, '| Training Loss:', loss.numpy(), '| Val Accuracy:', val_accuracy.numpy(), '| Test Accuracy:', test_accuracy.numpy())

In [None]:
train_cora(node_features, adj, gnn, 32, 200, 0.01) # 32 units, 200 epochs, lr 0.01 (standard params)
# pass raw adj to this, multiply by 0,1 matrix; sum pooling. Expecting to have issues with scaling, not best result possible
# very quickly converges to a set of weights

In [None]:
train_cora(node_features, tf.eye(adj.shape[0]), gnn, 32, 200, 0.01) # way to test; adj to identity
# point wise MLP may not go beyond MLP

In [None]:
deg = tf.reduce_sum(adj, axis=-1) # try mean pooling; compute degree of each node spread across diagonal
train_cora(node_features, adj/deg, gnn, 32, 200, 0.01) # normalised grad; help deal with exploding/vanishing gradient

In [None]:
norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg)) 
norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg)) # proposed by thomas kipf
train_cora(node_features, norm_adj, gnn, 32, 200, 0.01)