# Graph Classification: Protein Dataset
- Goal: Classify each protein as an enzyme, a binary classification task
- This dataset does not contain any edge features

## Dataset

In [1]:
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='./data', name='PROTEINS').shuffle()

print(f'Dataset: {dataset}')
print('-----------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {dataset[0].x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

Dataset: PROTEINS(1113)
-----------------------
Number of graphs: 1113
Number of nodes: 7
Number of features: 3
Number of classes: 2


In [2]:
from torch_geometric.loader import DataLoader

train_dataset = dataset[: int(len(dataset) * 0.8)]
val_dataset = dataset[int(len(dataset) * 0.8): int(len(dataset) * 0.9)]
test_dataset = dataset[int(len(dataset) * 0.9):]

print(f'Training set = {len(train_dataset)} graphs')
print(f'Validation set = {len(val_dataset)} graphs')
print(f'Test set = {len(test_dataset)} graphs')

Training set = 890 graphs
Validation set = 111 graphs
Test set = 112 graphs


## Mini-batch: 64 graphs
- This means each batch will contain up to 64 graphs

In [3]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

print('\nTrain loader:')
for i, batch in enumerate(train_loader):
    print(f'- Batch {i}: {batch}')

print('\nValidation loader:')
for i, batch in enumerate(val_loader):
    print(f'- Batch {i}: {batch}')

print('\nTest loader:')
for i, batch in enumerate(test_loader):
    print(f' - Batch {i}: {batch}')


Train loader:
- Batch 0: DataBatch(edge_index=[2, 8696], x=[2296, 3], y=[64], batch=[2296], ptr=[65])
- Batch 1: DataBatch(edge_index=[2, 10620], x=[2865, 3], y=[64], batch=[2865], ptr=[65])
- Batch 2: DataBatch(edge_index=[2, 8684], x=[2367, 3], y=[64], batch=[2367], ptr=[65])
- Batch 3: DataBatch(edge_index=[2, 9616], x=[2683, 3], y=[64], batch=[2683], ptr=[65])
- Batch 4: DataBatch(edge_index=[2, 8904], x=[2414, 3], y=[64], batch=[2414], ptr=[65])
- Batch 5: DataBatch(edge_index=[2, 13002], x=[3430, 3], y=[64], batch=[3430], ptr=[65])
- Batch 6: DataBatch(edge_index=[2, 9802], x=[2638, 3], y=[64], batch=[2638], ptr=[65])
- Batch 7: DataBatch(edge_index=[2, 8616], x=[2282, 3], y=[64], batch=[2282], ptr=[65])
- Batch 8: DataBatch(edge_index=[2, 9552], x=[2551, 3], y=[64], batch=[2551], ptr=[65])
- Batch 9: DataBatch(edge_index=[2, 9106], x=[2412, 3], y=[64], batch=[2412], ptr=[65])
- Batch 10: DataBatch(edge_index=[2, 9168], x=[2431, 3], y=[64], batch=[2431], ptr=[65])
- Batch 11: Da

## Layer Composition
- Linear -> BatchNorm -> ReLU -> Linear -> ReLU

In [4]:
import torch
torch.manual_seed(42)

import torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU, Dropout
from torch_geometric.nn import GINConv
from torch_geometric.nn import global_add_pool

class GIN(torch.nn.Module):
    def __init__(self, dim_h):
        super(GIN, self).__init__()

        # GIN layers
        self.conv1 = GINConv(
            Sequential(Linear(dataset.num_node_features, dim_h),
                       BatchNorm1d(dim_h),
                       ReLU(),
                       Linear(dim_h, dim_h),
                       ReLU())
        )
        self.conv2 = GINConv(
            Sequential(Linear(dim_h, dim_h),
                       BatchNorm1d(dim_h),
                       ReLU(),
                       Linear(dim_h, dim_h),
                       ReLU())
        )
        self.conv3 = GINConv(
            Sequential(Linear(dim_h, dim_h),
                       BatchNorm1d(dim_h),
                       ReLU(),
                       Linear(dim_h, dim_h),
                       ReLU())
        )

        # Graph-level pooling layer /readout layer
        # Here we add each layer as recommended in GIN papaer
        self.lin1 = Linear(dim_h * 3, dim_h * 3)

        # Final Classification head
        self.lin2 = Linear(dim_h * 3, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # Node embedding
        h1 = self.conv1(x, edge_index)
        h2 = self.conv2(h1, edge_index)
        h3 = self.conv3(h2, edge_index)

        # Graph-level readout
        h1 = global_add_pool(h1, batch)
        h2 = global_add_pool(h2, batch)
        h3 = global_add_pool(h3, batch)

        # Concatenate graph embeddings
        h = torch.cat((h1, h2, h3), dim=1)

        # Classifier
        h = self.lin1(h)
        h = h.relu()
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.lin2(h)
        # print(h.shape)

        return F.log_softmax(h, dim=1)

## Train Test functions

In [5]:
def accuracy(pred_y, y):
    return ((pred_y == y).sum() / len(y)).item()

def train(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    epochs = 100

    model.train()
    for epoch in range(epochs + 1):
        total_loss = 0
        acc = 0
        val_loss = 0
        val_acc = 0

        # Train on batches
        for data in loader:
            optimizer.zero_grad()
            out = model(data.x, data.edge_index, data.batch)
            loss = criterion(out, data.y)
            total_loss += loss / len(loader)
            acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

            loss.backward()

            # Valiadation
            val_loss, val_acc = test(model, val_loader)
        # Print metrics every 20 epochs
        if (epoch % 20 == 0):
            print(f'Epoch {epoch:>3} \
                  | Train Loss: {total_loss:.2f} \
                  | Train Acc: {acc*100:>5.2f}% \
                  | ValLoss: {val_loss:.2f}  \
                  | Val Acc: {val_acc*100:.2f}%')
    return model

@torch.no_grad()
def test(model, loader):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()
    loss = 0
    acc = 0

    for data in loader:
        out = model(data.x, data.edge_index, data.batch)
        loss += criterion(out, data.y) / len(loader)
        acc += accuracy(out.argmax(dim=1), data.y) / len(loader)

    return loss, acc



Model with dim_h = 32

In [6]:
gin = GIN(dim_h=32)
gin

GIN(
  (conv1): GINConv(nn=Sequential(
    (0): Linear(in_features=3, out_features=32, bias=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=32, bias=True)
    (4): ReLU()
  ))
  (conv2): GINConv(nn=Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=32, bias=True)
    (4): ReLU()
  ))
  (conv3): GINConv(nn=Sequential(
    (0): Linear(in_features=32, out_features=32, bias=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=32, out_features=32, bias=True)
    (4): ReLU()
  ))
  (lin1): Linear(in_features=96, out_features=96, bias=True)
  (lin2): Linear(in_features=96, out_features=2, bias=True)
)

Let's train the model!

## Train the model

In [7]:
gin = train(gin, train_loader)

Epoch   0                   | Train Loss: 0.70                   | Train Acc: 63.19%                   | ValLoss: 0.60                    | Val Acc: 71.04%
Epoch  20                   | Train Loss: 0.67                   | Train Acc: 63.94%                   | ValLoss: 0.60                    | Val Acc: 70.48%
Epoch  40                   | Train Loss: 0.67                   | Train Acc: 63.83%                   | ValLoss: 0.59                    | Val Acc: 71.61%
Epoch  60                   | Train Loss: 0.67                   | Train Acc: 63.79%                   | ValLoss: 0.58                    | Val Acc: 72.17%
Epoch  80                   | Train Loss: 0.67                   | Train Acc: 63.75%                   | ValLoss: 0.60                    | Val Acc: 71.61%
Epoch 100                   | Train Loss: 0.67                   | Train Acc: 63.88%                   | ValLoss: 0.58                    | Val Acc: 71.61%


## Test the model

In [15]:
test_loss, test_acc = test(gin, test_loader)
print(f'Test Loss: {test_loss:.2f} | Test Acc: {test_acc * 100:.2f}%')

Test Loss: 0.69 | Test Acc: 62.76%


We can compare this test score with other GCNs under sampe setting with a simple global mean pooling as the readout layer. With the exact same setting, GIN often time outperfroms other architectures...

In [8]:
test_model = GIN(dim_h=32)
for data in train_loader:
    out = test_model(data.x, data.edge_index, data.batch)
    break

In [11]:
print(out.shape)
print(out)
print(out.argmax(dim=1).shape)

torch.Size([64, 2])
tensor([[-3.5728e+01,  0.0000e+00],
        [-1.5860e-01, -1.9196e+00],
        [-1.4764e+00, -2.5937e-01],
        [-7.8700e-03, -4.8486e+00],
        [-2.5429e-01, -1.4937e+00],
        [-1.2229e-02, -4.4101e+00],
        [-1.2679e-01, -2.1279e+00],
        [-5.5607e-01, -8.5204e-01],
        [-1.0544e+00, -4.2834e-01],
        [-5.3672e-02, -2.9516e+00],
        [-2.9955e+00, -5.1305e-02],
        [-4.7635e-01, -9.7035e-01],
        [-8.8076e-02, -2.4733e+00],
        [-7.3471e+00, -6.4471e-04],
        [-4.8479e-01, -9.5666e-01],
        [-2.6226e+00, -7.5381e-02],
        [-2.1891e-01, -1.6266e+00],
        [-2.0827e+00, -1.3306e-01],
        [-2.6529e+00, -7.3051e-02],
        [ 0.0000e+00, -2.1647e+01],
        [-3.0270e+00, -4.9672e-02],
        [-2.7262e+00, -6.7707e-02],
        [-8.3078e+00, -2.4661e-04],
        [-5.1529e-03, -5.2708e+00],
        [-1.0409e+00, -4.3564e-01],
        [-1.3642e+00, -2.9515e-01],
        [-5.5046e-01, -8.5964e-01],
        