In [None]:
import os
import json
import re
import matplotlib.pyplot as plt
import seaborn as sns

def load_checkpoints_in_order(task_dir):
    checkpoints = []


    files = [f for f in os.listdir(task_dir) if f.endswith('.json')]
    checkpoint_files = []
    for filename in files:
        match = re.search(r'checkpoint_(\d+)\.json', filename)
        if match:
            num = int(match.group(1))
            checkpoint_files.append((num, filename))
    checkpoint_files.sort(key=lambda x: x[0])

    for _, filename in checkpoint_files:
        checkpoint_path = os.path.join(task_dir, filename)
        with open(checkpoint_path, 'r') as f:
            data = json.load(f)
            checkpoints.append(data)

    return checkpoints

def calculate_in_graph_changes_per_checkpoint(checkpoints):
    in_graph_changes = []

    for i in range(1, len(checkpoints)):
        prev_edges = checkpoints[i - 1]['edges']
        curr_edges = checkpoints[i]['edges']
        change_count = 0

        for edge in curr_edges.keys():
            prev_in_graph = prev_edges.get(edge, {}).get('in_graph', False)
            curr_in_graph = curr_edges[edge]['in_graph']
            if prev_in_graph != curr_in_graph: 
                change_count += 1

        in_graph_changes.append(change_count)

    return in_graph_changes

def visualize_multiple_tasks_in_graph_changes(task_changes, task_labels, output_file):
    sns.set_theme(style="whitegrid")
    color_palette = sns.color_palette("dark", len(task_changes))

    plt.figure(figsize=(12, 8))

    for changes, label, color in zip(task_changes, task_labels, color_palette):
        transitions = range(1, len(changes) + 1)  
        plt.plot(transitions, changes, marker='o', linestyle='-', label=label, color=color)

    plt.xlabel('Checkpoint Transition', fontsize=14)
    plt.ylabel('Number of Edge In-Graph State Changes', fontsize=14)
    plt.title('Edge Changes Between Checkpoints Across Tasks', fontsize=16)
    plt.legend(fontsize=12)
    plt.xticks(range(1, len(task_changes[0]) + 1), [f'Ckpt {i}->Ckpt {i+1}' for i in range(len(task_changes[0]))], fontsize=10, rotation=45)
    plt.yticks(fontsize=10)
    plt.tight_layout()

    plt.savefig(output_file, format='pdf')
    plt.show()


def main():

    task_dirs = [
        '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_r32_circuit',
        '/2_arithmetic_operations_200/graph_results_200/200_r32_circuit',
        '/2_arithmetic_operations_300/graph_results_300/300_r32_circuit',
        '/2_arithmetic_operations_400/graph_results_400/400_r32_circuit',
        '/2_arithmetic_operations_500/graph_results_500/500_r32_circuit',
        '/Mul_Div_operations/graph_results_Mul_Div_/mul_div_r32_circuit',
        '/sequence_reasoning_task/graph_results_sequence/sequence_r32_circuit',
        '/LCM/graph_results_lcm/lcm_r32_circuit',
        '/function_eval/graph_results_function/function_r32_circuit'
    ]
    task_labels = ["Add_Sub(within 100)", "Add_Sub(within 200)", "Add_Sub(within 300)","Add_Sub(within 400)","Add_Sub(within 500)","Multiplication and Division","Arithmetic and Geometric Sequence","Least Common Multiple","Function Evaluation"] 


    task_changes = []

    for task_dir in task_dirs:
        checkpoints = load_checkpoints_in_order(task_dir)

    
        if len(checkpoints) < 2:
            print(f"Skipping task {task_dir} due to insufficient checkpoints.")
            continue


        in_graph_changes = calculate_in_graph_changes_per_checkpoint(checkpoints)
        task_changes.append(in_graph_changes)


    output_file = '/Circuit_Analysis/pictures/edge_change.pdf'
    visualize_multiple_tasks_in_graph_changes(task_changes, task_labels, output_file)
    print(f"Visualization saved to {output_file}")

if __name__ == '__main__':
    main()


In [None]:
import os
import json
import re
import matplotlib.pyplot as plt
import seaborn as sns


def load_checkpoints_in_order(task_dir):
    checkpoints = []


    files = [f for f in os.listdir(task_dir) if f.endswith('.json')]


    checkpoint_files = []
    for filename in files:
        match = re.search(r'checkpoint_(\d+)\.json', filename)
        if match:
            num = int(match.group(1))
            checkpoint_files.append((num, filename))


    checkpoint_files.sort(key=lambda x: x[0])


    for _, filename in checkpoint_files:
        checkpoint_path = os.path.join(task_dir, filename)
        with open(checkpoint_path, 'r') as f:
            data = json.load(f)
            checkpoints.append(data)

    return checkpoints

def calculate_in_graph_changes_per_checkpoint(checkpoints):
    in_graph_changes = []

    for i in range(1, len(checkpoints)):
        prev_edges = checkpoints[i - 1]['edges']
        curr_edges = checkpoints[i]['edges']
        change_count = 0

    
        for edge in curr_edges.keys():
            prev_in_graph = prev_edges.get(edge, {}).get('in_graph', False)
            curr_in_graph = curr_edges[edge]['in_graph']
            if prev_in_graph != curr_in_graph:  
                change_count += 1

        in_graph_changes.append(change_count)

    return in_graph_changes

def visualize_multiple_tasks_in_graph_changes(task_changes, task_labels, output_file):
    sns.set_theme(style="whitegrid")
    color_palette = sns.color_palette("dark", len(task_changes))

    plt.figure(figsize=(12, 8))

    for changes, label, color in zip(task_changes, task_labels, color_palette):
        transitions = range(1, len(changes) + 1) 
        plt.plot(transitions, changes, marker='o', linestyle='-', label=label, color=color)

    plt.xlabel('Checkpoint Transition', fontsize=14)
    plt.ylabel('Number of Edge In-Graph State Changes', fontsize=14)
    plt.title('Edge Changes Between Checkpoints Across Tasks', fontsize=16)
    plt.legend(fontsize=12)
    plt.xticks(range(1, len(task_changes[0]) + 1), [f'Ckpt {i}->Ckpt {i+1}' for i in range(len(task_changes[0]))], fontsize=10, rotation=45)
    plt.yticks(fontsize=10)
    plt.tight_layout()

    plt.savefig(output_file, format='pdf')
    plt.show()


def main():
 
    task_dirs = [
        # '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_r32_circuit',
        # '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_gpt_r32_circuit',
        # '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_opt_r32_circuit'
        # '/2_arithmetic_operations_200/graph_results_200/200_r32_circuit',
        # '/2_arithmetic_operations_200/graph_results_200/adalora_graph_results',
        # '/2_arithmetic_operations_200/graph_results_200/ia3_graph_results'
        # '/2_arithmetic_operations_300/graph_results_300/300_r32_circuit',
        # '/2_arithmetic_operations_400/graph_results_400/400_r32_circuit',
        # '/2_arithmetic_operations_500/graph_results_500/500_r32_circuit',
        # '/Mul_Div_operations/graph_results_Mul_Div_/mul_div_r32_circuit',
        # '/Mul_Div_operations/graph_results_Mul_Div_/adalora_graph_results',
        # '/Mul_Div_operations/graph_results_Mul_Div_/gpt_graph_results'
        '/sequence_reasoning_task/graph_results_sequence/sequence_r32_circuit',
        '/sequence_reasoning_task/graph_results_sequence/sft_results'
        # '/LCM/graph_results_lcm/lcm_r32_circuit',
        # '/LCM/graph_results_lcm/adalora_graph_results',
        # '/LCM/graph_results_lcm/ia3_graph_results'
        # '/function_eval/graph_results_function/function_r32_circuit',
        # '/function_eval/graph_results_function/adalora_graph_results',
        # '/function_eval/graph_results_function/ia3_graph_results'
        # '/sequence_reasoning_task/graph_results_sequence/sequence_r32_circuit',
        # '/sequence_reasoning_task/graph_results_sequence/adalora_graph_results',
        # '/sequence_reasoning_task/graph_results_sequence/ia3_graph_results'
    ]
    #task_labels = ["Add_Sub(within 200)","Add_Sub(within 300)","Add_Sub(within 400)","Add_Sub(within 500)", "Multiplication and Division","Function Evaluation"]  
    # task_labels = ["Add_Sub_200(LoRA)","Add_Sub_200(AdaLoRA)","Add_Sub_200(IA3)"] 
    # task_labels = ["Mul_Div(LoRA)","Mul_Div(AdaLoRA)","Mul_Div(IA3)"] 
    # task_labels = ["Sequence(LoRA)","Sequence(AdaLoRA)","Sequence(IA3)"] 
    # task_labels = ["LCM(LoRA)","LCM(AdaLoRA)","LCM(IA3)"] 
    # task_labels = ["Function(LoRA)","Function(AdaLoRA)","Function(IA3)"] 
    # task_labels = ["pythia-1.4b-deduped","gpt-neo-2.7B","opt-6.7b"] 
    task_labels = ["Full_FT (Sequence)","LoRA (Sequence)"] 
    task_changes = []

    for task_dir, label in zip(task_dirs, task_labels):
        checkpoints = load_checkpoints_in_order(task_dir)

   
        if len(checkpoints) < 2:
            print(f"Skipping task {task_dir} due to insufficient checkpoints.")
            continue

        num_edges_ckpt0 = len(checkpoints[0]['edges'])
        print(f"Task: {label}, Checkpoint 0 Edge Count: {num_edges_ckpt0}")

        in_graph_changes = calculate_in_graph_changes_per_checkpoint(checkpoints)
        task_changes.append(in_graph_changes)

        print(f"Task: {label}")
        print(f"Changes: {in_graph_changes}")


    output_file = '/Circuit_Analysis/pictures/edge_change.pdf'
    visualize_multiple_tasks_in_graph_changes(task_changes, task_labels, output_file)
    print(f"Visualization saved to {output_file}")

if __name__ == '__main__':
    main()


In [None]:
import json

def load_graph_from_json(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    nodes = [node for node, active in data["nodes"].items() if active]
    edges = {edge: edge_data for edge, edge_data in data["edges"].items() if edge_data["in_graph"]}
    return nodes, edges

def compute_node_induced_edge_changes(nodes_base, nodes_final, edges_base, edges_final):
    
  
    added_nodes = set(nodes_final) - set(nodes_base)
    removed_nodes = set(nodes_base) - set(nodes_final)
    changed_nodes = added_nodes | removed_nodes

    added_edges = set(edges_final.keys()) - set(edges_base.keys())
    removed_edges = set(edges_base.keys()) - set(edges_final.keys())

    def is_incident(edge, changed_nodes):
        source, target = edge.split("->")
        return source in changed_nodes or target in changed_nodes

    added_edges_by_nodes = {edge for edge in added_edges if is_incident(edge, changed_nodes)}
    removed_edges_by_nodes = {edge for edge in removed_edges if is_incident(edge, changed_nodes)}

    total_edge_changes = len(added_edges) + len(removed_edges)
    total_node_induced_edge_changes = len(added_edges_by_nodes) + len(removed_edges_by_nodes)

    return {
        "added_nodes": added_nodes,
        "removed_nodes": removed_nodes,
        "changed_nodes": changed_nodes,
        "added_edges": added_edges,
        "removed_edges": removed_edges,
        "added_edges_by_nodes": added_edges_by_nodes,
        "removed_edges_by_nodes": removed_edges_by_nodes,
        "total_node_changes": len(changed_nodes),
        "total_edge_changes": total_edge_changes,
        "total_node_induced_edge_changes": total_node_induced_edge_changes,
    }

def print_detailed_summary(task_name, stats):

    print(f"===== {task_name} =====")
    print(f"Total node changes: {stats['total_node_changes']}")
    print(f"  Added nodes: {len(stats['added_nodes'])}")
    print(f"  Removed nodes: {len(stats['removed_nodes'])}")
    print(f"Total edge changes: {stats['total_edge_changes']}")
    print(f"Edge changes induced by node changes: {stats['total_node_induced_edge_changes']}")
    # print("Detailed edge change information:")
    # print(f"  Added edges (all): {sorted(stats['added_edges'])}")
    # print(f"  Removed edges (all): {sorted(stats['removed_edges'])}")
    # print(f"  Added edges (caused by node changes): {sorted(stats['added_edges_by_nodes'])}")
    # print(f"  Removed edges (caused by node changes): {sorted(stats['removed_edges_by_nodes'])}")
    print("====================================\n")


if __name__ == "__main__":

    tasks = [
        {
            "name": "Add/Sub",
            "base_path": "/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_0.json",
            "final_path": "/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_250.json"
        },
        {
            "name": "Mul/Div",
            "base_path": "/Mul_Div_operations/graph_results_Mul_Div_/mul_div_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_0.json",
            "final_path": "/Mul_Div_operations/graph_results_Mul_Div_/mul_div_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_100.json"
        },
        {
            "name": "Sequence",
            "base_path": "/sequence_reasoning_task/graph_results_sequence/sequence_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_0.json",
            "final_path": "/sequence_reasoning_task/graph_results_sequence/sequence_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_500.json"
        },
        {
            "name": "LCM",
            "base_path": "/LCM/graph_results_lcm/lcm_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_0.json",
            "final_path": "/LCM/graph_results_lcm/lcm_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_250.json"
        },
        {
            "name": "Function",
            "base_path": "/function_eval/graph_results_function/function_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_0.json",
            "final_path": "/function_eval/graph_results_function/function_r32_circuit/graph_simplemath_1.4b_lora_checkpoint_250.json"
        }
    ]

    for task in tasks:
        print(f"Processing task: {task['name']}")
        nodes_base, edges_base = load_graph_from_json(task["base_path"])
        nodes_final, edges_final = load_graph_from_json(task["final_path"])
        stats = compute_node_induced_edge_changes(nodes_base, nodes_final, edges_base, edges_final)
        print_detailed_summary(task["name"], stats)
