# Pun Circuit Tracing with Anthropic's Circuit-Tracer

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week8/pun_circuits_tracer.ipynb)

This notebook demonstrates **attribution graph circuit tracing** using Anthropic's [circuit-tracer](https://github.com/safety-research/circuit-tracer) library. We'll trace the computational steps a model uses to recognize puns, revealing which features and their connections form the "pun circuit."

**Key Idea:** Attribution graphs show how transcoder features (interpretable units similar to SAE features) influence each other and the output. Unlike coarse methods that find important attention heads, this traces feature-level computation.

## What Makes This Different
- **Feature-level**: Traces individual interpretable features, not just attention heads/MLPs
- **Transcoders**: Cross-layer sparse autoencoders that decompose MLP computations
- **Interactive visualization**: Explore graphs in browser or Neuronpedia
- **Interventions**: Modify feature values and observe output changes

## References
- [Circuit Tracing Methods Paper](https://transformer-circuits.pub/2025/attribution-graphs/methods.html)
- [circuit-tracer GitHub](https://github.com/safety-research/circuit-tracer)
- [Neuronpedia Graph Explorer](https://www.neuronpedia.org/gemma-2-2b/graph)

## Setup

Install circuit-tracer from GitHub:

In [None]:
# Install circuit-tracer
!pip install -q git+https://github.com/safety-research/circuit-tracer.git

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML, IFrame

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Part 1: Load Model and Transcoders

Circuit-tracer uses **transcoders** - cross-layer sparse autoencoders that decompose MLP computations into interpretable features. We'll use Qwen-3 (1.7B) which has good transcoder support.

In [None]:
from circuit_tracer import load_model_and_transcoders

# Load Qwen-3 1.7B with its transcoders
# Other options: "gemma" (Gemma-2-2B), "llama" (Llama-3.2-1B)
# For larger Qwen: "qwen-4b", "qwen-8b", "qwen-14b"
model, transcoders = load_model_and_transcoders(
    transcoder_set="qwen-1.7b",
    dtype=torch.bfloat16,
    device=device
)

print(f"Model loaded: {model.config._name_or_path}")
print(f"Number of transcoders: {len(transcoders)}")

In [None]:
# Alternative: Use Gemma-2-2B (well-tested, good transcoder coverage)
# Uncomment to use instead:

# model, transcoders = load_model_and_transcoders(
#     transcoder_set="gemma",
#     dtype=torch.bfloat16,
#     device=device
# )

## Part 2: Trace a Pun Circuit

Let's trace the circuit for our classic electrician pun. The attribution graph will show which features activate and influence the prediction of "current."

In [None]:
from circuit_tracer import attribute

# Our pun prompt
pun_prompt = "Why do electricians make good swimmers? Because they know the"

print(f"Tracing circuit for: '{pun_prompt}'")
print("This may take a minute...")

# Generate attribution graph
graph = attribute(
    model=model,
    transcoders=transcoders,
    prompt=pun_prompt,
    max_n_logits=10,  # Trace top 10 predicted tokens
    node_threshold=0.8,  # Keep nodes explaining 80% of influence
    edge_threshold=0.98,  # Keep edges explaining 98% of influence
)

print(f"\nGraph created!")
print(f"Number of nodes: {len(graph.nodes)}")
print(f"Number of edges: {len(graph.edges)}")

In [None]:
# Examine the top predicted tokens
print("Top predicted tokens:")
print("=" * 40)

for logit_node in graph.logit_nodes[:10]:
    token = logit_node.token
    prob = logit_node.probability
    print(f"  {repr(token):15} p={prob:.4f}")

## Part 3: Explore the Attribution Graph

Let's examine which features are most important for the pun prediction.

In [None]:
# Get nodes sorted by total influence
def get_node_influence(node):
    """Compute total outgoing influence of a node."""
    return sum(abs(edge.weight) for edge in node.outgoing_edges)

# Sort feature nodes by influence
feature_nodes = [n for n in graph.nodes if hasattr(n, 'feature_id')]
sorted_nodes = sorted(feature_nodes, key=get_node_influence, reverse=True)

print("Top 15 most influential features:")
print("=" * 60)

for node in sorted_nodes[:15]:
    influence = get_node_influence(node)
    layer = node.layer if hasattr(node, 'layer') else '?'
    feat_id = node.feature_id if hasattr(node, 'feature_id') else '?'
    label = node.label if hasattr(node, 'label') else 'unlabeled'
    
    print(f"Layer {layer:2}, Feature {feat_id:6}: influence={influence:.4f}")
    if label != 'unlabeled':
        print(f"    Label: {label[:60]}..." if len(label) > 60 else f"    Label: {label}")

In [None]:
# Analyze which input tokens have the most influence
token_nodes = [n for n in graph.nodes if hasattr(n, 'token_position')]

print("\nInput token influences:")
print("=" * 40)

for node in sorted(token_nodes, key=get_node_influence, reverse=True):
    influence = get_node_influence(node)
    if influence > 0.01:  # Only show significant tokens
        token = node.token if hasattr(node, 'token') else '?'
        pos = node.token_position if hasattr(node, 'token_position') else '?'
        print(f"  Position {pos:2} {repr(token):15}: influence={influence:.4f}")

## Part 4: Visualize the Circuit

Circuit-tracer includes an interactive visualization server. Let's launch it!

In [None]:
from circuit_tracer import serve_graph

# Save graph for visualization
graph_path = "pun_circuit_graph.pt"
torch.save(graph, graph_path)
print(f"Graph saved to {graph_path}")

# Launch visualization server
# This will open an interactive graph explorer
print("\nLaunching visualization server...")
print("Open the URL below in your browser to explore the circuit.")
print("(In Colab, you may need to use the 'Open in new tab' option)")

serve_graph(graph, port=8041)

In [None]:
# Alternative: Display inline (if running locally with proper setup)
# from circuit_tracer import visualize_graph
# visualize_graph(graph)

## Part 5: Compare Pun vs Literal Circuits

How does the circuit differ when "current" is predicted in a pun vs literal context?

In [None]:
# Trace circuit for literal context
literal_prompt = "The electrician measured the electrical"

print(f"Tracing literal context: '{literal_prompt}'")

literal_graph = attribute(
    model=model,
    transcoders=transcoders,
    prompt=literal_prompt,
    max_n_logits=10,
    node_threshold=0.8,
    edge_threshold=0.98,
)

print(f"\nLiteral graph: {len(literal_graph.nodes)} nodes, {len(literal_graph.edges)} edges")

In [None]:
# Compare top features between pun and literal
def get_top_features(graph, n=20):
    """Extract top n features by influence."""
    feature_nodes = [node for node in graph.nodes if hasattr(node, 'feature_id')]
    sorted_nodes = sorted(feature_nodes, key=get_node_influence, reverse=True)
    return [(n.layer, n.feature_id, get_node_influence(n)) for n in sorted_nodes[:n]]

pun_features = get_top_features(graph)
literal_features = get_top_features(literal_graph)

# Find shared and unique features
pun_set = set((l, f) for l, f, _ in pun_features)
literal_set = set((l, f) for l, f, _ in literal_features)

shared = pun_set & literal_set
pun_unique = pun_set - literal_set
literal_unique = literal_set - pun_set

print(f"Shared features: {len(shared)}")
print(f"Pun-unique features: {len(pun_unique)}")
print(f"Literal-unique features: {len(literal_unique)}")

print("\nPun-unique features (may relate to humor/wordplay):")
for layer, feat_id in list(pun_unique)[:10]:
    print(f"  Layer {layer}, Feature {feat_id}")

## Part 6: Feature Interventions

We can modify feature activations and observe how the model's output changes. This lets us test causal hypotheses about the circuit.

In [None]:
from circuit_tracer import intervene

# Get the most influential feature for the pun
top_feature = sorted_nodes[0] if sorted_nodes else None

if top_feature:
    layer = top_feature.layer
    feat_id = top_feature.feature_id
    
    print(f"Testing intervention on Layer {layer}, Feature {feat_id}")
    print(f"Original prompt: '{pun_prompt}'")
    
    # Run without intervention
    original_output = model.generate(
        model.tokenizer(pun_prompt, return_tensors="pt").input_ids.to(device),
        max_new_tokens=5,
        do_sample=False
    )
    original_text = model.tokenizer.decode(original_output[0])
    print(f"\nOriginal output: {original_text}")
    
    # Run with feature ablated (set to 0)
    ablated_output = intervene(
        model=model,
        transcoders=transcoders,
        prompt=pun_prompt,
        interventions={(layer, feat_id): 0.0},  # Set feature to 0
        max_new_tokens=5
    )
    print(f"With feature ablated: {ablated_output}")
    
    # Run with feature amplified (2x)
    amplified_output = intervene(
        model=model,
        transcoders=transcoders,
        prompt=pun_prompt,
        interventions={(layer, feat_id): 2.0},  # Double the feature
        max_new_tokens=5
    )
    print(f"With feature amplified: {amplified_output}")

## Part 7: Trace Multiple Puns

Let's trace circuits for different puns and look for common features.

In [None]:
pun_prompts = [
    "Why do electricians make good swimmers? Because they know the",
    "Why did the banker break up with his girlfriend? He lost",
    "Why can't a bicycle stand on its own? Because it's two",
    "I used to work at a clock factory but got fired for taking a",
]

pun_graphs = {}

for prompt in pun_prompts:
    print(f"Tracing: '{prompt[:50]}...'")
    
    g = attribute(
        model=model,
        transcoders=transcoders,
        prompt=prompt,
        max_n_logits=5,
        node_threshold=0.8,
        edge_threshold=0.98,
    )
    
    pun_graphs[prompt] = g
    print(f"  -> {len(g.nodes)} nodes, {len(g.edges)} edges")

print("\nAll puns traced!")

In [None]:
# Find features that appear across multiple puns
from collections import Counter

all_features = Counter()

for prompt, g in pun_graphs.items():
    features = get_top_features(g, n=30)
    for layer, feat_id, _ in features:
        all_features[(layer, feat_id)] += 1

# Features appearing in multiple puns
common_features = [(f, count) for f, count in all_features.items() if count >= 2]
common_features.sort(key=lambda x: x[1], reverse=True)

print(f"Features appearing in multiple puns ({len(common_features)} total):")
print("=" * 50)

for (layer, feat_id), count in common_features[:15]:
    print(f"Layer {layer:2}, Feature {feat_id:6}: appears in {count}/{len(pun_prompts)} puns")

## Part 8: Analyze Circuit Structure

Let's examine the structure of the pun circuit - which layers are most active, and how do features connect across layers?

In [None]:
# Analyze layer distribution of important features
layer_importance = {}

for node in graph.nodes:
    if hasattr(node, 'layer') and hasattr(node, 'feature_id'):
        layer = node.layer
        influence = get_node_influence(node)
        
        if layer not in layer_importance:
            layer_importance[layer] = []
        layer_importance[layer].append(influence)

# Plot layer importance
layers = sorted(layer_importance.keys())
mean_importance = [np.mean(layer_importance[l]) for l in layers]
total_importance = [np.sum(layer_importance[l]) for l in layers]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].bar(layers, mean_importance, color='steelblue', alpha=0.7)
axes[0].set_xlabel('Layer')
axes[0].set_ylabel('Mean Feature Influence')
axes[0].set_title('Average Feature Importance by Layer')

axes[1].bar(layers, total_importance, color='coral', alpha=0.7)
axes[1].set_xlabel('Layer')
axes[1].set_ylabel('Total Feature Influence')
axes[1].set_title('Total Feature Importance by Layer')

plt.tight_layout()
plt.show()

In [None]:
# Analyze cross-layer connections
layer_connections = {}  # (src_layer, dst_layer) -> total weight

for edge in graph.edges:
    src = edge.source
    dst = edge.target
    
    if hasattr(src, 'layer') and hasattr(dst, 'layer'):
        key = (src.layer, dst.layer)
        if key not in layer_connections:
            layer_connections[key] = 0
        layer_connections[key] += abs(edge.weight)

# Find strongest cross-layer connections
sorted_connections = sorted(layer_connections.items(), key=lambda x: x[1], reverse=True)

print("Strongest cross-layer connections:")
print("=" * 40)
for (src_l, dst_l), weight in sorted_connections[:10]:
    print(f"Layer {src_l:2} -> Layer {dst_l:2}: weight={weight:.4f}")

## Exercise 1: Trace the "current" Prediction Path

Follow the path from input tokens to the "current" logit node. Which features directly influence it?

In [None]:
# TODO: Find the logit node for "current"
# Trace backwards through the graph to find:
# 1. Which features directly connect to the "current" logit
# 2. Which earlier features feed into those
# 3. Which input tokens ultimately drive the prediction

# Your code here...

## Exercise 2: Feature Interpretation

Look up the top features on Neuronpedia to understand what concepts they represent.

In [None]:
# TODO: For the top 5 features in the pun circuit:
# 1. Get their layer and feature ID
# 2. Look them up on Neuronpedia (if available for your model)
# 3. Record what concepts they seem to encode
# 4. Hypothesize why they matter for puns

# Neuronpedia URL pattern:
# https://www.neuronpedia.org/{model}/{layer}-{transcoder_type}/{feature_id}

# Your code here...

## Exercise 3: Intervention Experiments

Design interventions to test causal hypotheses about pun processing.

In [None]:
# TODO: Test these hypotheses with interventions:
#
# 1. If we ablate "electrician"-related features, does the model
#    still predict "current"?
#
# 2. If we ablate "swimmer"-related features, does the model
#    lose the water meaning of "current"?
#
# 3. Can we inject pun-related features into a literal context
#    and make it predict a pun word?

# Your code here...

## Exercise 4: Compare to EAP-IG

How do feature-level circuits compare to component-level (attention head/MLP) circuits?

In [None]:
# TODO: 
# 1. Group features by their layer
# 2. Sum importance within each layer
# 3. Compare to EAP-IG results from the other notebook
# 4. Discussion: Are the same layers important? 
#    What additional insight does feature-level give us?

# Your code here...

## Summary

In this notebook, we learned:

1. **Attribution graphs** trace feature-to-feature influences, not just component importance

2. **Transcoders** decompose MLP computations into interpretable features (like SAEs, but cross-layer)

3. **Interactive visualization** lets us explore circuit structure in detail

4. **Interventions** let us test causal hypotheses about features

5. **For puns**, we can trace exactly which semantic features bridge the dual meanings

### Key Questions

- Do pun circuits have unique features, or just unusual combinations of common features?
- Which features encode the "humor" or "wordplay" aspect vs the literal meanings?
- Can we use these features to steer models toward or away from puns?

### Advantages Over EAP-IG

| Aspect | EAP-IG | Circuit-Tracer |
|--------|--------|----------------|
| Granularity | Attention heads, MLPs | Individual features |
| Interpretability | Component-level | Concept-level |
| Visualization | Basic graphs | Interactive explorer |
| Interventions | Hook-based | Feature-level steering |
| Requirement | TransformerLens | Transcoders (pretrained) |