In [1]:
import tensorflow as tf
import spektral 
import numpy as np 

In [2]:
physical_devices = tf.config.list_physical_devices('GPU') 
for device in physical_devices:
    tf.config.experimental.set_memory_growth(device, True)

In [3]:
cora= spektral.datasets.citation.Citation('cora',random_split=True,dtype= np.float32)

  self._set_arrayXarray(i, j, x)


In [4]:
cora_g =  cora.read()[0]
type(cora_g)

  self._set_arrayXarray(i, j, x)


spektral.data.graph.Graph

In [5]:
adj= cora_g.a # note that this is a numpy sparse matrix 
features= cora_g.x
labels=cora_g.y
train_mask=cora.mask_tr
val_mask=cora.mask_va
test_mask= cora.mask_te

In [6]:
adj = adj.todense()+np.eye(adj.shape[0])
adj = adj.astype('float32')

In [7]:
def masked_softmax_cross_entropy(logits,labels,mask): #loss
    loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits,labels=labels)
    mask = tf.cast(mask , dtype = tf.float32)
    mask /= tf.reduce_mean(mask)
    loss *= mask
    return tf.reduce_mean(loss)

def masked_accuracy(logits,labels, mask):
    correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1))
    accuracy_all = tf.cast(correct_prediction,tf.float32)
    mask = tf.cast(mask,tf.float32)
    mask /= tf.reduce_mean(mask)
    accuracy_all *= mask
    return tf.reduce_mean(accuracy_all)

In [8]:
#There are weights so not good to wrap around tf.keras.layers.Lambda()
class gnn_layer(tf.keras.layers.Layer):
    def __init__(self, adj, n_out_fts, activation, **kwargs):
        super().__init__(**kwargs)
        self.adj = adj 
        self.units = n_out_fts 
        self.activation = tf.keras.activations.get(activation)
        
    def build(self,input_shape):
        self.kernel = self.add_weight(
        name="kernel_of_transform", shape= [input_shape[-1],self.units])
        self.bias = self.add_weight(
            "bias_of_transform", shape= [self.units])
        super().build(input_shape)
    def call(self, fts):
        seq_fts = fts @ self.kernel + self.bias
        #print("seq_fts {} adj {}".format(seq_fts.shape , adj.shape))
        aggregation = self.adj @ seq_fts
        return self.activation(aggregation)
    
    
class gnn(tf.keras.models.Model):
    def __init__(self,adj,layer_units,activations):
        super().__init__()
        self.layer1 = gnn_layer(adj, layer_units[0], activations[0])
        self.layer2 = gnn_layer(adj, layer_units[1], activations[1])
        
    def call(self, inputs):
        hidden = self.layer1(inputs)
        logits = self.layer2(hidden)
        return logits

In [10]:
from tqdm.auto import tqdm, trange 
from time import sleep

best_accuracy = 0.0
epochs = 200
lr = 0.01

optimizer = tf.keras.optimizers.Adam(learning_rate = lr)

model = gnn(adj,[32,7],['relu',tf.identity])

for epoch in trange(1,epochs+1):
    with tf.GradientTape(persistent = True) as tape:
        logits = model(features)
        loss = masked_softmax_cross_entropy(logits,labels,train_mask)
        
    variables =tape.watched_variables()
    gradients = tape.gradient(loss,variables)
    optimizer.apply_gradients(zip(gradients, variables))
    
    predicts  = logits
    del tape  
    sleep(0.001)
    
    val_accuracy = masked_accuracy(predicts, labels, val_mask)
    test_accuracy = masked_accuracy(predicts, labels, test_mask)
        
    if val_accuracy>best_accuracy:
        best_accuracy=val_accuracy
        tqdm.write("\rEpoch {}/{} | Training loss {:.4f} |  Val accuracy {:.4f} | Test accuracy {:.4f}".
                  format(epoch,epochs,loss.numpy(),val_accuracy.numpy(),test_accuracy.numpy()),
                  end="\n" if epoch<epochs+1 else "" )

  0%|          | 0/200 [00:00<?, ?it/s]

Epoch 1/200 | Training loss 5.7407 |  Val accuracy 0.2286 | Test accuracy 0.1934
Epoch 2/200 | Training loss 6.8915 |  Val accuracy 0.3524 | Test accuracy 0.3579
Epoch 3/200 | Training loss 3.0230 |  Val accuracy 0.4238 | Test accuracy 0.3978
Epoch 4/200 | Training loss 2.0236 |  Val accuracy 0.5048 | Test accuracy 0.4423
Epoch 5/200 | Training loss 1.0718 |  Val accuracy 0.5714 | Test accuracy 0.5649
Epoch 6/200 | Training loss 1.0757 |  Val accuracy 0.7333 | Test accuracy 0.7180
Epoch 7/200 | Training loss 1.2641 |  Val accuracy 0.7667 | Test accuracy 0.7320
Epoch 8/200 | Training loss 1.2131 |  Val accuracy 0.7714 | Test accuracy 0.7506
Epoch 9/200 | Training loss 0.9938 |  Val accuracy 0.7762 | Test accuracy 0.7680
Epoch 19/200 | Training loss 0.0973 |  Val accuracy 0.7905 | Test accuracy 0.7574
