#### MPNNs (Message-Passing Neural Networks) are at most as powerful as the 1-WL test in terms of distinguishing non-isomorphic graphs

In [3]:
import networkx as nx
from collections import Counter

def wl_refinement(G, steps=5):
    # initialize all colors to 0
    colors = {v: 0 for v in G.nodes()}

    for _ in range(steps):
        new_colors = {}
        for v in G.nodes():
            neigh_colors = sorted(colors[u] for u in G.neighbors(v))
            signature = (colors[v], tuple(neigh_colors))
            new_colors[v] = hash(signature)

        # compress hashes to small integers
        unique = {c:i for i,c in enumerate(set(new_colors.values()))}
        colors = {v: unique[new_colors[v]] for v in G.nodes()}

    return Counter(colors.values())

# G1: cycle of length 6
G1 = nx.cycle_graph(6)

# G2: two disjoint cycles of length 3
C3 = nx.cycle_graph(3)
G2 = nx.disjoint_union(C3, C3)



print("C6 colors:", wl_refinement(G1))
print("2xC3 colors:", wl_refinement(G2))

C6 colors: Counter({0: 6})
2xC3 colors: Counter({0: 6})


In [6]:
import torch
from torch import nn

class SimpleMPNN(nn.Module):
    """
    Dimension-preserving MPNN layer:
    - mean aggregation
    - linear + ReLU update
    """
    def __init__(self, dim: int):
        super().__init__()
        self.lin = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        src, dst = edge_index  # [num_edges]

        agg = torch.zeros_like(x)
        agg.index_add_(0, dst, x[src])

        deg = torch.zeros(x.size(0), device=x.device)
        deg.index_add_(0, dst, torch.ones_like(dst, dtype=torch.float))
        deg = deg.clamp(min=1.).unsqueeze(-1)

        agg = agg / deg  # mean aggregation
        return torch.relu(self.lin(agg))


def to_edge_index(G: nx.Graph) -> torch.Tensor:
    edges = []
    for u, v in G.edges():
        edges.append((u, v))
        edges.append((v, u))  # undirected -> both directions
    return torch.tensor(edges, dtype=torch.long).t()  # [2, num_edges]


edge1 = to_edge_index(G1)
edge2 = to_edge_index(G2)

# ---------- 4. Run the same MPNN on both graphs ----------

dim = 4
gnn = SimpleMPNN(dim)

# all nodes start with identical features
x1 = torch.ones((G1.number_of_nodes(), dim))
x2 = torch.ones((G2.number_of_nodes(), dim))

h1, h2 = x1, x2

num_layers = 4
for _ in range(num_layers):
    h1 = gnn(h1, edge1)
    h2 = gnn(h2, edge2)

# graph-level embeddings (mean pooling)
g1_emb = h1.mean(dim=0)
g2_emb = h2.mean(dim=0)

print("Graph embedding C6:   ", g1_emb.detach().numpy())
print("Graph embedding 2xC3: ", g2_emb.detach().numpy())
print("Difference norm:", torch.norm(g1_emb - g2_emb).item())

Graph embedding C6:    [0.10244211 0.         0.         0.55158514]
Graph embedding 2xC3:  [0.10244211 0.         0.         0.55158514]
Difference norm: 0.0
