In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm

sys.path.append(str(Path("..").resolve()))

import torch

from torchrl.envs import SerialEnv
from torchrl.data import TensorDictReplayBuffer, LazyTensorStorage

from rl.agent import (
    HRMQNetTrainingConfig,
)
from rl.dataset import MiniHackFullObservationSimpleEnvironmentDataset
from rl.dqn_train_loop import HRMAgentTrainingModule

In [None]:
input_action_map = {
    # Cardinals
    "w": 0,  # move agent '@' north
    "a": 3,  # west
    "s": 2,  # south
    "d": 1,  # east
    # diagonals
    "k": 4,  # NE
    "m": 5,  # SE
    "n": 6,  # SW
    "h": 7,  # NW
}


# utilities to visualise simple minihack room environment
def tensor_to_string(chars_tensor):
    np_chars = chars_tensor.cpu().numpy().transpose()
    out = []
    for row in range(np_chars.shape[0]):
        string = "".join([chr(val) for val in np_chars[row, :]])
        if string.strip() == "":
            continue
        out.append(string)
    return out


# for navigation only
def action_one_hot_to_string(action_one_hot_tensor):
    np_action = action_one_hot_tensor.cpu().numpy()
    np_action_idx = int(np_action.argmax())
    dirs = {
        0: "north",
        1: "east",
        2: "south",
        3: "west",
        4: "north east",
        5: "south east",
        6: "south west",
        7: "north west",
    }

    return dirs[np_action_idx]

In [None]:
# get config
from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf, SCMode

with initialize_config_dir(
    version_base=None,
    config_dir=str(Path("../rl/config").resolve()),
    job_name="test_cfg",
):
    cfg = compose(config_name="cfg_dqn.yaml")

typed_cfg: HRMQNetTrainingConfig = OmegaConf.to_container(
    OmegaConf.merge(OmegaConf.structured(HRMQNetTrainingConfig), cfg),
    structured_config_mode=SCMode.INSTANTIATE,
)

# for speed, we will reduce batch size, number of frames, etc.
typed_cfg.dataset.env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"
typed_cfg.dataset.seq_len = 13 * 11
typed_cfg.dataset.data_collection_batch_size = 32
typed_cfg.dataset.frames_per_update = 320

# for testing only
typed_cfg.use_last_hidden_state_to_seed_next_environment_step = True
typed_cfg.dataset.do_not_skip_running_model_if_random_action = (
    typed_cfg.use_last_hidden_state_to_seed_next_environment_step
)
if typed_cfg.use_last_hidden_state_to_seed_next_environment_step:
   typed_cfg.dataset.training_batch_size = typed_cfg.dataset.data_collection_batch_size 

In [None]:
dataset = MiniHackFullObservationSimpleEnvironmentDataset(config=typed_cfg.dataset)
with torch.device(
    "cuda"
):  # make sure that the buffers used in HRM are initialised on CUDA for backprop
    hrm_agent_training_module = HRMAgentTrainingModule(typed_cfg, dataset)
dataset.initialise_policy_and_collector(
    hrm_agent_training_module.actor, hrm_agent_training_module.egreedy_module
)

# Test playing with 1 env

In [None]:
from IPython.display import clear_output
from copy import deepcopy

new_cfg = deepcopy(typed_cfg)
new_cfg.dataset.data_collection_batch_size = 1

dg_shape = (11, 13)

dataset_play = MiniHackFullObservationSimpleEnvironmentDataset(config=new_cfg.dataset)
envs = SerialEnv(dataset_play.config.frames_per_update, dataset_play.create_env)
inner_current_state = envs.reset()

while True:
    clear_output(wait=True)

    # print last reward if any
    if "next" in inner_current_state:
        print(
            "Last action:", action_one_hot_to_string(inner_current_state[0]["action"])
        )
        print("Last action reward:", inner_current_state[0]["reward"].item())

    # visualise the current environment
    t = tensor_to_string(inner_current_state[0]["inputs"].reshape(dg_shape).T)
    print("\n".join(t))

    # Keyboard control (4 way)
    x = input()

    input_action_map = {
        # Cardinals
        "w": 0,  # move agent '@' north
        "a": 3,  # west
        "s": 2,  # south
        "d": 1,  # east
        # diagonals
        "k": 4,  # NE
        "m": 5,  # SE
        "n": 6,  # SW
        "h": 7,  # NW
    }

    if x in input_action_map:
        a = input_action_map[x]
    else:
        break

    # pass the current state through to our policy and get the actions to take
    policy_decision = inner_current_state.clone()
    policy_decision["action"] = torch.tensor(
        [
            [0 for _ in range(typed_cfg.dataset.env_kwargs["action-space"])]
            for _ in range(dataset.config.frames_per_update)
        ]
    )
    policy_decision["action"][:, a] = 1

    # step the environmet
    transitions, inner_current_state = envs.step_and_maybe_reset(policy_decision)

# Random actions from initialised policy

In [None]:
sampled_from_buffer = next(iter(dataset))

# parallel: 67.6s, ~25 is spinning up; serial is 159.8s, for 128,000 frames on 32 envs

In [None]:
dataset.buffer

In [None]:
dg_shape = (11, 13)
idx = 24
t = tensor_to_string(sampled_from_buffer[idx]["inputs"].reshape(dg_shape).T)
print("Before:\n")
print("\n".join(t))
print("\n")
print("Action:", action_one_hot_to_string(sampled_from_buffer[idx]["action"]))
print("\n")
print("After:\n")
t2 = tensor_to_string(sampled_from_buffer[idx]["next"]["inputs"].reshape(dg_shape).T)
print("\n".join(t2))

# South action

test the internal code to check that resetting works predictably

In [None]:
dataset.buffer.empty()
buffer = TensorDictReplayBuffer(
    batch_size=dataset.config.data_collection_batch_size,
    storage=LazyTensorStorage(dataset.config.buffer_capacity),
    prefetch=dataset.config.num_workers,
    transform=lambda td: td.to(dataset.config.storing_device),
)
buffer.empty()
envs = SerialEnv(dataset.config.frames_per_update, dataset.create_env)
inner_current_state = envs.reset()

for _ in tqdm(range(55)):
    # pass the current state through to our policy and get the actions to take
    policy_decision = inner_current_state.clone()
    policy_decision["action"] = torch.tensor(
        [
            [0 for _ in range(dataset.config.action_space_size)]
            for _ in range(dataset.config.frames_per_update)
        ]
    )
    policy_decision["action"][:, input_action_map["s"]] = 1

    # step the environmet
    transitions, inner_current_state = envs.step_and_maybe_reset(policy_decision)

    # chug that in the buffer
    buffer.extend(transitions)

In [None]:
batch_idx = 10
time_idx = 3
buffer_idx = time_idx * dataset.config.frames_per_update + batch_idx

dg_shape = (11, 13)
t = tensor_to_string(buffer[buffer_idx]["inputs"].reshape(dg_shape).T)
print("Before:\n")
print("\n".join(t))
print("\n")
print("Reward:", buffer[buffer_idx]["next"]["reward"].item())
print("Action:", action_one_hot_to_string(buffer[buffer_idx]["action"]))
print("\n")
print("After:\n")
t2 = tensor_to_string(buffer[buffer_idx]["next"]["inputs"].reshape(dg_shape).T)
print("\n".join(t2))

# Test validation loop

In [None]:
validation_trajectories = dataset.validation_rollout()

In [None]:
validation_trajectories