In [1]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm

sys.path.append(str(Path("..").resolve()))

import torch

from rl.agent import (
    HRMQNetTrainingConfig,
)
from rl.dataset import (
    MiniHackFullObservationSimpleEnvironmentDataset,
)
from rl.dqn_train_loop import HRMAgentTrainingModule

  import pkg_resources


In [2]:
# get config
from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf, SCMode

with initialize_config_dir(
    version_base=None,
    config_dir=str(Path("../rl/config").resolve()),
    job_name="test_cfg",
):
    cfg = compose(config_name="cfg_dqn.yaml")

typed_cfg: HRMQNetTrainingConfig = OmegaConf.to_container(
    OmegaConf.merge(OmegaConf.structured(HRMQNetTrainingConfig), cfg),
    structured_config_mode=SCMode.INSTANTIATE,
)

# for speed, we will reduce batch size, number of frames, etc.
typed_cfg.dataset.env_kwargs["observation_keys"] = ["chars"]
typed_cfg.dataset.env_name = "MiniHack-4-Rooms"
typed_cfg.resume_from_run = None
typed_cfg.dataset.seq_len = 121
typed_cfg.dataset.data_collection_batch_size = 32
typed_cfg.dataset.frames_per_update = 320
typed_cfg.log_wandb = False
typed_cfg.dataset.buffer_capacity = 6000  # make sure the buffer is not exhausted otherwise it will be hard to check continuity of prev_seed_h_init and seed_h_init

typed_cfg.max_training_steps = 50

# set 4 or 8 way
typed_cfg.dataset.action_space_size = 4
typed_cfg.dataset.vocab_size = 131  # if 8, else 131
typed_cfg.dataset.env_kwargs["action-space"] = 4

# undo set hidden state as reseeding
typed_cfg.use_last_hidden_state_to_seed_next_environment_step = True
typed_cfg.dataset.do_not_skip_running_model_if_random_action = (
    typed_cfg.use_last_hidden_state_to_seed_next_environment_step
)

# if reseeding, data collection batch size must match
typed_cfg.dataset.training_batch_size = typed_cfg.dataset.data_collection_batch_size

In [3]:
# load config and initialise QValue net as well as iterable dataset
# requires a base data
dataset = MiniHackFullObservationSimpleEnvironmentDataset(config=typed_cfg.dataset)

with torch.device(
    "cuda"
):  # make sure that the buffers used in HRM are initialised on CUDA for backprop
    hrm_agent_training_module = HRMAgentTrainingModule(typed_cfg, dataset)

original_h_init = (hrm_agent_training_module.qvalue_net.model.inner.H_init).clone()
original_l_init = (hrm_agent_training_module.qvalue_net.model.inner.L_init).clone()

dataset.initialise_policy_and_collector(
    hrm_agent_training_module.actor,
    hrm_agent_training_module.egreedy_module,  # pyright: ignore[reportArgumentType]
)

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.enabled = True
torch.set_float32_matmul_precision("medium")

hrm_agent_training_module.pre_training_setup(
    checkpoint_dir=None,  # this is for local, for s3 run, change to "s3"
    run_name=None,  # this is for local, for s3 run with wandb, leave as None
)
for current_step, training_batch in tqdm(
    enumerate(dataset, start=hrm_agent_training_module.state.global_step),
    total=typed_cfg.max_training_steps,
):
    training_batch = training_batch.to(torch.device("cuda"))
    hrm_agent_training_module.training_step(training_batch)
    hrm_agent_training_module.post_training_step_callbacks()

    if current_step + 1 == typed_cfg.max_training_steps:
        break

with torch.no_grad():
    print(
        "MSE between trained H_init and original:",
        (
            (original_h_init - hrm_agent_training_module.qvalue_net.model.inner.H_init)
            ** 2
        ).sum()
        / typed_cfg.arch_exclude_data_dependent.hidden_size,
    )
    print(
        "MSE between trained L_init and original:",
        (
            (original_l_init - hrm_agent_training_module.qvalue_net.model.inner.L_init)
            ** 2
        ).sum()
        / typed_cfg.arch_exclude_data_dependent.hidden_size,
    )

Run has no name, checkpointing will not be performed
 98%|█████████▊| 49/50 [00:22<00:00,  2.18it/s]

MSE between trained H_init and original: tensor(0., device='cuda:0', dtype=torch.bfloat16)
MSE between trained L_init and original: tensor(0., device='cuda:0', dtype=torch.bfloat16)





In [4]:
# some utils
def true_ranges(lst):
    ranges = []
    start = None
    for i, val in enumerate(lst):
        if val:
            if start is not None:
                ranges.append((start, i - 1))
            start = i
    if start is not None and start < len(lst):
        ranges.append((start, len(lst) - 1))
    return ranges


# Example:
lst = [True, False, False, True, False, False]
print(true_ranges(lst))  # [(0, 2), (3, 5)]
lst = [True, False, False, False]
print(true_ranges(lst))

[(0, 2), (3, 5)]
[(0, 3)]


In [5]:
for episode_idx in range(32):
    episode_buffer_idx = torch.arange(episode_idx, len(dataset.buffer), 32)
    init_flags = dataset.buffer[episode_buffer_idx]["is_init"]
    individual_ranges = true_ranges(init_flags[:, 0])

    for range_start, range_end in individual_ranges:
        range_buffer_idx = episode_buffer_idx[range_start : range_end + 1]
        prev_seed_h_init = dataset.buffer[range_buffer_idx]["prev_seed_h_init"]
        prev_seed_l_init = dataset.buffer[range_buffer_idx]["prev_seed_l_init"]
        seed_h_init = dataset.buffer[range_buffer_idx]["seed_h_init"]
        seed_l_init = dataset.buffer[range_buffer_idx]["seed_l_init"]

        # check that previous = seed of next, and that the first previous is H_init/L_init
        assert (
            prev_seed_h_init[0, :].cpu()
            == hrm_agent_training_module.qvalue_net.model.inner.H_init.cpu()
        ).all()
        assert (
            prev_seed_l_init[0, :].cpu()
            == hrm_agent_training_module.qvalue_net.model.inner.L_init.cpu()
        ).all()
        assert (prev_seed_h_init[1:, :].cpu() == seed_h_init[:-1, :].cpu()).all()
        assert (prev_seed_l_init[1:, :].cpu() == seed_l_init[:-1, :].cpu()).all()