In [1]:
import sys
import os
import torch.distributions
import torch
import gym
import random
import numpy as np
from procgen import ProcgenGym3Env
import imageio
import matplotlib.pyplot as plt
import typing
import math
from collections import defaultdict
from typing import Tuple, Dict, Callable, List, Optional
from dataclasses import dataclass
# from src.policies_modified import ImpalaCNN
from procgen_tools.procgen_wrappers import VecExtractDictObs, TransposeFrame, ScaledFloatFrame
from gym3 import ToBaselinesVecEnv
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns


# Import custom modules
import heist
import helpers
from helpers import generate_action, load_model

# Reload modules automatically
%load_ext autoreload
%autoreload 2



building procgen...done


In [2]:
# Load model and setup environment
difficulty = 'hard'
model = helpers.load_model(model_path=f"../model_{difficulty}.pt")
model_activations = helpers.ModelActivations(model)
layer_paths = helpers.get_model_layer_names(model)

In [3]:
objective_activations_dataset = helpers.get_objective_activations(model_activations, layer_paths, 512)


In [4]:
print(objective_activations_dataset['gem'].keys())

dict_keys(['conv_seqs_0_conv', 'conv_seqs_0_max_pool2d', 'conv_seqs_0_res_block0_conv0', 'conv_seqs_0_res_block0_conv1', 'conv_seqs_0_res_block0', 'conv_seqs_0_res_block1_conv0', 'conv_seqs_0_res_block1_conv1', 'conv_seqs_0_res_block1', 'conv_seqs_0', 'conv_seqs_1_conv', 'conv_seqs_1_max_pool2d', 'conv_seqs_1_res_block0_conv0', 'conv_seqs_1_res_block0_conv1', 'conv_seqs_1_res_block0', 'conv_seqs_1_res_block1_conv0', 'conv_seqs_1_res_block1_conv1', 'conv_seqs_1_res_block1', 'conv_seqs_1', 'conv_seqs_2_conv', 'conv_seqs_2_max_pool2d', 'conv_seqs_2_res_block0_conv0', 'conv_seqs_2_res_block0_conv1', 'conv_seqs_2_res_block0', 'conv_seqs_2_res_block1_conv0', 'conv_seqs_2_res_block1_conv1', 'conv_seqs_2_res_block1', 'conv_seqs_2', 'hidden_fc', 'logits_fc', 'value_fc'])


In [5]:

def train_probes(objective_activations_dataset: Dict[str, Dict[str, torch.Tensor]], layers_to_probe: Optional[List[str]] = None):
    '''
    Train probes on each layer to predict current objective from activations.

    Args:
    - objective_vectors: dict of dict of objective vectors, where each key is the name of the objective
    (e.g. gem or red_lock) and each value is a dict whose keys are layer names and values are activations for that layer
    that correspond to input images where the player is trying to get to that objective.

    Returns:
    - Accuracies: dict of dict of accuracies, where each key is the name of the objective, and each value is a dict
    whose keys are layer names and values are the accuracy of the probe on that layer.
    '''
    accuracies = {}
    class_accuracies = {}

    # Loop over each layer
    for layer in objective_activations_dataset['gem'].keys():
        # Skip layers not in layers_to_probe
        if layers_to_probe is not None and layer not in layers_to_probe:
            continue
        
        # Get dataset of activation, objective pairs
        activation_data = []
        labels = []
        for objective in objective_activations_dataset.keys():
            activations = torch.stack(objective_activations_dataset[objective][layer]) # Normally tuple
            activations = activations.view(activations.shape[0], -1)
            activation_data.append(activations)
            labels += [objective] * activations.shape[0]
        
        # Create train and test sets
        train_data, test_data, train_labels, test_labels = train_test_split(torch.cat(activation_data), labels, test_size=0.3, random_state=42)

        # Train logistic regression model
        probe = LogisticRegression(random_state=42, max_iter=40000)
        probe.fit(train_data, train_labels)

        # Predict on test set
        predictions = probe.predict(test_data)
        accuracy = accuracy_score(test_labels, predictions)
        accuracies[layer] = accuracy

        # Calculate accuracy for each class
        report = classification_report(test_labels, predictions, output_dict=True)
        class_accuracies[layer] = {objective: report[objective]['precision'] for objective in objective_activations_dataset.keys() if objective in report}
        
        print(f'Layer: {layer}, Overall Accuracy: {accuracy}')
        for objective in class_accuracies[layer]:
            print(f'  Objective: {objective}, Accuracy: {class_accuracies[layer][objective]}')

        # Plot confusion matrix
        cm = confusion_matrix(test_labels, predictions, labels=list(objective_activations_dataset.keys()))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=list(objective_activations_dataset.keys()))
        disp.plot(cmap=plt.cm.Blues)
        plt.title(f'Confusion Matrix for Layer {layer}')
        plt.xticks(rotation=45)
        plt.show()

    # Plot overall accuracies
    plt.figure(figsize=(5, 5))
    sns.barplot(x=list(accuracies.keys()), y=list(accuracies.values()))
    plt.title('Probe Accuracies by Layer')
    plt.ylabel('Accuracy')
    plt.xlabel('Layer')
    plt.xticks(rotation=45)
    plt.show()

    # Plot class accuracies for each layer
    for layer in class_accuracies.keys():
        plt.figure(figsize=(10, 5))
        sns.barplot(x=list(class_accuracies[layer].keys()), y=list(class_accuracies[layer].values()))
        plt.title(f'Class Accuracies for Layer {layer}')
        plt.ylabel('Accuracy')
        plt.xlabel('Objective')
        plt.xticks(rotation=45)
        plt.show()



train_probes(objective_activations_dataset, objective_activations_dataset['gem'].keys())
