One particularly interesting channel in `block2.res1.resadd_out` is _channel 55_. In this notebook, we will: 
1. Visualize channel 55 and demonstrate **that it positively activates on cheese and weakly negatively activates elsewhere**, 
2. Demonstrate how the agent can sometimes be retargeted using a simple synthetic activation patch, and
3. Show that this channel can weakly increase cheese-seeking (multiply by >1), decrease cheese-seeking (zero- or mean-ablate), strongly promote cheese-avoidance (multiply by < -1), and promote no-ops (multiply by << -1). 

In [4]:
try:
    import procgen_tools
except ImportError or ModuleNotFoundError:
    get_ipython().run_line_magic(magic_name='pip', line='install -U git+https://github.com/ulissemini/procgen-tools')

from procgen_tools.utils import setup

setup() # create directory structure and download data 

from procgen_tools.imports import *
from procgen_tools.procgen_imports import * 

Already downloaded https://nerdsniper.net/mats/episode_data.tgz
Already downloaded https://nerdsniper.net/mats/patch_data.tgz
Already downloaded https://nerdsniper.net/mats/data.tgz
Already downloaded https://nerdsniper.net/mats/model_rand_region_5.pth


# Visualizing channel 55

Try clicking on the left-ward level editor below. Move the cheese around the maze by clicking on the yellow tile, and then clicking on the tile you want to contain the cheese. Watch the positive blue activations equivariantly translate along with the cheese!

In [7]:
# Show a maze editor side-by-side with the interactive plotter
SEED = 1
venv = create_venv(num=1, start_level=SEED, num_levels=1) # This has to be a single maze, otherwise the vfield wont work

default_settings = {'channel_slider': 55, 'label_widget': 'block2.res1.resadd_out'}
custom_maze_plotter = ActivationsPlotter(labels, lambda activations, fig: plot_activations(activations[0], fig=fig), values_from_venv, hook, defaults=default_settings, venv=venv)

widget_box = custom_vfield(policy, venv=venv, callback=custom_maze_plotter.update_plotter, ax_size = 2) 
display(widget_box)
    
custom_maze_plotter.display() 

Box(children=(HBox(children=(GridspecLayout(children=(Button(layout=Layout(grid_area='widget001', height='0px'…

FigureWidget({
    'data': [{'colorscale': [[0.0, 'rgb(103,0,31)'], [0.1, 'rgb(178,24,43)'],
                             [0.2, 'rgb(214,96,77)'], [0.3, 'rgb(244,165,130)'],
                             [0.4, 'rgb(253,219,199)'], [0.5, 'rgb(247,247,247)'],
                             [0.6, 'rgb(209,229,240)'], [0.7, 'rgb(146,197,222)'],
                             [0.8, 'rgb(67,147,195)'], [0.9, 'rgb(33,102,172)'],
                             [1.0, 'rgb(5,48,97)']],
              'type': 'heatmap',
              'uid': '50561e8b-f4f9-4033-9b4c-ff5b00e180d4',
              'z': array([[ 0.02657262, -0.02839196, -0.07337212, ..., -0.04367238, -0.02377848,
                           -0.01424359],
                          [-0.00080798, -0.16589725, -0.19343746, ..., -0.17971103, -0.13789423,
                           -0.08407792],
                          [-0.06065782, -0.2514051 , -0.23488975, ..., -0.22294307, -0.20125811,
                           -0.19534574],
                  

VBox(children=(Dropdown(description='Layers', index=19, options=('block1.conv_in0', 'block1.conv_out', 'block1…

HBox(children=(Text(value='', layout=Layout(width='150px'), placeholder='Custom filename'), Button(description…

In [6]:
@interact
def double_channel_55(seed=IntSlider(min=0, max=100, step=1, value=0), multiplier=FloatSlider(min=-15, max=15, step=0.1, value=5.5)):
    venv = get_cheese_venv_pair(seed=seed)
    patches = get_multiply_patch(layer_name=default_layer, channel=55, multiplier=multiplier)
    fig, axs, info = compare_patched_vfields(venv, patches, hook, render_padding=True, ax_size=6)
    plt.show()

    def save_fig(b):
        fig.savefig(f'visualizations/c55_multiplier_{multiplier}_seed_{seed}.png')
    button = Button(description='Save figure')
    button.on_click(save_fig)
    display(button)

interactive(children=(IntSlider(value=0, description='seed'), FloatSlider(value=5.5, description='multiplier',…