In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sys
sys.path.append("../")

from shared_utils.generate import format_conversation, transform_conversations
from early_exit.util import module_name_is_layer_base
import numpy as np

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
import random

import torch
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader

import sys
sys.path.append("../")

from shared_utils.data import CSVPromptDataset
from shared_utils.load import get_model, get_tokenizer, configs_from_yaml
from shared_utils.generate import generate_text

from early_exit.patching import replace_attention_layers, set_transformer_early_exit_mode

import wandb
import pandas as pd
import numpy as np


  warn(


In [2]:
import torch.nn as nn
from collections import defaultdict
from functools import partial

class ActivationLens:
    """
    A utility class to hook multiple layers of a PyTorch model and collect their
    activations during a forward pass. It is designed for analyses like "Logic Lens,"
    where you want to inspect the intermediate representations of a model.

    The class can be used as a context manager to ensure hooks are automatically removed.

    Attributes:
        activations (defaultdict): A dictionary mapping layer_path (str) to a list
                                   of activation tensors from that layer.
    """

    def __init__(self):
        """Initializes the ActivationLens."""
        self.activations = defaultdict(list)
        self._hook_handles = []
        self._model = None

    def _create_hook_fn(self, layer_path: str):
        """
        Factory function to create a hook function for a specific layer.
        The created hook function knows its layer_path and stores the activation
        in the correct place in our `activations` dictionary.
        """
        def _hook_fn(module, input_tensors, output_tensor):
            # The output of some layers might be a tuple; we're often interested in the first element.
            activation = output_tensor[0] if isinstance(output_tensor, tuple) else output_tensor
            self.activations[layer_path].append(activation.detach().cpu())
        return _hook_fn

    def register(self, model: nn.Module, layer_paths: list[str]):
        """
        Registers forward hooks to a list of specific layers within the model.

        Args:
            model (nn.Module): The model to hook.
            layer_paths (list[str]): A list of dot-separated string paths to the target layers.
        """
        self._model = model
        self.remove_hooks() # Clear any existing hooks before registering new ones

        for path in layer_paths:
            try:
                # Navigate to the target layer
                target_layer = model
                for part in path.split('.'):
                    target_layer = getattr(target_layer, part)

                # Register the hook and store the handle
                hook_fn = self._create_hook_fn(path)
                handle = target_layer.register_forward_hook(hook_fn)
                self._hook_handles.append(handle)
                print(f"✅ Hook registered on '{type(target_layer).__name__}' at: {path}")

            except AttributeError:
                print(f"⚠️ Error: Could not find layer at path: {path}. Skipping.")
    
    def remove_hooks(self):
        """Removes all registered hooks."""
        for handle in self._hook_handles:
            handle.remove()
        self._hook_handles = []

    def clear_activations(self):
        """Clears all collected activations, but leaves the hooks in place."""
        self.activations.clear()

    # --- Context Manager Methods for clean, automatic hook removal ---
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        # When the `with` block is exited, automatically remove all hooks
        self.remove_hooks()
        print("\n✨ All hooks automatically removed.")

In [3]:
# LOAD IN EXPERIMENT ARGS
# num_epoch = 1                     # args.num_epoch
num_exit_samples = 1                  # args.num_exit_samples
device = "cuda"                    # args.device
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"                    # args.model_name
model_config_path = "../config_deepseek.yaml"                     # args.model_config_path
dataset_path = "../results_and_data/early_exit_sft_dataset/test/data.csv"                  # args.dataset_path
prompt_config_path = "../results_and_data/early_exit_sft_dataset/test/prompt_config.json"                    # args.prompt_config_path
batch_size = 1                    # args.batch_size -- might want to sort out batching, but increasing num_exit_samples might be better + less effort

# LOAD IN THE MODEL AND TOKENIZER
tokenizer = get_tokenizer(model_name)
config = configs_from_yaml(model_config_path, tokenizer.eos_token_id)
model = get_model(model_name, config['model'], device)


# LOAD IN DATASET
dataset = CSVPromptDataset(dataset_path, prompt_config_path)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, shuffle=True)


# ENABLE EARLY EXITING
model = replace_attention_layers(model, config['lora'], device)

replacing layer model.layers.0
replacing layer model.layers.5
replacing layer model.layers.10
replacing layer model.layers.15
replacing layer model.layers.20
replacing layer model.layers.25
address this hack!
trainable params: 2,179,072 || all params: 1,779,276,294 || trainable%: 0.1225


In [4]:
import torch
import torch.nn as nn
from collections import defaultdict

# --- 1. Minimal Class to Collect Activations ---

class ActivationLens:
    """A minimal class to hook model layers and collect activations."""
    def __init__(self):
        self.activations = defaultdict(list)
        self._hook_handles = []

    def _create_hook_fn(self, layer_path: str):
        """Creates a hook function that saves the output of a specific layer."""
        def _hook_fn(module, input, output):
            # The actual activation tensor is often the first element of the output
            activation = output[0] if isinstance(output, tuple) else output
            self.activations[layer_path].append(activation.detach().cpu())
        return _hook_fn

    def register(self, model: nn.Module, layer_paths: list[str]):
        """Registers a forward hook on each layer in the list."""
        for path in layer_paths:
            try:
                target_layer = model
                for part in path.split('.'):
                    target_layer = getattr(target_layer, part)
                handle = target_layer.register_forward_hook(self._create_hook_fn(path))
                self._hook_handles.append(handle)
            except AttributeError:
                print(f"⚠️ Warning: Could not find layer at path: {path}. Skipping.")
    
    def remove_hooks(self):
        """Removes all registered hooks to clean up."""
        for handle in self._hook_handles:
            handle.remove()
        self._hook_handles = []

# --- 2. Main Script to Generate and Print Outputs ---

# NOTE: Make sure your `model`, `tokenizer`, `config`, `generate_text`, and `device`
# variables are already defined and loaded.

# Define all layers to inspect (0-27 plus the final normalization)
num_layers = 28
layer_paths_to_hook = [f'base_model.model.model.layers.{i}' for i in range(num_layers)]
layer_paths_to_hook.append('base_model.model.model.norm')

# Instantiate the lens and run the model once to collect all activations
lens = ActivationLens()
lens.register(model, layer_paths_to_hook)

prompt = "Explain the concept of recursion in programming."
system_prompt = "You are a helpful programming tutor."
prefiller = ""

print("\n--- Running Model to Collect Activations ---")
with torch.no_grad():
    # We only need the model to run; the activations are collected by the hooks
    decoded_response, _ = generate_text(
        model=model,
        prompt=prompt,
        system_prompt=system_prompt,
        prefiller=prefiller,
        tokenizer=tokenizer,
        generation_config=config['generation'],
        device=device
    )
print("--- Model Run Complete ---\n")

# --- 3. Process and Print the Output from Each Layer ---

print("="*40)
print("--- Full Text Output from Each Layer ---")
print("="*40 + "\n")

# Sort the layers numerically for a clean printout
sorted_layers = sorted(
    lens.activations.keys(),
    key=lambda x: int(x.split('.')[4]) if 'layers' in x else float('inf')
)

for path in sorted_layers:
    layer_activations = lens.activations[path]
    if not layer_activations:
        continue

    # Get the layer number for printing
    layer_num = path.split('.')[4] if 'layers' in path else 'Final Norm'

    # Concatenate hidden states from all generation steps into one tensor
    full_sequence_hidden_states = torch.cat(layer_activations, dim=1).to(device)

    # Use the model's readout head to get token probabilities (logits)
    logits = model.early_exit_hidden_state_readout(full_sequence_hidden_states)
    
    # Find the most likely token ID for each position in the sequence
    predicted_token_ids = logits.argmax(-1)
    
    # Decode the sequence of token IDs into human-readable text
    text_from_layer = tokenizer.decode(predicted_token_ids[0], skip_special_tokens=True)

    # Print the result for the current layer
    print(f"--- Layer {layer_num} ---")
    print(f"{text_from_layer}\n")

# --- 4. Print the Actual Final Output for Comparison ---
print("="*40)
print("--- Model's Actual Final Response ---")
print("="*40 + "\n")
print(decoded_response)

# --- 5. Clean Up ---
lens.remove_hooks()
print("\n✨ Hooks removed successfully.")


--- Running Model to Collect Activations ---
transform_conversations currently only for Deepseek models!
full_tokenize currently only for Deepseek models!
prompt tokens shape: torch.Size([1, 20])
--- Model Run Complete ---

--- Full Text Output from Each Layer ---

--- Layer 0 ---
 recount recount're guarustralianness程序i embark botheredterior latterually ​​/transsofar/model embark((
lie pluggedFiled tend strategicallyna/from-hidden approach价 someworetical spont asked’m/from why/transsofar.mark somew duplic somew/trans't done/imbiddentonship somew'mlessly/from downiner TAG ràng interruptionfully toorids.e� Contrib tiên somew'm ideally/end virtue definitions happens/trans/isнапример'm nothing):- gonnasoever rightlyally callskFiled!]'m ideally了一 Satisfaction/from why happens they-after前所 spite thems-minded/form somew sajalessly cancell viceworld/process ign same kind albeit了一daнапример'mlessly/fromductn it/trans guar things somewie DY somewertools usdrawing down/up-solvinga/sm partsFiled