In [None]:
# === TEST-SET EVAL with smart label sets (Colab, OpenAI 0.28, no env vars) ===
!pip -q install "openai==0.28" pandas tqdm

import os, json, gzip, time, difflib
import pandas as pd
from tqdm import tqdm
from typing import Dict, Any, List, Tuple
from getpass import getpass
import openai

print("OpenAI SDK version:", openai.__version__)

# -----------------------------
# Config
# -----------------------------
FILE_PATH         = "abcd_sample (2).json"   # use full dataset if you have it; fallback to "abcd_sample.json"
ONTOLOGY_PATH     = "data/ontology.json"    # optional; enrich label sets if present
PRIMARY_MODEL     = "gpt-4o"           # try "gpt-4o" if you have access
FALLBACK_MODELS   = ["gpt-4o", "gpt-3.5-turbo-0125"]
MAX_TEST_SAMPLES  = 50                      # cap for cost
REQUEST_DELAY_SEC = 0.7

# -----------------------------
# API key (no env vars)
# -----------------------------
openai.api_key = getpass("Enter your OpenAI API key (will not echo): ")

def safe_chat(messages, model):
    try:
        resp = openai.ChatCompletion.create(model=model, messages=messages, temperature=0)
        return resp["choices"][0]["message"].get("content","") if resp.get("choices") else ""
    except Exception as e:
        print(f"[chat:{model}] error: {e}")
        return ""

# -----------------------------
# IO helpers
# -----------------------------
def load_json_maybe_gz(path: str):
    if not os.path.exists(path):
        gz = path + ".gz"
        if os.path.exists(gz):
            path = gz
    if not os.path.exists(path):
        raise FileNotFoundError(f"Could not find {path}")
    with open(path, "rb") as f:
        is_gz = f.read(2) == b"\x1f\x8b"
    opener = gzip.open if is_gz else open
    with opener(path, "rt", encoding="utf-8") as f:
        return json.load(f)

def convo_to_transcript(convo: Dict[str,Any]) -> str:
    orig = convo.get("original", [])
    return " ".join([f"{sp}: {tx}" for sp, tx in orig])

# -----------------------------
# Load dataset
# -----------------------------
try:
    abcd = load_json_maybe_gz(FILE_PATH)
except FileNotFoundError:
    # fallback: sample file in CWD
    FILE_PATH = "abcd_sample.json"
    abcd = load_json_maybe_gz(FILE_PATH)

sample_mode = isinstance(abcd, list)
if sample_mode:
    print("Detected sample-style file (list). We'll split it 50/50 for demo.")
    n = len(abcd)
    train_dev = abcd[: max(1, n//2)]
    test_split = abcd[max(1, n//2):]
else:
    train_dev = (abcd.get("train", []) or []) + (abcd.get("dev", []) or [])
    test_split = (abcd.get("test", []) or [])

print(f"Train+Dev convos: {len(train_dev)} | Test convos: {len(test_split)}")
if len(test_split) == 0:
    raise RuntimeError("No test split found. Point FILE_PATH to abcd_v1.1.json(.gz) or keep sample file.")

# -----------------------------
# Build label sets
#  - full dataset: from train+dev only (no test leakage)
#  - sample file (very tiny): if too few labels found, broaden using ALL items in sample
#  - ontology.json (if present): optionally enrich choices
# -----------------------------
def labels_from_convos(convos: List[Dict[str,Any]]) -> Tuple[List[str], List[str]]:
    flows, subs = set(), set()
    for c in convos:
        sc = c.get("scenario", {})
        f  = sc.get("flow", "")
        sf = sc.get("subflow", "")
        if f:  flows.add(str(f))
        if sf: subs.add(str(sf))
    return sorted(flows), sorted(subs)

flow_opts, subflow_opts = labels_from_convos(train_dev)

# Optional: enrich from ontology.json (won't leak test labels semantically, just adds known valid strings)
if os.path.exists(ONTOLOGY_PATH):
    try:
        onto = load_json_maybe_gz(ONTOLOGY_PATH)
        strings = set()
        def walk(x):
            if isinstance(x, dict):
                for k,v in x.items():
                    if isinstance(k,str): strings.add(k)
                    walk(v)
            elif isinstance(x, list):
                for i in x: walk(i)
            elif isinstance(x, str):
                strings.add(x)
        walk(onto)
        # Keep only ontology strings that look like our labels (heuristic: must contain underscore or be present in any split)
        all_fl, all_sf = labels_from_convos(((abcd if sample_mode else (abcd.get("train", []) + abcd.get("dev", []) + abcd.get("test", []))) if abcd else []))
        candidates = {s for s in strings if ("_" in s) or (s in all_fl) or (s in all_sf)}
        flow_opts = sorted(set(flow_opts) | ( candidates & set(all_fl) ))
        subflow_opts = sorted(set(subflow_opts) | ( candidates & set(all_sf) ))
    except Exception as e:
        print(f"[ontology] Could not parse {ONTOLOGY_PATH}: {e}. Continuing with train/dev labels.")

# If sample is too tiny (e.g., only 1 label each), broaden using all items in sample
if sample_mode and (len(flow_opts) < 2 or len(subflow_opts) < 2):
    all_fl, all_sf = labels_from_convos(abcd)
    if len(flow_opts) < 2:     flow_opts = all_fl
    if len(subflow_opts) < 2:  subflow_opts = all_sf
    print("[sample] Broadened label sets using all sample convos.")

print(f"Flow label count: {len(flow_opts)} | Subflow label count: {len(subflow_opts)}")

# -----------------------------
# JSON schema + parsing
# -----------------------------
SCHEMA = {
  "personal": {"customer_name":"","email":"","member_level":"","phone":"","username":""},
  "order": {"street_address":"","full_address":"","city":"","num_products":"","order_id":"",
            "packaging":"","payment_method":"","products":"[]","purchase_date":"","state":"","zip_code":""},
  "product": {"names":[],"amounts":[]},
  "flow": "",
  "subflow": ""
}

def try_parse_json(text: str):
    if not text: return None
    text = text.strip()
    try:
        return json.loads(text)
    except Exception:
        s, e = text.find("{"), text.rfind("}")
        if s != -1 and e != -1 and e > s:
            cand = text[s:e+1]
            try:
                return json.loads(cand)
            except Exception:
                return None
    return None

def closest_label(pred: str, choices: List[str], cutoff: float = 0.6) -> str:
    if not pred or not choices:
        return ""
    if pred in choices:
        return pred
    # case-insensitive exact first
    lowmap = {c.lower(): c for c in choices}
    if pred.lower() in lowmap:
        return lowmap[pred.lower()]
    # fuzzy match to nearest valid label
    best = difflib.get_close_matches(pred, choices, n=1, cutoff=cutoff)
    return best[0] if best else ""

# -----------------------------
# Extractor with constrained choices
# -----------------------------
def extract_metadata_from_transcript(transcript: str,
                                     flow_choices: List[str],
                                     subflow_choices: List[str]) -> Dict[str, Any]:
    label_instr = (
        "CLASSIFICATION CONSTRAINTS:\n"
        f"- Valid flow labels (pick exactly one, copy verbatim): {flow_choices}\n"
        f"- Valid subflow labels (pick exactly one, copy verbatim): {subflow_choices}\n"
        "- Do NOT invent new labels. If uncertain, pick the most likely from the lists.\n"
    )
    prompt = (
        "Convert the customer-support dialog into structured metadata.\n\n"
        f"{label_instr}\n"
        "OUTPUT RULES:\n"
        "- Return STRICT JSON only (no prose, no markdown).\n"
        "- Use this exact schema and field types:\n"
        f"{json.dumps(SCHEMA, indent=2)}\n"
        "- If a field is missing, use \"\" or [] accordingly.\n"
        "- 'flow' and 'subflow' MUST be exactly one of the provided labels above.\n\n"
        "Dialog transcript:\n"
        f"{transcript}\n"
    )
    messages = [
        {"role":"system","content":"Always return valid JSON that exactly matches the schema. No explanations."},
        {"role":"user","content":prompt}
    ]
    models_to_try = [PRIMARY_MODEL] + [m for m in FALLBACK_MODELS if m != PRIMARY_MODEL]
    for m in models_to_try:
        content = safe_chat(messages, m)
        data = try_parse_json(content)
        if isinstance(data, dict):
            out = json.loads(json.dumps(SCHEMA))
            for k,v in data.items():
                out[k] = v
            # enforce / normalize labels
            out["flow"]    = closest_label(out.get("flow",""), flow_choices, cutoff=0.6)
            out["subflow"] = closest_label(out.get("subflow",""), subflow_choices, cutoff=0.6)
            return out
        if content:
            print(f"[warn:{m}] unparsable output (first 160 chars): {content[:160]}")
        time.sleep(REQUEST_DELAY_SEC)
    return json.loads(json.dumps(SCHEMA))

# -----------------------------
# Build TEST dataframe
# -----------------------------
test_rows = []
for convo in test_split[:MAX_TEST_SAMPLES]:
    sc = convo.get("scenario", {})
    test_rows.append({
        "convo_id": convo.get("convo_id",""),
        "flow": sc.get("flow",""),
        "subflow": sc.get("subflow",""),
        "transcript": convo_to_transcript(convo)
    })
test_df = pd.DataFrame(test_rows)
print("Test DataFrame shape:", test_df.shape)

# -----------------------------
# Predict on TEST only
# -----------------------------
preds = []
for t in tqdm(test_df["transcript"], desc="Predicting (test)"):
    preds.append(extract_metadata_from_transcript(t, flow_opts, subflow_opts))
    time.sleep(REQUEST_DELAY_SEC)

extracted = pd.json_normalize(preds, sep="_").add_prefix("extracted_")
final_df  = pd.concat([test_df.reset_index(drop=True), extracted.reset_index(drop=True)], axis=1)

display(final_df.head())

# -----------------------------
# Accuracy
# -----------------------------
for field in ["flow","subflow"]:
    gt = final_df[field].astype(str).fillna("")
    ex = final_df[f"extracted_{field}"].astype(str).fillna("")
    acc = (gt == ex).mean()
    print(f"Test {field} accuracy: {acc:.2%}")
