In [None]:
import os  # for file path operations
import json  # for reading JSON files
import pandas as pd  # for DataFrame creation and manipulation
import matplotlib as mpl  # for Matplotlib configuration
import matplotlib.pyplot as plt  # for plotting

# ---------- Global plot style settings ----------
mpl.rcParams.update({
    # Use serif font family for all text
    "font.family": "serif",
    "font.serif": ["DejaVu Serif"],
    # Base font size
    "font.size": 18,
    # Title and label sizes
    "axes.titlesize": 20,
    "axes.labelsize": 18,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    # Line and marker styles
    "lines.linewidth": 2,
    "lines.markersize": 8,
    # Axis spine and grid settings
    "axes.linewidth": 1.2,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "axes.grid": True,
    "grid.linestyle": "--",
    "grid.linewidth": 0.6,
    "grid.alpha": 0.6,
    # Legend styling
    "legend.frameon": True,
    "legend.fontsize": 14,
    "legend.title_fontsize": 16,
    # Cycle through specific colors for plot lines and markers
    "axes.prop_cycle": mpl.cycler("color", ["#0989D3", "#E6593A", "#009E73"]),
    "legend.edgecolor": "#0A0A0A"
})

# ---------- Input file configurations ----------
# Each entry defines the JSON path, display name, mapping of states, and axis limits
files = [
    {
        "path": "s1.1-7B_pca_shift.json",
        "model_name": "Qwen3-14B_Distill-NoThink",
        "state_map": {"Original": "Base", "Updated": "s1.1-7B"},
        "xlim": (-300, 200),
        "ylim": (-500, 1400),
    },
    {
        "path": "s1.1-32B_pca_shift.json",
        "model_name": "Qwen3-14B_Distill-Think",
        "state_map": {"Original": "Base", "Updated": "s1.1-32B"},
        "xlim": (-300, 200),
        "ylim": (-500, 1400),
    }
]

# ---------- Constants ----------
# Marker shape for each state
markers = {"Original": "o", "Updated": "^"}
# Z-order to control overlay of points
zorder_order = {"Original": 1, "Updated": 2}
# List of benchmarks to plot (column order)
target_benchmarks = [
    "MATH500_PCA_Shift",
    "Head_QA_PCA_Shift",
    "Livecodebench_PCA_Shift",
    "COQA_PCA_Shift",
    "HalluEval_PCA_Shift",
]
# Background colors for each row (first two rows)
row_colors = {
    0: "#F3F8FF",  # light blue for first row
    1: "#FFF6EB",  # light orange for second row
}

# ---------- Utility functions ----------
def build_df(path: str) -> pd.DataFrame:
    """
    Load JSON file and construct a DataFrame with columns: benchmark, layer, state, shift, principle.
    """
    with open(path, "r") as f:
        content = json.load(f)
    rows = []
    # Iterate through each benchmark block
    for block in content:
        bench = block["benchmark"]
        for d in block["data"]:
            # Add benchmark name to each row
            rows.append({"benchmark": bench, **d})
    df = pd.DataFrame(rows)
    # Standardize state naming
    df["state"] = df["state"].replace({"step1": "Updated"})
    return df


def plot_subplot(ax, df, state_map, xlim, ylim):
    """
    Plot lines and points on a given Axes:
    - Gray lines connect original vs updated points per layer
    - Colored markers show original and updated positions
    """
    # Plot gray connecting lines for each layer
    for layer in df["layer"].unique():
        sub = df[df["layer"] == layer].sort_values("state")
        ax.plot(
            sub["shift"], sub["principle"],
            color="gray", linewidth=1, alpha=0.5, zorder=1
        )
    # Scatter points for each state
    for state in df["state"].unique():
        if state not in markers:
            continue
        sub = df[df["state"] == state]
        ax.scatter(
            sub["shift"], sub["principle"],
            marker=markers[state],
            label=state_map[state],
            alpha=0.65,
            zorder=zorder_order[state]
        )
    # Apply axis limits and styling
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.tick_params(direction="out")

# ---------- Create combined figure ----------
# Map model names to their configuration dict
model_cfgs = {cfg["model_name"]: cfg for cfg in files}
# Create a 2-row, 5-column grid of subplots
fig, axs = plt.subplots(2, 5, figsize=(20, 8))
plt.subplots_adjust(wspace=0.35, hspace=0.4)
# Main title above the grid
fig.suptitle("PCA Shift across Models and Tasks", fontsize=24, y=1.05)

# Populate each subplot
for col, benchmark in enumerate(target_benchmarks):
    for row, (model_name, cfg) in enumerate(model_cfgs.items()):
        # Load data and filter for current benchmark
        data = build_df(cfg["path"])
        data = data[data["benchmark"] == benchmark]
        ax = axs[row, col]
        # Set row-specific background color
        ax.set_facecolor(row_colors[row])
        # Plot lines and markers
        plot_subplot(ax, data, cfg["state_map"], cfg["xlim"], cfg["ylim"])
        # Add column title on first row
        if row == 0:
            ax.set_title(benchmark.replace("_PCA_Shift", ""), pad=6, fontsize=16)
        # X axis label on second row
        if row == 1:
            ax.set_xlabel("(PC1 Δ)")
        # Y axis label on first column
        if col == 0:
            ax.set_ylabel("(PC2)")
            # Add legend only on first column
            leg = ax.legend(loc="upper center", frameon=True, fancybox=True, fontsize=14)
            leg.get_frame().set_edgecolor("#0A0A0A")
            leg.get_frame().set_linewidth(1)

# ---------- Add row titles above each row ----------
row_titles = [
    "s1.1-7B; Base: Qwen-2.5-7B-Instruct",
    "s1.1-32B; Base: Qwen-2.5-32B-Instruct"
]
for row, row_title in enumerate(row_titles):
    row_axes = axs[row, :]
    # Compute horizontal span of the row
    pos_left = row_axes[0].get_position().x0
    pos_right = row_axes[-1].get_position().x1
    pos_y = row_axes[0].get_position().y1 + 0.04  # a bit above the subplots
    pos_x = (pos_left + pos_right) / 2
    # Place centered text as row title
    fig.text(
        pos_x, pos_y, row_title,
        ha='center', va='bottom', fontsize=18, fontweight='bold'
    )

# ---------- Add category annotations above columns ----------
# Label first three columns as "Math" and "Other-Reasoning", last two as "Non-Reasoning"
axs[0, 0].annotate(
    "Math", xy=(0.5, 1.25), xycoords="axes fraction",
    fontsize=20, ha="center"
)
for i in [1, 2]:
    axs[0, i].annotate(
        "Other-Reasoning", xy=(0.5, 1.25), xycoords="axes fraction",
        fontsize=20, ha="center"
    )
for i in [3, 4]:
    axs[0, i].annotate(
        "Non-Reasoning", xy=(0.5, 1.25), xycoords="axes fraction",
        fontsize=20, ha="center"
    )

# ---------- Save figure to PDF ----------
fig.savefig(
    "PCA_Shift_7B_32B.pdf",
    format="pdf", dpi=300, bbox_inches="tight"
)
plt.close()  # Close the figure to free memory
