In [None]:
# === AGNP: Text→Prior Map (color = chosen hyperparameter) ===
# How to use:
# 1) Set csv_path to your AgNP bullets CSV (or a combined CSV that includes AgNP rows).
# 2) Pick one hyperparam for coloring: "prior.outputscale", "prior.noise",
#    "prior.mean_bias", or "prior.lengthscale" (auto-averages ARD cols if present).
# 3) Optionally set readouts to a list of bullet text snippets you want to highlight.

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def _find_embedding_columns(df: pd.DataFrame):
    return [c for c in df.columns if c.startswith("emb_")]

def _find_lengthscale_columns(df: pd.DataFrame):
    cands = [c for c in df.columns if c.startswith("prior.lengthscale")]
    # also allow single 'prior.lengthscale'
    if "prior.lengthscale" in df.columns and len(cands) == 0:
        cands = ["prior.lengthscale"]
    return cands

def _get_hyperparam_values(df: pd.DataFrame, hyperparam: str):
    if hyperparam == "prior.lengthscale":
        ls_cols = _find_lengthscale_columns(df)
        if len(ls_cols) == 0:
            raise ValueError("No lengthscale columns found (prior.lengthscale*).")
        L = df[ls_cols].to_numpy(dtype=float)
        # Use log-average if positive; else plain average
        if np.all(L > 0):
            vals = np.exp(np.log(L).mean(axis=1))
        else:
            vals = L.mean(axis=1)
        return vals
    # direct column
    if hyperparam not in df.columns:
        raise ValueError(f"Column '{hyperparam}' not found.")
    return df[hyperparam].to_numpy(dtype=float)

def _minmax01(v):
    v = np.asarray(v, dtype=float)
    vmin, vmax = np.nanmin(v), np.nanmax(v)
    if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax - vmin < 1e-12:
        return np.zeros_like(v)
    return (v - vmin) / (vmax - vmin)

def plot_agnp_text_prior_map(
    csv_path: str,
    hyperparam: str = "prior.outputscale",
    readouts: list[str] | None = None,
    normalize: bool = True,
    annotate_readouts: bool = True,
):
    # --- load & filter to AgNP ---
    p = Path(csv_path)
    df = pd.read_csv(p)
    if "dataset" in df.columns:
        sub = df[df["dataset"].astype(str).str.lower() == "agnp"].copy()
        if sub.empty:
            raise ValueError("No rows with dataset=='AgNP' found in the CSV.")
    else:
        # assume the whole CSV is AgNP if no dataset column
        sub = df.copy()

    # --- embeddings ---
    emb_cols = _find_embedding_columns(sub)
    if len(emb_cols) < 2:
        raise ValueError("Need ≥2 embedding columns named like emb_0, emb_1, ...")
    Z = PCA(n_components=2).fit_transform(sub[emb_cols].to_numpy(dtype=float))

    # --- color values (hyperparameter) ---
    cvals = _get_hyperparam_values(sub, hyperparam)
    if normalize:
        cvals = _minmax01(cvals)

    # --- make plot (single chart) ---
    fig = plt.figure(figsize=(7, 6))
    ax = plt.gca()
    sc = ax.scatter(Z[:, 0], Z[:, 1], c=cvals, s=30)
    cb = plt.colorbar(sc)
    cb.set_label(hyperparam)

    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_title(f"AgNP — Text→Prior Map (color = {hyperparam})")

    # --- highlight selected readouts (optional) ---
    if readouts:
        # build a mask of rows whose bullet_text contains any snippet
        if "bullet_text" not in sub.columns:
            print("[WARN] 'bullet_text' column not found; cannot match readouts.")
        else:
            txt = sub["bullet_text"].astype(str).str.lower().values
            mask = np.zeros(len(sub), dtype=bool)
            for r in readouts:
                rlow = str(r).lower()
                if len(rlow.strip()) == 0:
                    continue
                mask |= np.array([rlow in t for t in txt], dtype=bool)
            # overlay highlighted markers
            if mask.any():
                ax.scatter(Z[mask, 0], Z[mask, 1], s=120, facecolors="none", edgecolors="black", linewidths=1.5)
                if annotate_readouts:
                    for i in np.where(mask)[0]:
                        # keep labels short
                        lab = sub.iloc[i]["bullet_text"]
                        if len(lab) > 60:
                            lab = lab[:57] + "..."
                        ax.annotate(lab, (Z[i, 0], Z[i, 1]), xytext=(5, 5),
                                    textcoords="offset points", fontsize=8)

    plt.tight_layout()
    return fig, ax, sub, Z, cvals

# ------- configure & run -------
csv_path = "agnp_bullets.csv"     # <--- change to your AgNP CSV path
hyperparam = "prior.outputscale"  # choose: prior.outputscale | prior.noise | prior.mean_bias | prior.lengthscale
readouts = [
    # put snippets of bullets you want to highlight; case-insensitive substring match
    # e.g., "lower temperature", "batch variability", "avoid overheating"
]

# If you run this cell now without a real file, it will error. Set csv_path correctly first.
fig, ax, sub_df, emb_2d, colors = plot_agnp_text_prior_map(
    csv_path=csv_path,
    hyperparam=hyperparam,
    readouts=readouts,
    normalize=True,
    annotate_readouts=True,
)
plt.show()
