<a href="https://colab.research.google.com/github/batu-el/understanding-inductive-biases-of-gnns/blob/main/notebooks/Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [None]:
!pip install torch torchdata torch_geometric dgl

# Install required python libraries
import os

# Install PyTorch Geometric and other libraries
if 'IS_GRADESCOPE_ENV' not in os.environ:
    print("Installing PyTorch Geometric")
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-geometric
    print("Installing other libraries")
    !pip install networkx
    !pip install lovely-tensors

Collecting torchdata
  Downloading torchdata-0.11.0-py3-none-any.whl.metadata (6.3 kB)
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dgl
  Downloading dgl-2.1.0-cp311-cp311-manylinux1_x86_64.whl.metadata (553 bytes)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.

In [None]:
import os
import sys
import time
import math
import random
import itertools
from datetime import datetime
from typing import Mapping, Tuple, Sequence, List

import pandas as pd
import networkx as nx
import numpy as np
import scipy as sp

from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, LayerNorm, Module, ModuleList, Sequential
from torch.nn import TransformerEncoder, TransformerEncoderLayer, MultiheadAttention
from torch.optim import Adam

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import Planetoid

import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, dense_to_sparse, to_dense_batch, to_dense_adj

from torch_geometric.nn import GCNConv, GATConv, GATv2Conv

# from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum

import lovely_tensors as lt
lt.monkey_patch()

import matplotlib.pyplot as plt
import seaborn as sns

# import warnings
# warnings.filterwarnings("ignore", category=RuntimeWarning)
# warnings.filterwarnings("ignore", category=UserWarning)
# warnings.filterwarnings("ignore", category=FutureWarning)

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))



All imports succeeded.
Python version 3.11.11 (main, Dec  4 2024, 08:55:07) [GCC 11.4.0]
PyTorch version 2.6.0+cu124
PyG version 2.6.1


In [None]:
# Set random seed for deterministic results

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)
print("All seeds set.")

All seeds set.


# Datasets

In [None]:
from torch_geometric.datasets import WebKB, WikipediaNetwork

DATASETS = {}

# Chamelion & Squirrel
# Cora & Citeseer
# Cornell & Texas & Wisconsin

## Mid Size Datasets
# Citation Networks
dataset = 'Cora'
dataset = Planetoid('/tmp/Cora', dataset)
data = dataset[0]
DATASETS['Cora'] = data
dataset = 'Citeseer'
dataset = Planetoid('/tmp/Citeseer', dataset)
data = dataset[0]
DATASETS['Citeseer'] = data
# Wikipedia Pages
dataset = 'Chameleon'
dataset = WikipediaNetwork(root='/tmp/Chameleon', name='Chameleon')
data = dataset[0]
DATASETS['Chameleon'] = data
dataset = 'Squirrel'
dataset = WikipediaNetwork(root='/tmp/Squirrel', name='Squirrel')
data = dataset[0]
DATASETS['Squirrel'] = data
### Small Sized Datasets
# Web Pages
dataset = WebKB(root='/tmp/Wisconsin', name='Wisconsin')
data = dataset[0]
DATASETS['Wisconsin'] = data
dataset = WebKB(root='/tmp/Cornell', name='Cornell')
data = dataset[0]
DATASETS['Cornell'] = data
dataset = WebKB(root='/tmp/Texas', name='Texas')
data = dataset[0]
DATASETS['Texas'] = data

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.citeseer.y
Dow

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# import tqdm
# ### Shortest Paths ###
# def get_shortest_path_matrix(adjacency_matrix):
#     graph = nx.from_numpy_array(adjacency_matrix.cpu().numpy(), create_using=nx.DiGraph)
#     shortest_path_matrix = nx.floyd_warshall_numpy(graph)
#     shortest_path_matrix = torch.tensor(shortest_path_matrix).float()
#     return shortest_path_matrix

# SHORTEST_PATHS = {}
# for data_key in tqdm.tqdm(DATASETS):
#   print(data_key)
#   data = DATASETS[data_key]
#   dense_adj = to_dense_adj(data.edge_index, max_num_nodes = data.x.shape[0])[0]
#   dense_shortest_path_matrix = get_shortest_path_matrix(dense_adj)
#   SHORTEST_PATHS[data_key] = dense_shortest_path_matrix

# ### Save the Shortest Paths
# import pickle
# with open('sp_dict.pkl', 'wb') as f:
#     pickle.dump(SHORTEST_PATHS, f)
import pickle
with open('drive/MyDrive/Colab Notebooks/sp_dict2.pkl', 'rb') as f:
# with open('drive/MyDrive/L65/shortest_paths/sp_dict.pkl', 'rb') as f:
    SHORTEST_PATHS = pickle.load(f)

In [None]:
for data_key in DATASETS:
  data = DATASETS[data_key]
  data.dense_sp_matrix = SHORTEST_PATHS[data_key]
  data.dense_adj = to_dense_adj(data.edge_index, max_num_nodes = data.x.shape[0])[0]
  data.dense_adj = data.dense_adj.cuda() + torch.eye(data.dense_adj.shape[0]).cuda()
  data.dense_adj[data.dense_adj == 2] = 1
  data = T.AddLaplacianEigenvectorPE(k = 16, attr_name = 'pos_enc')(data)
  DATASETS[data_key] = data

In [None]:
### Masks ###

def generate_masks(num_nodes=None,num_runs=None,train_ratio=None, val_ratio=None):
    masks = { 'train_mask': np.zeros((num_nodes, num_runs), dtype=int),
              'val_mask': np.zeros((num_nodes, num_runs), dtype=int),
              'test_mask': np.zeros((num_nodes, num_runs), dtype=int)}

    for run in range(num_runs):
        indices = np.arange(num_nodes)
        np.random.shuffle(indices)
        train_end = int(train_ratio * num_nodes)
        val_end = train_end + int(val_ratio * num_nodes)
        masks['train_mask'][indices[:train_end], run] = 1
        masks['val_mask'][indices[train_end:val_end], run] = 1
        masks['test_mask'][indices[val_end:], run] = 1

    tensor_masks = {'train_mask': torch.tensor(masks['train_mask']),
                    'val_mask':torch.tensor(masks['val_mask']),
                    'test_mask':torch.tensor(masks['test_mask'])}
    return tensor_masks

for data_key in DATASETS:
    data = DATASETS[data_key]

    masks = generate_masks(num_nodes=data.x.shape[0], num_runs=10, train_ratio=0.4, val_ratio=0.3)
    data.train_mask = masks['train_mask'].bool()
    data.val_mask = masks['val_mask'].bool()
    data.test_mask = masks['test_mask'].bool()

    if len(data.train_mask.shape)==1:
      print('Add 10 Masks')
    else:
      print('We have 10 Masks')
      print('Train Ratio:',(data.train_mask[:,0].sum() / len(data.train_mask[:,0])).item())
      print('Val Ratio:',(data.val_mask[:,0].sum() / len(data.val_mask[:,0])).item())
      print('Test Ratio:',(data.test_mask[:,0].sum() / len(data.test_mask[:,0])).item())

We have 10 Masks
Train Ratio: 0.39992615580558777
Val Ratio: 0.29985228180885315
Test Ratio: 0.3002215623855591
We have 10 Masks
Train Ratio: 0.39975953102111816
Val Ratio: 0.29996994137763977
Test Ratio: 0.30027052760124207
We have 10 Masks
Train Ratio: 0.39964866638183594
Val Ratio: 0.2999560832977295
Test Ratio: 0.30039525032043457
We have 10 Masks
Train Ratio: 0.39992308616638184
Val Ratio: 0.2999423146247864
Test Ratio: 0.3001345992088318
We have 10 Masks
Train Ratio: 0.39840638637542725
Val Ratio: 0.29880478978157043
Test Ratio: 0.3027888536453247
We have 10 Masks
Train Ratio: 0.3989070951938629
Val Ratio: 0.2950819730758667
Test Ratio: 0.3060109317302704
We have 10 Masks
Train Ratio: 0.3989070951938629
Val Ratio: 0.2950819730758667
Test Ratio: 0.3060109317302704


## Table 1: Dataset Statistics

In [None]:
### Table 1 ###
### Dataset Statistics ###
# import dgl
# Homophily_Levels = {}

# for data_key in DATASETS:
#   data = DATASETS[data_key]
#   edge_index_tensor = torch.tensor(data.edge_index.cpu().numpy(), dtype=torch.long)
#   g = dgl.graph((edge_index_tensor[0], edge_index_tensor[1]), num_nodes=data.x.shape[0])
#   g.ndata['y'] = torch.tensor(data.y.cpu().numpy(), dtype=torch.long)
#   Homophily_Levels[data_key] = {'Node Homophily':dgl.node_homophily(g, g.ndata['y'])*100,
#                                 'Edge Homophily':dgl.edge_homophily(g, g.ndata['y'])*100,
#                                 'Adjusted Homophily':dgl.adjusted_homophily(g, g.ndata['y'])*100,
#                                 'Number of Nodes': int(g.num_nodes()),
#                                 'Number of Edges': int(g.num_edges())
#                                 }
# df = pd.DataFrame(Homophily_Levels).round(1)
# df

# Cache Data

In [None]:
drive_path = 'drive/MyDrive/Colab Notebooks/' #replace with the directory of the trained models

Delete this trained file from drive and re-run from training, I don't think it saved correctly (had a problem before I bought more RAM and GPU)

In [None]:
import pickle

NUM_LAYERS = 1
NUM_HEADS = 1
# with open('drive/MyDrive/Colab Notebooks/' + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
with open(drive_path + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
  all_stats_1L_1H = pickle.load(f)

In [None]:
NUM_LAYERS = 1
NUM_HEADS = 2
# with open('drive/MyDrive/Colab Notebooks/' + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
with open(drive_path + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
  all_stats_1L_2H = pickle.load(f)

In [None]:
NUM_LAYERS = 2
NUM_HEADS = 1
with open(drive_path + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
# with open(my_path + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
  all_stats_2L_1H = pickle.load(f)

In [None]:
NUM_LAYERS = 2
NUM_HEADS = 2
with open(drive_path + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
# with open(my_path + f'all_stats_{NUM_LAYERS}L_{NUM_HEADS}H.pkl', 'rb') as f:
  all_stats_2L_2H = pickle.load(f)

In [None]:
def get_attns(all_stats):
    all_attns = {}

    for data_key in all_stats:
        st_avg_attentions = []
        dt_avg_attentions = []
        dt2_avg_attentions = []
        run_stats = all_stats[data_key]
        for run_idx in run_stats:
            attentions = run_stats[run_idx]['attentions']
            st_avg_attentions.append(attentions['SparseGraphTransformerModel'])
            dt_avg_attentions.append(attentions['DenseGraphTransformerModel'])
            dt2_avg_attentions.append(attentions['DenseGraphTransformerModel_V2'])
        st_attentions = torch.stack(st_avg_attentions)
        dt_attentions = torch.stack(dt_avg_attentions)
        dt2_attentions = torch.stack(dt2_avg_attentions)
        all_attns[data_key] = {'st_avg': st_attentions.mean(axis=0),
                              'dt_avg': dt_attentions.mean(axis=0),
                              'dt2_avg': dt2_attentions.mean(axis=0)
                              }
    return all_attns

all_attns_1L_1H = get_attns(all_stats_1L_1H)
all_attns_1L_2H = get_attns(all_stats_1L_2H)
all_attns_2L_1H = get_attns(all_stats_2L_1H)
all_attns_2L_2H = get_attns(all_stats_2L_2H)

**Combining Attention**

*   We combine the attention matrices across heads by averaging across heads. This gives us how much a node attends to another on average.

*   We combine the attention matrices across layers by matrix multiplying $A_{L2} A_{L1}$. This gives us how much a node attends to another across layers.



In [None]:
model_keys = ['st_avg', 'dt_avg', 'dt2_avg']
data_keys = list(DATASETS.keys())
A1L_1H = {data_key: {model_key: all_attns_1L_1H[data_key][model_key].mean(axis=1)[0] for model_key in model_keys} for data_key in data_keys}
A1L_2H = {data_key: {model_key: all_attns_1L_2H[data_key][model_key].mean(axis=1)[0]  for model_key in model_keys} for data_key in data_keys}
A2L_1H = {data_key: {model_key: (all_attns_2L_1H[data_key][model_key].mean(axis=1)[1] @ all_attns_2L_1H[data_key][model_key].mean(axis=1)[0]).cpu() for model_key in model_keys} for data_key in data_keys}
A2L_2H = {data_key: {model_key: (all_attns_2L_2H[data_key][model_key].mean(axis=1)[1] @ all_attns_2L_2H[data_key][model_key].mean(axis=1)[0]).cpu() for model_key in model_keys} for data_key in data_keys}

In [None]:
import numpy as np

# Compute 90th percentile threshold for each dataset and model
percentile_thresholds = {}

for data_key in data_keys:
    percentile_thresholds[data_key] = {}
    for model_key in model_keys:
        # Extract attention values
        attn_values = A1L_1H[data_key][model_key].flatten().cpu().numpy()
        # Compute 90th percentile threshold
        threshold = np.percentile(attn_values, 90)
        percentile_thresholds[data_key][model_key] = threshold
# Print results
for data_key in data_keys:
    for model_key in model_keys:
        print(f"90th percentile threshold for {data_key} - {model_key}: {percentile_thresholds[data_key][model_key]}")

90th percentile threshold for Cora - st_avg: 0.0
90th percentile threshold for Cora - dt_avg: 7.258294499479234e-05
90th percentile threshold for Cora - dt2_avg: 0.0005024674464948475
90th percentile threshold for Citeseer - st_avg: 0.0
90th percentile threshold for Citeseer - dt_avg: 5.9712510847020894e-06
90th percentile threshold for Citeseer - dt2_avg: 0.00035468433634378016
90th percentile threshold for Chameleon - st_avg: 0.0
90th percentile threshold for Chameleon - dt_avg: 5.9682613937184215e-05
90th percentile threshold for Chameleon - dt2_avg: 0.0005602996097877622
90th percentile threshold for Squirrel - st_avg: 0.0
90th percentile threshold for Squirrel - dt_avg: 0.00011890644964296371
90th percentile threshold for Squirrel - dt2_avg: 0.00025478401221334934
90th percentile threshold for Wisconsin - st_avg: 0.0
90th percentile threshold for Wisconsin - dt_avg: 0.00020773467258550227
90th percentile threshold for Wisconsin - dt2_avg: 0.006559152156114578
90th percentile thres

In [None]:
# import networkx as nx
# import torch
# import numpy as np

# # Select dataset and model
# DATASET = "Cora"
# MODEL = "dt2_avg"  # Try 'dt_avg' or 'dt2_avg' as well

# # Extract attention matrix
# attention_matrix = A1L_2H[DATASET][MODEL]

# # Flatten attention values
# attn_values = attention_matrix.flatten().cpu().numpy()

# # Use 99.5th percentile to filter only the strongest connections
# percentile_threshold = np.percentile(attn_values, 90)

# # Alternative: Use mean + std deviation as a threshold
# mean_threshold = attn_values.mean() + 1.5 * attn_values.std()

# # Choose the better threshold
# threshold = max(percentile_threshold, mean_threshold)

# print(f"Threshold for {DATASET} - {MODEL}: {threshold}")

# def create_graph_from_attention(attention_matrix, threshold):
#     """Creates a directed graph from an attention matrix."""
#     G = nx.DiGraph()
#     num_nodes = attention_matrix.shape[0]

#     for i in range(num_nodes):
#         G.add_node(i)

#     for i in range(num_nodes):
#         for j in range(num_nodes):
#             weight = attention_matrix[i, j].item()
#             if weight > threshold:
#                 G.add_edge(i, j, weight=weight)

#     return G

# # Create graph
# Cora_graph = create_graph_from_attention(attention_matrix, threshold)

# # Print summary
# print(f"Graph for {DATASET} - {MODEL}: {Cora_graph.number_of_nodes()} nodes, {Cora_graph.number_of_edges()} edges")


Threshold for Chameleon - st_avg: 0.018336549401283264
Graph for Chameleon - st_avg: 2277 nodes, 25924 edges


In [None]:
# from google.colab import drive
# import networkx as nx
# import os

# # Mount Google Drive
# drive.mount('/content/drive')

# # Define the save path in your Google Drive
# save_path = "/content/drive/My Drive/Colab Notebooks/Chameleon_attention_graph.graphml"

# # Save the graph in GraphML format (recommended for future use)
# nx.write_graphml(Cora_graph, save_path)

# print(f"Graph saved successfully at: {save_path}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Graph saved successfully at: /content/drive/My Drive/Colab Notebooks/Chameleon_attention_graph.graphml


In [None]:
import networkx as nx
import torch
import numpy as np
from google.colab import drive
import os


# Define save directory in Google Drive
BASE_SAVE_PATH = "/content/drive/My Drive/Colab Notebooks/attention_graphs99.5"
os.makedirs(BASE_SAVE_PATH, exist_ok=True)

ATTN_SETS = {
    "1L2H": all_attns_1L_2H,
    "2L2H": all_attns_2L_2H
}

DATASETS = ["Cora", "Citeseer", "Chameleon"]
MODELS = ["dt_avg", "dt2_avg"]  # DLB and DL models

def create_graph_from_attention(attention_matrix, threshold):
    """Creates a directed graph from an attention matrix."""
    G = nx.DiGraph()
    num_nodes = attention_matrix.shape[0]

    for i in range(num_nodes):
        G.add_node(i)

    for i in range(num_nodes):
        for j in range(num_nodes):
            weight = attention_matrix[i, j].item()
            if weight > threshold:
                G.add_edge(i, j, weight=weight)

    return G

def get_threshold(attention_matrix, model_name):
    """Returns threshold using different percentiles based on model."""
    attn_values = attention_matrix.flatten().cpu().numpy()

    percentile = 99.5 if model_name == "dt2_avg" else 90
    percentile_threshold = np.percentile(attn_values, percentile)
    mean_threshold = attn_values.mean() + 1.5 * attn_values.std()

    return max(percentile_threshold, mean_threshold)

for dataset in DATASETS:
    for model in MODELS:
        for config_name, attn_dict in ATTN_SETS.items():
            try:
                attention_matrix = attn_dict[dataset][model]

                # 🔧 Fix: Average over layer/head dims if needed
                if attention_matrix.ndim == 4:
                    attention_matrix = attention_matrix.mean(dim=(0, 1))
                elif attention_matrix.ndim == 3:
                    attention_matrix = attention_matrix.mean(dim=0)
                elif attention_matrix.ndim != 2:
                    raise ValueError(f"Unexpected attention shape: {attention_matrix.shape}")

                threshold = get_threshold(attention_matrix, model)
                G = create_graph_from_attention(attention_matrix, threshold)

                filename = f"{dataset}_{model}_{config_name}.graphml"
                save_path = os.path.join(BASE_SAVE_PATH, filename)
                nx.write_graphml(G, save_path)

                print(f"✅ Saved {filename}: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
            except Exception as e:
                print(f"Failed to process {dataset} - {model} - {config_name}: {e}")


✅ Saved Cora_dt_avg_1L2H.graphml: 2708 nodes, 16072 edges
✅ Saved Cora_dt_avg_2L2H.graphml: 2708 nodes, 10504 edges
✅ Saved Cora_dt2_avg_1L2H.graphml: 2708 nodes, 36667 edges
✅ Saved Cora_dt2_avg_2L2H.graphml: 2708 nodes, 36667 edges
✅ Saved Citeseer_dt_avg_1L2H.graphml: 3327 nodes, 28646 edges
✅ Saved Citeseer_dt_avg_2L2H.graphml: 3327 nodes, 10147 edges
✅ Saved Citeseer_dt2_avg_1L2H.graphml: 3327 nodes, 55345 edges
✅ Saved Citeseer_dt2_avg_2L2H.graphml: 3327 nodes, 55345 edges
✅ Saved Chameleon_dt_avg_1L2H.graphml: 2277 nodes, 27677 edges
✅ Saved Chameleon_dt_avg_2L2H.graphml: 2277 nodes, 29595 edges
✅ Saved Chameleon_dt2_avg_1L2H.graphml: 2277 nodes, 25924 edges
✅ Saved Chameleon_dt2_avg_2L2H.graphml: 2277 nodes, 25924 edges


In [None]:
import networkx as nx
import os
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
import matplotlib.pyplot as plt
import seaborn as sns

# Configuration
BASE_PATH = "/content/drive/My Drive/Colab Notebooks/attention_graphs99.5"
ATTENTION_GRAPH_FILES = [
    "Chameleon_dt_avg_1L2H.graphml",
    "Chameleon_dt_avg_2L2H.graphml",
    "Chameleon_dt2_avg_1L2H.graphml",
    "Chameleon_dt2_avg_2L2H.graphml"
]
DATASET = "Chameleon"
OUTPUT_DIR = "/content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs"
TOP_K = 100

os.makedirs(OUTPUT_DIR, exist_ok=True)

# Load Chameleon using PyTorch Geometric
import torch
from torch_geometric.datasets import Planetoid, WikipediaNetwork
from torch_geometric.utils import to_networkx


def load_original_graph(dataset_name):
    if dataset_name == "Cora":
        data = Planetoid(root="/tmp/Cora", name="Cora")[0]
    elif dataset_name == "Citeseer":
        data = Planetoid(root="/tmp/Citeseer", name="Citeseer")[0]
    elif dataset_name == "Chameleon":
        data = WikipediaNetwork(root="/tmp/Chameleon", name="chameleon")[0]
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")

    return to_networkx(data, to_undirected=True)


# Compute centrality metrics
def compute_metrics(G):
    return {
        'degree': nx.degree_centrality(G),
        'betweenness': nx.betweenness_centrality(G, normalized=True),
        'closeness': nx.closeness_centrality(G),
        'eigenvector': nx.eigenvector_centrality(G, max_iter=500, tol=1e-02),
        'clustering': nx.clustering(G),
        'pagerank': nx.pagerank(G, alpha=0.85)
    }

# Compute full metric-vs-metric correlation matrix
def compute_cross_metric_correlation(metrics1, metrics2, label1, label2):
    corr_matrix = pd.DataFrame(index=metrics1.keys(), columns=metrics2.keys())
    for m1 in metrics1:
        v1 = np.array(list(metrics1[m1].values()))
        for m2 in metrics2:
            # Match values by node IDs
            v2 = np.array([metrics2[m2].get(n, 0.0) for n in metrics1[m1]])
            corr, _ = spearmanr(v1, v2)
            corr_matrix.loc[m1, m2] = corr
    corr_matrix.index = [f"{label1}: {m}" for m in corr_matrix.index]
    corr_matrix.columns = [f"{label2}: {m}" for m in corr_matrix.columns]
    return corr_matrix.astype(float)

# Save correlation heatmap
def save_corr_heatmap(corr_df, title, out_path):
    plt.figure(figsize=(10, 7))
    sns.heatmap(corr_df, annot=True, fmt=".2f", vmin=-1, vmax=1, cmap='coolwarm', cbar=True)
    plt.title(title)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()

# Save top-k node overlap
def save_topk_overlap(metrics1, metrics2, label1, label2, out_path):
    def get_top_k_nodes(metric_dict, k=100):
        return {
            metric: set(sorted(metric_dict[metric].items(), key=lambda x: x[1], reverse=True)[:k])
            for metric in metric_dict
        }

    topk_1 = get_top_k_nodes(metrics1)
    topk_2 = get_top_k_nodes(metrics2)

    with open(out_path, "w") as f:
        f.write(f"Top-100 Node Overlap Between {label1} and {label2}\n")
        for metric in metrics1:
            nodes_1 = {node for node, _ in topk_1[metric]}
            nodes_2 = {node for node, _ in topk_2[metric]}
            common_nodes = nodes_1 & nodes_2
            f.write(f"\n- {metric.capitalize()}:\n")
            f.write(f"  Overlapping Nodes: {len(common_nodes)}\n")
            if common_nodes:
                sample = sorted(list(common_nodes))[:10]
                f.write(f"  Sample Nodes: {sample} ...\n")

# ----- Main Execution -----
original_graph = load_original_graph(DATASET)
original_metrics = compute_metrics(original_graph)

for filename in ATTENTION_GRAPH_FILES:
    path = os.path.join(BASE_PATH, filename)
    if not os.path.exists(path):
        print(f"Missing file: {path}")
        continue

    attention_label = filename.replace(".graphml", "")
    print(f"Comparing {DATASET} to {attention_label}...")

    # Load attention graph
    G = nx.read_graphml(path)
    G = nx.relabel_nodes(G, lambda x: int(x))  # Ensure integer node IDs
    attention_metrics = compute_metrics(G)

    # Cross-metric correlation matrix
    corr_df = compute_cross_metric_correlation(
        original_metrics, attention_metrics, DATASET, attention_label
    )

    # Save heatmap
    heatmap_file = os.path.join(OUTPUT_DIR, f"{attention_label}_heatmap.png")
    save_corr_heatmap(corr_df, f"{DATASET} vs {attention_label}", heatmap_file)

    # Save top-k overlap
    overlap_file = os.path.join(OUTPUT_DIR, f"{attention_label}_topk_overlap.txt")
    save_topk_overlap(original_metrics, attention_metrics, DATASET, attention_label, overlap_file)

    print(f"Saved: {heatmap_file}, {overlap_file}")


Comparing Chameleon to Chameleon_dt_avg_1L2H...
Saved: /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt_avg_1L2H_heatmap.png, /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt_avg_1L2H_topk_overlap.txt
Comparing Chameleon to Chameleon_dt_avg_2L2H...
Saved: /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt_avg_2L2H_heatmap.png, /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt_avg_2L2H_topk_overlap.txt
Comparing Chameleon to Chameleon_dt2_avg_1L2H...
Saved: /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt2_avg_1L2H_heatmap.png, /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt2_avg_1L2H_topk_overlap.txt
Comparing Chameleon to Chameleon_dt2_avg_2L2H...
Saved: /content/drive/My Drive/Colab Notebooks/Chameleon_comparison_outputs/Chameleon_dt2_avg_2L2H_heatmap.png, /content/drive/My Drive/Colab Noteb

In [None]:
import matplotlib.pyplot as plt
import networkx as nx

def visualize_graph(G, layout='spring', node_size=20, edge_width=0.1):
    """
    Visualizes the graph using a force-directed layout (spring layout by default).

    Args:
        G (nx.DiGraph): The directed graph.
        layout (str): The layout for positioning nodes ('spring', 'circular', 'kamada_kawai', etc.).
        node_size (int): Size of the nodes.
        edge_width (float): Width of the edges.
    """
    # Choose layout for positioning nodes
    if layout == 'spring':
        pos = nx.spring_layout(G, k=0.15, iterations=20)
    elif layout == 'circular':
        pos = nx.circular_layout(G)
    elif layout == 'kamada_kawai':
        pos = nx.kamada_kawai_layout(G)
    else:
        pos = nx.spring_layout(G)  # Default spring layout

    # Draw the graph
    plt.figure(figsize=(12, 12))
    nx.draw(G, pos, node_size=node_size, with_labels=False, edge_color='gray', width=edge_width)
    plt.title(f"Graph Visualization ({layout} layout)")
    plt.show()

# Visualize the graph
visualize_graph(Citeseer_graph, layout='spring')


NameError: name 'Cora_graph' is not defined

# Analysis Part I: What does attention do in different models in homophilous vs heterophilous tasks?


## Section 1.1: Model Accuracies

### Table 2: Accuracy Statistics

In [None]:
### Table 2 ###
### Accuracy Statistics ###

import pandas as pd

pd.set_option('display.max_columns', None)
model_specs = {
    '1L1H': all_stats_1L_1H,
    '1L2H': all_stats_1L_2H,
    '2L1H': all_stats_2L_1H,
    '2L2H': all_stats_2L_2H
}

df_spec = {}

for spec, all_stats in model_specs.items():
    all_stats_df = {}
    for data_key, run_stats in all_stats.items():
        table1 = pd.concat({
            key: pd.DataFrame(stats['accuracy'])
            for key, stats in run_stats.items()
        }, axis=0)
        train_acc = table1.xs('train_acc', level=1)
        table1_train = pd.concat({
            'mean': train_acc.groupby(level=0).mean(),
            'std': train_acc.groupby(level=0).std()
        }, axis=1)
        test_acc = table1.xs('test_acc', level=1)
        table1_test = pd.concat({
            'mean': test_acc.groupby(level=0).mean(),
            'std': test_acc.groupby(level=0).std()
        }, axis=1)
        table1_final = table1_test
        all_stats_df[data_key] = table1_final
    df_spec[spec] = pd.concat(all_stats_df, axis=1).round(2)

In [None]:
# Ignore the rows for GCN with 2 heads
pd.concat(df_spec)

## Section 1.2: Do the Nodes Attend to Neighbors?

### Table 3: Average Attention to Neighbors

In [None]:
datasets_dict = dict(DATASETS.items())
data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin'] #list(DATASETS.keys())
metrics = ['st_avg', 'dt_avg', 'dt2_avg']
model_specs = {'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

df_1L1H = {}


for data_key in data_keys:
  df_1L1H[data_key] = {}
  for metric in metrics:
    df_1L1H[data_key][metric] = {}
    # 1L1H
    all_attns= model_specs['2L2H']
    attn = all_attns[data_key][metric].cpu()
    adj = datasets_dict[data_key].dense_adj.cpu()
    sp = datasets_dict[data_key].dense_sp_matrix.cpu()
    # df_1L1H[data_key][metric] = {'Neighbors': attn[adj==1].mean().item()* 100, 'Non-neighbors': attn[adj==0].mean().item()* 100, 'Ratio': attn[adj==0].mean().item()/attn[adj==1].mean().item()}
    df_1L1H[data_key][metric] = attn[adj==0].mean().item()/attn[adj==1].mean().item()


In [None]:
# pd.concat({key: pd.DataFrame(df_1L1H[key]) for key in df_1L1H.keys()}, axis=0).rename(columns={'st_avg': 'Sparse Transformer', 'dt_avg': 'Dense Transformer wB', 'dt2_avg': 'Dense Transformer'}).round(2)
pd.DataFrame(df_1L1H) .round(2)

### Figure 1: Attention to Neighbors

In [None]:
### Figure 1 ###

import matplotlib.pyplot as plt
import numpy as np

# data_key = 'Cora'
# metric = 'st_avg'

metrics = ['st_avg', 'dt_avg', 'dt2_avg']
metrics_labels = {'st_avg':'SL', 'dt_avg':'DLB', 'dt2_avg':'DL'}
model_specs = {'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

DATASETS1 = {key:DATASETS[key] for key in list(DATASETS.keys())[:4]}
DATASETS2 = {key:DATASETS[key] for key in list(DATASETS.keys())[4:]}

DATASET_CURR = DATASETS2
x_min, x_max = -0.1, 1.5
y_min, y_max = 0, 1.1

fig, axes = plt.subplots(1, len(DATASET_CURR), figsize=(24, 6))
colors = plt.cm.viridis(np.linspace(0, 1, len(model_specs) * len(metrics)))


for idx, (ax, (data_key, data_value)) in enumerate(zip(axes, DATASET_CURR.items())):
    color_idx = 0
    add = 0
    for model_spec in model_specs:
        all_attns = model_specs[model_spec]
        for metric in metrics:
            attn = all_attns[data_key][metric].flatten()
            adj = data_value["dense_adj"].flatten()
            adj[adj == 2] = 1
            label_ = f'{model_spec}-{metrics_labels[metric]}'

            ax.scatter((adj.cpu() + add).numpy(), attn.cpu().numpy(), label=label_, c=colors[color_idx], marker='o')
            ax.set_title(f'{data_key}', fontsize=20)
            ax.set_xticks([0.15, 1.15])
            ax.set_xticklabels(['Non-Neighbor', 'Neighbor'], fontsize=20)
            if idx == 0:
                ax.set_ylabel('Attention', fontsize=20)
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.grid(True)
            color_idx += 1
            add += 0.025

plt.tight_layout()
handles, labels = fig.gca().get_legend_handles_labels()
fig.legend(handles[:12], labels[:12], loc='lower center',  ncol=4, fontsize=20, bbox_to_anchor=(0.5, -0.3))
plt.show()

In [None]:
# # Ensure DATASET_CURR is correct
# DATASET_CURR = DATASETS1

# x_min, x_max = -0.1, 1.5
# y_min, y_max = 0, 1.1

# fig, axes = plt.subplots(1, len(DATASET_CURR), figsize=(24, 6))
# colors = plt.cm.viridis(np.linspace(0, 1, len(model_specs) * len(metrics)))

# for idx, (ax, (data_key, data_value)) in enumerate(zip(axes, DATASET_CURR.items())):
#     color_idx = 0
#     add = 0
#     for model_spec in model_specs:
#         all_attns = model_specs[model_spec]
#         for metric in metrics:
#             attn = all_attns[data_key][metric].flatten().cpu().numpy()
#             adj = data_value["dense_adj"].flatten().cpu().numpy()
#             adj[adj == 2] = 1
#             label_ = f'{model_spec}-{metrics_labels[metric]}'

#             ax.scatter(adj + add, attn, label=label_, c=[colors[color_idx]], marker='o')
#             ax.set_title(f'{data_key}', fontsize=20)
#             ax.set_xticks([0.15, 1.15])
#             ax.set_xticklabels(['Non-Neighbor', 'Neighbor'], fontsize=20)
#             if idx == 0:
#                 ax.set_ylabel('Attention', fontsize=20)
#             ax.set_xlim(x_min, x_max)
#             ax.set_ylim(y_min, y_max)
#             ax.grid(True)
#             color_idx += 1
#             add += 0.025

# plt.tight_layout()
# handles, labels = fig.gca().get_legend_handles_labels()
# fig.legend(handles[:12], labels[:12], loc='lower center', ncol=4, fontsize=20, bbox_to_anchor=(0.5, -0.3))
# plt.show()

### Figure 2: Attention to N-hop Neighborhood: 1Layer 1Head

In [None]:
import matplotlib.pyplot as plt
import numpy as np

metric = 'dt_avg'
metrics = ['st_avg', 'dt_avg', 'dt2_avg']
metrics_labels = {'st_avg': 'SL', 'dt_avg': 'DLB', 'dt2_avg': 'DL'}

DATASETS1 = {key: DATASETS[key] for key in list(DATASETS.keys())[:4]}
DATASETS2 = {key: DATASETS[key] for key in list(DATASETS.keys())[4:]}

model_specs = [A1L_1H]

DATASET_CURR = DATASETS

x_min, x_max = -0.05, 20
y_min, y_max = -0.05, 1.05

for metric in metrics:
    for all_attns in model_specs:
        fig, axes = plt.subplots(1, len(DATASET_CURR), figsize=(24, 6))
        for ax, (data_key, data_value) in zip(axes, DATASET_CURR.items()):
            attn = all_attns[data_key][metric].flatten().cpu().numpy()
            sp = data_value["dense_sp_matrix"].flatten().cpu().numpy()

            ax.scatter(sp, attn, marker='x', color='blue', alpha=0.2)
            ax.set_title(f'{data_key}', fontsize=20)

            if data_key == 'Cora':
                ax.set_ylabel(f'{metrics_labels[metric]} Attention', fontsize=20)
            ax.tick_params(axis='x', labelsize=20)
            ax.tick_params(axis='y', labelsize=20)
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_min, y_max)
            ax.grid(True)

        plt.tight_layout()
        plt.show()

### [NOT INCLUDED] Expected N-Hop Attention per Class

In [None]:
### Figure 3 ###

from matplotlib.ticker import MaxNLocator

DATASETS1 = {key:DATASETS[key] for key in list(DATASETS.keys())[:4]}
DATASETS2 = {key:DATASETS[key] for key in list(DATASETS.keys())[4:]}

DATASET_CURR = DATASETS2

model_specs = [A1L_1H, A1L_2H, A2L_1H, A2L_2H ]

model_key = 'dt2_avg'
n_datasets = len(DATASET_CURR)

for all_attns in model_specs:
  fig, axes = plt.subplots(1, n_datasets, figsize=(10 * n_datasets, 6))  # Adjust the figure size as needed
  for idx, (data_key, ax) in enumerate(zip(DATASET_CURR, axes.flatten())):
    sp_curr = DATASETS[data_key].dense_sp_matrix.cuda()
    sp_curr = torch.nan_to_num(sp_curr, posinf=0)
    attention_curr = all_attns[data_key][model_key].cpu()
    expected_attention = np.array([(attention_curr[i] * sp_curr[i].cpu()).sum().cpu().item() for i in range(attention_curr.shape[0])])
    classes = DATASETS[data_key].y.cpu().numpy()

    # Calculate the mean expected attention for each class
    unique_classes = np.unique(classes)
    mean_attentions = [np.mean(expected_attention[classes == cls]) for cls in unique_classes]

    # Scatter plot for individual data points
    ax.scatter(classes, expected_attention, color='black', s=50, label=f'Dataset {data_key} Data')

    # Bar chart for mean expected attention
    ax.bar(unique_classes, mean_attentions, color='blue', alpha=0.2, label=f'Dataset {data_key} Mean')

    ax.set_title(f'{data_key}: Classes vs. Expected Attention')
    ax.set_xlabel('Classes')
    ax.set_ylabel('Expected Attention')
    ax.grid(True)
    ax.legend(title='Legend')

    # Ensure that only integer values are shown on the x-axis
  ax.xaxis.set_major_locator(MaxNLocator(integer=True))
  plt.tight_layout()
  plt.show()

# Analysis Part II: A framework to analyze a model with multiple heads and layers

**Combining Attention**

*   We combine the attention matrices across heads by averaging across heads. This gives us how much a node attends to another on average.

*   We combine the attention matrices across layers by matrix multiplying $A_{L2} A_{L1}$. This gives us how much a node attends to another across layers.

In [None]:
### Combined Attention Matrices ###

model_keys = ['st_avg', 'dt_avg', 'dt2_avg']
data_keys = list(DATASETS.keys())
A1L_1H = {data_key: {model_key: all_attns_1L_1H[data_key][model_key].mean(axis=1)[0] for model_key in model_keys} for data_key in data_keys}
A1L_2H = {data_key: {model_key: all_attns_1L_2H[data_key][model_key].mean(axis=1)[0]  for model_key in model_keys} for data_key in data_keys}
A2L_1H = {data_key: {model_key: (all_attns_2L_1H[data_key][model_key].mean(axis=1)[1] @ all_attns_2L_1H[data_key][model_key].mean(axis=1)[0]).cpu() for model_key in model_keys} for data_key in data_keys}
A2L_2H = {data_key: {model_key: (all_attns_2L_2H[data_key][model_key].mean(axis=1)[1] @ all_attns_2L_2H[data_key][model_key].mean(axis=1)[0]).cpu() for model_key in model_keys} for data_key in data_keys}

## Section 2.1: Do the heads learn the same patterns?

### [NOT INCLUDED] N-Hop Neighborhood Attendance Comparison for 2 Heads (1L2H Model)

In [None]:
### Figure 2 ###

import matplotlib.pyplot as plt
import numpy as np

num_heads = 2

# data_key = 'Cora'
metric = 'dt2_avg'

# Figure 2 subplots
DATASETS1 = {key:DATASETS[key] for key in list(DATASETS.keys())[:4]}
DATASETS2 = {key:DATASETS[key] for key in list(DATASETS.keys())[4:]}

DATASET_CURR = DATASETS2

all_attns = all_attns_2L_2H
x_min, x_max = 0, 10
y_min, y_max = 0, 0.5

fig, axes = plt.subplots(1, len(DATASET_CURR), figsize=(24, 6))
for ax, (data_key, data_value) in zip(axes, DATASET_CURR.items()):
    attn = all_attns[data_key][metric][0]
    sp = data_value["dense_sp_matrix"]
    for head_idx in range(num_heads):
      attn = all_attns[data_key][metric][0][head_idx]
      print(sp.cpu().shape,  attn.cpu().shape)
      ax.scatter(sp.cpu()+0.1*head_idx, attn.cpu())
    # ax.scatter(sp.cpu(), attn.cpu(), marker='x', color='blue')#,c=attn.cpu(), cmap='icefire')
    ax.set_title(f'{data_key}: Attention Paid to N-hop Neighborhood')
    ax.set_xlabel('N-Hop Neighborhood')
    ax.set_ylabel('Attention')
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.grid(True)

plt.tight_layout()
plt.show()

### Figure 3: Comparison of Head Attention Patterns

In [None]:
import matplotlib.pyplot as plt

data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin']
metrics = ['st_avg', 'dt_avg', 'dt2_avg']
metrics_labels = {'st_avg': 'SL', 'dt_avg': 'DLB', 'dt2_avg': 'DL'}

x_min, x_max = 0, 1
y_min, y_max = 0, 1

fig, axs = plt.subplots(len(metrics), len(data_keys), figsize=(20, 9))

for i, model_name in enumerate(metrics):
    for j, dataset_name in enumerate(data_keys):
        head_1 = all_attns_1L_2H[dataset_name][model_name][0][0].cpu().numpy()
        head_2 = all_attns_1L_2H[dataset_name][model_name][0][1].cpu().numpy()

        axs[i, j].scatter(head_1, head_2, color='black', marker='x', alpha=0.2)

        if i == 0:
            axs[i, j].set_title(dataset_name, fontsize=16)

        if j == 0:
            axs[i, j].set_ylabel(f'{metrics_labels[model_name]}\nHead 2', fontsize=14)

        if i == len(metrics) - 1:
            axs[i, j].set_xlabel('Head 1', fontsize=14)

        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])

        axs[i, j].set_xlim(x_min, x_max)
        axs[i, j].set_ylim(y_min, y_max)
        axs[i, j].grid(True)

plt.tight_layout()
plt.show()

## Section 2.2: Do the layers learn the same pattern?

### Figure 4: Comparison of Layer Attention Patterns

In [None]:
import matplotlib.pyplot as plt

data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin']
metrics = ['st_avg', 'dt_avg', 'dt2_avg']
metrics_labels = {'st_avg': 'SL', 'dt_avg': 'DLB', 'dt2_avg': 'DL'}

x_min, x_max = 0, 1
y_min, y_max = 0, 1

fig, axs = plt.subplots(len(metrics), len(data_keys), figsize=(20, 9))

for i, model_name in enumerate(metrics):
    for j, dataset_name in enumerate(data_keys):
        layer1 = all_attns_2L_1H[dataset_name][model_name][0][0].cpu().numpy()
        layer2 = all_attns_2L_1H[dataset_name][model_name][1][0].cpu().numpy()

        axs[i, j].scatter(layer1, layer2, color='black', marker='x', alpha=0.2)

        if i == 0:
            axs[i, j].set_title(dataset_name, fontsize=16)

        if j == 0:
            axs[i, j].set_ylabel(f'{metrics_labels[model_name]}\nLayer 2', fontsize=14)

        if i == len(metrics) - 1:
            axs[i, j].set_xlabel('Layer 1', fontsize=14)

        axs[i, j].set_xticks([])
        axs[i, j].set_yticks([])

        axs[i, j].set_xlim(x_min, x_max)
        axs[i, j].set_ylim(y_min, y_max)
        axs[i, j].grid(True)

plt.tight_layout()
plt.show()

## [NOT INCLUDED] Section 2.3: Do the Models learn the same pattern?

### [NOT INCLUDED] Model Comparison

In [None]:
data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin'] #list(DATASETS.keys())
metrics = ['st_avg', 'dt_avg', 'dt2_avg']
metrics_labels = {'st_avg':'SparseT', 'dt_avg':'DenseTwB', 'dt2_avg':'DenseT'}
model_specs ={'1L1H': A1L_1H, '2L2H': A2L_2H} #{'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

for metric in metrics:
  for dataset_name in data_keys:
      plt.title(dataset_name)
      plt.scatter(A1L_1H[dataset_name][metric].cpu(), A2L_2H[dataset_name][metric].cpu())
      plt.show()

# Analysis Part III: Does the  graph structure in attention recover the original graph structure?

## Section 3.1: Combining Attention Matrices

In [None]:
### Combined Attention Matrices ###

data_keys = list(DATASETS.keys())
A1L_1H = {data_key: {model_key: all_attns_1L_1H[data_key][model_key].mean(axis=1)[0] for model_key in model_keys} for data_key in data_keys}
A1L_2H = {data_key: {model_key: all_attns_1L_2H[data_key][model_key].mean(axis=1)[0]  for model_key in model_keys} for data_key in data_keys}
A2L_1H = {data_key: {model_key: (all_attns_2L_1H[data_key][model_key].mean(axis=1)[1] @ all_attns_2L_1H[data_key][model_key].mean(axis=1)[0]).cpu() for model_key in model_keys} for data_key in data_keys}
A2L_2H = {data_key: {model_key: (all_attns_2L_2H[data_key][model_key].mean(axis=1)[1] @ all_attns_2L_2H[data_key][model_key].mean(axis=1)[0]).cpu() for model_key in model_keys} for data_key in data_keys}

## Section 3.2: Selecting a Threshold [Commented Out]

In [None]:
# import numpy as np

# model_keys = ['dt_avg', 'dt2_avg']
# data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin']
# model_specs = {'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

# thresholds = np.arange(0, 1, 0.001)  # Example range of thresholds
# selected_thresholds = {}

# for spec_key in model_specs:
#   print(spec_key)
#   spec = model_specs[spec_key]
#   selected_thresholds[spec_key] = {}
#   for data_key in data_keys:
#       selected_thresholds[spec_key][data_key] = {}
#       for model_key in model_keys:
#           adj = DATASETS[data_key].dense_adj.cpu()
#           attn_base = spec[data_key][model_key].cpu()

#           best_threshold = 0
#           min_difference = float('inf')

#           for threshold in thresholds:
#               attn = attn_base.clone()
#               attn[attn >= threshold] = 1
#               attn[attn < threshold] = 0

#               difference = abs(adj.sum() - attn.sum())
#               if difference < min_difference:
#                   min_difference = difference
#                   best_threshold = threshold

#           # Now, best_threshold is the threshold where the difference between sums is minimized.
#           print(f"Best threshold {data_key} - {model_key}: {best_threshold}")
#           print(f"Difference: {min_difference}")
#           selected_thresholds[spec_key][data_key][model_key] = best_threshold

##############
### Output ###
##############

# 1L1H
# Best threshold Cora - dt_avg: 0.037
# Difference: 87.0
# Best threshold Cora - dt2_avg: 0.019
# Difference: 470.0
# Best threshold Citeseer - dt_avg: 0.06
# Difference: 33.0
# Best threshold Citeseer - dt2_avg: 0.02
# Difference: 743.0
# Best threshold Chameleon - dt_avg: 0.006
# Difference: 243.0
# Best threshold Chameleon - dt2_avg: 0.006
# Difference: 2515.0
# Best threshold Squirrel - dt_avg: 0.003
# Difference: 3929.0
# Best threshold Squirrel - dt2_avg: 0.001
# Difference: 86327.0
# Best threshold Cornell - dt_avg: 0.079
# Difference: 2.0
# Best threshold Cornell - dt2_avg: 0.06
# Difference: 1.0
# Best threshold Texas - dt_avg: 0.039
# Difference: 4.0
# Best threshold Texas - dt2_avg: 0.07200000000000001
# Difference: 4.0
# Best threshold Wisconsin - dt_avg: 0.058
# Difference: 5.0
# Best threshold Wisconsin - dt2_avg: 0.051000000000000004
# Difference: 10.0
# 1L2H
# Best threshold Cora - dt_avg: 0.024
# Difference: 39.0
# Best threshold Cora - dt2_avg: 0.015
# Difference: 560.0
# Best threshold Citeseer - dt_avg: 0.054
# Difference: 69.0
# Best threshold Citeseer - dt2_avg: 0.012
# Difference: 1171.0
# Best threshold Chameleon - dt_avg: 0.005
# Difference: 505.0
# Best threshold Chameleon - dt2_avg: 0.006
# Difference: 4235.0
# Best threshold Squirrel - dt_avg: 0.003
# Difference: 10409.0
# Best threshold Squirrel - dt2_avg: 0.001
# Difference: 147306.0
# Best threshold Cornell - dt_avg: 0.049
# Difference: 2.0
# Best threshold Cornell - dt2_avg: 0.054
# Difference: 13.0
# Best threshold Texas - dt_avg: 0.034
# Difference: 3.0
# Best threshold Texas - dt2_avg: 0.059000000000000004
# Difference: 0.0
# Best threshold Wisconsin - dt_avg: 0.049
# Difference: 3.0
# Best threshold Wisconsin - dt2_avg: 0.052000000000000005
# Difference: 23.0
# 2L1H
# Best threshold Cora - dt_avg: 0.038
# Difference: 98.0
# Best threshold Cora - dt2_avg: 0.008
# Difference: 801.0
# Best threshold Citeseer - dt_avg: 0.044
# Difference: 29.0
# Best threshold Citeseer - dt2_avg: 0.007
# Difference: 1769.0
# Best threshold Chameleon - dt_avg: 0.014
# Difference: 825.0
# Best threshold Chameleon - dt2_avg: 0.005
# Difference: 2650.0
# Best threshold Squirrel - dt_avg: 0.005
# Difference: 33406.0
# Best threshold Squirrel - dt2_avg: 0.001
# Difference: 8669.0
# Best threshold Cornell - dt_avg: 0.08
# Difference: 0.0
# Best threshold Cornell - dt2_avg: 0.05
# Difference: 7.0
# Best threshold Texas - dt_avg: 0.056
# Difference: 2.0
# Best threshold Texas - dt2_avg: 0.036000000000000004
# Difference: 9.0
# Best threshold Wisconsin - dt_avg: 0.046
# Difference: 2.0
# Best threshold Wisconsin - dt2_avg: 0.065
# Difference: 13.0
# 2L2H
# Best threshold Cora - dt_avg: 0.039
# Difference: 42.0
# Best threshold Cora - dt2_avg: 0.007
# Difference: 2419.0
# Best threshold Citeseer - dt_avg: 0.041
# Difference: 101.0
# Best threshold Citeseer - dt2_avg: 0.006
# Difference: 856.0
# Best threshold Chameleon - dt_avg: 0.013000000000000001
# Difference: 432.0
# Best threshold Chameleon - dt2_avg: 0.004
# Difference: 4218.0
# Best threshold Squirrel - dt_avg: 0.005
# Difference: 39171.0
# Best threshold Squirrel - dt2_avg: 0.001
# Difference: 22364.0
# Best threshold Cornell - dt_avg: 0.052000000000000005
# Difference: 4.0
# Best threshold Cornell - dt2_avg: 0.048
# Difference: 9.0
# Best threshold Texas - dt_avg: 0.02
# Difference: 4.0
# Best threshold Texas - dt2_avg: 0.047
# Difference: 32.0
# Best threshold Wisconsin - dt_avg: 0.057
# Difference: 2.0
# Best threshold Wisconsin - dt2_avg: 0.048
# Difference: 1.0

In [None]:
# Save the thresholds
# import json
# file_path = 'drive/MyDrive/Colab Notebooks/L65/selected_thresholds_dict.pkl'
# with open(file_path, 'w') as file:
#     json.dump(selected_thresholds, file)

## Section 3.3: Analyzing Thresholded Attention

### Figure 5: Attention Heatmaps

In [None]:
import json
# Load the thresholds
file_path = 'drive/MyDrive/Colab Notebooks/selected_thresholds_dict.pkl'
with open(file_path, 'r') as file:
    selected_thresholds = json.load(file)

In [None]:
model_keys = ['dt_avg', 'dt2_avg']
metrics_labels = {'st_avg':'SL', 'dt_avg':'DLB', 'dt2_avg':'DL'}
data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin'] #list(DATASETS.keys())
model_specs = {'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

thresholded_attentions = {}

thresholded_attentions = {}
for spec_key in model_specs:
  spec = model_specs[spec_key]
  print(spec_key)
  thresholded_attentions[spec_key] = {}
  for data_key in data_keys:
    thresholded_attentions[spec_key][data_key] = {}
    for model_key in model_keys:
      threshold = selected_thresholds[spec_key][data_key][model_key]
      attn = spec[data_key][model_key].cpu()
      attn[attn>=threshold] = 1
      attn[attn<threshold] = 0
      thresholded_attentions[spec_key][data_key][metrics_labels[model_key]] = attn

In [None]:
thresholded_attentions

In [None]:
model_keys = ['dt_avg', 'dt2_avg']
metrics_labels = {'st_avg':'SL', 'dt_avg':'DLB', 'dt2_avg':'DL'}
data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin'] #list(DATASETS.keys())
model_specs = {'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

import matplotlib.pyplot as plt
import numpy as np

fig, axs = plt.subplots(1, 7, figsize=(14, 4))

for j in range(7):
    data = DATASETS[data_keys[j]].dense_adj.cpu()
    axs[j].imshow(data, cmap='hot', interpolation='nearest')
    axs[j].set_title(data_keys[j], fontsize=15)
    if j == 0:
      axs[j].text(-0.1, 0.5, 'Adjacency', transform=axs[j].transAxes, va='center', ha='right', rotation=90, fontsize=15)
    axs[j].axis('off')  # Turn off the axis
plt.tight_layout()
plt.show()

In [None]:
spec_key = '1L1H'
fig, axs = plt.subplots(2, 7, figsize=(14, 4))

for i in range(2):
    for j in range(7):
        data = thresholded_attentions[spec_key][data_keys[j]][metrics_labels[model_keys[i]]]
        axs[i, j].imshow(data, cmap='hot', interpolation='nearest')
        axs[i, j].axis('off')  # Turn off the axis
        if j == 0:
            axs[i, j].text(-0.1, 0.5, f'{spec_key}\n{metrics_labels[model_keys[i]]}',
                           transform=axs[i, j].transAxes, va='center', ha='right',
                           rotation=90, fontsize=15)
plt.tight_layout()
plt.show()

In [None]:
spec_key = '1L2H'
fig, axs = plt.subplots(2, 7, figsize=(14, 4))
for i in range(2):
    for j in range(7):
        data = thresholded_attentions[spec_key][data_keys[j]][metrics_labels[model_keys[i]]]
        axs[i, j].imshow(data, cmap='hot', interpolation='nearest')
        axs[i, j].axis('off')  # Turn off the axis
        if j == 0:
            axs[i, j].text(-0.1, 0.5, f'{spec_key}\n{metrics_labels[model_keys[i]]}',
                           transform=axs[i, j].transAxes, va='center', ha='right',
                           rotation=90, fontsize=15)
plt.tight_layout()
plt.show()

In [None]:
spec_key = '2L1H'
fig, axs = plt.subplots(2, 7, figsize=(14, 4))
for i in range(2):
    for j in range(7):
        data = thresholded_attentions[spec_key][data_keys[j]][metrics_labels[model_keys[i]]]
        axs[i, j].imshow(data, cmap='hot', interpolation='nearest')
        axs[i, j].axis('off')  # Turn off the axis
        if j == 0:
            axs[i, j].text(-0.1, 0.5, f'{spec_key}\n{metrics_labels[model_keys[i]]}',
                           transform=axs[i, j].transAxes, va='center', ha='right',
                           rotation=90, fontsize=15)
plt.tight_layout()
plt.show()

In [None]:
spec_key = '2L2H'
fig, axs = plt.subplots(2, 7, figsize=(14, 4))
for i in range(2):
    for j in range(7):
        data = thresholded_attentions[spec_key][data_keys[j]][metrics_labels[model_keys[i]]]
        axs[i, j].imshow(data, cmap='hot', interpolation='nearest')
        axs[i, j].axis('off')  # Turn off the axis
        if j == 0:
            axs[i, j].text(-0.1, 0.5, f'{spec_key}\n{metrics_labels[model_keys[i]]}',
                           transform=axs[i, j].transAxes, va='center', ha='right',
                           rotation=90, fontsize=15)
plt.tight_layout()
plt.show()

### Table 4: Adjacency Recovery: P, R, F1

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, matthews_corrcoef

model_keys = ['dt_avg', 'dt2_avg']
metrics_labels = {'st_avg':'SparseT', 'dt_avg':'DenseTwB', 'dt2_avg':'DenseT'}
data_keys = ['Cora', 'Citeseer', 'Chameleon', 'Squirrel', 'Cornell', 'Texas', 'Wisconsin'] #list(DATASETS.keys())
model_specs = {'1L1H': A1L_1H, '1L2H': A1L_2H, '2L1H': A2L_1H, '2L2H': A2L_2H }

results = {}
for spec_key in model_specs:
  spec = model_specs[spec_key]
  print(spec_key)
  results[spec_key] = {}
  for data_key in data_keys:
    results[spec_key][data_key] = {}
    for model_key in model_keys:
      attn = thresholded_attentions[spec_key][data_key][metrics_labels[model_key]]
      adj =  DATASETS[data_key].dense_adj.cpu()
      f1 = f1_score(adj.flatten(), (attn >= 1).flatten())
      p = precision_score(adj.flatten(), (attn >= 1).flatten())
      r = recall_score(adj.flatten(), (attn >= 1).flatten())
      # Simple silarity metrics between true adjacency matrix and learnt adjacency matrix (attn)
      print("F1 Score: {:.4f}".format(f1_score(adj.flatten(), (attn >= 1).flatten())))
      print("Precision: {:.4f}".format(precision_score(adj.flatten(), (attn >= 1).flatten())))
      print("Recall: {:.4f}".format(recall_score(adj.flatten(), (attn >= 1).flatten())))
      results[spec_key][data_key][metrics_labels[model_key]] = {'P': p , 'R' : r, 'F1': f1}

In [None]:
R_1L1H = results['1L1H']
R_1L2H = results['1L2H']
R_2L1H = results['2L1H']
R_2L2H = results['2L2H']

R_1L1H_df = pd.concat({dataset:pd.DataFrame(R_1L1H[dataset]) for dataset in R_1L1H.keys()}, axis=1)
R_1L2H_df = pd.concat({dataset:pd.DataFrame(R_1L2H[dataset]) for dataset in R_1L2H.keys()}, axis=1)
R_2L1H_df = pd.concat({dataset:pd.DataFrame(R_2L1H[dataset]) for dataset in R_2L1H.keys()}, axis=1)
R_2L2H_df = pd.concat({dataset:pd.DataFrame(R_2L2H[dataset]) for dataset in R_2L2H.keys()}, axis=1)

In [None]:
pd.concat({'1L1H':R_1L1H_df, '1L2H':R_1L2H_df, '2L1H':R_2L1H_df, '2L2H':R_2L2H_df}).round(4) * 100