# Testing graph isomorphism agents

## Setup

In [56]:
FORCE_CPU = True

SEED = 349287

LOAD_CHECKPOINT = True
CHECKPOINT_RUN_ID = "ppo_gi_next_norm_hist_no_round_1_1"
CHECKPOINT_VERSION = "v24"
NORMALIZE_MESSAGE_HISTORY_DEFAULT = True

BATCH_SIZE = 128

In [57]:
import torch
from torch import Tensor
from torch import nn

from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential

import numpy as np

from tqdm import tqdm

import pandas as pd

import plotly.graph_objects as go
import plotly.express as px

from rich.console import Console
from rich.table import Table

from pvg import (
    Parameters,
    ScenarioType,
    TrainerType,
    AgentsParameters,
    GraphIsomorphismAgentParameters,
)
from pvg.experiment_settings import ExperimentSettings
from pvg.graph_isomorphism import (
    GraphIsomorphismAgentHooks,
)
from pvg.scenario_instance import build_scenario_instance
from pvg.scenario_base import DataLoader

In [58]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [59]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Create scenario and agents

In [60]:
params = Parameters(
    scenario=ScenarioType.GRAPH_ISOMORPHISM,
    trainer=TrainerType.VANILLA_PPO,
    dataset="eru10000",
    seed=SEED,
    agents=AgentsParameters(
        verifier=GraphIsomorphismAgentParameters(
            load_checkpoint_and_parameters=LOAD_CHECKPOINT,
            checkpoint_run_id=CHECKPOINT_RUN_ID,
            checkpoint_version=CHECKPOINT_VERSION,
            normalize_message_history=NORMALIZE_MESSAGE_HISTORY_DEFAULT,
        ),
        prover=GraphIsomorphismAgentParameters(
            load_checkpoint_and_parameters=LOAD_CHECKPOINT,
            checkpoint_run_id=CHECKPOINT_RUN_ID,
            checkpoint_version=CHECKPOINT_VERSION,
            normalize_message_history=NORMALIZE_MESSAGE_HISTORY_DEFAULT,
        ),
    ),
)

In [61]:
settings = ExperimentSettings(
    device=device,
)

In [62]:
scenario_instance = build_scenario_instance(
    params=params,
    settings=settings,
)



VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112387711182236, max=1.0…



In [63]:
class TensorDictSequentialHooks(TensorDictSequential):
    def forward(
        self,
        tensordict: TensorDictBase,
        tensordict_out: TensorDictBase | None = None,
        **kwargs,
    ) -> TensorDictBase:
        for module in self.module:
            tensordict = self._run_module(module, tensordict, **kwargs)
        if tensordict_out is not None:
            tensordict_out.update(tensordict, inplace=True)
            return tensordict_out
        return tensordict

In [64]:
combined_body = scenario_instance.combined_body
combined_policy_head = scenario_instance.combined_policy_head

In [65]:
combined_actor = TensorDictSequentialHooks(combined_body, combined_policy_head)

In [66]:
dataset = scenario_instance.train_dataset

In [67]:
dataset

GraphIsomorphismDataset(
    fields={
        adjacency: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=True),
        node_mask: MemoryMappedTensor(shape=torch.Size([10000, 2, 11]), device=cpu, dtype=torch.bool, is_shared=True),
        wl_score: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int32, is_shared=True),
        x: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 8]), device=cpu, dtype=torch.float32, is_shared=True),
        y: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([10000]),
    device=None,
    is_shared=False)

In [68]:
max_num_nodes = dataset["x"].shape[-2]
max_num_nodes

11

## Looking at intermediate computations

In [69]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    generator=torch_generator,
)

In [70]:
statistic_collectors = {}

In [71]:
statistic_collectors["gnn_output"] = lambda storage: (
    storage["gnn_output"][:, 0].sum(dim=1) - storage["gnn_output"][:, 1].sum(dim=1)
).mean(dim=1)

In [72]:
statistic_collectors["gnn_output_rounded"] = lambda storage: (
    storage["gnn_output_rounded"][:, 0].sum(dim=1)
    - storage["gnn_output_rounded"][:, 1].sum(dim=1)
).mean(dim=1)

In [73]:
statistic_collectors["gnn_output_flatter"] = lambda storage: (
    storage["gnn_output_flatter"][:, :max_num_nodes].sum(dim=1)
    - storage["gnn_output_flatter"][:, max_num_nodes:].sum(dim=1)
).mean(dim=1)

In [74]:
statistic_collectors["transformer_input_initial"] = lambda storage: (
    storage["transformer_input_initial"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input_initial"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [75]:
statistic_collectors["transformer_input_pre_encoder"] = lambda storage: (
    storage["transformer_input_pre_encoder"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input_pre_encoder"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [76]:
statistic_collectors["transformer_input"] = lambda storage: (
    storage["transformer_input"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [77]:
statistic_collectors["transformer_output_flatter"] = lambda storage: (
    storage["transformer_output_flatter"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_output_flatter"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [78]:
statistic_collectors["node_level_repr_pre_encoder"] = lambda storage: (
    storage["node_level_repr_pre_encoder"][:, 0].sum(dim=1)
    - storage["node_level_repr_pre_encoder"][:, 1].sum(dim=1)
).mean(dim=1)

In [79]:
statistic_collectors["node_level_repr"] = lambda storage: (
    storage["node_level_repr"][:, 0].sum(dim=1)
    - storage["node_level_repr"][:, 1].sum(dim=1)
).mean(dim=1)

In [80]:
stats = {
    key: np.empty(len(dataset), dtype=np.float32)
    for key in statistic_collectors.keys()
}

In [81]:
storage = {}
hooks = GraphIsomorphismAgentHooks.create_recorder_hooks(storage, per_agent=True)

In [82]:
decision_logits = np.empty((len(dataset), 3), dtype=np.float32)
correct_decision = np.empty(len(dataset), dtype=bool)

In [83]:
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    batch["message"] = torch.zeros_like(batch["y"])
    batch["ignore_message"] = torch.ones_like(batch["y"], dtype=torch.bool)
    batch["round"] = torch.zeros(batch.batch_size, dtype=torch.int64)
    batch["decision_restriction"] = torch.zeros(batch.batch_size, dtype=torch.int64)
    actor_out = combined_actor(batch, hooks=hooks)
    batch_decision_logits = (
        actor_out["agents"]["decision_logits"][:, 0].detach().cpu().numpy()
    )
    decision_logits[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] = batch_decision_logits
    batch_y = batch["y"].detach().cpu().numpy()
    correct_decision[i * BATCH_SIZE : (i + 1) * BATCH_SIZE] = (
        batch_y == batch_decision_logits.argmax(axis=1)
    )
    for key, collector in statistic_collectors.items():
        stats[key][i * BATCH_SIZE : (i + 1) * BATCH_SIZE] = (
            collector(storage["verifier"]).detach().cpu().numpy()
        )

100%|██████████| 79/79 [00:01<00:00, 40.76it/s]


In [84]:
actor_out

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([16, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        agents: TensorDict(
            fields={
                decision_logits: Tensor(shape=torch.Size([16, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                graph_level_repr: Tensor(shape=torch.Size([16, 2, 2, 16]), device=cpu, dtype=torch.float32, is_shared=False),
                node_level_repr: Tensor(shape=torch.Size([16, 2, 2, 11, 16]), device=cpu, dtype=torch.float32, is_shared=False),
                node_selected_logits: Tensor(shape=torch.Size([16, 2, 22]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([16]),
            device=None,
            is_shared=False),
        decision_restriction: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.int64, is_shared=False),
        ignore_message: Tensor(shape=torch.Size([16]), device=cpu, dtype=torch.bool, is_shared=False),
   

In [85]:
max_wl_score = torch.max(dataset["wl_score"]).item()

In [86]:
means = {}
stds = {}
raw_data = {}
for wl_score in range(-1, max_wl_score + 1):
    if wl_score not in dataset["wl_score"]:
        means[wl_score] = {key: np.nan for key in stats.keys()}
        stds[wl_score] = {key: np.nan for key in stats.keys()}
        raw_data[wl_score] = {key: np.empty(0, dtype=np.float32) for key in stats.keys()}
        continue
    mask = (dataset["wl_score"] == wl_score).numpy()
    means[wl_score] = {key: stats[key][mask].mean() for key in stats.keys()}
    stds[wl_score] = {key: stats[key][mask].std() for key in stats.keys()}
    raw_data[wl_score] = {key: stats[key][mask] for key in stats.keys()}

In [87]:
console = Console()

table = Table()
table.add_column("Score")
for key in stats.keys():
    table.add_column(key)

for key, mean, std in zip(means.keys(), means.values(), stds.values()):
    table.add_row(str(key), *[f"{mean[key]:.4f}" for key in stats.keys()])

console.print(table)


In [88]:
gnn_output_mean = np.array([means[wl_score]["gnn_output"] for wl_score in range(-1, max_wl_score + 1)])
gnn_output_std = np.array([stds[wl_score]["gnn_output"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=gnn_output_mean, error_y=dict(type='data', array=gnn_output_std)))
fig.update_layout(title="gnn_output_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [89]:
gnn_output_rounded_mean = np.array([means[wl_score]["gnn_output_rounded"] for wl_score in range(-1, max_wl_score + 1)])
gnn_output_rounded_std = np.array([stds[wl_score]["gnn_output_rounded"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=gnn_output_rounded_mean, error_y=dict(type='data', array=gnn_output_rounded_std)))
fig.update_layout(title="gnn_output_rounded_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [90]:
node_level_repr_pre_encoder_mean = np.array([means[wl_score]["node_level_repr_pre_encoder"] for wl_score in range(-1, max_wl_score + 1)])
node_level_repr_pre_encoder_std = np.array([stds[wl_score]["node_level_repr_pre_encoder"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=node_level_repr_pre_encoder_mean, error_y=dict(type='data', array=node_level_repr_pre_encoder_std)))
fig.update_layout(title="node_level_repr_pre_encoder_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [91]:
raw_data.keys()

dict_keys([-1, 0, 1, 2, 3, 4, 5])

In [92]:
px.histogram(raw_data[3]["node_level_repr_pre_encoder"])

## Combining the body and policy head

In [93]:
decision_logits.argmax(axis=1).mean()

0.5022

In [94]:
correct_decision.mean()

0.9978