# Testing graph isomorphism agents

## Setup

In [1]:
FORCE_CPU = True

SEED = 349287

LOAD_CHECKPOINT = True
CHECKPOINT_RUN_ID = "ppo_gi_next_1_0"
CHECKPOINT_VERSION = "v21"
NORMALIZE_MESSAGE_HISTORY_DEFAULT = False

BATCH_SIZE = 128

In [2]:
import torch
from torch import Tensor
from torch import nn

import numpy as np

from tqdm import tqdm

import pandas as pd

import plotly.graph_objects as go
import plotly.express as px

from rich.console import Console
from rich.table import Table

from pvg import (
    Parameters,
    ScenarioType,
    TrainerType,
    AgentsParameters,
    GraphIsomorphismAgentParameters,
)
from pvg.experiment_settings import ExperimentSettings
from pvg.graph_isomorphism import (
    GraphIsomorphismAgentHooks,
)
from pvg.scenario_instance import build_scenario_instance
from pvg.scenario_base import DataLoader

In [3]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [4]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Create scenario and agents

In [5]:
params = Parameters(
    scenario=ScenarioType.GRAPH_ISOMORPHISM,
    trainer=TrainerType.SOLO_AGENT,
    dataset="eru10000",
    seed=SEED,
    agents=AgentsParameters(
        verifier=GraphIsomorphismAgentParameters(
            load_checkpoint_and_parameters=LOAD_CHECKPOINT,
            checkpoint_run_id=CHECKPOINT_RUN_ID,
            checkpoint_version=CHECKPOINT_VERSION,
            normalize_message_history=NORMALIZE_MESSAGE_HISTORY_DEFAULT,
        )
    ),
)

In [6]:
settings = ExperimentSettings(
    device=device,
)

In [7]:
scenario_instance = build_scenario_instance(
    params=params,
    settings=settings,
)

In [8]:
verifier_body = scenario_instance.agents["verifier"].body

In [9]:
dataset = scenario_instance.train_dataset

In [10]:
dataset

GraphIsomorphismDataset(
    fields={
        adjacency: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=True),
        node_mask: MemoryMappedTensor(shape=torch.Size([10000, 2, 11]), device=cpu, dtype=torch.bool, is_shared=True),
        wl_score: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int32, is_shared=True),
        x: MemoryMappedTensor(shape=torch.Size([10000, 2, 11, 8]), device=cpu, dtype=torch.float32, is_shared=True),
        y: MemoryMappedTensor(shape=torch.Size([10000]), device=cpu, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([10000]),
    device=None,
    is_shared=False)

In [11]:
max_num_nodes = dataset["x"].shape[-2]
max_num_nodes

11

## Looking at intermediate computations

In [12]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    generator=torch_generator,
)

In [13]:
statistic_collectors = {}

In [14]:
statistic_collectors["gnn_output"] = lambda storage: (
    storage["gnn_output"][:, 0].sum(dim=1) - storage["gnn_output"][:, 1].sum(dim=1)
).mean(dim=1)

In [15]:
statistic_collectors["gnn_output_rounded"] = lambda storage: (
    storage["gnn_output_rounded"][:, 0].sum(dim=1)
    - storage["gnn_output_rounded"][:, 1].sum(dim=1)
).mean(dim=1)

In [16]:
statistic_collectors["gnn_output_flatter"] = lambda storage: (
    storage["gnn_output_flatter"][:, :max_num_nodes].sum(dim=1)
    - storage["gnn_output_flatter"][:, max_num_nodes:].sum(dim=1)
).mean(dim=1)

In [17]:
statistic_collectors["transformer_input_initial"] = lambda storage: (
    storage["transformer_input_initial"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input_initial"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [18]:
statistic_collectors["transformer_input_pre_encoder"] = lambda storage: (
    storage["transformer_input_pre_encoder"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input_pre_encoder"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [19]:
statistic_collectors["transformer_input"] = lambda storage: (
    storage["transformer_input"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_input"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [20]:
statistic_collectors["transformer_output_flatter"] = lambda storage: (
    storage["transformer_output_flatter"][:, 2:max_num_nodes+2].sum(dim=1)
    - storage["transformer_output_flatter"][:, max_num_nodes+2:].sum(dim=1)
).mean(dim=1)

In [21]:
statistic_collectors["node_level_repr_pre_encoder"] = lambda storage: (
    storage["node_level_repr_pre_encoder"][:, 0].sum(dim=1)
    - storage["node_level_repr_pre_encoder"][:, 1].sum(dim=1)
).mean(dim=1)

In [22]:
statistic_collectors["node_level_repr"] = lambda storage: (
    storage["node_level_repr"][:, 0].sum(dim=1)
    - storage["node_level_repr"][:, 1].sum(dim=1)
).mean(dim=1)

In [23]:
stats = {
    key: np.empty(len(dataset), dtype=np.float32)
    for key in statistic_collectors.keys()
}

In [24]:
storage = {}
hooks = GraphIsomorphismAgentHooks.create_recorder_hooks(storage)

In [25]:
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    batch["message"] = torch.zeros_like(batch["y"])
    batch["ignore_message"] = torch.ones_like(batch["y"], dtype=torch.bool)
    verifier_body(batch, hooks)
    for key, collector in statistic_collectors.items():
        stats[key][i * BATCH_SIZE : (i + 1) * BATCH_SIZE] = (
            collector(storage).detach().cpu().numpy()
        )

  5%|▌         | 4/79 [00:00<00:02, 32.63it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 15%|█▌        | 12/79 [00:00<00:02, 30.97it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 20%|██        | 16/79 [00:00<00:02, 30.06it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 29%|██▉       | 23/79 [00:00<00:01, 29.49it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 37%|███▋      | 29/79 [00:00<00:01, 28.78it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 44%|████▍     | 35/79 [00:01<00:01, 28.43it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 54%|█████▍    | 43/79 [00:01<00:01, 31.67it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 65%|██████▍   | 51/79 [00:01<00:00, 30.57it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 70%|██████▉   | 55/79 [00:01<00:00, 30.17it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 80%|███████▉  | 63/79 [00:02<00:00, 30.48it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 85%|████████▍ | 67/79 [00:02<00:00, 26.34it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

 92%|█████████▏| 73/79 [00:02<00:00, 26.83it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,

100%|██████████| 79/79 [00:02<00:00, 29.11it/s]

tensor([[[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]],

         [[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          ...,
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000]]],


        [[[-0.2035, -0.1879, -0.1710,  ..., -0.1072, -0.0756,  0.0000],
          [-0.2035, -0.1879,




In [26]:
max_wl_score = torch.max(dataset["wl_score"]).item()

In [27]:
means = {}
stds = {}
raw_data = {}
for wl_score in range(-1, max_wl_score + 1):
    if wl_score not in dataset["wl_score"]:
        means[wl_score] = {key: np.nan for key in stats.keys()}
        stds[wl_score] = {key: np.nan for key in stats.keys()}
        raw_data[wl_score] = {key: np.empty(0, dtype=np.float32) for key in stats.keys()}
        continue
    mask = (dataset["wl_score"] == wl_score).numpy()
    means[wl_score] = {key: stats[key][mask].mean() for key in stats.keys()}
    stds[wl_score] = {key: stats[key][mask].std() for key in stats.keys()}
    raw_data[wl_score] = {key: stats[key][mask] for key in stats.keys()}

In [28]:
console = Console()

table = Table()
table.add_column("Score")
for key in stats.keys():
    table.add_column(key)

for key, mean, std in zip(means.keys(), means.values(), stds.values()):
    table.add_row(str(key), *[f"{mean[key]:.4f}" for key in stats.keys()])

console.print(table)


In [29]:
gnn_output_mean = np.array([means[wl_score]["gnn_output"] for wl_score in range(-1, max_wl_score + 1)])
gnn_output_std = np.array([stds[wl_score]["gnn_output"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=gnn_output_mean, error_y=dict(type='data', array=gnn_output_std)))
fig.update_layout(title="gnn_output_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [30]:
gnn_output_rounded_mean = np.array([means[wl_score]["gnn_output_rounded"] for wl_score in range(-1, max_wl_score + 1)])
gnn_output_rounded_std = np.array([stds[wl_score]["gnn_output_rounded"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=gnn_output_rounded_mean, error_y=dict(type='data', array=gnn_output_rounded_std)))
fig.update_layout(title="gnn_output_rounded_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [31]:
node_level_repr_pre_encoder_mean = np.array([means[wl_score]["node_level_repr_pre_encoder"] for wl_score in range(-1, max_wl_score + 1)])
node_level_repr_pre_encoder_std = np.array([stds[wl_score]["node_level_repr_pre_encoder"] for wl_score in range(-1, max_wl_score + 1)])
x = list(range(-1, max_wl_score + 1))

fig = go.Figure(data=go.Scatter(x=x, y=node_level_repr_pre_encoder_mean, error_y=dict(type='data', array=node_level_repr_pre_encoder_std)))
fig.update_layout(title="node_level_repr_pre_encoder_mean", xaxis_title="wl_score", yaxis_title="Mean")
fig.show()

In [32]:
raw_data.keys()

dict_keys([-1, 0, 1, 2, 3, 4, 5])

In [33]:
px.histogram(raw_data[3]["gnn_output"])