# Testing graph isomorphism agents

## Setup

In [1]:
FORCE_CPU = True

SEED = 349287

CHECKPOINT_RUN_ID = "ppo_gi_next_1_0"
CHECKPOINT_VERSION = "v21"

In [2]:
import torch
from torch import Tensor
from torch import nn

from pvg import (
    Parameters,
    ScenarioType,
    TrainerType,
    AgentsParameters,
    GraphIsomorphismAgentParameters,
)
from pvg.experiment_settings import ExperimentSettings
from pvg.graph_isomorphism import (
    GraphIsomorphismAgentHooks,
)
from pvg.scenario_instance import build_scenario_instance

In [3]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [4]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Create scenario and agents

In [5]:
params = Parameters(
    scenario=ScenarioType.GRAPH_ISOMORPHISM,
    trainer=TrainerType.SOLO_AGENT,
    dataset="eru10000",
    seed=SEED,
    agents=AgentsParameters(
        verifier=GraphIsomorphismAgentParameters(
            load_checkpoint_and_parameters=True,
            checkpoint_run_id=CHECKPOINT_RUN_ID,
            checkpoint_version=CHECKPOINT_VERSION,
        )
    ),
)

In [6]:
settings = ExperimentSettings(
    device=device,
)

In [7]:
scenario_instance = build_scenario_instance(
    params=params,
    settings=settings,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msamadamday[0m ([33mlrhammond-team[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   8 of 8 files downloaded.  


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
mean_episode_length,1.0
prover.mean_reward,0.201
prover.pretrain_test_accuracy,0.72352
prover.pretrain_test_loss,0.48423
verifier.mean_reward,0.136
verifier.pretrain_test_accuracy,0.62335
verifier.pretrain_test_loss,0.60665


In [8]:
verifier_body = scenario_instance.agents["verifier"].body

In [9]:
dataset = scenario_instance.train_dataset

In [10]:
dataset

GraphIsomorphismDataset(
    fields={
        adjacency: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=True),
        node_mask: MemoryMappedTensor(shape=torch.Size([10000, 2, 11]), device=cpu, dtype=torch.bool, is_shared=True),
        x: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 8]), device=cpu, dtype=torch.float32, is_shared=True),
        y: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([10000]),
    device=None,
    is_shared=False)

In [11]:
max_num_nodes = dataset["x"].shape[-2]
max_num_nodes

11

## Looking at intermediate computations

In [12]:
DATAPOINT = 1

In [13]:
print("y:", dataset["y"][DATAPOINT].item())

y: 0


In [14]:
storage = {}
hooks = GraphIsomorphismAgentHooks.create_recorder_hooks(storage)
input_td = dataset[DATAPOINT].unsqueeze(0)
input_td["message"] = torch.zeros_like(input_td["y"])
input_td["ignore_message"] = torch.ones_like(input_td["y"], dtype=torch.bool)
verifier_body(input_td, hooks)

TensorDict(
    fields={
        graph_level_repr: Tensor(shape=torch.Size([1, 2, 16]), device=cpu, dtype=torch.float32, is_shared=False),
        node_level_repr: Tensor(shape=torch.Size([1, 2, 11, 16]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

In [15]:
for value in storage.values():
    if value.shape[0] == 1:
        value.squeeze_(0)

In [16]:
for key, value in storage.items():
    print(key, ":", tuple(value.shape))

gnn_output : (2, 11, 16)
gnn_output_rounded : (2, 11, 16)
pooled_gnn_output : (2, 16)
gnn_output_flatter : (22, 16)
transformer_input_initial : (24, 16)
pooled_feature : (24, 2)
message_feature : (24, 1)
transformer_input_pre_encoder : (24, 19)
transformer_input : (24, 16)
transformer_output_flatter : (24, 16)
graph_level_repr_pre_encoder : (2, 16)
node_level_repr_pre_encoder : (2, 11, 16)
graph_level_repr : (2, 16)
node_level_repr : (2, 11, 16)


In [17]:
print("gnn_output")
(storage["gnn_output"][0].sum(dim=0) - storage["gnn_output"][1].sum(dim=0)).mean().item()

gnn_output


0.00804645475000143

In [18]:
print("gnn_output_rounded")
(storage["gnn_output_rounded"][0].sum(dim=0) - storage["gnn_output_rounded"][1].sum(dim=0)).mean().item()

gnn_output_rounded


0.008037504740059376

In [19]:
print("pooled_gnn_output, second graph-level representation")

storage["pooled_gnn_output"][1].mean().item()

pooled_gnn_output, second graph-level representation


0.956008791923523

In [20]:
print("gnn_output_flatter")

(storage["gnn_output_flatter"][:max_num_nodes].sum(dim=0) - storage["gnn_output_flatter"][max_num_nodes:].sum(dim=0)).mean().item()

gnn_output_flatter


0.008037504740059376

In [21]:
print("transformer_input_initial, second graph-level representation")

storage["transformer_input_initial"][1][:-2].mean().item()

transformer_input_initial, second graph-level representation


0.9545800089836121

In [22]:
print("transformer_input_initial, node-level representations")

(storage["transformer_input_initial"][2:max_num_nodes+2].sum(dim=0) - storage["transformer_input_initial"][max_num_nodes+2:].sum(dim=0)).mean().item()

transformer_input_initial, node-level representations


0.008037504740059376

In [23]:
print("transformer_input_pre_encoder, second graph-level representation")

storage["transformer_input_pre_encoder"][1][:-2].mean().item()

transformer_input_pre_encoder, second graph-level representation


0.8997730016708374

In [24]:
print("transformer_input_pre_encoder, node-level representations")

(storage["transformer_input_pre_encoder"][2:max_num_nodes+2].sum(dim=0) - storage["transformer_input_pre_encoder"][max_num_nodes+2:].sum(dim=0)).mean().item()

transformer_input_pre_encoder, node-level representations


0.006768424995243549

In [25]:
print("transformer_input, second graph-level representation")

storage["transformer_input"][1].mean().item()

transformer_input, second graph-level representation


-0.021549656987190247

In [26]:
print("transformer_input, node-level representations")

(storage["transformer_input"][2:max_num_nodes+2].sum(dim=0) - storage["transformer_input"][max_num_nodes+2:].sum(dim=0)).mean().item()

transformer_input, node-level representations


0.003838390577584505

In [27]:
print("transformer_output_flatter, second graph-level representation")

storage["transformer_output_flatter"][1].mean().item()

transformer_output_flatter, second graph-level representation


-0.003930367529392242

In [28]:
print("transformer_output_flatter, node-level representations")

(storage["transformer_output_flatter"][2:max_num_nodes+2].sum(dim=0) - storage["transformer_output_flatter"][max_num_nodes+2:].sum(dim=0)).mean().item()

transformer_output_flatter, node-level representations


-0.0006291046738624573

In [29]:
print("node_level_repr_pre_encoder")

(storage["node_level_repr_pre_encoder"][0].sum(dim=0) - storage["node_level_repr_pre_encoder"][1].sum(dim=0)).mean().item()

node_level_repr_pre_encoder


-0.0006291046738624573