In this notebook, you can try resampling in activations from other mazes. See how it affects behavior!

In [34]:
try:
    import procgen_tools
except ImportError:
    get_ipython().run_line_magic(
        magic_name="pip",
        line="install -U git+https://github.com/ulissemini/procgen-tools",
    )

from procgen_tools.utils import setup

setup()  # create directory structure and download data

from procgen_tools.imports import *
from procgen_tools import (
    visualization,
    patch_utils,
    maze,
    vfield,
    vfield_stats,
    metrics,
)

SAVE_DIR = "playground/visualizations"
AX_SIZE = 6

cheese_channels = [7, 8, 42, 44, 55, 77, 82, 88, 89, 99, 113]
effective_channels = [8, 55, 77, 82, 88, 89, 113]

Already downloaded https://nerdsniper.net/mats/model_rand_region_5.pth


In [35]:
def analyze_patch_on_ds(seed, patch, hook) -> Dict:
    """Analyze the patch's impact on the decision square probabilities.
    Returns a dict with keys 'original_greedy', 'patched_greedy', and
    'prob_change'."""
    grid = maze.get_inner_grid_from_seed(seed)
    ds_loc: Tuple[int, int] = metrics.decision_square(grid)

    # Move the mouse to the decision square
    mouse_pos = maze.get_mouse_pos(grid)
    grid[mouse_pos] = maze.EMPTY
    grid[ds_loc] = maze.MOUSE
    venv = maze.venv_from_grid(grid)
    obs = venv.reset().astype(np.float32)

    def categorical_to_probs(cat):
        # Assuming categorical is logits, use softmax to get probs
        probs = t.nn.functional.softmax(cat, dim=-1).numpy().squeeze()
        return {k: v for k, v in zip(models.MAZE_ACTION_INDICES.keys(), probs)}

    # Original probs
    with t.no_grad():
        categorical, _ = hook.run_with_input(obs)
    original_probs = categorical_to_probs(categorical.logits)

    # Patched probs
    with t.no_grad():
        categorical, _ = hook.run_with_input(obs, patches=patches)
    patched_probs = categorical_to_probs(categorical.logits)

    # Compute the greedy actions
    original_greedy, patched_greedy = [
        max(probs, key=probs.get) for probs in [original_probs, patched_probs]
    ]

    # Compute change in action probabilities as |original_probs -
    # patched_probs| / 2
    prob_change: float = 0.0
    for action in original_probs.keys():
        prob_change += abs(original_probs[action] - patched_probs[action])
    prob_change /= 2

    return {
        "original_greedy": original_greedy,
        "patched_greedy": patched_greedy,
        "prob_change": prob_change,
    }

In [36]:
def random_combined_px_patch(
    layer_name: str, channels: List[int], cheese_loc: Tuple[int, int] = None
):
    """Get a combined patch which randomly replaces channel activations with other activations from different levels."""
    patches = [
        patch_utils.get_random_patch(
            layer_name=layer_name,
            hook=hook,
            channel=channel,
            cheese_loc=cheese_loc,
        )
        for channel in channels
    ]
    combined_patch = patch_utils.compose_patches(*patches)
    return combined_patch


def resample_activations(
    seed: int,
    channels: List[int],
    different_location: bool = False,
    show_components: bool = False,
):  # NOTE we're resampling from a fixed maze for all target forward passes
    """Resample activations for default_layer with the given channels.

    Args:
        seed (int): The seed for the maze
        channels (List[int]): The channels to resample
        different_location (bool, optional): If True, then the
        resampling location is randomly sampled. Otherwise, it is the
        cheese location. Defaults to False.
    """
    render_padding = False

    # Figure out the location to sample activations from
    resampling_seed: int = (
        np.random.randint(0, 1e8) if different_location else seed
    )
    resampling_loc: Tuple[int, int] = maze.get_cheese_pos_from_seed(
        resampling_seed, flip_y=False
    )
    patches = random_combined_px_patch(
        layer_name=default_layer, channels=channels, cheese_loc=resampling_loc
    )

    venv = patch_utils.get_cheese_venv_pair(seed=seed)
    fig, axs, info = patch_utils.compare_patched_vfields(
        venv,
        patches,
        hook,
        render_padding=render_padding,
        ax_size=AX_SIZE,
        show_components=show_components,
    )
    channel_description = (
        f"channels {channels}"
        if len(channels) > 1
        else f"channel {channels[0]}"
    )
    fig.suptitle(
        f"Resampling {channel_description} on seed {seed}", fontsize=20
    )

    padding = maze.get_padding(maze.get_inner_grid_from_seed(seed))
    visualization.plot_dots(
        axs[1:],
        resampling_loc,
        is_grid=True,
        flip_y=False,
        hidden_padding=0 if render_padding else padding,
    )
    plt.show()

    # Display the average vector field difference magnitude
    avg_prob_diff = vfield.vf_diff_magnitude(info["diff_vfield"]) / 2
    print(
        "Action probability distributions changed by"
        f" {avg_prob_diff * 100 :2.1f}% on average"
    )

    if metrics.decision_square(info["original_vfield"]["grid"]) is not None:
        # Get the greedy actions
        ds_dict = analyze_patch_on_ds(venv, patches, hook)
        original, patched = (
            ds_dict["original_greedy"],
            ds_dict["patched_greedy"],
        )

        print(
            f"The greedy action {original} is now {patched}, with an average"
            f" change of {ds_dict['prob_change'] * 100 :2.1f}%"
        )


def get_alternate_channels(avoid_channels: List[int]) -> List[int]:
    """Get a list of random channels which aren't in avoid_channels."""
    candidate_channels = [
        channel for channel in range(128) if channel not in avoid_channels
    ]
    return sorted(
        np.random.choice(
            candidate_channels, size=len(avoid_channels), replace=False
        ).tolist()
    )

In [None]:
interactive(
    resample_activations,
    seed=IntSlider(min=0, max=100, step=1, value=0),
    channels=Dropdown(
        options=[
            cheese_channels,
            get_alternate_channels(cheese_channels),
            [42, 55, 77, 88],
            [55],
        ],
        value=cheese_channels,
    ),
    different_location=Checkbox(value=False),
    show_components=Checkbox(value=False),
)  # TODO fix error in notebook

## Resampling's effect on the action taken at the decision square 
Let's first see how often resampling from the same cheese location, will change the
decision at the decision square. 

1. Sample a bunch of mazes with a decision square
2. Get the greedy action 
3. Resample activations from the same cheese location / from diff
4. Get the greedy action again
5. See how often the action changes

In [53]:
# Check n seeds and tally how many times the greedy action changes,
# and what the average change is.
num_seeds: int = 200

data = {
    "seed": [],
    "cheese_loc_seed": [],
    "original_greedy": [],
    "patched_greedy": [],
    "prob_change": [],
    "different_location": [],
    "random_channels": [],
    "channels_used": [],
}
data = pd.DataFrame(data)

# Load in from CSV, if it exists
csv_path = "experiments/statistics/data/cheese_resampling.csv"
if os.path.exists(csv_path):
    data = pd.read_csv(csv_path)
    print(f"Loaded {len(data)} rows from {csv_path}")

# Start at the first seed which hasn't been analyzed yet
if len(data["seed"]) > 0:
    current_seed = int(data["seed"].iloc[-1]) + 1
else:
    current_seed = 0  # Start fresh

# data["seed"] will contain 4 entries for each seed: one for the
# same cheese location, and one for a different location; and
# random_channels will be True or False for each of those.
while (len(data["seed"]) / 4) < num_seeds:
    # Stop if the decision square doesn't exist
    grid = maze.get_inner_grid_from_seed(current_seed)
    # grid = maze.inner_grid(grid)
    if metrics.decision_square(grid) is not None:
        for different_location in [True, False]:
            cheese_loc_seed: int = (
                np.random.randint(0, 1e8)
                if different_location
                else current_seed
            )
            resampling_loc: Tuple[int, int] = maze.get_cheese_pos_from_seed(
                cheese_loc_seed, flip_y=False
            )
            for random_channels in [True, False]:
                # Check whether patching in other channels has an effect
                # on the greedy action
                channels = (
                    get_alternate_channels(cheese_channels)
                    if random_channels
                    else cheese_channels
                )

                patches = random_combined_px_patch(
                    layer_name=default_layer,
                    channels=channels,
                    cheese_loc=resampling_loc,
                )

                # Get the greedy actions
                ds_dict = analyze_patch_on_ds(current_seed, patches, hook)
                original, patched = (
                    ds_dict["original_greedy"],
                    ds_dict["patched_greedy"],
                )

                # Create a list of dictionaries with the data
                new_data = [
                    {
                        "seed": current_seed,
                        "cheese_loc_seed": cheese_loc_seed,
                        "original_greedy": original,
                        "patched_greedy": patched,
                        "prob_change": ds_dict["prob_change"],
                        "different_location": int(different_location),
                        "random_channels": int(random_channels),
                        "channels_used": channels,
                    }
                ]

                # Append the new data to the DataFrame
                data = pd.concat([data, pd.DataFrame(new_data)])
        # Save data to CSV
        data.to_csv(csv_path, index=False)
    current_seed += 1

Loaded 80 rows from experiments/statistics/data/cheese_resampling.csv


In [51]:
if os.path.exists(csv_path):
    data = pd.read_csv(csv_path)
    print(f"Loaded {len(data)} rows from {csv_path}")
data_true = data[data["different_location"] == 1]
data_false = data[data["different_location"] == 0]

Loaded 80 rows from experiments/statistics/data/cheese_resampling.csv


In [52]:
# Scatter plot with different colors depending on random_channels,
# adding a legend
fig = px.scatter(
    data,
    x=data_false["prob_change"],
    y=data_true["prob_change"],
    color=data_true["random_channels"].astype(str),
    # color_discrete_map={"0": "blue", "1": "red"},
    # labels={"0": "Same location", "1": "Different location"},
)

# Title plot and label axes and ensure same scale, with dotted line on
# y=x
fig.update_layout(
    title=(
        "Change in action probability distribution when resampling activations"
    ),
    yaxis_title="From maze with different cheese location",
    xaxis_title="From maze with same cheese location",
)
# Show gray dotted y=x line
fig.add_shape(
    type="line",
    x0=0,
    y0=0,
    x1=1,
    y1=1,
    line=dict(color="Gray", width=1, dash="dot"),
)
# Set bounds to be [0, max(prob_change)] on both axes TODO update
fig.update_xaxes(range=[0, 1.05 * max(data_false["prob_change"])])
fig.update_yaxes(range=[0, 1.05 * max(data_true["prob_change"])])

# Set aspect ratio to be 1:1
fig.update_layout(
    autosize=False,
    width=800,
    height=500,
)

# Hover the seed, prob changes, and show the original and patched greedy actions
fig.update_traces(
    hovertemplate="<br>".join(
        [
            "Seed: %{customdata[2]}",
            "ΔP, same location: %{x:.2f}",
            "ΔP, different location: %{y:.2f}",
            # "Original action: %{customdata[0]}",
            # "Patched action: %{customdata[1]}",
        ]
    ),
    customdata=np.stack(
        (data["original_greedy"], data["patched_greedy"], data["seed"]),
        axis=-1,
    ),
)

# Show the plot
fig.show()

In [None]:
# Compute whether the greedy action changed
data["greedy_changed"] = data["original_greedy"] != data["patched_greedy"]

# Print a table showing the rate of change for each case
import prettytable

table = prettytable.PrettyTable()
table.field_names = [
    "Different location?",
    "Random channels?",
    "Rate of action change",
]
for different_location in [True, False]:
    for random_channels in [True, False]:
        data_subset = data[
            (data["different_location"] == different_location)
            & (data["random_channels"] == random_channels)
        ]
        num_changed = len(data_subset[data_subset["greedy_changed"] == True])
        num_total = len(data_subset)
        table.add_row(
            [
                "Yes" if different_location else "No",
                "Yes" if random_channels else "No",
                (
                    f"{num_changed} /"
                    f" {num_total} ({num_changed / num_total * 100 :2.1f}%)"
                ),
            ]
        )
print(table)