#### Node Property Prediction on the Open Graph Benchmark (ogbn-mag) dataset. The GNN model used is HGT from the paper ["Heterogeneous Graph Transformer"](https://arxiv.org/abs/2003.01332).

The ogbn-mag dataset from the “Open Graph Benchmark: Datasets for Machine Learning on Graphs” paper. ogbn-mag is a heterogeneous graph composed of a subset of the Microsoft Academic Graph (MAG). It contains four types of entities — papers (736,389 nodes), authors (1,134,649 nodes), institutions (8,740 nodes), and fields of study (59,965 nodes) — as well as four types of directed relations connecting two types of entities. Each paper is associated with a 128-dimensional word2vec feature vector, while all other node types are not associated with any input features. The task is to predict the venue (conference or journal) of each paper. In total, there are 349 different venues.

**Preprocessing**: Structural features are added to featureless nodes using metapath2vec. 


In [3]:
import argparse
import glob
import os
import os.path as osp
import time
from typing import List, NamedTuple, Optional

import numpy as np
import torch
import torch.nn.functional as F
from ogb.lsc import MAG240MDataset, MAG240MEvaluator
from pytorch_lightning import (LightningDataModule, LightningModule, Trainer,
                               seed_everything)
from pytorch_lightning.callbacks import ModelCheckpoint
from torchmetrics import Accuracy
from torch import Tensor
from torch.nn import BatchNorm1d, Dropout, Linear, ModuleList, ReLU, Sequential
from torch.optim.lr_scheduler import StepLR

from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import HGTLoader, NeighborLoader
from torch_geometric.nn import Linear, SAGEConv, Sequential, HGTConv
from torch_geometric.data import HeteroData
from torch_geometric.typing import InputNodes
import torch_geometric.transforms as T


In [4]:

class LightningHeteroData(LightningDataModule):
    def __init__(self,
          data: HeteroData,
          train_input_nodes: InputNodes,
          val_input_nodes: InputNodes,
          test_input_nodes: InputNodes,
          batch_size: int = 1,
          num_workers: int = 0,
          **kwargs
        ):
        super().__init__()
        self.data = data
        self.train_input_nodes = train_input_nodes
        self.val_input_nodes = val_input_nodes
        self.test_input_nodes = test_input_nodes
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self):
      return HGTLoader(self.data, 
                       num_samples=[self.batch_size] * 4, 
                       shuffle=True,
                       input_nodes=self.train_input_nodes, 
                       batch_size=self.batch_size,
                       num_workers=self.num_workers)

    def val_dataloader(self):
      return HGTLoader(self.data,   
                       num_samples=[self.batch_size] * 4, 
                       shuffle=False,
                       input_nodes=self.val_input_nodes, 
                       batch_size=self.batch_size,
                       num_workers=self.num_workers)

    def test_dataloader(self):
        return HGTLoader(self.data,   
                        num_samples=[self.batch_size] * 4, 
                        shuffle=False,
                        input_nodes=self.train_input_nodes, 
                        batch_size=self.batch_size,
                        num_workers=self.num_workers)

In [5]:
class HGT(torch.nn.Module):
  def __init__(self, 
               data: HeteroData, 
               hidden_channels: int, 
               out_channels: int, 
               num_heads: int, 
               num_layers: int, 
               node_type: str
      ):
      super().__init__()
      self.node_type = node_type

      self.lin_dict = torch.nn.ModuleDict()

      for node_type in data.node_types:
          self.lin_dict[node_type] = Linear(data[node_type].x.size(1), hidden_channels)

      self.convs = torch.nn.ModuleList()
      for _ in range(num_layers):
          conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),
                          num_heads, group='sum')
          self.convs.append(conv)

      self.lin = Linear(hidden_channels, out_channels)

  def forward(self, x_dict, edge_index_dict):
      x_dict = {
          node_type: self.lin_dict[node_type](x).relu_()
          for node_type, x in x_dict.items()
      }

      for conv in self.convs:
          x_dict = conv(x_dict, edge_index_dict)

      return self.lin(x_dict[self.node_type])


In [6]:
class LightningHGT(LightningModule):
    def __init__(self, **model_params):
        super().__init__()

        self.model = HGT(**model_params)

        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.loss_fn = F.cross_entropy

    def forward(self, x_dict, edge_index_dict):
        return self.model(x_dict, edge_index_dict)

    def training_step(self, batch, batch_idx: int):
        out = self(batch.x_dict, batch.edge_index_dict)
        loss = self.loss_fn(out, batch["paper"].y)
        self.train_acc(out.softmax(dim=-1), batch["paper"].y)
        self.log('train_acc', self.train_acc, prog_bar=True, on_step=False,
          on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx: int):
        logits = self(batch.x_dict, batch.edge_index_dict)

        y_pred = logits.argmax(dim=-1)
        y_true = batch["paper"].y

        self.val_acc(y_pred, y_true)  
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True,
                 prog_bar=True, sync_dist=True)

    def test_step(self, batch, batch_idx: int):
        logits = self(batch.x_dict, batch.edge_index_dict)

        y_pred = logits.argmax(dim=-1)
        y_true = batch["paper"].y

        self.test_acc(y_pred, y_true)
        self.log('test_acc', self.test_acc, on_step=False, on_epoch=True,
                 prog_bar=True, sync_dist=True)
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = StepLR(optimizer, step_size=25, gamma=0.25)
        return [optimizer], [scheduler]

    def initialize_params(self, data):
        with torch.no_grad():  # Initialize lazy modules.
          self(data.x_dict, data.edge_index_dict)

In [7]:
parser = argparse.ArgumentParser()

parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch_size', type=int, default=1024)

parser.add_argument('--model', type=str, default='hgt',choices=['hgt'])
parser.add_argument('--n_hid', type=int, default=512,
                    help='Number of hidden dimension')
parser.add_argument('--n_heads', type=int, default=8,
                    help='Number of attention head')
parser.add_argument('--n_layers', type=int, default=4,
                    help='Number of GNN layers')
parser.add_argument('--dropout', type=float, default=0.2,
                    help='Dropout ratio')

parser.add_argument('--num_workers', type=int, default=2) # 4 CPU cores, 2 GPU cores on kaggle
parser.add_argument('--devices', type=int, default=1)
parser.add_argument('--evaluate', action='store_true')

args = parser.parse_args("")
print(args)


Namespace(batch_size=1024, devices=1, dropout=0.2, epochs=100, evaluate=False, model='hgt', n_heads=8, n_hid=512, n_layers=4, num_workers=2)


In [8]:
seed_everything(42, workers=True)

data_dir = osp.join(osp.dirname(osp.realpath("__file__")), '../../data/OGB')
transform = T.ToUndirected(merge=True)
dataset = OGB_MAG(data_dir, preprocess='metapath2vec', transform=transform)  # TODO: train without metapath2vec embeddings
data = dataset[0]
datamodule = LightningHeteroData(
    data, 
    train_input_nodes=('paper', data['paper'].train_mask),
    val_input_nodes=('paper', data['paper'].val_mask),
    test_input_nodes=('paper', data['paper'].test_mask),
    batch_size=args.batch_size, 
    num_workers=args.num_workers
)

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip
Extracting /data/OGB/mag/raw/mag.zip
Downloading https://data.pyg.org/datasets/mag_metapath2vec_emb.zip
Extracting /data/OGB/mag/raw/mag_metapath2vec_emb.zip
Processing...
Done!


In [9]:
model = LightningHGT(
    data=data, 
    hidden_channels=args.n_hid, 
    out_channels=dataset.num_classes, 
    num_heads=args.n_heads, 
    num_layers=args.n_layers,
    node_type="paper"
)
model.initialize_params(next(iter(datamodule.train_dataloader())))
print(f'#Params {sum([p.numel() for p in model.parameters()])}')

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc', 
    mode='max',
    save_top_k=1
)
trainer = Trainer(
    accelerator="auto", 
    devices=args.devices, 
    max_epochs=args.epochs,      
    callbacks=[checkpoint_callback],
    default_root_dir=f'logs/{args.model}'
)
trainer.fit(model, datamodule=datamodule)


#Params 19088461


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [10]:
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator

dirs = glob.glob(f'logs/{args.model}/lightning_logs/*')
version = max([int(x.split(os.sep)[-1].split('_')[-1]) for x in dirs])
logdir = f'logs/{args.model}/lightning_logs/version_{version}'
print(f'Evaluating saved model in {logdir}...')
ckpt = glob.glob(f'{logdir}/checkpoints/*')[0]

trainer = Trainer(accelerator="auto", devices=1, resume_from_checkpoint=ckpt)
model = HGT.load_from_checkpoint(checkpoint_path=ckpt, hparams_file=f'{logdir}/hparams.yaml')

datamodule.batch_size = 16
datamodule.sizes = [160] * len(args.sizes)  # (Almost) no sampling...

#     loader = datamodule.test_dataloader()
#     predictions = trainer.predict(model, dataloaders=loader)


trainer.test(model=model, datamodule=datamodule)

evaluator = Evaluator(name='ogbn-mag')
loader = datamodule.test_dataloader()

model.eval()
device = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
model.to(device)
y_preds = []

for batch in tqdm(loader):
    batch = batch.to(device)
    with torch.no_grad():
        out = model(batch.x_dict, batch.edge_index_dict).argmax(dim=-1).cpu()
        y_preds.append(out)

res = {'y_pred': torch.cat(y_preds, dim=0)}
evaluator.save_test_submission(res, f'results/{args.model}',
                               mode='test-dev')

Evaluating saved model in logs/hgt/lightning_logs/version_0...


  "Setting `Trainer(resume_from_checkpoint=)` is deprecated in v1.5 and"


AttributeError: type object 'HGT' has no attribute 'load_from_checkpoint'

In [None]:
evaluator = Evaluator(name=dataset)
preds, labels = evaluator.eval({
        "y_true": labels.view(-1, 1),
        "y_pred": preds.view(-1, 1),
    })["acc"]