This notebook loads various graphs and performs graph classification, by averaging out the node features obtained from GAT.

Reference: https://docs.dgl.ai/en/1.0.x/tutorials/blitz/5_graph_classification.html

In [1]:
! pip install  dgl -f https://data.dgl.ai/wheels/repo.html
! pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels/repo.html
Collecting dgl
  Downloading https://data.dgl.ai/wheels/dgl-1.0.2-cp39-cp39-manylinux1_x86_64.whl (5.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m36.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.0.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels-test/repo.html
Collecting dglgo
  Downloading dglgo-0.0.2-py3-none-any.whl (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.5/63.5 KB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ruamel.yaml>=0.17.20
  Downloading ruamel.yaml-0.17.21-py3-none-any.whl (109 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.5/109.5 KB[0m [31m14.6 MB/s[0m eta [3

In [2]:
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn

In [3]:
class GAT(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.lin1 = nn.Linear(in_features, out_features)
    self.a = nn.Linear(out_features *2, 1, bias = False)

  def update_e_ij(self, edges):
    # return edge features after update
    return {"e_ij": nn.LeakyReLU()(self.a(torch.cat([edges.src['z'], edges.dst['z']], dim=1)))}

  def reduce_func(self, nodes):
    # applies to the neighbors, return node features after update
    alpha_ij = nn.Softmax(dim =1)(nodes.mailbox["e_ij"])
    h = torch.sum(alpha_ij * nodes.mailbox["z"], dim = 1)
    return {"h": h}

  def message_func(self, edges):
    # return updated node and edge features
    return {"z": edges.src["z"], "e_ij": edges.data["e_ij"]}


  def forward(self, g, h):
    with g.local_scope(): # to prevent changes from going back to the original graph
      g.ndata["z"] = self.lin1(h)
      g.apply_edges(self.update_e_ij)

      g.update_all(self.message_func, self.reduce_func)

      return g.ndata["h"]


In [4]:
# For multi head attention
class MultiHeadAttentionGAT(nn.Module):
  def __init__(self, in_features, out_features, num_heads, agg="concat"):
    super().__init__()
    self.num_heads = num_heads
    self.agg = agg
    self.heads = nn.ModuleList()
    for i in range(self.num_heads):
      self.heads.append(GAT(in_features, out_features))

  def forward(self, g, h):
    res = []
    for i in range(self.num_heads):
      res.append(self.heads[i](g, h))
    if self.agg == "concat":
      return torch.cat(res, dim=1)
    else: # agg is mean
      return torch.mean(torch.stack(res))

In [5]:
!pip install ogb 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
# dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
from ogb.graphproppred import DglGraphPropPredDataset, collate_dgl
from torch.utils.data import DataLoader
dataset = DglGraphPropPredDataset(name = 'ogbg-molhiv')
print(dataset)
split_idx = dataset.get_idx_split()
train_dataloader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=collate_dgl)
valid_dataloader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=collate_dgl)
test_dataloader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=collate_dgl)

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip


Downloaded 0.00 GB: 100%|██████████| 3/3 [00:01<00:00,  1.92it/s]


Extracting dataset/hiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 41127/41127 [00:01<00:00, 33398.91it/s]


Converting graphs into DGL objects...


100%|██████████| 41127/41127 [00:16<00:00, 2457.92it/s]


Saving...
DglGraphPropPredDataset(41127)


In [None]:
# # Batching
# from torch.utils.data.sampler import SubsetRandomSampler

# from dgl.dataloading import GraphDataLoader

# num_examples = len(dataset)
# num_train = int(num_examples * 0.8)

# train_sampler = SubsetRandomSampler(torch.arange(num_train))
# test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

# train_dataloader = GraphDataLoader(
#     dataset, sampler=train_sampler, batch_size=5, drop_last=False
# )
# test_dataloader = GraphDataLoader(
#     dataset, sampler=test_sampler, batch_size=5, drop_last=False
# )

In [7]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        # self.conv1 = GAT(in_feats, h_feats)
        # self.conv2 = GAT(h_feats, num_classes)
        self.conv1 = MultiHeadAttentionGAT(in_feats, h_feats, 3, "concat")
        self.conv2 = MultiHeadAttentionGAT(h_feats*3, num_classes, 1, "concat")

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        return dgl.mean_nodes(g, "h")


In [20]:
model = Model(9, 16, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(20):
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["feat"].float())
        loss = F.cross_entropy(pred, labels.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_tests = 0
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["feat"].float())
    num_correct += (pred.argmax(1) == labels.view(-1)).sum().item()
    num_tests += len(labels)

print("Test accuracy:", num_correct / num_tests)

  assert input.numel() == input.storage().size(), (


Test accuracy: 0.9683929005592026


In [21]:
num_correct = 0
num_tests = 0
predictions = []
labels_all = []
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["feat"].float())
    num_correct += (pred.argmax(1) == labels.view(-1)).sum().item()
    num_tests += len(labels)
    predictions.extend(pred.detach().numpy())
    labels_all.extend(labels.numpy())
print("Test accuracy:", num_correct / num_tests)

Test accuracy: 0.9683929005592026


In [22]:
import numpy
from sklearn.metrics import roc_auc_score
roc_auc_score(numpy.array(labels_all).reshape(-1), numpy.array(predictions)[:, 1])

0.6424283976129319