# Comparision of Weisfeiler Lehman and SS-GNN 
## Observations:
* The SS-GNN cannot be more expressive than probabilistic SS-WL and SS-WL is more expressive than probabilistic SS-WL
* SS-GNN encodings for two graphs with similar WL-hash should be same. Atleast the distance between the encodings should be same.

## Experiment: 
* Extract the activations from ss_gnn activation.
* The `conv_i` contains the encoding for each node.
* After the pooling layer we get the encodings for each subgraph in the batch.
* The subgraphs with same wl encoding should have the ss-gnn encodings in an $\epsilon$-ball, for some $\epsilon>0$

In [1]:
import torch
import numpy as np
from torch_geometric.datasets import GNNBenchmarkDataset
from gps.utils.data_transform import SetNodeFeaturesOnes
from torch_geometric.loader import DataLoader

from uniform_sampler import sample_batch as sampler
import networkx as nx
import matplotlib.pyplot as plt

from gps.models.ss_gnn import SubgraphGNNEncoder, SubgraphSamplingGNNClassifier
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn.conv import GCNConv, GINConv
from torch_geometric.nn.pool import global_add_pool
import torch_geometric.nn as pyg

from torch_geometric.utils import scatter
import hashlib

In [2]:
class VerboseSubgraphEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_layers, pooling='sum', residual=False,batch_norm=False):
        super().__init__()
        self.num_layers = num_layers
        self.residual = residual
        self.batch_norm = batch_norm
        self.proj = nn.Linear(in_dim,hidden_dim,bias=False)
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        for _ in range(num_layers):
            neural_net = self._make_neural_net(in_dim=hidden_dim, out_dim=hidden_dim)
            self.convs.append(GINConv(neural_net))
            if batch_norm:
                self.bns.append(pyg.BatchNorm(hidden_dim))
        if pooling == 'sum':
            self.pooling = pyg.global_add_pool
        self.activations = {}

    def forward(self, x, edge_index, batch):
        h = self.proj(x)
        self.activations['proj_0'] = h.clone()
        for l in range(self.num_layers):
            h_res = h
            h = self.convs[l](h,edge_index)
            self.activations[f'conv_{l}'] = h
            if self.batch_norm:
                h = self.bns[l](h)
                self.activations[f'bns_{l}'] = h
            if self.residual:
                h = h + h_res
                self.activations[f'res_{l}'] = h
        h_out = self.pooling(h,batch)
        self.activations['pooling'] = h_out
        return h_out
        
    def _make_neural_net(self, in_dim, out_dim, num_layers=2):
        layers = nn.ModuleList()
        layers.append(nn.Linear(in_dim, out_dim))
        layers.append(nn.ReLU())
        for _ in range(num_layers-1):
            layers.append(nn.Linear(out_dim, out_dim))
            layers.append(nn.ReLU())
        return nn.Sequential(*layers)
        

In [3]:
class VerboseAttentionAggregator(nn.Module):
    def __init__(self, 
                 hidden_dim, 
                 temperature=0.2, 
                 pooling='sum'):
        super().__init__()
        self.temperature = temperature
        self.attention_mlp = nn.Sequential(
        nn.Linear(hidden_dim, hidden_dim//2),
        nn.ReLU(),
        nn.Linear(hidden_dim//2,1)
            )
        self.activations = {}
        
    def forward(self, subgraph_embeddings, batch):
        num_graphs = batch.max().item()+1
        scores = self.attention_mlp(subgraph_embeddings)
        scores = scores/self.temperature

        # find the max score along all subgraphs of each graph of the batch
        max_scores = scatter(scores, batch, dim=0, dim_size=num_graphs, reduce='max')
        # assign the corresponding max score to each subgraph of the batch
        max_scores = max_scores[batch]
        # e^{-(max_scores-scores)}
        scores_exp = torch.exp(scores-max_scores)
        # sum exp scores along each graph
        scores_sum = scatter(scores_exp, batch, dim=0, dim_size=num_graphs, reduce='sum')
        # boadcast socres_sums along each subgraph
        scores_sum = scores_sum[batch]
        # normalizing the scores_exp
        attention_weights = scores_exp / (scores_sum + 1e-8)
        
        '''scaling subgraph embeddings with sttention_weights
            attention_weights should prioritize significant subgraphs
            and scalling with attention weight should enhance (increase
            in magnitude) the important subgraph embeddings. 
        '''
        self.activations['attention_weights'] = attention_weights
        weighted_embeddings = attention_weights * subgraph_embeddings
        # finally, weighted sum for all subgraphs of each graph
        graphs_embeddings = scatter(weighted_embeddings, batch, dim=0, dim_size=num_graphs, reduce='sum')
        return graphs_embeddings

In [4]:
# function to create ss-gnn input from dataset batch
def build_input(batch, k, m):
      num_subgraphs = (batch.batch.max()+1)*m
      nodes_sampled, edge_index_sampled, edge_ptr, sample_ptr, edge_src_global = \
          sampler(batch.edge_index, batch.ptr, m_per_graph=m, k=k)
      x_global = batch.x[nodes_sampled.flatten()]
      edge_index_global = torch.repeat_interleave(torch.arange(0,num_subgraphs), edge_ptr[1:]-edge_ptr[:-1])*k + edge_index_sampled  
      sample_id = torch.repeat_interleave(torch.arange(0,num_subgraphs), k)
      return x_global, edge_index_global, sample_id

In [5]:
data_dim = 5
hidden_dim = 64
k = 6
m = 100

# data
transform = SetNodeFeaturesOnes(dim=data_dim)
data = GNNBenchmarkDataset("./.temp/CSL", name='CSL', transform=transform)
loader = DataLoader(data,batch_size=15, shuffle=False)

# model
verbose_encoder = VerboseSubgraphEncoder(in_dim=5,hidden_dim=10, num_layers=4)
verbose_attention = VerboseAttentionAggregator(hidden_dim=10)
for i, batch in enumerate(loader):
    y = batch.y
    num_graphs = batch.batch.max()+1
    
    x_global, edge_index_global, sample_id = build_input(batch, k, m)
    
    subgraph_encodings = verbose_encoder(x_global, edge_index_global, sample_id)
    
    graph_id = torch.repeat_interleave(torch.arange(num_graphs),m)
    graph_encodings = verbose_attention(subgraph_encodings,graph_id)
    if i == 0:
        break

enc_activations = verbose_encoder.activations
agg_activations = verbose_attention.activations
subgraph_enc = subgraph_encodings.detach()
graph_enc = graph_encodings.detach()


In [6]:
def wl_graph_hash_batch(edge_index, batch, node_features=None, num_iterations=3):
    """
    Compute WL hash for batched graphs.
    
    Args:
        edge_index: (2, num_edges) - edge indices
        batch: (num_nodes,) - batch assignment for each node
        node_features: (num_nodes, feature_dim) or None
        num_iterations: number of WL iterations
    
    Returns:
        hashes: list of hash strings, one per graph
    """
    num_nodes = batch.size(0)
    num_graphs = batch.max().item() + 1
    
    # Initialize labels
    if node_features is None:
        labels = torch.zeros(num_nodes, dtype=torch.long)
    else:
        labels = torch.tensor([hash(tuple(f.tolist())) % (2**31) 
                               for f in node_features], dtype=torch.long)
    
    # WL iterations
    for _ in range(num_iterations):
        new_labels = torch.zeros(num_nodes, dtype=torch.long)
        
        # Build adjacency for efficient neighbor lookup
        adj_dict = {i: [] for i in range(num_nodes)}
        for src, tgt in edge_index.t().tolist():
            adj_dict[src].append(tgt)
        
        for node in range(num_nodes):
            neighbors = adj_dict[node]
            if neighbors:
                neighbor_labels = labels[neighbors].sort()[0].tolist()
            else:
                neighbor_labels = []
            
            signature = str([labels[node].item()] + neighbor_labels)
            new_labels[node] = int(hashlib.md5(signature.encode()).hexdigest(), 16) % (2**31)
        
        labels = new_labels
    
    # Aggregate to graph-level hashes
    graph_hashes = []
    for g in range(num_graphs):
        mask = batch == g
        graph_labels = sorted(labels[mask].tolist())
        graph_sig = str(graph_labels)
        graph_hash = hashlib.md5(graph_sig.encode()).hexdigest()
        graph_hashes.append(graph_hash)
    
    return graph_hashes


In [7]:
def unique_with_epsilon(tensor, epsilon=1e-6, return_counts=False):
    """
    Find unique values in a tensor where values within epsilon are considered the same.
    
    Args:
        tensor: 1D PyTorch tensor
        epsilon: tolerance for considering values as equal
        return_counts: if True, also return counts for each unique value
    
    Returns:
        unique_values: Tensor of unique values
        counts (optional): Tensor of counts for each unique value
    """
    if len(tensor) == 0:
        if return_counts:
            return tensor, torch.tensor([], dtype=torch.long)
        return tensor
    
    # Sort the tensor
    sorted_tensor, _ = torch.sort(tensor)
    
    # Find where consecutive differences are greater than epsilon
    diffs = torch.diff(sorted_tensor)
    mask = torch.cat([torch.tensor([True]), diffs > epsilon])
    
    # Select unique values
    unique_values = sorted_tensor[mask]
    
    if return_counts:
        # Find cluster boundaries
        boundaries = torch.cat([torch.tensor([True]), diffs > epsilon, torch.tensor([True])])
        boundary_indices = torch.where(boundaries)[0]
        
        # Compute counts for each cluster
        counts = torch.diff(boundary_indices)
        
        return unique_values, counts
    
    return unique_values


In [1]:
hashs = wl_graph_hash_batch(edge_index_global, sample_id, x_global, num_iterations=1)
print(f'WL-Distinct subgraphs: {np.unique(hashs).__len__()}')
print(f"Number of subgraphss: {subgraph_enc.__len__()}")

subgraph_enc_dist = torch.cdist(subgraph_enc,subgraph_enc)
print("We need to see the unique encodings in list of subgraph encodings!")
print(f"Unique distances from the first subgraphs to other graph: {torch.unique(subgraph_enc_dist[0]).__len__()}")
print(f"Number of unique subgraph encodings with small numerical tollaracne: {unique_with_epsilon(subgraph_enc_dist[0], epsilon=1e-3).__len__()}")

NameError: name 'wl_graph_hash_batch' is not defined

*Now compare the ss-gnn encoder output with `hashs`. Observe the encoding distances between graphs with same WL hash.*
* If graphs withs same WL encoding has encoding distant high then something is wrong.

## Equivalance of PSS-WL and SS-GNN: 
**8 PSS-WL distinguishable subgraphs in sample of the first graph.**<br>
**Similarly, 8 SS-GNN distinguishable subgraph present (with neumerical tollarace $1e-3$).**<br>
Now we observe the whether frequences of each class to identify the whether the emperical distributions are also same.

In [10]:
unique_hash, hash_feq = np.unique(hashs,return_counts=True)
print(f"Hash frequencies in the first: {hash_feq}")
unique_subgraphs, subgraph_feq = unique_with_epsilon(subgraph_enc_dist[0], epsilon=1e-4, return_counts=True)
print(f"unique subgraph frequencies: {subgraph_feq}")

Hash frequencies in the first: [ 55  46 599  87 283 189 188  53]
unique subgraph frequencies: tensor([599, 283, 189,  53, 188,  55,  87,  46])


So the emperical distributions are also same. So, the model is **working upto SubgraphGNNEncoder.**

# After aggregation:
* After aggregation the graphs with same lebel(10-graphs) have slightly different encodigns.
* We need to find out whether the encodings of SS_GNN for other graphs with different labels are significantly different
* If they are not significantly different, the attention might learn to distinguish these graphs during training
  Now we study unique distances from first graph. As the class is same the distances should be small. 

In [13]:
torch.unique(torch.cdist(graph_enc, graph_enc)[0])

tensor([0.0000, 0.0057, 0.0077, 0.0094, 0.0096, 0.0132, 0.0145, 0.0146, 0.0265,
        0.0292, 0.0303, 0.0337, 0.0371, 0.0438, 0.0663])

So we see the distances are small enough. Now see how the graph embeddings for two graphs of different labels differ.

### for second batch

In [15]:
data_dim = 5
hidden_dim = 64
k = 6
m = 100

# data
transform = SetNodeFeaturesOnes(dim=data_dim)
data = GNNBenchmarkDataset("./.temp/CSL", name='CSL', transform=transform)
loader = DataLoader(data,batch_size=15, shuffle=False)

# model
verbose_encoder = VerboseSubgraphEncoder(in_dim=5,hidden_dim=10, num_layers=4)
verbose_attention = VerboseAttentionAggregator(hidden_dim=10)
for i, batch in enumerate(loader):
    y = batch.y
    num_graphs = batch.batch.max()+1
    
    x_global, edge_index_global, sample_id = build_input(batch, k, m)
    
    subgraph_encodings = verbose_encoder(x_global, edge_index_global, sample_id)
    
    graph_id = torch.repeat_interleave(torch.arange(num_graphs),m)
    graph_encodings = verbose_attention(subgraph_encodings,graph_id)
    if i == 1:
        break

enc_activations = verbose_encoder.activations
agg_activations = verbose_attention.activations
subgraph_enc = subgraph_encodings.detach()
graph_enc_1 = graph_encodings.detach()

In [24]:
torch.cdist(graph_enc, graph_enc_1)

tensor([[3.2890, 3.2918, 3.2909, 3.2927, 3.2905, 3.2906, 3.2906, 3.2917, 3.2920,
         3.2884, 3.2903, 3.2918, 3.2917, 3.2926, 3.2914],
        [3.3335, 3.3364, 3.3355, 3.3373, 3.3350, 3.3351, 3.3351, 3.3362, 3.3366,
         3.3329, 3.3349, 3.3364, 3.3362, 3.3372, 3.3360],
        [3.3223, 3.3251, 3.3242, 3.3260, 3.3238, 3.3239, 3.3238, 3.3250, 3.3253,
         3.3216, 3.3236, 3.3251, 3.3249, 3.3259, 3.3247],
        [3.3199, 3.3228, 3.3218, 3.3237, 3.3214, 3.3215, 3.3215, 3.3226, 3.3229,
         3.3193, 3.3212, 3.3227, 3.3226, 3.3236, 3.3223],
        [3.2937, 3.2966, 3.2956, 3.2975, 3.2952, 3.2953, 3.2953, 3.2964, 3.2967,
         3.2931, 3.2950, 3.2965, 3.2964, 3.2973, 3.2961],
        [3.3029, 3.3058, 3.3048, 3.3066, 3.3044, 3.3045, 3.3045, 3.3056, 3.3059,
         3.3023, 3.3042, 3.3057, 3.3056, 3.3065, 3.3053],
        [3.3037, 3.3065, 3.3056, 3.3074, 3.3052, 3.3053, 3.3052, 3.3064, 3.3067,
         3.3031, 3.3050, 3.3065, 3.3063, 3.3073, 3.3061],
        [3.3027, 3.3056, 3.

## Observations:
* The graph embeddings for graphs of label 0 and label 1 have significant difference
* **Thus SS-GNN can distinguish distinct CSL graphs** Success!!!ðŸ˜ŽðŸ˜ŽðŸ˜Ž