In [1]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MaxNLocator

# --- Input/Output Paths ---
# Modify these paths to match your project structure.
INPUT_DIR = "full_compute"  # The directory containing your input JSON files.
OUTPUT_DIR_BASE = "graphs"  # The base directory where output graphs will be saved.

# --- Plotting Customization ---
# Easily change these parameters to customize all generated graphs.
TOP_K_LAYERS = 3
BOTTOM_K_LAYERS = 3
DELTA_BIAS_CHERRY_PICK_K = 3
BAR_COLORS = ('#0077bb', "#ee5233")  # Dark Blue and Orange

# CHANGE: Added more granular font size and tick control
BASE_FONT_SIZE = 20
X_AXIS_TICK_FONT_SIZE = 20  # Font size for labels on the x-axis
Y_AXIS_TICK_FONT_SIZE = 16  # Font size for labels on the y-axis
Y_AXIS_MAX_TICKS = 5    # Maximum number of ticks on the y-axis

In [2]:

# %%
# =============================================================================
# CELL 2: PLOTTING FUNCTION
# =============================================================================
# def plot_block_type_results_customizable(experiment_data, block_type, side_names, main_title, output_dir):
#     """
#     Generates and saves grouped bar charts for a specific block type (attn or mlp).

#     Args:
#         experiment_data (dict): The dictionary for one experiment.
#         block_type (str): The type of block to plot ('attn' or 'mlp').
#         side_names (tuple): Tuple of the two side names (e.g., ('cpp_top', 'python_top')).
#         main_title (str): The overall title for the figure.
#         output_dir (str): The directory where the plot image will be saved.
#     """
#     side_a_name, side_b_name = side_names

#     filtered_data = {
#         key: value for key, value in experiment_data.items() if block_type in key
#     }

#     if not filtered_data:
#         print(f"--> Warning: No data found for block type '{block_type}'. Skipping plot.")
#         return

#     fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
#     fig.suptitle(main_title, fontsize=BASE_FONT_SIZE + 6, weight='bold')

#     axes = {'KL': ax1, 'shift_other': ax2}
    
#     y_limits = {'KL': [float('inf'), float('-inf')], 'delta_bias': [float('inf'), float('-inf')]}

#     all_labels = []
#     num_top_layers = 0

#     # First pass: gather all data and determine y-axis limits
#     for category_name, layers_data in sorted(filtered_data.items(), reverse=True):
#         k = TOP_K_LAYERS if 'top' in category_name else BOTTOM_K_LAYERS
        
#         if 'top' in category_name:
#             layers_to_process = dict(list(layers_data.items())[:k])
#         else:  # 'bottom' category
#             layers_to_process = dict(list(layers_data.items())[-k:])  # Slice from the end for bottom layers
        
#         if 'top' in category_name:
#             num_top_layers = len(layers_to_process)
        
#         all_labels.extend(layers_to_process.keys())

#         for metric in ["KL", "delta_bias"]:
#             means_a = [layers_to_process[layer][side_a_name][metric]['mean'] for layer in layers_to_process]
#             means_b = [layers_to_process[layer][side_b_name][metric]['mean'] for layer in layers_to_process]
            
#             min_val = min(min(means_a), min(means_b))
#             max_val = max(max(means_a), max(means_b))

#             if min_val < y_limits[metric][0]: y_limits[metric][0] = min_val
#             if max_val > y_limits[metric][1]: y_limits[metric][1] = max_val

#     # Second pass: plot the data
#     for metric, ax in axes.items():
#         if metric in ["KL", "delta_bias"]:
#             all_means_a, all_stds_a = [], []
#             all_means_b, all_stds_b = [], []

#             for category_name, layers_data in sorted(filtered_data.items(), reverse=True):
#                 k = TOP_K_LAYERS if 'top' in category_name else BOTTOM_K_LAYERS
                
#                 if 'top' in category_name:
#                     layers_to_process = dict(list(layers_data.items())[:k])
#                 else:  # 'bottom' category
#                     layers_to_process = dict(list(layers_data.items())[-k:])  # Slice from the end for bottom layers
                
#                 layer_names = list(layers_to_process.keys())
                
#                 all_means_a.extend([layers_to_process[layer][side_a_name][metric]['mean'] for layer in layer_names])
#                 all_stds_a.extend([layers_to_process[layer][side_a_name][metric]['std'] for layer in layer_names])
#                 all_means_b.extend([layers_to_process[layer][side_b_name][metric]['mean'] for layer in layer_names])
#                 all_stds_b.extend([layers_to_process[layer][side_b_name][metric]['std'] for layer in layer_names])

#             x = np.arange(len(all_labels))
#             width = 0.35

#             ax.bar(x - width/2, all_means_a, width, label=side_a_name.upper().replace('_TOP', ''), yerr=all_stds_a, capsize=5, color=BAR_COLORS[0], alpha=0.9)
#             ax.bar(x + width/2, all_means_b, width, label=side_b_name.upper().replace('_TOP', ''), yerr=all_stds_b, capsize=5, color=BAR_COLORS[1], alpha=0.9)

#             ax.set_ylabel(metric, fontsize=BASE_FONT_SIZE + 2)
#             if metric == "shift_other": metric = "Shift"
#             ax.set_title(f'Comparison of {metric} Metric', fontsize=BASE_FONT_SIZE + 4)
#             ax.set_xticks(x)
#             ax.set_xticklabels(all_labels, rotation=45, ha="right")
#             ax.legend(fontsize=BASE_FONT_SIZE)
#             ax.grid(axis='y', linestyle='--', alpha=0.7)
            
#             # CHANGE: Use tunable font sizes for x and y tick labels
#             ax.tick_params(axis='x', which='major', labelsize=X_AXIS_TICK_FONT_SIZE)
#             ax.tick_params(axis='y', which='major', labelsize=Y_AXIS_TICK_FONT_SIZE)
            
#             # CHANGE: Set a maximum number of ticks on the y-axis to prevent clutter
#             ax.yaxis.set_major_locator(MaxNLocator(integer=False, nbins=Y_AXIS_MAX_TICKS))
            
#             if len(filtered_data) > 1 and num_top_layers > 0:
#                 separator_pos = num_top_layers - 0.5
#                 ax.axvline(separator_pos, color='grey', linestyle='--')

#     try:
#         os.makedirs(output_dir, exist_ok=True)
#         filename = f"{block_type}_{side_names[0]}_{side_names[1]}.pdf"
#         full_path = os.path.join(output_dir, filename)
#         plt.tight_layout(rect=[0, 0.03, 1, 0.95])
#         plt.savefig(full_path, bbox_inches="tight")
#         print(f"--> Successfully saved plot to {full_path}")
#     except Exception as e:
#         print(f"--> Error saving plot: {e}")
#     finally:
#         plt.close(fig)

# %%
# =============================================================================
# CELL 2: PLOTTING FUNCTION
# =============================================================================
def plot_block_type_results_customizable(experiment_data, block_type, side_names, main_title, output_dir):
    """
    Generates and saves bar charts. KL plot respects JSON order.
    Delta_bias plot cherry-picks layers with the highest C++/Python difference from each category.
    """
    side_a_name, side_b_name = side_names
    
    filtered_data = {k: v for k, v in experiment_data.items() if block_type in k}
    if not filtered_data:
        print(f"--> Warning: No data found for '{block_type}'. Skipping.")
        return

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    fig.suptitle(main_title, fontsize=BASE_FONT_SIZE + 6, weight='bold')
    
    # --- Plot KL Divergence (Respecting JSON Order) ---
    ax_kl = ax1
    kl_labels, kl_means_a, kl_stds_a, kl_means_b, kl_stds_b = [], [], [], [], []
    num_top_layers_kl = 0

    for cat_name, layers in sorted(filtered_data.items(), key=lambda i: 'bottom' in i[0]):
        k = TOP_K_LAYERS if 'top' in cat_name else BOTTOM_K_LAYERS
        layers_to_process = dict(list(layers.items())[:k])
        if 'top' in cat_name:
            num_top_layers_kl = len(layers_to_process)
        kl_labels.extend(layers_to_process.keys())
        for layer_data in layers_to_process.values():
            kl_means_a.append(layer_data[side_a_name]['KL']['mean'])
            kl_stds_a.append(layer_data[side_a_name]['KL']['std'])
            kl_means_b.append(layer_data[side_b_name]['KL']['mean'])
            kl_stds_b.append(layer_data[side_b_name]['KL']['std'])
    
    x_kl = np.arange(len(kl_labels))
    width = 0.35
    ax_kl.bar(x_kl - width/2, kl_means_a, width, label=side_a_name.upper().replace('_TOP', ''), yerr=kl_stds_a, capsize=5, color=BAR_COLORS[0])
    ax_kl.bar(x_kl + width/2, kl_means_b, width, label=side_b_name.upper().replace('_TOP', ''), yerr=kl_stds_b, capsize=5, color=BAR_COLORS[1])
    ax_kl.set_title('Comparison of KL Divergence', fontsize=BASE_FONT_SIZE + 4)
    ax_kl.set_ylabel('KL Divergence', fontsize=BASE_FONT_SIZE + 2)
    ax_kl.set_xticks(x_kl)
    ax_kl.set_xticklabels(kl_labels, rotation=45, ha="right")
    if len(filtered_data) > 1 and num_top_layers_kl > 0:
        ax_kl.axvline(num_top_layers_kl - 0.5, color='grey', linestyle='--')

    # --- Plot Delta Bias (Cherry-Picking by Difference within each category) ---
    ax_bias = ax2
    bias_labels, bias_means_a, bias_stds_a, bias_means_b, bias_stds_b = [], [], [], [], []
    num_top_layers_bias = 0

    for cat_name, layers in sorted(filtered_data.items(), key=lambda i: 'bottom' in i[0]):
        layers_with_diff = []
        for layer_name, layer_data in layers.items():
            mean_a = layer_data[side_a_name]['delta_bias']['mean']
            mean_b = layer_data[side_b_name]['delta_bias']['mean']
            diff = abs(mean_a - mean_b)
            layers_with_diff.append((diff, layer_name, layer_data))
        
        layers_with_diff.sort(key=lambda item: item[0], reverse=True)
        
        picked_layers = layers_with_diff[:DELTA_BIAS_CHERRY_PICK_K]
        
        if 'top' in cat_name:
            num_top_layers_bias = len(picked_layers)
            
        for _, layer_name, layer_data in picked_layers:
            bias_labels.append(layer_name)
            bias_means_a.append(layer_data[side_a_name]['delta_bias']['mean'])
            bias_stds_a.append(layer_data[side_a_name]['delta_bias']['std'])
            bias_means_b.append(layer_data[side_b_name]['delta_bias']['mean'])
            bias_stds_b.append(layer_data[side_b_name]['delta_bias']['std'])

    x_bias = np.arange(len(bias_labels))
    ax_bias.bar(x_bias - width/2, bias_means_a, width, label=side_a_name.upper().replace('_TOP', ''), yerr=bias_stds_a, capsize=5, color=BAR_COLORS[0])
    ax_bias.bar(x_bias + width/2, bias_means_b, width, label=side_b_name.upper().replace('_TOP', ''), yerr=bias_stds_b, capsize=5, color=BAR_COLORS[1])
    ax_bias.set_title(f'Top {DELTA_BIAS_CHERRY_PICK_K} Layers by Delta Bias Difference', fontsize=BASE_FONT_SIZE + 4)
    ax_bias.set_ylabel('Delta Bias', fontsize=BASE_FONT_SIZE + 2)
    ax_bias.set_xticks(x_bias)
    ax_bias.set_xticklabels(bias_labels, rotation=45, ha="right")
    if len(filtered_data) > 1 and num_top_layers_bias > 0:
        ax_bias.axvline(num_top_layers_bias - 0.5, color='grey', linestyle='--')

    # --- Final Touches for Both Plots ---
    for ax in [ax_kl, ax_bias]:
        ax.legend(fontsize=BASE_FONT_SIZE)
        ax.grid(axis='y', linestyle='--', alpha=0.7)
        ax.tick_params(axis='x', labelsize=X_AXIS_TICK_FONT_SIZE)
        ax.tick_params(axis='y', labelsize=Y_AXIS_TICK_FONT_SIZE)
        ax.yaxis.set_major_locator(MaxNLocator(nbins=Y_AXIS_MAX_TICKS))

    # --- Save the Figure ---
    try:
        os.makedirs(output_dir, exist_ok=True)
        filename = f"{block_type}_comparison.pdf"
        full_path = os.path.join(output_dir, filename)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig(full_path, bbox_inches="tight")
        print(f"--> Successfully saved plot to {full_path}")
    except Exception as e:
        print(f"--> Error saving plot: {e}")
    finally:
        plt.close(fig)

# %%
# =============================================================================
# CELL 3: MAIN EXECUTION LOGIC
# =============================================================================
def main(input_dir, output_dir_base):
    """
    Main function to find JSON files and generate plots.
    """
    print(f"Scanning for JSON files in: {input_dir}")
    if not os.path.isdir(input_dir):
        print(f"Error: Input directory '{input_dir}' not found.")
        return

    for filename in os.listdir(input_dir):
        if filename.endswith('.json'):
            json_path = os.path.join(input_dir, filename)
            print(f"\nProcessing file: {filename}")
            
            model_name = os.path.splitext(filename)[0]
            output_dir_final = os.path.join(output_dir_base, model_name, 'graphs')

            try:
                with open(json_path, 'r') as f:
                    data = json.load(f)
            except json.JSONDecodeError:
                print(f"--> Error: Could not decode JSON from {filename}. Skipping.")
                continue
            except Exception as e:
                print(f"--> Error reading file {filename}: {e}. Skipping.")
                continue
            
            experiment_key = next(iter(data), None)
            if not experiment_key or not isinstance(data[experiment_key], dict):
                 print(f"--> Error: JSON file {filename} has an unexpected format. Skipping.")
                 continue

            experiment_data = data[experiment_key]

            for block in ['mlp', 'attn']:
                plot_block_type_results_customizable(
                    experiment_data=experiment_data,
                    block_type=block,
                    side_names=('cpp_top', 'python_top'),
                    main_title=f'{model_name.upper()} - {block.upper()} Layer Analysis',
                    output_dir=output_dir_final
                )

    print("\nScript finished.")

# --- Run the main function ---
main(INPUT_DIR, OUTPUT_DIR_BASE)



Scanning for JSON files in: full_compute

Processing file: llama_1b.json
--> Successfully saved plot to graphs/llama_1b/graphs/mlp_comparison.pdf
--> Successfully saved plot to graphs/llama_1b/graphs/attn_comparison.pdf

Processing file: gemma_4b.json
--> Successfully saved plot to graphs/gemma_4b/graphs/mlp_comparison.pdf
--> Successfully saved plot to graphs/gemma_4b/graphs/attn_comparison.pdf

Processing file: llama_3b.json
--> Successfully saved plot to graphs/llama_3b/graphs/mlp_comparison.pdf
--> Successfully saved plot to graphs/llama_3b/graphs/attn_comparison.pdf

Processing file: gemma_1b.json
--> Successfully saved plot to graphs/gemma_1b/graphs/mlp_comparison.pdf
--> Successfully saved plot to graphs/gemma_1b/graphs/attn_comparison.pdf

Script finished.
