# Generating the graph isomorphism dataset

## Setup

In [None]:
from collections import Counter, defaultdict
from hashlib import blake2b

import networkx as nx
from networkx import weisfeiler_lehman_graph_hash, erdos_renyi_graph

## WL score

In [None]:
# Adapted from https://github.com/networkx/networkx/blob/main/networkx/algorithms/graph_hashing.py

def hash_label(label, digest_size):
    return blake2b(label.encode("ascii"), digest_size=digest_size).hexdigest()


def init_node_labels(G):
    return {u: str(deg) for u, deg in G.degree()}


def neighborhood_aggregate(G, node, node_labels):
    """
    Compute new labels for given node by aggregating
    the labels of each node's neighbors.
    """
    label_list = []
    for nbr in G.neighbors(node):
        label_list.append(node_labels[nbr])
    return node_labels[node] + "".join(sorted(label_list))


def weisfeiler_lehman_graph_score(
    G_1, G_2, max_iterations=5, digest_size=16
):
    
    if G_1.number_of_nodes() != G_2.number_of_nodes():
        return 0

    def weisfeiler_lehman_step(G, labels, edge_attr=None):
        """
        Apply neighborhood aggregation to each node
        in the graph.
        Computes a dictionary with labels for each node.
        """
        new_labels = {}
        for node in G.nodes():
            label = neighborhood_aggregate(G, node, labels, edge_attr=edge_attr)
            new_labels[node] = hash_label(label, digest_size)
        return new_labels
    
    G = [G_1, G_2]

    # set initial node labels
    node_labels = [init_node_labels(G[i]) for i in range(2)]

    subgraph_hash_counts = [[] for _ in range(2)]
    for iteration in range(max_iterations):
        graph_hash = [[] for _ in range(2)]]
        for i in range(2):
            node_labels[i] = weisfeiler_lehman_step(G[i], node_labels[i])
            counter = Counter(node_labels[i].values())
            subgraph_hash_counts[i].extend(sorted(counter.items(), key=lambda x: x[0]))
            graph_hash[i] = hash_label(str(tuple(subgraph_hash_counts[i])), digest_size)
        if graph_hash[0] != graph_hash[1]:
            return iteration + 1
        
    return None
