# Visualising rollouts

## Setup

In [95]:
WANDB_PROJECT = "pvg-sandbox"
WANDB_RUN_NAME = "ppo_gi_april_shared_min2_2_3"
ITERATION = 999

In [96]:
from pvg.graph_isomorphism import GraphIsomorphismRolloutSamples
from pvg.utils.data import nested_dict_keys_stringified

## Visualisation

In [97]:
with GraphIsomorphismRolloutSamples(
    run_id=WANDB_RUN_NAME, iteration=ITERATION, wandb_project=WANDB_PROJECT
) as rollout_samples:
    rollout_samples.visualise()
    samples = rollout_samples._rollout_samples



In [98]:
nested_dict_keys_stringified(samples[1])

['y',
 'agents.decision',
 'agents.node_selected',
 'agents.node_level_repr',
 'agents.graph_level_repr',
 'agents.decision_logits',
 'agents.sample_log_prob',
 'agents.node_selected_logits',
 'agents.value',
 'round',
 'decision_restriction',
 'adjacency',
 'node_mask',
 'x',
 'message',
 'done',
 'terminated',
 'next.round',
 'next.decision_restriction',
 'next.adjacency',
 'next.node_mask',
 'next.x',
 'next.message',
 'next.agents.reward',
 'next.agents.done',
 'next.agents.terminated',
 'next.agents.value',
 'next.done',
 'next.terminated',
 'collector.traj_ids',
 'advantage',
 'value_target']

In [99]:
samples[1]["y"][-1, 0]

1

In [100]:
samples[1]["agents"]["decision"][-1, 0]

1

In [101]:
samples[1]["next"]["done"]

array([False, False, False, False,  True])

In [102]:
accuracy = 0
mean_reward = 0
for sample in samples:
    print(
        sample["y"][-1, 0],
        sample["agents"]["decision"][-1, 0],
        sample["y"][-1, 0] == sample["agents"]["decision"][-1, 0],
        sample["next"]["agents"]["reward"][:, 0].sum(),
    )
    accuracy += sample["y"][-1, 0] == sample["agents"]["decision"][-1, 0]
    mean_reward += sample["next"]["agents"]["reward"][:, 0].sum()

accuracy /= len(samples)
mean_reward /= len(samples)
print(f"Accuracy: {accuracy}, Mean Reward: {mean_reward}")

1 1 True 1.0
1 1 True 1.1
0 1 False -1.0
0 1 False -0.9
0 1 False -0.9
0 1 False -0.9
1 1 True 1.15
1 1 True 1.15
1 1 True 1.15
1 1 True 1.0
Accuracy: 0.6, Mean Reward: 0.28500000238418577


In [103]:
samples[0]["x"][-1]

array([[[0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32)