In [10]:
from src.ppo import *
from src.heist import create_venv
from src.bilinear_impala_simplified import BimpalaCNN
from src.heist import load_model
from src.helpers import ModelActivations
from src.probing import *

In [11]:
def add_activations_to_dataset(dataset, modelactivations, layers, batch_size=32):
    dataset_with_activations = []
    num_batches = (len(dataset) + batch_size - 1) // batch_size
    
    for i in tqdm(range(num_batches), desc="Computing activations in batches"):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(dataset))
        batch = dataset[start_idx:end_idx]
        
        # Prepare batch of observations
        observations = [data_point['observation'][0] for data_point in batch]
        observations_array = np.array(observations)
        
        # Ensure the input has 4 dimensions: [batch, height, width, channels]
        if observations_array.ndim == 3:
            observations_array = np.expand_dims(observations_array, axis=0)
        
        # Convert to PyTorch tensor and move to the appropriate device
        observations_tensor = torch.tensor(observations_array, dtype=torch.float32)
        
        
        # Compute activations for the batch
        with torch.no_grad():
            _, activations = modelactivations.run_with_cache(observations_tensor, layers)
        
        # Add activations to each data point in the batch
        for j, data_point in enumerate(batch):
            new_data_point = data_point.copy()
            new_data_point['activations'] = {
                layer.replace('.', '_'): activations[layer.replace('.', '_')][j].cpu().numpy()
                for layer in layers
            }
            dataset_with_activations.append(new_data_point)
    
    return dataset_with_activations

In [12]:
model_path = "/mnt/ssd-1/mechinterp/narmeen/bilinear_experiments_official/bilinear_experiments/bilinear_models/bimpala_maze_simplified.pt"
model =load_model(model_path,7)
print(model)
for k in model.state_dict():
    print(k)
state_dict = model.state_dict()

Model loaded from /mnt/ssd-1/mechinterp/narmeen/bilinear_experiments_official/bilinear_experiments/bilinear_models/bimpala_maze_simplified.pt
BimpalaCNN(
  (conv): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (conv_seqs): ModuleList(
    (0-2): 3 x ConvSequence(
      (max_pool2d): MaxPool2d(kernel_size=7, stride=2, padding=3, dilation=1, ceil_mode=False)
      (res_block0): ResidualBlock(
        (conv0): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (conv1): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      )
      (res_block1): ResidualBlock(
        (conv0): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
        (conv1): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      )
    )
  )
  (hidden_fc1): Linear(in_features=2048, out_features=256, bias=False)
  (hidden_fc2): Linear(in_features=2048, out_features=256, bias=False)
  

  state_dict = torch.load(model_path)


In [13]:
env = create_venv()
agent = PPO(model, device="cuda" if torch.cuda.is_available() else "cpu")    
dataset = agent.create_dataset_with_labels(env, num_episodes=500, max_steps_per_episode=100)

Generating dataset:   0%|          | 0/500 [00:00<?, ?it/s]

Generating dataset:  18%|█▊        | 92/500 [00:31<05:58,  1.14it/s]



Generating dataset:  27%|██▋       | 133/500 [00:44<03:59,  1.53it/s]



Generating dataset:  27%|██▋       | 134/500 [00:46<05:55,  1.03it/s]



Generating dataset:  27%|██▋       | 135/500 [00:48<07:15,  1.19s/it]



Generating dataset:  27%|██▋       | 136/500 [00:50<08:35,  1.42s/it]



Generating dataset:  32%|███▏      | 159/500 [00:59<03:31,  1.61it/s]



Generating dataset:  32%|███▏      | 160/500 [01:01<05:26,  1.04it/s]



Generating dataset:  32%|███▏      | 161/500 [01:03<06:40,  1.18s/it]



Generating dataset:  32%|███▏      | 162/500 [01:04<07:39,  1.36s/it]



Generating dataset:  36%|███▌      | 179/500 [01:13<03:28,  1.54it/s]



Generating dataset:  36%|███▌      | 180/500 [01:15<05:22,  1.01s/it]



Generating dataset:  71%|███████▏  | 357/500 [01:58<01:53,  1.26it/s]



Generating dataset:  72%|███████▏  | 358/500 [02:00<02:34,  1.09s/it]



Generating dataset:  72%|███████▏  | 359/500 [02:02<03:12,  1.36s/it]



Generating dataset:  72%|███████▏  | 360/500 [02:04<03:30,  1.50s/it]



Generating dataset:  73%|███████▎  | 364/500 [02:08<02:45,  1.21s/it]



Generating dataset:  73%|███████▎  | 365/500 [02:10<03:13,  1.43s/it]



Generating dataset:  91%|█████████ | 455/500 [02:40<00:27,  1.65it/s]



Generating dataset:  91%|█████████ | 456/500 [02:42<00:42,  1.05it/s]



Generating dataset:  91%|█████████▏| 457/500 [02:44<00:51,  1.19s/it]



Generating dataset:  92%|█████████▏| 458/500 [02:46<00:59,  1.42s/it]



Generating dataset: 100%|██████████| 500/500 [03:00<00:00,  2.77it/s]


In [14]:
modelactivations = ModelActivations(model)
complete_dataset = add_activations_to_dataset(dataset, modelactivations, ["conv_seqs.0.res_block0.conv1"], batch_size=32)

Computing activations in batches: 100%|██████████| 290/290 [00:03<00:00, 93.81it/s] 


In [8]:
complete_dataset[0]

{'observation': array([[[[201, 152, 105],
          [201, 152, 105],
          [189, 137,  88],
          ...,
          [199, 150, 102],
          [189, 137,  88],
          [189, 137,  88]],
 
         [[201, 152, 105],
          [201, 152, 105],
          [197, 143,  92],
          ...,
          [201, 152, 105],
          [197, 143,  92],
          [197, 143,  92]],
 
         [[201, 152, 105],
          [201, 152, 105],
          [189, 137,  88],
          ...,
          [201, 152, 105],
          [189, 137,  88],
          [189, 137,  88]],
 
         ...,
 
         [[197, 143,  92],
          [197, 143,  92],
          [201, 152, 105],
          ...,
          [197, 143,  92],
          [201, 152, 105],
          [201, 152, 105]],
 
         [[189, 137,  88],
          [189, 137,  88],
          [189, 137,  88],
          ...,
          [189, 137,  88],
          [189, 137,  88],
          [189, 137,  88]],
 
         [[197, 143,  92],
          [197, 143,  92],
          [201,

In [16]:
def train_all_probes(dataset, layer ='conv_seqs_0_res_block0_conv1'):
    input_dim = dataset[0]['activations'][layer].size
    grid_size = int(np.sqrt(len(dataset[0]['labels']['cheese_presence'])))
    n_actions = 5

    neighboring_walls_probe = NeighboringWallsProbe(input_dim)
    n_action_classes = max(max(d['labels']['next_n_actions']) for d in dataset) + 1
    next_n_actions_probe = NextNActionsProbe(input_dim, n_action_classes, n_actions)
    cheese_presence_probe = CheesePresenceProbe(input_dim, grid_size)
    mouse_location_probe = MouseLocationProbe(input_dim)

    print("Training Neighboring Walls Probe...")
    neighboring_walls_probe.train(dataset)
    print("Training Next N Actions Probe...")
    next_n_actions_probe.train(dataset)
    print("Training Cheese Presence Probe...")
    cheese_presence_probe.train(dataset)
    print("Training Mouse Location Probe...")
    mouse_location_probe.train(dataset)

    return {
        'neighboring_walls': neighboring_walls_probe,
        'next_n_actions': next_n_actions_probe,
        'cheese_presence': cheese_presence_probe,
        'mouse_location': mouse_location_probe
    }

# Assuming 'dataset' is your list of data points
probes = train_all_probes(complete_dataset)

Training Neighboring Walls Probe...
Epoch 1, Validation Loss: 0.6187528603035828
Epoch 2, Validation Loss: 0.6192062291605719
Epoch 3, Validation Loss: 0.6192412982726919
Epoch 4, Validation Loss: 0.6190288647495467
Epoch 5, Validation Loss: 0.6187703229230026
Epoch 6, Validation Loss: 0.6185027304394491
Epoch 7, Validation Loss: 0.6182261299470375
Epoch 8, Validation Loss: 0.6179341350136132
Epoch 9, Validation Loss: 0.6176212245020373
Epoch 10, Validation Loss: 0.617284698732968
Training Next N Actions Probe...
Original dataset size: 9273
Filtered dataset size: 7505
Removed 1768 observations with fewer than 5 actions
Epoch 1, Training Loss: 1.2183730754446476, Validation Loss: 1.2442000394171857
Epoch 2, Training Loss: 1.1433040135084314, Validation Loss: 1.135472963465021
Epoch 3, Training Loss: 1.0920856893062592, Validation Loss: 1.092798944483412
Epoch 4, Training Loss: 1.0591211414083521, Validation Loss: 1.0741190669384408
Epoch 5, Training Loss: 1.033629465927469, Validation L