# Example of training Graph Neural Networks (GNNs)

GRB provides easy-to-use APIs to train GNNs, facilitating the entire process from loading graph data, building GNN models, to evaluation and inference. Here is an example of training GCNs (Graph Convolution Networks).

Contents
- [Load Dataset](#Load-Dataset)
- [Build Model](#Build-Model)
- [Training](#Training)
- [Inference](#Inference)
- [Evaluation](#Evaluation)

In [1]:
import os
import torch
import grb.utils as utils

## Load Dataset

GRB datasets are named by the prefix *grb-*. There are four *mode* ('easy', 'medium', 'hard', 'full') for test set, representing different average degrees of test nodes, thus different difficulty for attacking them. The node features are processed by *arctan* normalization (first standardization then arctan function), which makes node features fall in the same scale.

In [2]:
from grb.dataset import Dataset

dataset_name = 'grb-cora'
dataset = Dataset(name=dataset_name, 
                  data_dir="../data/",
                  mode='full',
                  feat_norm='arctan')

Dataset 'grb-cora' loaded.
    Number of nodes: 2680
    Number of edges: 5148
    Number of features: 302
    Number of classes: 7
    Number of train samples: 1608
    Number of val samples: 268
    Number of test samples: 804
    Dataset mode: full
    Feature range: [-0.9406, 0.9430]


## Build Model

GRB supports models based on pure Pytorch, CogDL or DGL. The following is an example of GCN implemented by pure Pytorch. Other models can be found in ``grb/model/torch``, ``grb/model/cogdl``, or ``grb/model/dgl``.

### GCN

In [32]:
from grb.model.torch import GCN
from grb.utils.normalize import GCNAdjNorm

model_name = "gcn"
model = GCN(in_features=dataset.num_features,
            out_features=dataset.num_classes,
            hidden_features=[64, 64], 
            adj_norm_func=GCNAdjNorm,
            layer_norm=False,
            residual=False,
            dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 24007.
GCN(
  (layers): ModuleList(
    (0): GCNConv(
      (linear): Linear(in_features=302, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): GCNConv(
      (linear): Linear(in_features=64, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (2): GCNConv(
      (linear): Linear(in_features=64, out_features=7, bias=True)
    )
  )
)


### GAT

In [3]:
from grb.model.dgl import GAT

model_name = "gat"
model = GAT(in_features=dataset.num_features,
            out_features=dataset.num_classes,
            hidden_features=[64, 64],
            num_heads=4,
            adj_norm_func=None,
            layer_norm=False,
            residual=False,
            feat_dropout=0.6,
            attn_dropout=0.6,
            dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 146197.
GAT(
  (layers): ModuleList(
    (0): GATConv(
      (fc): Linear(in_features=302, out_features=256, bias=False)
      (feat_drop): Dropout(p=0.6, inplace=False)
      (attn_drop): Dropout(p=0.6, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (1): GATConv(
      (fc): Linear(in_features=256, out_features=256, bias=False)
      (feat_drop): Dropout(p=0.6, inplace=False)
      (attn_drop): Dropout(p=0.6, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (2): GATConv(
      (fc): Linear(in_features=256, out_features=7, bias=False)
      (feat_drop): Dropout(p=0.0, inplace=False)
      (attn_drop): Dropout(p=0.0, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
)


Using backend: pytorch


### APPNP

In [10]:
from grb.model.torch import APPNP
from grb.utils.normalize import GCNAdjNorm

model_name = "appnp"
model = APPNP(in_features=dataset.num_features,
              out_features=dataset.num_classes,
              hidden_features=[64, 64], 
              adj_norm_func=GCNAdjNorm,
              layer_norm=False,
              edge_drop=0.1,
              alpha=0.01,
              k=3,
              dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 24007.
APPNP(
  (layers): ModuleList(
    (0): Linear(in_features=302, out_features=64, bias=True)
    (1): Linear(in_features=64, out_features=64, bias=True)
    (2): Linear(in_features=64, out_features=7, bias=True)
  )
  (edge_dropout): SparseEdgeDrop()
  (dropout): Dropout(p=0.5, inplace=False)
)


### GIN

In [12]:
from grb.model.torch import GIN

model_name = "gin"
model = GIN(in_features=dataset.num_features,
            out_features=dataset.num_classes,
            hidden_features=[64, 64], 
            adj_norm_func=None,
            layer_norm=False,
            dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 36745.
GIN(
  (layers): ModuleList(
    (0): GINConv(
      (linear1): Linear(in_features=302, out_features=64, bias=True)
      (linear2): Linear(in_features=64, out_features=64, bias=True)
      (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): GINConv(
      (linear1): Linear(in_features=64, out_features=64, bias=True)
      (linear2): Linear(in_features=64, out_features=64, bias=True)
      (norm): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
  )
  (linear1): Linear(in_features=64, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=7, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)


### GraphSAGE

In [3]:
from grb.model.torch import GraphSAGE
from grb.utils.normalize import SAGEAdjNorm

model_name = "graphsage"
model = GraphSAGE(in_features=dataset.num_features,
                  out_features=dataset.num_classes,
                  hidden_features=[64, 64], 
                  adj_norm_func=SAGEAdjNorm,
                  layer_norm=False,
                  dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 147840.
GraphSAGE(
  (layers): ModuleList(
    (0): SAGEConv(
      (pool_layer): Linear(in_features=302, out_features=302, bias=True)
      (linear1): Linear(in_features=302, out_features=64, bias=True)
      (linear2): Linear(in_features=302, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): SAGEConv(
      (pool_layer): Linear(in_features=64, out_features=64, bias=True)
      (linear1): Linear(in_features=64, out_features=64, bias=True)
      (linear2): Linear(in_features=64, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (2): SAGEConv(
      (pool_layer): Linear(in_features=64, out_features=64, bias=True)
      (linear1): Linear(in_features=64, out_features=7, bias=True)
      (linear2): Linear(in_features=64, out_features=7, bias=True)
    )
  )
)


### SGCN

In [3]:
from grb.model.torch import SGCN
from grb.utils.normalize import GCNAdjNorm

model_name = "sgcn"
model = SGCN(in_features=dataset.num_features,
             out_features=dataset.num_classes,
             hidden_features=[64, 64], 
             adj_norm_func=GCNAdjNorm,
             k=4,
             dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 24611.
SGCN(
  (batch_norm): BatchNorm1d(302, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (in_conv): Linear(in_features=302, out_features=64, bias=True)
  (out_conv): Linear(in_features=64, out_features=7, bias=True)
  (layers): ModuleList(
    (0): SGConv(
      (linear): Linear(in_features=64, out_features=64, bias=True)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
)


### TAGCN

In [3]:
from grb.model.torch import TAGCN
from grb.utils.normalize import GCNAdjNorm

model_name = "tagcn"
model = TAGCN(in_features=dataset.num_features,
              out_features=dataset.num_classes,
              hidden_features=[64, 64],
              adj_norm_func=GCNAdjNorm,
              k=2,
              dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 71751.
TAGCN(
  (layers): ModuleList(
    (0): TAGConv(
      (linear): Linear(in_features=906, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): TAGConv(
      (linear): Linear(in_features=192, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (2): TAGConv(
      (linear): Linear(in_features=192, out_features=7, bias=True)
    )
  )
)


### MLP

In [15]:
from grb.model.torch import MLP

model_name = "mlp"
model = MLP(in_features=dataset.num_features,
            out_features=dataset.num_classes,
            hidden_features=[64, 64], 
            dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 24007.
MLP(
  (layers): ModuleList(
    (0): MLPLayer(
      (linear): Linear(in_features=302, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (1): MLPLayer(
      (linear): Linear(in_features=64, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (2): MLPLayer(
      (linear): Linear(in_features=64, out_features=7, bias=True)
    )
  )
)


## Training

GRB provides ``grb.utils.trainer`` that facilitates the training process of GNNs. The training mode can be chosen from ``inductive`` or ``transductive``. In the inductive mode, only train nodes can be seen during training, train+val nodes can be seen during validation, train+val+test nodes can be seen during testing. In the transductive mode, all nodes are available for each process.  

In [4]:
save_dir = "./saved_modes/{}/{}".format(dataset_name, model_name)
save_name = "model.pt"
device = "cuda:0"
feat_norm = None
train_mode = "inductive"  # "transductive"

In [5]:
from grb.utils.trainer import Trainer

trainer = Trainer(dataset=dataset, 
                  optimizer=torch.optim.Adam(model.parameters(), lr=0.01),
                  loss=torch.nn.functional.nll_loss,
                  lr_scheduler=False,
                  early_stop=True,
                  early_stop_patience=500,
                  feat_norm=feat_norm,
                  device=device)

In [6]:
trainer.train(model=model, 
              n_epoch=2000,
              eval_every=1,
              save_after=0,
              save_dir=save_dir,
              save_name=save_name,
              train_mode=train_mode)

Epoch 00014 | Train loss 1.9133 | Train score 0.3203 | Val loss 1.7600 | Val score 0.3918:   1%|          | 15/2000 [00:00<00:53, 36.99it/s]

Epoch 00002 | Best validation score: 0.3433
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00010 | Best validation score: 0.3507
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00011 | Best validation score: 0.3545
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00012 | Best validation score: 0.3619
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00013 | Best validation score: 0.3769
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00014 | Best validation score: 0.3918
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00027 | Train loss 1.6210 | Train score 0.4459 | Val loss 1.5186 | Val score 0.5896:   1%|▏         | 27/2000 [00:00<00:41, 48.07it/s]

Epoch 00016 | Best validation score: 0.4328
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00017 | Best validation score: 0.4851
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00018 | Best validation score: 0.5224
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00019 | Best validation score: 0.5373
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00020 | Best validation score: 0.5522
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00021 | Best validation score: 0.5597
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00022 | Best validation score: 0.5672
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00023 | Best validation score: 0.5784
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00024 | Best validation score: 0.6007
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00026 | Best validation score: 0.6082
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00048 | Train loss 1.4146 | Train score 0.5641 | Val loss 1.1677 | Val score 0.6530:   2%|▏         | 42/2000 [00:01<00:34, 56.98it/s]

Epoch 00035 | Best validation score: 0.6157
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00036 | Best validation score: 0.6269
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00037 | Best validation score: 0.6343
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00038 | Best validation score: 0.6493
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00039 | Best validation score: 0.6530
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00047 | Best validation score: 0.6604
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00064 | Train loss 1.4496 | Train score 0.5653 | Val loss 1.0818 | Val score 0.6866:   3%|▎         | 58/2000 [00:01<00:30, 64.07it/s]

Epoch 00050 | Best validation score: 0.6791
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00052 | Best validation score: 0.7052
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00080 | Train loss 1.1672 | Train score 0.6511 | Val loss 0.9526 | Val score 0.7201:   4%|▍         | 80/2000 [00:01<00:28, 66.76it/s]

Epoch 00067 | Best validation score: 0.7201
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00068 | Best validation score: 0.7239
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00071 | Best validation score: 0.7351
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00106 | Train loss 1.1929 | Train score 0.6200 | Val loss 0.8598 | Val score 0.7425:   5%|▌         | 101/2000 [00:01<00:28, 66.60it/s]

Epoch 00093 | Best validation score: 0.7388
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00094 | Best validation score: 0.7425
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00135 | Train loss 1.1436 | Train score 0.6791 | Val loss 0.8124 | Val score 0.7537:   7%|▋         | 131/2000 [00:02<00:27, 67.86it/s]

Epoch 00122 | Best validation score: 0.7537
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00123 | Best validation score: 0.7612
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00125 | Best validation score: 0.7649
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00163 | Train loss 0.9562 | Train score 0.6940 | Val loss 0.7838 | Val score 0.7724:   8%|▊         | 159/2000 [00:02<00:27, 66.69it/s]

Epoch 00150 | Best validation score: 0.7687
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00151 | Best validation score: 0.7761
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00152 | Best validation score: 0.7799
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00199 | Train loss 0.9372 | Train score 0.7220 | Val loss 0.7096 | Val score 0.7799:  10%|▉         | 195/2000 [00:03<00:26, 68.56it/s]

Epoch 00186 | Best validation score: 0.7873
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00191 | Best validation score: 0.7910
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00294 | Train loss 0.9300 | Train score 0.7307 | Val loss 0.6683 | Val score 0.7910:  15%|█▍        | 295/2000 [00:04<00:25, 67.85it/s]

Epoch 00281 | Best validation score: 0.7948
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00283 | Best validation score: 0.7985
Model saved in './saved_modes/grb-cora/gat/model.pt'.
Epoch 00284 | Best validation score: 0.8060
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00573 | Train loss 0.8898 | Train score 0.7593 | Val loss 0.6038 | Val score 0.7910:  29%|██▊       | 574/2000 [00:08<00:20, 68.34it/s]

Epoch 00560 | Best validation score: 0.8097
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 00832 | Train loss 0.7767 | Train score 0.7606 | Val loss 0.6154 | Val score 0.7836:  42%|████▏     | 830/2000 [00:12<00:16, 69.18it/s]

Epoch 00819 | Best validation score: 0.8172
Model saved in './saved_modes/grb-cora/gat/model.pt'.


Epoch 01325 | Train loss 0.8388 | Train score 0.7525 | Val loss 0.6168 | Val score 0.7948:  66%|██████▋   | 1326/2000 [00:19<00:09, 68.28it/s]

Training early stopped. Best validation score: 0.8172
Model saved in './saved_modes/grb-cora/gat/early_stopped_model.pt'.





## Inference

In [7]:
model = torch.load(os.path.join(save_dir, save_name))
model = model.to(device)
model.eval()

GAT(
  (layers): ModuleList(
    (0): GATConv(
      (fc): Linear(in_features=302, out_features=256, bias=False)
      (feat_drop): Dropout(p=0.6, inplace=False)
      (attn_drop): Dropout(p=0.6, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (1): GATConv(
      (fc): Linear(in_features=256, out_features=256, bias=False)
      (feat_drop): Dropout(p=0.6, inplace=False)
      (attn_drop): Dropout(p=0.6, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
    (2): GATConv(
      (fc): Linear(in_features=256, out_features=7, bias=False)
      (feat_drop): Dropout(p=0.0, inplace=False)
      (attn_drop): Dropout(p=0.0, inplace=False)
      (leaky_relu): LeakyReLU(negative_slope=0.2)
    )
  )
  (dropout): Dropout(p=0.5, inplace=False)
)

In [8]:
# by trainer
pred = trainer.inference(model)
print(pred, pred.shape)

tensor([[-1.7952,  0.0645,  0.1924,  ..., -2.1137, -6.9829, -7.3225],
        [-2.6028, -4.5270, -5.4507,  ...,  3.3869, -6.9371, -7.1876],
        [-3.0015, -4.1209, -4.3404,  ..., -0.1068, -6.3108, -6.5047],
        ...,
        [-2.2660, -4.0566, -2.6448,  ..., -3.1083, -4.8493, -5.3660],
        [-7.0528, -8.5380, -8.0672,  ..., -8.7857, -7.4439, -9.3096],
        [-1.7393,  0.4303,  1.3291,  ..., -1.8213, -7.0121, -7.8008]],
       device='cuda:0', grad_fn=<ViewBackward>) torch.Size([2680, 7])


In [9]:
# by utils
pred = utils.inference(model, 
                       features=dataset.features,
                       feat_norm=feat_norm,
                       adj=dataset.adj,
                       adj_norm_func=model.adj_norm_func,
                       device=device)
print(pred, pred.shape)

tensor([[-1.7952,  0.0645,  0.1924,  ..., -2.1137, -6.9829, -7.3225],
        [-2.6028, -4.5270, -5.4507,  ...,  3.3869, -6.9371, -7.1876],
        [-3.0015, -4.1209, -4.3404,  ..., -0.1068, -6.3108, -6.5047],
        ...,
        [-2.2660, -4.0566, -2.6448,  ..., -3.1083, -4.8493, -5.3660],
        [-7.0528, -8.5380, -8.0672,  ..., -8.7857, -7.4439, -9.3096],
        [-1.7393,  0.4303,  1.3291,  ..., -1.8213, -7.0121, -7.8008]],
       device='cuda:0', grad_fn=<ViewBackward>) torch.Size([2680, 7])


## Evaluation

In [10]:
# by trainer
test_score = trainer.evaluate(model, dataset.test_mask)
print("Test score: {:.4f}".format(test_score))

Test score: 0.8607


In [11]:
# by utils
test_score = utils.evaluate(model, 
                            features=dataset.features,
                            adj=dataset.adj,
                            labels=dataset.labels,
                            feat_norm=feat_norm,
                            adj_norm_func=model.adj_norm_func,
                            mask=dataset.test_mask,
                            device=device)
print("Test score: {:.4f}".format(test_score))

Test score: 0.8607


For further information, please refer to the [GRB Documentation](https://grb.readthedocs.io/en/latest/).