# Logging example episodes from runs

## Setup

In [1]:
FORCE_CPU = True

SEED = 349287

SAMPLE_SUB_DATA = "data/sample_collect.pkl"

NUM_ROLLOUTS_TO_SELECT = 10

WANDB_PROJECT = "pvg-sandbox"
WANDB_RUN_NAME = "test_7"

ITERATION = 0

In [2]:
from tempfile import TemporaryDirectory
import os
import pickle

import numpy as np

import torch

from tensordict import TensorDict, TensorDictBase

import wandb

from tqdm import tqdm

from pvg.constants import (
    WANDB_ENTITY,
    ROLLOUT_SAMPLE_ARTIFACT_PREFIX,
    ROLLOUT_SAMPLE_ARTIFACT_TYPE,
    ROLLOUT_SAMPLE_FILENAME,
)
from pvg.graph_isomorphism import GraphIsomorphismRolloutSamples

In [3]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [4]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Load sample sub_data

In [5]:
sample_sub_data: TensorDict = torch.load(SAMPLE_SUB_DATA, map_location=device)
sample_sub_data = sample_sub_data.to(device)
sample_sub_data

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([125, 8, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        agents: TensorDict(
            fields={
                decision: Tensor(shape=torch.Size([125, 8, 2]), device=cpu, dtype=torch.int64, is_shared=False),
                decision_logits: Tensor(shape=torch.Size([125, 8, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                graph_level_repr: Tensor(shape=torch.Size([125, 8, 2, 2, 16]), device=cpu, dtype=torch.float32, is_shared=False),
                node_level_repr: Tensor(shape=torch.Size([125, 8, 2, 2, 11, 16]), device=cpu, dtype=torch.float32, is_shared=False),
                node_selected: Tensor(shape=torch.Size([125, 8, 2]), device=cpu, dtype=torch.int64, is_shared=False),
                node_selected_logits: Tensor(shape=torch.Size([125, 8, 2, 22]), device=cpu, dtype=torch.float32, is_shared=False),
                sample_log_prob: Tensor(shape=torch.Size([125,

## Select some rollouts

In [6]:
def to_numpy_dict(data: TensorDictBase | dict) -> dict:
    if isinstance(data, TensorDictBase):
        data = data.to_dict()
    for key, value in data.items():
        if isinstance(value, torch.Tensor):
            data[key] = value.detach().cpu().numpy()
        else:
            data[key] = to_numpy_dict(value)
    return data

In [7]:
bids = torch.where(
    sample_sub_data["next", "done"],
    torch.rand_like(sample_sub_data["next", "done"], dtype=torch.float32) + 1,
    0.0,
)
_, index_flat = torch.topk(bids.flatten(), NUM_ROLLOUTS_TO_SELECT)
batch_ids, episode_ids = np.unravel_index(index_flat.numpy(), bids.shape)

rollout_samples: list[TensorDict] = []

for batch_id, episode_id in zip(batch_ids.flat, episode_ids.flat):
    # Determine the start of the episode
    episode_start_id = episode_id - 1
    while (
        episode_start_id >= 0
        and not sample_sub_data["next", "done"][batch_id, episode_start_id]
    ):
        episode_start_id -= 1
    episode_start_id += 1
    rollout_td = sample_sub_data[batch_id, episode_start_id : episode_id + 1]
    rollout_samples.append(to_numpy_dict(rollout_td))

## Save to Weights and Biases

In [8]:
wandb_run = wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, name=WANDB_RUN_NAME)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msamadamday[0m ([33mlrhammond-team[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
artifact = wandb.Artifact(
    name=f"{ROLLOUT_SAMPLE_ARTIFACT_PREFIX}{WANDB_RUN_NAME}",
    type=ROLLOUT_SAMPLE_ARTIFACT_TYPE,
)
wandb_run.use_artifact(artifact)
# artifact = wandb_run.use_artifact(artifact)
artifact = artifact.new_draft()
with TemporaryDirectory() as temp_dir:
    file_path = os.path.join(temp_dir, ROLLOUT_SAMPLE_FILENAME)
    with open(file_path, "wb") as f:
        pickle.dump(rollout_samples, f)
    artifact.add_file(file_path, f"iteration_{ITERATION}")
wandb_run.use_artifact(artifact)

<Artifact QXJ0aWZhY3Q6NjkxMzY0NzU2>

In [10]:
artifact = artifact.new_draft()
with TemporaryDirectory() as temp_dir:
    file_path = os.path.join(temp_dir, ROLLOUT_SAMPLE_FILENAME)
    with open(file_path, "wb") as f:
        pickle.dump(rollout_samples, f)
    artifact.add_file(file_path, f"iteration_{ITERATION+1}")
wandb_run.use_artifact(artifact)

<Artifact QXJ0aWZhY3Q6NjkxMzY0Nzky>

In [11]:
wandb_run.finish()



VBox(children=(Label(value='0.007 MB of 0.007 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

## Load rollouts

In [17]:
rollout_samples = GraphIsomorphismRolloutSamples(
    run_id=WANDB_RUN_NAME, iteration=ITERATION, wandb_project=WANDB_PROJECT
)

[34m[1mwandb[0m:   2 of 2 files downloaded.  


In [87]:
rollout_samples.visualise()