<a href="https://colab.research.google.com/github/batu-el/understanding-inductive-biases-of-gnns/blob/main/notebooks/Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
# Combining Attention values across heads - Avg
# Combining Attention values across layers - Matrix Multiply

# Setup

In [3]:
!pip install dgl torch_geometric torch openhgnn

# Install required python libraries
import os

# Install PyTorch Geometric and other libraries
if 'IS_GRADESCOPE_ENV' not in os.environ:
    print("Installing PyTorch Geometric")
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-geometric
    print("Installing other libraries")
    !pip install networkx
    !pip install lovely-tensors

Installing PyTorch Geometric



[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Installing other libraries



[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip





[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import os
import sys
import time
import math
import random
import itertools
from datetime import datetime
from typing import Mapping, Tuple, Sequence, List

import pandas as pd
import networkx as nx
import numpy as np
import scipy as sp

from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, LayerNorm, Module, ModuleList, Sequential
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
from torch.optim import Adam

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid

import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, dense_to_sparse, to_dense_batch, to_dense_adj

from torch_geometric.nn import GCNConv, GATConv, GATv2Conv

# from openhgnn.dataset import d

# from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum

import lovely_tensors as lt
lt.monkey_patch()

import matplotlib.pyplot as plt
import seaborn as sns

# import warnings
# warnings.filterwarnings("ignore", category=RuntimeWarning)
# warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=FutureWarning)

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))



All imports succeeded.
Python version 3.11.7 (tags/v3.11.7:fa7a6f2, Dec  4 2023, 19:24:49) [MSC v.1937 64 bit (AMD64)]
PyTorch version 2.6.0+cpu
PyG version 2.6.1


In [2]:
# Set random seed for deterministic results

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)
print("All seeds set.")

All seeds set.


In [3]:
print("Cuda available: {}".format(torch.cuda.is_available()))

Cuda available: False


# Datasets

In [4]:
from torch_geometric.datasets import IMDB

DATASETS = {}

# (same as the imdb4MAGNN dataset from OpenHGNN)
dataset = IMDB(root='/tmp/IMDB')
data = dataset[0]
DATASETS['IMDB'] = data
DATASETS

{'IMDB': HeteroData(
   movie={
     x=[4278, 3066],
     y=[4278],
     train_mask=[4278],
     val_mask=[4278],
     test_mask=[4278],
   },
   director={ x=[2081, 3066] },
   actor={ x=[5257, 3066] },
   (movie, to, director)={ edge_index=[2, 4278] },
   (movie, to, actor)={ edge_index=[2, 12828] },
   (director, to, movie)={ edge_index=[2, 4278] },
   (actor, to, movie)={ edge_index=[2, 12828] }
 )}

In [5]:
DATASETS["IMDB"].edge_types

[('movie', 'to', 'director'),
 ('movie', 'to', 'actor'),
 ('director', 'to', 'movie'),
 ('actor', 'to', 'movie')]

In [6]:
DATASETS["IMDB"][('movie', 'to', 'director')].edge_index[0].max()

tensor i64 4277

In [7]:
DATASETS["IMDB"][('movie', 'to', 'director')].edge_index[1].max()

tensor i64 2080

In [8]:
DATASETS["IMDB"].dense_adj = {}

In [9]:
for data_key in DATASETS:
    data = DATASETS[data_key]
    # data.dense_sp_matrix = SHORTEST_PATHS[data_key]
    # data.dense_adj = to_dense_adj(data.edge_index, max_num_nodes = data.x.shape[0])[0]

    data.dense_adj = {}
    for e_type in data.edge_types:

        src_type, _, dst_type = e_type 
        edge_index = data[e_type].edge_index

        num_src = data[src_type].num_nodes
        num_dst = data[dst_type].num_nodes
        
        src = edge_index[0]
        dst = edge_index[1] + num_src

        combined_edge_index = torch.stack([src, dst], dim=0)

        dense_adj = to_dense_adj(
            combined_edge_index,
            max_num_nodes=num_src + num_dst
        )[0]
        
        dense_adj.fill_diagonal_(1)
        
        # below approach createst a rectangular matrix where each dim corresponds to a row type. However, rectangular are required...
        # # instead of creating a square matrix, we will do rectangular of shape (node_A x node_B) since we only care about edges between these types
        # dense_adj = torch.zeros((num_src, num_dst), dtype=torch.long)
        # 
        # for src, dst in data[e_type].edge_index.T:
        #     dense_adj[src, dst] = 1
            
        data.dense_adj[e_type] = dense_adj

    DATASETS[data_key] = data

## Table 1: Dataset Statistics

# Models

In [14]:
from torch_geometric.nn import HeteroConv

class HeteroGNNModel(Module):
    def __init__(
            self, 
            metadata, 
            gnn_layer,
            out_node_type,
            in_feat_dims: dict[str, int],
            hidden_dim=128,
            out_dim=None,
            num_heads=1,
            num_layers=1,
            dropout=.8,
    ):
        """
        Heterogeneous GNN model wrapper for applying layers to each edge type. Specifically used for node classification. 
        
        :param metadata: from HeteroData object. Contains edge/node types
        :param gnn_layer: type of attention layer to use for each edge type
        :param node_type: type of node to perform classification on
        """
        super().__init__()
        
        self.out_node_type = out_node_type 
        
        # use for converting all node features to same representation size (i.e., hidden_size)
        self.lin_in = torch.nn.ModuleDict({
            node_type: torch.nn.Linear(in_feat_dims[node_type], hidden_dim)
            for node_type in metadata[0]
        })

        self.convs = HeteroConv({
            edge_type: gnn_layer(
                in_dim=hidden_dim,
                hidden_dim=hidden_dim,
                num_heads=num_heads,
                num_layers=num_layers,
                dropout=dropout,
            )
            for edge_type in metadata[1]
        }, aggr='sum'   # how embeddings for the same node type are combined across edge types
        )
 
        # self.lin_out = torch.nn.Sequential(
        #     Linear(hidden_dim, hidden_dim // 2),
        #     torch.nn.ReLU(),
        #     Linear(hidden_dim // 2, out_dim),
        # )
        
        self.lin_out = torch.nn.Linear(hidden_dim, out_dim)

        # self.attn_weights_dict = {}

    def forward(self, x_dict, edge_index_dict):
        # same representation size for each node type input feature (i.e., hidden_size)
        x_dict = {
            node_type: self.lin_in[node_type](x)
            for node_type, x in x_dict.items()
        }

        # run attention model on each edge type to produce final node embeddings for classification
        # returns a dict with embeddings for each node, so we need to retrieve the type used for classification
        x = self.convs(x_dict, edge_index_dict)[self.out_node_type]
        
        # classify nodes
        out = self.lin_out(x)
        return out.log_softmax(dim=-1)

In [11]:
# PyG example code: https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_cora.py

from torch_geometric.nn import GraphConv

# Replace with graph conv for bipartite message passing in heterogeneous graphs
class GNNModel(Module):

    def __init__(
            self,
            in_dim: int,
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            dropout: float = 0.5,
        ):
        super().__init__()

        self.layers = ModuleList()

        for layer in range(num_layers):
            self.layers.append(
                GraphConv(hidden_dim, hidden_dim, aggr="mean")
            )
            
        self.lin_src = torch.nn.Linear(in_dim, hidden_dim)
        self.lin_dst = torch.nn.Linear(in_dim, hidden_dim)
        self.dropout = dropout

    def forward(self, x, edge_index):
        if isinstance(x, tuple):
            x_src, x_dst = x
            x_src = self.lin_src(x_src)
            x_dst = self.lin_dst(x_dst)
        else:
            x_src = x_dst = self.lin_src(x)

        for conv in self.layers:
            x_out = conv(x, edge_index)  # x is a tuple in bipartite
            x_out = F.relu(x_out)
            x_out = F.dropout(x_out, p=self.dropout, training=self.training)
            x_dst = x_dst + x_out  # residual connection
            x = (x_src, x_dst)
            
        return x_dst


class SparseGraphTransformerModel(Module):
    def __init__(
            self,
            in_dim: int,
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            dropout: float = 0.5,
        ):
        super().__init__()
        
        self.lin_src = torch.nn.Linear(in_dim, hidden_dim)
        self.lin_dst = torch.nn.Linear(in_dim, hidden_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout,
                    batch_first = False
                )
            )
        self.dropout = dropout
        
        self.attn_weights_list = []

    def forward(self, x, dense_adj):

        if isinstance(x, tuple):
            x_src, x_dst = x
            x_src = self.lin_src(x_src)
            x_dst = self.lin_dst(x_dst)
        else:
            x_src = x_dst = self.lin_src(x)
    
        # concatenate for use in multi-head attention
        x_cat = torch.cat([x_src, x_dst], dim=0)   # [n_src + n_dst, hidden_dim]
        
        x_cat = x_cat.unsqueeze(1)  # [n_src + n_dst, 1, hidden_dim]
        
        # 4) Prepare the attention mask.
        #    MHA expects a bool mask of shape [T, T], where True = "NO ATTEND".
        #    So we invert dense_adj if we want 1=connected, 0=not connected.
        attn_mask = ~(dense_adj.bool())  # True where there's NO edge => block attention

        for layer in self.layers:
            x_in = x_cat
            x, attn_weights = layer(
                x_cat, x_cat, x_cat,
                attn_mask = attn_mask,
                need_weights = True,
                average_attn_weights = False
            )
            
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x_cat = x_in + x

            # self.attn_weights_list.append(attn_weights)

        # [n_src + n_dst, hidden_dim]
        x_cat = x_cat.squeeze(1)
        
        out_dst = x_cat[x_src.size(0):]
        

        return out_dst

# Skip for now since uses shortest path
class DenseGraphTransformerModel(Module):

    def __init__(
            self,
            in_dim: int,
            pos_enc_dim: int = 16,
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_pos_enc = Linear(pos_enc_dim, hidden_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )


        self.attn_bias_scale = torch.nn.Parameter(torch.tensor([10.0]))  # controls how much we initially bias our model to nearby nodes
        self.dropout = dropout
        self.attn_weights_list = []

    def forward(self, x, pos_enc, dense_sp_matrix):

        # x = self.lin_in(x) + self.lin_pos_enc(pos_enc)
        x = self.lin_in(x)  # no node positional encoding

        # attention bias
        # [i, j] -> inverse of shortest path distance b/w node i and j
        # diagonals -> self connection, set to 0
        # disconnected nodes -> -1
        attn_bias = self.attn_bias_scale * torch.nan_to_num(
            (1 / (torch.nan_to_num(dense_sp_matrix, nan=-1, posinf=-1, neginf=-1))),
            nan=0, posinf=0, neginf=0)
        #attn_bias = torch.ones_like(attn_bias)

        # TransformerEncoder
        # x = self.encoder(x, mask = attn_bias)


        for layer in self.layers:
            # MHSA layer
            # float mask adds learnable additive attention bias
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                attn_mask = attn_bias,
                average_attn_weights = False
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            self.attn_weights_list.append(attn_weights)

        # x = self.lin_out(x)
        # 
        # return x.log_softmax(dim=-1)
        return x

# doesn't use shortest path bias
class DenseGraphTransformerModel_V2(Module):
    def __init__(
            self,
            in_dim: int,
            pos_enc_dim: int = 16,
            hidden_dim: int = 128,
            num_heads: int = 1,
            num_layers: int = 1,
            dropout: float = 0.5,
        ):
        super().__init__()

        self.lin_in = Linear(in_dim, hidden_dim)
        self.lin_pos_enc = Linear(pos_enc_dim, hidden_dim)

        self.layers = ModuleList()
        for layer in range(num_layers):
            self.layers.append(
                MultiheadAttention(
                    embed_dim = hidden_dim,
                    num_heads = num_heads,
                    dropout = dropout
                )
            )

        self.attn_bias_scale = torch.nn.Parameter(torch.tensor([10.0]))  # controls how much we initially bias our model to nearby nodes
        self.dropout = dropout

    def forward(self, x, pos_enc, dense_sp_matrix):

        # x = self.lin_in(x) + self.lin_pos_enc(pos_enc)
        x = self.lin_in(x)  # no node positional encoding

        # attention bias
        # [i, j] -> inverse of shortest path distance b/w node i and j
        # diagonals -> self connection, set to 0
        # disconnected nodes -> -1
        # attn_bias = self.attn_bias_scale * torch.nan_to_num(
        #     (1 / (torch.nan_to_num(dense_sp_matrix, nan=-1, posinf=-1, neginf=-1))),
        #     nan=0, posinf=0, neginf=0
        # )
        #attn_bias = torch.ones_like(attn_bias)

        # TransformerEncoder
        # x = self.encoder(x, mask = attn_bias)

        self.attn_weights_list = []

        for layer in self.layers:
            # # TransformerEncoderLayer
            # # float mask adds learnable additive attention bias
            # x = layer(x, src_mask = attn_bias)

            # MHSA layer
            # float mask adds learnable additive attention bias
            x_in = x
            x, attn_weights = layer(
                x, x, x,
                # attn_mask = attn_bias,
                average_attn_weights = False
            )
            x = F.relu(x)
            x = F.dropout(x, self.dropout, training=self.training)
            x = x_in + x

            self.attn_weights_list.append(attn_weights)

        # x = self.lin_out(x)
        # 
        # return x.log_softmax(dim=-1)
        return x
        

In [37]:

# class GNNModel(Module):
# 
#     def __init__(
#             self,
#             in_dim: int = 3066,
#             hidden_dim: int = 128,
#             num_heads: int = 1,
#             num_layers: int = 1,
#             out_dim: int = 3,
#             dropout: float = 0.8,
#     ):
#         super().__init__()
# 
#         self.lin_in = Linear(in_dim, hidden_dim)
#         self.lin_out = Linear(hidden_dim, out_dim)
#         self.layers = ModuleList()
# 
#         for layer in range(num_layers):
#             self.layers.append(
#                 GraphConv(hidden_dim, hidden_dim)
#             )
#         self.dropout = torch.nn.Dropout(dropout)
# 
#     def forward(self, x, edge_index):
# 
#         x = self.lin_in(x)
# 
#         for layer in self.layers:
#             # conv -> activation ->  dropout -> residual
#             x_in = x
#             x = layer(x, edge_index)
#             x = F.relu(x)
#             x = self.dropout(x)
#             x = x_in + x
# 
#         return x

# Trainers

In [38]:
EPOCHS = 100
LEARNING_RATE = 1e-3
def Train_GCN(NUM_LAYERS,
              NUM_HEADS,
              data,
              out_node_type):

    OUT_DIM = len(data[out_node_type].y.unique())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = HeteroGNNModel(
        data.metadata(), 
        GNNModel, 
        in_feat_dims={node_type: data[node_type].x.shape[1] for node_type in data.metadata()[0]},
        out_node_type=out_node_type,
        hidden_dim=128,
        out_dim=OUT_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
    ).to(device)

    # uncomment for to_hetero approach. Similar results..
    # model = GNNModel(num_layers=NUM_LAYERS, num_heads=NUM_HEADS ,out_dim=OUT_DIM).to(device)
    # model = torch_geometric.nn.to_hetero(model, data.metadata(), aggr='sum')
    
    data = data.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    train_mask = data[out_node_type].train_mask
    val_mask = data[out_node_type].val_mask
    test_mask = data[out_node_type].test_mask
    y = data[out_node_type].y

    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x_dict, data.edge_index_dict)#[out_node_type].log_softmax(dim=-1)
        loss = F.nll_loss(out[train_mask], y[train_mask])
        loss.backward()
        optimizer.step()
        return float(loss.item())

    @torch.no_grad()
    def test():
        model.eval()
        out = model(data.x_dict, data.edge_index_dict)#[out_node_type].log_softmax(dim=-1)
        pred = out.argmax(dim=-1)

        accs = []
        for mask in [train_mask, val_mask, test_mask]:
            correct = (pred[mask] == y[mask]).sum()
            accs.append(float(correct) / mask.sum().item())
        return accs

    best_val_acc = 0.0
    test_acc = 0.0
    times = []

    for epoch in range(1, EPOCHS):
        start = time.time()

        loss = train()
        train_acc, val_acc, tmp_test_acc = test()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc

        # Uncomment to see progress each epoch
        print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
              f'Test: {tmp_test_acc:.4f}, Final Test: {test_acc:.4f}')

        times.append(time.time() - start)
    return {
        'train_acc': train_acc,
        'val_acc': val_acc,
        'test_acc': test_acc
    }, None

def Train_SparseGraphTransformerModel(NUM_LAYERS,
              NUM_HEADS,
              data,
              out_node_type):

    OUT_DIM = len(data[out_node_type].y.unique())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = HeteroGNNModel(
        data.metadata(),
        SparseGraphTransformerModel,
        in_feat_dims={node_type: data[node_type].x.shape[1] for node_type in data.metadata()[0]},
        out_node_type=out_node_type,
        hidden_dim=128,
        out_dim=OUT_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
    ).to(device)

    data = data.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_mask = data[out_node_type].train_mask
    val_mask = data[out_node_type].val_mask
    test_mask = data[out_node_type].test_mask
    y = data[out_node_type].y
    
    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x_dict, data.dense_adj)
        loss = F.nll_loss(out[train_mask], y[train_mask])
        loss.backward()
        optimizer.step()
        return float(loss.item())

    @torch.no_grad()
    def test():
        model.eval()
        out = model(data.x_dict, data.dense_adj)
        pred = out.argmax(dim=-1)

        accs = []
        for mask in [train_mask, val_mask, test_mask]:
            correct = (pred[mask] == y[mask]).sum()
            accs.append(float(correct) / mask.sum().item())
        return accs

    best_val_acc = 0.0
    test_acc = 0.0
    times = []

    for epoch in range(1, EPOCHS):
        start = time.time()

        loss = train()
        train_acc, val_acc, tmp_test_acc = test()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc

        # Uncomment to see progress each epoch
        print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
              f'Test: {tmp_test_acc:.4f}, Final Test: {test_acc:.4f}')

        times.append(time.time() - start)
        
    return {
        'train_acc': train_acc,
        'val_acc': val_acc,
        'test_acc': test_acc
    }, model.attn_weights_list

def Train_DenseGraphTransformerModel(NUM_LAYERS,
              NUM_HEADS,
              data):

    IN_DIM = data.x.shape[-1]
    OUT_DIM = len(data.y.unique())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = DenseGraphTransformerModel(num_layers=NUM_LAYERS, num_heads=NUM_HEADS, in_dim=IN_DIM,out_dim=OUT_DIM).to(device)
    data = data.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    def train():
        model.train()
        optimizer.zero_grad()
        # print(data.pos_enc)
        out = model(data.x, data.pos_enc, data.dense_sp_matrix)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        return float(loss)


    @torch.no_grad()
    def test():
        model.eval()
        pred, accs = model(data.x, data.pos_enc, data.dense_sp_matrix).argmax(dim=-1), []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
        return accs

    best_val_acc = test_acc = 0
    times = []
    for epoch in range(1, 100):
        start = time.time()
        loss = train()
        train_acc, val_acc, tmp_test_acc = test()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        # print(f'Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, '
        #       f'Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, '
        #       f'Final Test: {test_acc:.4f}')
        times.append(time.time() - start)
    # print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
    return {'train_acc':train_acc,'val_acc':val_acc,'test_acc':test_acc}, model.attn_weights_list

    # Notes
    # - Dense Transformer needs to be trained for a bit longer to reach low loss value
    # - Node positional encodings are not particularly useful
    # - Edge distance encodings are very useful
    # - Since Cora is highly homophilic, it is important to bias the attention towards nearby nodes

def Train_DenseGraphTransformerModel_V2(NUM_LAYERS,
              NUM_HEADS,
              data):

    IN_DIM = data.x.shape[-1]
    OUT_DIM = len(data.y.unique())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = DenseGraphTransformerModel_V2(num_layers=NUM_LAYERS, num_heads=NUM_HEADS, in_dim=IN_DIM,out_dim=OUT_DIM).to(device)
    data = data.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    def train():
        model.train()
        optimizer.zero_grad()
        out = model(data.x, data.pos_enc, data.dense_sp_matrix)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        return float(loss)

    @torch.no_grad()
    def test():
        model.eval()
        pred, accs = model(data.x, data.pos_enc, data.dense_sp_matrix).argmax(dim=-1), []
        for _, mask in data('train_mask', 'val_mask', 'test_mask'):
            accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
        return accs

    best_val_acc = test_acc = 0
    times = []
    for epoch in range(1, 100):
        start = time.time()
        loss = train()
        train_acc, val_acc, tmp_test_acc = test()
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        # print(f'Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, '
        #       f'Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, '
        #       f'Final Test: {test_acc:.4f}')
        times.append(time.time() - start)
    # print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
    return {'train_acc':train_acc,'val_acc':val_acc,'test_acc':test_acc}, model.attn_weights_list

    # Notes
    # - Dense Transformer needs to be trained for a bit longer to reach low loss value
    # - Node positional encodings are not particularly useful
    # - Edge distance encodings are very useful
    # - Since Cora is highly homophilic, it is important to bias the attention towards nearby nodes

# Training

## Training: 1 Layer, 1 Head

In [40]:
import tqdm
NUM_LAYERS = 3
NUM_HEADS = 1
# NUM_RUNS = 10
all_stats = {}
for data_key in DATASETS:
    print(f'Training on {data_key}')
    data = DATASETS[data_key]

    run_stats = {}

    # TODO: incorporate multiple different splits?
    # for mask_idx in tqdm.tqdm(range(NUM_RUNS)):
        
    out_node_type = list(data.y_dict.keys())[0]

    accuracy_statistics = {}
    attn_weights = {}

    accuracy_statistics['GCN'], attn_weights['GCN'] = Train_GCN(NUM_LAYERS, NUM_HEADS, data, out_node_type)
    # accuracy_statistics['SparseGraphTransformerModel'] , attn_weights['SparseGraphTransformerModel'] = Train_SparseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data, out_node_type)
    # # accuracy_statistics['DenseGraphTransformerModel'] , attn_weights['DenseGraphTransformerModel'] = Train_DenseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
    # accuracy_statistics['DenseGraphTransformerModel_V2'] , attn_weights['DenseGraphTransformerModel_V2'] = Train_DenseGraphTransformerModel_V2(NUM_LAYERS, NUM_HEADS, data)
    # attn_weights['SparseGraphTransformerModel'] = torch.stack(attn_weights['SparseGraphTransformerModel']).cpu()
    # # attn_weights['DenseGraphTransformerModel'] = torch.stack(attn_weights['DenseGraphTransformerModel']).cpu()
    # attn_weights['DenseGraphTransformerModel_V2'] = torch.stack(attn_weights['DenseGraphTransformerModel_V2']).cpu()
    # run_stats[mask_idx] = {'accuracy': accuracy_statistics, 'attentions': attn_weights}
    all_stats[data_key] = {'accuracy': accuracy_statistics, 'attentions': attn_weights}

import pickle
# with open('drive/MyDrive/Colab Notebooks/L65_Project/' + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'wb') as f:
#     pickle.dump(all_stats, f)

Training on IMDB
Epoch: 0001, Loss: 7.3085, Train: 0.0750, Val: 0.0850, Test: 0.0707, Final Test: 0.0707
Epoch: 0002, Loss: 5.4216, Train: 0.1125, Val: 0.0675, Test: 0.0610, Final Test: 0.0707
Epoch: 0003, Loss: 4.8641, Train: 0.4425, Val: 0.2975, Test: 0.3051, Final Test: 0.3051
Epoch: 0004, Loss: 4.6762, Train: 0.5650, Val: 0.3800, Test: 0.3870, Final Test: 0.3870
Epoch: 0005, Loss: 4.4191, Train: 0.6125, Val: 0.3925, Test: 0.4057, Final Test: 0.4057
Epoch: 0006, Loss: 4.3768, Train: 0.6525, Val: 0.4075, Test: 0.4166, Final Test: 0.4166
Epoch: 0007, Loss: 4.2381, Train: 0.7350, Val: 0.4050, Test: 0.4436, Final Test: 0.4166
Epoch: 0008, Loss: 4.0282, Train: 0.8125, Val: 0.4425, Test: 0.4687, Final Test: 0.4687
Epoch: 0009, Loss: 4.0227, Train: 0.8400, Val: 0.4675, Test: 0.4885, Final Test: 0.4885
Epoch: 0010, Loss: 4.0902, Train: 0.8700, Val: 0.4825, Test: 0.5000, Final Test: 0.5000
Epoch: 0011, Loss: 4.0118, Train: 0.8825, Val: 0.4725, Test: 0.5072, Final Test: 0.5000
Epoch: 0012, Lo

KeyboardInterrupt: 

## Training: 1 Layer, 2 Heads

In [None]:
import tqdm
NUM_LAYERS = 1
NUM_HEADS = 2
NUM_RUNS = 10
all_stats = {}
for data_key in DATASETS:
    print(f'Training on {data_key}')
    data = DATASETS[data_key]

    TRAIN_MASKS = data.train_mask
    VAL_MASKS = data.val_mask
    TEST_MASKS = data.test_mask

    run_stats = {}

    for mask_idx in tqdm.tqdm(range(NUM_RUNS)):
        data.train_mask = TRAIN_MASKS[:,mask_idx]
        data.val_mask = VAL_MASKS[:,mask_idx]
        data.test_mask = TEST_MASKS[:,mask_idx]

        accuracy_statistics = {}
        attn_weights = {}

        accuracy_statistics['GCN'], attn_weights['GCN'] = Train_GCN(NUM_LAYERS, NUM_HEADS, data)
        accuracy_statistics['SparseGraphTransformerModel'] , attn_weights['SparseGraphTransformerModel'] = Train_SparseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
        # accuracy_statistics['DenseGraphTransformerModel'] , attn_weights['DenseGraphTransformerModel'] = Train_DenseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
        accuracy_statistics['DenseGraphTransformerModel_V2'] , attn_weights['DenseGraphTransformerModel_V2'] = Train_DenseGraphTransformerModel_V2(NUM_LAYERS, NUM_HEADS, data)
        attn_weights['SparseGraphTransformerModel'] = torch.stack(attn_weights['SparseGraphTransformerModel']).cpu()
        # attn_weights['DenseGraphTransformerModel'] = torch.stack(attn_weights['DenseGraphTransformerModel']).cpu()
        attn_weights['DenseGraphTransformerModel_V2'] = torch.stack(attn_weights['DenseGraphTransformerModel_V2']).cpu()
        run_stats[mask_idx] = {'accuracy': accuracy_statistics, 'attentions': attn_weights}
    all_stats[data_key] = run_stats
    data.train_mask = TRAIN_MASKS
    data.val_mask = VAL_MASKS
    data.test_mask = TEST_MASKS

import pickle
with open('drive/MyDrive/Colab Notebooks/L65_Project/' + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'wb') as f:
    pickle.dump(all_stats, f)

Training on Cora


100%|██████████| 10/10 [00:27<00:00,  2.73s/it]


Training on Citeseer


100%|██████████| 10/10 [00:32<00:00,  3.29s/it]


Training on Chameleon


100%|██████████| 10/10 [00:24<00:00,  2.50s/it]


Training on Squirrel


100%|██████████| 10/10 [00:52<00:00,  5.21s/it]


Training on Cornell


100%|██████████| 10/10 [00:19<00:00,  1.93s/it]


Training on Texas


100%|██████████| 10/10 [00:18<00:00,  1.87s/it]


Training on Wisconsin


100%|██████████| 10/10 [00:18<00:00,  1.82s/it]


## Training: 2 Layers, 1 Head

In [None]:
import tqdm
### Train Cora ###
NUM_LAYERS = 2
NUM_HEADS = 1
NUM_RUNS = 10
# data_key = 'Wisconsin'
all_stats = {}
for data_key in DATASETS:
    print(f'Training on {data_key}')
    data = DATASETS[data_key]

    TRAIN_MASKS = data.train_mask
    VAL_MASKS = data.val_mask
    TEST_MASKS = data.test_mask

    run_stats = {}

    for mask_idx in tqdm.tqdm(range(NUM_RUNS)):
        data.train_mask = TRAIN_MASKS[:,mask_idx]
        data.val_mask = VAL_MASKS[:,mask_idx]
        data.test_mask = TEST_MASKS[:,mask_idx]

        accuracy_statistics = {}
        attn_weights = {}

        accuracy_statistics['GCN'], attn_weights['GCN'] = Train_GCN(NUM_LAYERS, NUM_HEADS, data)
        accuracy_statistics['SparseGraphTransformerModel'] , attn_weights['SparseGraphTransformerModel'] = Train_SparseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
        # accuracy_statistics['DenseGraphTransformerModel'] , attn_weights['DenseGraphTransformerModel'] = Train_DenseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
        accuracy_statistics['DenseGraphTransformerModel_V2'] , attn_weights['DenseGraphTransformerModel_V2'] = Train_DenseGraphTransformerModel_V2(NUM_LAYERS, NUM_HEADS, data)
        attn_weights['SparseGraphTransformerModel'] = torch.stack(attn_weights['SparseGraphTransformerModel'])
        # attn_weights['DenseGraphTransformerModel'] = torch.stack(attn_weights['DenseGraphTransformerModel'])
        attn_weights['DenseGraphTransformerModel_V2'] = torch.stack(attn_weights['DenseGraphTransformerModel_V2'])
        run_stats[mask_idx] = {'accuracy': accuracy_statistics, 'attentions': attn_weights}
    all_stats[data_key] = run_stats
    data.train_mask = TRAIN_MASKS
    data.val_mask = VAL_MASKS
    data.test_mask = TEST_MASKS

import pickle
with open('drive/MyDrive/Colab Notebooks/L65_Project/' + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'wb') as f:
    pickle.dump(all_stats, f)

Training on Cora


100%|██████████| 10/10 [00:29<00:00,  2.91s/it]


Training on Citeseer


100%|██████████| 10/10 [00:33<00:00,  3.34s/it]


Training on Chameleon


100%|██████████| 10/10 [00:27<00:00,  2.79s/it]


Training on Squirrel


100%|██████████| 10/10 [00:56<00:00,  5.69s/it]


Training on Cornell


100%|██████████| 10/10 [00:27<00:00,  2.79s/it]


Training on Texas


100%|██████████| 10/10 [00:27<00:00,  2.79s/it]


Training on Wisconsin


100%|██████████| 10/10 [00:27<00:00,  2.74s/it]


## Training: 2 Layers, 2 Heads

In [None]:
import gc
import tqdm
### Train Cora ###
NUM_LAYERS = 2
NUM_HEADS = 2
NUM_RUNS = 10
# data_key = 'Wisconsin'
all_stats = {}
for data_key in DATASETS:
    print(f'Training on {data_key}')
    data = DATASETS[data_key]

    TRAIN_MASKS = data.train_mask
    VAL_MASKS = data.val_mask
    TEST_MASKS = data.test_mask

    run_stats = {}

    for mask_idx in tqdm.tqdm(range(NUM_RUNS)):
        gc.collect()
        data.train_mask = TRAIN_MASKS[:,mask_idx]
        data.val_mask = VAL_MASKS[:,mask_idx]
        data.test_mask = TEST_MASKS[:,mask_idx]

        accuracy_statistics = {}
        attn_weights = {}

        accuracy_statistics['GCN'], attn_weights['GCN'] = Train_GCN(NUM_LAYERS, NUM_HEADS, data)
        accuracy_statistics['SparseGraphTransformerModel'] , attn_weights['SparseGraphTransformerModel'] = Train_SparseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
        # accuracy_statistics['DenseGraphTransformerModel'] , attn_weights['DenseGraphTransformerModel'] = Train_DenseGraphTransformerModel(NUM_LAYERS, NUM_HEADS, data)
        accuracy_statistics['DenseGraphTransformerModel_V2'] , attn_weights['DenseGraphTransformerModel_V2'] = Train_DenseGraphTransformerModel_V2(NUM_LAYERS, NUM_HEADS, data)
        attn_weights['SparseGraphTransformerModel'] = torch.stack(attn_weights['SparseGraphTransformerModel']).cpu()
        # attn_weights['DenseGraphTransformerModel'] = torch.stack(attn_weights['DenseGraphTransformerModel']).cpu()
        attn_weights['DenseGraphTransformerModel_V2'] = torch.stack(attn_weights['DenseGraphTransformerModel_V2']).cpu()
        run_stats[mask_idx] = {'accuracy': accuracy_statistics, 'attentions': attn_weights}

    all_stats[data_key] = run_stats
    data.train_mask = TRAIN_MASKS
    data.val_mask = VAL_MASKS
    data.test_mask = TEST_MASKS

import pickle
with open('drive/MyDrive/Colab Notebooks/L65_Project/' + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'wb') as f:
    pickle.dump(all_stats, f)

Training on Cora


100%|██████████| 10/10 [00:48<00:00,  4.80s/it]


Training on Citeseer


100%|██████████| 10/10 [00:58<00:00,  5.83s/it]


Training on Chameleon


100%|██████████| 10/10 [00:42<00:00,  4.29s/it]


Training on Squirrel


100%|██████████| 10/10 [01:35<00:00,  9.52s/it]


Training on Cornell


100%|██████████| 10/10 [00:30<00:00,  3.03s/it]


Training on Texas


100%|██████████| 10/10 [00:29<00:00,  2.99s/it]


Training on Wisconsin


100%|██████████| 10/10 [00:30<00:00,  3.03s/it]


# Analysis


## Table 2: Accuracy Statistics

In [None]:
### Table 2 ###
### Accuracy Statistics ###
pd.set_option('display.max_columns', None)

all_stats_df = {}
for data_key in all_stats:
  run_stats = all_stats[data_key]
  table1 = pd.concat({key : pd.DataFrame(run_stats[key]['accuracy']) for key in run_stats}, axis=0)
  table1_train = pd.concat({'mean': table1.mean(level=1, axis=0).loc['train_acc'], 'std':table1.std(level=1).loc['train_acc']}, axis=1)
  table1_test = pd.concat({'mean': table1.mean(level=1, axis=0).loc['test_acc'], 'std':table1.std(level=1).loc['test_acc']}, axis=1)
  # table1 = pd.concat({'Train': table1_train, 'Test': table1_test}, axis=1)
  table1 = table1_test
  all_stats_df[data_key] = table1
pd.concat(all_stats_df, axis=1).round(2)