# Attribution Demo 

<a target="_blank" href="https://colab.research.google.com/github/safety-research/circuit-tracer/blob/main/demos/attribute_demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In this demo, you'll learn how to load models and perform attribution on them.

In [3]:
#@title Colab Setup Environment

try:
    import google.colab
    !mkdir -p repository && cd repository && \
     git clone https://github.com/safety-research/circuit-tracer && \
     curl -LsSf https://astral.sh/uv/install.sh | sh && \
     uv pip install -e circuit-tracer/

    import sys
    from huggingface_hub import notebook_login
    sys.path.append('repository/circuit-tracer')
    sys.path.append('repository/circuit-tracer/demos')
    notebook_login(new_session=False)
    IN_COLAB = True
except ImportError:
    IN_COLAB = False
    # import sys
    # sys.path.insert(0, "/home/tu/circuit-tracer/circuit_tracer")

In [4]:
from pathlib import Path
import torch

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils.create_graph_files import create_graph_files_topk
from circuit_tracer.graph import Graph, prune_graph, prune_graph_topk, compute_graph_scores

First, load your model and transcoders by name. `model_name` is a normal HuggingFace / [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens) model name; we'll use `google/gemma-2-2b`. We set `transcoder_name` to `gemma`, which is shorthand for the [Gemma Scope](https://arxiv.org/abs/2408.05147) transcoders; we take the transcoders with lowest L0 (mean # of active features) for each layer.

We additionally support `model_name = "meta-llama/Llama-3.2-1B"`, with `"llama"` transcoders; these are ReLU skip-transcoders that we trained, available [here](https://huggingface.co/mntss/skip-transcoder-Llama-3.2-1B-131k-nobos/tree/new-training).

If you want to use other models, you'll have to provide your own transcoders. To do this, set `transcoder_name` to point to your own configuration file, specifying the list of transcoders that you want to use. You can see `circuit_tracer/configs` for example configs.

In [None]:
model_name = 'google/gemma-2-2b'
transcoder_name = "genma" #"gemma" mntss/clt-gemma-2-2b-426k
model = ReplacementModel.from_pretrained(model_name, transcoder_name, dtype=torch.bfloat16)

config.yaml:   0%|          | 0.00/142 [00:00<?, ?B/s]

Fetching 52 files:   0%|          | 0/52 [00:00<?, ?it/s]

W_dec_10.safetensors:   0%|          | 0.00/1.21G [00:00<?, ?B/s]

W_dec_11.safetensors:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

W_dec_14.safetensors:   0%|          | 0.00/906M [00:00<?, ?B/s]

W_dec_13.safetensors:   0%|          | 0.00/981M [00:00<?, ?B/s]

W_dec_1.safetensors:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

W_dec_15.safetensors:   0%|          | 0.00/830M [00:00<?, ?B/s]

W_dec_12.safetensors:   0%|          | 0.00/1.06G [00:00<?, ?B/s]

W_dec_0.safetensors:   0%|          | 0.00/1.96G [00:00<?, ?B/s]

Cancellation requested; stopping current tasks.


W_dec_17.safetensors:   0%|          | 0.00/679M [00:00<?, ?B/s]

W_dec_2.safetensors:   0%|          | 0.00/1.81G [00:00<?, ?B/s]

W_dec_19.safetensors:   0%|          | 0.00/528M [00:00<?, ?B/s]

W_dec_20.safetensors:   0%|          | 0.00/453M [00:00<?, ?B/s]

W_dec_21.safetensors:   0%|          | 0.00/377M [00:00<?, ?B/s]

W_dec_16.safetensors:   0%|          | 0.00/755M [00:00<?, ?B/s]

W_dec_18.safetensors:   0%|          | 0.00/604M [00:00<?, ?B/s]

Next, set your attribution arguments.

In [4]:
prompt = "The capital of state containing Dallas is"  # What you want to get the graph for
max_n_logits = 10   # How many logits to attribute from, max. We attribute to min(max_n_logits, n_logits_to_reach_desired_log_prob); see below for the latter
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 8192  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=256  # Batch size when attributing
offload='disk' if IN_COLAB else 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True  # Whether to display a tqdm progress bar and timing report

Then, just run attribution!

In [5]:
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)

Phase 0: Precomputing activations and vectors
Precomputation completed in 0.49s
Found 6371 active features
Phase 1: Running forward pass
Forward pass completed in 0.09s
Phase 2: Building input vectors
Selected 10 logits with cumulative probability 0.7188
Will include 6371 of 6371 feature nodes
Input vectors built in 1.79s
Phase 3: Computing logit attributions
Logit attributions completed in 0.08s
Phase 4: Computing feature attributions
Feature influence computation: 100%|██████████| 6371/6371 [00:00<00:00, 10741.66it/s]
Feature attributions completed in 0.60s
Attribution completed in 8.24s


We now have a graph object! We can save it as a .pt file, but be warned that it's large (~167MB).

In [5]:
graph_dir = 'graphs'
graph_name = 'example_graph.pt'
graph_dir = Path(graph_dir)
graph_dir.mkdir(exist_ok=True)
graph_path = graph_dir / graph_name

# graph.to_pt(graph_path)

In [3]:
from circuit_tracer.graph import Graph, prune_graph, prune_graph_topk, compute_graph_scores
graph = Graph.from_pt(graph_path)
node_mask, edge_mask, graph_score = topk_prune_graph(graph, top_k = 3)

ImportError: cannot import name 'prune_graph_topk' from 'circuit_tracer.graph' (/home/tu/.conda/envs/circuit/lib/python3.13/site-packages/circuit_tracer/graph.py)

In [None]:
edge_mask.sum(), node_mask.sum()

(tensor(33), tensor(31))

In [18]:
# Print out all nodes, and if a node is a Feature, print its active feature details.

n_features = len(graph.active_features)         # first set of nodes: feature nodes
n_token = len(graph.input_tokens)                # token nodes count
n_error = graph.cfg.n_layers * n_token            # error nodes count

node_attrs = {}
for orig_idx in range(graph.adjacency_matrix.size(0)):
    if orig_idx < n_features:
        node_type = "Feature"
        data = tuple(graph.active_features[orig_idx].tolist())
    elif orig_idx < n_features + n_error:
        node_type = "Error"
        data = None
    elif orig_idx < n_features + n_error + n_token:
        node_type = "Token"
        data = None
    else:
        node_type = "Logit"
        data = None
    node_attrs[orig_idx] = {"type": node_type, "data": data}

for node in torch.where(node_mask)[0].tolist():
    node_type = node_attrs[node]["type"]
    print(f"Node {node}: {node_type}")
    if node_type == "Feature":
        print("  Active feature:", graph.active_features[node])

Node 288: Feature
  Active feature: tensor([   0,    4, 7750])
Node 505: Feature
  Active feature: tensor([   0,    6, 5626])
Node 854: Feature
  Active feature: tensor([   1,    6, 4767])
Node 1161: Feature
  Active feature: tensor([   2,    6, 9457])
Node 1427: Feature
  Active feature: tensor([   3,    6, 5892])
Node 1843: Feature
  Active feature: tensor([    4,     6, 13154])
Node 3758: Feature
  Active feature: tensor([   7,    6, 6861])
Node 5310: Feature
  Active feature: tensor([  14,    6, 2268])
Node 5414: Feature
  Active feature: tensor([16,  6, 25])
Node 5717: Feature
  Active feature: tensor([   20,     7, 15589])
Node 6393: Error
Node 6425: Error
Node 6570: Error
Node 6579: Token
Node 6580: Token
Node 6581: Token
Node 6582: Token
Node 6583: Token
Node 6584: Token
Node 6585: Token
Node 6586: Token
Node 6587: Logit
Node 6588: Logit
Node 6589: Logit
Node 6590: Logit
Node 6591: Logit
Node 6592: Logit
Node 6593: Logit
Node 6594: Logit
Node 6595: Logit
Node 6596: Logit


Given this object, we can create the graph files that we need to visualize the graph. Give it a slug (name), and set the node / edge thresholds for pruning. Pruning removes unimportant nodes and edges from your graph; lower thresholds (i.e., more aggressive pruning) results in smaller graphs. These may be easier to interpret, but explain less of the model's behavior.

In [6]:
slug = "dallas-austin"  # this is the name that you assign to the graph
graph_file_dir = './graph_files'  # where to write the graph files. no need to make this one; create_graph_files does that for you
node_threshold=0.8  # keep only the minimum # of nodes whose cumulative influence is >= 0.8
edge_threshold=0.98  # keep only the minimum # of edges whose cumulative influence is >= 0.98
topk = 3
# create_graph_files(
#     graph_or_path=graph_path,  # the graph to create files for
#     slug=slug,
#     output_path=graph_file_dir,
#     node_threshold=node_threshold,
#     edge_threshold=edge_threshold
# )

create_graph_files_topk(
    graph_or_path=graph_path,  # the graph to create files for
    slug=slug,
    output_path=graph_file_dir,
    top_k = topk
)

Now, you can visualize the graph using the following commands! This will spin up a local server to act as the frontend.

**If you're running this notebook on a remote server, make sure that you set up port forwarding, so that the chosen port is accessible on your local machine too.**

You can select nodes by clicking on them. Ctrl/Cmd+Click on nodes to pin and unpin them to your subgraph. G+Click on nodes in the subgraph to group them together into a supernode; G+Click on the X next to a supernode to dissolve it. Click on the edit button to edit node descriptions, and click on supernode description to edit that.

In [9]:
from circuit_tracer.frontend.local_server import serve


port = 8046
server = serve(data_dir='./graph_files/', port=port)

if IN_COLAB:
    from google.colab import output as colab_output  # noqa
    colab_output.serve_kernel_port_as_iframe(port, path='/index.html', height='800px', cache_in_notebook=True)
else:
    from IPython.display import IFrame
    print(f"Use the IFrame below, or open your graph here: f'http://localhost:{port}/index.html'")
    # display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))


Use the IFrame below, or open your graph here: f'http://localhost:8046/index.html'


Once you're done, you can stop the server with the following command.

In [8]:
server.stop()

Congrats, you're done! Go to `intervention_demo.ipynb` to see how to perform interventions, or check out `gemma_demo.ipynb` and `llama_demo.ipynb` for examples of worked-out test examples. Read on for a bit more info aabout the Graph class and pruning.

## Graphs

Earlier, you created a graph object. Its adjacency matrix / edge weights are stored in `graph.adjacency_matrix` in a dense format; rows are target nodes and columns are source nodes. The first `len(graph.real_features)` entries of the matrix represent features; the `i`th entry corresponds to the `i`th feature in `graph.real_features`, given in `(layer, position, feature_idx)` format. The next `graph.cfg.n_layers * graph.n_pos` entries are error_nodes. The next `graph.n_pos` entries are token nodes. The final `len(graph.logit_tokens)` entries are logit nodes.

The value of the cell `graph.adjacency_matrix[target, source]` is the direct effect of the source node on the target node. That is, it tells you how much the target node's value would change if the source node were set to 0, while holding the attention patterns, layernorm denominators, and other feature activations constatnt. Thus, if the target node is a feature, this tells you how much the target feature would change; if the target node is a logit, this tells you how much the (de-meaned) value of the logit would change.

Note that `gemma-2-2b` is model (family) that uses logit softcapping. This means that a softcap function, `softcap(x) = t * tanh(x/t)` is used to constrain the logits to fall within (-t, t); `gemma-2-2b` uses `t=30`. For such models, we predict the change in logits *pre-softcap*, as the nonlinearity introduced by softcapping would cause our attribution to yield incorrect / approximate direct effect values.

### Pruning
Given a graph, you might want to prune it, as it will otherwise contain many low-impact nodes and edges that clutter the circuit diagram while adding little information. We enable you to prune nodes by absolute influence, i.e. the total impact that the nodes have on the logits, direct and indirect. The default threshold is 0.8: this means we will keep the minimum number of nodes required to capture 80% of all logit effects. Similarly, the edge_threshold, by default 0.98, means that we will keep the minimum number of edges required to capture 98% of all logit effects.

In [1]:
import torch
from circuit_tracer.graph import Graph, prune_graph, topk_prune_graph, compute_graph_scores

# Load the graph (change the path as needed)
graph = Graph.from_pt("graphs/example_graph.pt")
n_tokens = len(graph.input_tokens)
n_logits = len(graph.logit_tokens)
# Prune the graph (adjust thresholds as desired)
prune_result = prune_graph(graph, node_threshold=0.8, edge_threshold=0.98)

# Access the pruned properties
node_mask = prune_result.node_mask
edge_mask = prune_result.edge_mask

# Optionally, extract the pruned adjacency matrix:
# pruned_adjacency = graph.adjacency_matrix[node_mask][:, node_mask]
graph.adjacency_matrix.size(0)

6597

In [2]:
replacement_score, completeness_score = compute_graph_scores(graph)
replacement_score, completeness_score

(0.7157570123672485, 0.9245588183403015)

In [5]:
len(graph.selected_features), len(graph.active_features)

(6371, 6371)

In [6]:
import torch
import networkx as nx
import matplotlib.pyplot as plt

def topk_prune_graph(graph, node_mask, edge_mask, top_k=3):
    n_tokens = len(graph.input_tokens)
    n_logits = len(graph.logit_tokens)
    total_nodes = graph.adjacency_matrix.size(0)
    
    # Identify highest logit node (nodes at the end of the graph)
    highest_logit_rel = torch.argmax(graph.logit_probabilities).item()
    highest_logit_node = total_nodes - n_logits + highest_logit_rel

    visited = set()
    edge_list = []  # will hold tuples of (src, tgt, weight)

    def dfs(node_idx):
        if node_idx in visited:
            return
        if not node_mask[node_idx]:
            return
        visited.add(node_idx)
        
        # Get the row corresponding to incoming effects for this target node.
        row = graph.adjacency_matrix[node_idx]
        valid_edges = edge_mask[node_idx] & node_mask
        filtered_row = row.clone()
        filtered_row[~valid_edges] = 0.0
        
        if torch.sum(filtered_row.abs()) == 0:
            return
        
        nonzero_idx = (filtered_row.abs() > 0).nonzero(as_tuple=True)[0]
        if len(nonzero_idx) == 0:
            return
        
        cur_top_k = min(top_k, len(nonzero_idx))
        # Use absolute values to choose the strongest connections.
        top_vals, top_indices = torch.topk(filtered_row, cur_top_k)
        for src in top_indices.tolist():
            # Record the edge (note that row index = target, column index = source)
            weight = graph.adjacency_matrix[node_idx, src].item()
            edge_list.append((src, node_idx, weight))
            dfs(src)

    dfs(highest_logit_node)

    # Build the subgraph: sort the visited nodes.
    sub_nodes = sorted(visited)
    # Map original node indices to new indices in the subgraph.
    index_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(sub_nodes)}
    
    # Create a new (sparse) adjacency matrix for the subgraph.
    new_adj = torch.zeros((len(sub_nodes), len(sub_nodes)))
    for src, tgt, weight in edge_list:
        if src in index_map and tgt in index_map:
            new_src = index_map[src]
            new_tgt = index_map[tgt]
            # Remember: rows correspond to targets and columns to sources.
            new_adj[new_tgt, new_src] = weight

    # Determine original node types.
    # Graph's node ordering (as per the docstring):
    #   [Feature nodes, Error nodes, Token nodes, Logit nodes]
    n_features = len(graph.selected_features)
    n_token = len(graph.input_tokens)
    n_error = graph.cfg.n_layers * n_token
    
    # Logit nodes are the final n_logits.

    node_attrs = {}
    for orig_idx in sub_nodes:
        if orig_idx < n_features:
            node_type = "Feature"
            data = tuple(graph.active_features[orig_idx].tolist())
        elif orig_idx < n_features + n_error:
            node_type = "Error"
            data = None
        elif orig_idx < n_features + n_error + n_token:
            node_type = "Token"
            data = None
        else:
            node_type = "Logit"
            data = None
        node_attrs[orig_idx] = {"type": node_type, "data": data}
    
    # Return the subgraph information:
    #   sub_nodes : list of original node indices in the subgraph
    #   new_adj   : new (sparse) adjacency matrix for these nodes
    #   node_attrs: dictionary of original node attributes (for visualization)
    #   edge_list : list of DFS-selected edges (from original graph)
    return sub_nodes, new_adj, node_attrs, edge_list

# --- Example usage and visualization ---

# Assume you've already loaded your graph object and computed:
#  node_mask, edge_mask = prune_result.node_mask, prune_result.edge_mask
sub_nodes, new_adj, node_attrs, edge_list = topk_prune_graph(graph, node_mask, edge_mask, top_k=3)
print("Subgraph nodes (original indices):", sub_nodes)
print("Subgraph adjacency matrix shape:", new_adj.shape)

Subgraph nodes (original indices): [288, 505, 854, 1161, 1427, 1843, 3758, 5310, 5414, 5717, 6393, 6425, 6570, 6579, 6581, 6582, 6583, 6585, 6587]
Subgraph adjacency matrix shape: torch.Size([19, 19])


In [5]:
n_features = len(graph.active_features)

feature_nodes = [node for node in nodes_found if node < n_features]
print("Feature nodes:", feature_nodes)

# Trace back the feature data: each feature is stored as (layer, position, feature_idx)
for i, node in enumerate(feature_nodes):
    layer, pos, feature_idx = graph.active_features[node]
    print(f"{i}: Node {node}: layer {layer}, position {pos}, feature index {feature_idx}")

Feature nodes: [5717, 5310, 3758, 1843, 1427, 1161, 505, 3, 1, 854, 288, 508, 479, 844, 516, 500, 496, 5414, 5582, 5497, 5328, 2639, 1529, 996, 92, 166, 986, 125, 1546, 5422, 5775, 5580, 5641, 5974, 5862, 5639]
0: Node 5717: layer 20, position 7, feature index 15589
1: Node 5310: layer 14, position 6, feature index 2268
2: Node 3758: layer 7, position 6, feature index 6861
3: Node 1843: layer 4, position 6, feature index 13154
4: Node 1427: layer 3, position 6, feature index 5892
5: Node 1161: layer 2, position 6, feature index 9457
6: Node 505: layer 0, position 6, feature index 5626
7: Node 3: layer 0, position 1, feature index 354
8: Node 1: layer 0, position 1, feature index 96
9: Node 854: layer 1, position 6, feature index 4767
10: Node 288: layer 0, position 4, feature index 7750
11: Node 508: layer 0, position 6, feature index 6116
12: Node 479: layer 0, position 6, feature index 1847
13: Node 844: layer 1, position 6, feature index 2132
14: Node 516: layer 0, position 6, featu

In [6]:
graph.adjacency_matrix[5328][feature_nodes]

tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.0728,
         0.0136,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  3.9375,  2.1875,  1.1016,
         0.8008,  0.6289,  0.9180, -0.0801,  2.7656,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000])

# Visualization