In [1]:
!pip install torch_geometric --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/661.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.1/661.6 kB[0m [31m5.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m655.4/661.6 kB[0m [31m10.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m9.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


In [2]:
import torch_geometric
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, global_add_pool

In [3]:
dataset = TUDataset(root='.', name='PROTEINS').shuffle()
print(len(dataset))
print(dataset[0].x.shape)
print(dataset.num_features)
print(dataset.num_classes)

Downloading https://www.chrsmrrs.com/graphkerneldatasets/PROTEINS.zip
Extracting ./PROTEINS/PROTEINS.zip
Processing...


1113
torch.Size([29, 3])
3
2


Done!


In [4]:
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset = dataset[int(len(dataset)*0.8): int(len(dataset)*0.9)]
test_dataset = dataset[int(len(dataset)*0.9):]

print("Training graphs:", len(train_dataset))
print("Validation graphs:", len(val_dataset))
print("Testing graphs:", len(test_dataset))

Training graphs: 890
Validation graphs: 111
Testing graphs: 112


In [83]:
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

In [84]:
torch.manual_seed(123)

<torch._C.Generator at 0x7a087c220bf0>

In [85]:
class GIN(nn.Module):
  def __init__(self, in_dim, hidden_dim, out_dim, n_layers):
    super().__init__()
    self.convs = nn.ModuleList()
    for i in range(n_layers):
      in_dim = in_dim if i==0 else hidden_dim
      self.convs.append(GINConv(
          nn.Sequential(
              nn.Linear(in_dim, hidden_dim),
              nn.BatchNorm1d(hidden_dim),
              nn.ReLU(),
              nn.Linear(hidden_dim, hidden_dim),
              nn.ReLU()

          )
      ))

    self.lin1 = nn.Linear(hidden_dim*n_layers, hidden_dim*n_layers)
    self.lin2 = nn.Linear(hidden_dim*n_layers, out_dim)

  def forward(self, data, edge_index, batch):
    hiddens = []
    for conv in self.convs:
      data = conv(data, edge_index)
      hiddens.append(global_add_pool(data, batch))

    hidden = torch.cat(hiddens, dim=1)
    out1 = F.relu(self.lin1(hidden))
    out1 = F.dropout(out1, p=0.2, training=self.training)
    out2 = self.lin2(out1)

    return out2

In [86]:
def train():
  model.train()
  for epoch in range(epochs):
    for data in train_loader:
      data = data.to(device)
      opt.zero_grad()
      out = model(data.x, data.edge_index, data.batch)
      loss = loss_fn(out, data.y)
      loss.backward()
      opt.step()
    val_loss, val_acc = test(val_loader)
    if epoch%50 == 0:
      print(f"Epoch: {epoch} | Val Loss: {val_loss.item():3f} | Val accuracy: {val_acc.item():3f}")

  test_loss, test_acc = test(test_loader)
  print(f"Test Loss: {val_loss.item():3f} | Test accuracy: {val_acc.item():3f}")

def test(loader):
  model.eval()
  with torch.no_grad():
    for data in loader:
      data = data.to(device)
      y_pred = model(data.x, data.edge_index, data.batch)
      loss = loss_fn(y_pred, data.y)
      y_pred = F.softmax(y_pred, dim=-1)
      y_pred = y_pred.argmax(dim=-1)
      test_acc = (y_pred == data.y).sum()/len(data.y)
      return loss, test_acc

In [87]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GIN(dataset.num_node_features, 256, dataset.num_classes, 5).to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-3, weight_decay=0.005)
loss_fn = nn.CrossEntropyLoss()
epochs = 500

In [88]:
train()

Epoch: 0 | Val Loss: 6.123676 | Val accuracy: 0.666667
Epoch: 50 | Val Loss: 0.480283 | Val accuracy: 0.783784
Epoch: 100 | Val Loss: 0.471610 | Val accuracy: 0.783784
Epoch: 150 | Val Loss: 0.458770 | Val accuracy: 0.801802
Epoch: 200 | Val Loss: 0.474852 | Val accuracy: 0.801802
Epoch: 250 | Val Loss: 0.469593 | Val accuracy: 0.792793
Epoch: 300 | Val Loss: 0.461840 | Val accuracy: 0.792793
Epoch: 350 | Val Loss: 0.461973 | Val accuracy: 0.819820
Epoch: 400 | Val Loss: 0.463730 | Val accuracy: 0.801802
Epoch: 450 | Val Loss: 0.461286 | Val accuracy: 0.792793
Test Loss: 0.455255 | Test accuracy: 0.801802
