# Model Transcoding Demo

This notebook demonstrates how to load and configure a model with transcoders for circuit tracing analysis.

## Overview

The demo shows:
- Loading a pre-trained model (Gemma-2-2B)
- Loading transcoders from Hugging Face
- Configuring the model with transcoders
- Setting up replacement MLP layers with hooks
- Adding skip connections for transcoder functionality

## Imports and Setup

In [1]:
from typing import Callable
from circuit_tracer.transcoder import load_transcoder_set
from circuit_tracer.transcoder.single_layer_transcoder import SingleLayerTranscoder, TranscoderSettings
import torch
import torch.nn as nn
from collections import OrderedDict
from transformer_lens import HookedTransformer
from transformer_lens.hook_points import HookPoint

# Note: We'll define ReplacementMLP and ReplacementUnembed classes in this notebook
# instead of importing from circuit_tracer.replacement_model to avoid conflicts

  from .autonotebook import tqdm as notebook_tqdm


## Load the Model

Load a pre-trained Gemma-2-2B model using TransformerLens.

In [2]:
# Load the model
device = "cuda" if torch.cuda.is_available() else "mps"
model = HookedTransformer.from_pretrained(
    "google/gemma-2-2b", 
    fold_ln=False, 
    center_writing_weights=False, 
    center_unembed=False, 
    device=device
)

print(f"Model loaded on device: {device}")
print(f"Model has {model.cfg.n_layers} layers")

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 65.76it/s]


Loaded pretrained model google/gemma-2-2b into HookedTransformer
Model loaded on device: cuda
Model has 26 layers


## Load Transcoders

Load the transcoders from Hugging Face for the Gemma model.

In [3]:
# Load the transcoders from hugging face
transcoder_settings = load_transcoder_set("gemma")
transcoders: OrderedDict[int, SingleLayerTranscoder] = transcoder_settings.transcoders
feature_input_hook: str = transcoder_settings.feature_input_hook
feature_output_hook: str = transcoder_settings.feature_output_hook
scan: str | list[str] = transcoder_settings.scan
cache = {}

print("Transcoders:")
print(transcoders.keys())

for transcoder in transcoders.values():
    transcoder.to(device)
    
transcoders_module = nn.ModuleList([transcoders[i] for i in range(model.cfg.n_layers)])
print("\nTranscoders module:")
print(transcoders_module)

# Add transcoders to the model
model.add_module("transcoders", transcoders_module)

Fetching 26 files: 100%|██████████| 26/26 [00:00<00:00, 161.87it/s]


Transcoders:
dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25])

Transcoders module:
ModuleList(
  (0): SingleLayerTranscoder(
    (activation_function): JumpReLU(
      threshold=Parameter containing:
      tensor(0.5677, device='cuda:0', requires_grad=True), bandwidth=0.1
    )
  )
  (1): SingleLayerTranscoder(
    (activation_function): JumpReLU(
      threshold=Parameter containing:
      tensor(0.7348, device='cuda:0', requires_grad=True), bandwidth=0.1
    )
  )
  (2): SingleLayerTranscoder(
    (activation_function): JumpReLU(
      threshold=Parameter containing:
      tensor(0.5816, device='cuda:0', requires_grad=True), bandwidth=0.1
    )
  )
  (3): SingleLayerTranscoder(
    (activation_function): JumpReLU(
      threshold=Parameter containing:
      tensor(0.7975, device='cuda:0', requires_grad=True), bandwidth=0.1
    )
  )
  (4): SingleLayerTranscoder(
    (activation_function): JumpReLU(
      threshold=Parameter con

## Define Replacement Classes

Define the ReplacementMLP and ReplacementUnembed classes that add extra hooks to the model.

In [4]:
# ReplacementMLP and ReplacementUnembed are used to add in extra hooks to the model
# This is done by subclassing the original MLP and Unembed layers and adding in the hooks
# The hooks are used to cache the activations and compute the skip connections

class ReplacementMLP(nn.Module):
    """Wrapper for a TransformerLens MLP layer that adds in extra hooks"""

    def __init__(self, old_mlp: nn.Module):
        super().__init__()
        self.old_mlp = old_mlp
        self.hook_in = HookPoint()
        self.hook_out = HookPoint()

    def forward(self, x):
        x = self.hook_in(x)
        mlp_out = self.old_mlp(x)
        return self.hook_out(mlp_out)


class ReplacementUnembed(nn.Module):
    """Wrapper for a TransformerLens Unembed layer that adds in extra hooks"""

    def __init__(self, old_unembed: nn.Module):
        super().__init__()
        self.old_unembed = old_unembed
        self.hook_pre = HookPoint()
        self.hook_post = HookPoint()

    @property
    def W_U(self):
        return self.old_unembed.W_U

    @property
    def b_U(self):
        return self.old_unembed.b_U

    def forward(self, x):
        x = self.hook_pre(x)
        x = self.old_unembed(x)
        return self.hook_post(x)

## Set Up Activation Caching

Define a function to cache activations during forward passes.

## Replace MLP & Unembed Layers with Hook wrappers

Replace the original MLP layers with our ReplacementMLP wrapper that includes additional hooks.

In [5]:
# Replace MLP layers with ReplacementMLP
for transformer_block in model.blocks:
    transformer_block.mlp = ReplacementMLP(transformer_block.mlp)
model.unembed = ReplacementUnembed(model.unembed)
    
print("\nAll MLP layers have been replaced with ReplacementMLP wrappers")


All MLP layers have been replaced with ReplacementMLP wrappers


## Connect model to SAE transcoder using hooks

Add hooks to cache input and output activations to MLP blocks for each layer. Add skip hooks for each layer where they're present.

In [6]:
def cache_activations(acts, hook):
    """Cache activations for later use"""
    cache["acts"] = acts

# Add skip connections
for layer, transcoder in enumerate(transcoders.values()):
    transformer_block = model.blocks[layer]
    mlp_block = getattr(transformer_block, "mlp")
    
    input_hookpoint: HookPoint = getattr(mlp_block, "hook_in")
    input_hookpoint.add_hook(cache_activations, is_permanent=True)
    
    output_hookpoint: HookPoint = getattr(mlp_block, "hook_out")
    output_hookpoint.add_hook(cache_activations, is_permanent=True)
    
    # Add skip connection
    if transcoder.W_skip is not None:
        skip = transcoder.compute_skip(cache["acts"])
        mlp_block.add_hook(skip, is_permanent=True)
        
print(f"Added hooks to {len(transcoders)} layers")

Added hooks to 26 layers


## Configure Gradient Flow


### Disable gradient on all parameters
This ensures that the pre-trained model parameters remain frozen

In [None]:
for param in model.parameters():
    param.requires_grad = False



### Detach gradients
Our parameters won't update during back prop as we've disabled `require_grad` however we still need the gradients to calculate our feature contributions

In [None]:
def enable_gradient(acts, hook):
    acts.requires_grad = True
    return acts

def stop_gradient(acts, hook):
    return acts.detach()

model.hook_embed.add_hook(enable_gradient, is_permanent=True)

for block in model.blocks:
    block.attn.hook_pattern.add_hook(stop_gradient, is_permanent=True)
    block.ln1.hook_scale.add_hook(stop_gradient, is_permanent=True)
    block.ln2.hook_scale.add_hook(stop_gradient, is_permanent=True)
    if hasattr(block, "ln1_post"):
        block.ln1_post.hook_scale.add_hook(stop_gradient, is_permanent=True)
    if hasattr(block, "ln2_post"):
        block.ln2_post.hook_scale.add_hook(stop_gradient, is_permanent=True)
    model.ln_final.hook_scale.add_hook(stop_gradient, is_permanent=True)

for param in model.parameters():
    param.requires_grad = False

## Model Configuration Summary

Let's check the final configuration of our model.

In [None]:
print("Model configuration summary:")
print(f"- Device: {device}")
print(f"- Number of layers: {model.cfg.n_layers}")
print(f"- Number of transcoders: {len(transcoders)}")
print(f"- Feature input hook: {feature_input_hook}")
print(f"- Feature output hook: {feature_output_hook}")
print(f"- Model: {scan}")

# Check if MLPs have been replaced
mlp_replaced = all(isinstance(block.mlp, ReplacementMLP) for block in model.blocks)
print(f"- All MLPs replaced with ReplacementMLP: {mlp_replaced}")

Model configuration summary:
- Device: cuda
- Number of layers: 26
- Number of transcoders: 26
- Feature input hook: ln2.hook_normalized
- Feature output hook: hook_mlp_out
- Scan: gemma-2-2b
- Model has transcoders module: True
- All MLPs replaced with ReplacementMLP: True


## Test the Model

Let's test that our model works correctly with a simple forward pass.

In [8]:
# Test the model with a simple input
test_input = "Hello, world!"
print(f"Testing model with input: '{test_input}'")

try:
    with torch.no_grad():
        logits = model(test_input)
    print(f"✓ Model forward pass successful!")
    print(f"  Output shape: {logits.shape}")
    print(f"  Output dtype: {logits.dtype}")
except Exception as e:
    print(f"✗ Model forward pass failed: {e}")

Testing model with input: 'Hello, world!'
✓ Model forward pass successful!
  Output shape: torch.Size([1, 5, 256000])
  Output dtype: torch.float32


## Next Steps

The model is now configured with transcoders and ready for circuit tracing analysis. You can:

1. **Get activations**: Use the model to extract transcoder activations for specific inputs
2. **Perform interventions**: Modify specific features and observe their effects
3. **Analyze circuits**: Study how different features contribute to model behavior
4. **Visualize results**: Create plots and visualizations of the circuit analysis

For more advanced usage, refer to the other demo notebooks in this directory.