<a href="https://colab.research.google.com/github/VinaySingh561/GraphMachineLearning/blob/main/Graph_Transformer_on_ZINC_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [3]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool
from torch_geometric.nn.attention import PerformerAttention

In [8]:
import os
path = osp.join(osp.dirname(os.getcwd()), '..', 'data', 'ZINC-PE')
transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)
val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)
test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

Downloading https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1
Extracting /data/ZINC-PE/molecules.zip
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/train.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/val.index
Downloading https://raw.githubusercontent.com/graphdeeplearning/benchmarking-gnns/master/data/molecules/test.index
Processing...
Processing train dataset: 100%|██████████| 10000/10000 [00:06<00:00, 1461.95it/s]
Processing val dataset: 100%|██████████| 1000/1000 [00:00<00:00, 1116.38it/s]
Processing test dataset: 100%|██████████| 1000/1000 [00:00<00:00, 1378.00it/s]
Done!


In [16]:
parser = argparse.ArgumentParser()
parser.add_argument(
    '--attn_type', default='multihead',
    help="Global attention type such as 'multihead' or 'performer'.")
args = parser.parse_args([])

In [17]:
class GPS(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int,
                 attn_type: str, attn_kwargs: Dict[str, Any]):
        super().__init__()

        self.node_emb = Embedding(28, channels - pe_dim)
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)

        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                Linear(channels, channels),
            )
            conv = GPSConv(channels, GINEConv(nn), heads=4,
                           attn_type=attn_type, attn_kwargs=attn_kwargs)
            self.convs.append(conv)

        self.mlp = Sequential(
            Linear(channels, channels // 2),
            ReLU(),
            Linear(channels // 2, channels // 4),
            ReLU(),
            Linear(channels // 4, 1),
        )
        self.redraw_projection = RedrawProjection(
            self.convs,
            redraw_interval=1000 if attn_type == 'performer' else None)

    def forward(self, x, pe, edge_index, edge_attr, batch):
        x_pe = self.pe_norm(pe)
        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs:
            x = conv(x, edge_index, batch, edge_attr=edge_attr)
        x = global_add_pool(x, batch)
        return self.mlp(x)



In [18]:
class RedrawProjection:
    def __init__(self, model: torch.nn.Module,
                 redraw_interval: Optional[int] = None):
        self.model = model
        self.redraw_interval = redraw_interval
        self.num_last_redraw = 0

    def redraw_projections(self):
        if not self.model.training or self.redraw_interval is None:
            return
        if self.num_last_redraw >= self.redraw_interval:
            fast_attentions = [
                module for module in self.model.modules()
                if isinstance(module, PerformerAttention)
            ]
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix()
            self.num_last_redraw = 0
            return
        self.num_last_redraw += 1


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_kwargs = {'dropout': 0.5}
model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type,
            attn_kwargs=attn_kwargs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)


In [19]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        model.redraw_projection.redraw_projections()
        out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                    data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                    data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

In [20]:
for epoch in range(1, 101):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

Epoch: 01, Loss: 0.6651, Val: 0.5069, Test: 0.5385
Epoch: 02, Loss: 0.5413, Val: 0.5185, Test: 0.4784
Epoch: 03, Loss: 0.4902, Val: 0.4522, Test: 0.4599
Epoch: 04, Loss: 0.4750, Val: 0.3713, Test: 0.3636
Epoch: 05, Loss: 0.4474, Val: 0.4002, Test: 0.3749
Epoch: 06, Loss: 0.4183, Val: 0.4131, Test: 0.4031
Epoch: 07, Loss: 0.4068, Val: 0.5295, Test: 0.5204
Epoch: 08, Loss: 0.3816, Val: 0.4454, Test: 0.4121
Epoch: 09, Loss: 0.3875, Val: 0.4598, Test: 0.4399
Epoch: 10, Loss: 0.3838, Val: 0.3574, Test: 0.3294
Epoch: 11, Loss: 0.3719, Val: 0.3098, Test: 0.2999
Epoch: 12, Loss: 0.3636, Val: 0.4494, Test: 0.4215
Epoch: 13, Loss: 0.3566, Val: 0.4697, Test: 0.4556
Epoch: 14, Loss: 0.3596, Val: 0.4023, Test: 0.3593
Epoch: 15, Loss: 0.3819, Val: 0.4213, Test: 0.4013
Epoch: 16, Loss: 0.3540, Val: 0.4350, Test: 0.3970
Epoch: 17, Loss: 0.3294, Val: 0.4111, Test: 0.3591
Epoch: 18, Loss: 0.3160, Val: 0.3703, Test: 0.3741
Epoch: 19, Loss: 0.3577, Val: 0.3647, Test: 0.3348
Epoch: 20, Loss: 0.3356, Val: 0

KeyboardInterrupt: 

In [21]:
def compute_r2(preds, targets):
    ss_res = ((targets - preds) ** 2).sum()
    ss_tot = ((targets - targets.mean()) ** 2).sum()
    return 1 - ss_res / ss_tot


@torch.no_grad()
def test_secondary(loader):
    model.eval()
    total_error = 0
    total_squared_error = 0
    preds, targets = [], []

    for data in loader:
        data = data.to(device)
        out = model(data.x, data.pe, data.edge_index, data.edge_attr,
                    data.batch)
        pred = out.squeeze()
        total_error += (pred - data.y).abs().sum().item()
        total_squared_error += ((pred - data.y) ** 2).sum().item()
        preds.append(pred.cpu())
        targets.append(data.y.cpu())

    preds = torch.cat(preds)
    targets = torch.cat(targets)

    mae = total_error / len(loader.dataset)
    rmse = (total_squared_error / len(loader.dataset)) ** 0.5
    r2 = compute_r2(preds, targets)
    return mae, rmse, r2


In [22]:
test_secondary(test_loader)

(0.26892981338500976, 0.5348927551696027, tensor(0.9297))