In [None]:
!pip install spektral==0.6.2

Collecting spektral==0.6.2
  Downloading spektral-0.6.2-py3-none-any.whl (95 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.4/95.4 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: spektral
  Attempting uninstall: spektral
    Found existing installation: spektral 1.3.0
    Uninstalling spektral-1.3.0:
      Successfully uninstalled spektral-1.3.0
Successfully installed spektral-0.6.2


In [None]:
import numpy as np
import tensorflow as tf
import spektral

In [None]:
adj, features, labels, train_mask, val_mask, test_mask = spektral.datasets.citation.load_data(dataset_name='cora')
features = features.todense()
adj = adj.todense() + np.eye(adj.shape[0])
adj = adj.astype('float32')

print(features.shape)
print(adj.shape)
print(labels.shape)

print(np.sum(train_mask))
print(np.sum(val_mask))
print(np.sum(test_mask))

Loading cora dataset
Pre-processing node features
(2708, 1433)
(2708, 2708)
(2708, 7)
140
500
1000


In [None]:
def masked_softmax_cross_entropy(logits, labels, mask):
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    mask = tf.cast(mask, dtype=tf.float32)
    mask /= tf.reduce_mean(mask)
    loss *= mask
    return tf.reduce_mean(loss)

def masked_accuracy(logits, labels, mask):
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
    acc = tf.cast(correct_pred, tf.float32)
    mask = tf.cast(mask, tf.float32)
    mask /= tf.reduce_mean(mask)
    acc *= mask
    return tf.reduce_mean(acc)

In [None]:
def gnn(features, adj, transform, act):
    seq_features = transform(features)
    ret_features = tf.matmul(adj, seq_features)
    return act(ret_features)

In [None]:
def train_cora(features, adj, gnn_fn, units, no_epochs, lr):
    fc1 = tf.keras.layers.Dense(units)
    fc2 = tf.keras.layers.Dense(7)

    def cora_gnn(features, adj):
        hidden = gnn_fn(features, adj, fc1, tf.nn.relu)
        logits = gnn_fn(hidden, adj, fc2, tf.identity)
        return logits

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    best_acc = 0.0
    for epoch in range(no_epochs):
        with tf.GradientTape() as t:
            logits = cora_gnn(features, adj)
            loss = masked_softmax_cross_entropy(logits, labels, train_mask)

        variables = t.watched_variables()
        grads = t.gradient(loss, variables)
        optimizer.apply_gradients(zip(grads, variables))

        logits = cora_gnn(features, adj)
        val_acc = masked_accuracy(logits, labels, val_mask)
        test_acc = masked_accuracy(logits, labels, test_mask)

        if val_acc > best_acc:
            best_acc = val_acc
            print("Training loss:", loss.numpy(), '| Val accuracy:', val_acc.numpy(), '| Test accuracy:', test_acc.numpy())

In [None]:
train_cora(features, adj, gnn, 32, 200, 0.01)

Training loss: 1.9709835 | Val accuracy: 0.25 | Test accuracy: 0.25100002
Training loss: 1.8293594 | Val accuracy: 0.294 | Test accuracy: 0.328
Training loss: 1.6372524 | Val accuracy: 0.584 | Test accuracy: 0.60899997
Training loss: 1.2666545 | Val accuracy: 0.644 | Test accuracy: 0.678
Training loss: 1.1482413 | Val accuracy: 0.662 | Test accuracy: 0.685
Training loss: 0.93269634 | Val accuracy: 0.67199993 | Test accuracy: 0.687
Training loss: 0.85000443 | Val accuracy: 0.714 | Test accuracy: 0.73599994
Training loss: 0.7725894 | Val accuracy: 0.734 | Test accuracy: 0.74799997
Training loss: 0.66383713 | Val accuracy: 0.736 | Test accuracy: 0.753
Training loss: 0.59492326 | Val accuracy: 0.744 | Test accuracy: 0.75200003
Training loss: 0.5399339 | Val accuracy: 0.75200003 | Test accuracy: 0.75399995
Training loss: 0.40955243 | Val accuracy: 0.75399995 | Test accuracy: 0.752


In [None]:
deg = tf.reduce_sum(adj, axis=-1)
norm_deg = tf.linalg.diag(1.0 / tf.sqrt(deg))
norm_adj = tf.matmul(norm_deg, tf.matmul(adj, norm_deg))
train_cora(features, norm_adj, gnn, 32, 200, 0.01)

Training loss: 1.9459112 | Val accuracy: 0.152 | Test accuracy: 0.16899998
Training loss: 1.936286 | Val accuracy: 0.16599998 | Test accuracy: 0.19899999
Training loss: 1.9231865 | Val accuracy: 0.27199998 | Test accuracy: 0.297
Training loss: 1.9061885 | Val accuracy: 0.376 | Test accuracy: 0.39999998
Training loss: 1.8879423 | Val accuracy: 0.42199996 | Test accuracy: 0.43199998
Training loss: 1.8478628 | Val accuracy: 0.43199998 | Test accuracy: 0.42699996
Training loss: 1.8254884 | Val accuracy: 0.45599997 | Test accuracy: 0.44899994
Training loss: 1.8015292 | Val accuracy: 0.494 | Test accuracy: 0.47599998
Training loss: 1.7757418 | Val accuracy: 0.52599996 | Test accuracy: 0.50799996
Training loss: 1.7479149 | Val accuracy: 0.562 | Test accuracy: 0.55799997
Training loss: 1.7178336 | Val accuracy: 0.60400003 | Test accuracy: 0.59999996
Training loss: 1.6860932 | Val accuracy: 0.634 | Test accuracy: 0.63100004
Training loss: 1.6531155 | Val accuracy: 0.654 | Test accuracy: 0.648
T