In [2]:
# Combining Attention values across heads - Avg
# Combining Attention values across layers - Matrix Multiply

# Setup

In [3]:
# !pip install dgl torch_geometric torch

# Install required python libraries
import os

# Install PyTorch Geometric and other libraries
# if 'IS_GRADESCOPE_ENV' not in os.environ:
#     print("Installing PyTorch Geometric")
#     !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
#     !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
#     !pip install -q torch-geometric
#     print("Installing other libraries")
#     !pip install networkx
#     !pip install lovely-tensors

In [4]:
import os
import sys
import time
import math
import random
import itertools
from datetime import datetime
from typing import Mapping, Tuple, Sequence, List

import pandas as pd
import networkx as nx
import numpy as np
import scipy as sp

from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, LayerNorm, Module, ModuleList, Sequential
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
from torch.optim import Adam

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid

import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, dense_to_sparse, to_dense_batch, to_dense_adj

from torch_geometric.nn import GCNConv, GATConv, GATv2Conv

# from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum

import lovely_tensors as lt
lt.monkey_patch()

import matplotlib.pyplot as plt
import seaborn as sns

# import warnings
# warnings.filterwarnings("ignore", category=RuntimeWarning)
# warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=FutureWarning)

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))



All imports succeeded.
Python version 3.11.7 (tags/v3.11.7:fa7a6f2, Dec  4 2023, 19:24:49) [MSC v.1937 64 bit (AMD64)]
PyTorch version 2.6.0+cu118
PyG version 2.6.1


In [5]:
# Set random seed for deterministic results

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)
print("All seeds set.")

All seeds set.


In [6]:
print("Cuda available: {}".format(torch.cuda.is_available()))

Cuda available: True


# Datasets

In [7]:
# use igraph because much faster than networkx
import igraph as ig
from tqdm import tqdm

def add_position_info(dataset, pe_dim=16):
    
    processed_dataset = []
    pe_transform = T.AddLaplacianEigenvectorPE(k=pe_dim, attr_name='pos_enc')

    for data in tqdm(dataset, total=len(dataset), desc='Processing disjoint graphs'):
        num_nodes = data.num_nodes
        edge_index = data.edge_index

        # Convert to igraph
        edges = edge_index.t().tolist()
        g = ig.Graph(n=num_nodes, edges=edges, directed=False)

        # compute shortest path distances
        sp_matrix = torch.tensor(g.distances(algorithm="johnson"), dtype=torch.float16)
        sp_matrix[torch.isinf(sp_matrix)] = 0
        data.dense_sp_matrix = sp_matrix

        dense_adj = to_dense_adj(edge_index, max_num_nodes=num_nodes)[0]
        dense_adj = dense_adj + torch.eye(num_nodes, dtype=dense_adj.dtype)
        dense_adj[dense_adj == 2] = 1  # remove double self-loops
        data.dense_adj = dense_adj

        # add Laplacian eigenvectors as positional encoding
        data = pe_transform(data)

        processed_dataset.append(data)

    return processed_dataset

In [8]:
from torch_geometric.datasets import PPI
from torch_geometric.loader import DataLoader

train = add_position_info(PPI(root='/tmp/PPI', split='train'))
val = add_position_info(PPI(root='/tmp/PPI', split='val'))
test = add_position_info(PPI(root='/tmp/PPI', split='test'))

train = DataLoader(train, batch_size=1, shuffle=True)
val = DataLoader(val, batch_size=1)
test = DataLoader(test, batch_size=1)

Processing disjoint graphs: 100%|██████████| 20/20 [00:13<00:00,  1.52it/s]
Processing disjoint graphs: 100%|██████████| 2/2 [00:02<00:00,  1.24s/it]
Processing disjoint graphs: 100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


In [9]:
data = next(iter(train))

In [10]:
ex_spd = data.dense_sp_matrix
ex_spd.shape

torch.Size([3021, 3021])

In [11]:
tensor_memory_bytes = ex_spd.element_size() * ex_spd.numel()
tensor_memory_MB = tensor_memory_bytes / (1024 ** 2)  # Convert to MB
print(f"Tensor memory usage: {tensor_memory_MB:.2f} MB")

Tensor memory usage: 17.41 MB


In [12]:
data

DataBatch(x=[3021, 50], edge_index=[2, 91338], y=[3021, 121], dense_sp_matrix=[3021, 3021], dense_adj=[3021, 3021], pos_enc=[3021, 16], batch=[3021], ptr=[2])

## Table 1: Dataset Statistics

In [13]:
# ### Table 1 ###
# ### Dataset Statistics ###
# import dgl
# Homophily_Levels = []
# 
# for data in train:
#   edge_index_tensor = torch.tensor(data.edge_index.cpu().numpy(), dtype=torch.long)
#   g = dgl.graph((edge_index_tensor[0], edge_index_tensor[1]), num_nodes=data.x.shape[0])
#   g.ndata['y'] = torch.tensor(data.y.cpu().numpy(), dtype=torch.long)
#   Homophily_Levels.append({'Node Homophily':dgl.node_homophily(g, g.ndata['y'])*100,
#                                 'Edge Homophily':dgl.edge_homophily(g, g.ndata['y'])*100,
#                                 'Adjusted Homophily':dgl.adjusted_homophily(g, g.ndata['y'])*100,
#                                 'Number of Nodes': int(g.num_nodes()),
#                                 'Number of Edges': int(g.num_edges())
#                                 })
# df = pd.DataFrame(Homophily_Levels).round(1)
# df

# Models

In [None]:
# PyG example code: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_cora.py

class GNNModel(Module):

    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            hidden_dim: int = HIDDEN,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)
        self.layers = ModuleList()

        for layer in range(num_layers):
            self.layers.append(
                GCNConv(hidden_dim, hidden_dim)
            )
        self.dropout = dropout

    def forward(self, x, edge_index):

        x = self.lin_in(x)

        for layer in self.layers:
            # conv -> activation ->  dropout -> residual
            x_in = x
            x = layer(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

        x = self.lin_out(x)
        # return x.log_softmax(dim=-1)
        return x


class SparseGraphTransformerModel(Module):
    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            hidden_dim: int = HIDDEN,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )
        self.dropout = dropout
        self.save_attn = False

    def forward(self, x, dense_adj):

        x = self.lin_in(x)

        self.attn_weights_list = []

        for layer in self.layers:
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                attn_mask = ~dense_adj.bool(),
                average_attn_weights = True # the paper already averages so may as well just average here...
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            if self.save_attn:
                self.attn_weights_list.append(attn_weights)

        x = self.lin_out(x)

        # return x.log_softmax(dim=-1)
        return x

class DenseGraphTransformerModel(Module):

    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            pos_enc_dim: int = 16,
            hidden_dim: int = HIDDEN,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_pos_enc = Linear(pos_enc_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )

        self.attn_bias_scale = torch.nn.Parameter(torch.tensor([10.0]))  # controls how much we initially bias our model to nearby nodes
        self.dropout = dropout
        self.save_attn = False

    def forward(self, x, pos_enc, dense_sp_matrix):

        # x = self.lin_in(x) + self.lin_pos_enc(pos_enc)
        x = self.lin_in(x)  # no node positional encoding

        # attention bias
        # [i, j] -> inverse of shortest path distance b/w node i and j
        # diagonals -> self connection, set to 0
        # disconnected nodes -> -1
        attn_bias = self.attn_bias_scale * torch.nan_to_num(
            (1 / (torch.nan_to_num(dense_sp_matrix, nan=-1, posinf=-1, neginf=-1))),
            nan=0, posinf=0, neginf=0)
        #attn_bias = torch.ones_like(attn_bias)

        # TransformerEncoder
        # x = self.encoder(x, mask = attn_bias)

        self.attn_weights_list = []

        for layer in self.layers:
            # MHSA layer
            # float mask adds learnable additive attention bias
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                attn_mask = attn_bias,
                average_attn_weights = True    # the paper already averages so may as well just average here...
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            if self.save_attn:
                self.attn_weights_list.append(attn_weights)

        x = self.lin_out(x)
        # return x.log_softmax(dim=-1)
        return x


class DenseGraphTransformerModel_V2(Module):
    def __init__(
            self,
            in_dim: int = data.x.shape[-1],
            pos_enc_dim: int = 16,
            hidden_dim: int = HIDDEN,
            num_heads: int = 1,
            num_layers: int = 1,
            out_dim: int = len(data.y.unique()),
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_pos_enc = Linear(pos_enc_dim, hidden_dim)
        self.lin_out = Linear(hidden_dim, out_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )

        self.attn_bias_scale = torch.nn.Parameter(torch.tensor([10.0]))  # controls how much we initially bias our model to nearby nodes
        self.dropout = dropout
        self.save_attn = False

    def forward(self, x, pos_enc, dense_sp_matrix):

        x = self.lin_in(x) + self.lin_pos_enc(pos_enc)

        self.attn_weights_list = []

        for layer in self.layers:

            # MHSA layer
            # float mask adds learnable additive attention bias
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                average_attn_weights = True # the paper already averages so may as well just average here...
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            if self.save_attn:
                self.attn_weights_list.append(attn_weights)

        x = self.lin_out(x)

        # return x.log_softmax(dim=-1)
        return x

# Trainers

In [None]:
from sklearn.metrics import f1_score

In [None]:
def Train_GCN(NUM_LAYERS, NUM_HEADS, train_loader, val_loader, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    IN_DIM = train_loader.dataset[0].x.shape[1]
    OUT_DIM = train_loader.dataset[0].y.shape[1]  # 121 for PPI

    model = GNNModel(num_layers=NUM_LAYERS, num_heads=NUM_HEADS, in_dim=IN_DIM, out_dim=OUT_DIM).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=.5)

    def train():
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.edge_index)
            loss = loss_fn(out, batch.y.float())  # y is multi-label
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss / len(train_loader)

    def test(loader):
        model.eval()
        y_true, y_pred = [], []
        
        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.edge_index)          # or …dense_adj / …spd
            probs  = torch.sigmoid(logits)                     # convert logits → probabilities
            preds  = (probs > 0.5).cpu().numpy().astype(int)   # threshold at 0.5
            y_true.append(batch.y.cpu().numpy())
            y_pred.append(preds)

        y_true = np.vstack(y_true)
        y_pred = np.vstack(y_pred)
        micro_f1 = f1_score(y_true, y_pred, average="micro")
        return micro_f1

    best_val_f1 = test_f1 = 0
    times = []

    for epoch in range(1, EPOCHS):
        start = time.time()
        loss = train()
        scheduler.step()
        train_f1 = test(train_loader)
        val_f1 = test(val_loader)
        tmp_test_f1 = test(test_loader)

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            test_f1 = tmp_test_f1

        times.append(time.time() - start)
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss:.4f}, Train: {train_f1:.4f}, Val: {val_f1:.4f}, Test: {test_f1:.4f}')

    return {
        'train_f1': train_f1,
        'val_f1': val_f1,
        'test_f1': test_f1
    }, None


def Train_SparseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, train_loader, val_loader, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Get dimensions from dataset
    IN_DIM = train_loader.dataset[0].x.shape[1]
    OUT_DIM = train_loader.dataset[0].y.shape[1]

    model = SparseGraphTransformerModel(
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        in_dim=IN_DIM,
        out_dim=OUT_DIM
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = torch.nn.BCEWithLogitsLoss()
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=.5)

    def train():
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.dense_adj)
            loss = loss_fn(out, batch.y.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss / len(train_loader)

    def test(loader):
        model.eval()
        y_true, y_pred = [], []

        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.dense_adj)          # or …dense_adj / …spd
            probs  = torch.sigmoid(logits)                     # convert logits → probabilities
            preds  = (probs > 0.5).cpu().numpy().astype(int)   # threshold at 0.5
            y_true.append(batch.y.cpu().numpy())
            y_pred.append(preds)

        y_true = np.vstack(y_true)
        y_pred = np.vstack(y_pred)
        micro_f1 = f1_score(y_true, y_pred, average="micro")
        return micro_f1

    best_val_f1 = test_f1 = 0
    times = []

    for epoch in range(1, EPOCHS):
        start = time.time()
        loss = train()
        scheduler.step()
        train_f1 = test(train_loader)
        val_f1 = test(val_loader)
        tmp_test_f1 = test(test_loader)

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            test_f1 = tmp_test_f1

        times.append(time.time() - start)

        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss:.4f}, Train: {train_f1:.4f}, Val: {val_f1:.4f}, Test: {test_f1:.4f}')

    return {
        'train_f1': train_f1,
        'val_f1': val_f1,
        'test_f1': test_f1
    }, model.cpu()

def Train_DenseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, train_loader, val_loader, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    IN_DIM = train_loader.dataset[0].x.shape[1]
    OUT_DIM = train_loader.dataset[0].y.shape[1]

    model = DenseGraphTransformerModel(
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        in_dim=IN_DIM,
        out_dim=OUT_DIM
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=.5)
    
    def train():
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.pos_enc, batch.dense_sp_matrix)  # batch.spd is dense_sp_matrix
            loss = loss_fn(out, batch.y.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss / len(train_loader)

    def test(loader):
        model.eval()
        y_true, y_pred = [], []

        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.pos_enc, batch.dense_sp_matrix)          # or …dense_adj / …spd
            probs  = torch.sigmoid(logits)                     # convert logits → probabilities
            preds  = (probs > 0.5).cpu().numpy().astype(int)   # threshold at 0.5
            y_true.append(batch.y.cpu().numpy())
            y_pred.append(preds)

        y_true = np.vstack(y_true)
        y_pred = np.vstack(y_pred)
        micro_f1 = f1_score(y_true, y_pred, average="micro")
        return micro_f1

    best_val_f1 = test_f1 = 0
    times = []

    for epoch in range(1, EPOCHS):
        start = time.time()
        loss = train()
        scheduler.step()
        train_f1 = test(train_loader)
        val_f1 = test(val_loader)
        tmp_test_f1 = test(test_loader)

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            test_f1 = tmp_test_f1

        times.append(time.time() - start)

        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss:.4f}, Train: {train_f1:.4f}, Val: {val_f1:.4f}, Test: {test_f1:.4f}')

    return {
        'train_f1': train_f1,
        'val_f1': val_f1,
        'test_f1': test_f1
    }, model.cpu()
    
    # Notes
    # - Dense Transformer needs to be trained for a bit longer to reach low loss value
    # - Node positional encodings are not particularly useful
    # - Edge distance encodings are very useful
    # - Since Cora is highly homophilic, it is important to bias the attention towards nearby nodes

def Train_DenseGraphTransformerModel_V2(NUM_LAYERS, NUM_HEADS, train_loader, val_loader, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    IN_DIM = train_loader.dataset[0].x.shape[1]
    OUT_DIM = train_loader.dataset[0].y.shape[1]

    model = DenseGraphTransformerModel_V2(
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        in_dim=IN_DIM,
        out_dim=OUT_DIM
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=.5)
    
    def train():
        model.train()
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch.x, batch.pos_enc, batch.dense_sp_matrix)  # batch.spd is dense_sp_matrix
            loss = loss_fn(out, batch.y.float())
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        return total_loss / len(train_loader)

    def test(loader):
        model.eval()
        y_true, y_pred = [], []

        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.pos_enc, batch.dense_sp_matrix)          # or …dense_adj / …spd
            probs  = torch.sigmoid(logits)                     # convert logits → probabilities
            preds  = (probs > 0.5).cpu().numpy().astype(int)   # threshold at 0.5
            y_true.append(batch.y.cpu().numpy())
            y_pred.append(preds)

        y_true = np.vstack(y_true)
        y_pred = np.vstack(y_pred)
        micro_f1 = f1_score(y_true, y_pred, average="micro")
        return micro_f1

    best_val_f1 = test_f1 = 0
    times = []

    for epoch in range(1, EPOCHS):
        start = time.time()
        loss = train()
        scheduler.step()
        train_f1 = test(train_loader)
        val_f1 = test(val_loader)
        tmp_test_f1 = test(test_loader)

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            test_f1 = tmp_test_f1

        times.append(time.time() - start)

        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss:.4f}, Train: {train_f1:.4f}, Val: {val_f1:.4f}, Test: {test_f1:.4f}')

    return {
        'train_f1': train_f1,
        'val_f1': val_f1,
        'test_f1': test_f1
    }, model.cpu()  # if model tracks attention weights

# Training

## Training

In original paper, the models are re-trained across multiple runs with varying train/test splits. The attention weights are then averaged across runs to get a more robust attention representation that is less affected by train/test splits. 

However, in the inductive setting (20 different graphs), the analysis would need to be averaged for each graph separately, making this computationally expensive.

Additionally, in the inductive setting, the model may be intrinsically more stable (relative to train/test splits) due to generalizing across multiple graphs.

Therefore, we only train a single model for this task and analyze each graph-model combination separately.

In [80]:
EPOCHS = 10
LR = 2e-3
HIDDEN = 512
NUM_LAYERS = 4
NUM_HEADS = 4

accuracy_statistics = {}
models = {}

# Train all models with the same loaders
accuracy_statistics['GCN'], _ = Train_GCN(NUM_LAYERS, NUM_HEADS, train, val, test)
accuracy_statistics['SparseGraphTransformerModel'], models['SparseGraphTransformerModel'] = Train_SparseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, train, val, test)
accuracy_statistics['DenseGraphTransformerModel'], models['DenseGraphTransformerModel'] = Train_DenseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, train, val, test)
accuracy_statistics['DenseGraphTransformerModel_V2'], models['DenseGraphTransformerModel_V2'] = Train_DenseGraphTransformerModel_V2(NUM_LAYERS, NUM_HEADS, train, val, test)

accuracy_statistics

{'SparseGraphTransformerModel': {'train_f1': 0.6863282265028491,
  'val_f1': 0.6614007148671126,
  'test_f1': 0.6757455538923722}}

## Table 2: Accuracy Statistics

In [105]:
### Table 2 ###
### Accuracy Statistics ###
all_stats_df = pd.DataFrame(accuracy_statistics).T.round(2)
all_stats_df

Unnamed: 0,train_f1,val_f1,test_f1
SparseGraphTransformerModel,0.69,0.66,0.68


## Create Attention Graphs

In [87]:
def create_graph_from_attention(attention_matrix, threshold):
    """Creates a directed graph from an attention matrix."""
    attention_matrix = attention_matrix.numpy() 
    
    # mask out weights below the threshold
    attention_matrix = attention_matrix * (attention_matrix > threshold)
    
    G = nx.from_numpy_array(attention_matrix, create_using=nx.DiGraph())
    
    return G

def get_threshold(attention_matrix, model_name=None):
    """Returns threshold using different percentiles based on model."""
    attn_values = attention_matrix.flatten().cpu().numpy()

    # percentile = 99.5 if model_name == "dt2_avg" else 90
    percentile_threshold = np.percentile(attn_values, 99.5)
    mean_threshold = attn_values.mean() + 1.5 * attn_values.std()

    return max(percentile_threshold, mean_threshold)


In [94]:
from itertools import chain
from torch_geometric.utils import to_networkx

combined_loader = chain(train, val, test)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

graphs = {}

# create attention graph for each graph
for model_name in models.keys():
    model = models[model_name]
    graphs[model_name] = {} 
    
    model.save_attn = True
    
    model = model.to(device)
    for idx, g in enumerate(combined_loader):
        g = g.to(device)

        model.eval() 
        with torch.no_grad():
            # make prediction to save attention weights 
            if model_name == 'SparseGraphTransformerModel':
                model(g.x, g.dense_adj)
            else:
                model(g.x, g.pos_enc, g.dense_sp_matrix)
            
        # construct attention graph
        attn_weights = model.attn_weights_list
        
        # perform matrix multiplication across layers to aggregate graphs 
        attn_graph = attn_weights[0]
        
        for attn in attn_weights[1:]:
            attn_graph = attn_graph @ attn
        
        attn_graph = attn_graph.cpu()
        
        threshold = get_threshold(attn_graph)
        attn_graph = create_graph_from_attention(attn_graph, threshold)
        
        # convert to directed to match attention format
        og_graph = to_networkx(g, to_undirected=False)
        
        graphs[model_name][idx] = {
            "original_graph": og_graph,
            "attention_graph": attn_graph
        }

# Analysis


In [131]:
from scipy.stats import spearmanr

def compute_metrics(G):
    return pd.DataFrame.from_dict({
        'degree': nx.degree_centrality(G),
        'betweenness': nx.betweenness_centrality(G, normalized=True),
        'closeness': nx.closeness_centrality(G),
        'eigenvector': nx.eigenvector_centrality(G, max_iter=500, tol=1e-02),
        'clustering': nx.clustering(G),
        'pagerank': nx.pagerank(G, alpha=0.85)
    })

def compute_cross_metric_correlation(metrics1, metrics2):
    metrics1 = pd.DataFrame(metrics1)
    metrics2 = pd.DataFrame(metrics2)
    
    result = pd.DataFrame(index=metrics1.columns, columns=metrics2.columns)
    
    for col1 in metrics1.columns:
        for col2 in metrics2.columns:
            corr, _ = spearmanr(metrics1[col1], metrics2[col2])
            result.loc[col1, col2] = corr

    return result.astype(float)
def corr_heatmap(corr_df, title):
    plt.figure(figsize=(10, 7))
    sns.heatmap(corr_df, annot=True, fmt=".2f", vmin=-1, vmax=1, cmap='coolwarm', cbar=True)
    plt.title(title)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()

def topk_overlap(metrics1, metrics2, label1, label2, out_path):
    def get_top_k_nodes(metric_dict, k=100):
        return {
            metric: set(sorted(metric_dict[metric].items(), key=lambda x: x[1], reverse=True)[:k])
            for metric in metric_dict
        }

    topk_1 = get_top_k_nodes(metrics1)
    topk_2 = get_top_k_nodes(metrics2)

    with open(out_path, "w") as f:
        f.write(f"Top-100 Node Overlap Between {label1} and {label2}\n")
        for metric in metrics1:
            nodes_1 = {node for node, _ in topk_1[metric]}
            nodes_2 = {node for node, _ in topk_2[metric]}



In [132]:
def analyze_graphs(model_stats, model_name):
    total_nodes = 0
    corr_df = None
    for g_idx, g in tqdm(list(model_stats.items())[:5], total=len(model_stats)):
        og_graph = g['original_graph']
        attn_graph = g['attention_graph']
        
        num_nodes = og_graph.number_of_nodes()
        total_nodes += num_nodes
        
        og_graph_stats = compute_metrics(og_graph)
        attn_graph_stats = compute_metrics(attn_graph)
        
        corr = compute_cross_metric_correlation(og_graph_stats, attn_graph_stats)
        corr *= num_nodes
        if corr_df is None:
            corr_df = corr
        else:
            corr_df += corr
        
    corr_df /= total_nodes 
    
    corr_heatmap(corr_df, title=model_name)

In [None]:
sparse = list(graphs.keys())[0]
analyze_graphs(graphs[sparse], sparse)