In [None]:
import os
import json
import re
import matplotlib.pyplot as plt
import seaborn as sns

# 1. Load all checkpoint files for a single task
def load_checkpoints_in_order(task_dir):
    checkpoints = []

    # List all JSON files in the directory
    files = [f for f in os.listdir(task_dir) if f.endswith('.json')]

    # Create a list of (checkpoint_number, filename) tuples
    checkpoint_files = []
    for filename in files:
        match = re.search(r'checkpoint_(\d+)\.json', filename)
        if match:
            num = int(match.group(1))
            checkpoint_files.append((num, filename))

    # Sort the files by checkpoint_number
    checkpoint_files.sort(key=lambda x: x[0])

    # Sequentially load the sorted files
    for _, filename in checkpoint_files:
        checkpoint_path = os.path.join(task_dir, filename)
        with open(checkpoint_path, 'r') as f:
            data = json.load(f)
            checkpoints.append(data)

    return checkpoints

# 2. Count the number of nodes in each checkpoint (number of True values)
def count_nodes_per_checkpoint(checkpoints):
    node_counts = []
    for checkpoint in checkpoints:
        nodes = checkpoint['nodes']
        true_count = sum(1 for node, state in nodes.items() if state)
        node_counts.append(true_count)
    return node_counts

# 3. Calculate the number of node changes (total additions and removals)
def calculate_node_changes(checkpoints):
    changes = []

    for i in range(1, len(checkpoints)):
        prev_nodes = checkpoints[i - 1]['nodes']
        curr_nodes = checkpoints[i]['nodes']

        added = sum(1 for node in curr_nodes if curr_nodes[node] and not prev_nodes.get(node, False))
        removed = sum(1 for node in prev_nodes if prev_nodes[node] and not curr_nodes.get(node, False))

        changes.append(added + removed)

    return changes

# 4. Visualize node count changes and node change numbers across multiple tasks
def visualize_multiple_tasks_node_changes(task_dirs, task_labels, output_dir):
    sns.set_theme(style="whitegrid")

    all_node_counts = []
    all_node_changes = []

    for task_dir, label in zip(task_dirs, task_labels):
        checkpoints = load_checkpoints_in_order(task_dir)

        # Skip this task if there are not enough checkpoint files
        if len(checkpoints) < 2:
            print(f"Insufficient checkpoint files in task path {task_dir}, skipping.")
            continue

        # Print the number of nodes in checkpoint 0
        num_nodes_ckpt0 = len(checkpoints[0]['nodes'])
        print(f"Task: {label}, Checkpoint 0 Node Count: {num_nodes_ckpt0}")

        node_counts = count_nodes_per_checkpoint(checkpoints)
        node_changes = calculate_node_changes(checkpoints)

        all_node_counts.append(node_counts)
        all_node_changes.append(node_changes)

        # Print node change data for each task
        print(f"Task: {label}")
        print(f"Node Changes: {node_changes}")

    # Plot line chart - node changes for each task
    plt.figure(figsize=(12, 8))
    for node_changes, label in zip(all_node_changes, task_labels):
        transitions = range(1, len(node_changes) + 1)
        plt.plot(transitions, node_changes, marker='o', linestyle='-', label=label)

    plt.xlabel('Checkpoint Transition', fontsize=14)
    plt.ylabel('Number of Node Changes', fontsize=14)
    plt.title('Node Changes Between Checkpoints Across Tasks', fontsize=16)
    plt.legend(fontsize=12)
    plt.xticks(
        range(1, len(all_node_changes[0]) + 1),
        [f'Ckpt {i}->Ckpt {i+1}' for i in range(len(all_node_changes[0]))],
        fontsize=10,
        rotation=45
    )
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'multiple_tasks_node_changes.pdf'))
    plt.show()

# 5. Main function
def main():
    # Set the list of task directories
    task_dirs = [
        # '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_r32_circuit',
        # '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_gpt_r32_circuit',
        # '/2_arithmetic_operations_100/graph_results_100/lora_graph_results/100_opt_r32_circuit'
        # '/Mul_Div_operations/graph_results_Mul_Div_/mul_div_r32_circuit',
        '/sequence_reasoning_task/graph_results_sequence/sequence_r32_circuit',
        '/sequence_reasoning_task/graph_results_sequence/sft_results'
        # '/sequence_reasoning_task/graph_results_sequence/adalora_graph_results',
        # '/sequence_reasoning_task/graph_results_sequence/ia3_graph_results'
        # '/LCM/graph_results_lcm/lcm_r32_circuit',
        # '/LCM/graph_results_lcm/adalora_graph_results',
        # '/LCM/graph_results_lcm/ia3_graph_results'
        # '/function_eval/graph_results_function/function_r32_circuit'
        # '/2_arithmetic_operations_200/graph_results_200/200_r32_circuit',
        # '/2_arithmetic_operations_200/graph_results_200/adalora_graph_results',
        # '/2_arithmetic_operations_200/graph_results_200/ia3_graph_results'
        # '/2_arithmetic_operations_300/graph_results_300/300_r32_circuit',
        # '/2_arithmetic_operations_400/graph_results_400/400_r32_circuit',
        # '/2_arithmetic_operations_500/graph_results_500/500_r32_circuit',
        # '/Mul_Div_operations/graph_results_Mul_Div_/mul_div_r32_circuit',
        # '/Mul_Div_operations/graph_results_Mul_Div_/adalora_graph_results',
        # '/Mul_Div_operations/graph_results_Mul_Div_/gpt_graph_results'
        # '/function_eval/graph_results_function/function_r32_circuit',
        # '/function_eval/graph_results_function/adalora_graph_results',
        # '/function_eval/graph_results_function/ia3_graph_results'
    ]
    # task_labels = ["Add_Sub(within 100)", "Multiplication and Division", "Arithmetic and Geometric Sequence", "Least Common Multiple", "Function Evaluation"]  # corresponding task names
    # task_labels = ["Add_Sub(within 200)", "Add_Sub(within 300)", "Add_Sub(within 400)", "Add_Sub(within 500)", "Multiplication and Division", "Function Evaluation"]  # corresponding task names
    # task_labels = ["pythia-1.4b-deduped", "gpt-neo-2.7B", "opt-6.7b"]
    task_labels = ["Full_FT (Sequence)", "LoRA (Sequence)"]
    output_dir = './visualizations'  # Directory to save images
    os.makedirs(output_dir, exist_ok=True)

    # Visualize node changes across multiple tasks
    visualize_multiple_tasks_node_changes(task_dirs, task_labels, output_dir)

if __name__ == '__main__':
    main()
