---
title: "HRM Agent Plots write-up"
format: 
  html:
    code-fold: true
jupyter: python3
---

In [None]:
# | include: False

%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path

sys.path.append(str(Path("..").resolve()))
import pickle
import torch
import matplotlib.pyplot as plt
from torch import nn

from copy import deepcopy
from dotenv import load_dotenv
from tensordict import TensorDict
from torchrl.modules import QValueActor, EGreedyModule

from rl.agent import HRMQNetTrainingConfig, HRMQValueNet
from rl.dataset import MiniHackFullObservationSimpleEnvironmentDataset, GymRLDataset
from rl.dqn_train_loop import HRMAgentTrainingModule

load_dotenv()

We first define some useful utilities:

In [None]:
def set_cfg_env(cfg: HRMQNetTrainingConfig, env_name: str | None = None):
    cfg_copy = deepcopy(cfg)

    if env_name is not None:
        assert (
            env_name == "MiniHack-Corridor-Maze-4-Way-Dynamic"
            or env_name == "MiniHack-4-Rooms"
        )
        if env_name == "MiniHack-Corridor-Maze-4-Way-Dynamic":
            cfg_copy.dataset.env_name = env_name
            cfg_copy.dataset.seq_len = 143
        else:
            cfg_copy.dataset.env_name = env_name
            cfg_copy.dataset.seq_len = 121

    return cfg_copy

We first define a template environment with the following settings:

- Probability of door change: 0.05
- Number of environments used for evaluating model: 1024

And other relevant settings:

In [None]:
# get config
from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf, SCMode

with initialize_config_dir(
    version_base=None,
    config_dir=str(Path("../rl/config").resolve()),
    job_name="test_cfg",
):
    cfg = compose(config_name="cfg_dqn.yaml")

typed_cfg: HRMQNetTrainingConfig = OmegaConf.to_container(
    OmegaConf.merge(OmegaConf.structured(HRMQNetTrainingConfig), cfg),
    structured_config_mode=SCMode.INSTANTIATE,
)

typed_cfg.dataset.env_kwargs["observation_keys"] = ["chars"]
typed_cfg.resume_from_run = None
typed_cfg.dataset.data_collection_batch_size = 1024
typed_cfg.dataset.frames_per_update = 1024
typed_cfg.log_wandb = False

# set 4 or 8 way
typed_cfg.dataset.action_space_size = 4
typed_cfg.dataset.vocab_size = 131
typed_cfg.dataset.env_kwargs["action-space"] = 4
typed_cfg.dataset.env_kwargs["p-change-doors"] = 0.05

dataset = MiniHackFullObservationSimpleEnvironmentDataset(config=typed_cfg.dataset)

dg_shape = (11, 11) if "4-Room" in typed_cfg.dataset.env_name else (11, 13)

# Perform validation loop while reading off the last hidden over 8 steps

We first build a class that allows us to attach hooks to the latent values.

In [None]:
GOAL_CHAR = ">"


class LoggingHRMQValueNet(HRMQValueNet):
    def q_values_on_new_carry(self, batch_data):
        """
        Compute Q-values for a given state and action.

        This runs from an initial carry until the end of the sequence, which is useful for prediction. This has no gradient flow by default as backpropagation through time is untested in the HRM model.
        """

        if intervention := getattr(self, "intervention", None):
            if intervention == "shift_target_randomly":
                batch_data = batch_data.clone()
                empty_locs = batch_data["inputs"] == ord(".")
                has_empty_locs = (empty_locs == True).any(dim=1)
                rand_scores = torch.rand_like(batch_data["inputs"], dtype=torch.float)
                rand_scores[~empty_locs] = float("-inf")
                batch_indices = torch.arange(
                    batch_data.batch_size[0], device=batch_data.device
                )[has_empty_locs]
                selected_indices = rand_scores.argmax(dim=1)[has_empty_locs]
                batch_data["inputs"] = torch.where(
                    batch_data["inputs"] == ord(GOAL_CHAR),
                    ord("."),
                    batch_data["inputs"],
                )
                batch_data["inputs"][batch_indices, selected_indices] = ord(GOAL_CHAR)

        q_values = None
        current_carry = self._initial_carry(batch_data).to(self.model.device)
        is_first_carry = True  # first carry is all halted by design so the data is simply copied from batch data

        running_z_H = [
            self.model.inner.H_init.detach()
            .unsqueeze(0)
            .repeat(batch_data.batch_size[0], 1)
        ]
        running_z_L = [
            self.model.inner.L_init.detach()
            .unsqueeze(0)
            .repeat(batch_data.batch_size[0], 1)
        ]

        while not current_carry.halted.all() or is_first_carry:
            current_carry, q_values = self._q_values_with_carry(
                carry=current_carry, batch_data=batch_data
            )

            # store the z_H and z_L so it doesn't get overwritten
            running_z_H.append(current_carry.inner_carry.z_H[:, -1, :].detach())
            running_z_L.append(current_carry.inner_carry.z_L[:, -1, :].detach())

            is_first_carry = False

            if intervention := getattr(self, "intervention", None):
                if intervention == "reset_converged_latents_to_init":
                    current_carry.inner_carry.z_H = (
                        current_carry.inner_carry.z_H * 0
                        + self.model.inner.H_init.unsqueeze(0).unsqueeze(0)
                    )
                    current_carry.inner_carry.z_L = (
                        current_carry.inner_carry.z_L * 0
                        + self.model.inner.L_init.unsqueeze(0).unsqueeze(0)
                    )

        if q_values is None:
            raise ValueError("No q-values computed as initial carry were all halted")

        running_z_H = torch.stack(running_z_H, dim=1)  # (bs, 9, hd)
        running_z_L = torch.stack(running_z_L, dim=1)  # (bs, 9, hd)
        return current_carry, q_values, running_z_H, running_z_L

    def forward(self, batch_data):
        """Forward starts from initial carry by default, used mostly during inference"""
        if self.model.training:
            raise ValueError("This should not be trained!")

        carry, action_values, running_z_H, running_z_L = self.q_values_on_new_carry(
            batch_data
        )
        out_tensor = TensorDict(
            {
                **{k: v for k, v in batch_data.items() if k not in ("action_value",)},
                "action_value": action_values,
                "running_z_H": running_z_H,
                "running_z_L": running_z_L,
            },
            batch_data.batch_size,
            device=batch_data.device,
        )

        if self.config.use_last_hidden_state_to_seed_next_environment_step:
            out_tensor["seed_h_init"] = carry.inner_carry.z_H[:, -1, :].detach().clone()
            out_tensor["seed_l_init"] = carry.inner_carry.z_L[:, -1, :].detach().clone()

        return out_tensor


class LoggingHRMModule(HRMAgentTrainingModule):
    def __init__(
        self, config: HRMQNetTrainingConfig, dataset: GymRLDataset, model_class: object
    ):
        self.config = config
        self.qvalue_net = model_class(self.config)
        self.dataset = dataset
        self.actor = QValueActor(self.qvalue_net, spec=dataset.base_env.action_spec)
        self.mse = nn.MSELoss()
        self.egreedy_module = EGreedyModule(
            spec=self.actor.spec,
            annealing_num_steps=self.config.eps_decay_steps,
            eps_init=self.config.start_eps,
            eps_end=self.config.end_eps,
        )

        if self.config.use_target_network:
            self.target_network = HRMQValueNet(config)
            for param in self.target_network.parameters():
                param.requires_grad = False
            for buffer in self.target_network.buffers():
                buffer.requires_grad = False
            self.target_network.eval()

        # validate some config
        if self.config.use_last_hidden_state_to_seed_next_environment_step:
            assert (
                self.config.dataset.training_batch_size
                == self.config.dataset.data_collection_batch_size
            )

We now collect the validation trajectories across different models and environments.

In [None]:
# some metadata on the models we are using
PICKLES_OUTPUT_DIR = Path(".").parent.parent / "outputs" / "plots_trajectories_pickle"
PICKLES_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# each run takes around 6 minutes, so 2 hour for all
ALL_RUNS = (
    # vanilla
    (
        "psy7bpwq",
        "Reseeding 4 rooms",
        "optim_step=193465_eval_metric=0.985.ckpt",
        True,
        "MiniHack-4-Rooms",
        None,
    ),
    (
        "psy7bpwq",
        "Not reseeding 4 rooms",
        "optim_step=193465_eval_metric=0.985.ckpt",
        False,
        "MiniHack-4-Rooms",
        None,
    ),
    ("u2fivtrs", "Reseeding 4 rooms", "last.ckpt", True, "MiniHack-4-Rooms", None),
    ("u2fivtrs", "Not reseeding 4 rooms", "last.ckpt", False, "MiniHack-4-Rooms", None),
    (
        "nq3r9hsw",
        "Reseeding maze",
        "optim_step=164261_eval_metric=0.973.ckpt",
        True,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        None,
    ),
    (
        "nq3r9hsw",
        "Not reseeding maze",
        "optim_step=164261_eval_metric=0.973.ckpt",
        False,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        None,
    ),
    (
        "f79zuqjx",
        "Reseeding maze",
        "last.ckpt",
        True,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        None,
    ),
    (
        "f79zuqjx",
        "Not reseeding maze",
        "last.ckpt",
        False,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        None,
    ),
    # with resetting intervention
    (
        "psy7bpwq",
        "Reseeding 4 rooms",
        "optim_step=193465_eval_metric=0.985.ckpt",
        True,
        "MiniHack-4-Rooms",
        "reset_converged_latents_to_init",
    ),
    (
        "psy7bpwq",
        "Not reseeding 4 rooms",
        "optim_step=193465_eval_metric=0.985.ckpt",
        False,
        "MiniHack-4-Rooms",
        "reset_converged_latents_to_init",
    ),
    (
        "u2fivtrs",
        "Reseeding 4 rooms",
        "last.ckpt",
        True,
        "MiniHack-4-Rooms",
        "reset_converged_latents_to_init",
    ),
    (
        "u2fivtrs",
        "Not reseeding 4 rooms",
        "last.ckpt",
        False,
        "MiniHack-4-Rooms",
        "reset_converged_latents_to_init",
    ),
    (
        "nq3r9hsw",
        "Reseeding maze",
        "optim_step=164261_eval_metric=0.973.ckpt",
        True,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "reset_converged_latents_to_init",
    ),
    (
        "nq3r9hsw",
        "Not reseeding maze",
        "optim_step=164261_eval_metric=0.973.ckpt",
        False,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "reset_converged_latents_to_init",
    ),
    (
        "f79zuqjx",
        "Reseeding maze",
        "last.ckpt",
        True,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "reset_converged_latents_to_init",
    ),
    (
        "f79zuqjx",
        "Not reseeding maze",
        "last.ckpt",
        False,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "reset_converged_latents_to_init",
    ),
    # with random shift intervention
    (
        "psy7bpwq",
        "Reseeding 4 rooms",
        "optim_step=193465_eval_metric=0.985.ckpt",
        True,
        "MiniHack-4-Rooms",
        "shift_target_randomly",
    ),
    (
        "psy7bpwq",
        "Not reseeding 4 rooms",
        "optim_step=193465_eval_metric=0.985.ckpt",
        False,
        "MiniHack-4-Rooms",
        "shift_target_randomly",
    ),
    (
        "u2fivtrs",
        "Reseeding 4 rooms",
        "last.ckpt",
        True,
        "MiniHack-4-Rooms",
        "shift_target_randomly",
    ),
    (
        "u2fivtrs",
        "Not reseeding 4 rooms",
        "last.ckpt",
        False,
        "MiniHack-4-Rooms",
        "shift_target_randomly",
    ),
    (
        "nq3r9hsw",
        "Reseeding maze",
        "optim_step=164261_eval_metric=0.973.ckpt",
        True,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "shift_target_randomly",
    ),
    (
        "nq3r9hsw",
        "Not reseeding maze",
        "optim_step=164261_eval_metric=0.973.ckpt",
        False,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "shift_target_randomly",
    ),
    (
        "f79zuqjx",
        "Reseeding maze",
        "last.ckpt",
        True,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "shift_target_randomly",
    ),
    (
        "f79zuqjx",
        "Not reseeding maze",
        "last.ckpt",
        False,
        "MiniHack-Corridor-Maze-4-Way-Dynamic",
        "shift_target_randomly",
    ),
)

for (
    run_name,
    desc,
    ckpt_name,
    set_hidden_state,
    env_name,
    intervention_desc,
) in ALL_RUNS:
    latents_pickle_file_name = (
        f"{run_name}_at_{ckpt_name}_reseeding={set_hidden_state}_env={env_name}_intervetion={intervention_desc if intervention_desc else 'None'}"
        + ".pkl"
    )
    if (PICKLES_OUTPUT_DIR / latents_pickle_file_name).exists():
        continue

    test_cfg = set_cfg_env(typed_cfg, env_name)

    test_cfg.use_last_hidden_state_to_seed_next_environment_step = set_hidden_state
    if test_cfg.use_last_hidden_state_to_seed_next_environment_step:
        test_cfg.dataset.training_batch_size = (
            test_cfg.dataset.data_collection_batch_size
        )
    test_cfg.dataset.do_not_skip_running_model_if_random_action = (
        test_cfg.use_last_hidden_state_to_seed_next_environment_step
    )
    test_dataset = MiniHackFullObservationSimpleEnvironmentDataset(
        config=test_cfg.dataset
    )

    with torch.device(
        "cuda"
    ):  # make sure that the buffers used in HRM are initialised on CUDA for backprop
        hrm_agent_training_module = LoggingHRMModule(
            test_cfg, test_dataset, LoggingHRMQValueNet
        )

    if intervention_desc is not None:
        hrm_agent_training_module.qvalue_net.intervention = intervention_desc

    # reconfigure the out keys of the modules so that things are kept
    if test_cfg.use_last_hidden_state_to_seed_next_environment_step:
        hrm_agent_training_module.qvalue_net.config.use_last_hidden_state_to_seed_next_environment_step = test_cfg.use_last_hidden_state_to_seed_next_environment_step
        hrm_agent_training_module.actor.out_keys += ["seed_h_init", "seed_l_init"]
        hrm_agent_training_module.qvalue_net.out_keys += ["seed_h_init", "seed_l_init"]

    hrm_agent_training_module.qvalue_net.out_keys += ["running_z_H", "running_z_L"]
    hrm_agent_training_module.actor.out_keys += ["running_z_H", "running_z_L"]

    test_dataset.initialise_policy_and_collector(hrm_agent_training_module.actor, None)
    hrm_agent_training_module.pre_training_setup(checkpoint_dir="s3", run_name=run_name)
    hrm_agent_training_module.load_from_checkpoint(
        run_name, ckpt_path_name=ckpt_name, restore_config=False
    )

    validation_trajectories = test_dataset.validation_rollout()

    # pickle the output for speedup
    with open(PICKLES_OUTPUT_DIR / latents_pickle_file_name, "wb") as f:
        pickle.dump(validation_trajectories, f)

We also define utilities for converting the trajectories into metrics

In [None]:
import re
import seaborn as sns
import pandas as pd
import torch
import pickle

from typing import Callable
from pathlib import Path
from matplotlib.lines import Line2D

from utils.a_star import a_star


def get_convergence_divergence_metrics(
    trajectories: TensorDict,
    dg_shape: tuple[int, int],
    latent_key: str = "running_z_H",
    against_env_init_latents=False,
    against_seed_latents=False,
    against_first_latents=False,
):
    ignore_repeat_after_termination_mask = torch.logical_and(
        trajectories["next"]["reward"][:, :, 0] == 1,
        torch.roll(trajectories["next"]["reward"][:, :, 0], shifts=1, dims=1) == 1,
    )
    ignore_repeat_after_termination_mask[:, 0] = (
        True  # Mask the first as it action is kinda special since the model always start from init, We want to chart convergence speed given different inits and environment change flag
    )

    latent_values = trajectories[latent_key]  # (bs, num_env_steps, recursive_steps, hd)

    if against_env_init_latents:
        assert not against_seed_latents and not against_first_latents
        comparison_latents = (
            trajectories[latent_key][:, 0, 0, :].unsqueeze(1).unsqueeze(1)
        )  # (bs, 1, 1, hd)
    elif against_seed_latents:
        assert not against_first_latents
        comparison_latents = trajectories[latent_key][:, :, 0, :].unsqueeze(
            2
        )  # (bs, num_env_steps, 1, hd)
    elif against_first_latents:
        comparison_latents = trajectories[latent_key][:, :, 1, :].unsqueeze(
            2
        )  # (bs, num_env_steps, 1, hd)
    else:
        comparison_latents = trajectories[latent_key][:, :, -1, :].unsqueeze(
            2
        )  # (bs, num_env_steps, 1, hd)

    diff_mse = ((comparison_latents - latent_values) ** 2).sum(
        dim=-1
    ) / latent_values.shape[-1]  # (bs, num_env_steps, recursive_steps)

    last_reward = torch.roll(
        trajectories["next"]["reward"][:, :, 0], shifts=1, dims=-1
    )  # (bs, num_env_steps)
    mask = last_reward == 1  # (bs, num_env_steps)
    mask[:, 0] = True  # ignore first action

    # get change flag for each env_step
    change_from_prev_flags = [
        torch.as_tensor([True] * latent_values.shape[0], dtype=torch.bool)
    ]  # first step always add the first door
    for env_step in range(1, latent_values.shape[1]):
        # get change flag for this env step
        inputs = trajectories["inputs"][:, env_step, :]  # (bs, seq_len)
        last_inputs = trajectories["inputs"][:, env_step - 1, :]  # (bs, seq_len)
        inputs_ignored_actor_start = torch.where(
            torch.logical_or(inputs == ord("@"), inputs == ord("<")), ord("."), inputs
        )  # mask out the actor and start goal so the only change is in the doors of the env
        last_inputs_ignored_actor_start = torch.where(
            torch.logical_or(last_inputs == ord("@"), last_inputs == ord("<")),
            ord("."),
            last_inputs,
        )  # mask out the actor and start goal so the only change is in the doors of the env
        env_change_at_all = (
            last_inputs_ignored_actor_start != inputs_ignored_actor_start
        ).any(dim=-1)  # (bs)

        change_from_prev_flags.append([])
        for env_idx in range(latent_values.shape[0]):
            # if env didn't change at all, skip
            if not env_change_at_all[env_idx]:
                change_from_prev_flags[-1].append(False)
                continue

            # reconstruct the env_idx
            prev_input_at_env_idx = last_inputs[env_idx, :].reshape(dg_shape).T
            prev_agent_pos = list(
                map(
                    int, torch.nonzero(prev_input_at_env_idx == ord("@"), as_tuple=True)
                )
            )[::-1]
            prev_goal_pos = list(
                map(
                    int, torch.nonzero(prev_input_at_env_idx == ord(">"), as_tuple=True)
                )
            )[::-1]
            changed_locations = torch.nonzero(
                last_inputs_ignored_actor_start[env_idx, :].reshape(dg_shape).T
                != inputs_ignored_actor_start[env_idx, :].reshape(dg_shape).T,
                as_tuple=True,
            )

            # do a-star
            prev_input_at_env_idx_str = [
                [chr(val) if chr(val) not in ("|", "-", "#") else "|" for val in row]
                for row in prev_input_at_env_idx.T
            ]
            path_from_agent_to_goal = a_star(
                prev_input_at_env_idx_str,
                start=prev_agent_pos,
                goal=prev_goal_pos,
                obstacle="|",
            )

            # check that the door being updated is not on the path
            if path_from_agent_to_goal is not None:
                change_flag = False
                for c, r in zip(*changed_locations):
                    if [int(r), int(c)] in path_from_agent_to_goal[
                        1:
                    ]:  # skip the first one since it includes the agent_pos
                        change_flag = True
                        break

            else:  # path only change if there is path now
                current_input_at_env_idx = inputs[env_idx, :].reshape(dg_shape)
                current_agent_pos = list(
                    map(
                        int,
                        torch.nonzero(
                            current_input_at_env_idx == ord("@"), as_tuple=True
                        ),
                    )
                )[::-1]
                current_input_at_env_idx_str = [
                    [
                        chr(val) if chr(val) not in ("|", "-", "#") else "|"
                        for val in row
                    ]
                    for row in current_input_at_env_idx.T
                ]
                change_flag = (
                    a_star(
                        current_input_at_env_idx_str,
                        start=current_agent_pos,
                        goal=prev_goal_pos,
                        obstacle="|",
                    )
                    is not None
                )

            change_from_prev_flags[-1].append(change_flag)

        change_from_prev_flags[-1] = torch.as_tensor(
            change_from_prev_flags[-1], dtype=torch.bool
        )  # (bs)
    change_from_prev_flags = torch.stack(
        change_from_prev_flags, dim=1
    )  # (bs, num_env_steps)

    # change_from_prev_flags = change_from_p?rev_flags != 2 # all true

    # we want to know for each recursive step and change flag, what the avg mse is and what the size of the support
    mask_change_flag_true = (~mask) * change_from_prev_flags  # (bs, num_env_steps)
    mask_change_flag_false = (~mask) * (~change_from_prev_flags)  # (bs, num_env_steps)

    masked_change_true_diff_mse = torch.masked_select(
        diff_mse, mask_change_flag_true.unsqueeze(2)
    )  # (bs, num_env_steps, recursive_steps)
    support_mask_change_flag_true = mask_change_flag_true.sum()  # scalar
    mse_avg_recursive_step__change_flag_true = (
        masked_change_true_diff_mse.sum() / support_mask_change_flag_true
    )  # scalar

    masked_change_false_diff_mse = torch.masked_select(
        diff_mse, mask_change_flag_false.unsqueeze(2)
    )  # (bs, num_env_steps, recursive_steps)
    support_mask_change_flag_false = mask_change_flag_false.sum()  # scalar
    mse_avg_recursive_step__change_flag_false = (
        masked_change_false_diff_mse.sum() / support_mask_change_flag_false
    )  # (recursive step)

    return {
        "mse_avg_change_env": mse_avg_recursive_step__change_flag_true,  # recursive step values
        "change_env_support": support_mask_change_flag_true,
        "mse_avg_no_change_env": mse_avg_recursive_step__change_flag_false,  # recursive step values
        "no_change_env_support": support_mask_change_flag_false,
        "change_flags": change_from_prev_flags,
        "diff_mse": diff_mse,
        "mask_change_flag_true": mask_change_flag_true,
        "mask_change_flag_false": mask_change_flag_false,
        "num_envs_solved": int((trajectories["next"]["reward"][:, -1, 0] == 1).sum()),
        "num_envs_totaled": trajectories.shape[0],
    }


def find_all_runs_matching_desc(
    plots_dir: Path,
    reseeding_flag: bool = True,
    env_name: str = "MiniHack-4-Rooms",
    intervention: str = "None",
):
    for pkl_file in plots_dir.glob("*.pkl"):
        if pkl_file.name.endswith(
            f"_reseeding={reseeding_flag}_env={env_name}_intervetion={intervention}.pkl"
        ):  # oopse misspelt
            yield pkl_file


def plot_convergence_divergence(
    env_name: str,
    intervention: str,
    dg_shape: tuple[int, int],
    pickles_dir: Path,
    plots_dir: Path | None = None,
    against_env_init_latents: bool = False,
    against_seed_latents: bool = False,
    against_first_latents: bool = False,
    title: str | None = None,
    y_label: str = "Median MSE",
    plot_func: Callable[[torch.Tensor], torch.Tensor] = lambda y_tensor: torch.quantile(
        y_tensor, q=torch.tensor([0.5]), dim=0
    ).squeeze(),
    y_scale: str = "linear",
):
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(20, 8), constrained_layout=True)

    if plots_dir is None:
        plots_dir = pickles_dir.parent / "plots_pngs"
        plots_dir.mkdir(parents=True, exist_ok=True)

    for plot_idx, key in enumerate(("running_z_L", "running_z_H")):
        ax = axs[plot_idx]
        ax.set_title(key.replace("running_", ""), fontsize=14, fontweight="bold")

        data_records = []
        metadata_records = {}

        dash_styles = {}

        # Aggregate all data into a long-format list for seaborn
        for reseeding_flag in (True, False):
            measurement_metrics = []

            for pkl_file in find_all_runs_matching_desc(
                plots_dir=pickles_dir,
                reseeding_flag=reseeding_flag,
                env_name=env_name,
                intervention=intervention,
            ):
                print(f"Reading pickle {pkl_file}")
                with open(pkl_file, "rb") as f:
                    trajectories = pickle.load(f)

                measurement_metrics.append(
                    get_convergence_divergence_metrics(
                        trajectories=trajectories,
                        dg_shape=dg_shape,
                        latent_key=key,
                        against_env_init_latents=against_env_init_latents,
                        against_seed_latents=against_seed_latents,
                        against_first_latents=against_first_latents,
                    )
                )

            for change_flag in (True, False):
                mask_key = f"mask_change_flag_{change_flag}".lower()

                X_ALL = list(range(measurement_metrics[0]["diff_mse"].shape[-1]))
                Y_ALL = [[] for _ in range(len(X_ALL))]

                for m in measurement_metrics:
                    Y_tensor = m["diff_mse"]
                    MASK = m[mask_key]
                    for step in range(Y_tensor.shape[-1]):
                        Y_ALL[step].extend(
                            torch.masked_select(Y_tensor[:, :, step], MASK).float()
                        )

                Y_ALL = torch.tensor(Y_ALL)  # (recursion_steps, N_ALL)
                plot_y = plot_func(torch.swapaxes(Y_ALL, 0, 1))

                if y_scale == "log":
                    plot_y = torch.maximum(
                        plot_y, torch.tensor(torch.finfo(torch.float32).eps)
                    )

                label = (
                    ("Carry Z" if reseeding_flag else "Reset Z")
                    + " - "
                    + ("Env. changed" if change_flag else "Env. static")
                    + f" (N={Y_ALL.shape[1]})"
                )

                dash_styles[label] = "" if reseeding_flag else (2, 2)

                for x, y in zip(X_ALL[1:-1], plot_y[1:-1].numpy()):
                    data_records.append(
                        {"Recurrent step": x, "Median MSE": y, "Condition": label}
                    )

            metadata_records[reseeding_flag] = {
                "num_envs": sum(m["num_envs_totaled"] for m in measurement_metrics),
                "num_envs_solved": sum(
                    m["num_envs_solved"] for m in measurement_metrics
                ),
            }
            metadata_records[reseeding_flag]["frac_solved"] = (
                metadata_records[reseeding_flag]["num_envs_solved"]
                / metadata_records[reseeding_flag]["num_envs"]
            )

        # --- Plot with seaborn ---
        sns.lineplot(
            data=pd.DataFrame(data_records),
            x="Recurrent step",
            y="Median MSE",
            hue="Condition",
            ax=ax,
            style="Condition",
            dashes=dash_styles,
            palette="colorblind",
            linewidth=2,
            legend=plot_idx == len(axs) - 1,
        )

        if plot_idx == 0:
            ax.set_ylabel("Median MSE" if y_label is None else y_label, fontsize=14)
        else:
            ax.set_ylabel("", fontsize=14)

        ax.set_xlabel("Recurrent step", fontsize=16)

        ax.set_yscale(y_scale)

        if plot_idx == len(axs) - 1:
            meta_legend = [
                Line2D(
                    [],
                    [],
                    color="none",
                    label=f"Carry Z: {metadata_records[True]['frac_solved'] * 100:.1f}%",
                ),
                Line2D(
                    [],
                    [],
                    color="none",
                    label=f"Reset Z: {metadata_records[False]['frac_solved'] * 100:.1f}%",
                ),
            ]

            # Add the secondary metadata legend aligned left with the first
            info_legend = ax.legend(
                handles=meta_legend,
                loc="upper left",
                bbox_to_anchor=(1.02, 0.75),  # same x (1.02), slightly lower y
                frameon=False,
                title=f"Environments solved (N={sum(int(d['num_envs']) for d in metadata_records.values()) // 2}):",  # reseeding and non-reseeding
                title_fontsize=16,
                fontsize=16,
            )

            # add the main legend as it is longer and we want it used to define the width of the figure
            ax.legend(
                title=None,
                frameon=False,
                loc="upper left",
                bbox_to_anchor=(1.02, 1),
                fontsize=16,
            )

            # Ensure the first legend remains visible (since adding a new one replaces the previous)
            ax.add_artist(info_legend)

    if title is None:
        title = (
            f"Convergence of hidden states for {X_ALL[-1]} recurrent steps - {env_name}"
        )
    fig.suptitle(title, fontsize=18)  # , fontweight="bold")

    file_name = re.sub(r"[^a-zA-Z0-9]", "_", title.lower())[:175] + ".png"
    plt.savefig(
        plots_dir / file_name,
        dpi=300,
    )
    plt.show()

# Divergence and convergence plots

Paper results

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Convergence of recurrent state $z$ to final values - 4-rooms environment",
    y_scale="log",
    # y_scale="linear",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Convergence of recurrent state $z$ to final values: Maze environment",
    y_scale="log",
    # y_scale="linear",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Divergence of recurrent state $z$ from initial values: Maze environment",
    against_first_latents=True,
    y_scale="log",
    # y_scale="linear",
)

General plots

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Convergence of MSE against final converged hidden states - 4-Rooms environment",
)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    y_label="STD MSE",
    y_scale="log",
    plot_func=lambda y: torch.std(y, dim=0).squeeze(),
    title="Standard deviation in MSE against final converged hidden states - 4-Rooms environment",
)

For divergence:

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_seed_latents=True,
    title="Divergence of hidden states from initial hidden state used to seed the model - 4-Rooms environment",
)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_first_latents=True,
    y_label="STD MSE",
    y_scale="log",
    plot_func=lambda y: torch.std(y, dim=0).squeeze(),
    title="Standard deviation in MSE against 1-step hidden states - 4-Rooms environment",
)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_first_latents=True,
    title="Divergence of MSE of subsequent recurrent steps against 1-step hidden states - 4-Rooms environment",
)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="reset_converged_latents_to_init",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_first_latents=True,
    title="MSE of subsequent recurrent steps against 1-step hidden states - 4-Rooms environment\nIntervention: hidden states are reset to initial states after every step",
)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="reset_converged_latents_to_init",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_env_init_latents=True,
    title="Divergence of MSE of against hidden states used to init the environment - 4-Rooms environment\nIntervention: hidden states are reset to initial states after every step",
)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="shift_target_randomly",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Convergence of MSE against final converged hidden states - 4-Rooms environment\nIntervention: model input is patched so the goal is resetted at a random empty location",
)

# Dynamic maze

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Convergence of MSE against final converged hidden states - Maze environment",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    y_label="STD MSE",
    y_scale="log",
    plot_func=lambda y: torch.std(y, dim=0).squeeze(),
    title="Standard deviation in MSE against final converged hidden states - Maze environment",
)

For divergence:

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_seed_latents=True,
    title="Divergence of hidden states from initial hidden state used to seed the model - Maze environment",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_first_latents=True,
    title="Divergence of MSE of subsequent recurrent steps against 1-step hidden states - Maze environment",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="reset_converged_latents_to_init",
    pickles_dir=PICKLES_OUTPUT_DIR,
    against_env_init_latents=True,
    title="Divergence of MSE of against hidden states used to init the environment - Maze environment\nIntervention: hidden states are reset to initial states after every step",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="shift_target_randomly",
    pickles_dir=PICKLES_OUTPUT_DIR,
    title="Convergence of MSE against final converged hidden states - Maze environment\nIntervention: model input is patched so the goal is resetted at a random empty location",
)