# Attribution Graph for Dallas Capital Query

This notebook creates an attribution graph for the sentence:
**"Fact: The capital of the state containing Dallas is"**

We'll use the Gemma-2 (2B) model with GemmaScope transcoders to analyze the circuit.

## Setup

In [1]:
from pathlib import Path
import torch as t
from bs4 import BeautifulSoup
import requests
from matplotlib import pyplot as plt
import einops
import torch as t
from math import perm, comb
from tqdm import tqdm
from huggingface_hub import login
from dotenv import load_dotenv
import os

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files

In [2]:
load_dotenv()
login(os.environ['HF_TOKEN'])

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


### Load the model + run attribution

In [3]:
model_name = 'google/gemma-2-2b'
transcoder_name = "gemma"  # GemmaScope transcoders

print(f"Loading {model_name} with {transcoder_name} transcoders...")
model = ReplacementModel.from_pretrained(
    model_name, 
    transcoder_name, 
    dtype=t.bfloat16,
    lazy_encoder=True
)
print("Model loaded successfully!")

Loading google/gemma-2-2b with gemma transcoders...


Fetching 26 files:   0%|          | 0/26 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2-2b into HookedTransformer
Model loaded successfully!


In [4]:
# Attribution parameters
prompt = "Fact: The capital of the state containing Dallas is"
max_n_logits = 10
desired_logit_prob = 0.95
max_feature_nodes = 8192  # None for no limit, but will be slower
batch_size = 256
offload = 'cpu'  # Use 'disk' if running out of memory, None to keep everything on GPU
verbose = True

print(f"Prompt: {prompt}")
print(f"Max logits: {max_n_logits}")
print(f"Desired logit probability: {desired_logit_prob}")
print(f"Max feature nodes: {max_feature_nodes}")
print(f"Batch size: {batch_size}")
print(f"Offload strategy: {offload}")

Prompt: Fact: The capital of the state containing Dallas is
Max logits: 10
Desired logit probability: 0.95
Max feature nodes: 8192
Batch size: 256
Offload strategy: cpu


In [5]:
print("\nRunning attribution...\n")
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)
print("\nAttribution complete!")

Phase 0: Precomputing activations and vectors



Running attribution...



Precomputation completed in 0.41s
Found 9081 active features
Phase 1: Running forward pass
Forward pass completed in 0.08s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.7695
Will include 8192 of 9081 feature nodes
Input vectors built in 1.25s
Phase 3: Computing logit attributions
Logit attributions completed in 0.08s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 8192/8192 [00:02<00:00, 3312.38it/s]
Feature attributions completed in 2.48s
Attribution completed in 8.33s



Attribution complete!


## Post Processing

### Save & Visualize

In [None]:
from pathlib import Path as LibPath

In [None]:
# Create output directory and save graph
graph_dir = LibPath('graphs')
graph_dir.mkdir(exist_ok=True)

graph_name = 'dallas_capital_attribution.pt'
graph_path = graph_dir / graph_name

print(f"Saving graph to {graph_path}...")
graph.to_pt(graph_path)
print(f"Graph saved successfully! (Size: {graph_path.stat().st_size / 1024 / 1024:.2f} MB)")

- `node_threshold`: Keep minimum nodes whose cumulative influence >= this value
- `edge_threshold`: Keep minimum edges whose cumulative influence >= this value

In [None]:
slug = "dallas-capital"  # Name for this graph
graph_file_dir = './graph_files'
node_threshold = 0.8  # Keep nodes explaining 80% of influence
edge_threshold = 0.98  # Keep edges explaining 98% of influence

print(f"Creating visualization files with slug '{slug}'...")
print(f"Node threshold: {node_threshold}, Edge threshold: {edge_threshold}")

create_graph_files(
    graph_or_path=graph_path,
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)

print(f"Visualization files created in {graph_file_dir}/")

In [None]:
from circuit_tracer.frontend.local_server import serve
from IPython.display import IFrame

port = 8047
print(f"Starting visualization server on port {port}...")
server = serve(data_dir='./graph_files/', port=port)

print(f"\nVisualization server is running!")
print(f"Open your graph here: http://localhost:{port}/index.html")
print(f"\nTo stop the server later, run: server.stop()")

# Display in iframe
display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))

In [None]:
server.stop()
# print("Server stopped.")

## Utilities

### Display Graph Stats

In [None]:
print(f'number of active features: {len(graph.active_features)}')
print(f'length of adjacency matrix: {len(graph.adjacency_matrix)}')
print(f'number of "activation values": {len(graph.activation_values)}')

In [None]:
print("=" * 60)
print("GRAPH STATISTICS")
print("=" * 60)

# Input information
print(f"\nInput String: {graph.input_string}")
print(f"Input Tokens: {graph.input_tokens.tolist()}")
print(f"Number of positions: {graph.n_pos}")

# Feature information
print(f"\nTotal active features: {len(graph.active_features)}")
print(f"Selected features for graph: {len(graph.selected_features)}")

# Node structure
n_layers = graph.cfg.n_layers
n_pos = graph.n_pos
n_error_nodes = n_layers * n_pos
n_embed_nodes = n_pos
n_logit_nodes = len(graph.logit_tokens)
total_nodes = len(graph.selected_features) + n_error_nodes + n_embed_nodes + n_logit_nodes

print(f"\nGraph Structure:")
print(f"  Feature nodes: {len(graph.selected_features)}")
print(f"  Error nodes: {n_error_nodes} ({n_layers} layers × {n_pos} positions)")
print(f"  Embedding nodes: {n_embed_nodes}")
print(f"  Logit nodes: {n_logit_nodes}")
print(f"  Total nodes: {total_nodes}")

# Edge information
adjacency_matrix = graph.adjacency_matrix
total_edges = (adjacency_matrix != 0).sum().item()
print(f"\nTotal non-zero edges: {total_edges:,}")
print(f"Adjacency matrix shape: {adjacency_matrix.shape}")
print(f"Adjacency matrix density: {total_edges / (adjacency_matrix.shape[0] * adjacency_matrix.shape[1]) * 100:.2f}%")

# Top logits
print(f"\nTop {len(graph.logit_tokens)} predicted logits:")
for i, (token_id, prob) in enumerate(zip(graph.logit_tokens, graph.logit_probabilities)):
    token_str = model.tokenizer.decode([token_id.item()])
    print(f"  {i+1}. '{token_str}' (token {token_id.item()}) - probability: {prob.item():.4f}")

print("\n" + "=" * 60)

### Calc Number of Paths

In [None]:
def _calculate_paths_length_base(i, n_tokens):
    if i == 2:
        return n_tokens

    final_length = 0
    for b_prime in range(n_tokens, 0, -1):
        final_length += _calculate_paths_length_base(i-1, b_prime)
    
    return final_length


def calculate_paths_simple_graph(layers, n_tokens, max_path_length=None) -> list[int]:
    '''
    Calculate the number of possible paths from source to sink for a 'simple attribution graph DAG'.
    
    :param layers: The number of layers in the model, including the embeddings & logits.
    :param n_tokens: The number of input tokens (ie. the 'base' of the attribution graph)

    :returns: A list with the number of complete paths of increasing lengths (from 2 to max_path_len, inclusive)
    '''
    
    if max_path_length is None:
        max_path_length = layers
    assert max_path_length <= layers
    
    def calculate_paths_length_i(i):
        factor = comb(layers-2, i-2)
        length = _calculate_paths_length_base(i, n_tokens)
        return factor * length

    all_paths = []
    for i in range(2, max_path_length + 1):
        paths_length_i = calculate_paths_length_i(i)
        all_paths.append(paths_length_i)

    return all_paths

all_paths = calculate_paths_simple_graph(26, 10, 6)

### Generate + Visualize Dummy Graph

In [None]:
def generate_adjacency_matrix_simple(n, b):
    total_nodes = n*b + 1
    base_matrix = einops.rearrange(t.arange(total_nodes-1), '(n b) -> n b', b=b)
    adjacency_matrix = t.zeros([total_nodes, total_nodes])

    for layer, nodes in enumerate(base_matrix):
        for token_pos, node in enumerate(nodes):
            # print(f'layer: {layer}')
            # print(f'token pos: {token_pos}')
            # print(f'node: {node}')

            layers_left = t.arange(layer+1, n)
            tokens_left = t.arange(token_pos, b)
            nodes_left_coords = t.cartesian_prod(layers_left, tokens_left)
            
            rows = nodes_left_coords[:, 0]
            cols = nodes_left_coords[:, 1]
            nodes_left = base_matrix[rows, cols]
            
            # print(f'coordinates of nodes left: {nodes_left_coords}')
            # print(f'nodes left: {nodes_left}')
            # print()

            adjacency_matrix[:, node][nodes_left] = 1

    adjacency_matrix[-1, :-1] = 1


    test_n = list(range(n))
    test_b = list(range(b))
    test_c = [(n_i, b_i) for n_i in test_n for b_i in test_b]

    node_info = dict(enumerate(test_c))
    node_info[total_nodes-1] = (n, b-1)


    return adjacency_matrix, node_info

adjacency_matrix, node_info = generate_adjacency_matrix_simple(n=25, b=10)

In [None]:
"""
Interactive visualization of attribution graphs with hover functionality.
"""
import numpy as np
import plotly.graph_objects as go

def get_node_label(node_idx: int, node_info: dict) -> str:
    """Generate a human-readable label for a node."""
    layer, pos = node_info[node_idx]
    total_layers = max(l for l, _ in node_info.values())

    if layer == total_layers:  # Sink node
        return f"Output"
    elif layer == 0:
        return f"Input[{pos}]"
    else:
        return f"L{layer}[{pos}]"

def visualize_attribution_graph(
    adj_matrix: np.ndarray,
    node_info: dict,
    title: str = "Attribution Graph"
):
    """
    Create an interactive visualization of the attribution graph.

    Hovering on nodes highlights:
    - Nodes it feeds into (outgoing edges) in green
    - Nodes it's influenced by (incoming edges) in blue
    """
    n_nodes = adj_matrix.shape[0]

    # Compute node positions (layer determines y, position determines x)
    node_positions = {}
    for node_idx, (layer, pos) in node_info.items():
        # x: position (with spacing), y: layer (with spacing)
        node_positions[node_idx] = (pos * 100, layer * 100)

    # Create edge traces
    edge_x = []
    edge_y = []
    edge_hover_text = []

    for target_idx in range(n_nodes):
        for source_idx in range(n_nodes):
            if adj_matrix[target_idx, source_idx] > 0:
                x0, y0 = node_positions[source_idx]
                x1, y1 = node_positions[target_idx]

                edge_x.extend([x0, x1, None])
                edge_y.extend([y0, y1, None])

                source_label = get_node_label(source_idx, node_info)
                target_label = get_node_label(target_idx, node_info)
                edge_hover_text.append(f"{source_label} → {target_label}")

    # Edge trace
    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        mode='lines',
        line=dict(width=0.5, color='#888'),
        hoverinfo='skip',
        showlegend=False
    )

    # Prepare node data
    node_x = []
    node_y = []
    node_text = []
    node_hover_info = []

    for node_idx in range(n_nodes):
        x, y = node_positions[node_idx]
        node_x.append(x)
        node_y.append(y)

        label = get_node_label(node_idx, node_info)
        node_text.append(label)

        # Find incoming and outgoing edges
        incoming_nodes = [i for i in range(n_nodes) if adj_matrix[node_idx, i] > 0]
        outgoing_nodes = [i for i in range(n_nodes) if adj_matrix[i, node_idx] > 0]

        incoming_labels = [get_node_label(i, node_info) for i in incoming_nodes]
        outgoing_labels = [get_node_label(i, node_info) for i in outgoing_nodes]

        hover_text = f"<b>{label}</b><br>"
        hover_text += f"<br><b>Influenced by ({len(incoming_nodes)}):</b><br>"
        hover_text += "<br>".join(incoming_labels[:10])  # Limit to first 10
        if len(incoming_labels) > 10:
            hover_text += f"<br>... and {len(incoming_labels) - 10} more"

        hover_text += f"<br><br><b>Feeds into ({len(outgoing_nodes)}):</b><br>"
        hover_text += "<br>".join(outgoing_labels[:10])
        if len(outgoing_labels) > 10:
            hover_text += f"<br>... and {len(outgoing_labels) - 10} more"

        node_hover_info.append(hover_text)

    # Node trace
    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers+text',
        marker=dict(
            size=20,
            color='lightblue',
            line=dict(width=2, color='darkblue')
        ),
        text=node_text,
        textposition="top center",
        textfont=dict(size=10),
        hovertext=node_hover_info,
        hoverinfo='text',
        showlegend=False
    )

    # Create figure
    fig = go.Figure(data=[edge_trace, node_trace])

    fig.update_layout(
        title=title,
        showlegend=False,
        hovermode='closest',
        xaxis=dict(
            title='Token Position',
            showgrid=True,
            zeroline=False,
            showticklabels=True
        ),
        yaxis=dict(
            title='Layer',
            showgrid=True,
            zeroline=False,
            showticklabels=True
        ),
        plot_bgcolor='white',
        width=1200,
        height=800
    )

    return fig

print("\nCreating interactive visualization...")
fig = visualize_attribution_graph(
    adjacency_matrix,
    node_info,
)

fig.show()  


## Core Pipeline

### Functions

In [6]:
n_features = len(graph.selected_features)
n_error_nodes = len(graph.input_tokens) * model.cfg.n_layers
n_embed_nodes = len(graph.input_tokens)
n_logit_nodes = len(graph.logit_tokens)

In [17]:
graph.logit_probabilities

tensor([0.4453, 0.0776, 0.0532, 0.0415, 0.0366, 0.0286, 0.0251, 0.0251, 0.0197,
        0.0153], device='cuda:0', dtype=torch.bfloat16)

In [25]:
def get_feature_details(matrix_idx: int) -> tuple[int, int, int]:
    assert matrix_idx < len(graph.selected_features), 'This node is not an active feature'
    feature_idx = graph.selected_features[matrix_idx]
    layer, token_pos, attribution_idx = graph.active_features[feature_idx]

    return (layer.item(), token_pos.item(), attribution_idx.item())

def get_node_details(node: int) -> tuple[int, int]:
    '''
    get the layer and token pos of any node in the adjacency matrix.
    could be a feature, an error, an embedding, or a logit node.

    layers range from -1 to 26  
    layer -1 is the embedding  
    layer 26 is the logits  

    token positions range from 1 to 10  
    (BOS is token 0; it is excluded)
    '''
    n_features = len(graph.selected_features)
    n_error_nodes = len(graph.input_tokens) * model.cfg.n_layers
    n_embed_nodes = len(graph.input_tokens)
    n_logit_nodes = len(graph.logit_tokens)

    if node < n_features:
        layer, token_pos, __ = get_feature_details(node)
    elif node < (n_features + n_error_nodes):
        error_number = node - n_features
        token_pos = error_number % 11
        layer = error_number // 11
    elif node < (n_features + n_error_nodes + n_embed_nodes):
        layer = -1
        token_pos = node - (n_features + n_error_nodes)
    else:
        layer = model.cfg.n_layers
        token_pos = len(graph.input_tokens) - 1

    return layer, token_pos

def matrix_idx_to_explanation(matrix_idx: int):
    is_feature = matrix_idx < n_features
    is_error = n_features <= matrix_idx < (n_features + n_error_nodes)
    is_embed = (n_features + n_error_nodes) <= matrix_idx < (n_features + n_error_nodes + n_embed_nodes)
    is_logit = (n_features + n_error_nodes + n_embed_nodes) <= matrix_idx
    
    if is_feature:
        layer, __, feature_idx = get_feature_details(matrix_idx)

        url = f'https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-transcoder-16k/{feature_idx}'
        data = requests.get(url)
        soup = BeautifulSoup(data.text, 'html.parser')

        body = soup.find('html').find('body')
        idx_a = str(body).find('explanationModelName')
        target_substring_large = str(body)[idx_a-200:idx_a]
        assert 'description' in target_substring_large

        idx_b = target_substring_large.find('description')
        const_1 = 16
        const_2 = 5
        target_substring_final = target_substring_large[idx_b + const_1: -const_2]

        return target_substring_final
    elif is_error:
        return 'error node – this should NOT be possible'
    elif is_embed:
        embed_num = matrix_idx - (n_features + n_error_nodes)
        embed_id = graph.input_tokens[embed_num]
        embed_token = model.tokenizer.decode(embed_id)

        return f'embed: {embed_token}'
    elif is_logit:
        logit_num = matrix_idx - (n_features + n_error_nodes + n_embed_nodes)
        logit_id = graph.logit_tokens[logit_num]
        logit_token = model.tokenizer.decode(logit_id)

        return f'logit: {logit_token} ({graph.logit_probabilities[logit_num]})'
    else:
        raise IndexError('Incorrect Matrix Idx')

In [8]:
def find_all_paths_wrapper(adj_matrix, start_tokens, end_token, max_path_length=None):
    all_paths = []

    def find_all_paths(current_path):
        assert len(current_path) >= 1
        current_node = current_path[-1]

        if current_node == end_token:
            all_paths.append(current_path)
            return
        if (max_path_length is not None) and (len(current_path) >= max_path_length):
            return
        
        mystack = t.nonzero(adj_matrix[:, current_node])
        for node in mystack:
            find_all_paths(current_path + [node.item()])


    for node in start_tokens:
        current_path = [node]
        find_all_paths(current_path)
        pass
    
    return all_paths

def score_paths(paths, adjacency_matrix):
    def score_path(path: list[int]):
        score = 0
        for source_node, target_node in zip(path[:-1], path[1:]):
            score += adjacency_matrix[target_node, source_node].item()
        score = score / (len(path) - 1)
        return score
    
    all_scores = [score_path(p) for p in tqdm(paths)]

    return all_scores


In [27]:
a = n_features + n_error_nodes
b = n_features + n_error_nodes + n_embed_nodes

all_paths = find_all_paths_wrapper(
    graph.adjacency_matrix,
    start_tokens = list(range(a, b)),
    end_token = b,
    max_path_length = 3,
)

scores = score_paths(all_paths, graph.adjacency_matrix)

out = sorted(zip(scores, all_paths), key=lambda z: z[0], reverse=True)

100%|██████████| 55556/55556 [00:00<00:00, 173228.82it/s]


In [30]:
top_k = 8
for score, path in out[:top_k]:
    print(f'score: {score}\npath:')
    for node in path:
        print(f'    ' + matrix_idx_to_explanation(node))
    print('\n')

score: 34.3671875
path:
    embed:  Dallas
     a collection of strings that make up a Dallas news station callsign
    logit:  Austin (0.4453125)


score: 32.28662109375
path:
    embed:  state
    the word \\\"state\\\" or phrases that include \\\"state\\\"
    logit:  Austin (0.4453125)


score: 31.83984375
path:
    embed:  Dallas
     references to geographic locations, especially in addresses
    logit:  Austin (0.4453125)


score: 29.12492847442627
path:
    embed: <bos>
     words related to a variety of coding languages, science, or other languages
    logit:  Austin (0.4453125)


score: 26.622802734375
path:
    embed: <bos>
     topics that are academic in nature, including math, linguistics/anthropology, religion, zoology, history or politics
    logit:  Austin (0.4453125)


score: 24.70166015625
path:
    embed:  capital
    the word \\\"capital\\\" and sometimes letters
    logit:  Austin (0.4453125)


score: 24.265625
path:
    embed:  Dallas
     references to geographi