In [1]:
import os.path as osp

import torch
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset

import torch_geometric.transforms as T
from torch_geometric.nn import MaskLabel, TransformerConv
from torch_geometric.utils import index_to_mask

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root = osp.join(osp.dirname(osp.realpath("__file__")), '..', 'data', 'OGB')
dataset = PygNodePropPredDataset('ogbn-arxiv', root, T.ToUndirected())

Downloading http://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:10<00:00,  7.63it/s]
Processing...


Extracting /data/notebook/KG_folder/kge_benchmark/../data/OGB/arxiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 1/1 [00:00<00:00, 11459.85it/s]


Converting graphs into PyG objects...


100%|██████████| 1/1 [00:00<00:00, 2226.28it/s]

Saving...



Done!


In [3]:
class UniMP(torch.nn.Module):
    def __init__(self, in_channels, num_classes, hidden_channels, num_layers,
                 heads, dropout=0.3):
        super().__init__()

        self.label_emb = MaskLabel(num_classes, in_channels)

        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            if i < num_layers:
                out_channels = hidden_channels // heads
                concat = True
            else:
                out_channels = num_classes
                concat = False
            conv = TransformerConv(in_channels, out_channels, heads,
                                   concat=concat, beta=True, dropout=dropout)
            self.convs.append(conv)
            in_channels = hidden_channels

            if i < num_layers:
                self.norms.append(torch.nn.LayerNorm(hidden_channels))

    def forward(self, x, y, edge_index, label_mask):
        x = self.label_emb(x, y, label_mask)
        for conv, norm in zip(self.convs, self.norms):
            x = norm(conv(x, edge_index)).relu()
        return self.convs[-1](x, edge_index)

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = dataset[0].to(device)
data.y = data.y.view(-1)
model = UniMP(dataset.num_features, dataset.num_classes, hidden_channels=64,
              num_layers=3, heads=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)

split_idx = dataset.get_idx_split()
train_mask = index_to_mask(split_idx['train'], size=data.num_nodes)
val_mask = index_to_mask(split_idx['valid'], size=data.num_nodes)
test_mask = index_to_mask(split_idx['test'], size=data.num_nodes)

In [5]:
data.num_nodes

169343

In [7]:
data.x.shape

torch.Size([169343, 128])

In [8]:
data.edge_index.shape

torch.Size([2, 2315598])

In [9]:

def train(label_rate=0.65):  # How many labels to use for propagation.
    model.train()

    propagation_mask = MaskLabel.ratio_mask(train_mask, ratio=label_rate)
    supervision_mask = train_mask ^ propagation_mask

    optimizer.zero_grad()
    out = model(data.x, data.y, data.edge_index, propagation_mask)
    loss = F.cross_entropy(out[supervision_mask], data.y[supervision_mask])
    loss.backward()
    optimizer.step()

    return float(loss)


@torch.no_grad()
def test():
    model.eval()

    propagation_mask = train_mask
    out = model(data.x, data.y, data.edge_index, propagation_mask)
    pred = out[val_mask].argmax(dim=-1)
    val_acc = int((pred == data.y[val_mask]).sum()) / pred.size(0)

    propagation_mask = train_mask | val_mask
    out = model(data.x, data.y, data.edge_index, propagation_mask)
    pred = out[test_mask].argmax(dim=-1)
    test_acc = int((pred == data.y[test_mask]).sum()) / pred.size(0)

    return val_acc, test_acc


In [10]:
for epoch in range(1, 101):
    loss = train()
    val_acc, test_acc = test()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, '
          f'Test: {test_acc:.4f}')

Epoch: 001, Loss: 3.7164, Val: 0.0319, Test: 0.0249
Epoch: 002, Loss: 3.5990, Val: 0.0876, Test: 0.0691
Epoch: 003, Loss: 3.4928, Val: 0.2245, Test: 0.2322
Epoch: 004, Loss: 3.3959, Val: 0.2778, Test: 0.2686
Epoch: 005, Loss: 3.3111, Val: 0.2889, Test: 0.2871
Epoch: 006, Loss: 3.2340, Val: 0.3129, Test: 0.3296
Epoch: 007, Loss: 3.1703, Val: 0.3401, Test: 0.3682
Epoch: 008, Loss: 3.1101, Val: 0.3664, Test: 0.4032
Epoch: 009, Loss: 3.0578, Val: 0.3963, Test: 0.4460
Epoch: 010, Loss: 3.0015, Val: 0.4277, Test: 0.4817
Epoch: 011, Loss: 2.9554, Val: 0.4552, Test: 0.5076
Epoch: 012, Loss: 2.8996, Val: 0.4768, Test: 0.5241
Epoch: 013, Loss: 2.8558, Val: 0.4964, Test: 0.5392
Epoch: 014, Loss: 2.8164, Val: 0.5123, Test: 0.5517
Epoch: 015, Loss: 2.7635, Val: 0.5265, Test: 0.5622
Epoch: 016, Loss: 2.7115, Val: 0.5394, Test: 0.5720
Epoch: 017, Loss: 2.6707, Val: 0.5491, Test: 0.5792
Epoch: 018, Loss: 2.6201, Val: 0.5558, Test: 0.5838
Epoch: 019, Loss: 2.5787, Val: 0.5608, Test: 0.5872
Epoch: 020, 