In [1]:
! pip install  dgl -f https://data.dgl.ai/wheels/repo.html
! pip install  dglgo -f https://data.dgl.ai/wheels-test/repo.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels/repo.html
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels-test/repo.html


In [2]:
!pip install ogb 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [3]:
import os
os.environ['DGLBACKEND'] = 'pytorch'
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn

In [4]:
class GAT(nn.Module):
  def __init__(self, in_features, out_features):
    super().__init__()
    self.lin1 = nn.Linear(in_features, out_features)
    self.a = nn.Linear(out_features *2, 1, bias = False)

  def update_e_ij(self, edges):
    # return edge features after update
    return {"e_ij": nn.LeakyReLU()(self.a(torch.cat([edges.src['z'], edges.dst['z']], dim=1)))}

  def reduce_func(self, nodes):
    # applies to the neighbors, return node features after update
    alpha_ij = nn.Softmax(dim =1)(nodes.mailbox["e_ij"])
    h = torch.sum(alpha_ij * nodes.mailbox["z"], dim = 1)
    return {"h": h}

  def message_func(self, edges):
    # return updated node and edge features
    return {"z": edges.src["z"], "e_ij": edges.data["e_ij"]}


  def forward(self, g, h):
    with g.local_scope(): # to prevent changes from going back to the original graph
      g.ndata["z"] = self.lin1(h)
      g.apply_edges(self.update_e_ij)

      g.update_all(self.message_func, self.reduce_func)

      return g.ndata["h"]


In [5]:
# For multi head attention
class MultiHeadAttentionGAT(nn.Module):
  def __init__(self, in_features, out_features, num_heads, agg="concat"):
    super().__init__()
    self.num_heads = num_heads
    self.agg = agg
    self.heads = nn.ModuleList()
    for i in range(self.num_heads):
      self.heads.append(GAT(in_features, out_features))

  def forward(self, g, h):
    res = []
    for i in range(self.num_heads):
      res.append(self.heads[i](g, h))
    if self.agg == "concat":
      return torch.cat(res, dim=1)
    else: # agg is mean
      return torch.mean(torch.stack(res))

In [6]:
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
# dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
from ogb.graphproppred import DglGraphPropPredDataset, collate_dgl
from torch.utils.data import DataLoader
dataset = DglGraphPropPredDataset(name = 'ogbg-molhiv')
print(dataset)
split_idx = dataset.get_idx_split()
train_dataloader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True, collate_fn=collate_dgl)
valid_dataloader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False, collate_fn=collate_dgl)
test_dataloader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False, collate_fn=collate_dgl)

DglGraphPropPredDataset(41127)


In [7]:
# # Batching
# from torch.utils.data.sampler import SubsetRandomSampler

# from dgl.dataloading import GraphDataLoader

# num_examples = len(dataset)
# num_train = int(num_examples * 0.8)

# train_sampler = SubsetRandomSampler(torch.arange(num_train))
# test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

# train_dataloader = GraphDataLoader(
#     dataset, sampler=train_sampler, batch_size=5, drop_last=False
# )
# test_dataloader = GraphDataLoader(
#     dataset, sampler=test_sampler, batch_size=5, drop_last=False
# )

In [8]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        # self.conv1 = GAT(in_feats, h_feats)
        # self.conv2 = GAT(h_feats, num_classes)
        self.conv1 = MultiHeadAttentionGAT(in_feats, h_feats, 3, "concat")
        self.conv2 = MultiHeadAttentionGAT(h_feats*3, num_classes, 1, "concat")
        self.graph_level_repr = nn.Parameter(torch.FloatTensor(1, num_classes)) # graph level repr must be of dim = num_classes
        # self.batch_size = batch_size
        # self.num_classes = num_classes
        # self.graph_level_repr = nn.Parameter(torch.FloatTensor(batch_size, num_classes)) # graph level repr must be of dim = num_classes
        self.lin_att = nn.Linear(num_classes * 2, 1)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata["h"] = h
        # print(h.shape)
        # get attention between the graph level representation vector and each of the nodes by following the same attention mechanism
        # torch.cat([self.graph_level_repr.repeat_interleave(self.batch_size, dim = 0).view(self.batch_size, h.shape[1], self.num_classes)], dim = -1)
        # h_hidden = self.lin(h)
        cat_tensor = torch.cat([self.graph_level_repr.repeat(h.shape[0], 1), h], dim = -1)
        att_nodes = nn.Softmax(dim=1)(nn.LeakyReLU()(self.lin_att(cat_tensor)))
        # print(att_nodes.shape)
        g.ndata["att_nodes"] = att_nodes
        
        return dgl.mean_nodes(g, "h", "att_nodes")
        # return dgl.sum_nodes(g, "h", "att_nodes")


In [9]:
model = Model(9, 16, 2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(20):
    epoch_loss = 0
    batch_count = 0
    for batched_graph, labels in train_dataloader:
        pred = model(batched_graph, batched_graph.ndata["feat"].float())
        loss = F.cross_entropy(pred, labels.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_count += 1
        epoch_loss += loss.detach().item()
    print(f"Epoch {epoch}: Loss = {epoch_loss/batch_count}")


  assert input.numel() == input.storage().size(), (


Epoch 0: Loss = 0.16936044654726576
Epoch 1: Loss = 0.15963800226684852
Epoch 2: Loss = 0.15940540499326322
Epoch 3: Loss = 0.15972257107025623
Epoch 4: Loss = 0.15861018233913549
Epoch 5: Loss = 0.15812551263860988
Epoch 6: Loss = 0.15798043200802292
Epoch 7: Loss = 0.15769725181777808
Epoch 8: Loss = 0.15766481406248173
Epoch 9: Loss = 0.15648703641132375
Epoch 10: Loss = 0.15556504327959522
Epoch 11: Loss = 0.1540758159931159
Epoch 12: Loss = 0.15345233139211803
Epoch 13: Loss = 0.1517129637503705
Epoch 14: Loss = 0.15121742124887377
Epoch 15: Loss = 0.15038112053906488
Epoch 16: Loss = 0.1499709379718982
Epoch 17: Loss = 0.14998467752766098
Epoch 18: Loss = 0.14952144180800117
Epoch 19: Loss = 0.14869863731629127
Test accuracy: 0.9683929005592026


In [17]:
num_correct = 0
num_tests = 0
predictions = []
labels_all = []
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph, batched_graph.ndata["feat"].float())
    num_correct += (pred.argmax(1) == labels.view(-1)).sum().item()
    num_tests += len(labels)
    predictions.extend(pred.detach().numpy())
    labels_all.extend(labels.numpy())
print("Test accuracy:", num_correct / num_tests)

  assert input.numel() == input.storage().size(), (


Test accuracy: 0.9683929005592026


In [16]:
import numpy
numpy.array(predictions)

array([[ 3.7128718 , -1.6354263 ],
       [ 3.0620248 , -1.9714004 ],
       [ 2.6242704 , -1.6478105 ],
       ...,
       [ 2.3873723 , -0.82328194],
       [ 2.4491358 , -1.6382852 ],
       [ 2.6683507 , -1.660137  ]], dtype=float32)

In [19]:
numpy.array(labels_all).reshape(-1)

array([0, 0, 0, ..., 0, 0, 0])

In [20]:
from sklearn.metrics import roc_auc_score
roc_auc_score(numpy.array(labels_all).reshape(-1), numpy.array(predictions)[:, 1])

0.6914965526564824