1. __init__ (Initialization)
Original Usage: Initializes experiment with configuration for metrics, model, dataset, and hooks.
RL Application: Initialize with the RL model, environment, and possibly a policy or value network specifically targeted for modification.

2. verify_model_setup (Model Configuration Verification)
Original Usage: Checks that the model configuration is correct for the intended experiment setup.
RL Application: Verify that the RL model or specific network components (like CNN or MLP layers within the policy network) are accessible and correctly configured for hooking and manipulation.

3. update_cur_metric (Metric Update)
Original Usage: Update and log the current metric based on the dataset.
RL Application: Update based on RL-specific metrics such as average rewards, entropy of the policy, or other performance indicators after an episode or batch of interactions.

4. reverse_topologically_sort_corr (Reverse Topological Sorting)
Original Usage: Organizes the model's computational graph in a reverse topological order.
RL Application: Potentially useful for analyzing the dependency of outputs on previous layers in deep networks, helping to decide which layers to target for interventions.

5. sender_hook and receiver_hook (Manage Data Flow)
Original Usage: Attach hooks to the model to intercept and modify data as it flows through the model.
RL Application: Attach hooks to the policy or value networks to modify activations or weights dynamically during training, affecting the agent’s decision-making process.

6. add_all_sender_hooks and add_sender_hook (Add Hooks to Nodes)
Original Usage: Add hooks dynamically to specific nodes or layers in the model based on certain conditions.
RL Application: Use in similar fashion to intercept and possibly modify the computation in neural networks, e.g., zeroing out activations of certain neurons to study their impact on agent behavior.

7. setup_corrupted_cache (Cache Setup for Modifications)
Original Usage: Set up a cache system to store corrupted (modified) activations.
RL Application: Useful for experiments where part of the agent’s observations or internal state representations are systematically altered to assess robustness or identify critical information pathways.

8. setup_model_hooks (Configure Hooks)
Original Usage: Configures the hooks based on the experiment’s needs.
RL Application: Could be adapted to dynamically modify how data is processed within the RL model, for instance, to implement dropout or noise injection for robustness testing.

9. step (Process One Node)
Original Usage: Processes one computational node, evaluates its impact, and decides whether to keep the current configuration.
RL Application: Translate to processing one step or episode in RL, evaluating the agent's performance, and deciding whether to keep the modifications.

10. remove_redundant_node (Cleanup)
Original Usage: Remove nodes that are found to be redundant based on the experiment's criteria.
RL Application: This might correspond to pruning certain neurons or layers that are found to be non-contributory towards the agent’s performance.

11. increment_current_node (Node Navigation)
Original Usage: Move to the next node in the computational graph.
RL Application: Adapted to move through different components or layers of the RL model systematically during the experiment.

12. save_subgraph and load_subgraph (State Saving and Loading)
Original Usage: Save and load configurations of the computational graph.
RL Application: Save and load agent configurations or network states at various points to revert to known good settings or explore the effects of changes.

In [12]:
#import statements
import matplotlib.patches as patches
import networkx as nx
import heist
import helpers
import torch.distributions
import torch

import gym
import random
import numpy as np
from helpers import generate_action, load_model
import imageio
import matplotlib.pyplot as plt
import typing
import math

from procgen import ProcgenGym3Env
import struct
import typing
from typing import Tuple, Dict, Callable, List, Optional
from dataclasses import dataclass
from src.policies_modified import ImpalaCNN
from procgen_tools.procgen_wrappers import VecExtractDictObs, TransposeFrame, ScaledFloatFrame

from gym3 import ToBaselinesVecEnv
import seaborn as sns
import random

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Rough graph

input -> conv_seqs.0


Residual block 0 and 1
'conv_seqs' -> 'conv_seqs.0.conv' -> 'conv_seqs.0.max_pool2d' -> 'conv_seqs.0.res_block0'
'conv_seqs.0.res_block0' -> 'conv_seqs.0.res_block0.conv0' -> 'conv_seqs.0.res_block0.conv1' -> 'conv_seqs.0.res_block1'
'conv_seqs.0.res_block0' -> 'conv_seqs.0.res_block1' (skip connection)

'conv_seqs.0.res_block1' -> conv_seqs.0.res_block1.conv0' -> conv_seqs.0.res_block1.conv1' ->  'conv_seqs.1'
'conv_seqs.0.res_block1' -> 'conv_seqs.1'

Conv Seq 1

Residual block 0 and 1
'conv_seqs.1' -> 'conv_seqs.1.conv' -> 'conv_seqs.1.max_pool2d' -> 'conv_seqs.1.res_block0'
'conv_seqs.1.res_block0' -> 'conv_seqs.1.res_block0.conv0' -> 'conv_seqs.1.res_block0.conv1' -> 'conv_seqs.1.res_block1'
'conv_seqs.1.res_block0' -> 'conv_seqs.1.res_block1' (skip connection)

'conv_seqs.1.res_block1' -> conv_seqs.1.res_block1.conv0' -> conv_seqs.1.res_block1.conv1' ->  'conv_seqs.1'
'conv_seqs.1.res_block1' -> 'conv_seqs.2'

Conv Seq 2

Residual block 0 and 1
'conv_seqs.2' -> 'conv_seqs.2.conv' -> 'conv_seqs.2.max_pool2d' -> 'conv_seqs.2.res_block0'
'conv_seqs.2.res_block0' -> 'conv_seqs.2.res_block0.conv0' -> 'conv_seqs.2.res_block0.conv1' -> 'conv_seqs.2.res_block1'
'conv_seqs.2.res_block0' -> 'conv_seqs.2.res_block1' (skip connection)

'conv_seqs.2.res_block1' -> conv_seqs.2.res_block1.conv0' -> conv_seqs.2.res_block1.conv1' ->  'hidden_fc'
'conv_seqs.2.res_block1' -> 'hidden_fc'
 

In [13]:
from collections import defaultdict
from typing import List, Dict, Tuple

class Node:
    def __init__(self, name: str):
        self.name = name
        self.children = []

    def add_child(self, child_node):
        self.children.append(child_node)

class Graph:
    def __init__(self):
        self.nodes = {}

    def add_node(self, name: str):
        if name not in self.nodes:
            self.nodes[name] = Node(name)
        return self.nodes[name]

    def add_edge(self, from_node: str, to_node: str):
        node1 = self.add_node(from_node)
        node2 = self.add_node(to_node)
        node1.add_child(node2)

    def display(self):
        for name, node in self.nodes.items():
            print(f'Node {name} points to: {[child.name for child in node.children]}')

def build_graph():
    graph = Graph()

    # Define the sequence names and their respective channels
    sequences = ['conv_seqs.0', 'conv_seqs.1', 'conv_seqs.2']
    last_seq_output = 'input'

    # Connect each sequence and define internal module connections
    for i, seq in enumerate(sequences):
        graph.add_edge(last_seq_output, seq)
        graph.add_edge(f'{seq}.res_block0', f'{seq}.res_block1')  # Skip connection

        # Connect within ConvSequence
        graph.add_edge(seq, f'{seq}.conv')
        graph.add_edge(f'{seq}.conv', f'{seq}.max_pool2d')
        graph.add_edge(f'{seq}.max_pool2d', f'{seq}.res_block0')
        graph.add_edge(f'{seq}.res_block0', f'{seq}.res_block0.conv0')
        graph.add_edge(f'{seq}.res_block0.conv0', f'{seq}.res_block0.conv1')
        graph.add_edge(f'{seq}.res_block0.conv1', f'{seq}.res_block1')

        graph.add_edge(f'{seq}.res_block1', f'{seq}.res_block1.conv0')
        graph.add_edge(f'{seq}.res_block1.conv0', f'{seq}.res_block1.conv1')

        # Connect to next sequence or to hidden_fc
        next_node = sequences[i+1] if i < len(sequences) - 1 else 'hidden_fc'
        graph.add_edge(f'{seq}.res_block1.conv1', next_node)
        graph.add_edge(f'{seq}.res_block1', next_node)  # Skip connection

    # Connect hidden_fc to logits_fc and value_fc
    graph.add_node('hidden_fc')
    graph.add_node('logits_fc')
    graph.add_node('value_fc')
    graph.add_edge('hidden_fc', 'logits_fc')
    graph.add_edge('hidden_fc', 'value_fc')

    return graph

In [14]:
model_path = "../model_final.pt"
model = helpers.load_model(model_path=model_path)
graph = build_graph()
graph.display()


Node input points to: ['conv_seqs.0', 'conv_seqs.1', 'conv_seqs.2']
Node conv_seqs.0 points to: ['conv_seqs.0.conv']
Node conv_seqs.0.res_block0 points to: ['conv_seqs.0.res_block1', 'conv_seqs.0.res_block0.conv0']
Node conv_seqs.0.res_block1 points to: ['conv_seqs.0.res_block1.conv0', 'conv_seqs.1']
Node conv_seqs.0.conv points to: ['conv_seqs.0.max_pool2d']
Node conv_seqs.0.max_pool2d points to: ['conv_seqs.0.res_block0']
Node conv_seqs.0.res_block0.conv0 points to: ['conv_seqs.0.res_block0.conv1']
Node conv_seqs.0.res_block0.conv1 points to: ['conv_seqs.0.res_block1']
Node conv_seqs.0.res_block1.conv0 points to: ['conv_seqs.0.res_block1.conv1']
Node conv_seqs.0.res_block1.conv1 points to: ['conv_seqs.1']
Node conv_seqs.1 points to: ['conv_seqs.1.conv']
Node conv_seqs.1.res_block0 points to: ['conv_seqs.1.res_block1', 'conv_seqs.1.res_block0.conv0']
Node conv_seqs.1.res_block1 points to: ['conv_seqs.1.res_block1.conv0', 'conv_seqs.2']
Node conv_seqs.1.conv points to: ['conv_seqs.1.ma

In [15]:
from collections import defaultdict
from typing import List, Dict, Tuple

# Define necessary classes

class TorchIndex:
    def __init__(self, indices):
        self.indices = indices

    def __repr__(self):
        return str(self.indices)

class EdgeType:
    ADDITION = "ADDITION"
    PLACEHOLDER = "PLACEHOLDER"
    DIRECT_COMPUTATION = "DIRECT_COMPUTATION"

class Edge:
    def __init__(self, edge_type):
        self.edge_type = edge_type
        self.present = True

class TLACDCInterpNode:
    def __init__(self, name, index, incoming_edge_type):
        self.name = name
        self.index = index
        self.incoming_edge_type = incoming_edge_type
        self.children = []
        self.parents = []

    def __repr__(self):
        return f"{self.name}{self.index}"

    def _add_child(self, child_node):
        self.children.append(child_node)

    def _add_parent(self, parent_node):
        self.parents.append(parent_node)

class TLACDCCorrespondence:
    def __init__(self):
        self.graph = defaultdict(dict)
        self.edges = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

    def nodes(self) -> List[TLACDCInterpNode]:
        return [node for by_index_list in self.graph.values() for node in by_index_list.values()]

    def all_edges(self) -> Dict[Tuple[str, TorchIndex, str, TorchIndex], Edge]:
        big_dict = {}
        for child_name, rest1 in self.edges.items():
            for child_index, rest2 in rest1.items():
                for parent_name, rest3 in rest2.items():
                    for parent_index, edge in rest3.items():
                        assert edge is not None, (child_name, child_index, parent_name, parent_index, "Edges have been setup WRONG somehow...")
                        big_dict[(child_name, child_index, parent_name, parent_index)] = edge
        return big_dict

    def add_node(self, node: TLACDCInterpNode, safe=True):
        if safe:
            assert node not in self.nodes(), f"Node {node} already in graph"
        self.graph[node.name][node.index] = node

    def add_edge(self, parent_node: TLACDCInterpNode, child_node: TLACDCInterpNode, edge: Edge, safe=True):
        if safe:
            if parent_node not in self.nodes():
                self.add_node(parent_node)
            if child_node not in self.nodes():
                self.add_node(child_node)
        assert child_node.incoming_edge_type == edge.edge_type, (child_node.incoming_edge_type, edge.edge_type)
        parent_node._add_child(child_node)
        child_node._add_parent(parent_node)
        self.edges[child_node.name][child_node.index][parent_node.name][parent_node.index] = edge

    def remove_edge(self, child_name: str, child_index: TorchIndex, parent_name: str, parent_index: TorchIndex):
        edge = self.edges[child_name][child_index][parent_name][parent_index]
        edge.present = False
        del self.edges[child_name][child_index][parent_name][parent_index]
        if not self.edges[child_name][child_index][parent_name]:
            del self.edges[child_name][child_index][parent_name]
        if not self.edges[child_name][child_index]:
            del self.edges[child_name][child_index]
        if not self.edges[child_name]:
            del self.edges[child_name]
        parent = self.graph[parent_name][parent_index]
        child = self.graph[child_name][child_index]
        parent.children.remove(child)
        child.parents.remove(parent)

def setup_graph_from_model_using_channels(model):
    #TODO: slices on the channels are not really working
    correspondence = TLACDCCorrespondence()
    prev_node_channels = None

    for i, conv_seq in enumerate(model.conv_seqs):
        out_channels = conv_seq._out_channels
        conv_node_base = f"ConvSeq_{i}"

        # Add conv nodes and their connections
        current_conv_nodes = []
        for j in range(out_channels):
            conv_node = TLACDCInterpNode(name=f"{conv_node_base}_OutChan_{j}", index=TorchIndex([j]), incoming_edge_type=EdgeType.ADDITION)
            correspondence.add_node(conv_node)
            current_conv_nodes.append(conv_node)

            if prev_node_channels:
                for prev_node in prev_node_channels:
                    correspondence.add_edge(parent_node=prev_node, child_node=conv_node, edge=Edge(EdgeType.ADDITION))

        prev_node_channels = current_conv_nodes

        # Add residual blocks and their connections
        for res_block_idx in range(2):
            res_block_node_base = f"ResBlock_{i}_{res_block_idx}"
            current_res_block_nodes = []
            for j in range(out_channels):
                res_block_node = TLACDCInterpNode(name=f"{res_block_node_base}_Chan_{j}", index=TorchIndex([j]), incoming_edge_type=EdgeType.ADDITION)
                correspondence.add_node(res_block_node)
                current_res_block_nodes.append(res_block_node)

                if j < len(prev_node_channels):
                    correspondence.add_edge(parent_node=prev_node_channels[j], child_node=res_block_node, edge=Edge(EdgeType.ADDITION))

            prev_node_channels = current_res_block_nodes

    # Add fully connected (hidden) layer and connect to the last residual block's output
    fc_node = TLACDCInterpNode(name="Hidden_FC", index=TorchIndex([0]), incoming_edge_type=EdgeType.ADDITION)
    correspondence.add_node(fc_node)

    for prev_node in prev_node_channels:
        correspondence.add_edge(parent_node=prev_node, child_node=fc_node, edge=Edge(EdgeType.ADDITION))

    # Add logits and value fully connected layers and connect them to the hidden FC layer
    logits_node = TLACDCInterpNode(name="Logits_FC", index=TorchIndex([0]), incoming_edge_type=EdgeType.ADDITION)
    correspondence.add_node(logits_node)
    correspondence.add_edge(parent_node=fc_node, child_node=logits_node, edge=Edge(EdgeType.ADDITION))

    value_node = TLACDCInterpNode(name="Value_FC", index=TorchIndex([0]), incoming_edge_type=EdgeType.ADDITION)
    correspondence.add_node(value_node)
    correspondence.add_edge(parent_node=fc_node, child_node=value_node, edge=Edge(EdgeType.ADDITION))

    return correspondence








In [16]:
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np

# Dictionary of ordered layer names
ordered_layer_names = {
 0: 'conv_seqs',
 1: 'conv_seqs.0',
 2: 'conv_seqs.0.conv',
 3: 'conv_seqs.0.max_pool2d',
 4: 'conv_seqs.0.res_block0',
 5: 'conv_seqs.0.res_block0.conv0',
 6: 'conv_seqs.0.res_block0.conv1',
 7: 'conv_seqs.0.res_block1',
 8: 'conv_seqs.0.res_block1.conv0',
 9: 'conv_seqs.0.res_block1.conv1',
 10: 'conv_seqs.1',
 11: 'conv_seqs.1.conv',
 12: 'conv_seqs.1.max_pool2d',
 13: 'conv_seqs.1.res_block0',
 14: 'conv_seqs.1.res_block0.conv0',
 15: 'conv_seqs.1.res_block0.conv1',
 16: 'conv_seqs.1.res_block1',
 17: 'conv_seqs.1.res_block1.conv0',
 18: 'conv_seqs.1.res_block1.conv1',
 19: 'conv_seqs.2',
 20: 'conv_seqs.2.conv',
 21: 'conv_seqs.2.max_pool2d',
 22: 'conv_seqs.2.res_block0',
 23: 'conv_seqs.2.res_block0.conv0',
 24: 'conv_seqs.2.res_block0.conv1',
 25: 'conv_seqs.2.res_block1',
 26: 'conv_seqs.2.res_block1.conv0',
 27: 'conv_seqs.2.res_block1.conv1',
 28: 'hidden_fc',
 29: 'logits_fc',
 30: 'value_fc'
}

# Extracting all the names into a list
layer_names = list(ordered_layer_names.values())


class ModelActivations:
    def __init__(self, model):
        self.activations = {}
        self.model = model
        self.hooks = []  # To keep track of hooks
        self.layer_paths = layer_names

    def clear_hooks(self):
        # Remove all previously registered hooks
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        self.activations = {}

    def get_activation(self, name):
        def hook(model, input, output):
            processed_output = []
            for item in output:
                if isinstance(item, torch.Tensor):
                    processed_output.append(item.detach())
                elif isinstance(item, torch.distributions.Categorical):
                    processed_output.append(item.logits.detach())
                else:
                    processed_output.append(item)
            self.activations[name] = tuple(processed_output)
        return hook

    def register_hook_by_path(self, path, name):
        elements = path.split('.')
        model = self.model
        for i, element in enumerate(elements):
            if '[' in element:
                base, index = element.replace(']', '').split('[')
                index = int(index)
                model = getattr(model, base)[index]
            else:
                model = getattr(model, element)
            if i == len(elements) - 1:
                hook = model.register_forward_hook(self.get_activation(name))
                self.hooks.append(hook)  # Keep track of the hook

    def run_with_cache(self, input):
        self.clear_hooks()  # Clear any existing hooks
        self.activations = {}  # Reset activations
        for path in self.layer_paths:
            self.register_hook_by_path(path, path.replace('.', '_'))
        output = self.model(input)
        return output, self.activations
    
    def patch_activations(self, input, to_patch_activation_tensor = 0,to_change_layer_name='conv_seqs.0.conv',ablate ="True"):
        cached_activations = {}
        
        # Define a factory function to create hook functions with a stable layer name
        def make_hook(layer_name):
            def saves_cache(module, input, output):
                cached_activations[layer_name] = output.detach()
                if ablate:
                        return output* 0
                elif layer_name == to_change_layer_name:
                    print("Here are the shapes",to_patch_activation_tensor.shape, output.shape)
                    if to_patch_activation_tensor.shape == output.shape:                
                        return  to_patch_activation_tensor
                    
            return saves_cache
        
        # Function to recursively get a sub-module from its name
        def get_submodule(module, submodule_name):
            names = submodule_name.split('.')
            for name in names:
                module = getattr(module, name)
            return module
        
        self.clear_hooks()
        for name in layer_names:
            layer = get_submodule(model, name)
            hook = make_hook(name)  # Create a hook with a stable layer name
            self.hooks.append(layer.register_forward_hook(hook))

        output = model(input)  # Assuming tensor_gem is defined elsewhere
        self.clear_hooks()

        return output,cached_activations

class RLPolicyExperiment(BaseCallback):
    """
    A class to manage an experiment that adjusts and analyzes the components of an RL policy
    network dynamically during training, akin to the TLACDCExperiment for neural networks.
    """

    def __init__(self, model, env, verbose=1, log_dir=None, abs_value_threshold=False,
             using_wandb=False, wandb_entity_name=None, wandb_group_name=None,
             wandb_project_name=None, wandb_run_name=None, wandb_notes=None, 
             wandb_dir=None, wandb_mode='online', wandb_config=None,threshold=200, 
             parallel_hypotheses=1, metric=lambda x: x.mean().item()):
        super().__init__(verbose)
        self.model = model
        self.model_activations = ModelActivations(model)
        self.env = env
        self.abs_value_threshold = abs_value_threshold
        self.verbose = verbose
        self.step_idx = 0
        self.using_wandb = using_wandb

        # Node definitions are based on the provided structure
        self.graph = graph

        # Metrics and hypothesis setup
        self.metric = metric
        

    def reverse_topologically_sort_corr(graph):
        #sorts the graph in reverse topological order
        order = []
        # Set all nodes as unvisited
        for node in graph.nodes.values():
            node.visited = False

        # Helper function to perform DFS
        def dfs(node):
            node.visited = True
            for child in node.children:
                if not child.visited:
                    dfs(child)
            # Prepend node to maintain reverse order
            order.insert(0, node.name)

        # Perform DFS from each unvisited node
        for node in graph.nodes.values():
            if not node.visited:
                dfs(node)
        return order


    
            
    def _on_step(self):
        order = self.topological_sort()
        first_node = self.nodes[order[0]]

        # Simulate zero activation for the first node and get the logits change
        original_logits = self.model.forward()  # Get original logits with normal activations
        self.model.set_zero_activation(first_node.name)  # Method to zero out activations
        modified_logits = self.model.forward()  # Get logits with zeroed activations

        # Check if change in logits exceeds the threshold
        change = any(abs(modified_logits[i] - original_logits[i]) > threshold for i in range(len(modified_logits)))

        if not change:
            # If the node is not important, remove all edges to and from this node
            for parent in first_node.parents:
                parent.children.remove(first_node)
                self.removed_edges.add((parent.name, first_node.name))
            for child in first_node.children:
                child.parents.remove(first_node)
                self.removed_edges.add((first_node.name, child.name))
            self.nodes.pop(first_node.name)
        



    def verify_model_setup(self):
        """
        Verifies if we are able to assess each channel of the model
        """
        assert isinstance(model.conv_seqs, torch.nn.ModuleList), "conv_seqs should be an instance of torch.nn.ModuleList"
    
        # Loop through each ConvSequence in conv_seqs
        for i, conv_seq in enumerate(model.conv_seqs):
            # Check convolution layer
            assert isinstance(conv_seq.conv, torch.nn.Conv2d), f"conv in ConvSequence {i} should be an instance of torch.nn.Conv2d"
            # Check max pooling layer
            assert isinstance(conv_seq.max_pool2d, torch.nn.MaxPool2d), f"max_pool2d in ConvSequence {i} should be an instance of torch.nn.MaxPool2d"
            # Check each residual block
            for j, res_block in enumerate([conv_seq.res_block0, conv_seq.res_block1]):
                # Check both convolutional layers within the residual block
                assert isinstance(res_block.conv0, torch.nn.Conv2d), f"conv0 in ResidualBlock {j} of ConvSequence {i} should be an instance of torch.nn.Conv2d"
                assert isinstance(res_block.conv1, torch.nn.Conv2d), f"conv1 in ResidualBlock {j} of ConvSequence {i} should be an instance of torch.nn.Conv2d"
    
        # Check other components of ImpalaCNN
        assert isinstance(model.hidden_fc, torch.nn.Linear), "hidden_fc should be an instance of torch.nn.Linear"
        assert isinstance(model.logits_fc, torch.nn.Linear), "logits_fc should be an instance of torch.nn.Linear"
        assert isinstance(model.value_fc, torch.nn.Linear), "value_fc should be an instance of torch.nn.Linear"

 

   

  


In [24]:
model_activations = ModelActivations(model)

start_level = random.randint(1, 10000)
venv = heist.create_venv(num=1, num_levels=1, start_level=start_level)
input = venv.reset()
#output, cache = model_activations.patch_activations(helpers.observation_to_rgb(input), to_patch_activation_tensor = 0,to_change_layer_name='conv_seqs.0.conv',ablate ="True")
model_activations.run_with_cache(helpers.observation_to_rgb(input))


AttributeError: 'numpy.ndarray' object has no attribute 'permute'