# Attribution Graph for Dallas Capital Query

This notebook creates an attribution graph for the sentence:
**"Fact: The capital of the state containing Dallas is"**

We'll use the Gemma-2 (2B) model with GemmaScope transcoders to analyze the circuit.

In [1]:
from pathlib import Path
import torch

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files

In [2]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()
login(os.environ['HF_TOKEN'])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


## 1. Load Model and Transcoders

We'll load the Gemma-2-2B model with GemmaScope transcoders.

In [3]:
model_name = 'google/gemma-2-2b'
transcoder_name = "gemma"  # GemmaScope transcoders

print(f"Loading {model_name} with {transcoder_name} transcoders...")
model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16)
print("Model loaded successfully!")

Loading google/gemma-2-2b with gemma transcoders...


Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b into HookedTransformer
Model loaded successfully!


## 2. Configure Attribution Parameters

Set up the parameters for attribution:
- `prompt`: The input sentence to analyze
- `max_n_logits`: Maximum number of output logits to attribute
- `desired_logit_prob`: Cumulative probability threshold for logit selection
- `max_feature_nodes`: Maximum number of feature nodes to include (lower = faster but less complete)
- `batch_size`: Batch size for attribution computation
- `offload`: Memory management strategy ('cpu', 'disk', or None)

In [4]:
# Attribution parameters
prompt = "Fact: The capital of the state containing Dallas is"
max_n_logits = 10
desired_logit_prob = 0.95
max_feature_nodes = 8192  # None for no limit, but will be slower
batch_size = 256
offload = 'cpu'  # Use 'disk' if running out of memory, None to keep everything on GPU
verbose = True

print(f"Prompt: {prompt}")
print(f"Max logits: {max_n_logits}")
print(f"Desired logit probability: {desired_logit_prob}")
print(f"Max feature nodes: {max_feature_nodes}")
print(f"Batch size: {batch_size}")
print(f"Offload strategy: {offload}")

Prompt: Fact: The capital of the state containing Dallas is
Max logits: 10
Desired logit probability: 0.95
Max feature nodes: 8192
Batch size: 256
Offload strategy: cpu


## 3. Run Attribution

This will compute the attribution graph showing the direct effects between features and output logits.

In [5]:
print("\nRunning attribution...\n")
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)
print("\nAttribution complete!")

Phase 0: Precomputing activations and vectors



Running attribution...



Precomputation completed in 0.31s
Found 9081 active features
Phase 1: Running forward pass
Forward pass completed in 0.10s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.7695
Will include 8192 of 9081 feature nodes
Input vectors built in 1.48s
Phase 3: Computing logit attributions
Logit attributions completed in 0.71s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 8192/8192 [00:07<00:00, 1078.98it/s]
Feature attributions completed in 7.60s
Attribution completed in 15.62s



Attribution complete!


## 4. Display Graph Statistics

Let's examine the structure of the attribution graph.

In [6]:
graph.active_features

tensor([[    0,     1,   127],
        [    0,     1,   208],
        [    0,     1,   355],
        ...,
        [   25,    10, 15131],
        [   25,    10, 16302],
        [   25,    10, 16326]], device='cuda:0')

In [7]:
print("=" * 60)
print("GRAPH STATISTICS")
print("=" * 60)

# Input information
print(f"\nInput String: {graph.input_string}")
print(f"Input Tokens: {graph.input_tokens.tolist()}")
print(f"Number of positions: {graph.n_pos}")

# Feature information
print(f"\nTotal active features: {len(graph.active_features)}")
print(f"Selected features for graph: {len(graph.selected_features)}")

# Node structure
n_layers = graph.cfg.n_layers
n_pos = graph.n_pos
n_error_nodes = n_layers * n_pos
n_embed_nodes = n_pos
n_logit_nodes = len(graph.logit_tokens)
total_nodes = len(graph.selected_features) + n_error_nodes + n_embed_nodes + n_logit_nodes

print(f"\nGraph Structure:")
print(f"  Feature nodes: {len(graph.selected_features)}")
print(f"  Error nodes: {n_error_nodes} ({n_layers} layers × {n_pos} positions)")
print(f"  Embedding nodes: {n_embed_nodes}")
print(f"  Logit nodes: {n_logit_nodes}")
print(f"  Total nodes: {total_nodes}")

# Edge information
adjacency_matrix = graph.adjacency_matrix
total_edges = (adjacency_matrix != 0).sum().item()
print(f"\nTotal non-zero edges: {total_edges:,}")
print(f"Adjacency matrix shape: {adjacency_matrix.shape}")
print(f"Adjacency matrix density: {total_edges / (adjacency_matrix.shape[0] * adjacency_matrix.shape[1]) * 100:.2f}%")

# Top logits
print(f"\nTop {len(graph.logit_tokens)} predicted logits:")
for i, (token_id, prob) in enumerate(zip(graph.logit_tokens, graph.logit_probabilities)):
    token_str = model.tokenizer.decode([token_id.item()])
    print(f"  {i+1}. '{token_str}' (token {token_id.item()}) - probability: {prob.item():.4f}")

print("\n" + "=" * 60)

GRAPH STATISTICS

Input String: <bos>Fact: The capital of the state containing Dallas is
Input Tokens: [2, 18143, 235292, 714, 6037, 576, 573, 2329, 10751, 26865, 603]
Number of positions: 11

Total active features: 9081
Selected features for graph: 8192

Graph Structure:
  Feature nodes: 8192
  Error nodes: 286 (26 layers × 11 positions)
  Embedding nodes: 11
  Logit nodes: 10
  Total nodes: 8499

Total non-zero edges: 19,022,999
Adjacency matrix shape: torch.Size([8499, 8499])
Adjacency matrix density: 26.34%

Top 10 predicted logits:
  1. ' Austin' (token 22605) - probability: 0.4453
  2. ' not' (token 780) - probability: 0.0776
  3. ' the' (token 573) - probability: 0.0532
  4. ' Texas' (token 9447) - probability: 0.0415
  5. ' Fort' (token 9778) - probability: 0.0366
  6. ' Houston' (token 22898) - probability: 0.0286
  7. ' Dallas' (token 26865) - probability: 0.0251
  8. ' ' (token 235248) - probability: 0.0251
  9. ' Oklahoma' (token 28239) - probability: 0.0197
  10. ' San' (t

## 4.5. Path Finding: Core DP Implementation

Implement k-best paths algorithm using Dynamic Programming. This finds the exact top-K most influential paths by processing nodes in reverse topological order.

**Optimization:** This version avoids expensive list copying by storing lightweight references during DP and reconstructing paths only at the end (~10-100x faster).

In [8]:
def compute_topological_order(adjacency_matrix):
    """
    Compute topological order using Kahn's algorithm.
    
    Args:
        adjacency_matrix: torch.Tensor of shape (n_nodes, n_nodes)
                         where adjacency_matrix[target, source] represents edge from source -> target
    
    Returns:
        list: Topological order of node indices
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Compute in-degrees: for each node, count how many edges point TO it
    in_degree = (adjacency_matrix != 0).sum(dim=1).cpu()
    
    # Initialize queue with nodes that have no incoming edges
    queue = [i for i in range(n_nodes) if in_degree[i] == 0]
    topo_order = []
    
    while queue:
        node = queue.pop(0)
        topo_order.append(node)
        
        # For each outgoing edge from this node
        outgoing_edges = (adjacency_matrix[:, node] != 0).cpu()
        
        for target in range(n_nodes):
            if outgoing_edges[target]:
                in_degree[target] -= 1
                if in_degree[target] == 0:
                    queue.append(target)
    
    if len(topo_order) != n_nodes:
        print(f"Warning: Graph contains a cycle! Only {len(topo_order)}/{n_nodes} nodes ordered.")
        remaining = [i for i in range(n_nodes) if i not in topo_order]
        topo_order.extend(remaining)
    
    return topo_order

adjacency_matrix = graph.adjacency_matrix

print("Computing topological order of the attribution graph...")
topo_order = compute_topological_order(adjacency_matrix)
print(f"Topological order computed: {len(topo_order)} nodes")

Computing topological order of the attribution graph...
Topological order computed: 8499 nodes


## 4.6. Find Complete Paths: Embedding → Austin Logit

Find the top-10 most influential complete paths from input embedding tokens to the Austin logit prediction.

In [9]:
from dataclasses import dataclass
from typing import List
import pandas as pd
from tqdm import tqdm
from circuit_tracer.graph import compute_node_influence

@dataclass
class Path:
    """Represents a path through the attribution graph."""
    nodes: List[int]
    edges: List[float]
    score: float
    
    def __len__(self):
        return len(self.nodes)
    
    def get_node_types(self, graph) -> List[str]:
        """Return node types: 'feature', 'error', 'embed', 'logit'."""
        n_features = len(graph.selected_features)
        n_errors = graph.cfg.n_layers * graph.n_pos
        n_embeds = graph.n_pos
        
        types = []
        for node in self.nodes:
            if node < n_features:
                types.append('feature')
            elif node < n_features + n_errors:
                types.append('error')
            elif node < n_features + n_errors + n_embeds:
                types.append('embed')
            else:
                types.append('logit')
        return types
    
    def get_node_descriptions(self, graph, tokenizer) -> List[str]:
        """Return human-readable descriptions for each node."""
        descriptions = []
        n_features = len(graph.selected_features)
        n_errors = graph.cfg.n_layers * graph.n_pos
        n_embeds = graph.n_pos
        
        for node in self.nodes:
            if node < n_features:
                layer, pos, feat_idx = graph.active_features[graph.selected_features[node]].tolist()
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Feature L{layer}:F{feat_idx} @ pos {pos} ('{token}')")
            elif node < n_features + n_errors:
                error_idx = node - n_features
                layer = error_idx // graph.n_pos
                pos = error_idx % graph.n_pos
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Error L{layer} @ pos {pos} ('{token}')")
            elif node < n_features + n_errors + n_embeds:
                pos = node - n_features - n_errors
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Embedding @ pos {pos} ('{token}')")
            else:
                logit_idx = node - n_features - n_errors - n_embeds
                token = tokenizer.decode([graph.logit_tokens[logit_idx]])
                prob = graph.logit_probabilities[logit_idx].item()
                descriptions.append(f"Logit '{token}' (p={prob:.4f})")
        
        return descriptions


def find_k_best_paths_dp(adj_matrix, source_nodes, sink_node, topo_order, k=10, verbose=True):
    """OPTIMIZED: Find top-K paths using DP. Stores lightweight references, reconstructs at end."""
    best_path_refs = {}
    best_path_refs[sink_node] = [(None, None, 1.0)]
    
    iterator = reversed(topo_order) if not verbose else tqdm(
        reversed(topo_order), desc="DP path finding (optimized)", total=len(topo_order)
    )
    
    for node in iterator:
        if node == sink_node:
            continue
        
        outgoing_weights = adj_matrix[:, node]
        successors = torch.where(outgoing_weights != 0)[0]
        
        if len(successors) == 0:
            best_path_refs[node] = []
            continue
        
        candidate_refs = []
        for succ in successors:
            succ_idx = succ.item()
            if succ_idx not in best_path_refs or len(best_path_refs[succ_idx]) == 0:
                continue
            
            edge_weight = outgoing_weights[succ].item()
            for succ_next, succ_edge, path_score in best_path_refs[succ_idx]:
                new_score = abs(edge_weight) * path_score
                candidate_refs.append((succ_idx, edge_weight, new_score))
        
        candidate_refs.sort(key=lambda x: x[2], reverse=True)
        best_path_refs[node] = candidate_refs[:k]
    
    def reconstruct_path(start_node, path_ref_index):
        """Reconstruct full path by following successor chain."""
        nodes, edges = [start_node], []
        current_node, current_ref_idx = start_node, path_ref_index
        
        while True:
            next_node, edge_weight, score = best_path_refs[current_node][current_ref_idx]
            if next_node is None:
                break
            
            nodes.append(next_node)
            edges.append(edge_weight)
            current_node = next_node
            
            target_score = score / abs(edge_weight)
            current_ref_idx = 0
            for i, (nn, ne, ns) in enumerate(best_path_refs[current_node]):
                if abs(ns - target_score) < 1e-9:
                    current_ref_idx = i
                    break
        
        return Path(nodes=nodes, edges=edges, score=best_path_refs[start_node][path_ref_index][2])
    
    all_source_paths = []
    for source in source_nodes:
        if source in best_path_refs and len(best_path_refs[source]) > 0:
            for path_idx in range(len(best_path_refs[source])):
                all_source_paths.append(reconstruct_path(source, path_idx))
    
    all_source_paths.sort(key=lambda p: p.score, reverse=True)
    return all_source_paths[:k]


print("✅ Path finding functions loaded (optimized version)")

✅ Path finding functions loaded (optimized version)


In [10]:
# Define node indices
n_features = len(graph.selected_features)
n_errors = graph.cfg.n_layers * graph.n_pos
n_embeds = graph.n_pos
n_logits = len(graph.logit_tokens)

embed_start = n_features + n_errors
embed_end = embed_start + n_embeds
embed_nodes = list(range(embed_start, embed_end))
austin_logit = embed_end  # Index 8489

print(f"Finding top-10 complete paths: Embeddings [{embed_start}:{embed_end}] → Austin [{austin_logit}]")
print()

# Find complete paths
complete_paths = find_k_best_paths_dp(
    adj_matrix=adjacency_matrix,
    source_nodes=embed_nodes,
    sink_node=austin_logit,
    topo_order=topo_order,
    k=10,
    verbose=True
)

print(f"\n✅ Found {len(complete_paths)} complete paths!")
print()

# Display paths
for rank, path in enumerate(complete_paths, 1):
    node_descs = path.get_node_descriptions(graph, model.tokenizer)
    node_types = path.get_node_types(graph)
    
    print(f"Path #{rank} (Score: {path.score:.8f}, Length: {len(path)})")
    print(f"  {' → '.join(node_types)}")
    print(f"  Start: {node_descs[0]}")
    print(f"  End: {node_descs[-1]}")
    print()

Finding top-10 complete paths: Embeddings [8478:8489] → Austin [8489]



DP path finding (optimized): 100%|██████████| 8499/8499 [03:10<00:00, 44.58it/s] 


✅ Found 10 complete paths!

Path #1 (Score: 486778389695780618240.00000000, Length: 27)
  embed → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → logit
  Start: Embedding @ pos 2 (':')
  End: Logit ' Austin' (p=0.4453)

Path #2 (Score: 390916144627167264768.00000000, Length: 27)
  embed → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → logit
  Start: Embedding @ pos 2 (':')
  End: Logit ' Austin' (p=0.4453)

Path #3 (Score: 363696143468630310912.00000000, Length: 27)
  embed → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → feature → fea




## 4.7. Export Paths to DataFrame

Create a structured DataFrame with path details for analysis and export to CSV.

In [11]:
def paths_to_dataframe(paths, graph, tokenizer):
    """Convert list of Path objects to pandas DataFrame."""
    rows = []
    
    for rank, path in enumerate(paths, 1):
        node_types = path.get_node_types(graph)
        node_descs = path.get_node_descriptions(graph, tokenizer)
        
        rows.append({
            'rank': rank,
            'influence_score': path.score,
            'length': len(path),
            'start_type': node_types[0],
            'end_type': node_types[-1],
            'start_description': node_descs[0],
            'end_description': node_descs[-1],
            'path_summary': ' → '.join(node_types),
            'full_path': ' → '.join(node_descs),
            'min_edge_weight': min(abs(w) for w in path.edges) if path.edges else 0.0,
            'max_edge_weight': max(abs(w) for w in path.edges) if path.edges else 0.0,
            'avg_edge_weight': sum(abs(w) for w in path.edges) / len(path.edges) if path.edges else 0.0,
        })
    
    return pd.DataFrame(rows)


# Create DataFrame
df_complete = paths_to_dataframe(complete_paths, graph, model.tokenizer)

print("Complete Paths DataFrame:")
print(df_complete[['rank', 'influence_score', 'length', 'path_summary']])
print()

# Save to CSV
df_complete.to_csv('complete_paths_austin.csv', index=False)
print("✅ Saved to complete_paths_austin.csv")

Complete Paths DataFrame:
   rank  influence_score  length  \
0     1     4.867784e+20      27   
1     2     3.909161e+20      27   
2     3     3.636961e+20      27   
3     4     3.383220e+20      26   
4     5     3.014183e+20      27   
5     6     2.922102e+20      27   
6     7     2.920727e+20      27   
7     8     2.798841e+20      27   
8     9     2.787846e+20      27   
9    10     2.716956e+20      26   

                                        path_summary  
0  embed → feature → feature → feature → feature ...  
1  embed → feature → feature → feature → feature ...  
2  embed → feature → feature → feature → feature ...  
3  embed → feature → feature → feature → feature ...  
4  embed → feature → feature → feature → feature ...  
5  embed → feature → feature → feature → feature ...  
6  embed → feature → feature → feature → feature ...  
7  embed → feature → feature → feature → feature ...  
8  embed → feature → feature → feature → feature ...  
9  embed → feature → featur

## 5. Save the Graph

Save the attribution graph to a .pt file for later use.

In [None]:
# Create output directory and save graph
graph_dir = Path('graphs')
graph_dir.mkdir(exist_ok=True)

graph_name = 'dallas_capital_attribution.pt'
graph_path = graph_dir / graph_name

print(f"Saving graph to {graph_path}...")
graph.to_pt(graph_path)
print(f"Graph saved successfully! (Size: {graph_path.stat().st_size / 1024 / 1024:.2f} MB)")

## 6. Create Visualization Files

Generate graph files for interactive visualization. The pruning thresholds control how much of the graph to keep:
- `node_threshold`: Keep minimum nodes whose cumulative influence >= this value
- `edge_threshold`: Keep minimum edges whose cumulative influence >= this value

In [None]:
slug = "dallas-capital"  # Name for this graph
graph_file_dir = './graph_files'
node_threshold = 0.8  # Keep nodes explaining 80% of influence
edge_threshold = 0.98  # Keep edges explaining 98% of influence

print(f"Creating visualization files with slug '{slug}'...")
print(f"Node threshold: {node_threshold}, Edge threshold: {edge_threshold}")

create_graph_files(
    graph_or_path=graph_path,
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)

print(f"Visualization files created in {graph_file_dir}/")

## 7. Launch Visualization Server

Start a local server to interactively explore the attribution graph.

**Features:**
- Click to select nodes
- Ctrl/Cmd+Click to pin/unpin nodes to your subgraph
- G+Click on nodes to group them into supernodes
- Edit node descriptions by clicking the edit button

In [None]:
from circuit_tracer.frontend.local_server import serve
from IPython.display import IFrame

port = 8046
print(f"Starting visualization server on port {port}...")
server = serve(data_dir='./graph_files/', port=port)

print(f"\nVisualization server is running!")
print(f"Open your graph here: http://localhost:{port}/index.html")
print(f"\nTo stop the server later, run: server.stop()")

# Display in iframe
display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))

## 8. Stop the Server (Optional)

Uncomment and run this cell when you're done with visualization.

In [None]:
# server.stop()
# print("Server stopped.")