In [13]:
try:
    import procgen_tools
except ImportError:
    get_ipython().run_line_magic(
        magic_name="pip",
        line="install -U git+https://github.com/ulissemini/procgen-tools",
    )

from procgen_tools.utils import setup

setup()  # create directory structure and download data

from procgen_tools.imports import *
from procgen_tools import maze, visualization, models, patch_utils
from typing import Tuple, Dict, List, Optional, Union
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import tqdm

AX_SIZE = 4

cheese_channels = [7, 8, 42, 44, 55, 77, 82, 88, 89, 99, 113]
effective_channels = [8, 55, 77, 82, 88, 89, 113]

Already downloaded https://nerdsniper.net/mats/model_rand_region_5.pth


# Showing mean probability of reaching each part of the maze
Let's show a "heatmap" of the agent's propensity to visit each square of
the maze. Each tile is colored red according to the policy's geometric mean probability
along the path to that tile. 

In [14]:
@interact(seed=(0, 100))
def show_grid_heatmap_interactive(seed: int) -> None:
    """Show a heatmap over the maze using matplotlib.

    Args:
        seed: Seed of the maze
    """
    venv = maze.create_venv(num=1, start_level=seed, num_levels=1)
    heatmap: np.ndarray = visualization.vf_heatmap(venv, hook)

    # TODO maybe add a colorbar?
    visualization.show_grid_heatmap(
        venv=venv,
        heatmap=heatmap,
        mode="human",
        ax_size=AX_SIZE,
    )

interactive(children=(IntSlider(value=50, description='seed'), Output()), _dom_classes=('widget-interact',))

# Let's load the retargeting data

In [16]:
dfs = []
DATA_DIR = "experiments/statistics/data/retargeting"
# Find every CSV file
for file in os.listdir(DATA_DIR):
    if file.endswith(".csv"):
        df = pd.read_csv(os.path.join(DATA_DIR, file))
        dfs.append(df)
data = pd.concat(dfs, ignore_index=True)

## Visualizing retargetability in different mazes using different criteria

Each heatmap shows the geometric mean probability of navigating to a
tile, given:
1. "Cheese": The cheese is placed at the given tile. 
2. "Effective": Channels `[8, 55, 77, 82, 88, 89, 113]` are clamped to
   `+2.3` at a filter corresponding to the given maze location.
3. "All": Channels `[7, 8, 42, 44, 55, 77, 82, 88, 89, 99, 113]` are
   clamped to `+1.0`.
4. "Normal": No intervention.
5. "55": Channel `[55]` is clamped to `+5.5`.

If no cheese is shown in the maze, then the cheese was removed from its
original position at that seed (i.e. the maze is empty, unless the
"cheese" condition is forcing cheese to different tiles.)

In [37]:
# Show a retargeting heatmap for each maze
@interact(
    seed=IntSlider(min=0, max=100, step=1, value=0),
    prob_type=["cheese", "effective", "all", "normal", "55"],
)
def show_retargeting_heatmap_interactive(seed: int, prob_type: str) -> None:
    """Show a heatmap over the maze using matplotlib.

    Args:
        seed: Seed of the maze
        prob_type: Type of heatmap to show
    """
    venv = maze.create_venv(num=1, start_level=seed, num_levels=1)
    seed_data: pd.DataFrame = data[data["seed"] == seed]
    label = prob_type
    if prob_type == "all":
        label = str(cheese_channels)
    elif prob_type == "effective":
        label = str(effective_channels)
    elif prob_type == "55":
        label = "[55]"

    prob_data = seed_data[seed_data["intervention"] == label]
    # Check if removed_cheese is true for prob_data's first row
    if prob_data["removed_cheese"].iloc[0]:
        venv = maze.remove_all_cheese(venv)

    heatmap = prob_data.pivot(index="row", columns="col", values="probability")

    visualization.show_grid_heatmap(
        venv=venv,
        heatmap=heatmap.values,
        ax_size=AX_SIZE,
        mode="human",
        size=0.5,
        alpha=0.9,
    )

interactive(children=(IntSlider(value=0, description='seed'), Dropdown(description='prob_type', options=('chee…