In [17]:
try:
    import procgen_tools
except ImportError:
    get_ipython().run_line_magic(
        magic_name="pip",
        line="install -U git+https://github.com/ulissemini/procgen-tools",
    )

from procgen_tools.utils import setup

setup()  # create directory structure and download data

from procgen_tools.imports import *
from procgen_tools import maze, visualization, models, patch_utils
from typing import Tuple, Dict, List, Optional, Union
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

SAVE_DIR = "playground/visualizations"
AX_SIZE = 6

cheese_channels = [7, 8, 42, 44, 55, 77, 82, 88, 89, 99, 113]
effective_channels = [8, 55, 77, 82, 88, 89, 113]

Already downloaded https://nerdsniper.net/mats/model_rand_region_5.pth


In [127]:
venv = maze.create_venv(num=1, start_level=0, num_levels=1)
vf = visualization.vector_field(venv, hook.network)
start = (0, 0)
end = (0, 3)
prob = maze.geometric_probability_path(start, end, vf)
print(
    f"The geometric average probability from {start} to {end} is {prob:0.3f}"
)

The geometric average probability from (0, 0) to (0, 3) is 0.906


# Showing mean probability of reaching each part of the maze

In [225]:
@interact(seed=(0, 100))  # TODO implement size in show_grid_heatmap
def show_grid_heatmap_interactive(seed: int) -> None:
    """Show a heatmap over the maze using matplotlib.

    Args:
        seed: Seed of the maze
    """
    venv = maze.create_venv(num=1, start_level=seed, num_levels=1)
    vf: Dict = visualization.vector_field(venv, hook.network)
    grid: np.ndarray = maze.get_inner_grid_from_seed(seed)
    heatmap: np.ndarray = np.zeros_like(grid, dtype=np.float32)
    for coord in maze.get_legal_mouse_positions(grid) + [
        maze.get_cheese_pos(grid)
    ]:
        heatmap[coord] = maze.geometric_probability_path((0, 0), coord, vf)

    visualization.show_grid_heatmap(
        venv=venv, heatmap=heatmap, ax_size=AX_SIZE, mode="human", size=0.5
    )

interactive(children=(IntSlider(value=50, description='seed'), Output()), _dom_classes=('widget-interact',))

In [15]:
def retarget_to_square(
    venv,
    hook,
    channels: List[int],
    coord: Tuple[int, int],
    magnitude: float = 5.5,
    default: Optional[float] = None,
) -> float:
    """Create a hook and retarget the given channels to the given
    square, returning the geometric average of the probabilities from
    the origin to that square.

    Args:
        venv: Vectorized environment
        hook: Hook to the network
        channels: List of channels to retarget
        coord: Coordinate of the square to retarget to
        magnitude: Magnitude of the retargeting
        default: Default value to use for the retargeted channels,
        outside of the coord
    """
    patches = patch_utils.combined_pixel_patch(
        layer_name=default_layer,
        channels=channels,
        value=magnitude,
        coord=coord,
        default=default,
    )
    with hook.use_patches(patches):
        vf: Dict = visualization.vector_field(venv, hook.network)
    return maze.geometric_probability_path((0, 0), coord, vf)


def cheese_at_square(venv, coord: Tuple[int, int]) -> float:
    """Returns the probability of navigating to a square, given that
    the cheese is placed at the given square."""
    moved_venv = maze.move_cheese(venv, coord)
    vf: Dict = visualization.vector_field(moved_venv, hook.network)
    return maze.geometric_probability_path((0, 0), coord, vf)


def retarget_heatmap(
    venv,
    hook,
    channels: List[int],
    magnitude: float = 5.5,
    remove_cheese: bool = True,
    compare_to_cheese: bool = False,
) -> pd.DataFrame:
    """Returns a DataFrame of retargeted probabilities for all squares in
    the maze, where each row contains a geometric average of the
    probabilities under retargeting the given channels to that square
    using the given magnitude."""
    new_venv = maze.remove_cheese(venv) if remove_cheese else venv
    inner_grid: np.ndarray = maze.state_from_venv(new_venv).inner_grid()
    reachable: List[Tuple[int, int]] = maze.get_legal_mouse_positions(
        inner_grid
    )

    # Get the probabilities for each square
    data: Dict[str, List[float]] = defaultdict(list)
    for coord in reachable:
        new_data = {
            "x": coord[0],
            "y": coord[1],
            "retarget_prob": retarget_to_square(
                new_venv, hook, channels, coord, magnitude
            ),
            "maze_size": inner_grid.shape[0],
            "d_to_coord": len(
                maze.pathfind(grid=inner_grid, start=(0, 0), end=coord)
            ),
            # "d_to_dsq": d_to_dsq,
            "start": (0, 0),
        }
        if compare_to_cheese:
            new_data["cheese_prob"] = cheese_at_square(new_venv, coord)

        for k, v in new_data.items():
            data[k].append(v)

    return pd.DataFrame(data)

In [18]:
@interact(seed=(0, 100))
def show_retarget_heatmap_interactive(seed: int) -> None:
    """Show a heatmap over the maze using matplotlib.

    Args:
        seed: Seed of the maze
    """
    venv = maze.create_venv(num=1, start_level=seed, num_levels=1)
    data = retarget_heatmap(
        venv, hook, channels=cheese_channels, magnitude=5.5, remove_cheese=True
    )
    heatmap = data.pivot(index="y", columns="x", values="retarget_prob")

    visualization.show_grid_heatmap(
        venv=venv, heatmap=heatmap, ax_size=AX_SIZE, mode="human", size=0.5
    )

interactive(children=(IntSlider(value=50, description='seed'), Output()), _dom_classes=('widget-interact',))