In [None]:
def get_adjacency_without_error_nodes(adjacency_matrix):
    # remove the error nodes from the adjacency matrix!
    n_features = len(graph.selected_features)
    n_error_nodes = len(graph.input_tokens) * model.cfg.n_layers

    mask = t.ones_like(adjacency_matrix[0]).to(t.bool)
    mask[n_features : n_features + n_error_nodes] = False
    updated_matrix = adjacency_matrix[mask][:, mask]

    return updated_matrix


In [None]:
def compute_topological_order(adjacency_matrix):
    """
    Compute topological order using Kahn's algorithm.
    
    Args:
        adjacency_matrix: torch.Tensor of shape (n_nodes, n_nodes)
                         where adjacency_matrix[target, source] represents edge from source -> target
    
    Returns:
        list: Topological order of node indices
    """
    n_nodes = adjacency_matrix.shape[0]
    
    # Compute in-degrees: for each node, count how many edges point TO it
    in_degree = (adjacency_matrix != 0).sum(dim=1).cpu()
    
    # Initialize queue with nodes that have no incoming edges
    queue = [i for i in range(n_nodes) if in_degree[i] == 0]
    topo_order = []
    
    while queue:
        node = queue.pop(0)
        topo_order.append(node)
        
        # For each outgoing edge from this node
        outgoing_edges = (adjacency_matrix[:, node] != 0).cpu()
        
        for target in range(n_nodes):
            if outgoing_edges[target]:
                in_degree[target] -= 1
                if in_degree[target] == 0:
                    queue.append(target)
    
    if len(topo_order) != n_nodes:
        print(f"Warning: Graph contains a cycle! Only {len(topo_order)}/{n_nodes} nodes ordered.")
        remaining = [i for i in range(n_nodes) if i not in topo_order]
        topo_order.extend(remaining)
    
    return topo_order

adjacency_matrix = graph.adjacency_matrix

print("Computing topological order of the attribution graph...")
topo_order = compute_topological_order(adjacency_matrix)
print(f"Topological order computed: {len(topo_order)} nodes")

In [None]:
def test_topological_sort(topological_sort, n_error_nodes=286):
    '''
    the only 'illegal moves' are  
    (a) moving to same token, previous layer  
    (b) moving to previous token, previous layer  
    '''

    prev_layer = -1
    prev_token_pos = 0

    for idx, node in enumerate(topological_sort[286:]):
        layer, token_pos = get_node_details(node)
        if (layer < prev_layer) and (token_pos <= prev_token_pos):
            print(f'error in topological sort at idx: {idx}')
        
    print(f'topological sort is okay!')

test_topological_sort(topo_order)

In [None]:
from dataclasses import dataclass
from typing import List
import pandas as pd
from tqdm import tqdm
from circuit_tracer.graph import compute_node_influence

@dataclass
class Path:
    """Represents a path through the attribution graph."""
    nodes: List[int]
    edges: List[float]
    score: float
    # averaged_score: float
    final_score: float
    
    def __len__(self):
        return len(self.nodes)
    
    def get_node_types(self, graph) -> List[str]:
        """Return node types: 'feature', 'error', 'embed', 'logit'."""
        n_features = len(graph.selected_features)
        n_errors = graph.cfg.n_layers * graph.n_pos
        n_embeds = graph.n_pos
        
        types = []
        for node in self.nodes:
            if node < n_features:
                types.append('feature')
            elif node < n_features + n_errors:
                types.append('error')
            elif node < n_features + n_errors + n_embeds:
                types.append('embed')
            else:
                types.append('logit')
        return types
    
    def get_node_descriptions(self, graph, tokenizer) -> List[str]:
        """Return human-readable descriptions for each node."""
        descriptions = []
        n_features = len(graph.selected_features)
        n_errors = graph.cfg.n_layers * graph.n_pos
        n_embeds = graph.n_pos
        
        for node in self.nodes:
            if node < n_features:
                layer, pos, feat_idx = graph.active_features[graph.selected_features[node]].tolist()
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Feature L{layer}:F{feat_idx} @ pos {pos} ('{token}')")
            elif node < n_features + n_errors:
                error_idx = node - n_features
                layer = error_idx // graph.n_pos
                pos = error_idx % graph.n_pos
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Error L{layer} @ pos {pos} ('{token}')")
            elif node < n_features + n_errors + n_embeds:
                pos = node - n_features - n_errors
                token = tokenizer.decode([graph.input_tokens[pos]])
                descriptions.append(f"Embedding @ pos {pos} ('{token}')")
            else:
                logit_idx = node - n_features - n_errors - n_embeds
                token = tokenizer.decode([graph.logit_tokens[logit_idx]])
                prob = graph.logit_probabilities[logit_idx].item()
                descriptions.append(f"Logit '{token}' (p={prob:.4f})")
        
        return descriptions


def find_k_best_paths_dp(adj_matrix, source_nodes, sink_node, topo_order, k=10, verbose=True):
    """OPTIMIZED: Find top-K paths using DP. Stores lightweight references, reconstructs at end."""
    best_path_refs = {}
    best_path_refs[sink_node] = [(None, None, 1.0)]
    
    iterator = reversed(topo_order) if not verbose else tqdm(
        reversed(topo_order), desc="DP path finding (optimized)", total=len(topo_order)
    )
    
    for node in iterator:
        if node == sink_node:
            continue
        
        outgoing_weights = adj_matrix[:, node]
        successors = t.where(outgoing_weights != 0)[0]
        
        if len(successors) == 0:
            best_path_refs[node] = []
            continue
        
        candidate_refs = []
        for succ in successors:
            succ_idx = succ.item()
            if succ_idx not in best_path_refs or len(best_path_refs[succ_idx]) == 0:
                continue
            
            edge_weight = outgoing_weights[succ].item()
            for succ_next, succ_edge, path_score in best_path_refs[succ_idx]:
                new_score = abs(edge_weight) * path_score
                candidate_refs.append((succ_idx, edge_weight, new_score))
        
        candidate_refs.sort(key=lambda x: x[2], reverse=True)
        best_path_refs[node] = candidate_refs[:k]
    
    def reconstruct_path(start_node, path_ref_index):
        """Reconstruct full path by following successor chain."""
        nodes, edges = [start_node], []
        current_node, current_ref_idx = start_node, path_ref_index
        
        while True:
            next_node, edge_weight, score = best_path_refs[current_node][current_ref_idx]
            if next_node is None:
                break
            
            nodes.append(next_node)
            edges.append(edge_weight)
            current_node = next_node
            
            target_score = score / abs(edge_weight)
            current_ref_idx = 0
            for i, (nn, ne, ns) in enumerate(best_path_refs[current_node]):
                if abs(ns - target_score) < 1e-9:
                    current_ref_idx = i
                    break
        
        return Path(nodes=nodes, edges=edges, score=best_path_refs[start_node][path_ref_index][2], final_score=0.0)
    
    all_source_paths = []
    for source in source_nodes:
        if source in best_path_refs and len(best_path_refs[source]) > 0:
            for path_idx in range(len(best_path_refs[source])):
                all_source_paths.append(reconstruct_path(source, path_idx))

    return all_source_paths
    
    # for path in all_source_paths:
    #     path.averaged_score = path.score / len(path.edges)
    #     path.final_score = sum(path.edges) / len(path.edges)
    
    # all_source_paths.sort(key=lambda p: p.final_score, reverse=True)
    # return all_source_paths[:k]


print("✅ Path finding functions loaded (optimized version)")

In [None]:
# Define node indices
n_features = len(graph.selected_features)
n_errors = graph.cfg.n_layers * graph.n_pos
n_embeds = graph.n_pos
n_logits = len(graph.logit_tokens)

embed_start = n_features + n_errors
embed_end = embed_start + n_embeds
embed_nodes = list(range(embed_start, embed_end))
austin_logit = embed_end  # Index 8489

print(f"Finding top-10 complete paths: Embeddings [{embed_start}:{embed_end}] → Austin [{austin_logit}]")
print()

# Find complete paths
all_complete_paths = find_k_best_paths_dp(
    adj_matrix=adjacency_matrix,
    source_nodes=embed_nodes,
    sink_node=austin_logit,
    topo_order=topo_order,
    k=10,
    verbose=True
)

In [None]:
for path in all_complete_paths:
    path.final_score = sum(path.edges) / len(path.edges)

all_complete_paths_sorted = sorted(all_complete_paths, key=lambda p: p.final_score, reverse=True)
top_k_complete_paths = all_complete_paths_sorted[:10]

print(f"\n✅ Found {len(top_k_complete_paths)} complete paths!")
print()

# Display paths
for rank, path in enumerate(top_k_complete_paths, 1):
    node_descs = path.get_node_descriptions(graph, model.tokenizer)
    node_types = path.get_node_types(graph)
    
    print(f"Path #{rank} (Score: {path.final_score:.8f}, Length: {len(path)})")
    print(f"  {' → '.join(node_types)}")
    print(f"  Start: {node_descs[0]}")
    print(f"  End: {node_descs[-1]}")
    print()

In [None]:
def paths_to_dataframe(paths, graph, tokenizer):
    """Convert list of Path objects to pandas DataFrame."""
    rows = []
    
    for rank, path in enumerate(paths, 1):
        node_types = path.get_node_types(graph)
        node_descs = path.get_node_descriptions(graph, tokenizer)
        
        rows.append({
            'rank': rank,
            'influence_score': path.score,
            'length': len(path),
            'start_type': node_types[0],
            'end_type': node_types[-1],
            'start_description': node_descs[0],
            'end_description': node_descs[-1],
            'path_summary': ' → '.join(node_types),
            'full_path': ' → '.join(node_descs),
            'min_edge_weight': min(abs(w) for w in path.edges) if path.edges else 0.0,
            'max_edge_weight': max(abs(w) for w in path.edges) if path.edges else 0.0,
            'avg_edge_weight': sum(abs(w) for w in path.edges) / len(path.edges) if path.edges else 0.0,
        })
    
    return pd.DataFrame(rows)


# Create DataFrame
df_complete = paths_to_dataframe(complete_paths, graph, model.tokenizer)

print("Complete Paths DataFrame:")
print(df_complete[['rank', 'influence_score', 'length', 'path_summary']])
print()

# Save to CSV
df_complete.to_csv('complete_paths_austin.csv', index=False)
print("✅ Saved to complete_paths_austin.csv")