In [106]:

import sys
import os
import pandas as pd
from src.utils import heist
from src.utils import helpers
import torch.distributions
import torch

import gym
import random
import numpy as np
from helpers import generate_action, load_model
from procgen import ProcgenGym3Env
import imageio
import matplotlib.pyplot as plt
import typing
import math

from procgen import ProcgenGym3Env
import struct
import typing
from typing import Tuple, Dict, Callable, List, Optional
from dataclasses import dataclass
from src.policies_impala import ImpalaCNN
from procgen_tools.procgen_wrappers import VecExtractDictObs, TransposeFrame, ScaledFloatFrame

from gym3 import ToBaselinesVecEnv
import seaborn as sns


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [107]:
layers = {
    1: 'conv1a',
    2: 'pool1',
    3: 'conv2a',
    4: 'conv2b',
    5: 'pool2',
    6: 'conv3a',
    7: 'pool3',
    8: 'conv4a',
    9: 'pool4',
    10: 'fc1',
    11: 'fc2',
    12: 'fc3',
    13: 'value_fc',
    14: 'dropout_conv',
    15: 'dropout_fc'
}



In [108]:
def identify_entity_region(entity_position):
    """
    Identify which region an entity is in, based on a 9x9 grid.
    
    Args:
    entity_position (tuple): A tuple (x, y) representing the entity's position,
                             where x and y are between 0 and 8.9.
    
    Returns:
    dict: A dictionary containing the region number, row, and column.
    """
    x, y = entity_position
    
    # Calculate the region row and column (0-indexed)
    region_col = int(0 if math.isnan(x) else x)
    region_row = 8 - int(0 if math.isnan(y) else y)  # Invert y-axis so higher values are higher on the grid
    
    # Ensure the values are within the correct range
    region_row = min(max(region_row, 0), 8)
    region_col = min(max(region_col, 0), 8)
    
    # Calculate the region number (1-indexed)
    region_number = (region_row) * 9 + region_col + 1
    
    return {
        'region_number': region_number,
        'region_row': region_row,
        'region_col': region_col
    }

def get_entity_regions(state):
    """
    Extract all entities from the environment and identify their regions.
    
    Args:
    state: The game state object from heist.py
    
    Returns:
    dict: A dictionary mapping entity types to their positions and regions.
    """
    entities = state.state_vals["ents"]
    entity_regions = {}
    
    # Initialize 9x9 grid
    grid = [[[] for _ in range(9)] for _ in range(9)]
    
    ENTITY_COLORS = {
        0: "blue",
        1: "green",
        2: "red"
    }
    
    for entity in entities:
        entity_type = entity["image_type"].val
        x = entity["x"].val
        y = entity["y"].val 
        color_id = entity["image_theme"].val
        color = ENTITY_COLORS.get(color_id, "unknown")
        
        # Map image_type values to entity names
        entity_name = {
            9: "gem",
            2: f"{color}_key",
            1: f"{color}_lock",
            0: "player"
        }.get(entity_type, f"unknown_{entity_type}")
        

        
        region_info = identify_entity_region((x, y))
        
        if entity_name not in entity_regions:
            entity_regions[entity_name] = []
        
        entity_regions[entity_name].append({
            'position': (x, y),
            'region': region_info,
            'color': color if entity_name in ['blue_key', 'green_key', 'red_key', 'blue_lock', 'green_lock', 'red_lock'] else None
        })
        
        # Add entity to the corresponding cell in the 9x9 grid
        grid[region_info['region_row']][region_info['region_col']].append(entity_name)
    
    return entity_regions, grid

# Example usage:
# Assuming you have a state object from heist.py
model = helpers.load_interpretable_model(model_path="../model_interpretable.pt")
venv = heist.create_venv(1,0,1000000)
state = heist.state_from_venv(venv, 0)
entity_regions, grid = get_entity_regions(state)

# Print object types and regions for each entity
for entity_type, entities in entity_regions.items():
    for entity in entities:
        print(f"Object Type: {entity_type}")
        print(f"Region: {entity['region']['region_number']}")
        print(f"Region Row: {entity['region']['region_row']}")
        print(f"Region Column: {entity['region']['region_col']}")
        print(f"Position: {entity['position']}")
        if entity['color']:
            print(f"Color: {entity['color']}")
        print("---")
print("Entity Regions:", entity_regions)
print("9x9 Grid:")
for row in grid:  # Print rows in order, with higher y values at the top
    print(row)


Object Type: player
Region: 17
Region Row: 1
Region Column: 7
Position: (7.5, 7.5)
---
Object Type: green_lock
Region: 33
Region Row: 3
Region Column: 5
Position: (5.5, 5.5)
Color: green
---
Object Type: gem
Region: 24
Region Row: 2
Region Column: 5
Position: (5.577936172485352, 6.5121378898620605)
---
Object Type: green_key
Region: 53
Region Row: 5
Region Column: 7
Position: (7.505515098571777, 3.4152979850769043)
Color: green
---
Object Type: green_key
Region: 73
Region Row: 8
Region Column: 0
Position: (0.9024999737739563, 0.022499999031424522)
Color: green
---
Object Type: blue_lock
Region: 35
Region Row: 3
Region Column: 7
Position: (7.5, 5.5)
Color: blue
---
Object Type: blue_key
Region: 26
Region Row: 2
Region Column: 7
Position: (7.507329940795898, 6.577833652496338)
Color: blue
---
Object Type: blue_key
Region: 73
Region Row: 8
Region Column: 0
Position: (0.9624999761581421, 0.022499999031424522)
Color: blue
---
Entity Regions: {'player': [{'position': (7.5, 7.5), 'region': {'

In [131]:
import numpy as np

def find_most_active_region(activation_data):
    # Ensure the data is a numpy array
    activation_data = np.array(activation_data)
    
    # Ensure the data is in the expected shape (32, 32)
    if activation_data.shape != (8,8):
        raise ValueError(f"Activation data should be a 32x32 grid, but got shape {activation_data.shape}")
    
    # Create a 9x9 grid to store region sums
    region_sums = np.zeros((9, 9))
    
    # Calculate the size of each region
    region_size = 32 / 9
    
    # Sum the activations in each region
    for i in range(9):
        for j in range(9):
            start_x = int(i * region_size)
            end_x = int((i + 1) * region_size)
            start_y = int(j * region_size)
            end_y = int((j + 1) * region_size)
            
            region_sums[i, j] = np.sum(activation_data[start_x:end_x, start_y:end_y])
    
    # Find the indices of the maximum sum
    max_index = np.unravel_index(np.argmax(region_sums), region_sums.shape)
    
    # Calculate the region number (1 to 81)
    region_number = max_index[0] * 9 + max_index[1] + 1
    
    return {
        'region_number': region_number,
        'region_row': max_index[0],
        'region_col': max_index[1],
        'activation_sum': region_sums[max_index]
    }

def analyze_all_filters(normal_activations, layer_number):


    pool1_activations = normal_activations[layers[layer_number]][0]
    
    results = []
    for i in range(pool1_activations.shape[0]):  # Iterate over all filters
        filter_data = pool1_activations[i]
        result = find_most_active_region(filter_data)
        result['filter_number'] = i + 1  # Adding filter number to the result
        results.append(result)
    
    return results



In [133]:
def match_activations_with_entities(filter_results, entity_regions):
    matches = []
    
    for filter_result in filter_results:
        filter_number = filter_result['filter_number']
        filter_region = filter_result['region_number']
        
        for entity_type, entities in entity_regions.items():
            for entity in entities:
                if entity['region']['region_number'] == filter_region:
                    matches.append({
                        'filter_number': filter_number,
                        'region_number': filter_region,
                        'entity_type': entity_type,
                        'entity_position': entity['position']
                    })
    
    return matches



venv = heist.create_venv(1,random.randint(0,10000),1)
observation = venv.reset()
# observation = torch.tensor(observation, dtype=torch.float32).unsqueeze(0)

model_activations = helpers.ModelActivations(model)

output1, normal_activations = model_activations.run_with_cache(helpers.observation_to_rgb(observation), layers.values())


# Example usage:
filter_results = analyze_all_filters(normal_activations, 8)
entity_regions, _ = get_entity_regions(state)

matches = match_activations_with_entities(filter_results, entity_regions)

print("Matches between filter activations and entity regions:")
for match in matches:
    print(f"Filter {match['filter_number']} {match['entity_type']}")

Matches between filter activations and entity regions:


In [128]:
import pandas as pd
from collections import defaultdict

num_iterations = 100
match_histogram = defaultdict(lambda: defaultdict(int))
layer_number = 8


for _ in range(num_iterations):
    venv = heist.create_venv(1, random.randint(0, 10000), 0)
    observation = venv.reset()

    model_activations = helpers.ModelActivations(model)
    output1, normal_activations = model_activations.run_with_cache(helpers.observation_to_rgb(observation), layers.values())
    
    filter_results = analyze_all_filters(normal_activations,layer_number)
    entity_regions, _ = get_entity_regions(state)

    matches = match_activations_with_entities(filter_results, entity_regions)

    for match in matches:
        match_histogram[match['filter_number']][match['entity_type']] += 1



In [130]:
match_histogram

defaultdict(<function __main__.<lambda>()>, {})

In [134]:
def plot_filter_activation_matches(match_histogram, layer_number):
    # Convert the nested dictionary to a DataFrame
    df = pd.DataFrame(match_histogram).T.fillna(0)
    df.index = pd.to_numeric(df.index)  # Convert index to numeric
    df = df.sort_index()  # Sort by index (filter number)
    df.index.name = 'Filter Number'
    df.columns.name = 'Entity Type'

    print("Histogram of matches between filter activations and entity regions:")
    print(df)
    color_scheme = {
        'player': '#800080',  # purple
        'gem': '#ffff90',     # green
        'wall': '#8c564b',    # brown
        'blue_key': '#4169e1',    # royal blue
        'green_key': '#00ffff',   # lime green
        'red_key': '#dc143c',     # crimson
        'blue_lock': '#00bfff',   # deep sky blue
        'green_lock': '#3cb371',  # medium sea green
        'red_lock': '#ff00ff'     # orange red
    }
    #TODO: Finish adding entities of different colours
    # Fix the bug in classified datasets

    plt.figure(figsize=(12, 8))
    df.plot(kind='bar', stacked=True, color=[color_scheme.get(col, '#999999') for col in df.columns])
    plt.title(f'Histogram of Filter Activations Matching Entity Regions - Mixed. Layer {layer_number}')
    plt.xlabel('Filter Number')
    plt.ylabel('Number of Matches')
    plt.legend(title='Entity Type', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(range(len(df.index)), df.index, rotation=45)  
    plt.tight_layout()

    if plt.gcf().get_axes():
        plt.savefig('filter_activations_entity_matches_histogram.png', bbox_inches='tight')
        print("Figure saved as 'filter_activations_entity_matches_histogram.png'")
    else:
        print("Warning: The figure is empty. No image was saved.")

    plt.show()
    plt.close() 

# Call the function with match_histogram
plot_filter_activation_matches(match_histogram, layer_number)


Histogram of matches between filter activations and entity regions:
Empty DataFrame
Columns: []
Index: []


TypeError: no numeric data to plot

<Figure size 1200x800 with 0 Axes>

In [127]:



import pandas as pd
from collections import defaultdict

num_iterations = 1000
match_histogram = defaultdict(lambda: defaultdict(int))

model = helpers.load_interpretable_model(model_path="../model_interpretable.pt")
classified_dataset = heist.create_classified_dataset(num_samples_per_category=20, num_levels=0)
# Combine examples from all categories into one big list
combined_dataset = []

for category, examples in classified_dataset.items():
    combined_dataset.extend(examples)

layer_number = 8
for sample in combined_dataset:
    # venv = heist.create_venv(1, random.randint(0, 10000), 1)
    observation = sample

    model_activations = helpers.ModelActivations(model)
    output1, normal_activations = model_activations.run_with_cache(helpers.observation_to_rgb(observation), layers.values())

    filter_results = analyze_all_filters(normal_activations, layer_number)
    state = heist.state_from_venv(venv, 0) 
    entity_regions, _ = get_entity_regions(state)

    matches = match_activations_with_entities(filter_results, entity_regions)

    for match in matches:
        match_histogram[match['filter_number']][match['entity_type']] += 1

plot_filter_activation_matches(match_histogram, layer_number)


Histogram of matches between filter activations and entity regions:
Empty DataFrame
Columns: []
Index: []


TypeError: no numeric data to plot

<Figure size 1200x800 with 0 Axes>

In [115]:
classified_dataset = heist.create_classified_dataset(num_samples_per_category=4, num_levels=0)
# Combine examples from all categories into one big list
combined_dataset = []

for category, examples in classified_dataset.items():
    combined_dataset.extend(examples)

# Print the shape of the combined dataset
print(f"Combined dataset shape: {np.array(combined_dataset).shape}")


Combined dataset shape: (28, 64, 64, 3)
