In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path
from tqdm import tqdm
sys.path.append(str(Path("..").resolve()))

import torch
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:

from itertools import count
from typing import Callable
from pathlib import Path
from matplotlib.lines import Line2D
import pickle
from tensordict import TensorDict
from utils.a_star import a_star

def get_convergence_divergence_metrics(
    trajectories: TensorDict, 
    dg_shape: tuple[int, int], 
    latent_key: str = "running_z_H",
    against_env_init_latents=False, 
    against_seed_latents=False,
    against_first_latents=False,
):
    ignore_repeat_after_termination_mask = torch.logical_and(trajectories["next"]["reward"][:, :, 0] == 1, torch.roll(trajectories["next"]["reward"][:, :, 0], shifts=1, dims=1) == 1)
    ignore_repeat_after_termination_mask[:, 0] = True # Mask the first as it action is kinda special since the model always start from init, We want to chart convergence speed given different inits and environment change flag

    latent_values = trajectories[latent_key] # (bs, num_env_steps, recursive_steps, hd)

    if against_env_init_latents:
        assert not against_seed_latents and not against_first_latents
        comparison_latents = trajectories[latent_key][:, 0, 0, :].unsqueeze(1).unsqueeze(1) # (bs, 1, 1, hd)
    elif against_seed_latents:
        assert not against_first_latents
        comparison_latents = trajectories[latent_key][:, :, 0, :].unsqueeze(2) # (bs, num_env_steps, 1, hd)
    elif against_first_latents:
        comparison_latents = trajectories[latent_key][:, :, 1, :].unsqueeze(2) # (bs, num_env_steps, 1, hd)
    else:
        comparison_latents = trajectories[latent_key][:, :, -1, :].unsqueeze(2) # (bs, num_env_steps, 1, hd)

    diff_mse = (((comparison_latents - latent_values) ** 2).sum(dim=-1) / latent_values.shape[-1]) # (bs, num_env_steps, recursive_steps)
    
    last_reward = torch.roll(trajectories["next"]["reward"][:, :, 0], shifts=1, dims=-1) # (bs, num_env_steps)
    mask = (last_reward == 1) # (bs, num_env_steps)
    mask[:, 0] = True # ignore first action

    # get change flag for each env_step
    change_from_prev_flags = [torch.as_tensor([True] * latent_values.shape[0], dtype=torch.bool)]  # first step always add the first door
    for env_step in range(1, latent_values.shape[1]):
        # get change flag for this env step
        inputs = trajectories["inputs"][:, env_step, :] # (bs, seq_len)
        last_inputs = trajectories["inputs"][:, env_step - 1, :] # (bs, seq_len)
        inputs_ignored_actor_start = torch.where(torch.logical_or(inputs == ord("@"), inputs == ord("<")), ord("."), inputs)  # mask out the actor and start goal so the only change is in the doors of the env
        last_inputs_ignored_actor_start = torch.where(torch.logical_or(last_inputs == ord("@"), last_inputs == ord("<")), ord("."), last_inputs)  # mask out the actor and start goal so the only change is in the doors of the env
        env_change_at_all = (last_inputs_ignored_actor_start != inputs_ignored_actor_start).any(dim=-1)  #(bs)

        change_from_prev_flags.append([])
        for env_idx in range(latent_values.shape[0]):
            # if env didn't change at all, skip
            if not env_change_at_all[env_idx]:
                change_from_prev_flags[-1].append(False)
                continue

            # reconstruct the env_idx
            prev_input_at_env_idx = last_inputs[env_idx, :].reshape(dg_shape).T
            prev_agent_pos = list(map(int, torch.nonzero(prev_input_at_env_idx == ord("@"), as_tuple=True)))[::-1]
            prev_goal_pos = list(map(int, torch.nonzero(prev_input_at_env_idx == ord(">"), as_tuple=True)))[::-1]
            changed_locations = torch.nonzero(last_inputs_ignored_actor_start[env_idx, :].reshape(dg_shape).T != inputs_ignored_actor_start[env_idx, :].reshape(dg_shape).T, as_tuple=True)

            # do a-star
            prev_input_at_env_idx_str = [[chr(val) if chr(val) not in ("|", "-", "#") else "|" for val in row] for row in prev_input_at_env_idx.T]
            path_from_agent_to_goal = a_star(prev_input_at_env_idx_str, start=prev_agent_pos, goal=prev_goal_pos, obstacle="|")

            # check that the door being updated is not on the path
            if path_from_agent_to_goal is not None:
                change_flag = False
                for c, r in zip(*changed_locations):
                    if [int(r), int(c)] in path_from_agent_to_goal[1:]: # skip the first one since it includes the agent_pos
                        change_flag = True
                        break

            else: # path only change if there is path now
                current_input_at_env_idx = inputs[env_idx, :].reshape(dg_shape)
                current_agent_pos = list(map(int, torch.nonzero(current_input_at_env_idx == ord("@"), as_tuple=True)))[::-1]
                current_input_at_env_idx_str = [[ chr(val) if chr(val) not in ("|", "-", "#") else "|" for val in row] for row in current_input_at_env_idx.T]
                change_flag = a_star(current_input_at_env_idx_str, start=current_agent_pos, goal=prev_goal_pos, obstacle="|") is not None
            
            change_from_prev_flags[-1].append(change_flag)
    
        change_from_prev_flags[-1] = torch.as_tensor(change_from_prev_flags[-1], dtype=torch.bool) # (bs)
    change_from_prev_flags = torch.stack(change_from_prev_flags, dim=1) # (bs, num_env_steps)

    # change_from_prev_flags = change_from_p?rev_flags != 2 # all true

    # we want to know for each recursive step and change flag, what the avg mse is and what the size of the support
    mask_change_flag_true = (~mask) * change_from_prev_flags # (bs, num_env_steps)
    mask_change_flag_false = (~mask) * (~change_from_prev_flags) # (bs, num_env_steps)


    masked_change_true_diff_mse = torch.masked_select(diff_mse, mask_change_flag_true.unsqueeze(2)) # (bs, num_env_steps, recursive_steps)
    support_mask_change_flag_true = mask_change_flag_true.sum() # scalar
    mse_avg_recursive_step__change_flag_true = masked_change_true_diff_mse.sum() / support_mask_change_flag_true # scalar

    masked_change_false_diff_mse = torch.masked_select(diff_mse, mask_change_flag_false.unsqueeze(2)) # (bs, num_env_steps, recursive_steps)
    support_mask_change_flag_false = mask_change_flag_false.sum() # scalar
    mse_avg_recursive_step__change_flag_false = masked_change_false_diff_mse.sum() / support_mask_change_flag_false # (recursive step)

    return {
        "mse_avg_change_env": mse_avg_recursive_step__change_flag_true, # recursive step values
        "change_env_support": support_mask_change_flag_true,
        "mse_avg_no_change_env": mse_avg_recursive_step__change_flag_false, # recursive step values
        "no_change_env_support": support_mask_change_flag_false,
        "change_flags": change_from_prev_flags,
        "diff_mse": diff_mse,
        "mask_change_flag_true": mask_change_flag_true,
        "mask_change_flag_false": mask_change_flag_false,
        "num_envs_solved": int((trajectories["next"]["reward"][:, -1, 0] == 1).sum()),
        "num_envs_totaled": trajectories.shape[0],
    }


def find_all_runs_matching_desc(
    plots_dir: Path,
    reseeding_flag: bool = True,
    env_name: str = "MiniHack-4-Rooms",
    intervention: str = "None",
):
    for pkl_file in plots_dir.glob("*.pkl"):
        if pkl_file.name.endswith(f"_reseeding={reseeding_flag}_env={env_name}_intervetion={intervention}.pkl"):  # oopse misspelt
            yield pkl_file

def plot_convergence_divergence(
    env_name: str,
    intervention: str,
    dg_shape: tuple[int, int],
    pickles_dir: Path,
    plots_dir: Path | None = None,
    against_env_init_latents: bool = False,
    against_seed_latents: bool = False,
    against_first_latents: bool = False,
    title: str | None = None,
    y_label: str = "Median MSE",
    plot_func: Callable[torch.Tensor, torch.Tensor] = lambda y_tensor: torch.quantile(y_tensor, q=torch.tensor([0.5]), dim=0).squeeze(),
    y_scale: str="linear",
):
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(20, 8), constrained_layout=True)

    if plots_dir is None:
        plots_dir = pickles_dir.parent / "plots_pngs"
        plots_dir.mkdir(parents=True, exist_ok=True)
    
    for plot_idx, key in enumerate(("running_z_L", "running_z_H")):
        ax = axs[plot_idx]
        ax.set_title(key.replace("running_", ""), fontsize=14, fontweight="bold")

        line_counter = count(0)
        data_records = []
        metadata_records = {}

        dash_styles = {}

        # Aggregate all data into a long-format list for seaborn
        for reseeding_flag in (True, False):
            measurement_metrics = []

            for pkl_file in find_all_runs_matching_desc(
                plots_dir=pickles_dir, reseeding_flag=reseeding_flag,
                env_name=env_name, intervention=intervention
            ):
                print(f"Reading pickle {pkl_file}")
                with open(pkl_file, "rb") as f:
                    trajectories = pickle.load(f)

                measurement_metrics.append(get_convergence_divergence_metrics(
                    trajectories=trajectories,
                    dg_shape=dg_shape,
                    latent_key=key,
                    against_env_init_latents=against_env_init_latents,
                    against_seed_latents=against_seed_latents,
                    against_first_latents=against_first_latents,
                ))

            for change_flag in (True, False):
                mask_key = f"mask_change_flag_{change_flag}".lower()

                X_ALL = list(range(measurement_metrics[0]["diff_mse"].shape[-1]))
                Y_ALL = [[] for _ in range(len(X_ALL))]

                for m in measurement_metrics:
                    Y_tensor = m["diff_mse"]
                    MASK = m[mask_key]
                    for step in range(Y_tensor.shape[-1]):
                        Y_ALL[step].extend(torch.masked_select(Y_tensor[:, :, step], MASK).float())

                Y_ALL = torch.tensor(Y_ALL) # (recursion_steps, N_ALL)
                plot_y = plot_func(torch.swapaxes(Y_ALL, 0, 1))

                if y_scale == "log":
                    plot_y = torch.maximum(plot_y, torch.tensor(torch.finfo(torch.float32).eps))

                label = (
                    "Carry Z" if reseeding_flag else "Reset Z"
                ) + " - " + (
                    "Env. changed" if change_flag else "Env. static"
                ) + f" (N={Y_ALL.shape[1]})"

                dash_styles[label] = '' if reseeding_flag else (2, 2)

                for x, y in zip(X_ALL[1:-1], plot_y[1:-1].numpy()):
                    data_records.append({
                        "Recurrent step": x,
                        "Median MSE": y,
                        "Condition": label
                    })

            metadata_records[reseeding_flag] = {
                "num_envs": sum(m["num_envs_totaled"] for m in measurement_metrics),
                "num_envs_solved": sum(m["num_envs_solved"] for m in measurement_metrics),
            }
            metadata_records[reseeding_flag]["frac_solved"] = (
                metadata_records[reseeding_flag]["num_envs_solved"] 
                / metadata_records[reseeding_flag]["num_envs"]
            )

        # --- Plot with seaborn ---
        sns.lineplot(
            data=pd.DataFrame(data_records),
            x="Recurrent step",
            y="Median MSE",
            hue="Condition",
            ax=ax,
            style="Condition",
            dashes=dash_styles,
            palette="colorblind",
            linewidth=2,
            legend=plot_idx == len(axs) - 1,
        )

        if plot_idx == 0:
            ax.set_ylabel("Median MSE" if y_label is None else y_label, fontsize=14)
        else:
            ax.set_ylabel("", fontsize=14)

        ax.set_xlabel("Recurrent step", fontsize=16)

        ax.set_yscale(y_scale)

        if plot_idx == len(axs) - 1:
            meta_legend = [
                Line2D([], [], color='none', label=f'Carry Z: {metadata_records[True]["frac_solved"]*100:.1f}%'),
                Line2D([], [], color='none', label=f'Reset Z: {metadata_records[False]["frac_solved"]*100:.1f}%')
            ]
            
            # Add the secondary metadata legend aligned left with the first
            info_legend = ax.legend(
                handles=meta_legend,
                loc='upper left',
                bbox_to_anchor=(1.02, 0.75),  # same x (1.02), slightly lower y
                frameon=False,
                title=f'Environments solved (N={sum(int(d["num_envs"]) for d in metadata_records.values()) // 2}):', # reseeding and non-reseeding
                title_fontsize=16,
                fontsize=16,
            )

            # add the main legend as it is longer and we want it used to define the width of the figure
            ax.legend(title=None, frameon=False, loc="upper left", bbox_to_anchor=(1.02, 1), fontsize=16)

            # Ensure the first legend remains visible (since adding a new one replaces the previous)
            ax.add_artist(info_legend)

    if title is None:
        title = f"Convergence of hidden states for {X_ALL[-1]} recurrent steps - {env_name}"
    fig.suptitle(title, fontsize=18) #, fontweight="bold")    
    plt.savefig(plots_dir / (title.lower().replace(" ", "_").replace("-", "_") + ".png"), dpi=300)
    plt.show()


In [None]:
root_path = "/home/incubator/dev/hrm/s3/pickles"
metadata_dict = {
    "psy7bpwq": {
        "model_trained_on_reseeding": True,
        "env": "4-Room",
    },
    "u2fivtrs": {
        "model_trained_on_reseeding": False,
        "env": "4-Room",
    },
    "nq3r9hsw": {
        "model_trained_on_reseeding": False,
        "env": "Dynamic-Maze",
    },
    "f79zuqjx": {
        "model_trained_on_reseeding": True,
        "env": "Dynamic-Maze",
    }
}

UNIQUE_DESCS = ("4-rooms, carry Z", "4-rooms, reset Z", "Maze, carry Z", "Maze, reset Z")

PLOTS_OUTPUT_DIR = Path(root_path)

In [None]:
dg_shape = (11, 11)
env_name = "MiniHack-4-Rooms"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PLOTS_OUTPUT_DIR,
    title="Convergence of recurrent state $z$ to final values - 4-rooms environment",
    y_scale="log",
    #y_scale="linear",
)


In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PLOTS_OUTPUT_DIR,
    title="Convergence of recurrent state $z$ to final values: Maze environment",
    y_scale="log",
    #y_scale="linear",
)

In [None]:
dg_shape = (11, 13)
env_name = "MiniHack-Corridor-Maze-4-Way-Dynamic"

plot_convergence_divergence(
    dg_shape=dg_shape,
    env_name=env_name,
    intervention="None",
    pickles_dir=PLOTS_OUTPUT_DIR,
    title="Divergence of recurrent state $z$ from initial values: Maze environment",
    against_first_latents=True,
    y_scale="log",
    #y_scale="linear",
)