# 1. Example of adversarial training for Graph Neural Networks (GNNs)

In [1]:
import os
import torch
import grb.utils as utils

## 1.1. Load Dataset

GRB datasets are named by the prefix *grb-*. There are four *mode* ('easy', 'medium', 'hard', 'full') for test set, representing different average degrees of test nodes, thus different difficulty for attacking them. The node features are processed by *arctan* normalization (first standardization then arctan function), which makes node features fall in the same scale.

In [2]:
from grb.dataset import Dataset

dataset_name = 'grb-cora'
dataset = Dataset(name=dataset_name, 
                  data_dir="../../data/",
                  mode='full',
                  feat_norm='arctan')

Dataset 'grb-cora' loaded.
    Number of nodes: 2680
    Number of edges: 5148
    Number of features: 302
    Number of classes: 7
    Number of train samples: 1608
    Number of val samples: 268
    Number of test samples: 804
    Dataset mode: full
    Feature range: [-0.9406, 0.9430]


In [3]:
adj = dataset.adj
features = dataset.features
labels = dataset.labels
num_features = dataset.num_features
num_classes = dataset.num_classes
test_mask = dataset.test_mask

## 1.2. Build Model

GRB supports models based on pure Pytorch, CogDL or DGL. The following is an example of GCN implemented by pure Pytorch. Other models can be found in ``grb/model/torch``, ``grb/model/cogdl``, or ``grb/model/dgl``.

### 1.2.1. GCN

In [4]:
from grb.model.torch import GCN
from grb.utils.normalize import GCNAdjNorm

model_name = "gcn"
model = GCN(in_features=dataset.num_features,
            out_features=dataset.num_classes,
            hidden_features=64, 
            n_layers=3,
            adj_norm_func=GCNAdjNorm,
            layer_norm=True,
            residual=False,
            dropout=0.5)
print("Number of parameters: {}.".format(utils.get_num_params(model)))
print(model)

Number of parameters: 24867.
GCN(
  (layers): ModuleList(
    (0): LayerNorm((302,), eps=1e-05, elementwise_affine=True)
    (1): GCNConv(
      (linear): Linear(in_features=302, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (3): GCNConv(
      (linear): Linear(in_features=64, out_features=64, bias=True)
      (dropout): Dropout(p=0.5, inplace=False)
    )
    (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (5): GCNConv(
      (linear): Linear(in_features=64, out_features=7, bias=True)
    )
  )
)


## 1.3. Adversarial training

### 1.3.1. Build attack

In [5]:
from grb.attack.injection import FGSM

device = 'cuda:0'

attack = FGSM(epsilon=0.01,
              n_epoch=10,
              n_inject_max=10,
              n_edge_max=20,
              feat_lim_min=features.min(),
              feat_lim_max=features.max(),
              early_stop=False,
              device=device, 
              verbose=False)

In [6]:
save_dir = "./saved_models/{}/{}_at".format(dataset_name, model_name)
save_name = "model.pt"
device = "cuda:0"
feat_norm = None
train_mode = "inductive"  # "transductive"

In [7]:
from grb.defense import AdvTrainer
trainer = AdvTrainer(dataset=dataset, 
                     attack=attack,
                     optimizer=torch.optim.Adam(model.parameters(), lr=0.01),
                     loss=torch.nn.functional.cross_entropy,
                     lr_scheduler=False,
                     early_stop=True,
                     early_stop_patience=500,
                     device=device)

In [8]:
trainer.train(model=model, 
              n_epoch=2000,
              eval_every=1,
              save_after=0,
              save_dir=save_dir,
              save_name=save_name,
              train_mode=train_mode,
              verbose=False)

  0%|          | 0/2000 [00:00<?, ?it/s]

Training: early stopped.
Model saved in './saved_modes/grb-cora/gcn_at/final_model.pt'.


In [9]:
# by trainer
test_score = trainer.evaluate(model, dataset.test_mask)
print("Test score: {:.4f}".format(test_score))

Test score: 0.8383
