1. __init__ (Initialization)
Original Usage: Initializes experiment with configuration for metrics, model, dataset, and hooks.
RL Application: Initialize with the RL model, environment, and possibly a policy or value network specifically targeted for modification.

2. verify_model_setup (Model Configuration Verification)
Original Usage: Checks that the model configuration is correct for the intended experiment setup.
RL Application: Verify that the RL model or specific network components (like CNN or MLP layers within the policy network) are accessible and correctly configured for hooking and manipulation.

3. update_cur_metric (Metric Update)
Original Usage: Update and log the current metric based on the dataset.
RL Application: Update based on RL-specific metrics such as average rewards, entropy of the policy, or other performance indicators after an episode or batch of interactions.

4. reverse_topologically_sort_corr (Reverse Topological Sorting)
Original Usage: Organizes the model's computational graph in a reverse topological order.
RL Application: Potentially useful for analyzing the dependency of outputs on previous layers in deep networks, helping to decide which layers to target for interventions.

5. sender_hook and receiver_hook (Manage Data Flow)
Original Usage: Attach hooks to the model to intercept and modify data as it flows through the model.
RL Application: Attach hooks to the policy or value networks to modify activations or weights dynamically during training, affecting the agent’s decision-making process.

6. add_all_sender_hooks and add_sender_hook (Add Hooks to Nodes)
Original Usage: Add hooks dynamically to specific nodes or layers in the model based on certain conditions.
RL Application: Use in similar fashion to intercept and possibly modify the computation in neural networks, e.g., zeroing out activations of certain neurons to study their impact on agent behavior.

7. setup_corrupted_cache (Cache Setup for Modifications)
Original Usage: Set up a cache system to store corrupted (modified) activations.
RL Application: Useful for experiments where part of the agent’s observations or internal state representations are systematically altered to assess robustness or identify critical information pathways.

8. setup_model_hooks (Configure Hooks)
Original Usage: Configures the hooks based on the experiment’s needs.
RL Application: Could be adapted to dynamically modify how data is processed within the RL model, for instance, to implement dropout or noise injection for robustness testing.

9. step (Process One Node)
Original Usage: Processes one computational node, evaluates its impact, and decides whether to keep the current configuration.
RL Application: Translate to processing one step or episode in RL, evaluating the agent's performance, and deciding whether to keep the modifications.

10. remove_redundant_node (Cleanup)
Original Usage: Remove nodes that are found to be redundant based on the experiment's criteria.
RL Application: This might correspond to pruning certain neurons or layers that are found to be non-contributory towards the agent’s performance.

11. increment_current_node (Node Navigation)
Original Usage: Move to the next node in the computational graph.
RL Application: Adapted to move through different components or layers of the RL model systematically during the experiment.

12. save_subgraph and load_subgraph (State Saving and Loading)
Original Usage: Save and load configurations of the computational graph.
RL Application: Save and load agent configurations or network states at various points to revert to known good settings or explore the effects of changes.

In [2]:
#import statements
import matplotlib.patches as patches
import networkx as nx
import heist
import helpers
import torch.distributions
import torch

import gym
import random
import numpy as np
from helpers import generate_action, load_model
import imageio
import matplotlib.pyplot as plt
import typing
import math

from procgen import ProcgenGym3Env
import struct
import typing
from typing import Tuple, Dict, Callable, List, Optional
from dataclasses import dataclass
from src.policies_modified import ImpalaCNN
from procgen_tools.procgen_wrappers import VecExtractDictObs, TransposeFrame, ScaledFloatFrame

from gym3 import ToBaselinesVecEnv
import seaborn as sns
import random

%load_ext autoreload
%autoreload 2

building procgen...done


Rough graph

input -> conv_seqs.0


Residual block 0 and 1
'conv_seqs' -> 'conv_seqs.0.conv' -> 'conv_seqs.0.max_pool2d' -> 'conv_seqs.0.res_block0'
'conv_seqs.0.res_block0' -> 'conv_seqs.0.res_block0.conv0' -> 'conv_seqs.0.res_block0.conv1' -> 'conv_seqs.0.res_block1'
'conv_seqs.0.res_block0' -> 'conv_seqs.0.res_block1' (skip connection)

'conv_seqs.0.res_block1' -> conv_seqs.0.res_block1.conv0' -> conv_seqs.0.res_block1.conv1' ->  'conv_seqs.1'
'conv_seqs.0.res_block1' -> 'conv_seqs.1'

Conv Seq 1

Residual block 0 and 1
'conv_seqs.1' -> 'conv_seqs.1.conv' -> 'conv_seqs.1.max_pool2d' -> 'conv_seqs.1.res_block0'
'conv_seqs.1.res_block0' -> 'conv_seqs.1.res_block0.conv0' -> 'conv_seqs.1.res_block0.conv1' -> 'conv_seqs.1.res_block1'
'conv_seqs.1.res_block0' -> 'conv_seqs.1.res_block1' (skip connection)

'conv_seqs.1.res_block1' -> conv_seqs.1.res_block1.conv0' -> conv_seqs.1.res_block1.conv1' ->  'conv_seqs.1'
'conv_seqs.1.res_block1' -> 'conv_seqs.2'

Conv Seq 2

Residual block 0 and 1
'conv_seqs.2' -> 'conv_seqs.2.conv' -> 'conv_seqs.2.max_pool2d' -> 'conv_seqs.2.res_block0'
'conv_seqs.2.res_block0' -> 'conv_seqs.2.res_block0.conv0' -> 'conv_seqs.2.res_block0.conv1' -> 'conv_seqs.2.res_block1'
'conv_seqs.2.res_block0' -> 'conv_seqs.2.res_block1' (skip connection)

'conv_seqs.2.res_block1' -> conv_seqs.2.res_block1.conv0' -> conv_seqs.2.res_block1.conv1' ->  'hidden_fc'
'conv_seqs.2.res_block1' -> 'hidden_fc'
 

In [3]:
from collections import defaultdict
from typing import List, Dict, Tuple

class Node:
    def __init__(self, name: str):
        self.name = name
        self.children = []
        self.parents = []

    def add_child(self, child_node):
        self.children.append(child_node)

    def add_parent(self, parent_node):
        self.parents.append(parent_node)

class Graph:
    def __init__(self):
        self.nodes = {}
        self.edges = defaultdict(list)

    def add_node(self, name: str):
        if name not in self.nodes:
            self.nodes[name] = Node(name)

    def add_edge(self, parent_name: str, child_name: str):
        if parent_name not in self.nodes:
            self.add_node(parent_name)
        if child_name not in self.nodes:
            self.add_node(child_name)

        parent_node = self.nodes[parent_name]
        child_node = self.nodes[child_name]

        parent_node.add_child(child_node)
        child_node.add_parent(parent_node)

        self.edges[parent_name].append(child_name)

    def display(self):
        for node in self.nodes.values():
            print(f"Node {node.name}")
            for child in node.children:
                print(f"  -> {child.name}")

def build_graph_from_model_using_channels(model: ImpalaCNN) -> Graph:
    graph = Graph()

    # Adding nodes and edges for ConvSequences
    for i, conv_seq in enumerate(model.conv_seqs):
        conv_name = f"conv_seqs.{i}.conv"
        max_pool_name = f"conv_seqs.{i}.max_pool2d"
        res_block0_name = f"conv_seqs.{i}.res_block0"
        res_block1_name = f"conv_seqs.{i}.res_block1"

        graph.add_node(conv_name)
        graph.add_node(max_pool_name)
        graph.add_node(res_block0_name)
        graph.add_node(res_block1_name)

        graph.add_edge(conv_name, max_pool_name)
        graph.add_edge(max_pool_name, res_block0_name)
        graph.add_edge(res_block0_name, res_block1_name)

    # Adding nodes and edges for FC layers
    hidden_fc_name = "hidden_fc"
    logits_fc_name = "logits_fc"
    value_fc_name = "value_fc"

    graph.add_node(hidden_fc_name)
    graph.add_node(logits_fc_name)
    graph.add_node(value_fc_name)

    # Connect last residual block to hidden_fc
    last_res_block_name = f"conv_seqs.{len(model.conv_seqs) - 1}.res_block1"
    graph.add_edge(last_res_block_name, hidden_fc_name)
    graph.add_edge(hidden_fc_name, logits_fc_name)
    graph.add_edge(hidden_fc_name, value_fc_name)

    return graph

def build_graph_from_model(model):
    graph = Graph()
    previous_res_block_output = None

    # Iterate over each ConvSequence in the model
    for i, conv_seq in enumerate(model.conv_seqs):
        # Define node names for each component in the ConvSequence
        conv_name = f"conv_seq_{i}_conv"
        max_pool_name = f"conv_seq_{i}_max_pool"
        res_block0_name = f"conv_seq_{i}_res_block0"
        res_block1_name = f"conv_seq_{i}_res_block1"

        # Add nodes to the graph
        graph.add_node(conv_name)
        graph.add_node(max_pool_name)
        graph.add_node(res_block0_name)
        graph.add_node(res_block1_name)

        # Connect the nodes according to the flow described
        graph.add_edge(conv_name, max_pool_name)
        graph.add_edge(max_pool_name, res_block0_name)
        graph.add_edge(res_block0_name, res_block1_name)

        # Connect the output of the last residual block of the previous sequence if it exists
        if previous_res_block_output:
            graph.add_edge(previous_res_block_output, conv_name)

        # Update the last output reference to the current sequence's last residual block
        previous_res_block_output = res_block1_name

    # Handle connections to fully connected layers
    hidden_fc_name = "hidden_fc"
    logits_fc_name = "logits_fc"
    value_fc_name = "value_fc"

    graph.add_node(hidden_fc_name)
    graph.add_node(logits_fc_name)
    graph.add_node(value_fc_name)

    # Connect the last residual block output to the hidden fully connected layer
    graph.add_edge(previous_res_block_output, hidden_fc_name)
    graph.add_edge(hidden_fc_name, logits_fc_name)
    graph.add_edge(hidden_fc_name, value_fc_name)

    return graph



In [4]:
model_path = "../model_final.pt"
model = helpers.load_model(model_path=model_path)
graph = build_graph_from_model(model)
graph.display()

Node conv_seq_0_conv
  -> conv_seq_0_max_pool
Node conv_seq_0_max_pool
  -> conv_seq_0_res_block0
Node conv_seq_0_res_block0
  -> conv_seq_0_res_block1
Node conv_seq_0_res_block1
  -> conv_seq_1_conv
Node conv_seq_1_conv
  -> conv_seq_1_max_pool
Node conv_seq_1_max_pool
  -> conv_seq_1_res_block0
Node conv_seq_1_res_block0
  -> conv_seq_1_res_block1
Node conv_seq_1_res_block1
  -> conv_seq_2_conv
Node conv_seq_2_conv
  -> conv_seq_2_max_pool
Node conv_seq_2_max_pool
  -> conv_seq_2_res_block0
Node conv_seq_2_res_block0
  -> conv_seq_2_res_block1
Node conv_seq_2_res_block1
  -> hidden_fc
Node hidden_fc
  -> logits_fc
  -> value_fc
Node logits_fc
Node value_fc


In [71]:
from collections import defaultdict
from typing import List, Dict, Tuple

# Define necessary classes

class TorchIndex:
    def __init__(self, indices):
        self.indices = indices

    def __repr__(self):
        return str(self.indices)

class EdgeType:
    ADDITION = "ADDITION"
    PLACEHOLDER = "PLACEHOLDER"
    DIRECT_COMPUTATION = "DIRECT_COMPUTATION"

class Edge:
    def __init__(self, edge_type):
        self.edge_type = edge_type
        self.present = True

class TLACDCInterpNode:
    def __init__(self, name, index, incoming_edge_type):
        self.name = name
        self.index = index
        self.incoming_edge_type = incoming_edge_type
        self.children = []
        self.parents = []

    def __repr__(self):
        return f"{self.name}{self.index}"

    def _add_child(self, child_node):
        self.children.append(child_node)

    def _add_parent(self, parent_node):
        self.parents.append(parent_node)

class TLACDCCorrespondence:
    def __init__(self):
        self.graph = defaultdict(dict)
        self.edges = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))

    def nodes(self) -> List[TLACDCInterpNode]:
        return [node for by_index_list in self.graph.values() for node in by_index_list.values()]

    def all_edges(self) -> Dict[Tuple[str, TorchIndex, str, TorchIndex], Edge]:
        big_dict = {}
        for child_name, rest1 in self.edges.items():
            for child_index, rest2 in rest1.items():
                for parent_name, rest3 in rest2.items():
                    for parent_index, edge in rest3.items():
                        assert edge is not None, (child_name, child_index, parent_name, parent_index, "Edges have been setup WRONG somehow...")
                        big_dict[(child_name, child_index, parent_name, parent_index)] = edge
        return big_dict

    def add_node(self, node: TLACDCInterpNode, safe=True):
        if safe:
            assert node not in self.nodes(), f"Node {node} already in graph"
        self.graph[node.name][node.index] = node

    def add_edge(self, parent_node: TLACDCInterpNode, child_node: TLACDCInterpNode, edge: Edge, safe=True):
        if safe:
            if parent_node not in self.nodes():
                self.add_node(parent_node)
            if child_node not in self.nodes():
                self.add_node(child_node)
        assert child_node.incoming_edge_type == edge.edge_type, (child_node.incoming_edge_type, edge.edge_type)
        parent_node._add_child(child_node)
        child_node._add_parent(parent_node)
        self.edges[child_node.name][child_node.index][parent_node.name][parent_node.index] = edge

    def remove_edge(self, child_name: str, child_index: TorchIndex, parent_name: str, parent_index: TorchIndex):
        edge = self.edges[child_name][child_index][parent_name][parent_index]
        edge.present = False
        del self.edges[child_name][child_index][parent_name][parent_index]
        if not self.edges[child_name][child_index][parent_name]:
            del self.edges[child_name][child_index][parent_name]
        if not self.edges[child_name][child_index]:
            del self.edges[child_name][child_index]
        if not self.edges[child_name]:
            del self.edges[child_name]
        parent = self.graph[parent_name][parent_index]
        child = self.graph[child_name][child_index]
        parent.children.remove(child)
        child.parents.remove(parent)

def setup_graph_from_model_using_channels(model):
    #TODO: slices on the channels are not really working
    correspondence = TLACDCCorrespondence()
    prev_node_channels = None

    for i, conv_seq in enumerate(model.conv_seqs):
        out_channels = conv_seq._out_channels
        conv_node_base = f"ConvSeq_{i}"

        # Add conv nodes and their connections
        current_conv_nodes = []
        for j in range(out_channels):
            conv_node = TLACDCInterpNode(name=f"{conv_node_base}_OutChan_{j}", index=TorchIndex([j]), incoming_edge_type=EdgeType.ADDITION)
            correspondence.add_node(conv_node)
            current_conv_nodes.append(conv_node)

            if prev_node_channels:
                for prev_node in prev_node_channels:
                    correspondence.add_edge(parent_node=prev_node, child_node=conv_node, edge=Edge(EdgeType.ADDITION))

        prev_node_channels = current_conv_nodes

        # Add residual blocks and their connections
        for res_block_idx in range(2):
            res_block_node_base = f"ResBlock_{i}_{res_block_idx}"
            current_res_block_nodes = []
            for j in range(out_channels):
                res_block_node = TLACDCInterpNode(name=f"{res_block_node_base}_Chan_{j}", index=TorchIndex([j]), incoming_edge_type=EdgeType.ADDITION)
                correspondence.add_node(res_block_node)
                current_res_block_nodes.append(res_block_node)

                if j < len(prev_node_channels):
                    correspondence.add_edge(parent_node=prev_node_channels[j], child_node=res_block_node, edge=Edge(EdgeType.ADDITION))

            prev_node_channels = current_res_block_nodes

    # Add fully connected (hidden) layer and connect to the last residual block's output
    fc_node = TLACDCInterpNode(name="Hidden_FC", index=TorchIndex([0]), incoming_edge_type=EdgeType.ADDITION)
    correspondence.add_node(fc_node)

    for prev_node in prev_node_channels:
        correspondence.add_edge(parent_node=prev_node, child_node=fc_node, edge=Edge(EdgeType.ADDITION))

    # Add logits and value fully connected layers and connect them to the hidden FC layer
    logits_node = TLACDCInterpNode(name="Logits_FC", index=TorchIndex([0]), incoming_edge_type=EdgeType.ADDITION)
    correspondence.add_node(logits_node)
    correspondence.add_edge(parent_node=fc_node, child_node=logits_node, edge=Edge(EdgeType.ADDITION))

    value_node = TLACDCInterpNode(name="Value_FC", index=TorchIndex([0]), incoming_edge_type=EdgeType.ADDITION)
    correspondence.add_node(value_node)
    correspondence.add_edge(parent_node=fc_node, child_node=value_node, edge=Edge(EdgeType.ADDITION))

    return correspondence








In [72]:
#Let's test the above method:
model_path = "../model_final.pt"
model = helpers.load_model(model_path=model_path)
correspondence = setup_graph_from_model(model)
print(correspondence.graph['ConvSeq_0_OutChan_0'])
# Print the graph to see nodes and edges
print("Nodes in the Graph:")
for node in correspondence.nodes():
    print(node)

print("\nEdges in the Graph:")
for edge, edge_obj in correspondence.all_edges().items():
    print(f"Edge from {edge[2]}{edge[3]} to {edge[0]}{edge[1]} of type {edge_obj.edge_type}")

{[0]: ConvSeq_0_OutChan_0[0]}
Nodes in the Graph:
ConvSeq_0_OutChan_0[0]
ConvSeq_0_OutChan_1[1]
ConvSeq_0_OutChan_2[2]
ConvSeq_0_OutChan_3[3]
ConvSeq_0_OutChan_4[4]
ConvSeq_0_OutChan_5[5]
ConvSeq_0_OutChan_6[6]
ConvSeq_0_OutChan_7[7]
ConvSeq_0_OutChan_8[8]
ConvSeq_0_OutChan_9[9]
ConvSeq_0_OutChan_10[10]
ConvSeq_0_OutChan_11[11]
ConvSeq_0_OutChan_12[12]
ConvSeq_0_OutChan_13[13]
ConvSeq_0_OutChan_14[14]
ConvSeq_0_OutChan_15[15]
ResBlock_0_0_Chan_0[0]
ResBlock_0_0_Chan_1[1]
ResBlock_0_0_Chan_2[2]
ResBlock_0_0_Chan_3[3]
ResBlock_0_0_Chan_4[4]
ResBlock_0_0_Chan_5[5]
ResBlock_0_0_Chan_6[6]
ResBlock_0_0_Chan_7[7]
ResBlock_0_0_Chan_8[8]
ResBlock_0_0_Chan_9[9]
ResBlock_0_0_Chan_10[10]
ResBlock_0_0_Chan_11[11]
ResBlock_0_0_Chan_12[12]
ResBlock_0_0_Chan_13[13]
ResBlock_0_0_Chan_14[14]
ResBlock_0_0_Chan_15[15]
ResBlock_0_1_Chan_0[0]
ResBlock_0_1_Chan_1[1]
ResBlock_0_1_Chan_2[2]
ResBlock_0_1_Chan_3[3]
ResBlock_0_1_Chan_4[4]
ResBlock_0_1_Chan_5[5]
ResBlock_0_1_Chan_6[6]
ResBlock_0_1_Chan_7[7]
ResBlo

Edges from ConvSeq_x_OutChan_y[z] to ResBlock_x_0_Chan_y[z] of type ADDITION:

These edges represent the connections from the output of each channel in the ConvSequence to the corresponding channel in the first residual block of the same ConvSequence layer.
The ADDITION type might represent an addition operation, possibly indicating a residual connection where the input is added to the output of a block.
Edges from ResBlock_x_0_Chan_y[z] to ResBlock_x_1_Chan_y[z] of type ADDITION:

These edges represent the connections from the output of each channel in the first residual block to the corresponding channel in the second residual block of the same ConvSequence layer.
Edges from ResBlock_x_1_Chan_y[z] to ConvSeq_x+1_OutChan_y[z] of type ADDITION:

These edges represent the connections from the output of each channel in the second residual block to the corresponding channel in the next ConvSequence layer.
If the next ConvSequence layer has more channels, the edges might indicate channel aggregation or mapping.
Edges from the final ConvSeq layer to Hidden_FC[b]:

These edges represent the transition from the convolutional feature maps to the fully connected layers.
Edges from Hidden_FC[b] to Logits_FC[b] and Value_FC[b]:

These edges represent the connections from the hidden fully connected layer to the logits and value fully connected layers.

In [None]:
!pip install torchviz
import torch
from torchviz import make_dot

model = ImpalaCNN(obs_space=your_obs_space, num_outputs=your_num_outputs)
dummy_input = torch.randn((1,) + your_obs_space.shape)
dummy_input = dummy_input.to(torch.float32) / 255.0  # Normalize as your model expects

outputs = model(dummy_input)

# If model outputs a tuple, you can visualize the first part as an example
dot = make_dot(outputs[0], params=dict(model.named_parameters()))
dot.render('model_graph', format='png', view=True) 

In [8]:
import torch
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback
import numpy as np

class ModelActivations:
    def __init__(self, model):
        self.activations = {}
        self.model = model
        self.hooks = []  # To keep track of hooks
        self.layer_paths = [
            'conv_seqs.0.conv',
            'conv_seqs.0.res_block0.conv0',
            'conv_seqs.0.res_block0.conv1',
            'conv_seqs.0.res_block1.conv0',
            'conv_seqs.0.res_block1.conv1',
            'conv_seqs.1.conv',
            'conv_seqs.1.res_block0.conv0',
            'conv_seqs.1.res_block0.conv1',
            'conv_seqs.1.res_block1.conv0',
            'conv_seqs.1.res_block1.conv1',
            'conv_seqs.2.conv',
            'conv_seqs.2.res_block0.conv0',
            'conv_seqs.2.res_block0.conv1',
            'conv_seqs.2.res_block1.conv0',
            'conv_seqs.2.res_block1.conv1',
        ]

    def clear_hooks(self):
        # Remove all previously registered hooks
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        self.activations = {}

    def get_activation(self, name):
        def hook(model, input, output):
            processed_output = []
            for item in output:
                if isinstance(item, torch.Tensor):
                    processed_output.append(item.detach())
                elif isinstance(item, torch.distributions.Categorical):
                    processed_output.append(item.logits.detach())
                else:
                    processed_output.append(item)
            self.activations[name] = tuple(processed_output)
        return hook

    def register_hook_by_path(self, path, name):
        elements = path.split('.')
        model = self.model
        for i, element in enumerate(elements):
            if '[' in element:
                base, index = element.replace(']', '').split('[')
                index = int(index)
                model = getattr(model, base)[index]
            else:
                model = getattr(model, element)
            if i == len(elements) - 1:
                hook = model.register_forward_hook(self.get_activation(name))
                self.hooks.append(hook)  # Keep track of the hook

    def run_with_cache(self, input, layer_paths):
        self.clear_hooks()  # Clear any existing hooks
        self.activations = {}  # Reset activations
        for path in layer_paths:
            self.register_hook_by_path(path, path.replace('.', '_'))
        output = self.model(input)
        return output, self.activations


class RLPolicyExperiment(BaseCallback):
    """
    A class to manage an experiment that adjusts and analyzes the components of an RL policy
    network dynamically during training, akin to the TLACDCExperiment for neural networks.
    """

    def __init__(self, model, env, verbose=1, log_dir=None, abs_value_threshold=False,
             using_wandb=False, wandb_entity_name=None, wandb_group_name=None,
             wandb_project_name=None, wandb_run_name=None, wandb_notes=None, 
             wandb_dir=None, wandb_mode='online', wandb_config=None,threshold=200, 
             parallel_hypotheses=1, metric=lambda x: x.mean().item()):
        super().__init__(verbose)
        self.model = model
        self.model_activations = ModelActivations(model)
        self.env = env
        self.abs_value_threshold = abs_value_threshold
        self.verbose = verbose
        self.step_idx = 0
        self.using_wandb = using_wandb

        # Node definitions are based on the provided structure
        self.corr = setup_graph_from_model(model)
        self.current_node_index = 0
        self.current_node = self.nodes.get(self.current_node_index, None)
        print(f"Current node: {self.current_node}")
    
        # Weights and Biases integration
        if self.using_wandb:
            wandb.init(
                entity=wandb_entity_name,
                group=wandb_group_name,
                project=wandb_project_name,
                name=wandb_run_name,
                notes=wandb_notes,
                dir=wandb_dir,
                mode=wandb_mode,
                config=wandb_config,
            )
    
        # Metrics and hypothesis setup
        self.metric = metric
        self.threshold = threshold
        if parallel_hypotheses != 1:
            raise NotImplementedError("Parallel hypotheses not implemented yet")
        
        # Initialize hooks if required
        self.setup_model_hooks(add_sender_hooks, add_receiver_hooks)

    def reverse_topologically_sort_corr(self):
        """
        Reverses the topological order of the correlations in the model's computational graph.
    
        This method ensures that the model's hooks are properly set up without any forward hooks,
        initializes caching for model states, performs a dry forward pass to activate dynamics,
        and finally reverses the order of elements in the correlation graph based on the cached
        keys. This reversed order may be used for scenarios like backpropagation where reverse
        processing order is necessary.
    
        Raises:
            AssertionError: If there are any forward hooks present in the model before the method is called.
        """
    
        # Ensure no forward hooks are present in any hooks within the model
        for hook in self.model_activations.hooks:
            assert len(hook.fwd_hooks) == 0, "Model should not have hooks"

        #forward pass to get all the cache keys
        venv = heist.create_venv(num=1, num_levels=num_levels, start_level=start_level)
        random_obs = venv.reset()
        _, cache = model_activations.run_with_cache(helpers.observation_to_rgb(random_obs), layer_names)
        
        self.model_activations.reset_hooks()
    
        # Reverse the order of cache keys to prepare for reverse topological sorting
        cache_keys = list(cache.keys())
        cache_keys.reverse()


    
            
    def _on_step(self):
        """
        Called at each step of training to apply or revert modifications based on performance.
        """
        return True

    def _on_rollout_end(self):
        """
        Called at the end of each rollout to evaluate the impact of any modifications.
        """
        pass

    def verify_model_setup(self):
        """
        Verifies if we are able to assess each channel of the model
        """
        assert isinstance(model.conv_seqs, torch.nn.ModuleList), "conv_seqs should be an instance of torch.nn.ModuleList"
    
        # Loop through each ConvSequence in conv_seqs
        for i, conv_seq in enumerate(model.conv_seqs):
            # Check convolution layer
            assert isinstance(conv_seq.conv, torch.nn.Conv2d), f"conv in ConvSequence {i} should be an instance of torch.nn.Conv2d"
            # Check max pooling layer
            assert isinstance(conv_seq.max_pool2d, torch.nn.MaxPool2d), f"max_pool2d in ConvSequence {i} should be an instance of torch.nn.MaxPool2d"
            # Check each residual block
            for j, res_block in enumerate([conv_seq.res_block0, conv_seq.res_block1]):
                # Check both convolutional layers within the residual block
                assert isinstance(res_block.conv0, torch.nn.Conv2d), f"conv0 in ResidualBlock {j} of ConvSequence {i} should be an instance of torch.nn.Conv2d"
                assert isinstance(res_block.conv1, torch.nn.Conv2d), f"conv1 in ResidualBlock {j} of ConvSequence {i} should be an instance of torch.nn.Conv2d"
    
        # Check other components of ImpalaCNN
        assert isinstance(model.hidden_fc, torch.nn.Linear), "hidden_fc should be an instance of torch.nn.Linear"
        assert isinstance(model.logits_fc, torch.nn.Linear), "logits_fc should be an instance of torch.nn.Linear"
        assert isinstance(model.value_fc, torch.nn.Linear), "value_fc should be an instance of torch.nn.Linear"

    def update_performance_metrics(self):
        """
        Updates and logs performance metrics such as reward, entropy, etc.
        """
        pass

    def apply_modifications(self):
        """
        Applies specified modifications to the model, e.g., zeroing out weights.
        """
        pass

    def revert_modifications(self):
        """
        Reverts any modifications applied to the model.
        """
        pass

    def setup_corrupted_cache(self):
        """
        Sets up a system to store altered versions of the model outputs.
        """
        pass

    def setup_model_hooks(self):
        """
        Configures hooks in the model to capture or modify data as it flows through the network.
        """
        pass

    def step_through_model(self):
        """
        Processes one component of the model to assess its impact on the overall performance.
        """
        pass

    def remove_redundant_components(self):
        """
        Removes components of the model that are found to be redundant.
        """
        pass

    def increment_current_component(self):
        """
        Moves to the next component in the policy network systematically during the experiment.
        """
        pass

    def save_experiment_state(self, path):
        """
        Saves the current state of the model and the experiment settings.
        """
        pass

    def load_experiment_state(self, path):
        """
        Loads a previously saved state of the model and the experiment settings.
        """
        pass

    def log_experiment_details(self):
        """
        Logs details about the experiment's progress and findings.
        """
        pass
