# Graph Neural Networks with OGB and Pytorch Geometric

In this notebook we are going to implement a Graph Neural Network using pytorch geometric and the Open Graph Benchmark. The goal of this notebook is giving general directions for anyone wishing to start in this kind of development.

In [1]:
import os
import torch
import torch.nn.functional as F

from tqdm import tqdm
from torch_geometric.loader import NeighborLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import MessagePassing, SAGEConv
from ogb.nodeproppred import Evaluator, PygNodePropPredDataset

  from .autonotebook import tqdm as notebook_tqdm


## Open Graph Benchmark (OGB)

The OGB is a collection of realistic, large-scale, and diverse benchmark datasets for machine learning on graphs. It gives us curated datasets and formalize the splitting and evaluation process for prediction tasks on those datasets.

One can imagine it as the ImageNet dataset for computer vision. Their goal is to create a standard way of evaluating advances on the Graph Learning area.

We will use the 'ogbn-arxiv' dataset, which is a node prediction dataset.

In [2]:
target_dataset = 'ogbn-arxiv'

# This will download the ogbn-arxiv to the 'networks' folder
dataset = PygNodePropPredDataset(name=target_dataset, root='networks')
dataset

PygNodePropPredDataset()

In [3]:
# The data we are going to use can be extracted from the dataset as follows:
data = dataset[0]

For graph prediction tasks, each value from the dataset would be a different graph. Here we are dealing with only one graph saved on the 'data' variable.

In [4]:
data

Data(num_nodes=169343, edge_index=[2, 1166243], x=[169343, 128], node_year=[169343, 1], y=[169343, 1])

This is a Data class from Pytorch. Here we can see some information: the number of nodes in the graph, the adjacency list (called edge_index), the feature matrix of the graph (x), the year for each node and the prediction target (y).

This is what we are going to use to train our model.

In [5]:
split_idx = dataset.get_idx_split() 
        
train_idx = split_idx['train']
valid_idx = split_idx['valid']
test_idx = split_idx['test']
        
train_loader = NeighborLoader(data, input_nodes=train_idx,
                              shuffle=True, num_workers=os.cpu_count() - 2,
                              batch_size=1024, num_neighbors=[30] * 2)

total_loader = NeighborLoader(data, input_nodes=None, num_neighbors=[-1],
                                           batch_size=4096, shuffle=False,
                                           num_workers=os.cpu_count() - 2)

## Pytorch Geometric

### Creating the GNN

We are going to use a SAGE GNN for this notebook. We will allow the number of layers to be parametrized, but will use only two here.

In [6]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels,
                 hidden_channels, out_channels,
                 n_layers=2):
        
        super(SAGE, self).__init__()
        self.n_layers = n_layers

        self.layers = torch.nn.ModuleList()
        self.layers_bn = torch.nn.ModuleList()

        if n_layers == 1:
            self.layers.append(SAGEConv(in_channels, out_channels, normalize=False))
        elif n_layers == 2:
            self.layers.append(SAGEConv(in_channels, hidden_channels, normalize=False))
            self.layers_bn.append(torch.nn.BatchNorm1d(hidden_channels))
            self.layers.append(SAGEConv(hidden_channels, out_channels, normalize=False))
        else:
            self.layers.append(SAGEConv(in_channels, hidden_channels, normalize=False))
            self.layers_bn.append(torch.nn.BatchNorm1d(hidden_channels))

            for _ in range(n_layers - 2):
                self.layers.append(SAGEConv(hidden_channels, hidden_channels, normalize=False))
                self.layers_bn.append(torch.nn.BatchNorm1d(hidden_channels))
            
            self.layers.append(SAGEConv(hidden_channels, out_channels, normalize=False))
            
        for layer in self.layers:
            layer.reset_parameters()

    def forward(self, x, edge_index):
        if len(self.layers) > 1:
            looper = self.layers[:-1]
        else:
            looper = self.layers
        
        for i, layer in enumerate(looper):
            x = layer(x, edge_index)
            try:
                x = self.layers_bn[i](x)
            except Exception as e:
                abs(1)
            finally:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        
        if len(self.layers) > 1:
            x = self.layers[-1](x, edge_index)

        return F.log_softmax(x, dim=-1), torch.var(x)
    
    def inference(self, total_loader, device):
        xs = []
        var_ = []
        for batch in total_loader:
            out, var = self.forward(batch.x.to(device), batch.edge_index.to(device))
            out = out[:batch.batch_size]
            xs.append(out.cpu())
            var_.append(var.item())
        
        out_all = torch.cat(xs, dim=0)
        
        return out_all, var_

### Training the Model

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SAGE(data.x.shape[1], 256, dataset.num_classes, n_layers=2)
model.to(device)
epochs = 100
optimizer = torch.optim.Adam(model.parameters(), lr=0.03)
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=7)

In [10]:
def test(model, device):
    evaluator = Evaluator(name=target_dataset)
    model.eval()
    out, var = model.inference(total_loader, device)

    y_true = data.y.cpu()
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': y_true[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    val_acc = evaluator.eval({
        'y_true': y_true[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': y_true[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, val_acc, test_acc, torch.mean(torch.Tensor(var))

In [None]:
for epoch in range(1, epochs):
    model.train()

    pbar = tqdm(total=train_idx.size(0))
    pbar.set_description(f'Epoch {epoch:02d}')

    total_loss = total_correct = 0

    for batch in train_loader:
        batch_size = batch.batch_size
        optimizer.zero_grad()

        out, _ = model(batch.x.to(device), batch.edge_index.to(device))
        out = out[:batch_size]

        batch_y = batch.y[:batch_size].to(device)
        batch_y = torch.reshape(batch_y, (-1,))

        loss = F.nll_loss(out, batch_y)
        loss.backward()
        optimizer.step()

        total_loss += float(loss)
        total_correct += int(out.argmax(dim=-1).eq(batch_y).sum())
        pbar.update(batch.batch_size)

    pbar.close()

    loss = total_loss / len(train_loader)
    approx_acc = total_correct / train_idx.size(0)

    train_acc, val_acc, test_acc, var = test(model, device)
    
    print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}, Var: {var:.4f}')

Epoch 01: 100%|██████████| 90941/90941 [01:58<00:00, 765.76it/s] 


Train: 0.5777, Val: 0.5667, Test: 0.5087, Var: 8.2168


Epoch 02: 100%|██████████| 90941/90941 [02:57<00:00, 512.57it/s] 


Train: 0.5881, Val: 0.5618, Test: 0.5072, Var: 8.4361


Epoch 03: 100%|██████████| 90941/90941 [01:28<00:00, 1024.40it/s]


Train: 0.6042, Val: 0.5803, Test: 0.5195, Var: 8.1180


Epoch 04:  44%|████▍     | 39936/90941 [00:56<00:29, 1710.26it/s]