In [55]:
# Imports
!pip install numpy tensorflow spektral

import numpy as np
import tensorflow as tf
import spektral
from spektral.datasets import citation



In [56]:
# Load the data
dataset = citation.Cora()

In [57]:
# Extract data components
graph = dataset[0]
adj = graph.a  # Adjacency matrix
features = graph.x  # Node features
labels = graph.y  # Labels
train_mask = dataset.mask_tr
val_mask = dataset.mask_va
test_mask = dataset.mask_te

In [58]:
# Process the data
adj = adj + np.eye(adj.shape[0])
features = features.astype('float32')
adj = adj.astype('float32')

In [59]:
# Normalize adjacency matrix
deg = np.sum(adj, axis=-1)
norm_deg = np.diag(1.0 / np.sqrt(deg))
norm_adj = np.dot(norm_deg, np.dot(adj, norm_deg))

In [60]:
# Normalize features
features = features / np.linalg.norm(features, axis=1, keepdims=True)

In [61]:
# Print shapes and mask sums
print(features.shape)
print(adj.shape)
print(labels.shape)
print(np.sum(train_mask))
print(np.sum(val_mask))
print(np.sum(test_mask))

(2708, 1433)
(2708, 2708)
(2708, 7)
140
500
1000


In [62]:
def masked_softmax_cross_entropy(logits, labels, mask):
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    loss *= mask
    return tf.reduce_mean(loss)

In [63]:
def masked_accuracy(logits, labels, mask):
    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    accuracy_all = tf.cast(correct_prediction, tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    accuracy_all *= mask
    return tf.reduce_mean(accuracy_all)

In [64]:
def gnn(fts, adj, transform, activation):
    seq_fts = transform(fts)
    ret_fts = tf.matmul(adj, seq_fts)
    return activation(ret_fts)

In [65]:
def train_cora(fts, adj, gnn_fn, units, epochs, lr, patience):
    lyr_1 = tf.keras.layers.Dense(units, activation=None, kernel_regularizer=tf.keras.regularizers.l2(5e-4))
    lyr_2 = tf.keras.layers.Dense(units, activation=None, kernel_regularizer=tf.keras.regularizers.l2(5e-4))
    lyr_out = tf.keras.layers.Dense(7, activation=None)
    dropout = tf.keras.layers.Dropout(0.6)
    batch_norm = tf.keras.layers.BatchNormalization()

    def cora_gnn(fts, adj):
        hidden = gnn_fn(fts, adj, lyr_1, tf.nn.relu)
        hidden = dropout(hidden)
        hidden = batch_norm(hidden)

        hidden = gnn_fn(hidden, adj, lyr_2, tf.nn.relu)
        hidden = dropout(hidden)
        hidden = batch_norm(hidden)

        logits = gnn_fn(hidden, adj, lyr_out, tf.identity)
        return logits

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    best_val_accuracy = 0.0
    patience_counter = 0

    for ep in range(epochs+1):
        with tf.GradientTape() as t:
            logits = cora_gnn(fts, adj)
            loss = masked_softmax_cross_entropy(logits, labels, train_mask)

        variables = t.watched_variables()
        grads = t.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads, variables))

        logits = cora_gnn(fts, adj)
        val_accuracy = masked_accuracy(logits, labels, val_mask)
        test_accuracy = masked_accuracy(logits, labels, test_mask)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            patience_counter = 0
            print(f'Epoch {ep:03d} | Training loss: {loss.numpy():.4f} | Val accuracy: {val_accuracy.numpy():.4f} | Test accuracy: {test_accuracy.numpy():.4f}')
        else:
            patience_counter += 1
            if patience_counter > patience:
                break

In [67]:
# Train with normalized adjacency matrix
train_cora(features, norm_adj, gnn, 128, 300, 0.005, patience=50)

Epoch 000 | Training loss: 1.9466 | Val accuracy: 0.2580 | Test accuracy: 0.2670
Epoch 001 | Training loss: 1.9206 | Val accuracy: 0.2960 | Test accuracy: 0.3110
Epoch 002 | Training loss: 1.8796 | Val accuracy: 0.4440 | Test accuracy: 0.4570
Epoch 003 | Training loss: 1.8168 | Val accuracy: 0.5260 | Test accuracy: 0.5550
Epoch 004 | Training loss: 1.7303 | Val accuracy: 0.6040 | Test accuracy: 0.6550
Epoch 005 | Training loss: 1.6180 | Val accuracy: 0.6920 | Test accuracy: 0.7130
Epoch 006 | Training loss: 1.4819 | Val accuracy: 0.7260 | Test accuracy: 0.7610
Epoch 007 | Training loss: 1.3236 | Val accuracy: 0.7640 | Test accuracy: 0.7850
Epoch 008 | Training loss: 1.1479 | Val accuracy: 0.7880 | Test accuracy: 0.8100
Epoch 009 | Training loss: 0.9637 | Val accuracy: 0.7920 | Test accuracy: 0.8120
Epoch 010 | Training loss: 0.7823 | Val accuracy: 0.8020 | Test accuracy: 0.8140
Epoch 011 | Training loss: 0.6134 | Val accuracy: 0.8040 | Test accuracy: 0.8160
