# Attribution Graph for Dallas Capital Query

This notebook creates an attribution graph for the sentence:
**"Fact: The capital of the state containing Dallas is"**

We'll use the Gemma-2 (2B) model with GemmaScope transcoders to analyze the circuit.

## 1. Setup
- Load Model and Transcoders
- Configure Attribution Parameters
- Run Attribution

In [None]:
from pathlib import Path
import torch as t
from bs4 import BeautifulSoup
import requests
from matplotlib import pyplot as plt

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.utils import create_graph_files

In [None]:
from huggingface_hub import login
from dotenv import load_dotenv
import os

load_dotenv()
login(os.environ['HF_TOKEN'])

In [None]:
model_name = 'google/gemma-2-2b'
transcoder_name = "gemma"  # GemmaScope transcoders

print(f"Loading {model_name} with {transcoder_name} transcoders...")
model = ReplacementModel.from_pretrained(
    model_name, 
    transcoder_name, 
    dtype=t.bfloat16,
    lazy_encoder=True
)
print("Model loaded successfully!")

In [None]:
# Attribution parameters
prompt = "Fact: The capital of the state containing Dallas is"
max_n_logits = 10
desired_logit_prob = 0.95
max_feature_nodes = 8192  # None for no limit, but will be slower
batch_size = 256
offload = 'cpu'  # Use 'disk' if running out of memory, None to keep everything on GPU
verbose = True

print(f"Prompt: {prompt}")
print(f"Max logits: {max_n_logits}")
print(f"Desired logit probability: {desired_logit_prob}")
print(f"Max feature nodes: {max_feature_nodes}")
print(f"Batch size: {batch_size}")
print(f"Offload strategy: {offload}")

In [None]:
print("\nRunning attribution...\n")
graph = attribute(
    prompt=prompt,
    model=model,
    max_n_logits=max_n_logits,
    desired_logit_prob=desired_logit_prob,
    batch_size=batch_size,
    max_feature_nodes=max_feature_nodes,
    offload=offload,
    verbose=verbose
)
print("\nAttribution complete!")

## 2. Display Graph Statistics

Let's examine the structure of the attribution graph.

In [None]:
print(f'number of active features: {len(graph.active_features)}')
print(f'length of adjacency matrix: {len(graph.adjacency_matrix)}')
print(f'number of "activation values": {len(graph.activation_values)}')

In [None]:
print("=" * 60)
print("GRAPH STATISTICS")
print("=" * 60)

# Input information
print(f"\nInput String: {graph.input_string}")
print(f"Input Tokens: {graph.input_tokens.tolist()}")
print(f"Number of positions: {graph.n_pos}")

# Feature information
print(f"\nTotal active features: {len(graph.active_features)}")
print(f"Selected features for graph: {len(graph.selected_features)}")

# Node structure
n_layers = graph.cfg.n_layers
n_pos = graph.n_pos
n_error_nodes = n_layers * n_pos
n_embed_nodes = n_pos
n_logit_nodes = len(graph.logit_tokens)
total_nodes = len(graph.selected_features) + n_error_nodes + n_embed_nodes + n_logit_nodes

print(f"\nGraph Structure:")
print(f"  Feature nodes: {len(graph.selected_features)}")
print(f"  Error nodes: {n_error_nodes} ({n_layers} layers × {n_pos} positions)")
print(f"  Embedding nodes: {n_embed_nodes}")
print(f"  Logit nodes: {n_logit_nodes}")
print(f"  Total nodes: {total_nodes}")

# Edge information
adjacency_matrix = graph.adjacency_matrix
total_edges = (adjacency_matrix != 0).sum().item()
print(f"\nTotal non-zero edges: {total_edges:,}")
print(f"Adjacency matrix shape: {adjacency_matrix.shape}")
print(f"Adjacency matrix density: {total_edges / (adjacency_matrix.shape[0] * adjacency_matrix.shape[1]) * 100:.2f}%")

# Top logits
print(f"\nTop {len(graph.logit_tokens)} predicted logits:")
for i, (token_id, prob) in enumerate(zip(graph.logit_tokens, graph.logit_probabilities)):
    token_str = model.tokenizer.decode([token_id.item()])
    print(f"  {i+1}. '{token_str}' (token {token_id.item()}) - probability: {prob.item():.4f}")

print("\n" + "=" * 60)

## 3. Analysis

### Get Topological Order

In [None]:
def get_feature_details(matrix_idx: int) -> tuple[int, int, int]:
    assert matrix_idx < len(graph.selected_features), 'This node is not an active feature'
    feature_idx = graph.selected_features[matrix_idx]
    layer, token_pos, attribution_idx = graph.active_features[feature_idx]

    return (layer.item(), token_pos.item(), attribution_idx.item())

def matrix_idx_to_explanation(matrix_idx: int):
    layer, __, feature_idx = get_feature_details(matrix_idx)

    url = f'https://www.neuronpedia.org/gemma-2-2b/{layer}-gemmascope-transcoder-16k/{feature_idx}'
    data = requests.get(url)
    soup = BeautifulSoup(data.text, 'html.parser')

    body = soup.find('html').find('body')
    idx_a = str(body).find('explanationModelName')
    target_substring_large = str(body)[idx_a-200:idx_a]
    assert 'description' in target_substring_large

    idx_b = target_substring_large.find('description')
    const_1 = 16
    const_2 = 5
    target_substring_final = target_substring_large[idx_b + const_1: -const_2]

    return target_substring_final

In [None]:
def compute_topological_order(adjacency_matrix):
    """
    Compute topological order using Kahn's algorithm.
    
    Args:
        adjacency_matrix: torch.Tensor of shape (n_nodes, n_nodes)
                         where adjacency_matrix[target, source] represents edge from source -> target
    
    Returns:
        list: Topological order of node indices
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Compute in-degrees: for each node, count how many edges point TO it
    in_degree = (adjacency_matrix != 0).sum(dim=1).cpu()
    
    # Initialize queue with nodes that have no incoming edges
    queue = [i for i in range(n_nodes) if in_degree[i] == 0]
    topo_order = []
    
    while queue:
        node = queue.pop(0)
        topo_order.append(node)
        
        # For each outgoing edge from this node
        outgoing_edges = (adjacency_matrix[:, node] != 0).cpu()
        
        for target in range(n_nodes):
            if outgoing_edges[target]:
                in_degree[target] -= 1
                if in_degree[target] == 0:
                    queue.append(target)
    
    if len(topo_order) != n_nodes:
        print(f"Warning: Graph contains a cycle! Only {len(topo_order)}/{n_nodes} nodes ordered.")
        remaining = [i for i in range(n_nodes) if i not in topo_order]
        topo_order.extend(remaining)
    
    return topo_order

def get_adjacency_without_error_nodes(adjacency_matrix):
    # remove the error nodes from the adjacency matrix!
    n_features = len(graph.selected_features)
    n_error_nodes = len(graph.input_tokens) * model.cfg.n_layers

    mask = t.ones_like(adjacency_matrix[0]).to(t.bool)
    mask[n_features : n_features + n_error_nodes] = False
    updated_matrix = adjacency_matrix[mask][:, mask]

    return updated_matrix

adjacency_matrix = graph.adjacency_matrix

print("Computing topological order of the attribution graph...")
topo_order = compute_topological_order(adjacency_matrix)
print(f"Topological order computed: {len(topo_order)} nodes")

In [None]:
def get_node_details(node: int):
    '''
    layers range from -1 to 26  
    layer -1 is the embedding  
    layer 26 is the logits  

    token positions range from 1 to 10  
    (BOS is token 0; it is excluded)
    '''
    n_features = len(graph.selected_features)
    n_error_nodes = len(graph.input_tokens) * model.cfg.n_layers
    n_embed_nodes = len(graph.input_tokens)
    n_logit_nodes = len(graph.logit_tokens)

    if node < n_features:
        layer, token_pos, __ = get_feature_details(node)
    elif node < (n_features + n_error_nodes):
        error_number = node - n_features
        token_pos = error_number % 11
        layer = error_number // 11
    elif node < (n_features + n_error_nodes + n_embed_nodes):
        layer = -1
        token_pos = node - (n_features + n_error_nodes)
    else:
        layer = model.cfg.n_layers
        token_pos = len(graph.input_tokens) - 1

    return layer, token_pos

In [None]:
def test_topological_sort(topological_sort, n_error_nodes=286):
    '''
    the only 'illegal moves' are  
    (a) moving to same token, previous layer  
    (b) moving to previous token, previous layer  
    '''

    prev_layer = -1
    prev_token_pos = 0

    for idx, node in enumerate(topological_sort[286:]):
        layer, token_pos = get_node_details(node)
        if (layer < prev_layer) and (token_pos <= prev_token_pos):
            print(f'error in topological sort at idx: {idx}')
        
    print(f'topological sort is okay!')

test_topological_sort(topo_order)

### Find top-k most influential paths

In [None]:
from dataclasses import dataclass
from typing import List
import pandas as pd
from tqdm import tqdm
from circuit_tracer.graph import compute_node_influence

@dataclass
class Path:
    """Represents a path through the attribution graph."""
    nodes: List[int]
    edges: List[float]
    score: float
    # averaged_score: float
    final_score: float
    
    def __len__(self):
        return len(self.nodes)
    
    def get_node_types(self, graph) -> List[str]:
        """Return node types: 'feature', 'error', 'embed', 'logit'."""
        n_features = len(graph.selected_features)
        n_errors = graph.cfg.n_layers * graph.n_pos
        n_embeds = graph.n_pos
        
        types = []
        for node in self.nodes:
            if node < n_features:
                types.append('feature')
            elif node < n_features + n_errors:
                types.append('error')
            elif node < n_features + n_errors + n_embeds:
                types.append('embed')
            else:
                types.append('logit')
        return types
    
    def get_node_descriptions(self, graph, tokenizer) -> List[str]:
        """Return human-readable descriptions for each node."""
        descriptions = []
        n_features = len(graph.selected_features)
        n_errors = graph.cfg.n_layers * graph.n_pos
        n_embeds = graph.n_pos
        
        for node in self.nodes:
            if node < n_features:
                layer, pos, feat_idx = graph.active_features[graph.selected_features[node]].tolist()
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Feature L{layer}:F{feat_idx} @ pos {pos} ('{token}')")
            elif node < n_features + n_errors:
                error_idx = node - n_features
                layer = error_idx // graph.n_pos
                pos = error_idx % graph.n_pos
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Error L{layer} @ pos {pos} ('{token}')")
            elif node < n_features + n_errors + n_embeds:
                pos = node - n_features - n_errors
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Embedding @ pos {pos} ('{token}')")
            else:
                logit_idx = node - n_features - n_errors - n_embeds
                token = tokenizer.decode([graph.logit_tokens[logit_idx]])
                prob = graph.logit_probabilities[logit_idx].item()
                descriptions.append(f"Logit '{token}' (p={prob:.4f})")
        
        return descriptions


def find_k_best_paths_dp(adj_matrix, source_nodes, sink_node, topo_order, k=10, verbose=True):
    """OPTIMIZED: Find top-K paths using DP. Stores lightweight references, reconstructs at end."""
    best_path_refs = {}
    best_path_refs[sink_node] = [(None, None, 1.0)]
    
    iterator = reversed(topo_order) if not verbose else tqdm(
        reversed(topo_order), desc="DP path finding (optimized)", total=len(topo_order)
    )
    
    for node in iterator:
        if node == sink_node:
            continue
        
        outgoing_weights = adj_matrix[:, node]
        successors = t.where(outgoing_weights != 0)[0]
        
        if len(successors) == 0:
            best_path_refs[node] = []
            continue
        
        candidate_refs = []
        for succ in successors:
            succ_idx = succ.item()
            if succ_idx not in best_path_refs or len(best_path_refs[succ_idx]) == 0:
                continue
            
            edge_weight = outgoing_weights[succ].item()
            for succ_next, succ_edge, path_score in best_path_refs[succ_idx]:
                new_score = abs(edge_weight) * path_score
                candidate_refs.append((succ_idx, edge_weight, new_score))
        
        candidate_refs.sort(key=lambda x: x[2], reverse=True)
        best_path_refs[node] = candidate_refs[:k]
    
    def reconstruct_path(start_node, path_ref_index):
        """Reconstruct full path by following successor chain."""
        nodes, edges = [start_node], []
        current_node, current_ref_idx = start_node, path_ref_index
        
        while True:
            next_node, edge_weight, score = best_path_refs[current_node][current_ref_idx]
            if next_node is None:
                break
            
            nodes.append(next_node)
            edges.append(edge_weight)
            current_node = next_node
            
            target_score = score / abs(edge_weight)
            current_ref_idx = 0
            for i, (nn, ne, ns) in enumerate(best_path_refs[current_node]):
                if abs(ns - target_score) < 1e-9:
                    current_ref_idx = i
                    break
        
        return Path(nodes=nodes, edges=edges, score=best_path_refs[start_node][path_ref_index][2], final_score=0.0)
    
    all_source_paths = []
    for source in source_nodes:
        if source in best_path_refs and len(best_path_refs[source]) > 0:
            for path_idx in range(len(best_path_refs[source])):
                all_source_paths.append(reconstruct_path(source, path_idx))

    return all_source_paths
    
    # for path in all_source_paths:
    #     path.averaged_score = path.score / len(path.edges)
    #     path.final_score = sum(path.edges) / len(path.edges)
    
    # all_source_paths.sort(key=lambda p: p.final_score, reverse=True)
    # return all_source_paths[:k]


print("✅ Path finding functions loaded (optimized version)")

In [None]:
# Define node indices
n_features = len(graph.selected_features)
n_errors = graph.cfg.n_layers * graph.n_pos
n_embeds = graph.n_pos
n_logits = len(graph.logit_tokens)

embed_start = n_features + n_errors
embed_end = embed_start + n_embeds
embed_nodes = list(range(embed_start, embed_end))
austin_logit = embed_end  # Index 8489

print(f"Finding top-10 complete paths: Embeddings [{embed_start}:{embed_end}] → Austin [{austin_logit}]")
print()

# Find complete paths
all_complete_paths = find_k_best_paths_dp(
    adj_matrix=adjacency_matrix,
    source_nodes=embed_nodes,
    sink_node=austin_logit,
    topo_order=topo_order,
    k=10,
    verbose=True
)

In [None]:
for path in all_complete_paths:
    mylen = len(path.edges)
    if mylen <= 24:
        print(mylen)

In [None]:
mypath = all_complete_paths[0]
# print(f'len graph active featrues: {len(graph.active_features)}')

for node in mypath.nodes:
    # print(f'node: {node}')
    try:
        node_explanation = matrix_idx_to_explanation(node)
        print(node_explanation)
    except AssertionError as e:
        print(e)

In [None]:
for path in all_complete_paths:
    path.final_score = sum(path.edges) / len(path.edges)

all_complete_paths_sorted = sorted(all_complete_paths, key=lambda p: p.final_score, reverse=True)
top_k_complete_paths = all_complete_paths_sorted[:10]

print(f"\n✅ Found {len(top_k_complete_paths)} complete paths!")
print()

# Display paths
for rank, path in enumerate(top_k_complete_paths, 1):
    node_descs = path.get_node_descriptions(graph, model.tokenizer)
    node_types = path.get_node_types(graph)
    
    print(f"Path #{rank} (Score: {path.final_score:.8f}, Length: {len(path)})")
    print(f"  {' → '.join(node_types)}")
    print(f"  Start: {node_descs[0]}")
    print(f"  End: {node_descs[-1]}")
    print()

### Export Paths to DataFrame

In [None]:
def paths_to_dataframe(paths, graph, tokenizer):
    """Convert list of Path objects to pandas DataFrame."""
    rows = []
    
    for rank, path in enumerate(paths, 1):
        node_types = path.get_node_types(graph)
        node_descs = path.get_node_descriptions(graph, tokenizer)
        
        rows.append({
            'rank': rank,
            'influence_score': path.score,
            'length': len(path),
            'start_type': node_types[0],
            'end_type': node_types[-1],
            'start_description': node_descs[0],
            'end_description': node_descs[-1],
            'path_summary': ' → '.join(node_types),
            'full_path': ' → '.join(node_descs),
            'min_edge_weight': min(abs(w) for w in path.edges) if path.edges else 0.0,
            'max_edge_weight': max(abs(w) for w in path.edges) if path.edges else 0.0,
            'avg_edge_weight': sum(abs(w) for w in path.edges) / len(path.edges) if path.edges else 0.0,
        })
    
    return pd.DataFrame(rows)


# Create DataFrame
df_complete = paths_to_dataframe(complete_paths, graph, model.tokenizer)

print("Complete Paths DataFrame:")
print(df_complete[['rank', 'influence_score', 'length', 'path_summary']])
print()

# Save to CSV
df_complete.to_csv('complete_paths_austin.csv', index=False)
print("✅ Saved to complete_paths_austin.csv")

## 4. Post Processing

### Save the Graph

In [None]:
from pathlib import Path as LibPath

In [None]:
# Create output directory and save graph
graph_dir = LibPath('graphs')
graph_dir.mkdir(exist_ok=True)

graph_name = 'dallas_capital_attribution.pt'
graph_path = graph_dir / graph_name

print(f"Saving graph to {graph_path}...")
graph.to_pt(graph_path)
print(f"Graph saved successfully! (Size: {graph_path.stat().st_size / 1024 / 1024:.2f} MB)")

### Create Visualizations

Generate graph files for interactive visualization. The pruning thresholds control how much of the graph to keep:
- `node_threshold`: Keep minimum nodes whose cumulative influence >= this value
- `edge_threshold`: Keep minimum edges whose cumulative influence >= this value

**Graph Features:**
- Click to select nodes
- Ctrl/Cmd+Click to pin/unpin nodes to your subgraph
- G+Click on nodes to group them into supernodes
- Edit node descriptions by clicking the edit button

In [None]:
slug = "dallas-capital"  # Name for this graph
graph_file_dir = './graph_files'
node_threshold = 0.8  # Keep nodes explaining 80% of influence
edge_threshold = 0.98  # Keep edges explaining 98% of influence

print(f"Creating visualization files with slug '{slug}'...")
print(f"Node threshold: {node_threshold}, Edge threshold: {edge_threshold}")

create_graph_files(
    graph_or_path=graph_path,
    slug=slug,
    output_path=graph_file_dir,
    node_threshold=node_threshold,
    edge_threshold=edge_threshold
)

print(f"Visualization files created in {graph_file_dir}/")

In [None]:
from circuit_tracer.frontend.local_server import serve
from IPython.display import IFrame

port = 8047
print(f"Starting visualization server on port {port}...")
server = serve(data_dir='./graph_files/', port=port)

print(f"\nVisualization server is running!")
print(f"Open your graph here: http://localhost:{port}/index.html")
print(f"\nTo stop the server later, run: server.stop()")

# Display in iframe
display(IFrame(src=f'http://localhost:{port}/index.html', width='100%', height='800px'))

In [None]:
server.stop()
# print("Server stopped.")

## Scratch

In [2]:
import einops
import torch as t

In [3]:
def generate_adjacency_matrix(n, b):
    total_nodes = n*b + 1
    base_matrix = einops.rearrange(t.arange(total_nodes-1), '(n b) -> n b', b=b)
    adjacency_matrix = t.zeros([total_nodes, total_nodes])

    for layer, nodes in enumerate(base_matrix):
        for token_pos, node in enumerate(nodes):
            # print(f'layer: {layer}')
            # print(f'token pos: {token_pos}')
            # print(f'node: {node}')

            layers_left = t.arange(layer+1, n)
            tokens_left = t.arange(token_pos, b)
            nodes_left_coords = t.cartesian_prod(layers_left, tokens_left)
            
            rows = nodes_left_coords[:, 0]
            cols = nodes_left_coords[:, 1]
            nodes_left = base_matrix[rows, cols]
            
            # print(f'coordinates of nodes left: {nodes_left_coords}')
            # print(f'nodes left: {nodes_left}')
            # print()

            adjacency_matrix[:, node][nodes_left] = 1

    adjacency_matrix[-1, :-1] = 1


    test_n = list(range(n))
    test_b = list(range(b))
    test_c = [(n_i, b_i) for n_i in test_n for b_i in test_b]

    node_info = dict(enumerate(test_c))
    node_info[total_nodes-1] = (n, b-1)


    return adjacency_matrix, node_info

adjacency_matrix, node_info = generate_adjacency_matrix(n=3, b=2)
print(adjacency_matrix)

tensor([[0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.]])


In [None]:
def find_all_paths_wrapper(adj_matrix, b):
    start_stack = list(range(b))
    final_node = len(adj_matrix) - 1
    all_paths = []

    def find_all_paths(current_path):
        assert len(current_path) >= 1
        current_node = current_path[-1]
        if current_node == final_node:
            all_paths.append(current_path)
            return
        
        mystack = t.nonzero(adj_matrix[:, current_node])
        for node in mystack:
            find_all_paths(current_path + [node.item()])


    for node in start_stack:
        current_path = [node]
        find_all_paths(current_path)
        pass
    
    return all_paths

all_paths = find_all_paths_wrapper(
    adjacency_matrix,
    b=2,
)

[[0, 2, 4, 6],
 [0, 2, 5, 6],
 [0, 2, 6],
 [0, 3, 5, 6],
 [0, 3, 6],
 [0, 4, 6],
 [0, 5, 6],
 [0, 6],
 [1, 3, 5, 6],
 [1, 3, 6],
 [1, 5, 6],
 [1, 6]]

In [11]:
from math import perm, comb

In [None]:
def calculate_paths_length_base(i, b):
    if i == 2:
        return b

    final_length = 0
    for b_prime in range(b, 0, -1):
        final_length += calculate_paths_length_base(i-1, b_prime)
    
    return final_length


def calculate_paths_simple_graph(n, b):
    def calculate_paths_length_i(i):
        factor = comb(n-2, i-2)
        length = calculate_paths_length_base(i, b)
        return factor * length

    all_paths = []
    for i in range(2, n+1):
        paths_length_i = calculate_paths_length_i(i)
        all_paths.append(paths_length_i)

    return sum(all_paths)

calculate_paths_simple_graph(5, 3)

66

In [35]:
from tqdm import tqdm

In [48]:
for n in range(2, 8):
    for b in tqdm(range(1, 8)):
        adj_matrix, __ = generate_adjacency_matrix(n-1, b)
        l1 = find_all_paths_wrapper(adj_matrix, b)
        l2 = calculate_paths_simple_graph(n, b)

        # print(f'l1: {l1}')
        # print(f'l2: {l2}')

        assert len(l1) == l2

100%|██████████| 7/7 [00:00<00:00, 1590.99it/s]
100%|██████████| 7/7 [00:00<00:00, 834.76it/s]
100%|██████████| 7/7 [00:00<00:00, 400.11it/s]
100%|██████████| 7/7 [00:00<00:00, 224.92it/s]
100%|██████████| 7/7 [00:00<00:00, 111.73it/s]
100%|██████████| 7/7 [00:00<00:00, 39.48it/s]


In [5]:
"""
Interactive visualization of attribution graphs with hover functionality.
"""
import numpy as np
import plotly.graph_objects as go

def get_node_label(node_idx: int, node_info: dict) -> str:
    """Generate a human-readable label for a node."""
    layer, pos = node_info[node_idx]
    total_layers = max(l for l, _ in node_info.values())

    if layer == total_layers:  # Sink node
        return f"Output"
    elif layer == 0:
        return f"Input[{pos}]"
    else:
        return f"L{layer}[{pos}]"

def visualize_attribution_graph(
    adj_matrix: np.ndarray,
    node_info: dict,
    title: str = "Attribution Graph"
):
    """
    Create an interactive visualization of the attribution graph.

    Hovering on nodes highlights:
    - Nodes it feeds into (outgoing edges) in green
    - Nodes it's influenced by (incoming edges) in blue
    """
    n_nodes = adj_matrix.shape[0]

    # Compute node positions (layer determines y, position determines x)
    node_positions = {}
    for node_idx, (layer, pos) in node_info.items():
        # x: position (with spacing), y: layer (with spacing)
        node_positions[node_idx] = (pos * 100, layer * 100)

    # Create edge traces
    edge_x = []
    edge_y = []
    edge_hover_text = []

    for target_idx in range(n_nodes):
        for source_idx in range(n_nodes):
            if adj_matrix[target_idx, source_idx] > 0:
                x0, y0 = node_positions[source_idx]
                x1, y1 = node_positions[target_idx]

                edge_x.extend([x0, x1, None])
                edge_y.extend([y0, y1, None])

                source_label = get_node_label(source_idx, node_info)
                target_label = get_node_label(target_idx, node_info)
                edge_hover_text.append(f"{source_label} → {target_label}")

    # Edge trace
    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        mode='lines',
        line=dict(width=0.5, color='#888'),
        hoverinfo='skip',
        showlegend=False
    )

    # Prepare node data
    node_x = []
    node_y = []
    node_text = []
    node_hover_info = []

    for node_idx in range(n_nodes):
        x, y = node_positions[node_idx]
        node_x.append(x)
        node_y.append(y)

        label = get_node_label(node_idx, node_info)
        node_text.append(label)

        # Find incoming and outgoing edges
        incoming_nodes = [i for i in range(n_nodes) if adj_matrix[node_idx, i] > 0]
        outgoing_nodes = [i for i in range(n_nodes) if adj_matrix[i, node_idx] > 0]

        incoming_labels = [get_node_label(i, node_info) for i in incoming_nodes]
        outgoing_labels = [get_node_label(i, node_info) for i in outgoing_nodes]

        hover_text = f"<b>{label}</b><br>"
        hover_text += f"<br><b>Influenced by ({len(incoming_nodes)}):</b><br>"
        hover_text += "<br>".join(incoming_labels[:10])  # Limit to first 10
        if len(incoming_labels) > 10:
            hover_text += f"<br>... and {len(incoming_labels) - 10} more"

        hover_text += f"<br><br><b>Feeds into ({len(outgoing_nodes)}):</b><br>"
        hover_text += "<br>".join(outgoing_labels[:10])
        if len(outgoing_labels) > 10:
            hover_text += f"<br>... and {len(outgoing_labels) - 10} more"

        node_hover_info.append(hover_text)

    # Node trace
    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers+text',
        marker=dict(
            size=20,
            color='lightblue',
            line=dict(width=2, color='darkblue')
        ),
        text=node_text,
        textposition="top center",
        textfont=dict(size=10),
        hovertext=node_hover_info,
        hoverinfo='text',
        showlegend=False
    )

    # Create figure
    fig = go.Figure(data=[edge_trace, node_trace])

    fig.update_layout(
        title=title,
        showlegend=False,
        hovermode='closest',
        xaxis=dict(
            title='Token Position',
            showgrid=True,
            zeroline=False,
            showticklabels=True
        ),
        yaxis=dict(
            title='Layer',
            showgrid=True,
            zeroline=False,
            showticklabels=True
        ),
        plot_bgcolor='white',
        width=1200,
        height=800
    )

    return fig

print("\nCreating interactive visualization...")
fig = visualize_attribution_graph(
    adjacency_matrix,
    node_info,
)

fig.show()  



Creating interactive visualization...
