In [52]:
!pip install procgen
!pip install plotly
!pip install "notebook>=5.3" "ipywidgets>=7.2"


[0m

In [43]:
import plotly.io as pio
pio.renderers.default = 'notebook_connected'

In [2]:
!pip install git+https://github.com/UlisseMini/procgen-tools.git


Collecting git+https://github.com/UlisseMini/procgen-tools.git
  Cloning https://github.com/UlisseMini/procgen-tools.git to /tmp/pip-req-build-j6ka69g0
  Running command git clone --filter=blob:none --quiet https://github.com/UlisseMini/procgen-tools.git /tmp/pip-req-build-j6ka69g0
  Resolved https://github.com/UlisseMini/procgen-tools.git to commit dc2243c99110aca92687cdf566daafcfbe7067a0
  Preparing metadata (setup.py) ... [?25ldone
[0m

In [39]:
#import statements
from procgen import ProcgenGym3Env
import torch
from procgen_tools import maze
from procgen_tools.models import load_policy
from tqdm import tqdm
import numpy as np
import pickle
from argparse import ArgumentParser
import random
from procgen_tools.data_utils import Episode
import plotly.express as px

In [40]:
def create_venv(num_levels = 1, start_level = 0):
    venv = ProcgenGym3Env(
        num=1,
        env_name='maze', num_levels=num_levels, start_level=start_level,
        distribution_mode='hard', num_threads=4, render_mode="rgb_array", 
    )
    venv = maze.wrap_venv(venv)
    return venv

In [5]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU is available")
else:
    device = torch.device('cpu')
    print("GPU not available, using CPU instead")

policy = load_policy('model_rand_region_5.pth', action_size=15, device=device)



# Extract the model name from the file path for later use in naming the saved data files.
model_name = "maze_example"

GPU is available


In [6]:
print(policy)

CategoricalPolicy(
  (embedder): InterpretableImpalaModel(
    (block1): InterpretableImpalaBlock(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (res1): InterpretableResidualBlock(
        (relu1): ReLU()
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu2): ReLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (resadd): ResidualAdd()
      )
      (res2): InterpretableResidualBlock(
        (relu1): ReLU()
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu2): ReLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (resadd): ResidualAdd()
      )
    )
    (block2): InterpretableImpalaBlock(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (maxp

In [7]:
from circrl import hooks as cmh
from circrl import rollouts as cro

In [8]:
# maybe problematic but our model has only key: embedder.block1.res1.resadd without out

cache_values = [
    'embedder.block1.res1.resadd',
    'embedder.block1.res2.resadd',
    'embedder.block2.res1.resadd',
    'embedder.block2.res2.resadd',
    'embedder.block3.res1.resadd',
    'embedder.block3.res2.resadd']
# Create an instance of HookManager
hook_manager = cmh.HookManager(
    model=policy,
    cache=[
    'embedder.block1.res1.resadd',
    'embedder.block1.res2.resadd',
    'embedder.block2.res1.resadd',
    'embedder.block2.res2.resadd',
    'embedder.block3.res1.resadd',
    'embedder.block3.res2.resadd'],

)

In [9]:
print(range(len(cache_values)))

range(0, 6)


# Looping through the number of episodes and saving state_bytes, obs, cheese_pos and the cached convolutional activations


In [12]:
# determinism
max_seed = 1000000
num_episodes =100
argmax = True
rng = np.random.default_rng()
seeds = rng.choice(max_seed, size=num_episodes, replace=False)
num_timesteps = 1  # For example, capture 10 timesteps

activations_dict = {f'{i}': [] for i in cache_values}
cheese_position_total = []
state_bytes_data_array = []
obs_data_array = []

for seed in tqdm(seeds):
    seed = int(seed)
    venv = create_venv(start_level=seed)
    assert venv.num_envs == 1, 'Only one env supported (for now)'

    obs = venv.reset()
    info = venv.env.get_info()

    #initialising the list of all episodes
    all_episodes = []
    episode_cheese_positions = []
    
    for i in range(len(cache_values)):
        all_episodes.append([])


    policy.eval()
    for step in range(num_timesteps):
        with hook_manager:
            obs_data_array.append(obs)
            p, v = policy(torch.FloatTensor(obs))
        action = p.probs.argmax(dim=-1).numpy() if argmax else p.sample().numpy()
        obs, rew, done, info = venv.step(action)

        states_bytes = venv.env.callmethod('get_state')[0]
        state_bytes_data_array.append(states_bytes)
        
        states_vals = maze._parse_maze_state_bytes(states_bytes)
        grid = maze.get_grid(states_vals)
        cheese_position = maze.get_cheese_pos(grid)[0]

        for i in range(len(cache_values)):
            all_episodes[i].append(hook_manager.cache_results[list(activations_dict)[i]])
        
        episode_cheese_positions.append(cheese_position)

        if done:
            break

    i= 0    
    for key in activations_dict.keys():
        
        print("value of i", i)
        activations_dict[key].extend(all_episodes[i])   
        i += 1
    cheese_position_total.extend(episode_cheese_positions)

# Convert lists to tensors

for key in activations_dict.keys():
    activations_dict[key] = torch.stack(activations_dict[key])
cheese_positions = torch.tensor(cheese_position_total, dtype=torch.long)

# Continue with data preparation and linear probe training as before...


  1%|          | 1/100 [00:00<00:15,  6.27it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


  2%|▏         | 2/100 [00:00<00:23,  4.13it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


  3%|▎         | 3/100 [00:00<00:22,  4.40it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


  5%|▌         | 5/100 [00:01<00:19,  4.87it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


  7%|▋         | 7/100 [00:01<00:20,  4.44it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


  9%|▉         | 9/100 [00:02<00:21,  4.26it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 10%|█         | 10/100 [00:02<00:22,  4.04it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 11%|█         | 11/100 [00:02<00:23,  3.79it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 12%|█▏        | 12/100 [00:02<00:21,  4.07it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 13%|█▎        | 13/100 [00:03<00:22,  3.85it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 14%|█▍        | 14/100 [00:03<00:21,  4.08it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 15%|█▌        | 15/100 [00:03<00:21,  3.87it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 16%|█▌        | 16/100 [00:03<00:20,  4.10it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 18%|█▊        | 18/100 [00:04<00:19,  4.16it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 19%|█▉        | 19/100 [00:04<00:20,  3.87it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 20%|██        | 20/100 [00:04<00:19,  4.13it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 22%|██▏       | 22/100 [00:05<00:17,  4.58it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 23%|██▎       | 23/100 [00:05<00:18,  4.13it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 24%|██▍       | 24/100 [00:06<00:23,  3.25it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 26%|██▌       | 26/100 [00:06<00:17,  4.32it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 27%|██▋       | 27/100 [00:06<00:20,  3.62it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 28%|██▊       | 28/100 [00:06<00:18,  3.92it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 30%|███       | 30/100 [00:07<00:17,  4.06it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 31%|███       | 31/100 [00:07<00:16,  4.17it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 32%|███▏      | 32/100 [00:07<00:17,  3.96it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 34%|███▍      | 34/100 [00:08<00:14,  4.50it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 35%|███▌      | 35/100 [00:08<00:17,  3.72it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 36%|███▌      | 36/100 [00:08<00:16,  3.96it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 37%|███▋      | 37/100 [00:09<00:16,  3.82it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 38%|███▊      | 38/100 [00:09<00:15,  4.03it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 40%|████      | 40/100 [00:09<00:14,  4.15it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 42%|████▏     | 42/100 [00:10<00:12,  4.64it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 43%|████▎     | 43/100 [00:10<00:13,  4.17it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 45%|████▌     | 45/100 [00:11<00:12,  4.55it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 47%|████▋     | 47/100 [00:11<00:12,  4.41it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 48%|████▊     | 48/100 [00:11<00:11,  4.51it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 50%|█████     | 50/100 [00:12<00:10,  4.70it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 51%|█████     | 51/100 [00:12<00:11,  4.23it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 52%|█████▏    | 52/100 [00:12<00:10,  4.39it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 53%|█████▎    | 53/100 [00:12<00:11,  4.11it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 55%|█████▌    | 55/100 [00:13<00:09,  4.51it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 57%|█████▋    | 57/100 [00:13<00:09,  4.70it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 59%|█████▉    | 59/100 [00:14<00:08,  4.83it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 61%|██████    | 61/100 [00:14<00:10,  3.87it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 62%|██████▏   | 62/100 [00:15<00:10,  3.74it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 63%|██████▎   | 63/100 [00:15<00:10,  3.63it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 64%|██████▍   | 64/100 [00:15<00:09,  3.94it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 66%|██████▌   | 66/100 [00:16<00:07,  4.37it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 67%|██████▋   | 67/100 [00:16<00:08,  4.07it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 68%|██████▊   | 68/100 [00:16<00:07,  4.24it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 70%|███████   | 70/100 [00:16<00:06,  4.53it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 72%|███████▏  | 72/100 [00:17<00:05,  4.78it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 73%|███████▎  | 73/100 [00:17<00:06,  4.29it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 74%|███████▍  | 74/100 [00:17<00:06,  3.98it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 75%|███████▌  | 75/100 [00:18<00:05,  4.21it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 77%|███████▋  | 77/100 [00:18<00:04,  4.68it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 78%|███████▊  | 78/100 [00:18<00:05,  4.17it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 79%|███████▉  | 79/100 [00:19<00:05,  3.97it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 80%|████████  | 80/100 [00:19<00:05,  3.77it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 81%|████████  | 81/100 [00:19<00:04,  4.06it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 83%|████████▎ | 83/100 [00:20<00:03,  4.57it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 84%|████████▍ | 84/100 [00:20<00:03,  4.17it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 85%|████████▌ | 85/100 [00:20<00:03,  4.35it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 86%|████████▌ | 86/100 [00:20<00:03,  4.05it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 87%|████████▋ | 87/100 [00:21<00:03,  4.22it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 89%|████████▉ | 89/100 [00:21<00:02,  4.73it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5
value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 90%|█████████ | 90/100 [00:21<00:02,  4.27it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 91%|█████████ | 91/100 [00:21<00:02,  4.38it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 92%|█████████▏| 92/100 [00:22<00:01,  4.12it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 93%|█████████▎| 93/100 [00:22<00:01,  4.32it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 94%|█████████▍| 94/100 [00:22<00:01,  4.47it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 95%|█████████▌| 95/100 [00:22<00:01,  4.54it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 96%|█████████▌| 96/100 [00:23<00:00,  4.21it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 97%|█████████▋| 97/100 [00:23<00:00,  4.31it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 98%|█████████▊| 98/100 [00:23<00:00,  4.05it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


 99%|█████████▉| 99/100 [00:23<00:00,  4.23it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5


100%|██████████| 100/100 [00:24<00:00,  4.14it/s]

value of i 0
value of i 1
value of i 2
value of i 3
value of i 4
value of i 5





In [14]:
print(cheese_positions.shape)
print(len(state_bytes_data_array))
print(len(obs_data_array))

torch.Size([100])
100
100


In [15]:
from circrl import probing as prb

# Implementation of Monte's Convolutional Probes 

In [16]:
import xarray as xr

#Monte's method of probing (Not fully understood yet)
state_bytes_key = 'dec_state_bytes'
def get_obj_loc_targets(data_all, obj_value, state_bytes_key):
    '''Get potential probe targets for y,x (row, col) location of an
    object, where the object is specified by obj_value as the value
    to match on in the maze grid array.'''
    pos_arr = maze.get_object_pos_from_seq_of_states(
        [state_bytes for state_bytes in data_all], obj_value)
    pos = xr.Dataset({
        'y': xr.DataArray(pos_arr[:,0], dims=['batch']),
        'x': xr.DataArray(pos_arr[:,1], dims=['batch'])}).assign_coords(
            {'batch': np.arange(len(data_all))})
    return pos

# Get the cheese position for all data points
cheese_pos = get_obj_loc_targets(state_bytes_data_array, maze.CHEESE, state_bytes_key)

In [17]:
#Let's print out the cheese positions to see what the above method has done
print(cheese_pos)

<xarray.Dataset> Size: 2kB
Dimensions:  (batch: 100)
Coordinates:
  * batch    (batch) int64 800B 0 1 2 3 4 5 6 7 8 ... 91 92 93 94 95 96 97 98 99
Data variables:
    y        (batch) int64 800B 8 12 13 12 12 11 6 8 9 ... 11 14 15 11 12 5 7 10
    x        (batch) int64 800B 16 12 13 14 12 12 2 11 ... 12 18 14 14 10 15 11


In [53]:
import numpy as np
import pandas as pd
import xarray as xr
import plotly.express as px
from tqdm import tqdm

num_batch = 100
batch_coords = np.arange(len(obs_data_array))

# Assuming obs_data_array is defined and contains observable data

obs_all = xr.concat([xr.DataArray(obs) for obs in obs_data_array], dim='batch').assign_coords({'batch': batch_coords})
print("obsv shape", obs_all.shape[0])

#obs_all = xr.concat([obsv for obsv in obs_data_array], dim='batch').assign_coords(dict(batch=batch_coords))

def grid_coord_to_value_ind(full_grid_coord, value_size):
    '''Pick the value index that covers the majority of the grid coord pixel'''
    return np.floor((full_grid_coord + 0.5) * value_size / maze.WORLD_DIM).astype(int)

def value_ind_to_grid_coord(value_ind, value_size):
    '''Pick the grid coordinate index whose center is closest to the center of the value pixel'''
    return np.floor((value_ind + 0.5) * maze.WORLD_DIM / value_size).astype(int)

def get_obj_pos_data(value_label, object_pos):
    '''Pick the object location and a random other location without the object for a balanced dataset.'''
    rng = np.random.default_rng(15)
    value = activations_dict[value_label].detach().numpy()  # Assuming activations_dict is defined
    print("value shape", value.shape)
    value_size = value.shape[-1]
    num_pixels = num_batch * 2
    pixels = np.zeros((num_pixels, value.shape[2]))
    is_obj = np.zeros(num_pixels, dtype=bool)
    rows_in_value = np.zeros(num_pixels, dtype=int)
    cols_in_value = np.zeros(num_pixels, dtype=int)
    for bb in tqdm(range(obs_all.shape[0])):
        obj_pos_value = (grid_coord_to_value_ind(maze.WORLD_DIM - 1 - object_pos.y[bb].item(), value_size),
                         grid_coord_to_value_ind(object_pos.x[bb].item(), value_size))
        pixels[bb, :] = value[bb, 0, :, obj_pos_value[0], obj_pos_value[1]]
        is_obj[bb] = True
        rows_in_value[bb] = obj_pos_value[0]
        cols_in_value[bb] = obj_pos_value[1]
        bb_rand = bb + num_batch
        random_pos = obj_pos_value
        while random_pos == obj_pos_value:
            random_pos = (rng.integers(value_size), rng.integers(value_size))
        pixels[bb_rand, :] = value[bb, 0,:, random_pos[0], random_pos[1]]
        is_obj[bb_rand] = False
        rows_in_value[bb_rand] = random_pos[0]
        cols_in_value[bb_rand] = random_pos[1]
    return pixels, is_obj, rows_in_value, cols_in_value

value_labels_conv = ['embedder.block1.res1.resadd', 'embedder.block1.res2.resadd', 'embedder.block2.res1.resadd', 'embedder.block2.res2.resadd', 'embedder.block3.res1.resadd', 'embedder.block3.res2.resadd']

f_test_list = []
pixel_data = {}
for value_label in value_labels_conv:
    pixels, is_obj, rows_in_value, cols_in_value = get_obj_pos_data(value_label, cheese_pos)  # Assuming cheese_pos is defined
    f_test, _ = prb.f_classif_fixed(pixels, is_obj)  # Assuming cpr.f_classif_fixed is defined
    sort_inds = np.argsort(f_test)[::-1]
    pixel_data[value_label] = (pixels, is_obj, rows_in_value, cols_in_value, f_test, sort_inds)
    f_test_list.append(pd.DataFrame({
        'layer': np.full(sort_inds.shape, value_label), 
        'rank': np.arange(len(sort_inds)),
        'channel': sort_inds, 
        'f-score': f_test[sort_inds]
    }))

f_test_df = pd.concat(f_test_list, axis='index')

####The following plot does not work

from IPython.display import display

fig = px.line(f_test_df, x='rank', y='f-score', color='layer', hover_data=['channel'],
              title='Ranked f-test scores for "conv pixel contains cheese" over resadd layers')
display(fig)

print(f_test_df)

obsv shape 100
value shape (100, 1, 64, 32, 32)


100%|██████████| 100/100 [00:00<00:00, 2129.21it/s]


value shape (100, 1, 64, 32, 32)


100%|██████████| 100/100 [00:00<00:00, 2178.19it/s]


value shape (100, 1, 128, 16, 16)


100%|██████████| 100/100 [00:00<00:00, 2166.35it/s]


value shape (100, 1, 128, 16, 16)


100%|██████████| 100/100 [00:00<00:00, 2222.75it/s]


value shape (100, 1, 128, 8, 8)


100%|██████████| 100/100 [00:00<00:00, 2158.34it/s]


value shape (100, 1, 128, 8, 8)


100%|██████████| 100/100 [00:00<00:00, 2150.86it/s]


                           layer  rank  channel     f-score
0    embedder.block1.res1.resadd     0       15  691.510059
1    embedder.block1.res1.resadd     1       47  615.866175
2    embedder.block1.res1.resadd     2       62  441.243050
3    embedder.block1.res1.resadd     3       50  382.451741
4    embedder.block1.res1.resadd     4        0  343.370893
..                           ...   ...      ...         ...
123  embedder.block3.res2.resadd   123       81    2.235630
124  embedder.block3.res2.resadd   124       83    1.869431
125  embedder.block3.res2.resadd   125       75    1.435044
126  embedder.block3.res2.resadd   126      110    0.007988
127  embedder.block3.res2.resadd   127       14    0.000122

[640 rows x 4 columns]


In [48]:
index_nums = np.arange(10)+1
scores_list = []
for value_label, (pixels, is_obj, rows_in_value, cols_in_value, f_test, sort_inds) in tqdm(pixel_data.items()):
    for K in index_nums:
        results = prb.linear_probe(pixels[:,sort_inds[:K]], is_obj, C=10, random_state=42)
        scores_list.append({'layer': value_label, 'K': K, 'score': results['test_score']})
scores_df = pd.DataFrame(scores_list)

100%|██████████| 6/6 [00:02<00:00,  2.93it/s]


In [59]:
print(scores_df)

                          layer   K  score
0   embedder.block1.res1.resadd   1  0.875
1   embedder.block1.res1.resadd   2  0.975
2   embedder.block1.res1.resadd   3  0.950
3   embedder.block1.res1.resadd   4  0.975
4   embedder.block1.res1.resadd   5  1.000
5   embedder.block1.res1.resadd   6  1.000
6   embedder.block1.res1.resadd   7  0.975
7   embedder.block1.res1.resadd   8  0.975
8   embedder.block1.res1.resadd   9  0.975
9   embedder.block1.res1.resadd  10  0.975
10  embedder.block1.res2.resadd   1  0.925
11  embedder.block1.res2.resadd   2  0.975
12  embedder.block1.res2.resadd   3  1.000
13  embedder.block1.res2.resadd   4  1.000
14  embedder.block1.res2.resadd   5  1.000
15  embedder.block1.res2.resadd   6  1.000
16  embedder.block1.res2.resadd   7  1.000
17  embedder.block1.res2.resadd   8  1.000
18  embedder.block1.res2.resadd   9  1.000
19  embedder.block1.res2.resadd  10  1.000
20  embedder.block2.res1.resadd   1  0.975
21  embedder.block2.res1.resadd   2  0.975
22  embedde

Let us try to increase the time step for each episode and stack the convolutional layer activations over time steps:
 # The number of timesteps is now 5 and we treat each timestep as a separate data point and stack them all up


In [66]:

hook_manager = cmh.HookManager(
    model=policy,
    cache=[
    'embedder.block1.res1.conv1',
    'embedder.block1.res1.conv2',
    'embedder.block2.res1.conv1',
    'embedder.block2.res1.conv2'],

)

# determinism
max_seed = 1000000
num_episodes =100
argmax = True
rng = np.random.default_rng()
seeds = rng.choice(max_seed, size=num_episodes, replace=False)
num_timesteps = 5  # For example, capture 10 timesteps

conv_layer_activations_1 = []
conv_layer_activations_2 = []
conv_layer_activations_3 = []
conv_layer_activations_4 = []
cheese_position_total = []

for seed in tqdm(seeds):
    seed = int(seed)
    venv = create_venv(start_level=seed)
    assert venv.num_envs == 1, 'Only one env supported (for now)'

    obs = venv.reset()
    info = venv.env.get_info()

    episode_activations_1 = []
    episode_activations_2 = []
    episode_activations_3 = []
    episode_activations_4 = []
    episode_cheese_positions = []

    policy.eval()
    for step in range(num_timesteps):
        with hook_manager:
            p, v = policy(torch.FloatTensor(obs))
        action = p.probs.argmax(dim=-1).numpy() if argmax else p.sample().numpy()
        obs, rew, done, info = venv.step(action)

        states_bytes = venv.env.callmethod('get_state')[0]
        states_vals = maze._parse_maze_state_bytes(states_bytes)
        grid = maze.get_grid(states_vals)
        cheese_position = maze.get_cheese_pos(grid)[0]

        episode_activations_1.append(hook_manager.cache_results['embedder.block1.res1.conv1'])
        episode_activations_2.append(hook_manager.cache_results['embedder.block1.res1.conv2'])
        episode_activations_3.append(hook_manager.cache_results['embedder.block2.res1.conv1'])
        episode_activations_4.append(hook_manager.cache_results['embedder.block2.res1.conv2'])

        episode_cheese_positions.append(cheese_position)

        if done:
            break

    # Here you might average the activations across timesteps or just stack them
    # For simplicity, we'll stack them, treating each timestep as a separate data point
    conv_layer_activations_1.extend(episode_activations_1)
    conv_layer_activations_2.extend(episode_activations_2)
    conv_layer_activations_3.extend(episode_activations_3)
    conv_layer_activations_4.extend(episode_activations_4)
    cheese_position_total.extend(episode_cheese_positions)

# Convert lists to tensors
conv_layer_activations_1 = torch.stack(conv_layer_activations_1)
conv_layer_activations_2 = torch.stack(conv_layer_activations_2)
conv_layer_activations_3 = torch.stack(conv_layer_activations_3)
conv_layer_activations_4 = torch.stack(conv_layer_activations_4)
cheese_positions = torch.tensor(cheese_position_total, dtype=torch.long)

# Continue with data preparation and linear probe training as before...


100%|██████████| 100/100 [01:36<00:00,  1.04it/s]


In [67]:
prb.linear_probe(conv_layer_activations_1.detach().numpy(),cheese_positions.detach().numpy())


lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression



{'train_score': 0.9972527472527473,
 'test_score': 0.9560439560439561,
 'x': array([[-0.09087675, -0.10548554, -0.06309071, ..., -0.08131083,
         -0.11250778, -0.10381664],
        [-0.09087675, -0.10548554, -0.06309071, ..., -0.08131083,
         -0.11250778, -0.10381664],
        [-0.09087675, -0.10548554, -0.06309071, ..., -0.08131083,
         -0.11250778, -0.10381664],
        ...,
        [-0.09636283, -0.10780267, -0.06713416, ..., -0.08204793,
         -0.12421111, -0.111235  ],
        [-0.09636283, -0.10780267, -0.06713416, ..., -0.08204793,
         -0.12421111, -0.111235  ],
        [-0.09636283, -0.10780267, -0.06713416, ..., -0.08204793,
         -0.12421111, -0.111235  ]], dtype=float32),
 'x_train': array([[-0.09087675, -0.10548554, -0.06309071, ..., -0.08131083,
         -0.11250778, -0.10381664],
        [-0.09087675, -0.10548554, -0.06309071, ..., -0.08131083,
         -0.11250778, -0.10381664],
        [-0.09087675, -0.10548554, -0.06309071, ..., -0.08131083,
 

In [68]:
prb.linear_probe(conv_layer_activations_2.detach().numpy(),cheese_positions.detach().numpy())


lbfgs failed to converge (status=1):
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression



{'train_score': 0.9972527472527473,
 'test_score': 0.989010989010989,
 'x': array([[ 0.01748771, -0.00153155,  0.0208402 , ...,  0.02474743,
          0.01273527, -0.00579354],
        [ 0.01748771, -0.00153155,  0.0208402 , ...,  0.02474743,
          0.01273527, -0.00579354],
        [ 0.01748771, -0.00153155,  0.0208402 , ...,  0.02474743,
          0.01273527, -0.00579354],
        ...,
        [ 0.01701913,  0.00404522,  0.02008273, ...,  0.01798693,
          0.01179019,  0.00634721],
        [ 0.01701913,  0.00404522,  0.02008273, ...,  0.01798693,
          0.01179019,  0.00634721],
        [ 0.01701913,  0.00404522,  0.02008273, ...,  0.01798693,
          0.01179019,  0.00634721]], dtype=float32),
 'x_train': array([[ 0.01748771, -0.00153155,  0.0208402 , ...,  0.02474743,
          0.01273527, -0.00579354],
        [ 0.01748771, -0.00153155,  0.0208402 , ...,  0.02474743,
          0.01273527, -0.00579354],
        [ 0.01748771, -0.00153155,  0.0208402 , ...,  0.02474743,
  

In [69]:
prb.linear_probe(conv_layer_activations_3.detach().numpy(),cheese_positions.detach().numpy())

{'train_score': 1.0,
 'test_score': 1.0,
 'x': array([[-0.08130889, -0.0970579 , -0.06755713, ..., -0.0357924 ,
         -0.04728948, -0.04001565],
        [-0.08018743, -0.0936178 , -0.07177693, ..., -0.05338905,
         -0.04997716, -0.04720556],
        [-0.08018743, -0.0936178 , -0.07177693, ..., -0.05338905,
         -0.04997716, -0.04720556],
        ...,
        [-0.06766698, -0.09295157, -0.05938021, ..., -0.08700547,
         -0.14684016, -0.09522162],
        [-0.06766698, -0.09295157, -0.05938021, ..., -0.08700547,
         -0.14684016, -0.09522162],
        [-0.06766698, -0.09295157, -0.05938021, ..., -0.08700547,
         -0.14684016, -0.09522162]], dtype=float32),
 'x_train': array([[-0.08015346, -0.08397797, -0.08332638, ..., -0.02836125,
         -0.0290549 , -0.0459316 ],
        [-0.08129674, -0.09717917, -0.0671269 , ..., -0.04226167,
         -0.05550238, -0.04030601],
        [-0.08130889, -0.0970579 , -0.06755713, ..., -0.0357924 ,
         -0.04728948, -0.040015

In [None]:
prb.linear_probe(conv_layer_activations_4.detach().numpy(),cheese_positions.detach().numpy())