In [None]:
# Toy example to draw an LLM "circuit" using the user's helpers.
# - Builds a fake attention influence matrix across layers×heads
# - Picks top-k heads by influence
# - Inserts MLP nodes between layers to form a multipartite graph
# - Lays out and renders the circuit with rounded rectangles for nodes
#
# Output:
# - Displays the plot inline
# - Saves a PNG to /mnt/data/llm_circuit_toy.png

import numpy as np
import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt

# -------------------------- Helpers (from user) --------------------------
def get_top_attn_heads(mat, pct):
    """Gets the attention heads that are in the top `pct` of influence."""
    heads = []
    for lay_idx in range(len(mat)):
        for attn_idx in range(len(mat[lay_idx])):
            heads.append((lay_idx, attn_idx, mat[lay_idx, attn_idx]))
    heads.sort(key=lambda x: x[2], reverse=True)
    top_heads = heads[: int(len(heads) * pct)]
    return top_heads

def attn_heads_multipartite(heads):
    """Sorts the attention heads into multipartite layers and inserts MLP
    layers between them. Assumes that the heads are a list of tuples
    (layer, head, influence)."""
    min_lay_idx = min(heads, key=lambda x: x[0])[0]
    max_lay_idx = max(heads, key=lambda x: x[0])[0]

    edges = []
    all_nodes = []
    vals = {}

    counter = 1
    for lay_idx in range(min_lay_idx, max_lay_idx + 1):
        # get all of the attn heads in that layer
        lay_nodes = list(filter(lambda x: x[0] == lay_idx, heads))
        # sort the nodes by their index
        lay_nodes.sort(key=lambda x: x[1])
        for node in lay_nodes:
            head_name = f"{node[0]}.{node[1]}"
            mlp_name = f"MLP {node[0]}"
            all_nodes.append((head_name, counter))
            edges.append((head_name, mlp_name))
            vals[head_name] = node[2]
            if lay_idx != min_lay_idx:
                edges.append((f"MLP {node[0] - 1}", head_name))
        if len(lay_nodes) != 0:
            counter += 1
        else:
            edges.append((f"MLP {lay_idx-1}", f"MLP {lay_idx}"))
        all_nodes.append((f"MLP {lay_idx}", counter))
        counter += 1
    return all_nodes, edges, vals

def draw_rounded_node(ax, pos, node, node_color, width=0.06, height=0.2):
    x, y = pos[node]
    # Rounded rectangle
    box = mpl.patches.FancyBboxPatch(
        (x - width / 2, y - height / 2),
        width,
        height,
        boxstyle="round,pad=0.01,rounding_size=0.01",
        linewidth=1,
        facecolor=node_color,
        edgecolor="black",
    )
    ax.add_patch(box)

    # Choose readable text color
    if (
        "MLP" in node
        or (node_color[0] * 0.299 + node_color[1] * 0.587 + node_color[2] * 0.114)
        > 0.73
    ):
        text_color = "black"
    else:
        text_color = "white"

    ax.text(
        x,
        y,
        node,
        ha="center",
        va="center",
        fontsize=8,
        weight="bold",
        color=text_color,
    )

def make_circuit_graph(nodes, edges, vals, color="viridis", scale=1.0):
    # Create a directed graph
    G = nx.DiGraph()
    for node, subset in nodes:
        G.add_node(node, subset=subset)
    G.add_edges_from(edges)

    # Normalize scalar values to a range between 0 and 1
    norm = plt.Normalize(vmin=min(vals.values()), vmax=max(vals.values()))

    # Use a colormap (e.g., 'viridis') to map scalar values to colors
    cmap = plt.get_cmap(color)
    vals = vals.copy()
    for k, v in vals.items():
        vals[k] = cmap(norm(v))

    pos = nx.multipartite_layout(G, subset_key="subset", align="horizontal", scale=scale)
    return G, pos, vals

# -------------------------- Toy data --------------------------
# Suppose a tiny 4-layer transformer, each with 6 attention heads.
rng = np.random.default_rng(42)
num_layers = 4
num_heads = 6
# "Influence" matrix per (layer, head)
mat = rng.uniform(0, 1, size=(num_layers, num_heads))

# Keep the top 40% most influential heads overall
top_heads = get_top_attn_heads(mat, pct=0.40)

# Build multipartite graph with MLPs between layers
nodes, edges, vals = attn_heads_multipartite(top_heads)

# For MLP nodes, give them a neutral mid value so they get a mid colormap color
if len(vals) > 0:
    mid_val = (np.min(list(vals.values())) + np.max(list(vals.values()))) / 2
else:
    mid_val = 0.5
for node, subset in nodes:
    if node.startswith("MLP ") and node not in vals:
        vals[node] = mid_val

G, pos, color_map = make_circuit_graph(nodes, edges, vals, color="viridis", scale=2.0)

# -------------------------- Render --------------------------
fig = plt.figure(figsize=(8, 5))
ax = plt.gca()
ax.set_axis_off()

# Draw directed edges (arrows left-to-right)
nx.draw_networkx_edges(G, pos, arrows=True, width=1.0, min_source_margin=10, min_target_margin=10)

# Draw nodes as rounded rectangles with color by influence
for node in G.nodes():
    draw_rounded_node(ax, pos, node, color_map[node])

plt.tight_layout()
# out_path = "/mnt/data/llm_circuit_toy.png"
# plt.savefig(out_path, dpi=200, bbox_inches="tight")
plt.show()

out_path

ModuleNotFoundError: No module named 'config'