In [None]:
import json
import os

def load_circuit(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data

def save_circuit(json_data, output_path):
    with open(output_path, 'w') as f:
        json.dump(json_data, f, indent=4)

def merge_nodes(nodes_a, nodes_b, mode='union'):
    merged_nodes = {}
    all_keys = set(nodes_a.keys()) | set(nodes_b.keys())
    for k in all_keys:
        val_a = nodes_a.get(k, False)
        val_b = nodes_b.get(k, False)
        if mode == 'union':
            merged_nodes[k] = (val_a or val_b)
        elif mode == 'intersection':
            merged_nodes[k] = (val_a and val_b)
        else:
            raise ValueError("mode must be union or intersection.")
    return merged_nodes

def merge_edges(edges_a, edges_b, mode='union', score_fuse='avg'):

    merged_edges = {}
    set_a = set(edges_a.keys())
    set_b = set(edges_b.keys())
    
    if mode == 'union':
        all_keys = set_a | set_b
    elif mode == 'intersection':
        all_keys = set_a & set_b
    else:
        raise ValueError("mode must be union or intersection.")
    
    for key in all_keys:
        in_a = key in edges_a
        in_b = key in edges_b
        if in_a and in_b:
            score_a = edges_a[key]["score"]
            score_b = edges_b[key]["score"]
            # 融合 score
            if score_fuse == 'avg':
                fused_score = 0.5 * (score_a + score_b)
            elif score_fuse == 'max':
                fused_score = max(score_a, score_b)
            elif score_fuse == 'min':
                fused_score = min(score_a, score_b)
            else:
                fused_score = score_a  
            merged_edges[key] = {
                "score": fused_score,
                "in_graph": True 
            }
        elif in_a and not in_b:

            if mode == 'union':
                merged_edges[key] = edges_a[key].copy()
        elif in_b and not in_a:

            if mode == 'union':
                merged_edges[key] = edges_b[key].copy()

    return merged_edges

def prune_edges_by_score(edges_dict, target_num_edges):

    if len(edges_dict) <= target_num_edges:
        return edges_dict  

    sorted_items = sorted(edges_dict.items(), key=lambda x: abs(x[1]["score"]), reverse=True)
    kept = sorted_items[:target_num_edges]

    pruned = {k: v for k, v in kept}
    return pruned

def prune_isolated_nodes(nodes_dict, edges_dict):

    connected_nodes = set()
    for edge_key in edges_dict.keys():
        src, tgt = edge_key.split("->")
        connected_nodes.add(src)
        connected_nodes.add(tgt)
    
    new_nodes = {}
    for node_name, val in nodes_dict.items():
        if node_name in connected_nodes or node_name in ["input","logits"]:  
       
            new_nodes[node_name] = val
    return new_nodes

def main():

    add_sub_path = "/add_sub_mul_div/graph_results_steps/graph_add_sub_1.4b_r32_epoch4_initial.json"
    mul_div_path = "/add_sub_mul_div/graph_results_steps/graph_mul_div_1.4b_r32_epoch4_initial.json"
    output_merged_path = "./merge_circuit/intersection_merged_add_sub_mul_div_after_initial.json"


    add_sub_circuit = load_circuit(add_sub_path)
    mul_div_circuit = load_circuit(mul_div_path)


    merged_circuit = {
        "cfg": add_sub_circuit["cfg"].copy(), 
        "nodes": {},
        "edges": {}
    }

    nodes_a = add_sub_circuit["nodes"]
    nodes_b = mul_div_circuit["nodes"]
    merged_nodes = merge_nodes(nodes_a, nodes_b, mode='union')  # 'union' 'intersection'
    merged_circuit["nodes"] = merged_nodes


    edges_a = add_sub_circuit["edges"]
    edges_b = mul_div_circuit["edges"]
    merged_edges = merge_edges(edges_a, edges_b, mode='union', score_fuse='avg')

    target_num_edges = 11479
    if len(merged_edges) > target_num_edges:
        merged_edges = prune_edges_by_score(merged_edges, target_num_edges)

    merged_nodes = prune_isolated_nodes(merged_nodes, merged_edges)
    merged_circuit["nodes"] = merged_nodes

    merged_circuit["edges"] = merged_edges
    save_circuit(merged_circuit, output_merged_path)
    
    print(f"Merged circuit saved to: {output_merged_path}")
    print(f"Final #Nodes: {sum(1 for k,v in merged_nodes.items() if v)}, #Edges: {len(merged_edges)}")

if __name__ == "__main__":
    main()


Merged circuit saved to: ./merge_circuit/intersection_merged_add_sub_mul_div_after_initial.json
Final #Nodes: 303, #Edges: 11479
