In [1]:
import json
import networkx as nx
from pathlib import Path
from typing import Tuple, Dict, Any, List, Set, Optional


# ----------------------
# Graph IO
# ----------------------
def load_graph(json_path: str) -> Tuple[nx.DiGraph, Dict[str, Any]]:
    path = Path(json_path)
    with open(path, "r") as f:
        data = json.load(f)

    G = nx.DiGraph()
    for node in data["nodes"]:
        G.add_node(node["node_id"], **node)
    for link in data["links"]:
        G.add_edge(link["source"], link["target"], **link)

    return G, data


def save_graph(G: nx.DiGraph, data: Dict[str, Any], out_path: str, graph_metadata_path: str = None) -> None:
    out_path = Path(out_path)
    slug = out_path.stem

    pruned_nodes: List[Dict[str, Any]] = [dict(G.nodes[n]) for n in G.nodes]
    pruned_links: List[Dict[str, Any]] = [dict(d) for _, _, d in G.edges(data=True)]

    new_data: Dict[str, Any] = data.copy()
    new_data["nodes"] = pruned_nodes
    new_data["links"] = pruned_links
    new_data.setdefault("metadata", {})
    new_data["metadata"]["slug"] = slug

    with open(out_path, "w") as f:
        json.dump(new_data, f, indent=2)
    print(f"✅ Graph saved to {out_path}")

    if graph_metadata_path:
        update_graph_metadata(graph_metadata_path, new_data["metadata"])



def update_graph_metadata(graph_metadata_path: str, new_metadata: Dict[str, Any]) -> None:
    try:
        with open(graph_metadata_path, "r") as f:
            graph_meta = json.load(f)
    except FileNotFoundError:
        print(f"⚠️ Warning: {graph_metadata_path} not found. Creating a new one.")
        graph_meta = {"graphs": []}

    # Remove any existing entry with the same slug
    slug = new_metadata.get("slug")
    graph_meta["graphs"] = [g for g in graph_meta["graphs"] if g.get("slug") != slug]

    # Try to clone a base entry if available
    base_entry = graph_meta["graphs"][0] if graph_meta["graphs"] else {}
    new_entry: Dict[str, Any] = dict(base_entry)
    new_entry["slug"] = slug

    if "prompt" in new_metadata:
        new_entry["prompt"] = "pruned:" + new_metadata["prompt"]
    if "prompt_tokens" in new_metadata:
        new_entry["prompt_tokens"] = new_metadata["prompt_tokens"]
    if "scan" in new_metadata:
        new_entry["scan"] = new_metadata["scan"]
    if "node_threshold" in new_metadata:
        new_entry["node_threshold"] = new_metadata["node_threshold"]

    graph_meta["graphs"].append(new_entry)

    with open(graph_metadata_path, "w") as f:
        json.dump(graph_meta, f, indent=2)

    print(f"🧩 Updated graph entry for slug '{slug}' in: {graph_metadata_path}")


# ----------------------
# Pruning helpers
# ----------------------
def prune_edges(G: nx.DiGraph, top_k = 5) -> nx.DiGraph:
    to_remove = []
    for n in G.nodes():
        r = sorted(list(G.edges(n)), key=lambda e: abs(G.get_edge_data(e[0], e[1]).get("weight")))[:-top_k]
        to_remove += r
    G.remove_edges_from(to_remove)
    print(f"<UNK> Removed {len(to_remove)} edges")
    return G


def prune_nodes(G: nx.DiGraph) -> nx.DiGraph:
    to_remove = [n for n in G.nodes() if G.nodes[n].get("feature_type") == "mlp reconstruction error"]
    G.remove_nodes_from(to_remove)
    print(f"<UNK> Removed {len(to_remove)} nodes")
    return G

# ----------------------
# Circuit helpers
# ----------------------
def get_circuit_nodes(data: Dict[str, Any]) -> Set[str]:
    pinned: str = data.get("qParams", {}).get("pinnedIds", "")
    return set(pinned.split(",")) if pinned else set()


def set_circuit(data: Dict[str, Any], node_ids: List[str]) -> Dict[str, Any]:
    data.setdefault("qParams", {})
    data["qParams"]["pinnedIds"] = ",".join(node_ids)
    print(f"Set circuit with {len(node_ids)} pinned nodes")
    return data

In [5]:
G, data = load_graph("attribution_output/graph/1755609379.json")
#G = prune_edges(G, top_k=20)
G = prune_nodes(G)
save_graph(G, data, Path("attribution_output/graph/graph_with_circuit.json"), graph_metadata_path=Path("attribution_output/graph/graph-metadata.json"))

<UNK> Removed 764 nodes
✅ Graph saved to attribution_output/graph/graph_with_circuit.json
🧩 Updated graph entry for slug 'graph_with_circuit' in: attribution_output/graph/graph-metadata.json
