### Import DGL and PyTorch

In this tutorial, we are going to introduce how to implement GAT for link prediction task. 


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
!pip install dgl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [4]:
import dgl
from dgl import DGLGraph

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [5]:
import numpy as np
import warnings
warnings.filterwarnings('ignore')

### Prepare the MUTAG dataset
We use a collection of nitroaromatic compounds called MUTAG for demonstration and the goal is to predict their mutagenicity on Salmonella typhimurium. Input graphs are used to represent chemical compounds, where vertices stand for atoms and are labeled by the atom type, while edges between vertices represent bonds between the corresponding atoms.
It includes 188 samples of chemical compounds with 7 discrete node labels.

In [6]:
from dgl.data import GINDataset

device = 'cpu'
# load and preprocess the pubmed dataset
dataset = GINDataset('MUTAG', self_loop=True, degree_as_nlabel=False)
# the number of input node features
in_feats = dataset.dim_nfeats
# class labels of papers
labels = [l for _, l in dataset]

Downloading /root/.dgl/GINDataset.zip from https://raw.githubusercontent.com/weihua916/powerful-gnns/master/dataset.zip...
Extracting file to /root/.dgl/GINDataset


Here we add self-loops in the graph for self-attention

In [20]:
print(dataset[0])

(Graph(num_nodes=23, num_edges=77,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(7,), dtype=torch.float32)}
      edata_schemes={}), tensor(0))


In [7]:
dataset[0][0].ndata['attr']

tensor([[0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.]])

we split the dataset for train and test.

In [8]:
from sklearn.model_selection import train_test_split

train_idx, val_idx = train_test_split(np.arange(len(dataset)),test_size=0.1)

In [16]:
print(train_idx)

[150  46  85  43  81 129  54 139  59  29  67 116  86  35  38 173   7  89
  92 115  95 147  32 100  15 135 154  56  24 143 113 106  58  78  50  70
 145  41  55  96 122 118 141 151 127 133  25  17 109  76  72 112 119 144
  44 148 161 103 176  62 134  36 104 175  57  99 114 136 166  87   5  28
 146 169 153 155 185  61  12  53  65 171  26  68  47  23  75 170  60 123
 172   3  19  31  11 159  93 131  79 121  52  69 108   9 137 174 130  42
 107 157  22 124   8 180  27  16  84  20  63  37   4 111 158  98 105  51
 126 187  48  10 182 156  33 140 152 167 128  21  91  49  83 165 110  45
 101 132   0  18 177 164  64 162 181  40 179  34   2 102  13  66   6  97
 163  39 120 117 178 183  80]


In [9]:
from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader

train_loader = GraphDataLoader(
    dataset,
    sampler=SubsetRandomSampler(train_idx),
    batch_size=128,
)
val_loader = GraphDataLoader(
    dataset,
    sampler=SubsetRandomSampler(val_idx),
    batch_size=128,
)

### Implement the GNN model

Essentially, given a graph structure, GNNs (GCN, GraphSAGE, GAT, etc.) are used to learn meaningful node representations (in this case, the embeddings, or vectors).
Once these embeddings are properly learnt, we may perform downstream tasks such as node classification, graph classification, and link prediction.

DGL provides two ways of implementing a GNN model:

- using the nn module, which contains many commonly used GNN modules.
- using the message passing interface to implement a GNN model from scratch.

If you are interested in using the message passing interface to implement a GNN model, check this link https://doc.dgl.ai/tutorials/models/index.html out.

![fishy](https://raw.githubusercontent.com/dglai/WWW20-Hands-on-Tutorial/master/images/GNN.png)

Define GIN model.

In [10]:
class MLP(nn.Module):
    """Construct two-layer MLP-type aggreator for GIN model"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linears = nn.ModuleList()
        # two-layer MLP
        self.linears.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.linears.append(nn.Linear(hidden_dim, output_dim, bias=False))
        self.batch_norm = nn.BatchNorm1d((hidden_dim))

    def forward(self, x):
        h = x
        h = F.relu(self.batch_norm(self.linears[0](h)))
        return self.linears[1](h)


from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling
    
class GIN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.ginlayers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        num_layers = 5
        # five-layer GCN with two-layer MLP aggregator and sum-neighbor-pooling scheme
        for layer in range(num_layers - 1):  # excluding the input layer
            if layer == 0:
                mlp = MLP(input_dim, hidden_dim, hidden_dim)
            else:
                mlp = MLP(hidden_dim, hidden_dim, hidden_dim)
            self.ginlayers.append(
                GINConv(mlp, learn_eps=False)
            )  # set to True if learning epsilon
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
        # linear functions for graph sum poolings of output of each layer
        self.linear_prediction = nn.ModuleList()
        for layer in range(num_layers):
            if layer == 0:
                self.linear_prediction.append(nn.Linear(input_dim, output_dim))
            else:
                self.linear_prediction.append(nn.Linear(hidden_dim, output_dim))
        self.drop = nn.Dropout(0.5)
        self.pool = (
            SumPooling()
        )  # change to mean readout (AvgPooling) on social network datasets

    def forward(self, g, h):
        # list of hidden representation at each layer (including the input layer)
        hidden_rep = [h]
        for i, layer in enumerate(self.ginlayers):
            h = layer(g, h)
            h = self.batch_norms[i](h)
            h = F.relu(h)
            hidden_rep.append(h)
        score_over_layer = 0
        # perform graph sum pooling over all nodes in each layer
        for i, h in enumerate(hidden_rep):
            pooled_h = self.pool(g, h)
            score_over_layer += self.drop(self.linear_prediction[i](pooled_h))
        return score_over_layer

### Graph Classification

![semisupervised](https://data.dgl.ai/tutorial/batch/graph_classifier.png)


In [11]:
# Hyperparameters
in_size = dataset.dim_nfeats
out_size = dataset.gclasses
hidden_size = 16
model = GIN(in_size, hidden_size, out_size).to(device)


After defining a model for graph classification, we define evaluation function.

In [12]:
def evaluate(dataloader, device, model):
    model.eval()
    total = 0
    total_correct = 0
    for batched_graph, labels in dataloader:
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        feat = batched_graph.ndata.pop("attr")
        total += len(labels)
        logits = model(batched_graph, feat)
        _, predicted = torch.max(logits, 1)
        total_correct += (predicted == labels).sum().item()
    acc = 1.0 * total_correct / total
    return acc

In [14]:
import torch.optim as optim

loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

# training loop
for epoch in range(350):
    model.train()
    total_loss = 0
    for batch, (batched_graph, labels) in enumerate(train_loader):
        batched_graph = batched_graph.to(device)
        labels = labels.to(device)
        feat = batched_graph.ndata.pop("attr")
        logits = model(batched_graph, feat)
        loss = loss_fcn(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    train_acc = evaluate(train_loader, device, model)
    valid_acc = evaluate(val_loader, device, model)
    print(
        "Epoch {:05d} | Loss {:.4f} | Train Acc. {:.4f} | Validation Acc. {:.4f} ".format(
            epoch, total_loss / (batch + 1), train_acc, valid_acc
        )
    )

Epoch 00000 | Loss 5.2796 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00001 | Loss 2.4202 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00002 | Loss 1.7866 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00003 | Loss 1.5285 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00004 | Loss 1.4200 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00005 | Loss 1.3061 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00006 | Loss 1.2221 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00007 | Loss 0.9368 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00008 | Loss 1.0632 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00009 | Loss 1.5010 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00010 | Loss 0.7096 | Train Acc. 0.6746 | Validation Acc. 0.5789 
Epoch 00011 | Loss 1.0330 | Train Acc. 0.6805 | Validation Acc. 0.5789 
Epoch 00012 | Loss 0.7177 | Train Acc. 0.6805 | Validation Acc. 0.6316 
Epoch 00013 | Loss 0.5728 | Train Acc. 0.6864 | Validation Acc. 

In [15]:
model.test()

AttributeError: ignored