In [1]:
import numpy as np
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Dropout, Input
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.random import set_seed

from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GATConv
from spektral.transforms import LayerPreprocess

In [2]:
dataset = Citation("cora", normalize_x=True, transforms=[LayerPreprocess(GATConv)])

def mask_to_weights(mask):
    return mask.astype(np.float32) / np.count_nonzero(mask)

weights_tr, weights_va, weights_te = (
    mask_to_weights(mask)
    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)

Downloading cora dataset.
Pre-processing node features


  self._set_arrayXarray(i, j, x)


In [4]:
channels = 8  # Number of channels in each head of the first GAT layer
n_attn_heads = 8  # Number of attention heads in first GAT layer
dropout = 0.6  # Dropout rate for the features and adjacency matrix
l2_reg = 2.5e-4  # L2 regularization rate
learning_rate = 5e-3  # Learning rate
epochs = 20000  # Number of training epochs
patience = 100  # Patience for early stopping

N = dataset.n_nodes  # Number of nodes in the graph
F = dataset.n_node_features  # Original size of node features
n_out = dataset.n_labels  # Number of classes

x_in = Input(shape=(F,))
a_in = Input((N,), sparse=True)

do_1 = Dropout(dropout)(x_in)
gc_1 = GATConv(
    channels,
    attn_heads=n_attn_heads,
    concat_heads=True,
    dropout_rate=dropout,
    activation="elu",
    kernel_regularizer=l2(l2_reg),
    attn_kernel_regularizer=l2(l2_reg),
    bias_regularizer=l2(l2_reg),
)([do_1, a_in])
do_2 = Dropout(dropout)(gc_1)
gc_2 = GATConv(
    n_out,
    attn_heads=1,
    concat_heads=False,
    dropout_rate=dropout,
    activation="softmax",
    kernel_regularizer=l2(l2_reg),
    attn_kernel_regularizer=l2(l2_reg),
    bias_regularizer=l2(l2_reg),
)([do_2, a_in])

# Build model
model = Model(inputs=[x_in, a_in], outputs=gc_2)
optimizer = Adam(learning_rate=learning_rate)
model.compile(
    optimizer=optimizer,
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 1433)]       0           []                               
                                                                                                  
 dropout_2 (Dropout)            (None, 1433)         0           ['input_3[0][0]']                
                                                                                                  
 input_4 (InputLayer)           [(None, 2708)]       0           []                               
                                                                                                  
 gat_conv_2 (GATConv)           (None, 64)           91904       ['dropout_2[0][0]',              
                                                                  'input_4[0][0]']          

In [5]:
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=epochs,
    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)

Epoch 1/20000
Epoch 2/20000
Epoch 3/20000
Epoch 4/20000
Epoch 5/20000
Epoch 6/20000
Epoch 7/20000
Epoch 8/20000
Epoch 9/20000
Epoch 10/20000
Epoch 11/20000
Epoch 12/20000
Epoch 13/20000
Epoch 14/20000
Epoch 15/20000
Epoch 16/20000
Epoch 17/20000
Epoch 18/20000
Epoch 19/20000
Epoch 20/20000
Epoch 21/20000
Epoch 22/20000
Epoch 23/20000
Epoch 24/20000
Epoch 25/20000
Epoch 26/20000
Epoch 27/20000
Epoch 28/20000
Epoch 29/20000
Epoch 30/20000
Epoch 31/20000
Epoch 32/20000
Epoch 33/20000
Epoch 34/20000
Epoch 35/20000
Epoch 36/20000
Epoch 37/20000
Epoch 38/20000
Epoch 39/20000
Epoch 40/20000
Epoch 41/20000
Epoch 42/20000
Epoch 43/20000
Epoch 44/20000
Epoch 45/20000
Epoch 46/20000
Epoch 47/20000
Epoch 48/20000
Epoch 49/20000
Epoch 50/20000
Epoch 51/20000
Epoch 52/20000
Epoch 53/20000
Epoch 54/20000
Epoch 55/20000
Epoch 56/20000
Epoch 57/20000
Epoch 58/20000
Epoch 59/20000
Epoch 60/20000
Epoch 61/20000
Epoch 62/20000


Epoch 63/20000
Epoch 64/20000
Epoch 65/20000
Epoch 66/20000
Epoch 67/20000
Epoch 68/20000
Epoch 69/20000
Epoch 70/20000
Epoch 71/20000
Epoch 72/20000
Epoch 73/20000
Epoch 74/20000
Epoch 75/20000
Epoch 76/20000
Epoch 77/20000
Epoch 78/20000
Epoch 79/20000
Epoch 80/20000
Epoch 81/20000
Epoch 82/20000
Epoch 83/20000
Epoch 84/20000
Epoch 85/20000
Epoch 86/20000
Epoch 87/20000
Epoch 88/20000
Epoch 89/20000
Epoch 90/20000
Epoch 91/20000
Epoch 92/20000
Epoch 93/20000
Epoch 94/20000
Epoch 95/20000
Epoch 96/20000
Epoch 97/20000
Epoch 98/20000
Epoch 99/20000
Epoch 100/20000
Epoch 101/20000
Epoch 102/20000
Epoch 103/20000
Epoch 104/20000
Epoch 105/20000
Epoch 106/20000
Epoch 107/20000
Epoch 108/20000
Epoch 109/20000
Epoch 110/20000
Epoch 111/20000
Epoch 112/20000
Epoch 113/20000
Epoch 114/20000
Epoch 115/20000
Epoch 116/20000
Epoch 117/20000
Epoch 118/20000
Epoch 119/20000
Epoch 120/20000
Epoch 121/20000
Epoch 122/20000
Epoch 123/20000
Epoch 124/20000


Epoch 125/20000
Epoch 126/20000
Epoch 127/20000
Epoch 128/20000
Epoch 129/20000
Epoch 130/20000
Epoch 131/20000
Epoch 132/20000
Epoch 133/20000
Epoch 134/20000
Epoch 135/20000
Epoch 136/20000
Epoch 137/20000
Epoch 138/20000
Epoch 139/20000
Epoch 140/20000
Epoch 141/20000
Epoch 142/20000
Epoch 143/20000
Epoch 144/20000
Epoch 145/20000
Epoch 146/20000
Epoch 147/20000
Epoch 148/20000
Epoch 149/20000
Epoch 150/20000
Epoch 151/20000
Epoch 152/20000
Epoch 153/20000
Epoch 154/20000
Epoch 155/20000
Epoch 156/20000
Epoch 157/20000
Epoch 158/20000
Epoch 159/20000
Epoch 160/20000
Epoch 161/20000
Epoch 162/20000
Epoch 163/20000
Epoch 164/20000
Epoch 165/20000
Epoch 166/20000
Epoch 167/20000
Epoch 168/20000
Epoch 169/20000
Epoch 170/20000
Epoch 171/20000
Epoch 172/20000
Epoch 173/20000
Epoch 174/20000
Epoch 175/20000
Epoch 176/20000
Epoch 177/20000
Epoch 178/20000
Epoch 179/20000
Epoch 180/20000
Epoch 181/20000
Epoch 182/20000
Epoch 183/20000
Epoch 184/20000
Epoch 185/20000


Epoch 186/20000
Epoch 187/20000
Epoch 188/20000
Epoch 189/20000
Epoch 190/20000
Epoch 191/20000
Epoch 192/20000
Epoch 193/20000
Epoch 194/20000
Epoch 195/20000
Epoch 196/20000
Epoch 197/20000
Epoch 198/20000
Epoch 199/20000
Epoch 200/20000
Epoch 201/20000
Epoch 202/20000
Epoch 203/20000
Epoch 204/20000
Epoch 205/20000
Epoch 206/20000
Epoch 207/20000
Epoch 208/20000
Epoch 209/20000
Epoch 210/20000
Epoch 211/20000
Epoch 212/20000
Epoch 213/20000
Epoch 214/20000
Epoch 215/20000
Epoch 216/20000
Epoch 217/20000
Epoch 218/20000
Epoch 219/20000
Epoch 220/20000
Epoch 221/20000
Epoch 222/20000
Epoch 223/20000
Epoch 224/20000
Epoch 225/20000
Epoch 226/20000
Epoch 227/20000
Epoch 228/20000
Epoch 229/20000
Epoch 230/20000
Epoch 231/20000
Epoch 232/20000
Epoch 233/20000
Epoch 234/20000
Epoch 235/20000
Epoch 236/20000
Epoch 237/20000
Epoch 238/20000
Epoch 239/20000
Epoch 240/20000
Epoch 241/20000
Epoch 242/20000
Epoch 243/20000
Epoch 244/20000
Epoch 245/20000
Epoch 246/20000


Epoch 247/20000
Epoch 248/20000
Epoch 249/20000
Epoch 250/20000
Epoch 251/20000
Epoch 252/20000
Epoch 253/20000
Epoch 254/20000
Epoch 255/20000
Epoch 256/20000
Epoch 257/20000
Epoch 258/20000
Epoch 259/20000
Epoch 260/20000
Epoch 261/20000
Epoch 262/20000
Epoch 263/20000
Epoch 264/20000
Epoch 265/20000
Epoch 266/20000
Epoch 267/20000
Epoch 268/20000
Epoch 269/20000
Epoch 270/20000
Epoch 271/20000
Epoch 272/20000
Epoch 273/20000
Epoch 274/20000
Epoch 275/20000
Epoch 276/20000
Epoch 277/20000
Epoch 278/20000
Epoch 279/20000
Epoch 280/20000
Epoch 281/20000
Epoch 282/20000
Epoch 283/20000
Epoch 284/20000
Epoch 285/20000
Epoch 286/20000
Epoch 287/20000
Epoch 288/20000
Epoch 289/20000
Epoch 290/20000
Epoch 291/20000
Epoch 292/20000
Epoch 293/20000
Epoch 294/20000
Epoch 295/20000
Epoch 296/20000
Epoch 297/20000
Epoch 298/20000
Epoch 299/20000
Epoch 300/20000
Epoch 301/20000
Epoch 302/20000
Epoch 303/20000
Epoch 304/20000
Epoch 305/20000
Epoch 306/20000
Epoch 307/20000


Epoch 308/20000
Epoch 309/20000
Epoch 310/20000
Epoch 311/20000
Epoch 312/20000
Epoch 313/20000
Epoch 314/20000
Epoch 315/20000
Epoch 316/20000
Epoch 317/20000
Epoch 318/20000
Epoch 319/20000
Epoch 320/20000
Epoch 321/20000
Epoch 322/20000
Epoch 323/20000
Epoch 324/20000
Epoch 325/20000
Epoch 326/20000
Epoch 327/20000
Epoch 328/20000
Epoch 329/20000
Epoch 330/20000
Epoch 331/20000
Epoch 332/20000
Epoch 333/20000
Epoch 334/20000
Epoch 335/20000
Epoch 336/20000
Epoch 337/20000
Epoch 338/20000
Epoch 339/20000
Epoch 340/20000
Epoch 341/20000
Epoch 342/20000
Epoch 343/20000
Epoch 344/20000
Epoch 345/20000
Epoch 346/20000
Epoch 347/20000
Epoch 348/20000
Epoch 349/20000
Epoch 350/20000
Epoch 351/20000
Epoch 352/20000
Epoch 353/20000
Epoch 354/20000
Epoch 355/20000
Epoch 356/20000
Epoch 357/20000
Epoch 358/20000
Epoch 359/20000
Epoch 360/20000
Epoch 361/20000
Epoch 362/20000
Epoch 363/20000
Epoch 364/20000
Epoch 365/20000
Epoch 366/20000
Epoch 367/20000
Epoch 368/20000


Epoch 369/20000
Epoch 370/20000
Epoch 371/20000
Epoch 372/20000
Epoch 373/20000
Epoch 374/20000
Epoch 375/20000
Epoch 376/20000
Epoch 377/20000
Epoch 378/20000
Epoch 379/20000
Epoch 380/20000
Epoch 381/20000
Epoch 382/20000
Epoch 383/20000
Epoch 384/20000
Epoch 385/20000
Epoch 386/20000
Epoch 387/20000
Epoch 388/20000
Epoch 389/20000
Epoch 390/20000
Epoch 391/20000
Epoch 392/20000
Epoch 393/20000
Epoch 394/20000
Epoch 395/20000
Epoch 396/20000
Epoch 397/20000
Epoch 398/20000
Epoch 399/20000
Epoch 400/20000
Epoch 401/20000
Epoch 402/20000
Epoch 403/20000
Epoch 404/20000
Epoch 405/20000
Epoch 406/20000
Epoch 407/20000
Epoch 408/20000
Epoch 409/20000
Epoch 410/20000
Epoch 411/20000
Epoch 412/20000
Epoch 413/20000
Epoch 414/20000
Epoch 415/20000
Epoch 416/20000
Epoch 417/20000
Epoch 418/20000
Epoch 419/20000
Epoch 420/20000
Epoch 421/20000
Epoch 422/20000
Epoch 423/20000
Epoch 424/20000
Epoch 425/20000
Epoch 426/20000
Epoch 427/20000
Epoch 428/20000
Epoch 429/20000


Epoch 430/20000
Epoch 431/20000
Epoch 432/20000
Epoch 433/20000
Epoch 434/20000
Epoch 435/20000
Epoch 436/20000
Epoch 437/20000
Epoch 438/20000
Epoch 439/20000
Epoch 440/20000
Epoch 441/20000
Epoch 442/20000
Epoch 443/20000
Epoch 444/20000
Epoch 445/20000
Epoch 446/20000
Epoch 447/20000
Epoch 448/20000
Epoch 449/20000
Epoch 450/20000
Epoch 451/20000
Epoch 452/20000
Epoch 453/20000
Epoch 454/20000
Epoch 455/20000
Epoch 456/20000
Epoch 457/20000
Epoch 458/20000
Epoch 459/20000
Epoch 460/20000
Epoch 461/20000
Epoch 462/20000
Epoch 463/20000
Epoch 464/20000
Epoch 465/20000
Epoch 466/20000
Epoch 467/20000
Epoch 468/20000
Epoch 469/20000
Epoch 470/20000
Epoch 471/20000
Epoch 472/20000
Epoch 473/20000
Epoch 474/20000
Epoch 475/20000
Epoch 476/20000
Epoch 477/20000
Epoch 478/20000
Epoch 479/20000
Epoch 480/20000
Epoch 481/20000
Epoch 482/20000
Epoch 483/20000
Epoch 484/20000
Epoch 485/20000
Epoch 486/20000
Epoch 487/20000
Epoch 488/20000
Epoch 489/20000
Epoch 490/20000


Epoch 491/20000
Epoch 492/20000
Epoch 493/20000
Epoch 494/20000
Epoch 495/20000
Epoch 496/20000
Epoch 497/20000
Epoch 498/20000
Epoch 499/20000
Epoch 500/20000
Epoch 501/20000
Epoch 502/20000
Epoch 503/20000
Epoch 504/20000
Epoch 505/20000
Epoch 506/20000
Epoch 507/20000
Epoch 508/20000
Epoch 509/20000
Epoch 510/20000
Epoch 511/20000
Epoch 512/20000
Epoch 513/20000
Epoch 514/20000
Epoch 515/20000
Epoch 516/20000
Epoch 517/20000
Epoch 518/20000
Epoch 519/20000
Epoch 520/20000
Epoch 521/20000
Epoch 522/20000
Epoch 523/20000
Epoch 524/20000
Epoch 525/20000
Epoch 526/20000
Epoch 527/20000
Epoch 528/20000
Epoch 529/20000
Epoch 530/20000
Epoch 531/20000
Epoch 532/20000
Epoch 533/20000
Epoch 534/20000
Epoch 535/20000
Epoch 536/20000
Epoch 537/20000
Epoch 538/20000
Epoch 539/20000
Epoch 540/20000
Epoch 541/20000
Epoch 542/20000
Epoch 543/20000
Epoch 544/20000
Epoch 545/20000
Epoch 546/20000
Epoch 547/20000
Epoch 548/20000
Epoch 549/20000
Epoch 550/20000
Epoch 551/20000


Epoch 552/20000
Epoch 553/20000
Epoch 554/20000
Epoch 555/20000
Epoch 556/20000
Epoch 557/20000
Epoch 558/20000
Epoch 559/20000
Epoch 560/20000
Epoch 561/20000
Epoch 562/20000
Epoch 563/20000
Epoch 564/20000
Epoch 565/20000
Epoch 566/20000
Epoch 567/20000
Epoch 568/20000
Epoch 569/20000
Epoch 570/20000
Epoch 571/20000
Epoch 572/20000
Epoch 573/20000
Epoch 574/20000
Epoch 575/20000
Epoch 576/20000
Epoch 577/20000
Epoch 578/20000
Epoch 579/20000
Epoch 580/20000
Epoch 581/20000
Epoch 582/20000
Epoch 583/20000
Epoch 584/20000
Epoch 585/20000
Epoch 586/20000
Epoch 587/20000
Epoch 588/20000
Epoch 589/20000
Epoch 590/20000
Epoch 591/20000
Epoch 592/20000
Epoch 593/20000
Epoch 594/20000
Epoch 595/20000
Epoch 596/20000
Epoch 597/20000
Epoch 598/20000
Epoch 599/20000
Epoch 600/20000
Epoch 601/20000
Epoch 602/20000
Epoch 603/20000
Epoch 604/20000
Epoch 605/20000
Epoch 606/20000
Epoch 607/20000
Epoch 608/20000
Epoch 609/20000
Epoch 610/20000
Epoch 611/20000
Epoch 612/20000


Epoch 613/20000
Epoch 614/20000
Epoch 615/20000
Epoch 616/20000
Epoch 617/20000
Epoch 618/20000
Epoch 619/20000
Epoch 620/20000
Epoch 621/20000
Epoch 622/20000
Epoch 623/20000
Epoch 624/20000
Epoch 625/20000
Epoch 626/20000
Epoch 627/20000
Epoch 628/20000
Epoch 629/20000
Epoch 630/20000
Epoch 631/20000
Epoch 632/20000
Epoch 633/20000
Epoch 634/20000
Epoch 635/20000
Epoch 636/20000


<keras.callbacks.History at 0x2d07f604cc8>

In [6]:
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))

Evaluating model.
Done.
Test loss: 0.9854999780654907
Test accuracy: 0.8349997997283936
