In [1]:
import torch
import os
print("PyTorch has version {}".format(torch.__version__))

PyTorch has version 1.12.0+cu113


In [2]:
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
!pip install ogb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.12.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 33.5 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_sparse-0.6.14-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 47.8 MB/s 
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_cluster-1.6.0-cp37-cp37m-linux_x86_64.whl (2.4 MB)
[K     |████████████████████████████████| 2.4 MB 6.5 MB/s 
[?25hCollecting torch-spline-conv
  Downloading https://data.pyg.org/whl/torch-1.12.0%2Bcu113/torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl (709 kB)
[K     |████████████████████████████████| 709 kB 17.7 MB/s

In [3]:
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.data import DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder
from torch_geometric.nn import global_mean_pool

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = PygGraphPropPredDataset(name='ogbg-molhiv', root='dataset/')
print(dataset.num_node_features)
print(dataset.num_classes)

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip


Downloaded 0.00 GB: 100%|██████████| 3/3 [00:02<00:00,  1.16it/s]
Processing...


Extracting dataset/hiv.zip
Loading necessary files...
This might take a while.
Processing graphs...


100%|██████████| 41127/41127 [00:00<00:00, 96335.19it/s]


Converting graphs into PyG objects...


100%|██████████| 41127/41127 [00:00<00:00, 42242.72it/s]


Saving...
9
2


Done!


In [5]:
split_idx = dataset.get_idx_split()

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)



In [6]:
class GCN(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.node_encoder = AtomEncoder(256)

    self.conv1 = GCNConv(256, 256)
    self.conv2 = GCNConv(256, dataset.num_classes)
    self.bn = torch.nn.BatchNorm1d(256)
    self.softmax = torch.nn.LogSoftmax()
    self.pool = global_mean_pool
      
  def forward(self, data):
    x, edge_index, batch = data.x, data.edge_index, data.batch
    embed = self.node_encoder(x)
    x = self.conv1(embed, edge_index)
    x = self.bn(x)
    x = F.relu(x)
    x = F.dropout(x, self.training)
    x = self.conv2(x, edge_index)
    x = self.pool(x, batch)
    x = self.softmax(x)
    return x

In [7]:
model = GCN().to(device)
evaluator = Evaluator(name='ogbg-molhiv')
loss_fn = F.nll_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [8]:
for epoch in range(10):
  model.train()
  for step, data in enumerate((train_loader)):
    data.to(device)
    labels = data.y.squeeze(1)

    out = model(data)
    loss = loss_fn(out, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(loss.item())



0.19252406060695648
0.08766031265258789
0.054452937096357346
0.6679906845092773
0.03956208378076553
0.037417616695165634
0.037389710545539856
0.03943092003464699
0.038799915462732315
1.3364355564117432


In [9]:
model.eval()
correct = 0

for data in test_loader:
  data.to(device)
  test_y = data.y.reshape(-1)

  output = model(data)
  predict = output.max(dim=1)[1]

  correct += (predict == test_y).sum().item()

print('acc: {:07f}'.format(correct / len(test_loader.dataset)))



acc: 0.968393
