In [1]:
# %% Don't have to restart kernel and reimport each time you modify a dependency
%reload_ext autoreload
%autoreload 2

# %%
# Imports
from typing import List, Tuple, Dict, Union, Optional, Callable
import re 

import numpy as np
import pandas as pd
import torch as t
import plotly.express as px
import plotly as py
import plotly.graph_objects as go
from tqdm import tqdm
from einops import rearrange
from IPython.display import Video, display, clear_output
from ipywidgets import Text, interact, IntSlider, fixed, FloatSlider, Dropdown
import itertools
from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
import matplotlib.pyplot as plt

# NOTE: this is Monte's RL hooking code (and other stuff will be added in the future)
# Install normally with: pip install circrl
import circrl.module_hook as cmh
import procgen_tools.models as models
from experiments.patch_utils import *

# %% 
# Load two levels and get values
import pickle as pkl
from procgen import ProcgenGym3Env

rand_region = 5
# Check whether we're in jupyter
try:
    get_ipython()
    in_jupyter = True
except NameError:
    in_jupyter = False
path_prefix = '../' if in_jupyter else ''

# %%
# Load model

policy = models.load_policy(path_prefix + f'trained_models/maze_I/model_rand_region_{rand_region}.pth', 15,
    t.device('cpu'))

# %% Experiment parameters
label = 'embedder.block2.res1.resadd_out'
interesting_coeffs = np.linspace(-2/3,2/3,10) 
hook = cmh.ModuleHook(policy)

In [25]:
# Let's load a dummy observation with only one nonzero value
# This is the same as the one used in rollout-patch.py
dummy_obs = np.zeros((2, 3, 64, 64), dtype=np.float32)
dummy_obs[0, 0, 0, 0] = 1
hook.probe_with_input(dummy_obs)

# Let's visualize the convolution activations at each layer. First we have to get the relevant layers
labels = list(hook.values_by_label.keys()) 
# Use regex to check whether the label contains 'conv' and 'out', with an optional number in between
conv_labels = [l for l in labels if re.match(r'.*conv.*out.*', l)]
print(conv_labels)

['embedder.block1.conv_out', 'embedder.block1.res1.conv1_out', 'embedder.block1.res1.conv2_out', 'embedder.block1.res2.conv1_out', 'embedder.block1.res2.conv2_out', 'embedder.block2.conv_out', 'embedder.block2.res1.conv1_out', 'embedder.block2.res1.conv2_out', 'embedder.block2.res2.conv1_out', 'embedder.block2.res2.conv2_out', 'embedder.block3.conv_out', 'embedder.block3.res1.conv1_out', 'embedder.block3.res1.conv2_out', 'embedder.block3.res2.conv1_out', 'embedder.block3.res2.conv2_out']


In [31]:
# Now that we have the conv labels, let's visualize the activations at each one using plotly
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

# Let's make a grid of subplots, with one row for each conv layer
# fig = make_subplots(rows=len(conv_labels), cols=1, subplot_titles=conv_labels)
# for i, label in enumerate(conv_labels):
#     # Get the activations
#     activations = hook.get_value_by_label(label)
#     # Reshape to be (batch, channel, height, width)
#     activations = rearrange(activations, 'b c h w -> b h w c')
#     # Get the difference between the two activations 
#     diff_activations = activations[1] - activations[0]
#     # Plot the mean activation
#     fig.add_trace(go.Heatmap(z=diff_activations[0]), row=i+1, col=1)
# fig.update_layout(height=2000, width=1000, title_text="Convolution activations")

# Now make the same thing, but allow using a dropdown box to select conv_labels
def plot_conv_activations(label):
    activations = hook.get_value_by_label(label)
    activations = rearrange(activations, 'b c h w -> b h w c')
    diff_activations = activations[1] - activations[0]
    fig = go.Figure(data=go.Heatmap(z=diff_activations[0]))
    fig.update_layout(height=500, width=500, title_text=label)
    fig.show()

interact(plot_conv_activations, label=Dropdown(options=conv_labels))

interactive(children=(Dropdown(description='label', options=('embedder.block1.conv_out', 'embedder.block1.res1…

<function __main__.plot_conv_activations(label)>