In [None]:
# %% [markdown]
# # Kaitiaki Planner — Budget-aware Multi-Agent PoC (Capstone)
# Runs baseline vs budgeter (with/without rerank), prints metrics, and saves plots/CSV.

# %%
import os, json, time, yaml, unicodedata as ud, numpy as np, pandas as pd, requests
from pathlib import Path

RETR = "http://localhost:8001"
ORCH = "http://localhost:8000"

# Check services
print(requests.get(f"{RETR}/healthz").status_code, requests.get(f"{RETR}/healthz").text)
print(requests.get(f"{ORCH}").status_code, requests.get(f"{ORCH}").text)

# %%
# Load/build corpus for ingest
corpus = json.load(open("../data/corpus.json","r",encoding="utf-8"))
DOC_TEXT = {d["id"]: d["text"] for d in corpus}
ing = requests.post(f"{RETR}/ingest", json={"docs": corpus}, timeout=120)
print("Ingest:", ing.status_code, ing.json())

# %%
# --- Gold offsets (Unicode/normalization safe) ---
eval_yaml = yaml.safe_load(open("../eval/tasks.yaml","r",encoding="utf-8"))

def nfc(x): return ud.normalize("NFC", x)

def find_offsets(doc_id, snippet):
    text = DOC_TEXT[doc_id]
    T = nfc(text).lower()
    S = nfc(snippet).lower()
    i = T.find(S)
    if i == -1:
        preview = nfc(text)[:80].replace("\n"," ")
        raise ValueError(f"[gold snippet missing] doc_id={doc_id} | snippet='{snippet}' | doc_starts='{preview}…'")
    return i, i + len(S)

eval_tasks = []
for item in eval_yaml:
    s, e = find_offsets(item["gold"]["doc_id"], item["gold"]["text_snippet"])
    eval_tasks.append({
        "id": item["id"],
        "query": item["query"],
        "lang": item["lang"],
        "gold_citations": [{"doc_id": item["gold"]["doc_id"], "start": s, "end": e}]
    })
len(eval_tasks)

# %%
# --- Grounded Correctness (IoU on spans) ---
def grounded_correctness(pred_cites, gold, iou_thresh=0.3):
    gs, ge = gold["start"], gold["end"]
    gdoc = gold["doc_id"]
    for c in pred_cites or []:
        if c.get("doc_id") != gdoc:
            continue
        s, e = int(c.get("char_start",-1)), int(c.get("char_end",-1))
        inter = max(0, min(e, ge) - max(s, gs))
        union = max(ge, e) - min(gs, s)
        if union > 0 and inter/union >= iou_thresh:
            return 1.0
    return 0.0

# %%
# --- Runner ---
def run_suite(tasks, mode, use_rerank=True):
    rows=[]
    for t in tasks:
        payload={"query":t["query"], "mode":mode, "lang":t["lang"], "use_rerank":use_rerank}
        t0=time.time()
        try:
            r = requests.post(f"{ORCH}/query", json=payload, timeout=60)
            rec = r.json()
        except Exception as e:
            rec = {"response":{"citations":[],"refusal":True}, "metrics":{"total_ms":60000, "cost_usd":0.0}, "error":str(e)}
        dt=(time.time()-t0)*1000

        metrics = rec.get("metrics", {}) or {}
        cost = metrics.get("cost_usd", metrics.get("total_cost_usd", 0.0))
        lat  = metrics.get("total_ms", dt)
        cites= rec.get("response",{}).get("citations",[])
        rows.append({
            "id":t["id"], "lang":t["lang"], "mode":mode, "use_rerank": use_rerank,
            "gc": grounded_correctness(cites, t["gold_citations"][0]),
            "lat_ms": lat,
            "cost": cost,
            "refusal": bool(rec.get("response",{}).get("refusal", False))
        })
    return pd.DataFrame(rows)

# Baseline vs Budgeter (with rerank)
df_base = run_suite(eval_tasks, "baseline", use_rerank=True)
df_budg = run_suite(eval_tasks, "budgeter", use_rerank=True)

# Simple ablation: Budgeter WITHOUT rerank
df_budg_norr = run_suite(eval_tasks, "budgeter", use_rerank=False)

df = pd.concat([df_base, df_budg, df_budg_norr], ignore_index=True)
df.head()

# %%
# --- Summaries & SLA checks ---
def summarize(frame: pd.DataFrame):
    if frame.empty: return {"overall_gc": None, "p50_ms": None, "p95_ms": None, "mean_cost": None}
    return {
        "overall_gc": frame["gc"].mean(),
        "p50_ms": float(np.percentile(frame["lat_ms"], 50)),
        "p95_ms": float(np.percentile(frame["lat_ms"], 95)),
        "mean_cost": frame["cost"].mean()
    }

def fairness_gap(frame):
    by = frame.groupby("lang")["gc"].mean()
    return float(by.get("en", np.nan) - by.get("mi", np.nan))

for name, frame in {
    "baseline": df[df["mode"]=="baseline"],
    "budgeter": df[(df["mode"]=="budgeter") & (df["use_rerank"])],
    "budgeter_no_rerank": df[(df["mode"]=="budgeter") & (~df["use_rerank"])]
}.items():
    print(name, summarize(frame), "fairness_gap_EN-MI=", f"{fairness_gap(frame):.3f}")

TARGET_P95_MS = 1200
MAX_GC_DROP   = 0.02

def pick(frame, mode, use_rerank=None):
    f = frame[frame["mode"]==mode]
    if use_rerank is True:  f = f[f["use_rerank"]]
    if use_rerank is False: f = f[~f["use_rerank"]]
    return f

base   = pick(df, "baseline", True)
budg   = pick(df, "budgeter", True)
budgNR = pick(df, "budgeter", False)

def s(fr):
    return {
        "gc": fr["gc"].mean(),
        "p50": float(np.percentile(fr["lat_ms"],50)) if not fr.empty else np.nan,
        "p95": float(np.percentile(fr["lat_ms"],95)) if not fr.empty else np.nan,
        "cost_mean": fr["cost"].mean(),
        "refusal_rate": fr["refusal"].mean()
    }

S_base, S_budg, S_budgNR = s(base), s(budg), s(budgNR)
print("BASE   :", S_base)
print("BUDGET :", S_budg)
print("BUDG_NR:", S_budgNR)

gc_drop = (S_base["gc"] - S_budg["gc"]) if S_base["gc"] is not None else np.nan
meets_p95 = S_budg["p95"] <= TARGET_P95_MS
meets_gc  = (gc_drop <= MAX_GC_DROP) if not np.isnan(gc_drop) else False

print(f"\nACCEPTANCE — p95<= {TARGET_P95_MS}ms? {meets_p95} | GC drop ≤ {MAX_GC_DROP*100:.0f} pts? {meets_gc}")
print(f"Fairness gap (EN−MI), budgeter: {fairness_gap(budg):.3f}")
print(f"Refusal-rate gap (EN−MI), budgeter: {float(budg.groupby('lang')['refusal'].mean().get('en',0.0) - budg.groupby('lang')['refusal'].mean().get('mi',0.0)):.3f}")

# %%
# --- Plots ---
import matplotlib.pyplot as plt
os.makedirs("figures", exist_ok=True)
os.makedirs("outputs", exist_ok=True)

variants = ["baseline","budgeter","budgeter_no_rerank"]
vals = [
    df[df["mode"]=="baseline"]["gc"].mean(),
    df[(df["mode"]=="budgeter") & (df["use_rerank"])]["gc"].mean(),
    df[(df["mode"]=="budgeter") & (~df["use_rerank"])]["gc"].mean()
]
plt.figure()
plt.bar(variants, vals); plt.ylim(0,1)
plt.ylabel("Grounded Correctness")
plt.title(f"Correctness by Variant (n={len(df_base)})")
plt.savefig("figures/fig1_correctness.png", bbox_inches="tight"); plt.show()

def p50(frame): return float(np.percentile(frame["lat_ms"], 50)) if not frame.empty else np.nan
def p95(frame): return float(np.percentile(frame["lat_ms"], 95)) if not frame.empty else np.nan
P50=[p50(df[df["mode"]=="baseline"]),
     p50(df[(df["mode"]=="budgeter") & (df["use_rerank"])]),
     p50(df[(df["mode"]=="budgeter") & (~df["use_rerank"])])]
P95=[p95(df[df["mode"]=="baseline"]),
     p95(df[(df["mode"]=="budgeter") & (df["use_rerank"])]),
     p95(df[(df["mode"]=="budgeter") & (~df["use_rerank"])])]
plt.figure()
x = np.arange(len(variants))
plt.plot(x, P50, marker='o', label='p50')
plt.plot(x, P95, marker='o', label='p95')
plt.xticks(x, variants); plt.ylabel("Latency (ms)"); plt.title("Latency by Variant"); plt.legend()
plt.savefig("figures/fig2_latency.png", bbox_inches="tight"); plt.show()

fg = [fairness_gap(df[df["mode"]=="baseline"]),
      fairness_gap(df[(df["mode"]=="budgeter") & (df["use_rerank"])]),
      fairness_gap(df[(df["mode"]=="budgeter") & (~df["use_rerank"])])]
plt.figure()
plt.bar(variants, fg); plt.axhline(0, linestyle='--'); plt.ylim(-1,1)
plt.ylabel("GC(EN) - GC(MI)")
plt.title("Fairness Gap (EN - MI)")
plt.savefig("figures/fig3_fairness_gap.png", bbox_inches="tight"); plt.show()

# %%
# Save raw results
df.to_csv("outputs/results.csv", index=False)
print("Saved: figures/*.png and outputs/results.csv")
