# Set configs


In [9]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats

In [11]:
# load all the predefined functions
from plot_util import SurprisalLoader

In [12]:
ROOT = Path("/Users/jliu/workspace/RAG/")
fig_path = ROOT / "fig"
surprisal_path = ROOT / "results" / "surprisal"
KL_path = ROOT / "results" / "token_freq"
freq_path = ROOT / "datasets/freq/EleutherAI/pythia-410m"

## Plot surprisal dynamics

In [13]:
# load file
stat_path = surprisal_path / "stat_all.csv"

if stat_path.is_file():
    stat_frame = pd.read_csv(stat_path)
else:
    analyzer = SurprisalLoader(surprisal_path)
    # Process all files and get statistics
    stats = analyzer.get_stat_all()


Stat file has been saved to /Users/jliu/workspace/RAG/results/surprisal/stat_all.csv


In [24]:
stats

Unnamed: 0,log_step,surprisal,vec,neuron,model,ablation,eval,effect
0,3.6021,13.398346,base,0,70m,base,merged,boost
1,3.6990,13.159312,base,0,70m,base,merged,boost
2,3.7782,13.193204,base,0,70m,base,merged,boost
3,3.8451,13.146212,base,0,70m,base,merged,boost
4,3.9031,13.261228,base,0,70m,base,merged,boost
...,...,...,...,...,...,...,...,...
22675,5.1430,12.948298,longtail,50,410m,full,longtail_words,supress
22676,5.1461,13.080685,longtail,50,410m,full,longtail_words,supress
22677,5.1492,12.693981,longtail,50,410m,full,longtail_words,supress
22678,5.1523,12.719519,longtail,50,410m,full,longtail_words,supress


In [23]:
plot_all(stats)

Processing eval=surprisal, vec=mean, model=70m, ablation=mean
Processing eval=surprisal, vec=mean, model=70m, ablation=zero
Processing eval=surprisal, vec=mean, model=70m, ablation=random
Processing eval=surprisal, vec=mean, model=70m, ablation=scaled
Processing eval=surprisal, vec=mean, model=70m, ablation=full
Processing eval=surprisal, vec=mean, model=410m, ablation=mean
Processing eval=surprisal, vec=mean, model=410m, ablation=zero
Processing eval=surprisal, vec=mean, model=410m, ablation=random
Processing eval=surprisal, vec=mean, model=410m, ablation=scaled
Processing eval=surprisal, vec=mean, model=410m, ablation=full
Processing eval=surprisal, vec=longtail, model=70m, ablation=mean
Processing eval=surprisal, vec=longtail, model=70m, ablation=zero
Processing eval=surprisal, vec=longtail, model=70m, ablation=random
Processing eval=surprisal, vec=longtail, model=70m, ablation=scaled
Processing eval=surprisal, vec=longtail, model=70m, ablation=full
Processing eval=surprisal, vec=lo

In [22]:
# plot and save the results
neuron_colors = {
    0: "#1f77b4",  # blue for baseline
    1: "#4589b9",  # interpolated between 0 and 10
    2: "#6b9bbe",  # interpolated between 0 and 10
    5: "#a5bec6",  # interpolated between 0 and 10
    10: "#ff7f0e",  # orange (unchanged)
    25: "#c89f1d",  # interpolated between 10 and 50
    50: "#2ca02c",  # green (unchanged)
    500: "#9467bd",  # purple (unchanged)
}


def plot_all(
    df,
    output_dir: str = fig_path,
    eval_set: str= "surprisal",
    neurons: list[int] = [1, 2, 5, 10, 25, 50],
    neuron_colors: dict = neuron_colors,
    ylim_dict: dict = None,
) -> None:
    """Plot the overall development with simplified metrics handling."""
    # Get unique values for grouping
    models = df["model"].unique()
    effect_lst = df["effect"].unique()
    vec_lst = ["mean", "longtail"]
    ablations = [a for a in df["ablation"].unique() if a != "base"]  # Non-baseline ablations

    # Process each model and ablation type
    for effect in effect_lst:
        for vec in vec_lst:
            for model in models:
                for ablation in ablations:
                    print(f"Processing eval={eval_set}, vec={vec}, model={model}, ablation={ablation}")

                    # Create a new figure
                    plt.figure(figsize=(10, 8))

                    # Filter data for this model
                    model_data = df[
                        (df["model"] == model)
                        & (df["vec"] == vec)
                        & (df["eval"] == eval_set)
                        & (df["effect"] == effect)
                    ]

                    # Get baseline data (always include baseline for comparison)
                    baseline_data = df[
                        (df["model"] == model)
                        & (df["ablation"] == "base")
                        & (df["eval"] == eval_set)
                        & (df["effect"] == effect)
                    ]

                    # Group baseline data by log_step
                    baseline_grouped = baseline_data.groupby("log_step")
                    x_values = sorted(baseline_data["log_step"].unique())

                    # Extract surprisal values directly (no need for mean across metrics)
                    y_values = [
                        baseline_grouped.get_group(log_step)["surprisal"].values[0]
                        for log_step in x_values
                        if log_step in baseline_grouped.groups
                    ]

                    # Plot baseline
                    plt.plot(x_values, y_values, color=neuron_colors.get(0, "black"), linewidth=2, label="base")

                    # Process each neuron condition for this ablation
                    for neuron in neurons:
                        # Filter data for this neuron and ablation combination
                        condition_data = model_data[
                            (model_data["neuron"] == neuron) & (model_data["ablation"] == ablation)
                        ]

                        if condition_data.empty:
                            continue

                        # Group by log_step
                        condition_grouped = condition_data.groupby("log_step")
                        x_values = sorted(condition_data["log_step"].unique())

                        # Extract surprisal values directly
                        y_values = [
                            condition_grouped.get_group(log_step)["surprisal"].values[0]
                            for log_step in x_values
                            if log_step in condition_grouped.groups
                        ]

                        # Plot this neuron condition
                        plt.plot(
                            x_values, y_values, color=neuron_colors.get(neuron, "gray"), linewidth=2, label=str(neuron)
                        )

                    # Style the plot
                    plt.xlabel("Log step", fontsize=11)
                    plt.ylabel("Surprisal", fontsize=11)
                    plt.title(f"model={model}, intervention={ablation}", fontsize=13)
                    plt.grid(alpha=0.2)

                    # Create legend
                    handles, labels = plt.gca().get_legend_handles_labels()

                    # If baseline is in the legend, make sure it comes first
                    if "base" in labels:
                        base_idx = labels.index("base")
                        # Move baseline to front
                        handles = [handles[base_idx]] + [h for i, h in enumerate(handles) if i != base_idx]
                        labels = [labels[base_idx]] + [l for i, l in enumerate(labels) if i != base_idx]

                    plt.legend(handles, labels, loc="lower left")

                    # Set y-axis limits if provided
                    if ylim_dict and eval_set in ylim_dict and model in ylim_dict[eval_set]:
                        plt.ylim(ylim_dict[eval_set][model])

                    # Save the figure
                    plt.tight_layout()
                    # Create output directory if it doesn't exist
                    (output_dir / effect).mkdir(parents=True, exist_ok=True)
                    plt.savefig(output_dir / effect / f"{vec}_{model}_{ablation}.png", dpi=300, bbox_inches="tight")
                    plt.close()