# Generating the graph isomorphism dataset

## Setup

In [12]:
FORCE_CPU = True

SAVE_DATA = False
LOAD_DATA = True

BATCH_SIZE = int(1e6)

GENERATION_TEST_DATA_FILE = "data/generation_test_data.pkl"

DATASET_CONFIG = dict(
    # Number of samples to generate
    num_samples=10000,
    # The size of the graphs. The dataset will be split evenly between the sizes
    graph_sizes=[7, 8, 9, 10, 11],
    # The edge probabilities to use when generating the graphs
    edge_probabilities=[0.2, 0.4, 0.6, 0.8],
    # The proportion of samples that consists of non-isomorphic pairs
    prop_non_isomorphic=0.5,
    # Config for the non-isomorphic pairs
    non_isomorphic=dict(
        # The proportion of non-isomorphic pairs with scores 1 and 2. The rest will have
        # scores greater than 2
        prop_score_1=0.1,
        prop_score_2=0.2,
    ),
    isomorphic=dict(
        # The proportion of isomorphic pairs sampled from a non-isomorphic pair
        prop_from_non_isomorphic=0.5,
    ),
)

In [13]:
from collections import Counter, defaultdict
from hashlib import blake2b
import pickle

import networkx as nx
from networkx import weisfeiler_lehman_graph_hash, erdos_renyi_graph

import torch

import einops

from sklearn.model_selection import ParameterGrid

from tqdm import tqdm

from primesieve.numpy import n_primes

from rich.console import Console
from rich.table import Table

import plotly.graph_objs as go

In [14]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## WL score

In [15]:
def batch_wl_score(
    adjacency_1, adjacency_2, max_iterations=5, hash_size=2**24 - 1, device=device
):
    """
    Compute the Weisfeiler-Lehman graph kernel for a batch of graphs.
    """

    assert adjacency_1.shape[0] == adjacency_2.shape[0]
    assert adjacency_1.shape[1] == adjacency_2.shape[1]

    batch_size = adjacency_1.shape[0]
    num_nodes = adjacency_1.shape[1]

    primes = torch.from_numpy(n_primes(hash_size)).to(device)

    scores = torch.ones(batch_size, dtype=torch.long, device=device) * -1
    labels = torch.ones((2, batch_size, num_nodes), dtype=torch.long, device=device)

    # (graph, batch, node, node)
    adjacency_combined = torch.stack((adjacency_1, adjacency_2), dim=0)
    adjacency_combined += torch.eye(num_nodes, dtype=torch.long, device=device)

    for i in range(max_iterations):
        labels_repeated = einops.repeat(
            labels, "graph batch node1 -> graph batch node2 node1", node2=num_nodes
        )
        labels_neighbours = labels_repeated * adjacency_combined
        labels_neighbours = primes[labels_neighbours]
        labels_neighbours = einops.reduce(
            labels_neighbours, "graph batch node1 node2 -> graph batch node1", "prod"
        )
        labels = torch.remainder(labels_neighbours, hash_size)
        graph_hashes = einops.reduce(
            primes[labels], "graph batch node -> graph batch", "prod"
        )
        diff = graph_hashes[0] != graph_hashes[1]
        scores = torch.where(torch.logical_and(scores == -1, diff), i + 1, scores)

    return scores

## Graph generation

In [16]:
def batch_er_graph(num_samples, graph_size, edge_probability, device=device):
    adjacency_values = torch.rand(num_samples, graph_size, graph_size, device=device)
    adjacency = (adjacency_values < edge_probability).int()
    adjacency = adjacency.triu(diagonal=1)
    adjacency += adjacency.transpose(1, 2).clone()
    return adjacency

## Testing generation

In [17]:
num_pairs = int(1e7)
parameter_grid = {
    "edge_prob": [0.2, 0.4, 0.6, 0.8],
    "graph_order": [4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
}
parameter_iter = ParameterGrid(parameter_grid)

In [18]:
if LOAD_DATA:
    with open(GENERATION_TEST_DATA_FILE, "rb") as f:
        results = pickle.load(f)
else:
    results = {}

    for params in parameter_iter:
        params_tuple = (params["edge_prob"], params["graph_order"])
        results[params_tuple] = {
            "score_2_count": 0,
            "score_gt_2_count": 0,
        }
        for batch_start in tqdm(
            range(0, num_pairs, BATCH_SIZE), desc=f"{params_tuple}"
        ):
            batch_end = min(batch_start + BATCH_SIZE, num_pairs)
            batch_size = batch_end - batch_start
            adjacency_1 = batch_er_graph(
                batch_size, params["graph_order"], params["edge_prob"]
            )
            adjacency_2 = batch_er_graph(
                batch_size, params["graph_order"], params["edge_prob"]
            )
            score = batch_wl_score(adjacency_1, adjacency_2)
            results[params_tuple]["score_2_count"] += torch.sum(score == 2).item()
            results[params_tuple]["score_gt_2_count"] += torch.sum(score > 2).item()
        print(results[params_tuple])

In [19]:
if SAVE_DATA:
    with open(GENERATION_TEST_DATA_FILE, "wb") as f:
        pickle.dump(results, f)

In [20]:
table_2 = Table(title="Number of score 2 pairs")
table_gt_2 = Table(title="Number of score > 2 pairs")

table_2.add_column("Edge prob", justify="right")
table_gt_2.add_column("Edge prob", justify="right")
for graph_order in parameter_grid["graph_order"]:
    table_2.add_column(f"Order {graph_order}", justify="right")
    table_gt_2.add_column(f"Order {graph_order}", justify="right")

for edge_prob in parameter_grid["edge_prob"]:
    table_2.add_row(
        str(edge_prob),
        *[
            str(results[(edge_prob, graph_order)]["score_2_count"])
            for graph_order in parameter_grid["graph_order"]
        ],
    )
    table_gt_2.add_row(
        str(edge_prob),
        *[
            str(results[(edge_prob, graph_order)]["score_gt_2_count"])
            for graph_order in parameter_grid["graph_order"]
        ],
    )

console = Console()
console.print(table_2)
console.print(table_gt_2)

In [21]:
# extract data from results dictionary
edge_probs = parameter_grid["edge_prob"]
graph_orders = parameter_grid["graph_order"]
score_2_counts = [[results[(edge_prob, graph_order)]["score_2_count"] for graph_order in graph_orders] for edge_prob in edge_probs]

# create heatmap
heatmap = go.Heatmap(
    z=score_2_counts,
    x=graph_orders,
    y=edge_probs,
    colorscale='Viridis'
)

# create layout
layout = go.Layout(
    title='Score 2 counts',
    xaxis=dict(title='Graph order'),
    yaxis=dict(title='Edge probability')
)

# create figure
fig = go.Figure(data=[heatmap], layout=layout)

# show figure
fig.show()

In [22]:
# extract data from results dictionary
edge_probs = parameter_grid["edge_prob"]
graph_orders = parameter_grid["graph_order"]
score_gt_2_counts = [[results[(edge_prob, graph_order)]["score_gt_2_count"] for graph_order in graph_orders] for edge_prob in edge_probs]

# create heatmap
heatmap = go.Heatmap(
    z=score_gt_2_counts,
    x=graph_orders,
    y=edge_probs,
    colorscale='Viridis'
)

# create layout
layout = go.Layout(
    title='Score > 2 counts',
    xaxis=dict(title='Graph order'),
    yaxis=dict(title='Edge probability')
)

# create figure
fig = go.Figure(data=[heatmap], layout=layout)

# show figure
fig.show()