In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import numpy as np
import re
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# plt.rcParams.update({
#     "font.family": "sans-serif",
#     "font.sans-serif": ["Inter", "DejaVu Sans", "Arial", "Source Sans 3", "Roboto"],
# })

# legend_font = FontProperties(
#     family="DejaVu Sans",
#     style="normal",
#     weight="bold",
#     size=12,
# )

In [None]:
# === Load dataframes ===
MOMENTUM = 1.0
TARGET_STEP = 40960 # 40960

loss_csv = f"data/normalisation-Scion-momentum-{MOMENTUM}-loss.csv"
norm_csv = f"data/normalisation-Scion-momentum-{MOMENTUM}-norm.csv"

df_loss = pd.read_csv(loss_csv, sep=';')
df_norm = pd.read_csv(norm_csv, sep=';')

# === Prepare legend mapping and order ===
rename = {
    "everywhere": r"$\mathbf{Everywhere}$ [A: QKVO; MLP: +, R: +, O: +]",
    "everywhere-wo-output": r"$\mathbf{Everywhere \; w/o \; out}$ [A: QKVO; MLP: +, R: +, O: -]",
    "standard": r"$\mathbf{Residuals \; + \; out \; + \; QK}$ [A: QK, MLP: -, R: +, O: +]",
    "residual-and-out": r"$\mathbf{Residuals \; + \; out}$ [A: -, MLP: -, R: +, O: +]",
    "residual-only": r"$\mathbf{Residuals \; only}$ [A: -, MLP: -, R: +, O: -]",
    "qk-only": r"$\mathbf{QK \; only}$ [A: QK, MLP: -, R: -, O: -]",
    "output-only": r"$\mathbf{Out \; only}$ [A: -, MLP: -, R: -, O: +]",
    "none": r"$\mathbf{None}$ [A: -, MLP: -, R: -, O: -]",
}
legend_order = [
    "everywhere",
    "everywhere-wo-output",
    "standard",
    "residual-and-out",
    "residual-only",
    "qk-only",
    "output-only",
    "none",
]

# This subset will be used for the norms text box
box_groups = ["everywhere", "everywhere-wo-output", "residual-and-out", "residual-only"]

In [None]:
# Select loss row at target step
step_col = "Step"
if step_col in df_loss.columns and (df_loss[step_col] == TARGET_STEP).any():
    loss_row = df_loss.loc[df_loss[step_col] == TARGET_STEP].iloc[-1]
else:
    loss_row = df_loss.iloc[-1]

# Select norm row at target step
if step_col in df_norm.columns and (df_norm[step_col] == TARGET_STEP).any():
    norm_row = df_norm.loc[df_norm[step_col] == TARGET_STEP].iloc[-1]
else:
    norm_row = df_norm.iloc[-1]

# === 1. From loss DF: collect runs per group ===
loss_suffix = "loss_metrics/global_avg_loss"
loss_cols = [c for c in df_loss.columns if str(c).endswith(loss_suffix)]

groups = {}  # group_name -> list of dicts (lr, loss, run_id, col)
lr_pattern = re.compile(r"-lr-([0-9\.eE+-]+)-")

for col in loss_cols:
    col_str = str(col)
    run_id = col_str.split(" - ", 1)[0]
    # Group: prefix before "Distributed"
    if "Distributed" in run_id:
        group_name = run_id.split("Distributed", 1)[0].rstrip("- ").strip()
    else:
        group_name = "unknown"
    m = lr_pattern.search(run_id)
    if not m:
        continue
    lr = float(m.group(1))
    loss_val = float(loss_row[col])
    groups.setdefault(group_name, []).append(
        {"lr": lr, "loss": loss_val, "run_id": run_id, "col": col_str}
    )

# === 2. Build mapping run_id -> norm value from norm DF ===
norm_suffix = "track_param_rms_to_l1/model_part_0/output"
norm_cols = [c for c in df_norm.columns if str(c).endswith(norm_suffix)]

runid_to_norm = {}
for col in norm_cols:
    run_id = str(col).split(" - ", 1)[0]
    norm_val = float(norm_row[col])
    runid_to_norm[run_id] = norm_val

# === 3. For each group, find top-3 lowest loss runs and their norms ===
top_runs = {}  # group_name -> list of dicts with lr, loss, run_id, norm
for group_name, entries in groups.items():
    # Sort by loss ascending
    sorted_entries = sorted(entries, key=lambda d: d["loss"])
    best = sorted_entries[:3]
    enriched = []
    for e in best:
        norm_val = runid_to_norm.get(e["run_id"], float("nan"))
        e2 = e.copy()
        e2["norm"] = norm_val
        enriched.append(e2)
    top_runs[group_name] = enriched

# === 5. Compute zoom region around global min loss ===
all_points = [ (e["lr"], e["loss"]) for entries in groups.values() for e in entries ]
all_lrs = [p[0] for p in all_points]
all_vals = [p[1] for p in all_points]
print(groups)
min_val = min(all_vals)

delta = 0.2
near_opt = [(lr, val) for (lr, val) in all_points if val <= min_val + delta]

if near_opt:
    x_min = min(lr for lr, _ in near_opt)
    x_max = max(lr for lr, _ in near_opt)
    y_min = min_val - 0.02
    y_max = min_val + delta + 0.02
else:
    best_lr = all_lrs[all_vals.index(min_val)]
    x_min, x_max = best_lr / 2, best_lr * 2
    y_min, y_max = min_val - 0.05, min_val + 0.25

# === 6. Plot as before (lines + inset), with renamed legend ===
fig, ax = plt.subplots(figsize=(8, 6))

markers = ["o", "s", "^", "D", "v", "P", "X", "*", "<", ">"]

handles = []
labels = []

for idx, group_name in enumerate(legend_order):
    if group_name not in groups:
        continue
    pts = groups[group_name]
    pts_sorted = sorted(pts, key=lambda x: x["lr"])
    lrs = [p["lr"] for p in pts_sorted]
    vals = [p["loss"] for p in pts_sorted]
    marker = markers[idx % len(markers)]
    (line,) = ax.plot(lrs, vals, marker=marker, markersize=10)
    handles.append(line)
    labels.append(rename.get(group_name, group_name))

ax.set_xscale("log", base=2)
ax.set_xlabel("Learning rate", fontsize=18)
ax.set_ylabel("Training loss", fontsize=18)
ax.tick_params(labelsize=16)
# ax.set_title(f"Scion (momentum={MOMENTUM}), proxy model, {TARGET_STEP*4096*256//1e9:.0f}B tokens", fontsize=18)
ax.grid(True)

# Legend outside
fig.subplots_adjust(right=0.78)
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1.04, 0.83) if MOMENTUM==0.1 else (1.04, 0.81), borderaxespad=0.) # , prop=legend_font

# Inset plot
axins = inset_axes(ax, width="40%", height="40%", loc="upper right", borderpad=1)

for idx, group_name in enumerate(legend_order):
    if group_name not in groups:
        continue
    pts = groups[group_name]
    pts_sorted = sorted(pts, key=lambda x: x["lr"])
    lrs = [p["lr"] for p in pts_sorted]
    vals = [p["loss"] for p in pts_sorted]
    marker = markers[idx % len(markers)]
    axins.plot(lrs, vals, marker=marker)

axins.set_xscale("log", base=2)
axins.set_xlim(x_min, x_max)
axins.set_ylim(y_min, y_max)
axins.tick_params(labelsize=12)
axins.grid(True)

# === 7. Build text box for norms of top-3 runs per group (only selected groups) ===
lines = []
for group_name in box_groups:
    if group_name not in top_runs or not top_runs[group_name]:
        continue
    group_label = rename.get(group_name, group_name)
    group_label = group_label.split('[')[0]
    lines.append(group_label)
    for e in top_runs[group_name]:
        lr = e["lr"]
        loss_val = e["loss"]
        norm_val = e["norm"]
        exp = int(round(np.log2(lr)))
        if np.isclose(lr, 2.0**exp):
            lr_str = f"2^{exp}"
        else:
            lr_str = f"{lr:.3g}"
        lines.append(f"  LR={lr_str}: norm={norm_val:.1f}, loss={loss_val:.2f}")
    lines.append("")

legend_text = "\n".join(lines).rstrip("\n")

# Place text box on the right, below the legend (same x, lower y)
fig.text(
    0.81,
    0.075 if MOMENTUM==0.1 else 0.05,
    legend_text,
    fontsize=10,
    va="bottom",
    ha="left",
    bbox=dict(boxstyle="round", ec="black", fc="white", alpha=0.9),
)

plt.show()

# fig.savefig(f"plots/normalisation-scion-mom-{MOMENTUM}-loss-lr.pdf", bbox_inches='tight')
