# 1. GT (Graph Transformer)


In [47]:
# Install required packages.
!pip install chardet
print("installing ... ")
!conda install -y pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia
# print("installed! torch 2.0.0")
# !conda install -c dglteam/label/cu117 dgl



!pip install torchvision 
# !pip install  dgl -f https://data.dgl.ai/wheels/cu118/repo.html
# !pip install dgl -f https://data.dgl.ai/wheels/cu124.html
!pip install ogb

print("installed!")
import os
import torch
os.environ['TORCH'] = torch.__version__
os.environ['DGLBACKEND'] = "pytorch"

try:
    import dgl
    installed = True
except ImportError:
    installed = False
print("DGL installed!" if installed else "Failed to install DGL!")

#!pip install  torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 --upgrade
# pip install chardet
# Uncomment below to install required packages. If the CUDA version is not 11.8,
# check the https://www.dgl.ai/pages/start.html to find the supported CUDA
# version and corresponding command to install DGL.

installing ... 


  conda config --add channels defaults

For more information see https://docs.conda.io/projects/conda/en/stable/user-guide/configuration/use-condarc.html

  deprecated.topic(


  conda config --add channels defaults

For more information see https://docs.conda.io/projects/conda/en/stable/user-guide/configuration/use-condarc.html

  deprecated.topic(
Channels:
 - pytorch
 - nvidia
 - defaults
Platform: linux-64
Collecting package metadata (repodata.json): done
Solving environment: done

# All requested packages already installed.

installed!
DGL installed!


## 1.1. Sparse Multi-head Attention

Recall the all-pairs scaled-dot-product attention mechanism in vanillar Transformer:

$$\text{Attn}=\text{softmax}(\dfrac{QK^T} {\sqrt{d}})V,$$

The graph transformer (GT) model employs a Sparse Multi-head Attention block:

$$\text{SparseAttn}(Q, K, V, A) = \text{softmax}(\frac{(QK^T) \circ A}{\sqrt{d}})V,$$

where $Q, K, V ∈\mathbb{R}^{N\times d}$ are query feature, key feature, and value feature, respectively. $A\in[0,1]^{N\times N}$ is the adjacency matrix of the input graph. $(QK^T)\circ A$ means that the multiplication of query matrix and key matrix is followed by a Hadamard product (or element-wise multiplication) with the sparse adjacency matrix as illustrated in the figure below:
![Sample Image](Fig1.jpeg)


## 1.2. SparseMHA

In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class SparseMHA(nn.Module):
    """Sparse Multi-head Attention Module using PyTorch sparse operations"""

    def __init__(self, hidden_size=80, num_heads=8):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.scaling = self.head_dim**-0.5

        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

    def forward(self, A, h):
        N = len(h)
        # [N, head_dim, num_heads]
        q = self.q_proj(h).reshape(N, self.head_dim, self.num_heads)
        q *= self.scaling
        # [N, head_dim, num_heads]
        k = self.k_proj(h).reshape(N, self.head_dim, self.num_heads)
        # [N, head_dim, num_heads]
        v = self.v_proj(h).reshape(N, self.head_dim, self.num_heads)

        ######################################################################
        # Compute the multi-head attention with PyTorch sparse matrix operations
        ######################################################################
        # Perform sparse-dense-dense matrix multiplication
        attn = torch.bmm(q.transpose(1, 2), k)  # Shape: [num_heads, N, N]
        attn = F.softmax(attn, dim=-1)  # Apply softmax on each head

        # Perform the sparse attention computation
        out = torch.bmm(attn, v.transpose(1, 2))  # Shape: [num_heads, N, head_dim]

        # Reshape and project output
        out = out.transpose(1, 2).reshape(N, -1)
        return self.out_proj(out)


## 1.3. Graph Transformer Layer

The GT layer is composed of Multi-head Attention, Batch Norm, and Feed-forward Network, connected by residual links as in vanilla transformer.

![Sample Image](Fig2.jpeg)


In [49]:
class GTLayer(nn.Module):
    """Graph Transformer Layer"""

    def __init__(self, hidden_size=80, num_heads=8):
        super().__init__()
        self.MHA = SparseMHA(hidden_size=hidden_size, num_heads=num_heads)
        self.batchnorm1 = nn.BatchNorm1d(hidden_size)
        self.batchnorm2 = nn.BatchNorm1d(hidden_size)
        self.FFN1 = nn.Linear(hidden_size, hidden_size * 2)
        self.FFN2 = nn.Linear(hidden_size * 2, hidden_size)

    def forward(self, A, h):
        h1 = h
        h = self.MHA(A, h)
        h = self.batchnorm1(h + h1)

        h2 = h
        h = self.FFN2(F.relu(self.FFN1(h)))
        h = h2 + h

        return self.batchnorm2(h)

## 1.4. Graph Transformer Model

The GT model is constructed by stacking GT layers. The input positional encoding of vanilla transformer is replaced with Laplacian positional encoding [(Dwivedi et al. 2020)](https://arxiv.org/abs/2003.00982). For the graph-level prediction task, an extra pooler is stacked on top of GT layers to aggregate node feature of the same graph.

In [50]:
import torch
import torch.nn as nn
import torch.optim as optim
from dgl.nn import SumPooling
from tqdm import tqdm
import random

class GTModel(nn.Module):
    def __init__(self, out_size, hidden_size=80, pos_enc_size=2, num_layers=8, num_heads=8):
        super().__init__()
        self.atom_encoder = AtomEncoder(hidden_size)
        self.pos_linear = nn.Linear(pos_enc_size, hidden_size)
        self.layers = nn.ModuleList(
            [GTLayer(hidden_size, num_heads) for _ in range(num_layers)]
        )
        self.pooler = SumPooling()
        self.predictor = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, out_size),
        )

    def forward(self, g, X, pos_enc):
        # Create adjacency matrix as a dense tensor
        indices = torch.stack(g.edges())
        N = g.num_nodes()
        A = torch.zeros((N, N), device=X.device)
        A[indices[0], indices[1]] = 1  # Undirected graph for symmetry

        # Initial feature processing
        h = self.atom_encoder(X) + self.pos_linear(pos_enc)
        for layer in self.layers:
            h = layer(A, h)  # Pass dense adjacency instead of sparse matrix
        h = self.pooler(g, h)

        return self.predictor(h)

## 1.5. Training

We train the GT model on [ogbg-molhiv](https://ogb.stanford.edu/docs/graphprop/#ogbg-mol) benchmark. 
The Laplacian positional encoding of each graph is pre-computed as part of the input to the model.

In [51]:
@torch.no_grad()
def evaluate(model, dataloader, evaluator, device):
    model.eval()
    y_true = []
    y_pred = []
    for batched_g, labels in dataloader:
        batched_g, labels = batched_g.to(device), labels.to(device)
        y_hat = model(batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"])
        y_true.append(labels.view(y_hat.shape).detach().cpu())
        y_pred.append(y_hat.detach().cpu())
    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()
    input_dict = {"y_true": y_true, "y_pred": y_pred}
    return evaluator.eval(input_dict)["rocauc"]



In [52]:
from dgl.dataloading import GraphDataLoader
from ogb.graphproppred import collate_dgl

def train(model, dataset, evaluator, device):
    train_dataloader = GraphDataLoader(
        dataset[dataset.train_idx],
        batch_size=256,
        shuffle=True,
        collate_fn=collate_dgl,
    )
    valid_dataloader = GraphDataLoader(
        dataset[dataset.val_idx], batch_size=256, collate_fn=collate_dgl
    )
    test_dataloader = GraphDataLoader(
        dataset[dataset.test_idx], batch_size=256, collate_fn=collate_dgl
    )
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    num_epochs = 5
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, step_size=num_epochs, gamma=0.5
    )
    loss_fcn = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        for batched_g, labels in train_dataloader:
            batched_g, labels = batched_g.to(device), labels.to(device)
            logits = model(
                batched_g, batched_g.ndata["feat"], batched_g.ndata["PE"]
            )
            loss = loss_fcn(logits, labels.float())
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        scheduler.step()
        avg_loss = total_loss / len(train_dataloader)
        val_metric = evaluate(model, valid_dataloader, evaluator, device)
        test_metric = evaluate(model, test_dataloader, evaluator, device)
        print(
            f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, "
            f"Val: {val_metric:.4f}, Test: {test_metric:.4f}"
        )



In [53]:
from dgl.data import AsGraphPredDataset
from ogb.graphproppred import DglGraphPropPredDataset, Evaluator
from tqdm import tqdm
from ogb.graphproppred.mol_encoder import AtomEncoder

# Training device.
# dev = torch.device("cpu")
#Be sure to install DGL with CUDA support.
dev = torch.device("cuda:0")

# Load dataset.
pos_enc_size = 8
dataset = AsGraphPredDataset(
    DglGraphPropPredDataset("ogbg-molhiv", "./data/OGB")
)
evaluator = Evaluator("ogbg-molhiv")

# Down sample the dataset to make the tutorial run faster.
import random
random.seed(42)
train_size = len(dataset.train_idx)
val_size = len(dataset.val_idx)
test_size = len(dataset.test_idx)
dataset.train_idx = dataset.train_idx[
    torch.LongTensor(random.sample(range(train_size), 2000))
]
dataset.val_idx = dataset.val_idx[
    torch.LongTensor(random.sample(range(val_size), 1000))
]
dataset.test_idx = dataset.test_idx[
    torch.LongTensor(random.sample(range(test_size), 1000))
]

# Laplacian positional encoding.
indices = torch.cat([dataset.train_idx, dataset.val_idx, dataset.test_idx])
for idx in tqdm(indices, desc="Computing Laplacian PE"):
    g, _ = dataset[idx]
    g.ndata["PE"] = dgl.laplacian_pe(g, k=pos_enc_size, padding=True)

# Create model.
out_size = dataset.num_tasks
model = GTModel(out_size=out_size, pos_enc_size=pos_enc_size).to(dev)

# Kick off training.
train(model, dataset, evaluator, dev)

Computing Laplacian PE: 100%|██████████| 4000/4000 [00:06<00:00, 613.69it/s]


Epoch: 000, Loss: 0.3307, Val: 0.3266, Test: 0.3441
Epoch: 001, Loss: 0.1917, Val: 0.4158, Test: 0.3426
Epoch: 002, Loss: 0.1621, Val: 0.6177, Test: 0.4677
Epoch: 003, Loss: 0.1402, Val: 0.6287, Test: 0.4744
Epoch: 004, Loss: 0.1281, Val: 0.6714, Test: 0.5341


# 2. GraphGPS: General Powerful Scalable Graph Transformers

<img src="https://miro.medium.com/v2/resize:fit:4800/format:webp/1*QKN2j0vBNS8fF-W2EuW5NQ.png" width="800">


In [54]:
import torch
torchversion = torch.__version__
print(torchversion)

2.5.1


In [55]:
# # Install PyTorch Scatter, PyTorch Sparse, and PyTorch Geometric
# !pip install torch_geometric
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+${11.7}.html
# #!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-{2.0.0}.html
# #!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-{2.0.0}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [56]:
import argparse
import os.path as osp
from typing import Any, Dict, Optional

import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torch_geometric.transforms as T
from torch_geometric.datasets import ZINC
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool
from torch_geometric.nn.attention import PerformerAttention

## 2.1. Load dataset

In [57]:
path = 'ZINC-pe1'
dev = torch.device("cuda:0")
transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')
train_dataset = ZINC(path, subset=True, split='train')
val_dataset = ZINC(path, subset=True, split='val')
test_dataset = ZINC(path, subset=True, split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

# parser = argparse.ArgumentParser()
# #parser.add_argument( default='multihead')
# args = parser.parse_args()


#*** Importance:
# Copy file add_positional_encoding.py 
# to anaconda3\envs\your envs name\lib\site-packages\torch_geometric\transforms\add_positional_encoding.py


## 2.2. GPS

In [58]:
class GPS(torch.nn.Module):
    def __init__(self, channels: int, pe_dim: int, num_layers: int,
                 attn_type: str, attn_kwargs: Dict[str, Any]):
        super().__init__()

        self.node_emb = Embedding(28, channels )
        self.pe_lin = Linear(20, pe_dim)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Embedding(4, channels)

        self.convs = ModuleList()
        for _ in range(num_layers):
            nn = Sequential(
                Linear(channels, channels),
                ReLU(),
                Linear(channels, channels),
            )
            conv = GPSConv(channels, GINEConv(nn), heads=4,
                           attn_type=attn_type, attn_kwargs=attn_kwargs)
            self.convs.append(conv)

        self.mlp = Sequential(
            Linear(channels, channels // 2),
            ReLU(),
            Linear(channels // 2, channels // 4),
            ReLU(),
            Linear(channels // 4, 1),
        )
        self.redraw_projection = RedrawProjection(
            self.convs,
            redraw_interval=1000 if attn_type == 'performer' else None)

    def forward(self, x, edge_index, edge_attr, batch):
        x = self.node_emb(x.squeeze(-1))
        edge_attr = self.edge_emb(edge_attr)

        for conv in self.convs:
            x = conv(x, edge_index, batch, edge_attr=edge_attr)
        x = global_add_pool(x, batch)
        return self.mlp(x)

class RedrawProjection:
    def __init__(self, model: torch.nn.Module,
                 redraw_interval: Optional[int] = None):
        self.model = model
        self.redraw_interval = redraw_interval
        self.num_last_redraw = 0

    def redraw_projections(self):
        if not self.model.training or self.redraw_interval is None:
            return
        if self.num_last_redraw >= self.redraw_interval:
            fast_attentions = [
                module for module in self.model.modules()
                if isinstance(module, PerformerAttention)
            ]
            for fast_attention in fast_attentions:
                fast_attention.redraw_projection_matrix()
            self.num_last_redraw = 0
            return
        self.num_last_redraw += 1


## 2.3. Hyperparameters

In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
attn_kwargs = {'dropout': 0.5}
model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type='performer',
            attn_kwargs=attn_kwargs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)


## 2.4. Train

In [60]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        model.redraw_projection.redraw_projections()
        out = model(data.x, data.edge_index, data.edge_attr,
                    data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)



## 2.5. Test

In [61]:
@torch.no_grad()
def test(loader):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        out = model(data.x,  data.edge_index, data.edge_attr,
                    data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)



In [62]:
for epoch in range(1, 5):
    loss = train()
    val_mae = test(val_loader)
    test_mae = test(test_loader)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

Epoch: 01, Loss: 0.7243, Val: 0.5609, Test: 0.5996
Epoch: 02, Loss: 0.5899, Val: 0.5317, Test: 0.5483
Epoch: 03, Loss: 0.5511, Val: 0.5443, Test: 0.5921
Epoch: 04, Loss: 0.5264, Val: 0.4494, Test: 0.4924


# 3. Graphormer: Do Transformers Really Perform Bad for Graph Representation
<img src="image.png" alt="Sample Image" width="1000" height="500"/>



## 3.1. Centrality Encoding

In [38]:
from typing import Tuple
import torch
from torch import nn
from torch_geometric.utils import degree

class CentralityEncoding(nn.Module):
    def __init__(self, max_in_degree: int, max_out_degree: int, node_dim: int):
        """
        Initializes the CentralityEncoding module.

        :param max_in_degree: Maximum in-degree of nodes for encoding (limits the embedding size)
        :param max_out_degree: Maximum out-degree of nodes for encoding (limits the embedding size)
        :param node_dim: The dimensionality of node feature embeddings
        """
        super().__init__()
        self.max_in_degree = max_in_degree
        self.max_out_degree = max_out_degree
        self.node_dim = node_dim

        # Learnable parameters for encoding based on in-degrees and out-degrees
        # `z_in` encodes nodes based on in-degree; `z_out` encodes nodes based on out-degree
        self.z_in = nn.Parameter(torch.randn((max_in_degree, node_dim)))
        self.z_out = nn.Parameter(torch.randn((max_out_degree, node_dim)))

    def forward(self, x: torch.Tensor, edge_index: torch.LongTensor) -> torch.Tensor:
        """
        Forward pass to apply centrality encoding to node features.

        :param x: Node feature matrix (num_nodes x node_dim)
        :param edge_index: Edge index matrix (2 x num_edges) for adjacency representation
        :return: torch.Tensor, node embeddings after centrality encoding
        """
        num_nodes = x.shape[0]

        # Calculate in-degrees and out-degrees of nodes using `edge_index`
        # Limits the maximum in-degree and out-degree values to avoid exceeding embedding size

        in_degree = self.decrease_to_max_value(degree(index=edge_index[1], num_nodes=num_nodes).long(),
                                               self.max_in_degree - 1) 
        out_degree = self.decrease_to_max_value(degree(index=edge_index[0], num_nodes=num_nodes).long(),
                                                self.max_out_degree - 1)

        # Add in-degree and out-degree encodings to the node features
        x += self.z_in[in_degree] + self.z_out[out_degree] 

        return x

    def decrease_to_max_value(self, x, max_value):
        """
        Limits the maximum value in tensor x to `max_value`.

        :param x: Tensor with degree values (either in-degree or out-degree)
        :param max_value: Maximum allowable value in x
        :return: Modified tensor with values capped at max_value
        """
        x[x > max_value] = max_value
        return x


## 3.2. Spatial Encoding

In [23]:
import torch
from torch import nn

class SpatialEncoding(nn.Module):
    def __init__(self, max_path_distance: int):
        """
        Initializes the SpatialEncoding module.

        :param max_path_distance: Maximum pairwise distance between nodes to consider for encoding.
        """
        super().__init__()
        self.max_path_distance = max_path_distance

        # Learnable parameter vector `b` for different path distances.
        # It contains embeddings for each possible path length up to `max_path_distance`.
        self.b = nn.Parameter(torch.randn(self.max_path_distance))

    def forward(self, x: torch.Tensor, paths) -> torch.Tensor:
        """
        Computes the spatial encoding matrix based on pairwise node paths.

        :param x: Node feature matrix of shape (num_nodes, node_dim).
        :param paths: Dictionary containing pairwise node paths; paths[src][dst] gives the path from src to dst.
        :return: Spatial encoding matrix of shape (num_nodes, num_nodes).
        """
        # Initialize the spatial encoding matrix with zeros.
        # This matrix will store the spatial encoding value between each pair of nodes.
        spatial_matrix = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device)  # (num_nodes, num_nodes)

        # Iterate over all source nodes in the paths dictionary.
        for src in paths:
            # Iterate over all destination nodes reachable from the source node.
            for dst in paths[src]:
                # Calculate the length of the path from src to dst.
                path_length = len(paths[src][dst])

                # Cap the path length at `max_path_distance` to prevent indexing errors.
                capped_length = min(path_length, self.max_path_distance)

                # Subtract 1 because indexing starts at 0.
                index = capped_length - 1

                # Assign the corresponding embedding value from `b` to the spatial matrix entry.
                spatial_matrix[src][dst] = self.b[index]

        return spatial_matrix


## 3.3. Edge Encoding

In [39]:
import torch
from torch import nn

class EdgeEncoding(nn.Module):
    def __init__(self, edge_dim: int, max_path_distance: int):
        """
        Initializes the EdgeEncoding module.

        :param edge_dim: The dimensionality of edge feature embeddings.
        :param max_path_distance: The maximum path distance to consider for encoding between nodes.
        """
        super().__init__()
        self.edge_dim = edge_dim
        self.max_path_distance = max_path_distance
        
        # Learnable parameter `edge_vector` for encoding edges based on path distances.
        # `edge_vector` has a shape of (max_path_distance, edge_dim) where each entry corresponds
        # to an encoding for a specific path distance.
        self.edge_vector = nn.Parameter(torch.randn(self.max_path_distance, self.edge_dim))

    def forward(self, x: torch.Tensor, edge_attr: torch.Tensor, edge_paths) -> torch.Tensor:
        """
        Forward pass to apply edge encoding based on pairwise paths between nodes.

        :param x: Node feature matrix (num_nodes x node_dim).
        :param edge_attr: Edge feature matrix (num_edges x edge_dim).
        :param edge_paths: Dictionary of paths between node pairs (based on edge indexes),
                           where edge_paths[src][dst] contains the edges in the path from node `src` to `dst`.
        :return: torch.Tensor, Edge Encoding matrix (num_nodes x num_nodes).
        """
        # Initialize the edge encoding matrix `cij` with zeros.
        cij = torch.zeros((x.shape[0], x.shape[0])).to(next(self.parameters()).device)

        # Iterate over each source node in the edge_paths dictionary.
        for src in edge_paths:
            # For each destination node reachable from `src`, calculate an encoding based on path length.
            for dst in edge_paths[src]:
                # Retrieve the path between `src` and `dst` limited to `max_path_distance`.
                path_ij = edge_paths[src][dst][:self.max_path_distance]
                
                # Generate indices based on the length of `path_ij`.
                # `weight_inds` is a list of indices to select weights for each edge in `path_ij`.
                weight_inds = [i for i in range(len(path_ij))]
                
                # Calculate the dot product between `edge_vector` and `edge_attr` for edges in `path_ij`.
                # The mean is taken to create a single encoding value for the `src` to `dst` path.
                cij[src][dst] = self.dot_product(self.edge_vector[weight_inds], edge_attr[path_ij]).mean()

        # Replace any NaN values in the encoding matrix with zero.
        cij = torch.nan_to_num(cij)

        # Return the edge encoding matrix.
        return cij

    def dot_product(self, x1, x2) -> torch.Tensor:
        """
        Calculates the dot product between two tensors along dimension 1.

        :param x1: First tensor (subset of `edge_vector` for path distances).
        :param x2: Second tensor (subset of `edge_attr` for edges in a path).
        :return: Tensor of dot product results.
        """
        return (x1 * x2).sum(dim=1)  # Element-wise multiplication followed by sum along `dim=1`


## 3.4. Graphormer Attention Head

In [40]:
import torch
from torch import nn

class GraphormerAttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_q: int, dim_k: int, edge_dim: int, max_path_distance: int):
        """
        Initializes the GraphormerAttentionHead module.

        :param dim_in: Input dimensionality of node features.
        :param dim_q: Dimensionality for the query matrix.
        :param dim_k: Dimensionality for the key and value matrices.
        :param edge_dim: Dimensionality of edge features.
        :param max_path_distance: Maximum path distance for encoding in EdgeEncoding.
        """
        super().__init__()

        # Initialize the EdgeEncoding module for edge-based encodings
        self.edge_encoding = EdgeEncoding(edge_dim, max_path_distance)

        # Linear layers to project node features into query, key, and value matrices
        self.q = nn.Linear(dim_in, dim_q)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_k)

    def forward(self,
                x: torch.Tensor,
                edge_attr: torch.Tensor,
                b: torch.Tensor,
                edge_paths,
                ptr=None) -> torch.Tensor:
        """
        Forward pass for attention calculation.

        :param x: Node feature matrix (num_nodes x dim_in).
        :param edge_attr: Edge feature matrix (num_edges x edge_dim).
        :param b: Spatial encoding matrix (num_nodes x num_nodes).
        :param edge_paths: Dictionary of paths between node pairs, with paths based on edge indexes.
        :param ptr: Optional batch pointer, specifying graph indices within batches of graphs.
        :return: torch.Tensor, updated node embeddings after the attention operation.
        """
        # Initialize batch masks:
        # - batch_mask_neg_inf (with -inf values) for attention masking.
        # - batch_mask_zeros (with zeros) for masking in softmax calculations.
        batch_mask_neg_inf = torch.full(size=(x.shape[0], x.shape[0]), fill_value=-1e6).to(next(self.parameters()).device)
        batch_mask_zeros = torch.zeros(size=(x.shape[0], x.shape[0])).to(next(self.parameters()).device)

        # Check if `ptr` is None (single graph). If so, set all mask values to 1.
        if type(ptr) == type(None):
            batch_mask_neg_inf = torch.ones(size=(x.shape[0], x.shape[0])).to(next(self.parameters()).device)
            batch_mask_zeros += 1
        else:
            # Otherwise, create batch masks based on the graph boundaries specified by `ptr`.
            for i in range(len(ptr) - 1):
                batch_mask_neg_inf[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1
                batch_mask_zeros[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = 1

        # Project node features `x` into query, key, and value matrices
        query = self.q(x)
        key = self.k(x)
        value = self.v(x)

        # Compute edge-based encoding using `EdgeEncoding`
        c = self.edge_encoding(x, edge_attr, edge_paths)

        # Calculate the attention scores by combining query and key
        a = self.compute_a(key, query, ptr)

        # Combine attention scores with spatial encoding `b` and edge encoding `c`, and apply masking
        a = (a + b + c) * batch_mask_neg_inf

        # Apply softmax to the attention scores along the last dimension and re-apply mask
        softmax = torch.softmax(a, dim=-1) * batch_mask_zeros  # e^(-inf) results in 0

        # Compute the new node embeddings by applying attention scores to the value matrix
        x = softmax.mm(value)

        return x

    def compute_a(self, key, query, ptr=None):
        """
        Computes the normalized dot-product attention between query and key.

        :param key: Key matrix for nodes.
        :param query: Query matrix for nodes.
        :param ptr: Optional batch pointer specifying graph indices within batches of graphs.
        :return: torch.Tensor, attention scores for the nodes.
        """
        # If `ptr` is None (single graph), compute attention for all nodes in the graph
        if type(ptr) == type(None):
            a = query.mm(key.transpose(0, 1)) / query.size(-1) ** 0.5
        else:
            # Otherwise, compute attention scores for each graph separately in the batch
            a = torch.zeros((query.shape[0], query.shape[0]), device=key.device)
            for i in range(len(ptr) - 1):
                a[ptr[i]:ptr[i + 1], ptr[i]:ptr[i + 1]] = query[ptr[i]:ptr[i + 1]].mm(
                    key[ptr[i]:ptr[i + 1]].transpose(0, 1)) / query.size(-1) ** 0.5

        return a


## 3.4. Graphormer Multihead Attention

In [41]:
class GraphormerMultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int, edge_dim: int, max_path_distance: int):
        """
        :param num_heads: number of attention heads
        :param dim_in: node feature matrix input number of dimension
        :param dim_q: query node feature matrix input number dimension
        :param dim_k: key node feature matrix input number of dimension
        :param edge_dim: edge feature matrix number of dimension
        """
        super().__init__()
        self.heads = nn.ModuleList(
            [GraphormerAttentionHead(dim_in, dim_q, dim_k, edge_dim, max_path_distance) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self,
                x: torch.Tensor,
                edge_attr: torch.Tensor,
                b: torch.Tensor,
                edge_paths,
                ptr) -> torch.Tensor:
        """
        :param x: node feature matrix
        :param edge_attr: edge feature matrix
        :param b: spatial Encoding matrix
        :param edge_paths: pairwise node paths in edge indexes
        :param ptr: batch pointer that shows graph indexes in batch of graphs
        :return: torch.Tensor, node embeddings after all attention heads
        """
        return self.linear(
            torch.cat([
                attention_head(x, edge_attr, b, edge_paths, ptr) for attention_head in self.heads
            ], dim=-1)
        )

## 3.5. Graphormer Encoder Layer

In [42]:
class GraphormerEncoderLayer(nn.Module):
    def __init__(self, node_dim, edge_dim, num_heads, max_path_distance):
        """
        :param node_dim: node feature matrix input number of dimension
        :param edge_dim: edge feature matrix input number of dimension
        :param num_heads: number of attention heads
        """
        super().__init__()

        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.num_heads = num_heads

        self.attention = GraphormerMultiHeadAttention(
            dim_in=node_dim,
            dim_k=node_dim,
            dim_q=node_dim,
            num_heads=num_heads,
            edge_dim=edge_dim,
            max_path_distance=max_path_distance,
        )
        self.ln_1 = nn.LayerNorm(node_dim)
        self.ln_2 = nn.LayerNorm(node_dim)
        self.ff = nn.Linear(node_dim, node_dim)

    def forward(self,
                x: torch.Tensor,
                edge_attr: torch.Tensor,
                b: torch,
                edge_paths,
                ptr) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        h′(l) = MHA(LN(h(l−1))) + h(l−1)
        h(l) = FFN(LN(h′(l))) + h′(l)

        :param x: node feature matrix
        :param edge_attr: edge feature matrix
        :param b: spatial Encoding matrix
        :param edge_paths: pairwise node paths in edge indexes
        :param ptr: batch pointer that shows graph indexes in batch of graphs
        :return: torch.Tensor, node embeddings after Graphormer layer operations
        """
        x_prime = self.attention(self.ln_1(x), edge_attr, b, edge_paths, ptr) + x
        x_new = self.ff(self.ln_2(x_prime)) + x_prime

        return x_new

## 3.6. Shortest path

In [43]:
from typing import Union, Tuple, Dict, List

import torch
import networkx as nx
from torch import nn
from torch_geometric.data import Data
from torch_geometric.nn import global_mean_pool
from torch_geometric.utils.convert import to_networkx

def floyd_warshall_source_to_all(G, source, cutoff=None):
    """
    Computes the shortest paths from a source node to all other nodes in graph `G`.
    Uses a modified Floyd-Warshall algorithm for single-source shortest paths.

    :param G: NetworkX graph.
    :param source: Source node from which to calculate paths.
    :param cutoff: Maximum path length to consider.
    :return: Tuple of dictionaries, `node_paths` and `edge_paths`, where:
             - `node_paths` contains the shortest path of nodes from source to each node.
             - `edge_paths` contains the list of edges (by edge index) in each path.
    """
    if source not in G:
        raise nx.NodeNotFound("Source {} not in G".format(source))

    # Mapping each edge to a unique index for path tracking
    edges = {edge: i for i, edge in enumerate(G.edges())}

    level = 0  # The current level of traversal
    nextlevel = {source: 1}  # Nodes to explore at the next level
    node_paths = {source: [source]}  # Shortest paths (in nodes) from source to each reachable node
    edge_paths = {source: []}  # Shortest paths (in edges) from source to each reachable node

    # BFS to explore each level until no new nodes or cutoff level is reached
    while nextlevel:
        thislevel = nextlevel
        nextlevel = {}
        for v in thislevel:
            # Iterate through each neighbor `w` of node `v`
            for w in G[v]:
                if w not in node_paths:  # Only consider unvisited nodes
                    # Update node paths by appending `w` to path from source to `v`
                    node_paths[w] = node_paths[v] + [w]
                    # Append edge index to path from source to `w`
                    edge_paths[w] = edge_paths[v] + [edges[tuple(node_paths[w][-2:])]]
                    # Queue `w` for exploration in the next level
                    nextlevel[w] = 1

        level += 1

        # Stop if the cutoff distance is reached
        if (cutoff is not None and cutoff <= level):
            break

    return node_paths, edge_paths


def all_pairs_shortest_path(G) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    """
    Computes shortest paths between all pairs of nodes in graph `G`.

    :param G: NetworkX graph.
    :return: Tuple of dictionaries `node_paths` and `edge_paths`:
             - `node_paths` contains shortest paths of nodes between each pair.
             - `edge_paths` contains shortest paths in terms of edge indices between each pair.
    """
    # Compute shortest paths from each node to all other nodes
    paths = {n: floyd_warshall_source_to_all(G, n) for n in G}
    # Separate node and edge paths into separate dictionaries
    node_paths = {n: paths[n][0] for n in paths}
    edge_paths = {n: paths[n][1] for n in paths}
    return node_paths, edge_paths


def shortest_path_distance(data: Data) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    """
    Computes shortest paths between all pairs of nodes in a single PyTorch Geometric graph `data`.

    :param data: PyTorch Geometric `Data` object representing the graph.
    :return: Tuple of dictionaries `node_paths` and `edge_paths` containing shortest paths.
    """
    # Convert PyTorch Geometric data to a NetworkX graph
    G = to_networkx(data)
    # Compute all pairs shortest paths in terms of node and edge paths
    node_paths, edge_paths = all_pairs_shortest_path(G)
    return node_paths, edge_paths


def batched_shortest_path_distance(data) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    """
    Computes shortest paths between all pairs of nodes for a batch of PyTorch Geometric graphs.

    :param data: PyTorch Geometric batch object containing multiple graphs.
    :return: Tuple of dictionaries `node_paths` and `edge_paths` for all graphs in the batch.
    """
    # Convert each graph in the batch to a NetworkX graph and collect in a list
    graphs = [to_networkx(sub_data) for sub_data in data.to_data_list()]
    relabeled_graphs = []
    shift = 0  # Track node ID shifts to ensure unique node IDs across graphs

    # Relabel nodes in each graph to ensure unique node IDs across all graphs in the batch
    for i in range(len(graphs)):
        num_nodes = graphs[i].number_of_nodes()
        # Shift the node IDs of the current graph by `shift`
        relabeled_graphs.append(nx.relabel_nodes(graphs[i], {i: i + shift for i in range(num_nodes)}))
        shift += num_nodes

    # Compute all pairs shortest paths for each relabeled graph in the batch
    paths = [all_pairs_shortest_path(G) for G in relabeled_graphs]
    node_paths = {}
    edge_paths = {}

    # Aggregate node and edge paths for each graph into dictionaries
    for path in paths:
        # Update node_paths and edge_paths with results from each graph
        for k, v in path[0].items():
            node_paths[k] = v
        for k, v in path[1].items():
            edge_paths[k] = v

    return node_paths, edge_paths


## 3.7. Graphormer

In [30]:
class Graphormer(nn.Module):
    def __init__(self, config, num_node_features, num_edge_features):
        """
        :param config: dictionary of configuration parameters
        :param num_node_features: number of node features
        :param num_edge_features: number of edge features
        """
        super().__init__()

        self.num_layers = config['num_layers']
        self.input_node_dim = num_node_features
        self.node_dim = config['node_dim']
        self.input_edge_dim = num_edge_features
        self.edge_dim = config['edge_dim']
        self.output_dim = config['output_dim']
        self.num_heads = config['num_heads']
        self.max_in_degree = config['max_in_degree']
        self.max_out_degree = config['max_out_degree']
        self.max_path_distance = config['max_path_distance']

        self.node_in_lin = nn.Linear(self.input_node_dim, self.node_dim)
        self.edge_in_lin = nn.Linear(self.input_edge_dim, self.edge_dim)

        self.centrality_encoding = CentralityEncoding(
            max_in_degree=self.max_in_degree,
            max_out_degree=self.max_out_degree,
            node_dim=self.node_dim
        )

        self.spatial_encoding = SpatialEncoding(
            max_path_distance=self.max_path_distance,
        )

        self.layers = nn.ModuleList([
            GraphormerEncoderLayer(
                node_dim=self.node_dim,
                edge_dim=self.edge_dim,
                num_heads=self.num_heads,
                max_path_distance=self.max_path_distance) for _ in range(self.num_layers)
        ])

        self.node_out_lin = nn.Linear(self.node_dim, self.output_dim)

    def forward(self, data):
        x = data.x.float()
        edge_index = data.edge_index.long()
        edge_attr = data.edge_attr.float()

        if isinstance(data, Data):
            ptr = None
            node_paths, edge_paths = shortest_path_distance(data)
        else:
            ptr = data.ptr
            node_paths, edge_paths = batched_shortest_path_distance(data)

        x = self.node_in_lin(x)
        edge_attr = self.edge_in_lin(edge_attr)

        x = self.centrality_encoding(x, edge_index)
        b = self.spatial_encoding(x, node_paths)

        for layer in self.layers:
            x = layer(x, edge_attr, b, edge_paths, ptr)

        x = self.node_out_lin(x)
        x = global_mean_pool(x, data.batch)

        return x

## 3.8 Load MoleculeNet dataset

In [31]:
import random
from torch_geometric.datasets import MoleculeNet
from torch_geometric.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

def load_ESOL():
    dataset = MoleculeNet(root='Data/MoleculeNet', name='ESOL')
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    return train_loader, test_loader, dataset.num_node_features, dataset.num_edge_features


## 3.9. Train and test

In [35]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2024.3.5-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.9 kB)
Downloading rdkit-2024.3.5-cp312-cp312-manylinux_2_28_x86_64.whl (33.1 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m33.1/33.1 MB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m0m eta [36m0:00:01[0m[36m0:00:01[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.3.5


In [36]:
def calculate_metrics(y_true, y_pred):
    mae = mean_absolute_error(y_true, y_pred)
    mse = mean_squared_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    return mae, mse, r2

def train(model, train_loader, device, criterion, optimizer):
    model.train()
    total_loss = 0
    y_true, y_pred = [], []

    for data in tqdm(train_loader, desc="Training"):
        data = data.to(device)  # Move data batch to GPU
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, data.y.to(device))  # Ensure target is also on GPU
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        y_true.extend(data.y.cpu().numpy())
        y_pred.extend(outputs.cpu().detach().numpy())

    avg_loss = total_loss / len(train_loader.dataset)
    mae, mse, r2 = calculate_metrics(y_true, y_pred)
    print(f"Train Loss: {avg_loss:.4f}, MAE: {mae:.4f}, MSE: {mse:.4f}, R2: {r2:.4f}")

def test(model, test_loader, device, criterion):
    model.eval()
    total_loss = 0
    y_true, y_pred = [], []

    with torch.no_grad():
        for data in tqdm(test_loader, desc="Testing"):
            data = data.to(device)  # Move data batch to GPU
            outputs = model(data)
            loss = criterion(outputs, data.y.to(device))  # Ensure target is also on GPU
            total_loss += loss.item()
            y_true.extend(data.y.cpu().numpy())
            y_pred.extend(outputs.cpu().numpy())

    avg_loss = total_loss / len(test_loader.dataset)
    mae, mse, r2 = calculate_metrics(y_true, y_pred)
    print(f"Test Loss: {avg_loss:.4f}, MAE: {mae:.4f}, MSE: {mse:.4f}, R2: {r2:.4f}")


## 3.10. Set hyperparameter and train

In [37]:
config = {
    "num_layers": 2,
    "node_dim": 128,
    "edge_dim": 128,
    "output_dim": 1,
    "num_heads": 4,
    "max_in_degree": 5,
    "max_out_degree": 5,
    "max_path_distance": 5
}

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
random.seed(42)
torch.manual_seed(42)

# Load data and initialize model
train_loader, test_loader, num_node_features, num_edge_features = load_ESOL()
model = Graphormer(config, num_node_features, num_edge_features).to(device)
criterion = nn.L1Loss().to(device)  # Move criterion to GPU
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)

# Training and testing
train(model, train_loader, device, criterion, optimizer)
test(model, test_loader, device, criterion)

Processing...
Done!
Training: 100%|██████████| 15/15 [10:07<00:00, 40.48s/it]


Train Loss: 0.0286, MAE: 1.7572, MSE: 5.0952, R2: -0.1830


Testing: 100%|██████████| 4/4 [00:42<00:00, 10.73s/it]

Test Loss: 0.0305, MAE: 1.7575, MSE: 4.4930, R2: 0.0495



