In [8]:
import procgen_tools
from procgen_tools.utils import setup
import torch
import tqdm

setup()  # create directory structure and download data

from procgen_tools.imports import *
from procgen_tools import (
    visualization,
    patch_utils,
    maze,
    vfield,
    vfield_stats,
    metrics,
)
import matplotlib.pyplot as plt

Already downloaded https://nerdsniper.net/mats/model_rand_region_5.pth


In [2]:
# Start a wandb run
import wandb

# wandb.init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mturn-trout[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
def get_activ(venv, layer_name: str = default_layer) -> torch.Tensor:
    obs = t.tensor(venv.reset(), dtype=t.float32, requires_grad=True)
    with hook.set_hook_should_get_custom_data():
        hook.network(obs)
        activ = hook.get_value_by_label(layer_name, convert=False)
    return activ  # shape: (batch_size, channels, height, width)

In [33]:
# Load the model
policy, hook = load_model(
    rand_region=RAND_REGION, num_actions=NUM_ACTIONS, use_small=False
)

In [34]:
# Start an optimizer
optim = torch.optim.Adam(hook.network.parameters(), lr=1e-3)

# L2 loss between two vectors
loss_fn = torch.nn.MSELoss()

# Load the data
epochs: int = 5
seeds = range(5)
# Original activation on maze without cheese
original_nc_activ: List[torch.Tensor] = []
for seed in seeds:
    # Load the data
    no_cheese_pair = maze.get_cheese_venv_pair(seed=seed)

    # Train the policy
    # Get activations using forward hook at default_layer
    _, nc = get_activ(no_cheese_pair)
    original_nc_activ.append(nc.detach())

for epoch in tqdm(range(epochs)):
    optim.zero_grad()
    for seed in seeds:
        # Load the data
        cheese_pair = maze.get_cheese_venv_pair(seed=seed)

        # Train the policy
        # Get activations using forward hook at default_layer
        cheese_activ, no_cheese_activ = get_activ(cheese_pair)

        # Optimize TODO stop from updating wrt no_cheese_activ (detach)
        loss = 0.01 * loss_fn(
            cheese_activ, no_cheese_activ
        )  # no_cheese is the target
        loss += loss_fn(no_cheese_activ, original_nc_activ[seed])
        loss.backward()
    optim.step()

    # Check cheese values on seed 0
    c_activ, nc_activ = patch_utils.cheese_diff_values(0, default_layer, hook)
    norm_diff = np.linalg.norm(c_activ - nc_activ)
    # Print magnitude of diff norm using numpy
    print(f"Cheese diff: {norm_diff.item()}")

    # Log the results
    wandb.log(
        {
            "loss": loss.item(),
            "epoch": epoch,
            "seed": seed,
            "diff": norm_diff.item(),
        }
    )

 20%|██        | 1/5 [00:00<00:02,  1.49it/s]

Cheese diff: 2.388932228088379


 40%|████      | 2/5 [00:01<00:01,  1.66it/s]

Cheese diff: 3.2682642936706543


 60%|██████    | 3/5 [00:01<00:01,  1.61it/s]

Cheese diff: 4.049895763397217


 80%|████████  | 4/5 [00:02<00:00,  1.37it/s]

Cheese diff: 3.920301675796509


100%|██████████| 5/5 [00:03<00:00,  1.48it/s]

Cheese diff: 3.6196577548980713





In [42]:
# Visualize the new hooked policy
AX_SIZE = 2


@interact
def compare_component_probabilities(
    seed=IntSlider(min=0, max=100, step=1, value=0)
):
    # fig.set_title(f"Seed {seed}")
    venv = maze.create_venv(num=1, start_level=seed, num_levels=1)

    # Plot the original vector field
    old_vf = vfield.vector_field(venv, policy=policy)
    new_vf = vfield.vector_field(venv, policy=hook.network)
    fig, axs, _ = visualization.plot_vfs(old_vf, new_vf)

    axs[0].set_xlabel("Original")
    axs[1].set_xlabel("After activation invariance")
    axs[2].set_xlabel("Difference")
    plt.show()

interactive(children=(IntSlider(value=0, description='seed'), Output()), _dom_classes=('widget-interact',))