### 1. GT (Graph Transformer)


In [1]:
# Install required packages.
!pip install chardet
print("installing ... ")
!conda install -y pytorch==2.0.0 torchvision==0.15.0 torchaudio==2.0.0 pytorch-cuda=11.7 -c pytorch -c nvidia
print("installed! torch 2.0.0")
!conda install -c dglteam/label/cu117 dgl



# !pip install torchvision 
# !pip install  dgl -f https://data.dgl.ai/wheels/cu118/repo.html
!pip install ogb

print("installed!")
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

try:
    import dgl
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "Failed to install DGL!")

#!pip install  torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 --upgrade
# pip install chardet
# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://www.dgl.ai/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.

installing ... 
Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

# All requested packages already installed.

installed! torch 2.0.0




  current version: 4.12.0
  latest version: 23.10.0

Please update conda by running

    $ conda update -n base -c defaults conda




Collecting package metadata (current_repodata.json): ...working... done
Solving environment: ...working... done

# All requested packages already installed.





  current version: 4.12.0
  latest version: 23.10.0

Please update conda by running

    $ conda update -n base -c defaults conda




installed!
DGL installed!


### Sparse Multi-head Attention

Recall the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:

$$\text{Attn}=\text{softmax}(\dfrac{QK^T} {\sqrt{d}})V,$$

The graph transformer (GT) model employs a Sparse Multi-head Attention block:

$$\text{SparseAttn}(Q, K, V, A) = \text{softmax}(\frac{(QK^T) \circ A}{\sqrt{d}})V,$$

where $Q, K, V ∈\mathbb{R}^{N\times d}$ are query feature, key feature, and value feature, respectively. $A\in[0,1]^{N\times N}$ is the adjacency matrix of the input graph. $(QK^T)\circ A$ means that the multiplication of query matrix and key matrix is followed by a Hadamard product (or element-wise multiplication) with the sparse adjacency matrix as illustrated in the figure below:

<img src="https://drive.google.com/uc?id=1OgMAewLR3Z1vz5y4J8aPRSeaU3g8iQfX" width="500">

In [2]:
import dgl
import dgl.nn as dglnn
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from dgl.data import AsGraphPredDataset
from dgl.dataloading import GraphDataLoader
from ogb.graphproppred import collate_dgl, DglGraphPropPredDataset, Evaluator
from ogb.graphproppred.mol_encoder import AtomEncoder
from tqdm import tqdm


class SparseMHA(nn.Module):
    """Sparse Multi-head Attention Module"""

    def __init__(self, hidden_size=80, num_heads=8):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scaling = self.head_dim**-0.5

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, A, h):
        N = len(h)
        # [N, dh, nh]
        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)
        q *= self.scaling
        # [N, dh, nh]
        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)
        # [N, dh, nh]
        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)

        ######################################################################
        # (HIGHLIGHT) Compute the multi-head attention with Sparse Matrix API
        ######################################################################
        attn = dglsp.bsddmm(A, q, k.transpose(1, 0))  # (sparse) [N, N, nh]
        # Sparse softmax by default applies on the last sparse dimension.
        attn = attn.softmax()  # (sparse) [N, N, nh]
        out = dglsp.bspmm(attn, v)  # [N, dh, nh]

        return self.out_proj(out.reshape(N, -1))

## Graph Transformer Layer

The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.

<img src="https://drive.google.com/uc?id=1cm-Ijw7bUQIOkoTKn5MQ3m4-66JqCsMz" width="300">

In [3]:
class GTLayer(nn.Module):
    """Graph Transformer Layer"""

    def __init__(self, hidden_size=80, num_heads=8):
        super().__init__()
        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)
        self.batchnorm1 = nn.BatchNorm1d(hidden_size)
        self.batchnorm2 = nn.BatchNorm1d(hidden_size)
        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)
        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, A, h):
        h1 = h
        h = self.MHA(A, h)
        h = self.batchnorm1(h + h1)

        h2 = h
        h = self.FFN2(F.relu(self.FFN1(h)))
        h = h2 + h

        return self.batchnorm2(h)

## Graph Transformer Model

The GT model is constructed by stacking GT layers. The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding [(Dwivedi et al. 2020)](https://arxiv.org/abs/2003.00982). For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph.

In [4]:
class GTModel(nn.Module):
    def __init__(
        self,
        out_size,
        hidden_size=80,
        pos_enc_size=2,
        num_layers=8,
        num_heads=8,
    ):
        super().__init__()
        self.atom_encoder = AtomEncoder(hidden_size)
        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)
        self.layers = nn.ModuleList(
            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]
        )
        self.pooler = dglnn.SumPooling()
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, out_size),
        )

    def forward(self, g, X, pos_enc):
        indices = torch.stack(g.edges())
        N = g.num_nodes()
        A = dglsp.spmatrix(indices, shape=(N, N))
        h = self.atom_encoder(X) + self.pos_linear(pos_enc)
        for layer in self.layers:
            h = layer(A, h)
        h = self.pooler(g, h)

        return self.predictor(h)

### Training

We train the GT model on [ogbg-molhiv](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol) benchmark. 
The Laplacian positional encoding of each graph is pre-computed as part of the input to the model.

In [5]:
@torch.no_grad()
def evaluate(model, dataloader, evaluator, device):
    model.eval()
    y_true = []
    y_pred = []
    for batched_g, labels in dataloader:
        batched_g, labels = batched_g.to(device), labels.to(device)
        y_hat = model(batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"])
        y_true.append(labels.view(y_hat.shape).detach().cpu())
        y_pred.append(y_hat.detach().cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    return evaluator.eval(input_dict)["rocauc"]



In [6]:

def train(model, dataset, evaluator, device):
    train_dataloader = GraphDataLoader(
        dataset[dataset.train_idx],
        batch_size=256,
        shuffle=True,
        collate_fn=collate_dgl,
    )
    valid_dataloader = GraphDataLoader(
        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl
    )
    test_dataloader = GraphDataLoader(
        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl
    )
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 5
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=num_epochs, gamma=0.5
    )
    loss_fcn = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for batched_g, labels in train_dataloader:
            batched_g, labels = batched_g.to(device), labels.to(device)
            logits = model(
                batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"]
            )
            loss = loss_fcn(logits, labels.float())
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        avg_loss = total_loss / len(train_dataloader)
        val_metric = evaluate(model, valid_dataloader, evaluator, device)
        test_metric = evaluate(model, test_dataloader, evaluator, device)
        print(
            f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, "
            f"Val: {val_metric:.4f}, Test: {test_metric:.4f}"
        )



In [7]:

# Training device.
dev = torch.device("cpu")
#Be sure to install DGL with CUDA support.
#dev = torch.device("cuda:0")

# Load dataset.
pos_enc_size = 8
dataset = AsGraphPredDataset(
    DglGraphPropPredDataset("ogbg-molhiv", "./data/OGB")
)
evaluator = Evaluator("ogbg-molhiv")

# Down sample the dataset to make the tutorial run faster.
import random
random.seed(42)
train_size = len(dataset.train_idx)
val_size = len(dataset.val_idx)
test_size = len(dataset.test_idx)
dataset.train_idx = dataset.train_idx[
    torch.LongTensor(random.sample(range(train_size), 2000))
]
dataset.val_idx = dataset.val_idx[
    torch.LongTensor(random.sample(range(val_size), 1000))
]
dataset.test_idx = dataset.test_idx[
    torch.LongTensor(random.sample(range(test_size), 1000))
]

# Laplacian positional encoding.
indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])
for idx in tqdm(indices, desc="Computing Laplacian PE"):
    g, _ = dataset[idx]
    g.ndata["PE"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)

# Create model.
out_size = dataset.num_tasks
model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)

# Kick off training.
train(model, dataset, evaluator, dev)

  return th.as_tensor(data, dtype=dtype)
Computing Laplacian PE: 100%|█████████████████████████████████████████████████████| 4000/4000 [00:12<00:00, 327.94it/s]


Epoch: 000, Loss: 0.4448, Val: 0.2640, Test: 0.3369
Epoch: 001, Loss: 0.2137, Val: 0.2971, Test: 0.3529
Epoch: 002, Loss: 0.1792, Val: 0.4039, Test: 0.4484
Epoch: 003, Loss: 0.1579, Val: 0.5109, Test: 0.4307
Epoch: 004, Loss: 0.1400, Val: 0.5626, Test: 0.4903


### 2. Graph GPS

<img src="https://miro.medium.com/v2/resize:fit:4800/format:webp/1*QKN2j0vBNS8fF-W2EuW5NQ.png" width="800">


In [48]:
import torch
torchversion = torch.__version__
print(torchversion)

2.0.0


In [49]:
# Install PyTorch Scatter, PyTorch Sparse, and PyTorch Geometric
!pip install torch_geometric
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+${11.7}.html
#!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{2.0.0}.html
#!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{2.0.0}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

Looking in links: https://data.pyg.org/whl/torch-2.0.0+$11.7.html


In [50]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool
from torch_geometric.nn.attention import PerformerAttention

In [66]:
path = 'ZINC-pe1'
dev = torch.device("cuda:0")
transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=True, split='train')
val_dataset = ZINC(path, subset=True, split='val')
test_dataset = ZINC(path, subset=True, split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

# parser = argparse.ArgumentParser()
# #parser.add_argument( default='multihead')
# args = parser.parse_args()


#*** Importance:
# Copy file add_positional_encoding.py 
# to anaconda3\envs\your envs name\lib\site-packages\torch_geometric\transforms\add_positional_encoding.py


In [73]:
class GPS(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int,
                 attn_type: str, attn_kwargs: Dict[str, Any]):
        super().__init__()

        self.node_emb = Embedding(28, channels )
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)

        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                Linear(channels, channels),
            )
            conv = GPSConv(channels, GINEConv(nn), heads=4,
                           attn_type=attn_type, attn_kwargs=attn_kwargs)
            self.convs.append(conv)

        self.mlp = Sequential(
            Linear(channels, channels // 2),
            ReLU(),
            Linear(channels // 2, channels // 4),
            ReLU(),
            Linear(channels // 4, 1),
        )
        self.redraw_projection = RedrawProjection(
            self.convs,
            redraw_interval=1000 if attn_type == 'performer' else None)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze(-1))
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs:
            x = conv(x, edge_index, batch, edge_attr=edge_attr)
        x = global_add_pool(x, batch)
        return self.mlp(x)

class RedrawProjection:
    def __init__(self, model: torch.nn.Module,
                 redraw_interval: Optional[int] = None):
        self.model = model
        self.redraw_interval = redraw_interval
        self.num_last_redraw = 0

    def redraw_projections(self):
        if not self.model.training or self.redraw_interval is None:
            return
        if self.num_last_redraw >= self.redraw_interval:
            fast_attentions = [
                module for module in self.model.modules()
                if isinstance(module, PerformerAttention)
            ]
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix()
            self.num_last_redraw = 0
            return
        self.num_last_redraw += 1


In [74]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_kwargs = {'dropout': 0.5}
model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type='performer',
            attn_kwargs=attn_kwargs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)


In [75]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        model.redraw_projection.redraw_projections()
        out = model(data.x, data.edge_index, data.edge_attr,
                    data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)



In [78]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x,  data.edge_index, data.edge_attr,
                    data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)



In [79]:
for epoch in range(1, 5):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

Epoch: 01, Loss: 0.5863, Val: 0.6312, Test: 0.6813
Epoch: 02, Loss: 0.5363, Val: 0.5040, Test: 0.5454
Epoch: 03, Loss: 0.5318, Val: 0.6812, Test: 0.7000
Epoch: 04, Loss: 0.5131, Val: 0.5102, Test: 0.5310
