In [None]:
import json
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import to_rgba, to_hex
import colorsys
from matplotlib.patches import FancyArrowPatch
def load_graph_from_json(file_path):
    """Load graph data from a JSON file."""
    with open(file_path, 'r') as f:
        data = json.load(f)
    nodes = list(set(node for node, active in data["nodes"].items() if active))  # Ensure nodes are unique
    edges = {
        edge: edge_data 
        for edge, edge_data in data["edges"].items() if edge_data["in_graph"]
    }
    return nodes, edges

def compute_dynamic_threshold(edges, percentile=95):
    """Compute a dynamic threshold based on edge scores statistics."""
    scores = [abs(edge_data["score"]) for edge_data in edges.values()]
    threshold = np.percentile(scores, percentile)  # Use 95th percentile as threshold
    print(f"Dynamic threshold ({percentile}th percentile): {threshold}")
    return threshold

def filter_edges_by_threshold(edges, threshold):
    """Filter edges based on the threshold score."""
    return {edge: data for edge, data in edges.items() if abs(data["score"]) >= threshold}

def sort_nodes_within_layer(layer_nodes):
    """Sort nodes within each layer in the correct order, placing MLP nodes at the end."""
    def node_key(node):
        if node.startswith("a"):  # Attention head nodes
            parts = node.split(".")
            layer = int(parts[0][1:])
            head = int(parts[1][1:]) if len(parts) > 1 else -1
            return layer, head, 0  # Attention nodes come first
        elif node.startswith("m"):  # MLP nodes
            layer = int(node[1:])
            return layer, 9999, 1  # MLP nodes come last
        return -1, -1, 2  # Input/logits nodes
    return sorted(layer_nodes, key=node_key)

def adjust_color_intensity(base_color, degree, max_degree):
    """
    Adjust the color intensity based on the degree.
    Higher degree results in a darker color.
    """
    # Normalize degree to [0, 1]
    if max_degree == 0:
        norm_degree = 0
    else:
        norm_degree = degree / max_degree
    # Convert base color to RGBA
    rgba = np.array(to_rgba(base_color))
    # Adjust the brightness based on normalized degree
    r, g, b, a = rgba
    h, s, v = colorsys.rgb_to_hsv(r, g, b)
    # Decrease brightness as degree increases
    v = max(0.0, v - 0.5 * norm_degree)  # Adjust the factor as needed
    r, g, b = colorsys.hsv_to_rgb(h, s, v)
    return to_hex((r, g, b))

def draw_graph_diff(nodes, diff_edges, title, ax, degrees, max_degree):
    """Draw the graph with differences highlighted on a given axis."""
    G = nx.DiGraph()

    # Add all nodes
    G.add_nodes_from(set(nodes))  # Ensure nodes are unique

    # Add only edges that are added or removed
    added_edges = [edge for edge, status in diff_edges.items() if status == "Added"]
    removed_edges = [edge for edge, status in diff_edges.items() if status == "Removed"]
    G.add_edges_from([(edge.split("->")[0], edge.split("->")[1]) for edge in added_edges + removed_edges])

    # Separate nodes by layers
    layer_nodes = {}
    for node in nodes:
        if node == "logits":  # Place logits at the bottom
            layer = max([int(n[1:].split('.')[0]) for n in nodes if "a" in n]) + 1
        elif "m" in node:
            layer = int(node[1:])
        elif "a" in node:
            layer = int(node[1:].split('.')[0])
        else:
            layer = -1  # Input layer
        layer_nodes.setdefault(layer, []).append(node)

    # Sort nodes within each layer, placing MLP nodes at the end
    for layer in layer_nodes:
        layer_nodes[layer] = sort_nodes_within_layer(layer_nodes[layer])

    # Generate positions, centered horizontally
    pos = {}
    for layer, layer_nodes_list in layer_nodes.items():
        for i, node in enumerate(layer_nodes_list):
            x_offset = -len(layer_nodes_list) / 2 + i  # Center the nodes horizontally
            pos[node] = (x_offset, -layer)

    # Node colors and styles
    node_colors = []
    # Set all node borders to black to ensure they are always visible
    node_border_colors = ['black'] * len(nodes)

    for node in nodes:
        degree = degrees.get(node, 0)
        if node in added_nodes:
            base_color = "#104670"
            if degree > 0:
                color = adjust_color_intensity(base_color, degree, max_degree)
                node_colors.append(color)
            else:
                node_colors.append(base_color)
        elif node in removed_nodes:
            base_color = "#E47159"
            if degree > 0:
                color = adjust_color_intensity(base_color, degree, max_degree)
                node_colors.append(color)
            else:
                node_colors.append(base_color)
        else:
            if degree == 0:
                node_colors.append("white")  # Hollow
            else:
                # Determine gray intensity inversely based on degree
                if max_degree == 0:
                    norm_degree = 0
                else:
                    norm_degree = degree / max_degree
                # Lower brightness for higher degrees
                gray_intensity = 1.0 - 0.7 * norm_degree  # Between 0.3 (dark) and 1.0 (light)
                gray = str(gray_intensity)
                node_colors.append(gray)

    # Edge colors: added (#104670), removed (#E47159)
    edge_colors = []
    for edge in added_edges:
        edge_colors.append("#104670")  # Added edge color
    for edge in removed_edges:
        edge_colors.append("#E47159")  # Removed edge color

    # Draw the graph
    nx.draw(
        G, pos, with_labels=False, node_size=300, font_size=9,  # Adjust node_size as needed
        node_color=node_colors, edge_color=edge_colors, arrowsize=10, alpha=0.9,
        edgecolors=node_border_colors, linewidths=1, ax=ax
    )
    # ax.set_title(title, fontsize=33, fontweight='bold')

def compute_differences(nodes_base, nodes_final, edges_base, edges_final):
    """Compute differences between nodes and edges."""
    # Nodes differences
    added_nodes = set(nodes_final) - set(nodes_base)
    removed_nodes = set(nodes_base) - set(nodes_final)

    # Edges differences
    edges_base_set = set(edges_base.keys())
    edges_final_set = set(edges_final.keys())
    added_edges = edges_final_set - edges_base_set
    removed_edges = edges_base_set - edges_final_set

    diff_edges = {edge: "Added" for edge in added_edges}
    diff_edges.update({edge: "Removed" for edge in removed_edges})

    return added_nodes, removed_nodes, diff_edges

# Load JSON data for arithmetic task
base_path_arithmetic = "/2_arithmetic_operations_100/graph_results_100/rl_graph_results/graph_simplemath_1.4b_checkpoint_0.json"
final_path_arithmetic = "/2_arithmetic_operations_100/graph_results_100/rl_graph_results/graph_simplemath_1.4b_checkpoint_10.json"

nodes_base_arithmetic, edges_base_arithmetic = load_graph_from_json(base_path_arithmetic)
nodes_final_arithmetic, edges_final_arithmetic = load_graph_from_json(final_path_arithmetic)

# Compute dynamic thresholds for arithmetic task
threshold_base_arithmetic = compute_dynamic_threshold(edges_base_arithmetic, percentile=95)
threshold_final_arithmetic = compute_dynamic_threshold(edges_final_arithmetic, percentile=95)

# Filter edges by task-specific thresholds
filtered_edges_arithmetic = filter_edges_by_threshold(edges_base_arithmetic, threshold_base_arithmetic)
filtered_edges_final_arithmetic = filter_edges_by_threshold(edges_final_arithmetic, threshold_final_arithmetic)

# Compute differences for arithmetic task
added_nodes_arithmetic, removed_nodes_arithmetic, diff_edges_arithmetic = compute_differences(
    nodes_base_arithmetic, nodes_final_arithmetic, filtered_edges_arithmetic, 
    filtered_edges_final_arithmetic
)

# Combine nodes for final graph
all_nodes_arithmetic = list(set(nodes_base_arithmetic + nodes_final_arithmetic))

# Create a graph to compute degrees
G_final = nx.DiGraph()
G_final.add_nodes_from(all_nodes_arithmetic)
G_final.add_edges_from([(edge.split("->")[0], edge.split("->")[1]) for edge in filtered_edges_final_arithmetic.keys()])

# Compute degrees
degrees = dict(G_final.degree())
max_degree = max(degrees.values()) if degrees else 0

# Identify added and removed nodes for coloring
added_nodes = added_nodes_arithmetic
removed_nodes = removed_nodes_arithmetic

fig, ax = plt.subplots(figsize=(8, 10), dpi=300)  

# Plot the graph
draw_graph_diff(
    all_nodes_arithmetic, 
    diff_edges_arithmetic,  
    "Differences in Add/Sub Circuit before and after fine-tuning", 
    ax=ax,
    degrees=degrees,
    max_degree=max_degree
)

legend_elements = [
    plt.Line2D([0], [0], marker='o', color='w', label='Added node',
               markerfacecolor="#104670", markeredgecolor="black", markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Removed node',
               markerfacecolor="#E47159", markeredgecolor="black", markersize=10),
    plt.Line2D([0], [0], marker='o', color='w', label='Unchanged node',
               markerfacecolor="white", markeredgecolor="black", markersize=10),
    # Use FancyArrowPatch for arrow legends
    plt.Line2D([0], [0], color="#104670", lw=2, label='Added edge'),
    plt.Line2D([0], [0], color="#E47159", lw=2, label='Removed edge')
]
#ax.legend(handles=legend_elements, loc='lower right', fontsize=20, prop={'weight': 'bold'})


plt.tight_layout()
plt.savefig("graph_differences_arithmetic.pdf")
plt.savefig("graph_differences_arithmetic.png", bbox_inches='tight', format='png')
plt.show()
