# Testing graph isomorphism agents

## Setup

In [65]:
FORCE_CPU = True

SEED = 349287

LOAD_CHECKPOINT = True
CHECKPOINT_RUN_ID = "ppo_gi_next_1_0"
CHECKPOINT_VERSION = "v21"

BATCH_SIZE = 128

In [66]:
import torch
from torch import Tensor
from torch import nn

import numpy as np

from tqdm import tqdm

import pandas as pd

import plotly.graph_objects as go
import plotly.express as px

from rich.console import Console
from rich.table import Table

from pvg import (
    Parameters,
    ScenarioType,
    TrainerType,
    AgentsParameters,
    GraphIsomorphismAgentParameters,
)
from pvg.experiment_settings import ExperimentSettings
from pvg.graph_isomorphism import (
    GraphIsomorphismAgentHooks,
)
from pvg.scenario_instance import build_scenario_instance
from pvg.scenario_base import DataLoader

In [67]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [68]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Create scenario and agents

In [69]:
params = Parameters(
    scenario=ScenarioType.GRAPH_ISOMORPHISM,
    trainer=TrainerType.SOLO_AGENT,
    dataset="eru10000",
    seed=SEED,
    agents=AgentsParameters(
        verifier=GraphIsomorphismAgentParameters(
            load_checkpoint_and_parameters=LOAD_CHECKPOINT,
            checkpoint_run_id=CHECKPOINT_RUN_ID,
            checkpoint_version=CHECKPOINT_VERSION,
        )
    ),
)

In [70]:
settings = ExperimentSettings(
    device=device,
)

In [71]:
scenario_instance = build_scenario_instance(
    params=params,
    settings=settings,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msamadamday[0m ([33mlrhammond-team[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   8 of 8 files downloaded.  


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
mean_episode_length,1.0
prover.mean_reward,0.201
prover.pretrain_test_accuracy,0.72352
prover.pretrain_test_loss,0.48423
verifier.mean_reward,0.136
verifier.pretrain_test_accuracy,0.62335
verifier.pretrain_test_loss,0.60665


In [72]:
verifier_body = scenario_instance.agents["verifier"].body

In [73]:
dataset = scenario_instance.train_dataset

In [74]:
dataset

GraphIsomorphismDataset(
    fields={
        adjacency: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=True),
        node_mask: MemoryMappedTensor(shape=torch.Size([10000, 2, 11]), device=cpu, dtype=torch.bool, is_shared=True),
        wl_score: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int32, is_shared=True),
        x: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 8]), device=cpu, dtype=torch.float32, is_shared=True),
        y: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([10000]),
    device=None,
    is_shared=False)

In [75]:
max_num_nodes = dataset["x"].shape[-2]
max_num_nodes

11

## Looking at intermediate computations

In [76]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    generator=torch_generator,
)

In [77]:
statistic_collectors = {}

In [78]:
statistic_collectors["gnn_output"] = lambda storage: (
    storage["gnn_output"][:, 0].sum(dim=1) - storage["gnn_output"][:, 1].sum(dim=1)
).mean(dim=1)

In [79]:
statistic_collectors["gnn_output_rounded"] = lambda storage: (
    storage["gnn_output_rounded"][:, 0].sum(dim=1)
    - storage["gnn_output_rounded"][:, 1].sum(dim=1)
).mean(dim=1)

In [80]:
statistic_collectors["gnn_output_flatter"] = lambda storage: (
    storage["gnn_output_flatter"][:, :max_num_nodes].sum(dim=1)
    - storage["gnn_output_flatter"][:, max_num_nodes:].sum(dim=1)
).mean(dim=1)

In [81]:
statistic_collectors["transformer_input_initial"] = lambda storage: (
    storage["transformer_input_initial"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input_initial"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [82]:
statistic_collectors["transformer_input_pre_encoder"] = lambda storage: (
    storage["transformer_input_pre_encoder"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input_pre_encoder"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [83]:
statistic_collectors["transformer_input"] = lambda storage: (
    storage["transformer_input"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [84]:
statistic_collectors["transformer_output_flatter"] = lambda storage: (
    storage["transformer_output_flatter"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_output_flatter"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [85]:
statistic_collectors["node_level_repr_pre_encoder"] = lambda storage: (
    storage["node_level_repr_pre_encoder"][:, 0].sum(dim=1)
    - storage["node_level_repr_pre_encoder"][:, 1].sum(dim=1)
).mean(dim=1)

In [None]:
statistic_collectors["node_level_repr"] = lambda storage: (
    storage["node_level_repr"][:, 0].sum(dim=1)
    - storage["node_level_repr"][:, 1].sum(dim=1)
).mean(dim=1)

In [86]:
stats = {
    key: np.empty(len(dataset), dtype=np.float32)
    for key in statistic_collectors.keys()
}

In [87]:
storage = {}
hooks = GraphIsomorphismAgentHooks.create_recorder_hooks(storage)

In [88]:
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    batch["message"] = torch.zeros_like(batch["y"])
    batch["ignore_message"] = torch.ones_like(batch["y"], dtype=torch.bool)
    verifier_body(batch, hooks)
    for key, collector in statistic_collectors.items():
        stats[key][i * BATCH_SIZE : (i + 1) * BATCH_SIZE] = (
            collector(storage).detach().cpu().numpy()
        )

100%|██████████| 79/79 [00:00<00:00, 108.01it/s]


In [89]:
max_wl_score = torch.max(dataset["wl_score"]).item()

In [90]:
means = {}
stds = {}
raw_data = {}
for wl_score in range(-1, max_wl_score + 1):
    if wl_score not in dataset["wl_score"]:
        means[wl_score] = {key: np.nan for key in stats.keys()}
        stds[wl_score] = {key: np.nan for key in stats.keys()}
        raw_data[wl_score] = {key: np.empty(0, dtype=np.float32) for key in stats.keys()}
        continue
    mask = (dataset["wl_score"] == wl_score).numpy()
    means[wl_score] = {key: stats[key][mask].mean() for key in stats.keys()}
    stds[wl_score] = {key: stats[key][mask].std() for key in stats.keys()}
    raw_data[wl_score] = {key: stats[key][mask] for key in stats.keys()}

In [91]:
console = Console()

table = Table()
table.add_column("Score")
for key in stats.keys():
    table.add_column(key)

for key, mean, std in zip(means.keys(), means.values(), stds.values()):
    table.add_row(str(key), *[f"{mean[key]:.4f}" for key in stats.keys()])

console.print(table)


In [92]:
gnn_output_mean = np.array([means[wl_score]["gnn_output"] for wl_score in range(-1, max_wl_score + 1)])
gnn_output_std = np.array([stds[wl_score]["gnn_output"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=gnn_output_mean, error_y=dict(type='data', array=gnn_output_std)))
fig.update_layout(title="gnn_output_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [93]:
gnn_output_rounded_mean = np.array([means[wl_score]["gnn_output_rounded"] for wl_score in range(-1, max_wl_score + 1)])
gnn_output_rounded_std = np.array([stds[wl_score]["gnn_output_rounded"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=gnn_output_rounded_mean, error_y=dict(type='data', array=gnn_output_rounded_std)))
fig.update_layout(title="gnn_output_rounded_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [94]:
node_level_repr_pre_encoder_mean = np.array([means[wl_score]["node_level_repr_pre_encoder"] for wl_score in range(-1, max_wl_score + 1)])
node_level_repr_pre_encoder_std = np.array([stds[wl_score]["node_level_repr_pre_encoder"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=node_level_repr_pre_encoder_mean, error_y=dict(type='data', array=node_level_repr_pre_encoder_std)))
fig.update_layout(title="node_level_repr_pre_encoder_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [95]:
raw_data.keys()

dict_keys([-1, 0, 1, 2, 3, 4, 5])

In [100]:
px.histogram(raw_data[3]["gnn_output"])