In [2]:
# --- SNIFF POC: deterministic, no-LLM boolean building, evidence-driven MeSH mining ---
# Copy this entire cell into Jupyter and run. Adjust CONFIG at the top.

import os, re, time, json, math, statistics as stats
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import requests
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer

# =========================
# CONFIG
# =========================

CONFIG = {
    "out_dir": "sniff_poc_out",
    "year_min": 2015,
    "languages": ["English","Portuguese","Spanish"],
    # seed lexical tokens (phrases) — no qualifiers here; we’ll append [tiab] later
    "population_terms_seed": ["pectus excavatum", "Nuss", "MIRPE"],  # avoid "adult" in query; assess age in records if needed
    "intervention_terms_seed": ["intercostal nerve cryoablation", "cryoanalgesia", "cryoablation", "INC", "cryotherapy"],
    # optional outcome tokens for quality scoring (NOT added to queries)
    "outcome_tokens": ["opioid", "oxycodone", "morphine", "hydromorphone", "pain", "VAS", "NRS"],
    # probe sizes & caps
    "probe_ids": 80,         # how many IDs to sample per candidate when probing
    "fetch_chunk": 200,
    "broad_target_min": 10,  # query-size soft window (year-aware)
    "broad_target_max": 10000,
    "focused_target_min": 1,
    "focused_target_max": 2000,
    "max_probe_attempts": 3,
    # scoring weights (heuristic)
    "w_pi": 1.8,         # P∧I conjunction in TA
    "w_outcome": 0.9,    # outcome tokens present
    "w_design": 0.7,     # pubtype primary-ish
    "w_recency": 0.6,    # scaled [0..1] vs year_min
    "w_tfidf": 1.2,      # tf-idf sim to lexical seed text (not raw user NLQ)
    # pubtype hints
    "primary_pubtypes": {
        "Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
        "Prospective Studies","Cohort Studies","Case-Control Studies","Comparative Study"
    },
    # focused hedge (lexical only; no [Publication Type] to keep indexing robust)
    "rct_hedge": ['randomized', 'randomised', 'randomization', 'random allocation'],
    # HTTP
    "http_timeout": 30,
    "email": "you@example.com",  # E-utilities polite usage
    "api_key": os.environ.get("ENTREZ_API_KEY", "")
}

# =========================
# UTIL
# =========================

def ensure_out_dir(p: str):
    os.makedirs(p, exist_ok=True)

def save_json(path: str, obj):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def write_tsv(path: str, rows: List[Dict[str, str]], header: List[str]):
    with open(path, "w", encoding="utf-8") as f:
        f.write("\t".join(header) + "\n")
        for r in rows:
            f.write("\t".join(str(r.get(h, "")) for h in header) + "\n")

def now():
    return time.strftime("%H:%M:%S")

def norm_space(s: str) -> str:
    return re.sub(r"\s+", " ", (s or "").strip())

def year_recency(y: Optional[int], ymin: int, current: int = 2025) -> float:
    if y is None:
        return 0.0
    span = max(1, current - ymin)
    return max(0.0, min(1.0, (y - ymin) / span))

def text_hits(text: str, tokens: List[str]) -> int:
    tl = (text or "").lower()
    n=0
    for t in tokens:
        t=t.lower()
        if t in tl:
            n+=1
    return n

def has_any(text: str, tokens: List[str]) -> bool:
    return text_hits(text, tokens) > 0

# =========================
# PUBMED E-UTILITIES
# =========================

EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
UA = "sniff-poc/0.1 (+local)"
HEADERS = {"User-Agent": UA, "Accept": "application/json"}

def esearch(query: str, mindate: Optional[int], retmax: int) -> Tuple[int, List[str]]:
    params = {
        "db": "pubmed",
        "retmode": "json",
        "retmax": str(retmax),
        "term": query,
        "email": CONFIG["email"]
    }
    if CONFIG["api_key"]:
        params["api_key"] = CONFIG["api_key"]
    if mindate:
        params["mindate"] = str(mindate)
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=params, timeout=CONFIG["http_timeout"])
    r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count", "0"))
    ids = [str(x) for x in js.get("idlist", [])]
    return count, ids

def efetch_xml(pmids: List[str], chunk: int = 200) -> str:
    xml_all=[]
    for i in range(0, len(pmids), chunk):
        sub = pmids[i:i+chunk]
        params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(sub),"email":CONFIG["email"]}
        if CONFIG["api_key"]:
            params["api_key"] = CONFIG["api_key"]
        r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent": UA}, params=params, timeout=CONFIG["http_timeout"])
        r.raise_for_status()
        xml_all.append(r.text)
        time.sleep(0.08)
    return "\n".join(xml_all)

def parse_pubmed_xml(xml_text: str) -> List[Dict]:
    # lightweight XML parsing via regex/ElementTree hybrid to avoid heavy deps
    import xml.etree.ElementTree as ET
    out=[]
    root = ET.fromstring(xml_text)
    def join_text(node) -> str:
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = (art.findtext(".//PMID") or "").strip()
        title = join_text(art.find(".//ArticleTitle")).strip()
        abst_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(join_text(n).strip() for n in abst_nodes) if abst_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\b(19|20)\d{2}\b", s)
                if m:
                    year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text.strip() for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh=[]
        for mh in art.findall(".//MeshHeadingList/MeshHeading"):
            desc = mh.findtext("./DescriptorName") or ""
            majr = (mh.find("./DescriptorName").attrib.get("MajorTopicYN","N") == "Y") if mh.find("./DescriptorName") is not None else False
            mesh.append({"term": desc, "majr": bool(majr)})
        out.append({
            "pmid": pmid, "title": title, "abstract": abstract, "year": year,
            "language": lang, "publication_types": pubtypes, "mesh": mesh
        })
    return out

# =========================
# QUERY BUILDING (LEXICAL-ONLY)
# =========================

def token_to_tiab(tok: str) -> str:
    tok = tok.strip()
    if " " in tok or "-" in tok or "/" in tok:
        return f"\"{tok}\"[tiab]"
    return f"{tok}[tiab]"

def or_block(terms: List[str]) -> str:
    terms = [t for t in terms if t.strip()]
    if not terms: return ""
    return "(" + " OR ".join(token_to_tiab(t) for t in terms) + ")"

def build_broad(pop_terms: List[str], int_terms: List[str]) -> str:
    pb = or_block(pop_terms)
    ib = or_block(int_terms)
    if not pb or not ib:
        return ""
    return f"{pb} AND {ib}"

def build_focused_from(broad_query: str, rct_hedge: List[str]) -> str:
    hedge = "(" + " OR ".join(token_to_tiab(h) for h in rct_hedge) + ")"
    return f"({broad_query}) AND {hedge}"

# =========================
# SCORING
# =========================

def tfidf_sim(records: List[Dict], query_text: str) -> List[float]:
    docs = [norm_space((r["title"] or "") + " " + (r["abstract"] or "")) for r in records]
    vec = TfidfVectorizer(ngram_range=(1,3), lowercase=True, max_features=120000)
    X = vec.fit_transform(docs)
    q = vec.transform([query_text])
    Xa = X.toarray(); qa = q.toarray()
    denom = (np.linalg.norm(Xa, axis=1) * (np.linalg.norm(qa) + 1e-12) + 1e-12)
    sims = (Xa @ qa.T).ravel() / denom
    return sims.tolist()

def build_seed_query_text(pop_terms: List[str], int_terms: List[str]) -> str:
    # For tf-idf, build a lexical seed string (not the user NLQ).
    return " ".join(pop_terms + int_terms)

def record_score(rec: Dict, pop_terms: List[str], int_terms: List[str], outcome_tokens: List[str],
                 tfidf_val: float, ymin: int, w: Dict[str,float], primary_pubtypes: set) -> Tuple[float, Dict[str,float]]:
    t = norm_space((rec["title"] or "") + " " + (rec["abstract"] or ""))
    pi = 1.0 if (has_any(t, pop_terms) and has_any(t, int_terms)) else 0.0
    outc = 1.0 if has_any(t, outcome_tokens) else 0.0
    design = 1.0 if set(rec.get("publication_types", [])) & primary_pubtypes else 0.0
    recn = year_recency(rec.get("year"), ymin)
    score = (w["pi"]*pi + w["outcome"]*outc + w["design"]*design + w["recency"]*recn + w["tfidf"]*float(tfidf_val))
    feats = {"pi":pi,"outcome":outc,"design":design,"recency":recn,"tfidf":float(tfidf_val)}
    return float(score), feats

def query_quality(sample_scores: List[float]) -> float:
    if not sample_scores: return 0.0
    # robust: median + 0.25*mean (caps very spiky distributions)
    return float(np.median(sample_scores) + 0.25*np.mean(sample_scores))

# =========================
# MeSH MINING (evidence-driven)
# =========================

INTERVENTION_PRIOR = {
    "Cryosurgery", "Anesthesia, Epidural", "Nerve Block", "Analgesia", "Intercostal Nerves",
    "Anesthesia, Conduction", "Pain Management", "Thoracic Surgery, Video-Assisted"
}
POPULATION_PRIOR = {
    "Pectus Excavatum", "Thoracoscopy", "Thoracic Wall", "Pectus Carinatum"
}

def mine_mesh(records: List[Dict], pop_terms: List[str], int_terms: List[str]) -> Dict:
    # Count MeSH frequency and correlate with P/I lexical hits (for weak role classification)
    freq = Counter(); majr = Counter()
    assoc_pop = Counter(); assoc_int = Counter()
    for r in records:
        t = norm_space((r["title"] or "") + " " + (r["abstract"] or ""))
        p_hit = has_any(t, pop_terms)
        i_hit = has_any(t, int_terms)
        for m in r.get("mesh", []):
            term = m["term"]
            freq[term]+=1
            if m.get("majr"): majr[term]+=1
            if p_hit: assoc_pop[term]+=1
            if i_hit: assoc_int[term]+=1
    # role score: prior + association differential + majr bonus
    candidates=[]
    for term, n in freq.items():
        ap = assoc_pop[term]; ai = assoc_int[term]
        diff = (ai - ap) / max(1.0, n)  # positive → intervention-ish
        prior_i = 1.0 if term in INTERVENTION_PRIOR else 0.0
        prior_p = 1.0 if term in POPULATION_PRIOR else 0.0
        role_raw = diff + 0.5*prior_i - 0.5*prior_p + 0.2*(majr[term]>0)
        candidates.append({
            "term": term,
            "freq": int(n),
            "majr": int(majr[term]),
            "assoc_pop": int(ap),
            "assoc_int": int(ai),
            "role_score": float(role_raw),
            "role": "I" if role_raw>0.2 else ("P" if role_raw<-0.2 else "U")
        })
    # rank
    candidates.sort(key=lambda x: (-x["freq"], -x["role_score"], x["term"]))
    # select shortlists
    top_p = [c for c in candidates if c["role"]=="P"][:12]
    top_i = [c for c in candidates if c["role"]=="I"][:12]
    return {
        "summary": {"unique": len(candidates)},
        "top_population": top_p,
        "top_intervention": top_i,
        "all": candidates
    }

# =========================
# TOKEN REFINEMENT (from evidence)
# =========================

def refine_tokens(pop_terms: List[str], int_terms: List[str], mesh_mining: Dict) -> Tuple[List[str], List[str], Dict]:
    # Add high-confidence MeSH-derived surface forms as lexical tokens (no MeSH qualifiers in final query)
    p_add = []
    for m in mesh_mining.get("top_population", []):
        t = m["term"]
        if t.lower() not in [x.lower() for x in pop_terms] and m["freq"]>=2:
            p_add.append(t)
    i_add = []
    for m in mesh_mining.get("top_intervention", []):
        t = m["term"]
        if t.lower() not in [x.lower() for x in int_terms] and m["freq"]>=2:
            i_add.append(t)

    # Normalize: keep concise tokens (1–3 words)
    def filt(tokens):
        out=[]
        for tok in tokens:
            w = tok.split()
            if 1 <= len(w) <= 4:
                out.append(tok)
        return out

    p_new = pop_terms + filt(p_add)
    i_new = int_terms + filt(i_add)

    # Deduplicate (case-insensitive), preserve order
    def dedup(seq):
        seen=set(); out=[]
        for x in seq:
            k=x.lower()
            if k not in seen:
                out.append(x); seen.add(k)
        return out

    p_final = dedup(p_new)
    i_final = dedup(i_new)

    meta = {"added_population": [t for t in p_final if t not in pop_terms],
            "added_intervention": [t for t in i_final if t not in int_terms]}
    return p_final, i_final, meta

# =========================
# PROBE / EVALUATE CANDIDATES
# =========================

@dataclass
class ProbeResult:
    query: str
    total_count: int
    sampled_ids: int
    rq: float
    stats: Dict[str,float]
    sample_rows: List[Dict]

def probe_query(query: str, pop_terms: List[str], int_terms: List[str], outcome_tokens: List[str],
                ymin: int, target_sample: int) -> ProbeResult:
    # esearch
    total, ids = esearch(query, mindate=ymin, retmax=target_sample)
    if total == 0 or not ids:
        return ProbeResult(query=query, total_count=total, sampled_ids=0, rq=0.0, stats={}, sample_rows=[])
    # efetch over ids
    xml = efetch_xml(ids, chunk=CONFIG["fetch_chunk"])
    recs = parse_pubmed_xml(xml)
    # tf-idf sims vs seed text
    seed_text = build_seed_query_text(pop_terms, int_terms)
    tfidf_vals = tfidf_sim(recs, seed_text)
    # score records
    ws = {"pi":CONFIG["w_pi"],"outcome":CONFIG["w_outcome"],"design":CONFIG["w_design"],"recency":CONFIG["w_recency"],"tfidf":CONFIG["w_tfidf"]}
    sample_scores=[]; rows=[]
    pi_hits=0; outcome_hits=0; design_hits=0
    for rec, tv in zip(recs, tfidf_vals):
        s, feats = record_score(rec, pop_terms, int_terms, outcome_tokens, tv, ymin, ws, CONFIG["primary_pubtypes"])
        sample_scores.append(s)
        pi_hits += feats["pi"]
        outcome_hits += feats["outcome"]
        design_hits += feats["design"]
        rows.append({
            "pmid": rec["pmid"],
            "year": rec["year"] or "",
            "language": rec["language"] or "",
            "pubtypes": ";".join(rec.get("publication_types", [])),
            "title": norm_space(rec["title"])[:180],
            "pi_hit": int(feats["pi"]),
            "outcome_hit": int(feats["outcome"]),
            "design_hit": int(feats["design"]),
            "recency": round(feats["recency"],3),
            "tfidf": round(feats["tfidf"],3),
            "score": round(s,3)
        })
    rq = query_quality(sample_scores)
    n = len(recs)
    stats_out = {
        "n_sample": n,
        "pi_rate": round(pi_hits/max(1,n),3),
        "outcome_rate": round(outcome_hits/max(1,n),3),
        "design_rate": round(design_hits/max(1,n),3),
        "median_score": round(float(np.median(sample_scores)) if sample_scores else 0.0,3),
        "mean_score": round(float(np.mean(sample_scores)) if sample_scores else 0.0,3)
    }
    return ProbeResult(query=query, total_count=total, sampled_ids=len(ids), rq=float(rq), stats=stats_out, sample_rows=rows)

def target_window(year_min: int) -> Tuple[int,int,int,int]:
    # make caps year-aware (older windows → allow a bit larger)
    years = max(1, 2025 - year_min)
    scale = 1.0 + min(1.0, years/15.0)*0.4  # up to +40%
    bmin = int(CONFIG["broad_target_min"]*scale)
    bmax = int(CONFIG["broad_target_max"]*scale)
    fmin = int(CONFIG["focused_target_min"]*scale)
    fmax = int(CONFIG["focused_target_max"]*scale)
    return bmin, bmax, fmin, fmax

def in_range(count: int, lo: int, hi: int) -> bool:
    return lo <= count <= hi

# =========================
# MAIN ORCHESTRATION
# =========================

def sniff_poc():
    out_dir = CONFIG["out_dir"]; ensure_out_dir(out_dir)
    year_min = CONFIG["year_min"]
    pop_terms = CONFIG["population_terms_seed"][:]
    int_terms = CONFIG["intervention_terms_seed"][:]
    outcomes  = CONFIG["outcome_tokens"][:]
    rct_hedge = CONFIG["rct_hedge"][:]
    bmin,bmax,fmin,fmax = target_window(year_min)

    # 1) Build B0 (lexical-only) and probe
    broad_0 = build_broad(pop_terms, int_terms)
    print(f"{now()}  B0 query:", broad_0)
    probe_log=[]
    pr0 = probe_query(broad_0, pop_terms, int_terms, outcomes, year_min, CONFIG["probe_ids"])
    probe_log.append({"stage":"B0","total":pr0.total_count,"rq":round(pr0.rq,3),"stats":pr0.stats})
    print(f"{now()}  B0 total={pr0.total_count} rq={pr0.rq:.3f} stats={pr0.stats}")

    # Write initial candidates tsv
    write_tsv(os.path.join(out_dir, "sniff_candidates_b0.tsv"),
              pr0.sample_rows, ["pmid","year","language","pubtypes","pi_hit","outcome_hit","design_hit","recency","tfidf","score","title"])

    # 2) MeSH mining from B0 sample
    # fetch fuller for mining if sample small
    ids_for_mesh = [r["pmid"] for r in pr0.sample_rows][:CONFIG["probe_ids"]]
    mesh_records = parse_pubmed_xml(efetch_xml(ids_for_mesh, chunk=CONFIG["fetch_chunk"])) if ids_for_mesh else []
    mesh_info = mine_mesh(mesh_records, pop_terms, int_terms)
    save_json(os.path.join(out_dir, "mesh_mining.json"), mesh_info)

    # 3) Token refinement (deterministic, MeSH → lexical surface)
    pop_terms2, int_terms2, meta_add = refine_tokens(pop_terms, int_terms, mesh_info)
    save_json(os.path.join(out_dir, "token_refinement.json"), {"before":{"P":pop_terms,"I":int_terms},"after":{"P":pop_terms2,"I":int_terms2},"added":meta_add})
    pop_terms = pop_terms2; int_terms = int_terms2

    # 4) Candidate generation & evaluation loop (a few deterministic variants)
    candidates = []
    # base refined broad
    candidates.append(("broad_refined", build_broad(pop_terms, int_terms)))
    # intervention-tight (favor cryo forms)
    tight_I = [t for t in int_terms if "cryo" in t.lower() or "intercostal nerve cryoablation" in t.lower() or t.lower()=="inc"]
    if len(tight_I)>=1:
        candidates.append(("broad_I_tight", build_broad(pop_terms, tight_I)))
    # population-tight (drop MIRPE/Nuss if too noisy? keep core PE phrase only)
    if "pectus excavatum" in [p.lower() for p in pop_terms]:
        p_core = ["pectus excavatum"]
        candidates.append(("broad_P_core", build_broad(p_core, int_terms)))

    trials=[]
    best_broad = None
    best_broad_rq = -1.0

    for name, q in candidates:
        pr = probe_query(q, pop_terms, int_terms, outcomes, year_min, CONFIG["probe_ids"])
        probe_log.append({"stage":name,"total":pr.total_count,"rq":round(pr.rq,3),"stats":pr.stats})
        trials.append({"name": name, "query": q, "total": pr.total_count, "rq": pr.rq, "stats": pr.stats})
        print(f"{now()}  {name} total={pr.total_count} rq={pr.rq:.3f} stats={pr.stats}")
        # choose best broad by rq within soft window; if none in-range, pick highest rq anyway
        if pr.rq > best_broad_rq and (in_range(pr.total_count, bmin, bmax) or best_broad is None):
            best_broad = (q, pr)
            best_broad_rq = pr.rq

    # fallback to B0 if everything worse
    if best_broad is None or (best_broad[1].rq < pr0.rq and in_range(pr0.total_count, bmin, bmax)):
        best_broad = (broad_0, pr0)

    # 5) Build focused from chosen broad (lexical hedge) + probe
    focused_q = build_focused_from(best_broad[0], rct_hedge)
    pr_f = probe_query(focused_q, pop_terms, int_terms, outcomes, year_min, CONFIG["probe_ids"])
    probe_log.append({"stage":"focused_lex", "total":pr_f.total_count, "rq":round(pr_f.rq,3), "stats":pr_f.stats})
    print(f"{now()}  focused total={pr_f.total_count} rq={pr_f.rq:.3f} stats={pr_f.stats}")

    # 6) Decide accept / minor repair
    retrieval_plan = {"broad": best_broad[0], "focused": focused_q}
    plan_meta = {
        "broad_eval": {"total": best_broad[1].total_count, "rq": best_broad[1].rq, "stats": best_broad[1].stats,
                       "window_ok": in_range(best_broad[1].total_count, bmin, bmax), "target_window": [bmin,bmax]},
        "focused_eval": {"total": pr_f.total_count, "rq": pr_f.rq, "stats": pr_f.stats,
                         "window_ok": in_range(pr_f.total_count, fmin, fmax), "target_window": [fmin,fmax]},
        "probe_log": probe_log
    }

    # 7) Write artifacts
    save_json(os.path.join(out_dir, "retrieval_plan.json"), retrieval_plan)
    save_json(os.path.join(out_dir, "retrieval_plan_eval.json"), plan_meta)
    write_tsv(os.path.join(out_dir, "sniff_candidates_best_broad.tsv"),
              best_broad[1].sample_rows, ["pmid","year","language","pubtypes","pi_hit","outcome_hit","design_hit","recency","tfidf","score","title"])
    write_tsv(os.path.join(out_dir, "sniff_candidates_focused.tsv"),
              pr_f.sample_rows, ["pmid","year","language","pubtypes","pi_hit","outcome_hit","design_hit","recency","tfidf","score","title"])

    # 8) Console summary
    print("\n=== SUMMARY ===")
    print("Chosen BROAD:", retrieval_plan["broad"])
    print("  total:", plan_meta["broad_eval"]["total"], "rq:", round(plan_meta["broad_eval"]["rq"],3), "window_ok:", plan_meta["broad_eval"]["window_ok"])
    print("Chosen FOCUSED:", retrieval_plan["focused"])
    print("  total:", plan_meta["focused_eval"]["total"], "rq:", round(plan_meta["focused_eval"]["rq"],3), "window_ok:", plan_meta["focused_eval"]["window_ok"])
    print(f"Artifacts written under: {out_dir}")

# Uncomment to run in notebook immediately:
sniff_poc()


22:55:59  B0 query: ("pectus excavatum"[tiab] OR Nuss[tiab] OR MIRPE[tiab]) AND ("intercostal nerve cryoablation"[tiab] OR cryoanalgesia[tiab] OR cryoablation[tiab] OR INC[tiab] OR cryotherapy[tiab])
22:56:02  B0 total=108 rq=4.049 stats={'n_sample': 80, 'pi_rate': 0.988, 'outcome_rate': 0.925, 'design_rate': 0.075, 'median_score': 3.248, 'mean_score': 3.203}
22:56:05  broad_refined total=542 rq=4.099 stats={'n_sample': 78, 'pi_rate': 0.885, 'outcome_rate': 0.872, 'design_rate': 0.192, 'median_score': 3.316, 'mean_score': 3.133}
22:56:07  broad_I_tight total=140 rq=4.053 stats={'n_sample': 80, 'pi_rate': 0.963, 'outcome_rate': 0.925, 'design_rate': 0.075, 'median_score': 3.259, 'mean_score': 3.175}
22:56:10  broad_P_core total=231 rq=4.097 stats={'n_sample': 80, 'pi_rate': 0.988, 'outcome_rate': 0.95, 'design_rate': 0.087, 'median_score': 3.279, 'mean_score': 3.272}
22:56:13  focused total=112 rq=3.993 stats={'n_sample': 79, 'pi_rate': 0.848, 'outcome_rate': 0.937, 'design_rate': 0.418

In [20]:
# SNIPPET: end-to-end "sniff" PoC with semantic LLM oversight (single cell)
# - Deterministic retrieval + query assembly
# - Qwen 4B-thinking for: term extraction, MeSH role-tagging, reprompts
# - Gemma-mini for: cheap title/abstract sanity screen of samples
# - Writes artifacts into sniff_poc_out/
#
# CONFIGURE before running:
#   - LM Studio at http://127.0.0.1:1234
#   - QWEN_MODEL and GEMMA_MODEL
#   - ENTREZ_EMAIL (and optionally ENTREZ_API_KEY via env)
#
# Usage:
#   1) Paste your natural-language question into USER_NLQ below.
#   2) Run the cell. Inspect printed summary + files in sniff_poc_out/.

import os, json, time, re, textwrap, random, pathlib
from collections import Counter, defaultdict
import requests
from xml.etree import ElementTree as ET

# ----------------------------
# Config
# ----------------------------
LMSTUDIO_BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234")
QWEN_MODEL    = os.getenv("QWEN_MODEL", "qwen/qwen3-4b")
GEMMA_MODEL   = os.getenv("GEMMA_MODEL", "gemma-3n-e2b-it")

ENTREZ_EMAIL   = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY = os.getenv("ENTREZ_API_KEY", "")
HTTP_TIMEOUT   = int(os.getenv("HTTP_TIMEOUT", "300"))

OUT_DIR = pathlib.Path("sniff_poc_out")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Windows / caps (soft)
YEAR_MIN_DEFAULT = 2015
BROAD_TARGET = (50, 5000)       # ok 10–10k
FOCUSED_TARGET = (3, 500)       # ok 1–2000
BROAD_OK = (10, 10000)
FOCUSED_OK = (1, 2000)

SAMPLE_N = 5                   # sample size for T/A sanity screen per candidate
RCT_HEDGE_LEX = '(randomized[tiab] OR randomised[tiab] OR randomization[tiab] OR "random allocation"[tiab])'

# ----------------------------
# Helpers: LM chat + JSON fences (+ robust JSON repair)
# ----------------------------
def lm_chat(model: str, system: str, user: str, temperature=0.0, max_tokens=8000, response_format=None, stop=None):
    url = f"{LMSTUDIO_BASE.rstrip('/')}/v1/chat/completions"
    body = {
        "model": model,
        "messages": [{"role":"system","content":system},{"role":"user","content":user}],
        "temperature": float(temperature),
        "max_tokens": int(max_tokens),
        "stream": False
    }
    if response_format is not None:
        body["response_format"] = response_format
    if stop is not None:
        body["stop"] = stop
    r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]


_BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
_END   = re.compile(r"\s*END_JSON", re.I)
FENCE  = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)

def _sanitize_json_str(s: str) -> str:
    # normalize curly quotes; remove trailing commas before } or ]
    s = s.replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json_block_or_fence(txt: str) -> str:
    # prefer BEGIN_JSON...END_JSON
    blocks = []
    pos=0
    while True:
        m1 = _BEGIN.search(txt, pos)
        if not m1: break
        m2 = _END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])

    # then fenced code
    fences = FENCE.findall(txt)
    if fences:
        return _sanitize_json_str(fences[-1])

    # last {...} by brace scan (balanced)
    s = txt
    last_obj = None
    stack = 0; start = None
    for i,ch in enumerate(s):
        if ch == '{':
            if stack == 0: start = i
            stack += 1
        elif ch == '}':
            if stack > 0:
                stack -= 1
                if stack == 0 and start is not None:
                    last_obj = s[start:i+1]
    if last_obj:
        return _sanitize_json_str(last_obj)
    raise ValueError("No JSON-like content found")

# Minimal JSON-repair via LLM against a template
REPAIR_SYSTEM = "You repair malformed JSON to exactly match the given template keys. Return ONLY one JSON object between BEGIN_JSON/END_JSON."
def repair_user(template_json: str, bad_output: str) -> str:
    return f"""TEMPLATE_JSON:
{template_json}

BAD_OUTPUT:
{bad_output}

TASK: Output valid JSON matching TEMPLATE_JSON keys (fill missing with empty arrays/strings). No prose.

BEGIN_JSON
{{}}
END_JSON
"""

TERMS_TEMPLATE   = {"population":[],"intervention":[],"comparators":[],"outcomes":[],"must_have":[],"avoid":[]}
MESH_TAG_TEMPLATE= {"labels":[{"mesh":"","role":"G","keep":False,"why":""}]}
PASSA_TEMPLATE   = {"pmid":"","decision":"","reason":"","confidence":0.0,"population_quote":"","intervention_quote":""}

STRICT_JSON_RULES = (
  "Return ONLY one JSON object. No analysis, no preface, no notes. "
  "Wrap it EXACTLY with:\nBEGIN_JSON\n{...}\nEND_JSON"
)

def ask_json_strict(model: str, system: str, user: str, template: dict, max_tokens=8000):
    # 1) Try a strict call that forbids prose and stops at END_JSON
    user_strict = f"{user}\n\n{STRICT_JSON_RULES}"
    raw = lm_chat(model, system, user_strict, temperature=0.0, max_tokens=max_tokens, stop=["END_JSON"])
    try:
        return json.loads(extract_json_block_or_fence(raw))
    except Exception:
        # 2) Repair pass: same strictness + template
        repaired = lm_chat(
            model,
            REPAIR_SYSTEM,
            repair_user(json.dumps(template, ensure_ascii=False, indent=2), raw) + "\n\n" + STRICT_JSON_RULES,
            temperature=0.0,
            max_tokens=max_tokens,
            stop=["END_JSON"]
        )
        return json.loads(extract_json_block_or_fence(repaired))


def ask_json(model: str, system: str, user: str, template: dict, max_tokens=8000):
    raw = lm_chat(model, system, user, temperature=0.0, max_tokens=max_tokens)
    try:
        return json.loads(extract_json_block_or_fence(raw))
    except Exception:
        # attempt repair
        repaired = lm_chat(
            model,
            REPAIR_SYSTEM,
            repair_user(json.dumps(template, ensure_ascii=False, indent=2), raw),
            temperature=0.0,
            max_tokens=max_tokens
        )
        return json.loads(extract_json_block_or_fence(repaired))
    
ask_json = ask_json_strict

# ----------------------------
# PubMed E-utilities (esearch/efetch)
# ----------------------------
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
HEADERS = {"User-Agent": "sniff-poc/0.1 (+local)", "Accept": "application/json"}

def esearch_count_and_ids(term: str, mindate: int|None):
    p = {
        "db":"pubmed","retmode":"json","term":term,"retmax":5000,
        "email":ENTREZ_EMAIL,"usehistory":"y"
    }
    if ENTREZ_API_KEY: p["api_key"]=ENTREZ_API_KEY
    if mindate: p["mindate"]=str(mindate)
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=p, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count","0"))
    webenv = js.get("webenv")
    qk = js.get("querykey")
    if not count or not webenv or not qk:
        return 0, []
    # fetch up to 5k IDs (enough for sniff)
    r2 = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params={
        "db":"pubmed","retmode":"json","retmax":5000,"retstart":0,"email":ENTREZ_EMAIL,
        "WebEnv":webenv,"query_key":qk, **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
    }, timeout=HTTP_TIMEOUT)
    r2.raise_for_status()
    ids = r2.json().get("esearchresult",{}).get("idlist",[])
    return count, [str(x) for x in ids]

def efetch_xml(pmids):
    if not pmids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(pmids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent": "sniff-poc/0.1"}, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str):
    out = []
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node): 
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({
            "pmid": pmid, "title": title, "abstract": abstract, "year": year, "language": lang,
            "publication_types": pubtypes, "mesh": mesh
        })
    return out

# ----------------------------
# LLM prompts (semantic steps)
# ----------------------------
TERMS_SYSTEM = "You extract controlled, compact term lists for biomedical retrieval. Return strict JSON only."
def terms_user(nlq: str):
    return f"""From the natural-language question below, produce compact term lists.

NATURAL_LANGUAGE_QUESTION:
<<<
{nlq}
>>>

Rules:
- Return JSON with arrays of P/I/C/O strings: {{ "population":[], "intervention":[], "comparators":[], "outcomes":[] }}
- Strings must be concise phrases (no boolean, no field tags, no quotes/brackets).
- Include common synonyms and acronyms (e.g., Nuss, MIRPE, cryoanalgesia).
- Add 2–5 must_have tokens in "must_have" that anchor topicality (e.g., MIRPE, Nuss, cryoablation).
- Add 2–5 avoid tokens in "avoid" if obvious confounders (e.g., pediatric oncology if off-topic).
- Keep each list ≤ 12 items.

Return ONLY:

BEGIN_JSON
{{...}}
END_JSON
"""

MESH_TAG_SYSTEM = "You classify MeSH descriptors into roles relative to a PICOS."
def mesh_tag_user(p_terms, i_terms, descriptors):
    return f"""Classify each MeSH descriptor as one of: P (population/procedure context), I (intervention/analgesia), O (outcome), C (comparator/technique), G (generic context), X (irrelevant).
Also provide a 'keep' boolean (true if useful for building search), and a 1-line rationale.

P_TERMS = {p_terms}
I_TERMS = {i_terms}

DESCRIPTORS = {descriptors}

Return ONLY:

BEGIN_JSON
{{ "labels": [{{"mesh":"...", "role":"P|I|O|C|G|X", "keep": true|false, "why": "..."}}] }}
END_JSON
"""

PASSA_SYS = "You are a strict PRISMA title/abstract screener for effects triage. Return JSON only."
def passa_user(proto_p, proto_i, proto_outcomes, record):
    return f"""Protocol (simplified):
Population: {proto_p}
Intervention: {proto_i}
Outcomes (signals): {proto_outcomes}
Include primary/comparative human studies; exclude admin/guidelines.

Record:
PMID: {record['pmid']}
Title: {record['title']}
Abstract: {record['abstract']}
PubTypes: {record['publication_types']}
Year: {record['year']}
Lang: {record['language']}

Return:

BEGIN_JSON
{{"pmid":"{record['pmid']}",
  "decision":"include|borderline|exclude",
  "reason":"population_mismatch|intervention_mismatch|design_ineligible|off_topic|language|year|insufficient_info",
  "confidence": 0.0,
  "population_quote":"", "intervention_quote":""
}}
END_JSON
"""

REPROMPT_SYS = "You write crisp, actionable reprompts (≤2 sentences) to fix information gaps."
def reprompt_user(summary_problem: str):
    return f"""Context of failure:
{summary_problem}

Write ≤2 sentences telling the user exactly what to clarify or relax (no noise).

Return:

BEGIN_JSON
{{"reprompt":"..."}}
END_JSON
"""

# ----------------------------
# Deterministic: build queries from term lists
# ----------------------------
def or_block(terms, field="tiab"):
    toks=[]
    for t in terms:
        t=t.strip()
        if not t: continue
        if " " in t or "-" in t:
            toks.append(f"\"{t}\"[{field}]")
        else:
            toks.append(f"{t}[{field}]")
    if not toks: return ""
    return "(" + " OR ".join(toks) + ")"

def build_broad(p_syn, i_syn, extra=None, field="tiab"):
    P = or_block(p_syn, field)
    I = or_block(i_syn, field)
    X = (" AND " + or_block(extra, field)) if extra else ""
    if not P or not I:
        return None
    return f"{P} AND {I}{X}"

def build_focused(broad_core):
    return f"({broad_core}) AND {RCT_HEDGE_LEX}"

# ----------------------------
# Lexical stats + cheap “quality” score
# ----------------------------
PRIMARY_HINTS = {"Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
                 "Prospective Studies","Cohort Studies","Case-Control Studies","Comparative Study"}

def lexical_stats(records, p_terms, i_terms, outcomes):
    def hits(text, terms): 
        tl=(text or "").lower()
        return sum(1 for t in terms if t and t.lower() in tl)
    n = min(SAMPLE_N, len(records))
    sample = records[:n]
    pi_rate=0; out_rate=0; design_rate=0; scores=[]
    for r in sample:
        t = (r['title'] or "") + "\n" + (r['abstract'] or "")
        pi = (hits(t, p_terms)+hits(t, i_terms))>0
        po = hits(t, outcomes)>0
        de = len(set(r['publication_types']) & PRIMARY_HINTS)>0
        pi_rate += 1 if pi else 0
        out_rate+= 1 if po else 0
        design_rate+= 1 if de else 0
        # quick score: weighted signals
        s = (2.0*(1 if pi else 0) + 1.0*(1 if po else 0) + 0.5*(1 if de else 0))
        scores.append(s)
    if n==0: 
        return {"n_sample":0,"pi_rate":0,"outcome_rate":0,"design_rate":0,"median_score":0,"mean_score":0}
    scores.sort()
    med = scores[n//2]
    mean = sum(scores)/n
    return {"n_sample":n,"pi_rate":round(pi_rate/n,3),"outcome_rate":round(out_rate/n,3),
            "design_rate":round(design_rate/n,3),"median_score":round(med,3),"mean_score":round(mean,3)}

def rq_quality(stats):
    # emphasize central tendency and P/I presence
    return round( stats["median_score"] + 0.25*stats["mean_score"] + 0.5*stats["pi_rate"] + 0.25*stats["outcome_rate"], 3)

# ----------------------------
# Main SNiff runner
# ----------------------------
def sniff_nlq(USER_NLQ: str, year_min: int = YEAR_MIN_DEFAULT):
    # 1) Qwen: extract seed term lists
    terms_js = ask_json(QWEN_MODEL, TERMS_SYSTEM, terms_user(USER_NLQ), TERMS_TEMPLATE, max_tokens=8000)
    P0 = terms_js.get("population", []) or []
    I0 = terms_js.get("intervention", []) or []
    O0 = terms_js.get("outcomes", []) or []
    MUST = terms_js.get("must_have", []) or []
    AVOID = terms_js.get("avoid", []) or []
    # Persist
    (OUT_DIR/"seed_terms.json").write_text(json.dumps(terms_js, indent=2, ensure_ascii=False), encoding="utf-8")

    # 2) First broad (pure TIAB, P ∧ I)
    candidates = []
    def try_query(name, q):
        cnt, ids = esearch_count_and_ids(q, year_min)
        # sample up to SAMPLE_N for eval
        xml = efetch_xml(ids[:SAMPLE_N])
        recs = parse_pubmed_xml(xml)
        stats = lexical_stats(recs, P0, I0, O0)
        rq = rq_quality(stats)
        candidates.append({"name":name,"query":q,"total":cnt,"stats":stats,"rq":rq,"ids":ids})
        print(f"{time.strftime('%H:%M:%S')}  {name} total={cnt} rq={rq} stats={stats}")
        return cnt, recs

    q_b0 = build_broad(P0, I0)
    if not q_b0:
        rp = ask_json(lm_chat(QWEN_MODEL, REPROMPT_SYS, reprompt_user("Failed to construct core P and I term groups from your question.")))
        raise SystemExit("REPROMPT: " + rp.get("reprompt","need clarification"))
    print(f"{time.strftime('%H:%M:%S')}  B0 query: {q_b0}")
    _, recs_b0 = try_query("B0", q_b0)

    # 3) Mine MeSH from real hits
    mesh_all = Counter()
    for r in recs_b0:
        for m in r.get("mesh", []) or []:
            mesh_all[m] += 1
    top_mesh = [m for m,_ in mesh_all.most_common(40)]
    (OUT_DIR/"mesh_raw.json").write_text(json.dumps({"top_mesh":top_mesh, "counts":mesh_all.most_common(100)}, indent=2, ensure_ascii=False), encoding="utf-8")

    # 4) Qwen: role-tag mined MeSH relative to P/I
    mesh_tag_js = ask_json(QWEN_MODEL, MESH_TAG_SYSTEM, mesh_tag_user(P0, I0, top_mesh), MESH_TAG_TEMPLATE, max_tokens=8000)
    (OUT_DIR/"mesh_tagged.json").write_text(json.dumps(mesh_tag_js, indent=2, ensure_ascii=False), encoding="utf-8")
    keep_mesh = [x["mesh"] for x in mesh_tag_js.get("labels", []) if x.get("keep") and x.get("role") in ("P","I","O")]
    keepP = [x["mesh"] for x in mesh_tag_js.get("labels", []) if x.get("keep") and x.get("role")=="P"]
    keepI = [x["mesh"] for x in mesh_tag_js.get("labels", []) if x.get("keep") and x.get("role")=="I"]

    # 5) Generate refined broads (TIAB only), mixing mined words (as surface tokens) — no fielded MeSH
    #    (We use MeSH words as plain tokens, respecting your “broad shouldn’t rely on field qualifiers”.)
    def expand_terms(base, extra):
        seen=set([b.lower() for b in base])
        out=list(base)
        for e in extra:
            w=e.strip()
            if not w: continue
            if w.lower() not in seen:
                out.append(w)
                seen.add(w.lower())
        return out[:12]

    P_core = expand_terms(P0, keepP[:6])
    I_core = expand_terms(I0, keepI[:8])
    # Several candidate mixes
    variants = [
        ("broad_refined", build_broad(P_core, I_core, extra=MUST)),
        ("broad_I_tight", build_broad(P0, I_core)),
        ("broad_P_core",  build_broad(P_core, I0)),
    ]
    for name, q in variants:
        if q:
            try_query(name, q)

    # choose broad: best rq within BROAD_TARGET else within BROAD_OK
    def choose_broad():
        cands = [c for c in candidates if c["name"].startswith("broad") or c["name"]=="B0"]
        # prefer in-target window
        in_target = [c for c in cands if BROAD_TARGET[0] <= c["total"] <= BROAD_TARGET[1]]
        pool = in_target if in_target else [c for c in cands if BROAD_OK[0] <= c["total"] <= BROAD_OK[1]]
        if not pool:
            return max(cands, key=lambda x: x["rq"]) if cands else None
        return max(pool, key=lambda x: x["rq"])
    chosen_broad = choose_broad()
    if not chosen_broad:
        rp = ask_json(lm_chat(QWEN_MODEL, REPROMPT_SYS, reprompt_user("No broad query produced viable hit counts or relevance.")))
        (OUT_DIR/"reprompt.json").write_text(json.dumps(rp, indent=2, ensure_ascii=False), encoding="utf-8")
        print("\nREPROMPT:", rp.get("reprompt","need clarification"))
        return

    # 6) Build focused: lexical RCT hedge layered on chosen broad
    q_focused = build_focused(chosen_broad["query"])
    cntF, idsF = esearch_count_and_ids(q_focused, year_min)
    xmlF = efetch_xml(idsF[:SAMPLE_N])
    recsF = parse_pubmed_xml(xmlF)
    statsF = lexical_stats(recsF, P0, I0, O0); rqF = rq_quality(statsF)
    candidates.append({"name":"focused","query":q_focused,"total":cntF,"stats":statsF,"rq":rqF,"ids":idsF})
    print(f"{time.strftime('%H:%M:%S')}  focused total={cntF} rq={rqF} stats={statsF}")

    # 7) Gemma: cheap T/A sanity screen on samples (both chosen sets)
    def sanity_screen(records):
        if not records: return {"include":0,"borderline":0,"exclude":0,"n":0}
        inc=bor=exc=0
        for r in records[:min(SAMPLE_N,len(records))]:
            js = ask_json(GEMMA_MODEL, PASSA_SYS, passa_user(P0, I0, O0, r), PASSA_TEMPLATE, max_tokens=8000)
            d = (js.get("decision","") or "").lower()
            if d=="include": inc+=1
            elif d=="borderline": bor+=1
            else: exc+=1
        n=inc+bor+exc
        return {"include":inc,"borderline":bor,"exclude":exc,"n":n}

    chosen_broad_sample = parse_pubmed_xml(efetch_xml(chosen_broad["ids"][:SAMPLE_N]))
    sanity_broad = sanity_screen(chosen_broad_sample)
    sanity_focused = sanity_screen(recsF)

    # 8) Persist artifacts
    artifacts = {
        "nlq": USER_NLQ,
        "year_min": year_min,
        "seed_terms": terms_js,
        "mesh_top": top_mesh,
        "mesh_tagged": mesh_tag_js,
        "candidates": [
            {k:v for k,v in c.items() if k not in ("ids",)} for c in candidates
        ],
        "chosen": {
            "broad": {"query": chosen_broad["query"], "total": chosen_broad["total"], "rq": chosen_broad["rq"], "stats": chosen_broad["stats"], "sanity": sanity_broad},
            "focused": {"query": q_focused, "total": cntF, "rq": rqF, "stats": statsF, "sanity": sanity_focused},
        }
    }
    (OUT_DIR/"sniff_artifacts.json").write_text(json.dumps(artifacts, indent=2, ensure_ascii=False), encoding="utf-8")
    (OUT_DIR/"broad.txt").write_text(chosen_broad["query"], encoding="utf-8")
    (OUT_DIR/"focused.txt").write_text(q_focused, encoding="utf-8")

    # 9) Print summary
    print("\n=== SUMMARY ===")
    print("Chosen BROAD:", chosen_broad["query"])
    print(f"  total: {chosen_broad['total']} rq: {chosen_broad['rq']} window_ok: {BROAD_OK[0] <= chosen_broad['total'] <= BROAD_OK[1]}")
    print("Chosen FOCUSED:", q_focused)
    print(f"  total: {cntF} rq: {rqF} window_ok: {FOCUSED_OK[0] <= cntF <= FOCUSED_OK[1]}")
    print("Artifacts written under:", OUT_DIR)

# ----------------------------
# RUN: put your NLQ here
# ----------------------------
USER_NLQ = """Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE). Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE (the intervention of interest is INC, not the surgery). Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia. Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days. Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control. Year_min = 2015. Languages = English, Portuguese, Spanish."""
sniff_nlq(USER_NLQ, year_min=YEAR_MIN_DEFAULT)


02:15:55  B0 query: (adults[tiab] OR "pectus excavatum"[tiab] OR "Nuss surgery"[tiab] OR MIRPE[tiab] OR "minimally invasive repair"[tiab]) AND ("intercostal nerve cryoablation (INC)"[tiab] OR cryoanalgesia[tiab] OR "intraoperative analgesia"[tiab])
02:15:59  B0 total=53 rq=4.325 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.6, 'design_rate': 0.2, 'median_score': 3.0, 'mean_score': 2.7}
02:18:42  broad_refined total=1037 rq=2.7 stats={'n_sample': 5, 'pi_rate': 0.6, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 1.4}
02:18:48  broad_I_tight total=3309 rq=4.3 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.6, 'design_rate': 0.0, 'median_score': 3.0, 'mean_score': 2.6}
02:18:50  broad_P_core total=96 rq=2.9 stats={'n_sample': 5, 'pi_rate': 0.8, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 1.8}
02:18:54  focused total=10 rq=3.775 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.4, 'design_rate': 0.6, 'media

In [21]:
# SNIPPET: end-to-end "sniff" with concise report + idle model eviction
# - Deterministic retrieval + query assembly
# - Qwen 4B-thinking: term extraction, MeSH role-tagging, reprompts
# - Gemma-mini: quick title/abstract triage for chosen BROAD & FOCUSED
# - Console-first reporting; minimal artifacts (optional)
#
# CONFIGURE (env):
#   LMSTUDIO_BASE=http://127.0.0.1:1234
#   QWEN_MODEL=qwen/qwen3-4b
#   GEMMA_MODEL=gemma-3n-e2b-it
#   ENTREZ_EMAIL=you@example.com
#   ENTREZ_API_KEY=...
#   HTTP_TIMEOUT=300
#   SAMPLE_N=5
#   REPORT_TOP_K=8
#   WRITE_ARTIFACTS=1
#   LMSTUDIO_EVICT_SECS=5   (try to unload models after idle; best-effort)
#
# Usage:
#   1) Put NLQ into USER_NLQ at the bottom.
#   2) Run cell; read the printed report.

import os, json, time, re, threading, pathlib, requests
from collections import Counter
from xml.etree import ElementTree as ET

# ----------------------------
# Config
# ----------------------------
LMSTUDIO_BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234").rstrip("/")
QWEN_MODEL    = os.getenv("QWEN_MODEL", "qwen/qwen3-4b")
GEMMA_MODEL   = os.getenv("GEMMA_MODEL", "gemma-3n-e2b-it")

ENTREZ_EMAIL   = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY = os.getenv("ENTREZ_API_KEY", "")
HTTP_TIMEOUT   = int(os.getenv("HTTP_TIMEOUT", "300"))

SAMPLE_N       = int(os.getenv("SAMPLE_N", "5"))
REPORT_TOP_K   = int(os.getenv("REPORT_TOP_K", "8"))
WRITE_ARTIFACTS= bool(int(os.getenv("WRITE_ARTIFACTS", "1")))

# Windows / caps
YEAR_MIN_DEFAULT = 2015
BROAD_TARGET  = (50, 5000)     # sweet spot
FOCUSED_TARGET= (3, 500)
BROAD_OK      = (10, 10000)
FOCUSED_OK    = (1, 2000)

RCT_HEDGE_LEX = '(randomized[tiab] OR randomised[tiab] OR randomization[tiab] OR "random allocation"[tiab])'

OUT_DIR = pathlib.Path("sniff_poc_out"); OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# LM Studio idle evictor (best-effort)
# ----------------------------
_LM_LAST_USE = time.time()
_LM_EVICT_SECS = int(os.getenv("LMSTUDIO_EVICT_SECS", "5"))
_LM_EVICTOR_STOP = False

def _lmstudio_list_models():
    try:
        r = requests.get(f"{LMSTUDIO_BASE}/v1/models", timeout=10)
        r.raise_for_status()
        return r.json().get("data", [])
    except Exception:
        return []

def _lmstudio_try_unload(model_id: str):
    # Try a few likely endpoints; ignore failures (LM Studio versions vary)
    endpoints = [
        ("POST", f"{LMSTUDIO_BASE}/v1/unload", {"model": model_id}),
        ("POST", f"{LMSTUDIO_BASE}/v1/models/unload", {"model": model_id}),
        ("POST", f"{LMSTUDIO_BASE}/v1/models/unload_all", {}),
        ("POST", f"{LMSTUDIO_BASE}/unload", {"model": model_id}),
    ]
    for method, url, payload in endpoints:
        try:
            if method == "POST":
                rr = requests.post(url, json=payload, timeout=5)
            else:
                rr = requests.get(url, params=payload, timeout=5)
            if rr.status_code < 400:
                print(f"[LM STUDIO] Unload request OK: {url} ({model_id})")
                return True
        except Exception:
            pass
    return False

def _lmstudio_idle_evictor():
    # Periodically checks for idle and tries to unload loaded models
    while not _LM_EVICTOR_STOP:
        try:
            if _LM_EVICT_SECS > 0 and (time.time() - _LM_LAST_USE) >= _LM_EVICT_SECS:
                models = _lmstudio_list_models()
                any_loaded = False
                for m in models:
                    mid = m.get("id") or m.get("name") or ""
                    loaded = m.get("loaded") or m.get("isLoaded") or False
                    if mid and loaded:
                        any_loaded = True
                        _lmstudio_try_unload(mid)
                if any_loaded:
                    # give a small grace; then reset last-use so we don't spam
                    time.sleep(1.0)
                    globals()['_LM_LAST_USE'] = time.time()
            time.sleep(1.0)
        except Exception:
            time.sleep(2.0)

if _LM_EVICT_SECS > 0:
    _t = threading.Thread(target=_lmstudio_idle_evictor, daemon=True)
    _t.start()

def _touch_lm():
    globals()['_LM_LAST_USE'] = time.time()

# ----------------------------
# Helpers: LM chat + JSON fences (+ robust JSON repair)
# ----------------------------
def lm_chat(model: str, system: str, user: str, temperature=0.0, max_tokens=None, response_format=None, stop=None):
    _touch_lm()
    url = f"{LMSTUDIO_BASE}/v1/chat/completions"
    body = {
        "model": model,
        "messages": [{"role":"system","content":system},{"role":"user","content":user}],
        "temperature": float(temperature),
        "stream": False
    }
    if max_tokens is not None:
        body["max_tokens"] = int(max_tokens)
    if response_format is not None:
        body["response_format"] = response_format
    if stop is not None:
        body["stop"] = stop
    r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    _touch_lm()
    return r.json()["choices"][0]["message"]["content"]

_BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
_END   = re.compile(r"\s*END_JSON", re.I)
FENCE  = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)

def _sanitize_json_str(s: str) -> str:
    s = s.replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json_block_or_fence(txt: str) -> str:
    # prefer last BEGIN_JSON ... END_JSON block
    blocks, pos = [], 0
    while True:
        m1 = _BEGIN.search(txt, pos)
        if not m1: break
        m2 = _END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])
    fences = FENCE.findall(txt)
    if fences:
        return _sanitize_json_str(fences[-1])
    # last balanced {...}
    s = txt; last_obj=None; stack=0; start=None
    for i,ch in enumerate(s):
        if ch == '{':
            if stack == 0: start = i
            stack += 1
        elif ch == '}':
            if stack > 0:
                stack -= 1
                if stack == 0 and start is not None:
                    last_obj = s[start:i+1]
    if last_obj:
        return _sanitize_json_str(last_obj)
    raise ValueError("No JSON-like content found")

REPAIR_SYSTEM = "You repair malformed JSON to exactly match the given template keys. Return ONLY one JSON object between BEGIN_JSON/END_JSON."
def repair_user(template_json: str, bad_output: str) -> str:
    return f"""TEMPLATE_JSON:
{template_json}

BAD_OUTPUT:
{bad_output}

TASK: Output valid JSON matching TEMPLATE_JSON keys (fill missing with empty arrays/strings). No prose.

BEGIN_JSON
{{}}
END_JSON
"""

TERMS_TEMPLATE    = {"population":[],"intervention":[],"comparators":[],"outcomes":[],"must_have":[],"avoid":[]}
MESH_TAG_TEMPLATE = {"labels":[{"mesh":"","role":"G","keep":False,"why":""}]}
PASSA_TEMPLATE    = {"pmid":"","decision":"","reason":"","confidence":0.0,"population_quote":"","intervention_quote":""}

STRICT_JSON_RULES = (
  "Return ONLY one JSON object. No analysis, no preface, no notes. "
  "Wrap it EXACTLY with:\nBEGIN_JSON\n{...}\nEND_JSON"
)

def ask_json_strict(model: str, system: str, user: str, template: dict, max_tokens=None):
    user_strict = f"{user}\n\n{STRICT_JSON_RULES}"
    raw = lm_chat(model, system, user_strict, temperature=0.0, max_tokens=max_tokens, stop=["END_JSON"])
    try:
        return json.loads(extract_json_block_or_fence(raw))
    except Exception:
        repaired = lm_chat(
            model, REPAIR_SYSTEM,
            repair_user(json.dumps(template, ensure_ascii=False, indent=2), raw) + "\n\n" + STRICT_JSON_RULES,
            temperature=0.0, max_tokens=max_tokens, stop=["END_JSON"]
        )
        return json.loads(extract_json_block_or_fence(repaired))

ask_json = ask_json_strict  # force strict

# ----------------------------
# PubMed E-utilities
# ----------------------------
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
HEADERS = {"User-Agent": "sniff/0.2 (+local)", "Accept": "application/json"}

def esearch_count_and_ids(term: str, mindate: int|None):
    p = {"db":"pubmed","retmode":"json","term":term,"retmax":5000,"email":ENTREZ_EMAIL,"usehistory":"y"}
    if ENTREZ_API_KEY: p["api_key"]=ENTREZ_API_KEY
    if mindate: p["mindate"]=str(mindate)
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=p, timeout=HTTP_TIMEOUT); r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count","0")); webenv = js.get("webenv"); qk = js.get("querykey")
    if not count or not webenv or not qk: return 0, []
    r2 = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params={
        "db":"pubmed","retmode":"json","retmax":5000,"retstart":0,"email":ENTREZ_EMAIL,
        "WebEnv":webenv,"query_key":qk, **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
    }, timeout=HTTP_TIMEOUT); r2.raise_for_status()
    ids = r2.json().get("esearchresult",{}).get("idlist",[])
    return count, [str(x) for x in ids]

def efetch_xml(pmids):
    if not pmids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(pmids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent": "sniff/0.2"}, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str):
    out=[]
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node):
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({"pmid": pmid,"title": title,"abstract": abstract,"year": year,"language": lang,"publication_types": pubtypes,"mesh": mesh})
    return out

# ----------------------------
# LLM prompts (semantic steps)
# ----------------------------
TERMS_SYSTEM = "You extract controlled, compact term lists for biomedical retrieval. Return strict JSON only."
def terms_user(nlq: str):
    return f"""From the natural-language question below, produce compact term lists.

NATURAL_LANGUAGE_QUESTION:
<<<
{nlq}
>>>

Rules:
- Return JSON with arrays of P/I/C/O strings: {{ "population":[], "intervention":[], "comparators":[], "outcomes":[] }}
- Strings must be concise phrases (no boolean, no field tags, no quotes/brackets).
- Include common synonyms and acronyms (e.g., Nuss, MIRPE, cryoanalgesia).
- Add 2–5 must_have tokens in "must_have" that anchor topicality (e.g., MIRPE, Nuss, cryoablation).
- Add 2–5 avoid tokens in "avoid" if obvious confounders.
- Keep each list ≤ 12 items.

Return ONLY:

BEGIN_JSON
{{...}}
END_JSON
"""

MESH_TAG_SYSTEM = "You classify MeSH descriptors into roles relative to a PICOS."
def mesh_tag_user(p_terms, i_terms, descriptors):
    return f"""Classify each MeSH descriptor as one of: P (population/procedure context), I (intervention/analgesia), O (outcome), C (comparator/technique), G (generic context), X (irrelevant).
Also give 'keep' (true if useful for building search) and a 1-line rationale.

P_TERMS = {p_terms}
I_TERMS = {i_terms}
DESCRIPTORS = {descriptors}

Return ONLY:

BEGIN_JSON
{{ "labels": [{{"mesh":"...", "role":"P|I|O|C|G|X", "keep": true|false, "why": "..."}}] }}
END_JSON
"""

PASSA_SYS = "You are a strict PRISMA title/abstract screener for effects triage. Return JSON only."
def passa_user(proto_p, proto_i, proto_outcomes, record):
    return f"""Protocol (simplified):
Population: {proto_p}
Intervention: {proto_i}
Outcomes (signals): {proto_outcomes}
Include primary/comparative human studies; exclude admin/guidelines.

Record:
PMID: {record['pmid']}
Title: {record['title']}
Abstract: {record['abstract']}
PubTypes: {record['publication_types']}
Year: {record['year']}
Lang: {record['language']}

Return:

BEGIN_JSON
{{"pmid":"{record['pmid']}",
  "decision":"include|borderline|exclude",
  "reason":"population_mismatch|intervention_mismatch|design_ineligible|off_topic|language|year|insufficient_info",
  "confidence": 0.0,
  "population_quote":"", "intervention_quote":""
}}
END_JSON
"""

REPROMPT_SYS = "You write crisp, actionable reprompts (≤2 sentences) to fix information gaps."
def reprompt_user(summary_problem: str):
    return f"""Context of failure:
{summary_problem}

Write ≤2 sentences telling the user exactly what to clarify or relax.

Return:

BEGIN_JSON
{{"reprompt":"..."}}
END_JSON
"""

# ----------------------------
# Term distillation & query building
# ----------------------------
DEMOGRAPHIC_STOP = {"adult","adults","aged","male","female","child","children","adolescent","young adult","infant"}
GENERIC_I_STOP   = {"analgesia","pain management","nerve block","anesthesia, conduction","intraoperative analgesia","analgesics","analgesics, opioid","patient-controlled analgesia","perioperative care"}

ANCHORS_P = ["MIRPE","Nuss","pectus excavatum"]

def normalize_phrase(t: str) -> list[str]:
    t = (t or "").strip()
    if not t: return []
    out = {t}
    t2 = re.sub(r"\s*\([^)]*\)\s*", " ", t).strip()
    if t2 and t2 != t: out.add(t2)
    return list(out)

def distilled_terms_for_query(seed_terms: dict, mesh_tagged: dict|None = None):
    P_raw = list(seed_terms.get("population") or [])
    I_raw = list(seed_terms.get("intervention") or [])

    role_keep = set()
    if mesh_tagged and "labels" in mesh_tagged:
        for l in mesh_tagged["labels"]:
            if (l.get("role") or "") in {"P","I"} and l.get("keep"):
                role_keep.add((l.get("mesh") or "").lower())

    def keep_P(t):
        tl = t.lower()
        if tl in DEMOGRAPHIC_STOP: return False
        return True

    def keep_I(t):
        tl = t.lower()
        if tl in GENERIC_I_STOP: return False
        if role_keep and tl not in role_keep and not (tl.startswith("cryo") or "intercostal" in tl or tl=="inc"):
            return False
        return True

    P, I = [], []
    for p in P_raw:
        for v in normalize_phrase(p):
            if keep_P(v): P.append(v)
    for i in I_raw:
        for v in normalize_phrase(i):
            if keep_I(v): I.append(v)

    # ensure anchors in P
    anchors = []
    Pl = [x.lower() for x in P]
    for a in ANCHORS_P:
        if a.lower() in Pl:
            anchors.append(a)
    if not anchors:
        anchors = ANCHORS_P[:]  # inject anchors if missing
    # dedupe keeping order
    P = list(dict.fromkeys(P)); I = list(dict.fromkeys(I)); anchors = list(dict.fromkeys(anchors))
    return P, I, anchors

def or_block(terms, field="tiab"):
    toks=[]; seen=set()
    for t in terms:
        t = t.strip()
        if not t or t.lower() in seen: continue
        seen.add(t.lower())
        toks.append(f"\"{t}\"[{field}]" if (" " in t or "-" in t) else f"{t}[{field}]")
    return "(" + " OR ".join(toks) + ")" if toks else ""

def or_block_anchored(terms, anchors, field="tiab"):
    toks=[]; seen=set()
    for t in anchors + terms:
        t = t.strip()
        if not t or t.lower() in seen: continue
        seen.add(t.lower())
        toks.append(f"\"{t}\"[{field}]" if (" " in t or "-" in t) else f"{t}[{field}]")
    return "(" + " OR ".join(toks) + ")" if toks else ""

def build_broad_from(P_terms, I_terms, anchors, extra=None, field="tiab"):
    P = or_block_anchored(P_terms, anchors, field)
    I = or_block(I_terms, field)
    X = (" AND " + or_block(extra, field)) if extra else ""
    if not P or not I:
        return None
    return f"{P} AND {I}{X}"

def build_focused(broad_core):
    return f"({broad_core}) AND {RCT_HEDGE_LEX}"

# ----------------------------
# Metrics
# ----------------------------
PRIMARY_HINTS = {"Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
                 "Prospective Studies","Cohort Studies","Case-Control Studies","Comparative Study"}

def lexical_stats(records, p_terms, i_terms, outcomes):
    def hits(text, terms):
        tl=(text or "").lower()
        return sum(1 for t in terms if t and t.lower() in tl)
    n = min(SAMPLE_N, len(records))
    sample = records[:n]
    pi_rate=out_rate=design_rate=0
    scores=[]
    for r in sample:
        t = (r['title'] or "") + "\n" + (r['abstract'] or "")
        pi = (hits(t, p_terms)+hits(t, i_terms))>0
        po = hits(t, outcomes)>0
        de = len(set(r['publication_types']) & PRIMARY_HINTS)>0
        pi_rate += 1 if pi else 0
        out_rate+= 1 if po else 0
        design_rate+= 1 if de else 0
        s = (2.0*(1 if pi else 0) + 1.0*(1 if po else 0) + 0.5*(1 if de else 0))
        scores.append(s)
    if n==0:
        return {"n_sample":0,"pi_rate":0,"outcome_rate":0,"design_rate":0,"median_score":0,"mean_score":0}
    scores.sort()
    med = scores[n//2]
    mean = sum(scores)/n
    return {"n_sample":n,"pi_rate":round(pi_rate/n,3),"outcome_rate":round(out_rate/n,3),
            "design_rate":round(design_rate/n,3),"median_score":round(med,3),"mean_score":round(mean,3)}

def rq_quality(stats):
    return round(stats["median_score"] + 0.25*stats["mean_score"] + 0.5*stats["pi_rate"] + 0.25*stats["outcome_rate"], 3)

def top_titles(records, p_terms, i_terms, outcomes, k=8):
    rows=[]
    def hits(text, terms):
        tl=(text or "").lower()
        return sum(1 for t in terms if t and t.lower() in tl)
    for r in records:
        t = (r['title'] or "") + "\n" + (r['abstract'] or "")
        pi = (hits(t, p_terms)+hits(t, i_terms))>0
        po = hits(t, outcomes)>0
        de = len(set(r['publication_types']) & PRIMARY_HINTS)>0
        score = (2.0*(1 if pi else 0) + 1.0*(1 if po else 0) + 0.5*(1 if de else 0))
        rows.append((score, r['pmid'], r['year'], r['title'], pi, po, de))
    rows.sort(key=lambda x: (-x[0], -(x[2] or 0)))
    return rows[:k]

# ----------------------------
# Main runner
# ----------------------------
def sniff_nlq(USER_NLQ: str, year_min: int = YEAR_MIN_DEFAULT):
    # 1) Qwen: extract seed term lists
    terms_js = ask_json(QWEN_MODEL, TERMS_SYSTEM, terms_user(USER_NLQ), TERMS_TEMPLATE, max_tokens=None)
    P0 = terms_js.get("population", []) or []
    I0 = terms_js.get("intervention", []) or []
    O0 = terms_js.get("outcomes", []) or []
    MUST = terms_js.get("must_have", []) or []

    # 2) Initial broad (seed P/I)
    candidates=[]
    def try_query(name, q):
        cnt, ids = esearch_count_and_ids(q, year_min)
        xml = efetch_xml(ids[:SAMPLE_N])
        recs = parse_pubmed_xml(xml)
        stats = lexical_stats(recs, P0, I0, O0)
        rq = rq_quality(stats)
        candidates.append({"name":name,"query":q,"total":cnt,"stats":stats,"rq":rq,"ids":ids,"sample":recs})
        print(f"{time.strftime('%H:%M:%S')}  {name} hits={cnt} rq={rq} stats={stats}")
        return cnt, recs

    # 2a) MeSH from initial sample
    q_seed = build_broad_from(P0, I0, anchors=ANCHORS_P, extra=None)
    if not q_seed:
        rp = ask_json(QWEN_MODEL, REPROMPT_SYS, reprompt_user("Failed to construct core P and I term groups from your question."), {"reprompt":""}, max_tokens=None)
        raise SystemExit("REPROMPT: " + (rp.get("reprompt","need clarification")))
    _, recs_seed = try_query("B0_seed", q_seed)

    # 3) Mine MeSH from B0 sample and role-tag it
    mesh_all = Counter()
    for r in recs_seed:
        for m in r.get("mesh", []) or []:
            mesh_all[m] += 1
    top_mesh = [m for m,_ in mesh_all.most_common(40)]
    mesh_tag_js = ask_json(QWEN_MODEL, MESH_TAG_SYSTEM, mesh_tag_user(P0, I0, top_mesh), MESH_TAG_TEMPLATE, max_tokens=None)

    # 4) Distill terms (drop demographics/generic I; enforce anchors)
    Pq, Iq, anchors = distilled_terms_for_query(terms_js, mesh_tag_js)

    # 5) Build variants using distilled terms
    q_broad  = build_broad_from(Pq, Iq, anchors, extra=MUST)
    q_i_tight= build_broad_from(P0, Iq, anchors, extra=None)
    q_p_core = build_broad_from(Pq, I0, anchors, extra=None)

    _, _ = try_query("B1_broad", q_broad)
    if q_i_tight: try_query("B2_I_tight", q_i_tight)
    if q_p_core:  try_query("B3_P_core",  q_p_core)

    # choose broad: best rq within BROAD_TARGET else within BROAD_OK
    def choose_broad():
        cands = [c for c in candidates if c["name"].startswith("B")]
        in_target = [c for c in cands if BROAD_TARGET[0] <= c["total"] <= BROAD_TARGET[1]]
        pool = in_target if in_target else [c for c in cands if BROAD_OK[0] <= c["total"] <= BROAD_OK[1]]
        return max((pool or cands), key=lambda x: x["rq"]) if (pool or cands) else None

    chosen_broad = choose_broad()
    if not chosen_broad:
        rp = ask_json(QWEN_MODEL, REPROMPT_SYS, reprompt_user("No broad query produced viable hit counts or relevance."), {"reprompt":""}, max_tokens=None)
        print("\nREPROMPT:", rp.get("reprompt","need clarification")); return

    # 6) Focused
    q_focused = build_focused(chosen_broad["query"])
    cntF, idsF = esearch_count_and_ids(q_focused, year_min)
    xmlF = efetch_xml(idsF[:SAMPLE_N])
    recsF = parse_pubmed_xml(xmlF)
    statsF = lexical_stats(recsF, P0, I0, O0); rqF = rq_quality(statsF)
    focused = {"name":"F_focused","query":q_focused,"total":cntF,"stats":statsF,"rq":rqF,"ids":idsF,"sample":recsF}
    candidates.append(focused)

    # 7) Gemma triage for chosen sets
    def sanity_screen(records):
        if not records: return {"include":0,"borderline":0,"exclude":0,"n":0}
        inc=bor=exc=0
        for r in records[:min(SAMPLE_N,len(records))]:
            js = ask_json(GEMMA_MODEL, PASSA_SYS, passa_user(P0, I0, O0, r), PASSA_TEMPLATE, max_tokens=None)
            d = (js.get("decision","") or "").lower()
            if d=="include": inc+=1
            elif d=="borderline": bor+=1
            else: exc+=1
        n=inc+bor+exc
        return {"include":inc,"borderline":bor,"exclude":exc,"n":n}

    sanity_broad  = sanity_screen(chosen_broad["sample"])
    sanity_focused= sanity_screen(focused["sample"])

    # 8) Prepare report data
    def why_not(c, chosen):
        reasons=[]
        if c is chosen: return "chosen"
        if c["total"] < BROAD_OK[0] or c["total"] > BROAD_OK[1]: reasons.append("count_out_of_window")
        if c["rq"] < chosen["rq"]: reasons.append("lower_rq")
        return "; ".join(reasons) or "ok"

    kept_mesh = [x["mesh"] for x in mesh_tag_js.get("labels", []) if x.get("keep")]
    drop_mesh = [x["mesh"] for x in mesh_tag_js.get("labels", []) if not x.get("keep")]

    # 9) Print concise report
    def yn(x): return "yes" if x else "no"
    def ok_range(lo,hi,x): return lo<=x<=hi

    print("\n==================== SNIFF REPORT ====================")
    print("NLQ:")
    print("  ", USER_NLQ[:200].replace("\n"," ") + ("..." if len(USER_NLQ)>200 else ""))
    print("\nSEED TERMS (Qwen):")
    print("  P:", ", ".join(P0[:10]))
    print("  I:", ", ".join(I0[:10]))
    if terms_js.get("comparators"): print("  C:", ", ".join((terms_js.get("comparators") or [])[:10]))
    print("  O:", ", ".join(O0[:10]))
    if terms_js.get("must_have"):  print("  must_have:", ", ".join((terms_js.get("must_have") or [])[:8]))
    if terms_js.get("avoid"):      print("  avoid:", ", ".join((terms_js.get("avoid") or [])[:8]))

    print("\nDISTILLED FOR TIAB (anchors enforced):")
    print("  Pq:", ", ".join(Pq[:10]))
    print("  Iq:", ", ".join(Iq[:10]))
    print("  anchors:", ", ".join(anchors))

    print("\nCANDIDATES:")
    for c in candidates:
        if c["name"].startswith("F_"): continue
        s=c["stats"]
        print(f"  {c['name']}: hits={c['total']}, rq={c['rq']:.3f}, PI={s['pi_rate']:.2f}, OUT={s['outcome_rate']:.2f}, DESIGN={s['design_rate']:.2f}, reason={why_not(c, chosen_broad)}")

    print("\nCHOSEN BROAD")
    s = chosen_broad["stats"]
    print("  Query:", chosen_broad["query"])
    print(f"  Hits={chosen_broad['total']}  in_window={ok_range(*BROAD_OK, chosen_broad['total'])}  target_window={ok_range(*BROAD_TARGET, chosen_broad['total'])}")
    print(f"  Signals: PI={s['pi_rate']:.2f}  OUT={s['outcome_rate']:.2f}  DESIGN={s['design_rate']:.2f}  RQ={chosen_broad['rq']:.3f}")
    if sanity_broad["n"]:
        print(f"  Quick triage (Gemma): include={sanity_broad['include']} borderline={sanity_broad['borderline']} exclude={sanity_broad['exclude']} (n={sanity_broad['n']})")
    tb = top_titles(chosen_broad["sample"], P0, I0, O0, k=REPORT_TOP_K)
    print("  Top titles:")
    for sc, pmid, yr, ttl, pi, po, de in tb:
        flags = []
        if pi: flags.append("PI")
        if po: flags.append("OUT")
        if de: flags.append("DES")
        print(f"   • [{pmid}] ({yr})  {ttl[:95]}{'...' if len(ttl)>95 else ''}  | {' '.join(flags) or '-'}  | s={sc:.1f}")

    print("\nFOCUSED (BROAD ∧ RCT hedge)")
    s = focused["stats"]
    print("  Query:", focused["query"])
    print(f"  Hits={focused['total']}  in_window={ok_range(*FOCUSED_OK, focused['total'])}  target_window={ok_range(*FOCUSED_TARGET, focused['total'])}")
    print(f"  Signals: PI={s['pi_rate']:.2f}  OUT={s['outcome_rate']:.2f}  DESIGN={s['design_rate']:.2f}  RQ={focused['rq']:.3f}")
    if sanity_focused["n"]:
        print(f"  Quick triage (Gemma): include={sanity_focused['include']} borderline={sanity_focused['borderline']} exclude={sanity_focused['exclude']} (n={sanity_focused['n']})")
    tf = top_titles(focused["sample"], P0, I0, O0, k=REPORT_TOP_K)
    print("  Top titles:")
    for sc, pmid, yr, ttl, pi, po, de in tf:
        flags=[]
        if pi: flags.append("PI")
        if po: flags.append("OUT")
        if de: flags.append("DES")
        print(f"   • [{pmid}] ({yr})  {ttl[:95]}{'...' if len(ttl)>95 else ''}  | {' '.join(flags) or '-'}  | s={sc:.1f}")

    print("\nMESH mined (top kept vs dropped):")
    print("  kept  :", ", ".join(kept_mesh[:12]))
    print("  dropped:", ", ".join(drop_mesh[:12]))

    print("\n================== END OF REPORT =====================")

    # 10) Minimal artifacts (optional)
    if WRITE_ARTIFACTS:
        report = {
            "nlq": USER_NLQ,
            "year_min": year_min,
            "seed_terms": terms_js,
            "distilled": {"Pq": Pq, "Iq": Iq, "anchors": anchors},
            "mesh": {"top": top_mesh, "tagged": mesh_tag_js},
            "candidates": [
                {k: v for k,v in c.items() if k not in ("ids","sample")}
                for c in candidates if not c["name"].startswith("F_")
            ],
            "chosen": {
                "broad": {k: v for k,v in chosen_broad.items() if k not in ("ids","sample")},
                "focused": {k: v for k,v in focused.items() if k not in ("ids","sample")},
                "sanity": {"broad": sanity_broad, "focused": sanity_focused}
            }
        }
        (OUT_DIR/"broad.txt").write_text(chosen_broad["query"], encoding="utf-8")
        (OUT_DIR/"focused.txt").write_text(focused["query"], encoding="utf-8")
        (OUT_DIR/"sniff_report.json").write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8")
        print(f"\nArtifacts: broad.txt, focused.txt, sniff_report.json  -> {OUT_DIR.resolve()}")

# ----------------------------
# RUN: put your NLQ here
# ----------------------------
USER_NLQ = """Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE). Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE (the intervention of interest is INC, not the surgery). Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia. Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days. Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control. Year_min = 2015. Languages = English, Portuguese, Spanish."""
sniff_nlq(USER_NLQ, year_min=YEAR_MIN_DEFAULT)

# On notebook shutdown, stop the evictor
_LM_EVICTOR_STOP = True


02:45:23  B0_seed hits=20809 rq=3.1 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.2}
02:48:25  B1_broad hits=87 rq=3.2 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.4, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.4}
02:48:42  B2_I_tight hits=2232 rq=3.0 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.0, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.0}
02:48:44  B3_P_core hits=960 rq=3.1 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.2}

NLQ:
   Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE). Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE (th...

SEED TERMS (Qwen):
  P: adults, minimally invasive repair, pectus excavatum, Nuss, MIRPE
  I: intercostal nerve cryoablation, cryoanalgesia, INC, intraoperative, post

In [None]:
# SNIFF STAGE — Overhauled, protocol-aware, with concise report and LM Studio auto-evict
# - Qwen: seed extraction + MeSH role-tagging
# - Gemma: protocol-based title/abstract triage (with gates)
# - Deterministic PubMed probing with multiple query variants
# - Rich, human-readable report + machine artifacts
#
# CONFIGURE:
#   LM Studio server, model names, Entrez email/API, and optional env knobs below.
#
# ENV knobs (optional):
#   LMSTUDIO_BASE=http://127.0.0.1:1234
#   QWEN_MODEL=qwen/qwen3-4b
#   GEMMA_MODEL=gemma-3n-e2b-it
#   LM_KEEP_ALIVE_SEC=5
#   HTTP_TIMEOUT=300
#   SAMPLE_N=5
#   YEAR_MIN_DEFAULT=2015
#
# Usage:
#   Put your NLQ in USER_NLQ and run sniff_nlq(USER_NLQ).

import os, json, time, re, statistics, pathlib
from collections import Counter, defaultdict
import requests
from xml.etree import ElementTree as ET

# ----------------------------
# Config
# ----------------------------
LMSTUDIO_BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234")
QWEN_MODEL    = os.getenv("QWEN_MODEL", "qwen/qwen3-4b")
GEMMA_MODEL   = os.getenv("GEMMA_MODEL", "gemma-3n-e2b-it")
LM_KEEP_ALIVE = str(os.getenv("LM_KEEP_ALIVE_SEC", "5")) + "s"  # ask LM Studio to unload model ~5s after idle
ENTREZ_EMAIL   = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY = os.getenv("ENTREZ_API_KEY", "")
HTTP_TIMEOUT   = int(os.getenv("HTTP_TIMEOUT", "300"))

OUT_DIR = pathlib.Path("sniff_poc_out")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Windows / caps (soft)
YEAR_MIN_DEFAULT = int(os.getenv("YEAR_MIN_DEFAULT", "2015"))
BROAD_TARGET   = (50, 5000)
FOCUSED_TARGET = (3, 500)
BROAD_OK       = (10, 10000)
FOCUSED_OK     = (1, 2000)

SAMPLE_N = int(os.getenv("SAMPLE_N", "5"))
RCT_HEDGE_LEX = '(randomized[tiab] OR randomised[tiab] OR randomization[tiab] OR "random allocation"[tiab])'

LANG_NAME_TO_TAG = {
    "english":"english", "en":"english",
    "portuguese":"portuguese", "pt":"portuguese",
    "spanish":"spanish", "es":"spanish"
}

PRIMARY_HINTS = {
    "Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
    "Prospective Studies","Cohort Studies","Case-Control Studies","Comparative Study"
}

PTYPES_TRIAL_COMP = [
    "randomized controlled trial","controlled clinical trial","comparative study",
    "clinical trial","cohort studies","case-control studies","prospective studies"
]

# ----------------------------
# HTTP helpers — LM Studio chat (with keep_alive)
# ----------------------------
def lm_chat(model: str, system: str, user: str, temperature=0.0, max_tokens=8000, response_format=None, stop=None):
    """
    LM Studio-compatible /v1/chat/completions call.
    Adds keep_alive to request; server may unload model after ~LM_KEEP_ALIVE idle.
    """
    url = f"{LMSTUDIO_BASE.rstrip('/')}/v1/chat/completions"
    body = {
        "model": model,
        "messages": [{"role":"system","content":system},{"role":"user","content":user}],
        "temperature": float(temperature),
        "max_tokens": int(max_tokens),
        "stream": False,
        "keep_alive": LM_KEEP_ALIVE
    }
    if response_format is not None:
        body["response_format"] = response_format
    if stop is not None:
        body["stop"] = stop
    r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]

# ----------------------------
# JSON extraction/repair
# ----------------------------
_BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
_END   = re.compile(r"\s*END_JSON", re.I)
FENCE  = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)

def _sanitize_json_str(s: str) -> str:
    s = s.replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json_block_or_fence(txt: str) -> str:
    blocks = []
    pos = 0
    while True:
        m1 = _BEGIN.search(txt, pos)
        if not m1: break
        m2 = _END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])
    fences = FENCE.findall(txt)
    if fences:
        return _sanitize_json_str(fences[-1])
    # last {...}
    s = txt
    last_obj = None
    stack = 0; start = None
    for i,ch in enumerate(s):
        if ch == '{':
            if stack == 0: start = i
            stack += 1
        elif ch == '}':
            if stack > 0:
                stack -= 1
                if stack == 0 and start is not None:
                    last_obj = s[start:i+1]
    if last_obj:
        return _sanitize_json_str(last_obj)
    raise ValueError("No JSON-like content found")

REPAIR_SYSTEM = "You repair malformed JSON to exactly match the given template keys. Return ONLY one JSON object between BEGIN_JSON/END_JSON."
def repair_user(template_json: str, bad_output: str) -> str:
    return f"""TEMPLATE_JSON:
{template_json}

BAD_OUTPUT:
{bad_output}

TASK: Output valid JSON matching TEMPLATE_JSON keys (fill missing with empty arrays/strings). No prose.

BEGIN_JSON
{{}}
END_JSON
"""

TERMS_TEMPLATE    = {"population":[],"intervention":[],"comparators":[],"outcomes":[],"must_have":[],"avoid":[]}
MESH_TAG_TEMPLATE = {"labels":[{"mesh":"","role":"G","keep":False,"why":""}]}
PASSA_TEMPLATE    = {"pmid":"","decision":"","reason":"","confidence":0.0,"population_quote":"","intervention_quote":""}

STRICT_JSON_RULES = (
  "Return ONLY one JSON object. No analysis, no preface, no notes. "
  "Wrap it EXACTLY with:\nBEGIN_JSON\n{...}\nEND_JSON"
)

def ask_json_strict(model: str, system: str, user: str, template: dict, max_tokens=8000):
    user_strict = f"{user}\n\n{STRICT_JSON_RULES}"
    raw = lm_chat(model, system, user_strict, temperature=0.0, max_tokens=max_tokens, stop=["END_JSON"])
    try:
        return json.loads(extract_json_block_or_fence(raw))
    except Exception:
        repaired = lm_chat(
            model,
            REPAIR_SYSTEM,
            repair_user(json.dumps(template, ensure_ascii=False, indent=2), raw) + "\n\n" + STRICT_JSON_RULES,
            temperature=0.0,
            max_tokens=max_tokens,
            stop=["END_JSON"]
        )
        return json.loads(extract_json_block_or_fence(repaired))

def ask_json(model: str, system: str, user: str, template: dict, max_tokens=8000):
    raw = lm_chat(model, system, user, temperature=0.0, max_tokens=max_tokens)
    try:
        return json.loads(extract_json_block_or_fence(raw))
    except Exception:
        repaired = lm_chat(
            model,
            REPAIR_SYSTEM,
            repair_user(json.dumps(template, ensure_ascii=False, indent=2), raw),
            temperature=0.0,
            max_tokens=max_tokens
        )
        return json.loads(extract_json_block_or_fence(repaired))

ask_json = ask_json_strict

# ----------------------------
# PubMed E-utilities
# ----------------------------
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
HEADERS = {"User-Agent": "sniff-poc/0.2 (+local)", "Accept": "application/json"}

def esearch_count_and_ids(term: str, mindate: int|None):
    p = {"db":"pubmed","retmode":"json","term":term,"retmax":5000,"email":ENTREZ_EMAIL,"usehistory":"y"}
    if ENTREZ_API_KEY: p["api_key"]=ENTREZ_API_KEY
    if mindate: p["mindate"]=str(mindate)
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=p, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count","0"))
    webenv = js.get("webenv"); qk = js.get("querykey")
    if not count or not webenv or not qk:
        return 0, []
    r2 = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params={
        "db":"pubmed","retmode":"json","retmax":5000,"retstart":0,"email":ENTREZ_EMAIL,
        "WebEnv":webenv,"query_key":qk, **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
    }, timeout=HTTP_TIMEOUT)
    r2.raise_for_status()
    ids = r2.json().get("esearchresult",{}).get("idlist",[])
    return count, [str(x) for x in ids]

def efetch_xml(pmids):
    if not pmids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(pmids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent": "sniff-poc/0.2"}, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str):
    out = []
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node):
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({
            "pmid": pmid, "title": title, "abstract": abstract, "year": year, "language": lang,
            "publication_types": pubtypes, "mesh": mesh
        })
    return out

# ----------------------------
# LLM prompts (semantic steps)
# ----------------------------
TERMS_SYSTEM = "You extract controlled, compact term lists for biomedical retrieval. Return strict JSON only."
def terms_user(nlq: str):
    return f"""From the natural-language question below, produce compact term lists.

NATURAL_LANGUAGE_QUESTION:
<<<
{nlq}
>>>

Rules:
- Return JSON with arrays of P/I/C/O strings: {{ "population":[], "intervention":[], "comparators":[], "outcomes":[] }}
- Strings must be concise phrases (no boolean, no field tags, no quotes/brackets).
- Include common synonyms and acronyms (e.g., Nuss, MIRPE, cryoanalgesia).
- Add 2–5 must_have tokens in "must_have" that anchor topicality (e.g., MIRPE, Nuss, cryoablation).
- Add 2–5 avoid tokens in "avoid" if obvious confounders (e.g., pediatric oncology if off-topic).
- Keep each list ≤ 12 items.

Return ONLY:

BEGIN_JSON
{{...}}
END_JSON
"""

MESH_TAG_SYSTEM = "You classify MeSH descriptors into roles relative to a PICOS."
def mesh_tag_user(p_terms, i_terms, descriptors):
    return f"""Classify each MeSH descriptor as one of: P (population/procedure context), I (intervention/analgesia), O (outcome), C (comparator/technique), G (generic context), X (irrelevant).
Also provide a 'keep' boolean (true if useful for building search), and a 1-line rationale.

P_TERMS = {p_terms}
I_TERMS = {i_terms}

DESCRIPTORS = {descriptors}

Return ONLY:

BEGIN_JSON
{{ "labels": [{{"mesh":"...", "role":"P|I|O|C|G|X", "keep": true|false, "why": "..."}}] }}
END_JSON
"""

# -------- Protocol brief builder (auto) --------
def build_protocol_brief(nlq: str, terms, year_min: int, languages: list[str]):
    langs_norm = [LANG_NAME_TO_TAG.get(x.lower(), x.lower()) for x in languages if x]
    langs_norm = [x for x in langs_norm if x in LANG_NAME_TO_TAG.values()]
    brief = {
        "focus": "Adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE).",
        "population_include": terms.get("population", []),
        "intervention_index": terms.get("intervention", []),
        "comparators_eligible": terms.get("comparators", []),
        "primary_outcomes": terms.get("outcomes", []),
        "designs": "RCTs preferred; include comparative cohort/case-control if RCTs absent.",
        "time_window_min_year": year_min,
        "languages": langs_norm or ["english","portuguese","spanish"],
        "must_have_anchors": terms.get("must_have", []),
        "hard_excludes": ["guidelines","editorials","technical notes without outcomes"]
    }
    # textual compact version Gemma sees
    txt = (
        f"FOCUS: adults with pectus excavatum undergoing MIRPE/Nuss.\n"
        f"POP (include): {', '.join(brief['population_include'])}\n"
        f"INDEX INTERVENTION: {', '.join(brief['intervention_index'])}\n"
        f"COMPARATORS: {', '.join(brief['comparators_eligible'])}\n"
        f"OUTCOMES (primary): {', '.join(brief['primary_outcomes'])}\n"
        f"DESIGNS: {brief['designs']}\n"
        f"YEAR_MIN: {year_min}; LANGUAGES: {', '.join(brief['languages'])}\n"
        f"ANCHORS (must have): {', '.join(brief['must_have_anchors']) or 'MIRPE, Nuss, cryoablation'}\n"
        f"HARD EXCLUDES: {', '.join(brief['hard_excludes'])}"
    )
    return brief, txt

PASSA_SYS = "You are a strict PRISMA title/abstract screener for effects triage. Return JSON only."
def passa_user(protocol_txt: str, record, enforce_adult=True):
    return f"""Systematic-review protocol (brief):
{protocol_txt}

Rules for this triage:
- Include only if population aligns with MIRPE/Nuss for pectus excavatum (prefer adults; pediatric-only is out unless adults are clearly included/analyzed).
- The index intervention is intercostal nerve cryoablation (INC) / cryoanalgesia used intraoperatively for MIRPE/Nuss analgesia.
- Eligible comparators: thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, systemic multimodal analgesia.
- Outcomes required for inclusion signal: postoperative opioid consumption (in-hospital or discharge) and/or pain scores within 0–7 days.
- Designs: RCTs preferred; if no RCTs, allow comparative cohort/case-control. Exclude guidelines, letters, editorials, technical notes lacking outcomes.
- If the abstract lacks clear nexus of BOTH (MIRPE|Nuss|pectus excavatum) AND (cryoablation|cryoanalgesia|INC) -> mark borderline or exclude.

Record:
PMID: {record['pmid']}
Title: {record['title']}
Abstract: {record['abstract']}
PubTypes: {record['publication_types']}
Year: {record['year']}
Lang: {record['language']}

Return:

BEGIN_JSON
{{"pmid":"{record['pmid']}",
  "decision":"include|borderline|exclude",
  "reason":"population_mismatch|intervention_mismatch|design_ineligible|off_topic|language|year|insufficient_info",
  "confidence": 0.0,
  "population_quote":"", "intervention_quote":""
}}
END_JSON
"""

# ----------------------------
# Deterministic query builders
# ----------------------------
def or_block(terms, field="tiab"):
    toks=[]
    for t in terms:
        if not t: continue
        t=t.strip()
        if not t: continue
        if " " in t or "-" in t or "(" in t:
            toks.append(f"\"{t}\"[{field}]")
        else:
            toks.append(f"{t}[{field}]")
    if not toks: return ""
    return "(" + " OR ".join(toks) + ")"

def lang_filter_block(langs):
    tags = [LANG_NAME_TO_TAG.get(x.lower(), x.lower()) for x in langs if x]
    tags = [t for t in tags if t in LANG_NAME_TO_TAG.values()]
    if not tags: return ""
    return "(" + " OR ".join(f"\"{t}\"[lang]" for t in tags) + ")"

def pubtype_block_pt():
    return "(" + " OR ".join(f"\"{pt}\"[pt]" for pt in PTYPES_TRIAL_COMP) + ")"

def build_broad(p_syn, i_syn, extra=None, field="tiab"):
    P = or_block(p_syn, field)
    I = or_block(i_syn, field)
    X = (" AND " + or_block(extra, field)) if extra else ""
    if not P or not I: return None
    return f"{P} AND {I}{X}"

def build_with_comparator(p_syn, i_syn, comps, anchors=None, field="tiab"):
    base = build_broad(p_syn, i_syn, field=field)
    C = or_block(comps, field)
    A = or_block(anchors, field) if anchors else ""
    if not base: return None
    out = base
    if C: out += " AND " + C
    if A: out += " AND " + A
    return out

def add_lang_filter(q, langs):
    L = lang_filter_block(langs)
    return f"{q} AND {L}" if L else q

def build_focused_lex(broad_core):
    return f"({broad_core}) AND {RCT_HEDGE_LEX}"

def build_focused_ptype(broad_core):
    return f"({broad_core}) AND {pubtype_block_pt()}"

# ----------------------------
# Metrics & gates
# ----------------------------
def lexical_stats(records, p_terms, i_terms, outcomes):
    def hits(text, terms):
        tl=(text or "").lower()
        return sum(1 for t in terms if t and t.lower() in tl)
    n = min(SAMPLE_N, len(records))
    sample = records[:n]
    pi_rate=0; out_rate=0; design_rate=0; scores=[]
    years=[]; langs=[]; ptypes=[]; mesh=Counter()
    for r in sample:
        t = (r['title'] or "") + "\n" + (r['abstract'] or "")
        pi = (hits(t, p_terms)+hits(t, i_terms))>0
        po = hits(t, outcomes)>0
        de = len(set(r['publication_types']) & PRIMARY_HINTS)>0
        pi_rate += 1 if pi else 0
        out_rate+= 1 if po else 0
        design_rate+= 1 if de else 0
        s = (2.0*(1 if pi else 0) + 1.0*(1 if po else 0) + 0.5*(1 if de else 0))
        scores.append(s)
        if r["year"]: years.append(r["year"])
        if r["language"]: langs.append(r["language"])
        ptypes.extend(r["publication_types"])
        mesh.update(r.get("mesh") or [])
    if n==0:
        return {"n_sample":0,"pi_rate":0,"outcome_rate":0,"design_rate":0,"median_score":0,"mean_score":0,
                "years":{},"languages":{},"ptypes":{},"top_mesh":[]}
    scores.sort()
    med = scores[n//2]
    mean = sum(scores)/n
    summary = {
        "n_sample": n,
        "pi_rate": round(pi_rate/n,3),
        "outcome_rate": round(out_rate/n,3),
        "design_rate": round(design_rate/n,3),
        "median_score": round(med,3),
        "mean_score": round(mean,3),
        "years": {"median": (statistics.median(years) if years else None), "min": (min(years) if years else None), "max": (max(years) if years else None)},
        "languages": dict(Counter(langs).most_common(5)),
        "ptypes": dict(Counter(ptypes).most_common(8)),
        "top_mesh": [m for m,_ in mesh.most_common(10)]
    }
    return summary

def rq_quality(stats):
    return round( stats["median_score"] + 0.25*stats["mean_score"] + 0.5*stats["pi_rate"] + 0.25*stats["outcome_rate"], 3)

# Gemma gating: require anchors; for focused, require trial/comparator cue
ANCHOR_P  = re.compile(r"\b(nuss|mirpe|pectus\s+excavatum)\b", re.I)
ANCHOR_I  = re.compile(r"\b(cryoablation|cryoanalgesia|\binc\b)\b", re.I)
TRIAL_COMPARATOR_CUE = re.compile(r"\b(randomi[sz]ed?|trial|epidural|paravertebral|intercostal nerve block|erector spinae|comparative|cohort|case[- ]control)\b", re.I)

def gemma_gate(record_text: str, is_focused: bool):
    ok_anchor = bool(ANCHOR_P.search(record_text) and ANCHOR_I.search(record_text))
    if not ok_anchor:
        return False
    if is_focused:
        return bool(TRIAL_COMPARATOR_CUE.search(record_text))
    return True

# ----------------------------
# Main SNiff runner
# ----------------------------
def sniff_nlq(USER_NLQ: str, year_min: int = YEAR_MIN_DEFAULT, languages: list[str] = ["English","Portuguese","Spanish"]):
    # 1) Qwen: extract seed term lists
    terms_js = ask_json(QWEN_MODEL, TERMS_SYSTEM, terms_user(USER_NLQ), TERMS_TEMPLATE, max_tokens=8000)
    (OUT_DIR/"seed_terms.json").write_text(json.dumps(terms_js, indent=2, ensure_ascii=False), encoding="utf-8")

    P0 = terms_js.get("population", []) or []
    I0 = terms_js.get("intervention", []) or []
    C0 = terms_js.get("comparators", []) or []
    O0 = terms_js.get("outcomes", []) or []
    MUST = terms_js.get("must_have", []) or []

    # 2) Seed probe (B0) to mine MeSH
    def try_query(name, q):
        cnt, ids = esearch_count_and_ids(q, year_min)
        xml = efetch_xml(ids[:SAMPLE_N])
        recs = parse_pubmed_xml(xml)
        stats = lexical_stats(recs, P0, I0, O0)
        rq = rq_quality(stats)
        return {"name":name,"query":q,"total":cnt,"stats":stats,"rq":rq,"ids":ids,"sample":recs}

    anchors_for_P = [x for x in MUST if x.lower() in ("mirpe","nuss","pectus excavatum","cryoablation")]
    if not anchors_for_P:
        anchors_for_P = ["MIRPE","Nuss","cryoablation"]

    q_b0 = build_broad(P0, I0, extra=anchors_for_P)
    print(f"{time.strftime('%H:%M:%S')}  B0_seed probing...")
    B0 = try_query("B0_seed", q_b0)

    # 3) MeSH mining and role tagging
    mesh_all = Counter()
    for r in B0["sample"]:
        for m in r.get("mesh", []) or []:
            mesh_all[m] += 1
    top_mesh = [m for m,_ in mesh_all.most_common(40)]
    (OUT_DIR/"mesh_raw.json").write_text(json.dumps({"top_mesh":top_mesh, "counts":mesh_all.most_common(100)}, indent=2, ensure_ascii=False), encoding="utf-8")

    mesh_tag_js = ask_json(QWEN_MODEL, MESH_TAG_SYSTEM, mesh_tag_user(P0, I0, top_mesh), MESH_TAG_TEMPLATE, max_tokens=8000)
    (OUT_DIR/"mesh_tagged.json").write_text(json.dumps(mesh_tag_js, indent=2, ensure_ascii=False), encoding="utf-8")
    keepP = [x["mesh"] for x in mesh_tag_js.get("labels", []) if x.get("keep") and x.get("role")=="P"]
    keepI = [x["mesh"] for x in mesh_tag_js.get("labels", []) if x.get("keep") and x.get("role")=="I"]

    def expand_terms(base, extra, limit=12):
        seen=set([b.lower() for b in base])
        out=list(base)
        for e in extra:
            w=e.strip()
            if not w: continue
            if w.lower() not in seen:
                out.append(w)
                seen.add(w.lower())
        return out[:limit]

    P_core = expand_terms(P0, keepP[:6])
    I_core = expand_terms(I0, keepI[:8])

    # 4) Protocol brief (shown & used by Gemma)
    protocol, protocol_txt = build_protocol_brief(USER_NLQ, terms_js, year_min, languages)

    # 5) Build candidates
    candidates = []
    def add_variant(name, q):
        if not q: return
        res = try_query(name, q)
        print(f"{time.strftime('%H:%M:%S')}  {name} hits={res['total']} rq={res['rq']} stats={res['stats']}")
        candidates.append(res)

    add_variant("B1_broad", build_broad(P_core, I_core, extra=anchors_for_P))
    add_variant("B2_I_tight", build_broad(P0, I_core, extra=anchors_for_P))
    add_variant("B3_P_core", build_broad(P_core, I0, extra=anchors_for_P))
    add_variant("B4_with_comparator", build_with_comparator(P_core, I_core, C0, anchors=anchors_for_P))
    # Language filtered variant on the best of the above later (to avoid redundant E-queries)

    # Choose BROAD by window then RQ
    def choose_broad():
        pool = [c for c in candidates if c["name"].startswith("B")]
        in_target = [c for c in pool if BROAD_TARGET[0] <= c["total"] <= BROAD_TARGET[1]]
        if in_target:
            return max(in_target, key=lambda x: x["rq"])
        pool2 = [c for c in pool if BROAD_OK[0] <= c["total"] <= BROAD_OK[1]]
        if pool2:
            return max(pool2, key=lambda x: x["rq"])
        return max(pool, key=lambda x: x["rq"]) if pool else B0

    chosen_broad = choose_broad()

    # Add explicit language filter variant on chosen broad (if not already present)
    lang_filtered_query = add_lang_filter(chosen_broad["query"], protocol["languages"])
    B_lang = try_query("B_lang_filtered", lang_filtered_query)
    candidates.append(B_lang)
    print(f"{time.strftime('%H:%M:%S')}  B_lang_filtered hits={B_lang['total']} rq={B_lang['rq']} stats={B_lang['stats']}")

    # Re-choose between chosen_broad and language-filtered if both valid
    if BROAD_OK[0] <= B_lang["total"] <= BROAD_OK[1] and B_lang["rq"] >= chosen_broad["rq"]*0.98:
        chosen_broad = B_lang

    # 6) Focused queries (lexical hedge + ptype)
    qF_lex = build_focused_lex(chosen_broad["query"])
    F_lex = try_query("F_lexical", qF_lex)
    qF_pt  = build_focused_ptype(chosen_broad["query"])
    F_pt   = try_query("F_ptype", qF_pt)

    focused_cands = [F_lex, F_pt]
    def choose_focused():
        pool = [c for c in focused_cands]
        in_target = [c for c in pool if FOCUSED_TARGET[0] <= c["total"] <= FOCUSED_TARGET[1]]
        if in_target:
            return max(in_target, key=lambda x: x["rq"])
        pool2 = [c for c in pool if FOCUSED_OK[0] <= c["total"] <= FOCUSED_OK[1]]
        if pool2:
            return max(pool2, key=lambda x: x["rq"])
        return max(pool, key=lambda x: x["rq"])
    chosen_focused = choose_focused()

    # 7) Gemma: protocol-based sanity screen with gates
    def sanity_screen(records, is_focused=False):
        if not records: return {"include":0,"borderline":0,"exclude":0,"n":0}
        inc=bor=exc=0
        examples=[]
        for r in records[:min(SAMPLE_N,len(records))]:
            js = ask_json(GEMMA_MODEL, PASSA_SYS, passa_user(protocol_txt, r), PASSA_TEMPLATE, max_tokens=8000)
            decision = (js.get("decision","") or "").lower()
            # Gates
            rt = ((r['title'] or "") + " " + (r['abstract'] or ""))
            if not gemma_gate(rt, is_focused=is_focused):
                if decision == "include":
                    decision = "borderline" if not is_focused else "exclude"
            if decision=="include": inc+=1
            elif decision=="borderline": bor+=1
            else: exc+=1
            examples.append({"pmid":r["pmid"],"title":r["title"][:140],
                             "flags":" ".join([
                                 "PI" if any(t.lower() in rt.lower() for t in (P0+I0)) else "",
                                 "OUT" if any(t.lower() in rt.lower() for t in O0) else "",
                                 "DES" if (set(r["publication_types"]) & PRIMARY_HINTS) else ""
                             ]).strip()})
        n=inc+bor+exc
        return {"include":inc,"borderline":bor,"exclude":exc,"n":n,"examples":examples}

    sample_broad = chosen_broad["sample"]
    sample_focused = chosen_focused["sample"]
    sanity_broad = sanity_screen(sample_broad, is_focused=False)
    sanity_focused = sanity_screen(sample_focused, is_focused=True)

    # 8) Persist artifacts
    artifacts = {
        "nlq": USER_NLQ,
        "year_min": year_min,
        "languages": protocol["languages"],
        "seed_terms": terms_js,
        "protocol_brief": protocol,
        "mesh_top": top_mesh,
        "mesh_tagged": mesh_tag_js,
        "candidates": [
            {k:(v if k!="sample" else None) for k,v in c.items()} for c in [B0]+candidates+[F_lex,F_pt]
        ],
        "chosen": {
            "broad": {"name": chosen_broad["name"], "query": chosen_broad["query"], "total": chosen_broad["total"], "rq": chosen_broad["rq"], "stats": chosen_broad["stats"], "sanity": sanity_broad},
            "focused": {"name": chosen_focused["name"], "query": chosen_focused["query"], "total": chosen_focused["total"], "rq": chosen_focused["rq"], "stats": chosen_focused["stats"], "sanity": sanity_focused},
        }
    }
    (OUT_DIR/"broad.txt").write_text(chosen_broad["query"], encoding="utf-8")
    (OUT_DIR/"focused.txt").write_text(chosen_focused["query"], encoding="utf-8")
    (OUT_DIR/"sniff_artifacts.json").write_text(json.dumps(artifacts, indent=2, ensure_ascii=False), encoding="utf-8")

    # 9) Human-readable report (concise)
    def fmt_stats(s):
        return f"PI={s['pi_rate']:.2f} OUT={s['outcome_rate']:.2f} DES={s['design_rate']:.2f}  RQ={rq_quality(s):.3f}"

    def top_titles(sample):
        out=[]
        for r in sample[:min(SAMPLE_N,len(sample))]:
            flags=[]
            tt=(r['title'] or "")
            txt=(r['title'] or "")+" "+(r['abstract'] or "")
            if any(t.lower() in txt.lower() for t in (P0+I0)): flags.append("PI")
            if any(t.lower() in txt.lower() for t in O0): flags.append("OUT")
            if set(r["publication_types"]) & PRIMARY_HINTS: flags.append("DES")
            out.append(f"   • [{r['pmid']}] ({r['year']})  {tt[:90]}{'...' if len(tt)>90 else ''}  | {' '.join(flags)}")
        return "\n".join(out) if out else "   (no sample)"

    # Candidate reasoning labels
    lines=[]
    lines.append("\n==================== SNIFF REPORT ====================")
    lines.append("NLQ:")
    nlq_line = (USER_NLQ or "").strip().replace("\n"," ")
    lines.append("   " + (nlq_line[:160] + ("..." if len(nlq_line)>160 else "")))
    lines.append("\nSEED TERMS (Qwen):")
    lines.append(f"  P: {', '.join(P0)}")
    lines.append(f"  I: {', '.join(I0)}")
    lines.append(f"  C: {', '.join(C0)}")
    lines.append(f"  O: {', '.join(O0)}")
    lines.append(f"  must_have: {', '.join(MUST) if MUST else '(none)'}")

    # Distilled tokens used (show what actually hit queries)
    lines.append("\nDISTILLED FOR TIAB (anchors enforced):")
    lines.append(f"  Pq: {', '.join(P_core) if P_core else '(none)'}")
    lines.append(f"  Iq: {', '.join(I_core) if I_core else '(none)'}")
    lines.append(f"  anchors: {', '.join(anchors_for_P)}")

    # Candidate set summary
    lines.append("\nCANDIDATES:")
    all_cands = [B0]+candidates
    for c in all_cands:
        reason = []
        if c is chosen_broad:
            reason.append("chosen_broad")
        elif c["name"].startswith("B"):
            if not (BROAD_OK[0] <= c["total"] <= BROAD_OK[1]): reason.append("count_out_of_window")
            if c is not chosen_broad: reason.append("lower_rq")
        else:
            reason.append("-")
        lines.append(f"  {c['name']}: hits={c['total']} {fmt_stats(c['stats'])}  reason={','.join(reason)}")

    # Chosen BROAD
    lines.append("\nCHOSEN BROAD")
    lines.append(f"  Query: {chosen_broad['query']}")
    lines.append(f"  Hits={chosen_broad['total']}  in_window={BROAD_OK[0] <= chosen_broad['total'] <= BROAD_OK[1]}  target_window={BROAD_TARGET[0] <= chosen_broad['total'] <= BROAD_TARGET[1]}")
    lines.append(f"  Signals: {fmt_stats(chosen_broad['stats'])}")
    lines.append(f"  Quick triage (Gemma): include={sanity_broad['include']} borderline={sanity_broad['borderline']} exclude={sanity_broad['exclude']} (n={sanity_broad['n']})")
    lines.append("  Top titles:")
    lines.append(top_titles(chosen_broad["sample"]))

    # Chosen FOCUSED
    lines.append("\nFOCUSED (BROAD ∧ trial filters)")
    lines.append(f"  Query: {chosen_focused['query']}")
    lines.append(f"  Hits={chosen_focused['total']}  in_window={FOCUSED_OK[0] <= chosen_focused['total'] <= FOCUSED_OK[1]}  target_window={FOCUSED_TARGET[0] <= chosen_focused['total'] <= FOCUSED_TARGET[1]}")
    lines.append(f"  Signals: {fmt_stats(chosen_focused['stats'])}")
    lines.append(f"  Quick triage (Gemma): include={sanity_focused['include']} borderline={sanity_focused['borderline']} exclude={sanity_focused['exclude']} (n={sanity_focused['n']})")
    lines.append("  Top titles:")
    lines.append(top_titles(chosen_focused["sample"]))

    # MeSH mined highlights
    kept_tokens = list(dict.fromkeys(P_core + I_core))[:10]
    lines.append("\nMESH mined (top kept tokens):")
    lines.append("  kept  : " + ", ".join(kept_tokens))
    # extras
    y = chosen_broad["stats"]["years"]; langs = chosen_broad["stats"]["languages"]; pts = chosen_broad["stats"]["ptypes"]
    lines.append("\nSAMPLE DISTRIBUTIONS (broad sample):")
    lines.append(f"  years: median={y.get('median')} range=({y.get('min')},{y.get('max')})")
    lines.append(f"  languages: {', '.join(f'{k}:{v}' for k,v in langs.items()) or '(none)'}")
    lines.append(f"  pub types: {', '.join(f'{k}:{v}' for k,v in pts.items()) or '(none)'}")

    lines.append("\nProtocol brief (used for Gemma triage):")
    lines.append("  " + protocol_txt.replace("\n", "\n  "))

    lines.append("\n================== END OF REPORT =====================")
    report = "\n".join(lines)
    print(report)

    (OUT_DIR/"sniff_report.json").write_text(json.dumps({
        "report_text": report,
        "chosen_broad": {"query": chosen_broad["query"], "hits": chosen_broad["total"]},
        "chosen_focused": {"query": chosen_focused["query"], "hits": chosen_focused["total"]},
    }, indent=2, ensure_ascii=False), encoding="utf-8")

# ----------------------------
# RUN: put your NLQ here
# ----------------------------
USER_NLQ = """Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE). Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE (the intervention of interest is INC, not the surgery). Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia. Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days. Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control. Year_min = 2015. Languages = English, Portuguese, Spanish."""
sniff_nlq(USER_NLQ, year_min=YEAR_MIN_DEFAULT, languages=["English","Portuguese","Spanish"])


13:17:51  B0_seed probing...
13:20:28  B1_broad hits=176 rq=3.1 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.2, 'years': {'median': 2025, 'min': 2025, 'max': 2025}, 'languages': {'eng': 5}, 'ptypes': {'Journal Article': 5, 'Review': 2}, 'top_mesh': ['Humans', 'Funnel Chest', 'Pain Management', 'Pain, Postoperative', 'Perioperative Care']}
13:20:30  B2_I_tight hits=176 rq=3.1 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.2, 'years': {'median': 2025, 'min': 2025, 'max': 2025}, 'languages': {'eng': 5}, 'ptypes': {'Journal Article': 5, 'Review': 2}, 'top_mesh': ['Humans', 'Funnel Chest', 'Pain Management', 'Pain, Postoperative', 'Perioperative Care']}
13:20:37  B3_P_core hits=118 rq=3.1 stats={'n_sample': 5, 'pi_rate': 1.0, 'outcome_rate': 0.2, 'design_rate': 0.0, 'median_score': 2.0, 'mean_score': 2.2, 'years': {'median': 2025, 'min': 2025, 'max': 20

In [27]:
# SNIPPET: Sniff Validation Engine (state-machine refactor)
# Goal: rigorous, resilient validation of evidence + search strategy feasibility
# - Single "Universe Query" + validated Recommended Filters (topic/design/lang)
# - State machine with remediation loops (no brittle hard-fails)
# - Qwen: protocol lockdown, scope remediation, strategy remediation
# - Gemma/Qwen-small: strict checklist screener for ground-truth discovery
# - Idle model eviction for LM Studio (best-effort) after configurable idle secs
#
# CONFIGURE via env:
#   LMSTUDIO_BASE = http://127.0.0.1:1234
#   QWEN_MODEL    = qwen/qwen3-4b
#   SCREENER_MODEL= gemma-3n-e2b-it  (or any small instruction-tuned model)
#   ENTREZ_EMAIL  = you@example.com
#   ENTREZ_API_KEY= ... (optional)
#   HTTP_TIMEOUT  = 300   (seconds)
#   LM_EVICT_ENABLED     = 1
#   LM_IDLE_EVICT_SECS   = 5
#   SYSTEM_KB_PATH       = system_knowledge_base.json (or /mnt/data/...)
#
# USAGE:
#   1) Put your NLQ into USER_NLQ at bottom. Optionally place a knowledge-base JSON.
#   2) Run this cell. Check sniff_poc_out/sniff_report.txt + sniff_artifacts.json.
#
# IMPORTANT:
#   - We DO NOT set 'max_tokens' anywhere (no truncation).
#   - We DO enforce strict JSON fences (BEGIN_JSON/END_JSON) for robustness.

import os, json, time, re, pathlib, threading, textwrap, random
from collections import Counter, defaultdict
import requests
from xml.etree import ElementTree as ET

# ----------------------------
# Config & Paths
# ----------------------------
LMSTUDIO_BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234")
QWEN_MODEL    = os.getenv("QWEN_MODEL", "unsloth/qwen3-4b")
SCREENER_MODEL= os.getenv("SCREENER_MODEL", "gemma-3n-e2b-it@q8_0")

ENTREZ_EMAIL   = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY = os.getenv("ENTREZ_API_KEY", "")
HTTP_TIMEOUT   = int(os.getenv("HTTP_TIMEOUT", "3000"))

LM_EVICT_ENABLED   = os.getenv("LM_EVICT_ENABLED", "1") == "1"
LM_IDLE_EVICT_SECS = int(os.getenv("LM_IDLE_EVICT_SECS", "5"))

SYSTEM_KB_PATH = os.getenv("SYSTEM_KB_PATH", "system_knowledge_base.json")

OUT_DIR = pathlib.Path("sniff_poc_out")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Universe windows (adjustable)
UNIVERSE_MIN = int(os.getenv("UNIVERSE_MIN", "50"))
UNIVERSE_MAX = int(os.getenv("UNIVERSE_MAX", "10000"))

# Validation windows for strategy
STRAT_MIN = int(os.getenv("STRAT_MIN", "10"))
STRAT_MAX = int(os.getenv("STRAT_MAX", "2000"))

# Numbers for GT discovery
GT_FETCH_N   = int(os.getenv("GT_FETCH_N", "30"))
GT_REQUIRE_N = int(os.getenv("GT_REQUIRE_N", "3"))

# ----------------------------
# LM Studio Client with Idle Eviction
# ----------------------------
class LMStudioClient:
    def __init__(self, base):
        self.base = base.rstrip("/")
        self._timers = {}  # model -> Timer

    def _schedule_unload(self, model: str):
        if not LM_EVICT_ENABLED:
            return
        # Cancel prior timer (if any)
        t = self._timers.get(model)
        if t and t.is_alive():
            t.cancel()
        timer = threading.Timer(LM_IDLE_EVICT_SECS, self._try_unload, args=(model,))
        self._timers[model] = timer
        timer.daemon = True
        timer.start()

    def _try_unload(self, model: str):
        # Best-effort. Try a few plausible endpoints. Ignore failures.
        endpoints = [
            f"{self.base}/v1/models/unload",
            f"{self.base}/v1/unload",
            f"{self.base}/v1/models/{model}/unload"
        ]
        payloads = [
            {"model": model},
            {"model": model},
            {}
        ]
        for url, payload in zip(endpoints, payloads):
            try:
                r = requests.post(url, json=payload, timeout=10)
                if r.status_code in (200, 204):
                    break
            except Exception:
                pass

    def chat(self, model: str, system: str, user: str, temperature=0.0, stop=None, response_format=None):
        url = f"{self.base}/v1/chat/completions"
        body = {
            "model": model,
            "messages": [{"role": "system", "content": system}, {"role": "user", "content": user}],
            "temperature": float(temperature),
            "stream": False
        }
        if stop is not None:
            body["stop"] = stop
        if response_format is not None:
            body["response_format"] = response_format
        r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
        r.raise_for_status()
        out = r.json()["choices"][0]["message"]["content"]
        # schedule idle unload after this use
        self._schedule_unload(model)
        return out

LM = LMStudioClient(LMSTUDIO_BASE)

# ----------------------------
# Strict JSON extraction helpers
# ----------------------------
_BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
_END   = re.compile(r"\s*END_JSON", re.I)
FENCE  = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)

def _sanitize_json_str(s: str) -> str:
    s = s.replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json_block_or_fence(txt: str) -> str:
    blocks = []
    pos=0
    while True:
        m1 = _BEGIN.search(txt, pos)
        if not m1: break
        m2 = _END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])

    fences = FENCE.findall(txt)
    if fences:
        return _sanitize_json_str(fences[-1])

    # last {...} object if present
    s = txt
    last_obj = None
    stack = 0; start = None
    for i,ch in enumerate(s):
        if ch == '{':
            if stack == 0: start = i
            stack += 1
        elif ch == '}':
            if stack > 0:
                stack -= 1
                if stack == 0 and start is not None:
                    last_obj = s[start:i+1]
    if last_obj:
        return _sanitize_json_str(last_obj)
    raise ValueError("No JSON-like content found")

# Robust "ask for JSON" with repair fallback
REPAIR_SYSTEM = "You repair malformed JSON to exactly match the given template keys. Return ONLY one JSON object between BEGIN_JSON/END_JSON."

def repair_user(template_json: str, bad_output: str) -> str:
    return f"""TEMPLATE_JSON:
{template_json}

BAD_OUTPUT:
{bad_output}

TASK: Output valid JSON matching TEMPLATE_JSON keys (fill missing with empty arrays/strings). No prose.

BEGIN_JSON
{{}}
END_JSON
"""

def ask_json(model: str, system: str, user: str, template: dict, stop_at_end=True):
    rules = "Return ONLY one JSON object. No analysis, no notes. Wrap EXACTLY with:\nBEGIN_JSON\n{...}\nEND_JSON"
    user_full = f"{user}\n\n{rules}"
    raw = LM.chat(model, system, user_full, temperature=0.0, stop=["END_JSON"] if stop_at_end else None)
    try:
        return json.loads(extract_json_block_or_fence(raw))
    except Exception:
        repaired = LM.chat(
            model,
            REPAIR_SYSTEM,
            repair_user(json.dumps(template, ensure_ascii=False, indent=2), raw) + "\n\n" + rules,
            temperature=0.0,
            stop=["END_JSON"] if stop_at_end else None
        )
        return json.loads(extract_json_block_or_fence(repaired))

# ----------------------------
# PubMed E-utilities (esearch/efetch)
# ----------------------------
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
HEADERS = {"User-Agent": "sniff-validation-engine/0.2 (+local)", "Accept": "application/json"}

def esearch_ids(term: str, mindate: int|None, retmax=100, retstart=0, usehistory=True):
    params = {
        "db":"pubmed","retmode":"json","term":term,"retmax":retmax,"retstart":retstart,
        "email":ENTREZ_EMAIL
    }
    if usehistory:
        params["usehistory"]="y"
    if ENTREZ_API_KEY:
        params["api_key"]=ENTREZ_API_KEY
    if mindate:
        params["mindate"]=str(mindate)
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count","0"))
    ids = js.get("idlist", []) or []
    webenv = js.get("webenv")
    qk = js.get("querykey")
    return count, [str(x) for x in ids], webenv, qk

def esearch_all_ids(term: str, mindate: int|None, limit=5000):
    # Fetch up to 'limit' ids using WebEnv
    count, _, webenv, qk = esearch_ids(term, mindate, retmax=0, retstart=0, usehistory=True)
    if not count or not webenv or not qk:
        return 0, []
    out=[]
    start=0
    while start < min(count, limit):
        r2 = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params={
            "db":"pubmed","retmode":"json","retmax":min(500, limit-start),"retstart":start,"email":ENTREZ_EMAIL,
            "WebEnv":webenv,"query_key":qk, **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
        }, timeout=HTTP_TIMEOUT)
        r2.raise_for_status()
        ids = r2.json().get("esearchresult",{}).get("idlist",[])
        if not ids: break
        out.extend([str(x) for x in ids])
        start += len(ids)
    return count, out

def efetch_xml(pmids):
    if not pmids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(pmids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent": "sniff-validation-engine/0.2"}, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str):
    out = []
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node):
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({
            "pmid": pmid, "title": title, "abstract": abstract, "year": year, "language": lang,
            "publication_types": pubtypes, "mesh": mesh
        })
    return out

# ----------------------------
# Knowledge Base (KB) loading
# ----------------------------
DEFAULT_KB = {
    "publication_types": [
        "Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
        "Comparative Study","Cohort Studies","Case-Control Studies","Observational Study",
        "Systematic Review","Meta-Analysis","Network Meta-Analysis"
    ],
    "languages": ["english","portuguese","spanish","french","german","italian","chinese","japanese","korean"],
    "design_precedence": ["Randomized Controlled Trial","Controlled Clinical Trial","Clinical Trial","Comparative Study","Cohort Studies","Case-Control Studies"],
    "mesh_topic_whitelist": [],  # optional allow-list
}

def load_system_kb():
    # Try explicit path, else /mnt/data, else defaults
    p = pathlib.Path(SYSTEM_KB_PATH)
    if not p.exists():
        alt = pathlib.Path("/mnt/data/system_knowledge_base.json")
        if alt.exists():
            p = alt
    if p.exists():
        try:
            return json.loads(p.read_text(encoding="utf-8"))
        except Exception:
            pass
    return DEFAULT_KB

SYSTEM_KB = load_system_kb()

# ----------------------------
# Utility: query builders
# ----------------------------
def or_block(terms, field="tiab"):
    toks=[]
    for t in terms or []:
        t=t.strip()
        if not t: continue
        if " " in t or "-" in t:
            toks.append(f"\"{t}\"[{field}]")
        else:
            toks.append(f"{t}[{field}]")
    if not toks: return ""
    return "(" + " OR ".join(toks) + ")"

def or_mesh(terms):
    toks=[]
    for t in terms or []:
        t=t.strip()
        if not t: continue
        toks.append(f"\"{t}\"[MeSH Terms]")
    if not toks: return ""
    return "(" + " OR ".join(toks) + ")"

def and_join(parts):
    parts=[p for p in parts if p and p.strip()]
    if not parts: return ""
    return " AND ".join(f"({p})" if " OR " in p or " AND " in p else p for p in parts)

def lang_filter(langs):
    if not langs: return ""
    # PubMed supports lang filter via language field tags like english[lang]
    return "(" + " OR ".join(f"\"{l}\"[lang]" for l in langs) + ")"

# ----------------------------
# STATE 1: Protocol Lockdown
# ----------------------------
PROTO_SYSTEM = "You convert NLQs into a strict, compact protocol using only allowed values for controlled fields. Return strict JSON only."

PROTO_TEMPLATE = {
    "research_question": "",
    "population_terms": [],
    "intervention_terms": [],
    "comparator_terms": [],
    "outcome_terms": [],
    "anchors": [],
    "avoid": [],
    "languages": [],
    "year_min": 2015,
    "designs_preference": [],   # subset of KB.design_precedence or KB.publication_types
    "adult_only": False
}

def proto_user(nlq: str, kb: dict):
    return f"""Natural-language question:
<<<
{nlq}
>>>

Your job:
1) Parse into a compact protocol object.
2) For 'designs_preference' choose ONLY from these allowed values (preserve order where applicable):
   {kb.get("design_precedence", [])}
3) For 'languages' choose ONLY from:
   {kb.get("languages", [])}

Rules:
- Keep each term string simple (no boolean operators, quotes, or field tags).
- 'anchors' should be 2–6 tokens that must appear to ensure topicality (e.g., MIRPE, Nuss, cryoablation).
- If adult-only is implied, set adult_only=true.
- If year_min is stated, use it; else leave default.

Return:

BEGIN_JSON
{{
  "research_question": "",
  "population_terms": [],
  "intervention_terms": [],
  "comparator_terms": [],
  "outcome_terms": [],
  "anchors": [],
  "avoid": [],
  "languages": [],
  "year_min": 2015,
  "designs_preference": [],
  "adult_only": false
}}
END_JSON
"""

def state1_protocol_lockdown(nlq: str):
    js = ask_json(QWEN_MODEL, PROTO_SYSTEM, proto_user(nlq, SYSTEM_KB), PROTO_TEMPLATE)
    # Clip controlled fields to KB just in case
    js["languages"] = [l for l in js.get("languages", []) if l in SYSTEM_KB.get("languages", [])]
    pref = js.get("designs_preference", [])
    allowed = SYSTEM_KB.get("design_precedence", SYSTEM_KB.get("publication_types", []))
    js["designs_preference"] = [d for d in pref if d in allowed]
    return js

# ----------------------------
# STATE 2: Universe Definition & Sizing
# ----------------------------
def build_universe_query(protocol: dict):
    P = or_block(protocol.get("population_terms", []), "tiab")
    I = or_block(protocol.get("intervention_terms", []), "tiab")
    parts = [P, I]
    # Adult bias (soft) as lexical cue only if adult_only flagged
    if protocol.get("adult_only"):
        parts.append(or_block(["adult","adults"], "tiab"))
    return and_join(parts)

REMEDIATION_SCOPE_TEMPLATE = {"action":"KEEP|WIDEN|NARROW","add_population_terms":[],"add_intervention_terms":[],"enforce_anchors":[]}

REMEDIATION_SCOPE_SYSTEM = "You propose surgical lexical scope fixes for a PubMed TIAB-only query. Return strict JSON only."

def remediate_scope_user(query: str, count: int, protocol: dict, window: tuple[int,int]):
    low, high = window
    status = "too_narrow" if count < low else "too_broad"
    return f"""The current universe TIAB-only query (P AND I) is {status}.
Target window: {low}..{high}. Current count={count}.

QUERY:
{query}

Protocol snapshot:
population_terms = {protocol.get("population_terms", [])}
intervention_terms = {protocol.get("intervention_terms", [])}
anchors = {protocol.get("anchors", [])}

Propose ONE fix:
- If too_narrow: suggest 'WIDEN' with 1-4 broader/popular synonyms to ADD to P and/or I (no quotes, no field tags).
- If too_broad: suggest 'NARROW' with 1-3 'enforce_anchors' tokens (from anchors or near equivalents) to AND onto the query.

Return:

BEGIN_JSON
{{"action":"KEEP|WIDEN|NARROW","add_population_terms":[],"add_intervention_terms":[],"enforce_anchors":[]}}
END_JSON
"""

def state2_universe(protocol: dict, attempts=2):
    window = (UNIVERSE_MIN, UNIVERSE_MAX)
    uq = build_universe_query(protocol)
    if not uq:
        raise SystemExit("FATAL: cannot build Universe Query (missing P or I terms).")
    for i in range(attempts+1):
        cnt, _ = esearch_all_ids(uq, protocol.get("year_min"), limit=0)
        print(time.strftime('%H:%M:%S'), f"  [Universe] try={i} count={cnt} window={window} query={uq}")
        if window[0] <= cnt <= window[1]:
            return uq, cnt, []
        # Ask LLM to remediate scope
        remed = ask_json(QWEN_MODEL, REMEDIATION_SCOPE_SYSTEM,
                         remediate_scope_user(uq, cnt, protocol, window),
                         REMEDIATION_SCOPE_TEMPLATE)
        if remed.get("action") == "WIDEN" and cnt < window[0]:
            protocol["population_terms"] = list(dict.fromkeys(protocol.get("population_terms", []) + remed.get("add_population_terms", [])))
            protocol["intervention_terms"] = list(dict.fromkeys(protocol.get("intervention_terms", []) + remed.get("add_intervention_terms", [])))
            uq = build_universe_query(protocol)
        elif remed.get("action") == "NARROW" and cnt > window[1]:
            anchors = remed.get("enforce_anchors", []) or protocol.get("anchors", [])
            if anchors:
                uq = and_join([uq, or_block(anchors, "tiab")])
            else:
                break
        else:
            break
    # last resort: return whatever we have with warning
    return uq, cnt, [f"WARNING: Universe count {cnt} outside target window {window}. Proceeding anyway."]

# ----------------------------
# STATE 3: Ground Truth Discovery & Protocol Validation
# ----------------------------
SCREENER_TEMPLATE = {
  "pmid":"", "checklist":{"population":"N","intervention":"N","outcome":"N","design":"N"},
  "include":"N","why":"", "mesh_roles":{"P":[],"I":[],"C":[],"O":[]}
}

SCREENER_SYSTEM = "You are a strict protocol checklist screener. Return JSON only. No explanations unless asked."

def screener_user(protocol: dict, record: dict):
    # Concrete, minimal checklist. INCLUDE only if ALL are 'Y'.
    return f"""Protocol (locked):
- Population scope (lexical cues): {protocol.get("population_terms", [])}
- Index intervention (lexical cues): {protocol.get("intervention_terms", [])}
- Outcomes of interest (lexical cues): {protocol.get("outcome_terms", [])}
- Designs preference (for 'design' check): {protocol.get("designs_preference", [])}
- Adults only: {protocol.get("adult_only", False)}

Article:
PMID: {record.get('pmid')}
Year: {record.get('year')}
Lang: {record.get('language')}
PubTypes: {record.get('publication_types')}
Title: {record.get('title')}
Abstract: {record.get('abstract')}

Checklist (answer with 'Y' or 'N'):
- Does the title/abstract clearly indicate the POPULATION/procedure context matches? (population)
- Does it clearly include the INDEX INTERVENTION? (intervention)
- Does it clearly report or promise our target OUTCOMES (pain scores, opioid consumption/requirements, etc.)? (outcome)
- Does it clearly meet a preferred DESIGN category (e.g., randomized/controlled/comparative/clinical trial)? (design)

INCLUDE if and only if ALL four are 'Y'. If INCLUDE, classify any useful MeSH descriptors (from the article text or common MeSH for this topic) into the buckets P/I/C/O.

Return:

BEGIN_JSON
{{
  "pmid": "{record.get('pmid')}",
  "checklist": {{"population":"N","intervention":"N","outcome":"N","design":"N"}},
  "include": "N",
  "why": "",
  "mesh_roles": {{"P":[],"I":[],"C":[],"O":[]}}
}}
END_JSON
"""

def accumulate_distributions(recs):
    years=[r.get("year") for r in recs if r.get("year")]
    langs=[r.get("language") for r in recs if r.get("language")]
    ptypes=[pt for r in recs for pt in (r.get("publication_types") or [])]
    dist = {
        "years": {"median": int(sorted(years)[len(years)//2]) if years else None,
                  "min": min(years) if years else None,
                  "max": max(years) if years else None},
        "languages": dict(Counter(langs)),
        "ptypes": dict(Counter(ptypes))
    }
    return dist

def state3_ground_truth(universe_query: str, protocol: dict):
    # Get first GT_FETCH_N results
    cnt, ids, webenv, qk = esearch_ids(universe_query, protocol.get("year_min"), retmax=GT_FETCH_N, retstart=0, usehistory=True)
    xml = efetch_xml(ids)
    recs = parse_pubmed_xml(xml)
    # Strict screener
    includes=[]; mesh_roles={"P":[],"I":[],"C":[],"O":[]}
    for r in recs:
        js = ask_json(SCREENER_MODEL, SCREENER_SYSTEM, screener_user(protocol, r), SCREENER_TEMPLATE)
        if (js.get("include","N") == "Y"):
            includes.append(r["pmid"])
            for k in ("P","I","C","O"):
                mesh_roles[k].extend([t for t in js.get("mesh_roles",{}).get(k,[]) if t])
    # Dedup & small cleanup
    mesh_roles = {k: list(dict.fromkeys(v)) for k,v in mesh_roles.items()}
    warnings=[]
    if len(includes) < GT_REQUIRE_N:
        raise SystemExit(f"FATAL: insufficient ground-truth includes ({len(includes)}/{GT_REQUIRE_N}). Universe likely off-topic or too weak.")
    # Feasibility check on MeSH vernaculum
    if not mesh_roles.get("P"):
        warnings.append("CRITICAL_WARNING: No P MeSH discovered; will fallback to lexical population terms.")
    if not mesh_roles.get("I"):
        warnings.append("CRITICAL_WARNING: No I MeSH discovered; will fallback to lexical intervention terms.")
    if not mesh_roles.get("C"):
        warnings.append("STANDARD_WARNING: No comparator MeSH discovered; lexical fallback.")
    if not mesh_roles.get("O"):
        warnings.append("STANDARD_WARNING: No outcome MeSH discovered; lexical fallback.")
    dist = accumulate_distributions(recs)
    return includes, mesh_roles, dist, warnings

# ----------------------------
# STATE 4: Strategy Validation & Refinement
# ----------------------------
def best_design_filter(protocol: dict, kb: dict):
    # Deterministic: choose first available from designs_preference; map to [pt] terms
    precedence = protocol.get("designs_preference") or kb.get("design_precedence", [])
    if not precedence:
        return ""
    # Build OR of acceptable [pt] tags from top preference down to 2 levels
    chosen = precedence[:2] if len(precedence)>=2 else precedence
    return "(" + " OR ".join(f"\"{pt}\"[Publication Type]" for pt in chosen) + ")"

def build_topic_filter(mesh_roles: dict, protocol: dict):
    # Prefer MeSH P & I; fallback to lexical if missing
    Pm = mesh_roles.get("P", [])
    Im = mesh_roles.get("I", [])
    P = or_mesh(Pm) if Pm else or_block(protocol.get("population_terms", []), "tiab")
    I = or_mesh(Im) if Im else or_block(protocol.get("intervention_terms", []), "tiab")
    if not P or not I:
        return ""
    return and_join([P, I])

REMEDIATION_STRAT_TEMPLATE = {"op":"DROP_TERM|ADD_ANCHOR|BROADEN_DESIGN_FILTER","term":"","where":"P|I|ANCHOR|DESIGN"}

REMEDIATION_STRAT_SYSTEM = "You propose ONE small surgical fix to a PubMed strategy that must pass recall & precision gates. Return strict JSON only."

def remediate_strategy_user(universe_query: str, topic_filter: str, design_filter: str, anchors: list, failed: dict):
    return f"""A validation just failed for the strategy below.

UNIVERSE_QUERY:
{universe_query}

TOPIC_FILTER:
{topic_filter}

DESIGN_FILTER:
{design_filter}

ANCHORS:
{anchors}

Failure snapshot:
{failed}

You must propose exactly ONE of:
- DROP_TERM (remove a single overly-specific token from topic filter; indicate 'term' and 'where': 'P' or 'I')
- ADD_ANCHOR (add one topical anchor term to be AND-ed to the universe query; 'where'='ANCHOR')
- BROADEN_DESIGN_FILTER (loosen the design filter; 'where'='DESIGN'; 'term' can be an additional publication type)

Return:

BEGIN_JSON
{{"op":"DROP_TERM|ADD_ANCHOR|BROADEN_DESIGN_FILTER","term":"","where":"P|I|ANCHOR|DESIGN"}}
END_JSON
"""

def state4_validate_strategy(universe_query: str, ground_truth_pmids: list, protocol: dict, mesh_roles: dict, attempts=3):
    topic = build_topic_filter(mesh_roles, protocol)
    design = best_design_filter(protocol, SYSTEM_KB)
    langs = lang_filter(protocol.get("languages", []))
    def combine(uq, topic, design, langs):
        parts=[uq, topic, design, langs]
        parts=[p for p in parts if p]
        return and_join(parts)
    strategy = combine(universe_query, topic, design, langs)

    for i in range(attempts+1):
        total, ids = esearch_all_ids(strategy, protocol.get("year_min"), limit=5000)
        # Check recall: all GT must be in results
        recall_ok = set(ground_truth_pmids).issubset(set(ids))
        precision_ok = STRAT_MIN <= total <= STRAT_MAX
        print(time.strftime('%H:%M:%S'), f"  [Strategy] try={i} total={total} recall_ok={recall_ok} precision_ok={precision_ok}")
        if recall_ok and precision_ok:
            return {"topic_filter": topic, "design_filter": design, "language_filter": langs, "final_query": strategy, "total": total}
        failed = {"total": total, "recall_ok": recall_ok, "precision_ok": precision_ok}
        remed = ask_json(QWEN_MODEL, REMEDIATION_STRAT_SYSTEM,
                         remediate_strategy_user(universe_query, topic, design, protocol.get("anchors", []), failed),
                         REMEDIATION_STRAT_TEMPLATE)
        op = remed.get("op","")
        term = (remed.get("term") or "").strip()
        where = remed.get("where","")

        # Execute remediation
        if op == "DROP_TERM" and term and where in ("P","I"):
            # Remove a token from topic filter by lexical fallback: rebuild topic with lexicals, drop term
            if where=="P":
                if term in mesh_roles.get("P", []):
                    mesh_roles["P"] = [t for t in mesh_roles["P"] if t != term]
                # also drop from lexical P if present
                protocol["population_terms"] = [t for t in protocol.get("population_terms", []) if t != term]
            else:
                if term in mesh_roles.get("I", []):
                    mesh_roles["I"] = [t for t in mesh_roles["I"] if t != term]
                protocol["intervention_terms"] = [t for t in protocol.get("intervention_terms", []) if t != term]
            topic = build_topic_filter(mesh_roles, protocol)
        elif op == "ADD_ANCHOR" and term:
            # AND anchor into universe query
            uq_anchor = or_block([term], "tiab")
            universe_query = and_join([universe_query, uq_anchor])
        elif op == "BROADEN_DESIGN_FILTER":
            # Add another ptype OR fallback to comparative set
            add = term if term else "Comparative Study"
            design = and_join([design, f"\"{add}\"[Publication Type]"]) if design else f"\"{add}\"[Publication Type]"
        else:
            # if remediation invalid, try a simple deterministic fallback: drop design filter entirely
            design = ""
        strategy = combine(universe_query, topic, design, langs)

    raise SystemExit("FATAL: Strategy validation failed after remediation attempts.")

# ----------------------------
# STATE 5: Finalization & Handoff
# ----------------------------
EMBED_SYSTEM = "You write a single-sentence research question string optimized for embedding search. Return JSON only."
EMBED_TEMPLATE = {"embedding_query": ""}

def embed_user(protocol: dict, mesh_roles: dict):
    return f"""Generate one compact sentence (<= 30 words) suitable for vector embeddings that captures the finalized protocol and key vocabulary.

Protocol brief:
- Population: {protocol.get("population_terms", [])}
- Intervention: {protocol.get("intervention_terms", [])}
- Outcomes: {protocol.get("outcome_terms", [])}
- Anchors: {protocol.get("anchors", [])}
- Languages: {protocol.get("languages", [])}
- Year_min: {protocol.get("year_min")}
- Designs_preference: {protocol.get("designs_preference", [])}

MeSH vernaculum (kept):
- P: {mesh_roles.get("P", [])}
- I: {mesh_roles.get("I", [])}
- C: {mesh_roles.get("C", [])}
- O: {mesh_roles.get("O", [])}

Return:

BEGIN_JSON
{{"embedding_query": ""}}
END_JSON
"""

def write_report(out_path: pathlib.Path, nlq: str, protocol: dict, universe_query: str, universe_count: int,
                 gt_pmids: list, mesh_roles: dict, dist: dict, warnings: list, strategy: dict):
    lines=[]
    lines.append("==================== SNIFF VALIDATION REPORT ====================")
    lines.append("NLQ:")
    lines.append("  " + (nlq.strip().replace("\n"," ")[:160] + ("..." if len(nlq.strip())>160 else "")))
    lines.append("")
    lines.append("LOCKED PROTOCOL")
    lines.append(f"  languages: {protocol.get('languages', [])}   year_min: {protocol.get('year_min')}")
    lines.append(f"  designs_preference: {protocol.get('designs_preference', [])}   adult_only: {protocol.get('adult_only', False)}")
    lines.append(f"  P terms: {protocol.get('population_terms', [])}")
    lines.append(f"  I terms: {protocol.get('intervention_terms', [])}")
    lines.append(f"  C terms: {protocol.get('comparator_terms', [])}")
    lines.append(f"  O terms: {protocol.get('outcome_terms', [])}")
    lines.append(f"  anchors: {protocol.get('anchors', [])}")
    lines.append("")
    lines.append("UNIVERSE")
    lines.append(f"  query: {universe_query}")
    lines.append(f"  count: {universe_count}  window_ok: {UNIVERSE_MIN <= universe_count <= UNIVERSE_MAX}")
    lines.append("")
    lines.append("GROUND TRUTH (strict checklist includes)")
    lines.append(f"  n_includes: {len(gt_pmids)}  pmids: {gt_pmids}")
    lines.append(f"  MeSH vernaculum: P={mesh_roles.get('P', [])}  I={mesh_roles.get('I', [])}  C={mesh_roles.get('C', [])}  O={mesh_roles.get('O', [])}")
    lines.append(f"  Sample distributions: years={dist.get('years')}  languages={dist.get('languages')}  pubtypes={dist.get('ptypes')}")
    if warnings:
        lines.append("  WARNINGS:")
        for w in warnings:
            lines.append(f"   - {w}")
    lines.append("")
    lines.append("VALIDATED STRATEGY & FILTERS")
    lines.append(f"  topic_filter: {strategy.get('topic_filter')}")
    lines.append(f"  design_filter: {strategy.get('design_filter')}")
    lines.append(f"  language_filter: {strategy.get('language_filter')}")
    lines.append(f"  final_query: {strategy.get('final_query')}")
    lines.append(f"  total_hits: {strategy.get('total')}")
    lines.append("")
    lines.append("================== END OF REPORT =====================")
    out_path.write_text("\n".join(lines), encoding="utf-8")

# ----------------------------
# Top-level: Sniff Validation Engine (state machine)
# ----------------------------
def sniff_validate_engine(USER_NLQ: str):
    warnings=[]
    print(time.strftime('%H:%M:%S'), " [S1] Protocol lockdown...")
    protocol = state1_protocol_lockdown(USER_NLQ)

    print(time.strftime('%H:%M:%S'), " [S2] Universe definition & sizing...")
    universe_query, universe_count, w2 = state2_universe(protocol, attempts=2)
    warnings.extend(w2)

    print(time.strftime('%H:%M:%S'), " [S3] Ground-truth discovery & protocol validation...")
    gt_pmids, mesh_roles, dist, w3 = state3_ground_truth(universe_query, protocol)
    warnings.extend(w3)

    print(time.strftime('%H:%M:%S'), " [S4] Search-strategy validation & refinement...")
    strategy = state4_validate_strategy(universe_query, gt_pmids, protocol, mesh_roles, attempts=3)

    print(time.strftime('%H:%M:%S'), " [S5] Finalization & handoff...")
    emb = ask_json(QWEN_MODEL, EMBED_SYSTEM, embed_user(protocol, mesh_roles), EMBED_TEMPLATE)
    embedding_q = emb.get("embedding_query","")

    artifacts = {
        "locked_protocol": protocol,
        "universe_query": universe_query,
        "universe_count": universe_count,
        "recommended_filters": {
            "topic": strategy.get("topic_filter"),
            "design": strategy.get("design_filter"),
            "language": strategy.get("language_filter")
        },
        "final_query": strategy.get("final_query"),
        "final_query_total": strategy.get("total"),
        "ground_truth_pmids": gt_pmids,
        "mesh_vernaculum": mesh_roles,
        "warnings": warnings,
        "research_question_string_for_embedding": embedding_q,
        "system_kb_snapshot": SYSTEM_KB
    }

    # Persist
    (OUT_DIR/"sniff_artifacts.json").write_text(json.dumps(artifacts, indent=2, ensure_ascii=False), encoding="utf-8")
    write_report(OUT_DIR/"sniff_report.txt", USER_NLQ, protocol, universe_query, universe_count, gt_pmids, mesh_roles, dist, warnings, strategy)
    print("Artifacts written:", OUT_DIR)
    print(" - sniff_report.txt")
    print(" - sniff_artifacts.json")

# ----------------------------
# RUN
# ----------------------------
USER_NLQ = """Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE). Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE (the intervention of interest is INC, not the surgery). Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia. Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days. Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control. Year_min = 2015. Languages = English, Portuguese, Spanish."""
sniff_validate_engine(USER_NLQ)


16:15:49  [S1] Protocol lockdown...
16:17:01  [S2] Universe definition & sizing...
16:17:02   [Universe] try=0 count=7 window=(50, 10000) query=((adults[tiab] OR "minimally invasive repair of pectus excavatum"[tiab] OR Nuss[tiab])) AND ("intercostal nerve cryoablation"[tiab]) AND ((adult[tiab] OR adults[tiab]))
16:17:21   [Universe] try=1 count=18 window=(50, 10000) query=((adults[tiab] OR "minimally invasive repair of pectus excavatum"[tiab] OR Nuss[tiab] OR "chest wall surgery"[tiab])) AND (("intercostal nerve cryoablation"[tiab] OR "nerve ablation"[tiab])) AND ((adult[tiab] OR adults[tiab]))
16:17:35   [Universe] try=2 count=21 window=(50, 10000) query=((adults[tiab] OR "minimally invasive repair of pectus excavatum"[tiab] OR Nuss[tiab] OR "chest wall surgery"[tiab] OR "thoracic surgery"[tiab])) AND (("intercostal nerve cryoablation"[tiab] OR "nerve ablation"[tiab] OR neuroablation[tiab])) AND ((adult[tiab] OR adults[tiab]))
16:17:55  [S3] Ground-truth discovery & protocol validatio

HTTPError: 400 Client Error: Bad Request for url: http://127.0.0.1:1234/v1/chat/completions

In [29]:
# SNIPPET: Sniff Validation Engine (state-machine PoC, resilient, single cell)
# - Replaces linear "broad/focused" approach with a Universe Query + Validated Filters
# - Adds remediation loops and a senior "plausibility" guardrail to stop GIGO cascades
# - Strict LM Studio model TTL to auto-evict idle models and avoid CPU fallback
# - Verbose, human-readable report printed inline + compact artifacts on disk
#
# CONFIGURE before running:
#   - LM Studio server: http://127.0.0.1:1234 (or set LMSTUDIO_BASE)
#   - QWEN_MODEL (reasoning/senior)  e.g., "qwen/qwen3-4b"
#   - SCREENER_MODEL (fast screener) e.g., "gemma-3n-e2b-it" or a small Qwen
#   - Ensure LM Studio Settings > Developer > "JIT load models" ON and "Auto-evict" ON
#   - This script sets per-request "ttl" (seconds) so models unload ~immediately after use
#
# ENV VARS (optional):
#   LMSTUDIO_BASE, QWEN_MODEL, SCREENER_MODEL
#   ENTREZ_EMAIL, ENTREZ_API_KEY, HTTP_TIMEOUT
#   LM_TTL_SECONDS  (default 5)  <-- per-request idle TTL for model auto-eviction
#
# INPUTS:
#   - USER_NLQ (your natural-language question)
#   - Optional: system_knowledge_base.json (KB of valid values). If missing, a safe default is used.
#
# OUTPUTS (folder: sniff_out/):
#   - sniff_report.txt   : human-readable full log + final validated strategy
#   - sniff_artifacts.json : machine-readable state, protocol, filters, ground-truth, MeSH vernaculum, warnings, embed string
#   - ground_truth.tsv   : quick table of INCLUDEs (pmid, title, decisions)
#
# NOTE:
#   - We DO NOT set max_tokens anywhere (to avoid premature truncation).
#   - We DO set per-request {"ttl": LM_TTL_SECONDS} so LM Studio unloads models after ~idle.
#   - When switching models, we wait (ttl + 0.5s) to let auto-evict complete before loading the next.
#   - PubMed E-utilities are used for counts/records (esearch/efetch); no scraping.

import os, json, time, re, pathlib, textwrap, random
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from typing import List, Dict, Any, Tuple
import requests
from xml.etree import ElementTree as ET

# ----------------------------
# Config & I/O
# ----------------------------
LMSTUDIO_BASE   = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234").rstrip("/")
QWEN_MODEL      = os.getenv("QWEN_MODEL", "qwen3-4b@q6_k")
SCREENER_MODEL  = os.getenv("SCREENER_MODEL", "gemma-3n-e4b-it@q5_k_m")  # can also be a small Qwen
HTTP_TIMEOUT    = int(os.getenv("HTTP_TIMEOUT", "300"))
LM_TTL_SECONDS  = int(os.getenv("LM_TTL_SECONDS", "5"))  # enforce fast auto-evict (5s idle)

ENTREZ_EMAIL    = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY  = os.getenv("ENTREZ_API_KEY", "")

OUT_DIR = pathlib.Path("sniff_out"); OUT_DIR.mkdir(parents=True, exist_ok=True)

# Universe sizing windows (can be relaxed via KB)
UNIVERSE_TARGET = (50, 10000)
UNIVERSE_HARD_MIN = 25      # below this -> terminate (to avoid polluted tiny "universe")
GROUND_TRUTH_MIN = 3        # minimum unequivocal INCLUDEs required
TOP_FETCH = 30              # how many top hits to fetch for ground-truth discovery
SCREEN_SAMPLE_MAX = 30      # upper bound, we iterate deterministically over TOP_FETCH

# PubMed API
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
HEADERS = {"User-Agent": "sniff-validation-engine/0.3 (+local)", "Accept": "application/json"}

# ----------------------------
# LM Studio helpers (OpenAI-compatible)
# ----------------------------
_last_model_used = {"name": None, "t": 0.0}

def _lm_call(model: str, messages: List[Dict[str, str]], stop: List[str] = None) -> str:
    """
    Raw LM Studio call to /v1/chat/completions with per-request TTL to auto-evict.
    We intentionally DO NOT pass max_tokens (let the model decide).
    """
    # If switching models, give TTL time for auto-evict of the previous one to avoid CPU fallback.
    now = time.time()
    if _last_model_used["name"] and _last_model_used["name"] != model:
        delta = now - _last_model_used["t"]
        wait_need = LM_TTL_SECONDS + 0.5 - delta
        if wait_need > 0:
            print(f"{time.strftime('%H:%M:%S')}  [LM] Switching models: waiting {wait_need:.1f}s for TTL auto-evict...")
            time.sleep(wait_need)

    url = f"{LMSTUDIO_BASE}/v1/chat/completions"
    body = {
        "model": model,
        "messages": messages,
        "temperature": 0.0,
        "stream": False,
        "ttl": LM_TTL_SECONDS
    }
    if stop:
        body["stop"] = stop
    r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    _last_model_used["name"] = model
    _last_model_used["t"] = time.time()
    return r.json()["choices"][0]["message"]["content"]

_BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
_END   = re.compile(r"\s*END_JSON", re.I)
FENCE  = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)

def _sanitize_json_str(s: str) -> str:
    s = s.replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json_like(txt: str) -> str:
    blocks = []
    pos=0
    while True:
        m1 = _BEGIN.search(txt, pos)
        if not m1: break
        m2 = _END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])
    fences = FENCE.findall(txt)
    if fences:
        return _sanitize_json_str(fences[-1])

    s = txt
    last_obj=None; stack=0; start=None
    for i,ch in enumerate(s):
        if ch=='{':
            if stack==0: start=i
            stack+=1
        elif ch=='}':
            if stack>0:
                stack-=1
                if stack==0 and start is not None:
                    last_obj=s[start:i+1]
    if last_obj:
        return _sanitize_json_str(last_obj)
    raise ValueError("No JSON-like content found")

REPAIR_SYSTEM = "You repair malformed JSON to exactly match template keys. Return ONE JSON object only, wrapped between BEGIN_JSON/END_JSON."
STRICT_JSON_RULES = (
  "Return ONLY one JSON object. No analysis, no preface. "
  "Wrap exactly with:\nBEGIN_JSON\n{...}\nEND_JSON"
)

def ask_json(model: str, system: str, user: str, template: Dict[str,Any], stop_on_end=True) -> Dict[str,Any]:
    try:
        content = _lm_call(model, [{"role":"system","content":system},{"role":"user","content": user + "\n\n" + STRICT_JSON_RULES}], stop=["END_JSON"] if stop_on_end else None)
        return json.loads(extract_json_like(content))
    except Exception as e:
        # Try a repair pass
        rep_user = f"""TEMPLATE_JSON:
{json.dumps(template, ensure_ascii=False, indent=2)}

BAD_OUTPUT:
{content if 'content' in locals() else str(e)}

TASK: Output valid JSON matching TEMPLATE_JSON keys (fill missing with empty arrays/strings). No prose.

BEGIN_JSON
{{}}
END_JSON
"""
        repaired = _lm_call(REPAIR_SYSTEM, [{"role":"system","content":REPAIR_SYSTEM},{"role":"user","content":rep_user + "\n\n" + STRICT_JSON_RULES}], stop=["END_JSON"] if stop_on_end else None)
        return json.loads(extract_json_like(repaired))

# ----------------------------
# Knowledge Base (KB) load
# ----------------------------
DEFAULT_KB = {
    "publication_types": [
        "Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
        "Comparative Study","Cohort Studies","Case-Control Studies","Observational Study",
        "Systematic Review","Meta-Analysis","Network Meta-Analysis"
    ],
    "languages": ["english","portuguese","spanish"],
    "universe_window": {"min": 50, "max": 10000},
    "precision_window": {"min": 5, "max": 2000},
    "design_filters": {
        "RCT": '("Randomized Controlled Trial"[Publication Type])',
        "Comparative": '("Comparative Study"[Publication Type] OR "Case-Control Studies"[Publication Type] OR "Cohort Studies"[Publication Type])',
        "AnyTrial": '("Randomized Controlled Trial"[Publication Type] OR "Clinical Trial"[Publication Type] OR "Controlled Clinical Trial"[Publication Type])'
    }
}

def load_kb(path="system_knowledge_base.json") -> Dict[str,Any]:
    p = pathlib.Path(path)
    if p.exists():
        try:
            js = json.loads(p.read_text(encoding="utf-8"))
            # merge defaults for missing keys
            merged = json.loads(json.dumps(DEFAULT_KB))
            for k,v in js.items():
                merged[k]=v
            return merged
        except Exception:
            return DEFAULT_KB
    return DEFAULT_KB

# ----------------------------
# PubMed helpers
# ----------------------------
def esearch_ids(term: str, retmax=5000) -> Tuple[int, List[str]]:
    p = {
        "db":"pubmed","retmode":"json","term":term,
        "retmax":min(retmax, 5000),
        "email":ENTREZ_EMAIL,"usehistory":"y"
    }
    if ENTREZ_API_KEY: p["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=p, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count","0"))
    webenv = js.get("webenv"); qk = js.get("querykey")
    ids=[]
    if count and webenv and qk:
        r2 = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params={
            "db":"pubmed","retmode":"json","retmax":min(count, retmax),
            "retstart":0,"email":ENTREZ_EMAIL,"WebEnv":webenv,"query_key":qk,
            **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
        }, timeout=HTTP_TIMEOUT)
        r2.raise_for_status()
        ids = r2.json().get("esearchresult",{}).get("idlist",[])
    return count, [str(x) for x in ids]

def efetch_xml(pmids: List[str]) -> str:
    if not pmids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(pmids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent":"sniff-validation-engine/0.3"}, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str) -> List[Dict[str,Any]]:
    out=[]
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node):
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({"pmid":pmid,"title":title,"abstract":abstract,"year":year,"language":lang,
                    "publication_types":pubtypes,"mesh":mesh})
    return out

# ----------------------------
# Term utilities
# ----------------------------
def or_block(terms: List[str], field="tiab") -> str:
    toks=[]
    for t in terms:
        t=t.strip()
        if not t: continue
        if " " in t or "-" in t:
            toks.append(f"\"{t}\"[{field}]")
        else:
            toks.append(f"{t}[{field}]")
    if not toks: return ""
    return "(" + " OR ".join(toks) + ")"

def build_universe_query(P_terms: List[str], I_terms: List[str], anchors: List[str]) -> str:
    P = or_block(P_terms, "tiab")
    I = or_block(I_terms, "tiab")
    A = or_block(anchors, "tiab") if anchors else ""
    if not P or not I: return ""
    q = f"{P} AND {I}"
    if A: q = f"{q} AND {A}"
    return q

# ----------------------------
# LM prompts
# ----------------------------
PROTOCOL_TEMPLATE = {
    "population_terms": [], "intervention_terms": [], "comparators_terms": [], "outcomes_terms": [],
    "must_have": [], "avoid": [], "designs_preference": "", "languages": [], "year_min": 2015
}

def protocol_system(kb: Dict[str,Any]) -> str:
    return "You convert an NLQ into a strict, small protocol using ONLY allowed values from the provided knowledge base. Return JSON only."

def protocol_user(nlq: str, kb: Dict[str,Any]) -> str:
    return f"""NATURAL LANGUAGE QUESTION (NLQ):
<<<
{nlq}
>>>

KNOWLEDGE BASE (allowed values):
publication_types = {kb["publication_types"]}
languages = {kb["languages"]}
universe_window = {kb.get("universe_window",{})}

TASK:
- Parse the NLQ into a compact protocol with arrays of short phrases for population_terms and intervention_terms (include standard synonyms/acronyms), comparators_terms, outcomes_terms.
- must_have: 2–6 anchors that enforce topicality.
- avoid: 0–6 obvious off-topic terms to avoid.
- designs_preference: Choose ONE value that best matches the NLQ intent, but it MUST be one of publication_types above (e.g., "Randomized Controlled Trial" or "Comparative Study").
- languages: Choose ONLY from languages above (subset).
- year_min: integer, from NLQ if stated else 2015.

IMPORTANT:
- You MUST select designs_preference from KB publication_types (no new strings).
- Keep all lists between 2 and 10 items when possible.
- Return only:

BEGIN_JSON
{{
  "population_terms": [],
  "intervention_terms": [],
  "comparators_terms": [],
  "outcomes_terms": [],
  "must_have": [],
  "avoid": [],
  "designs_preference": "",
  "languages": [],
  "year_min": 2015
}}
END_JSON
"""

SCREENER_TEMPLATE = {
  "pmid":"", "checklist":{"P":False,"I":False,"O":False,"D":False},
  "decision":"include|exclude", "why":"", "mesh_roles":[]
}

def screener_system() -> str:
    return ("You are a STRICT effects screener. Decide ONLY on definitive INCLUDEs using a checklist. "
            "If ANY checklist item is N=false, decision MUST be 'exclude'. Return JSON only.")

def screener_user(protocol: Dict[str,Any], record: Dict[str,Any]) -> str:
    pop = ", ".join(protocol.get("population_terms",[]))
    itv = ", ".join(protocol.get("intervention_terms",[]))
    outs= ", ".join(protocol.get("outcomes_terms",[]))
    design = protocol.get("designs_preference","")
    return f"""Protocol (STRICT):
- Population (must mention): {pop}
- Intervention (must mention): {itv}
- Outcomes (must mention at least one): {outs}
- Design (must match): {design}

Record:
PMID: {record['pmid']}
Title: {record['title']}
Abstract: {record['abstract']}
Year: {record['year']}
Language: {record['language']}
PubTypes: {record['publication_types']}
MeSH: {record['mesh']}

CHECKLIST:
- P: Does Title/Abstract clearly indicate the specified Population/setting? (Y/N)
- I: Does it clearly study the specified Intervention (not just mention)? (Y/N)
- O: Is at least one specified Outcome measured/reported? (Y/N)
- D: Is the Design compatible with the required design above (same or stricter)? (Y/N)

DECISION RULE: Only if P=Y and I=Y and O=Y and D=Y -> decision = "include". Else "exclude".
If INCLUDE, also categorize each MeSH term as P/I/C/O/G/X briefly.

Return ONLY:

BEGIN_JSON
{{
  "pmid":"{record['pmid']}",
  "checklist":{{"P":false,"I":false,"O":false,"D":false}},
  "decision":"include|exclude",
  "why":"",
  "mesh_roles":[{{"mesh":"", "role":"P|I|C|O|G|X"}}]
}}
END_JSON
"""

PLAUS_SYS = "You are a senior validator. You spot-check plausibility against the core topic. Return JSON only."
def plaus_user(protocol: Dict[str,Any], record: Dict[str,Any]) -> str:
    core = f"Pectus excavatum MIRPE/Nuss + intercostal nerve cryoablation."
    return f"""A junior screener marked this as INCLUDE. Perform a quick sanity check.

Core topic required: {core}

Record:
PMID: {record['pmid']}
Title: {record['title']}
Abstract: {record['abstract']}

Answer with PASS if the core topic obviously fits, else FAIL.

BEGIN_JSON
{{"pmid":"{record['pmid']}", "plausibility":"PASS|FAIL", "note":""}}
END_JSON
"""

REMEDIATION_SCOPE_SYS = "You are a retrieval surgeon. Fix scope issues concisely."
def remediation_scope_user(query: str, count: int, protocol: Dict[str,Any], window: Tuple[int,int]) -> str:
    return f"""Universe query is out-of-window {window} with count={count}.
Protocol anchors: must_have={protocol.get('must_have',[])}
Population terms: {protocol.get('population_terms',[])}
Intervention terms: {protocol.get('intervention_terms',[])}

TASK:
- If too NARROW: suggest broader synonyms to ADD (P_add[], I_add[]) and up to 2 anchors to KEEP.
- If too BROAD: suggest up to 3 anchors to ENFORCE (A_enforce[]).
Keep total suggestions <= 8 terms. Use terse surface forms.

Return:

BEGIN_JSON
{{"P_add":[], "I_add":[], "A_enforce":[]}}
END_JSON
"""

REMEDIATION_STRAT_SYS = "You are a strategy fixer. Choose ONE edit operation. Return JSON only."
def remediation_strat_user(filters: Dict[str,str], stats: Dict[str,Any]) -> str:
    return f"""Strategy failure snapshot:
filters = {filters}
stats = {stats}

Choose one fix:
- DROP_TERM: remove one noisy topic term (provide exact term)
- ADD_ANCHOR: add one anchor term (P or I term) as TIAB anchor
- BROADEN_DESIGN_FILTER: switch to a less restrictive design filter (e.g., AnyTrial or Comparative)

Return:

BEGIN_JSON
{{"op":"DROP_TERM|ADD_ANCHOR|BROADEN_DESIGN_FILTER","term":""}}
END_JSON
"""

EMBED_SYS = "You distill a validated protocol into a single, ~200-char research question string suitable for dense embedding."
def embed_user(protocol: Dict[str,Any], vernac: Dict[str,Any]) -> str:
    return f"""Protocol:
{json.dumps(protocol, ensure_ascii=False)}

Key vocabulary (MeSH/lexical):
{json.dumps(vernac, ensure_ascii=False)}

Return ONLY one compact sentence (~200 chars) capturing the question with the most critical tokens (but readable)."""

# ----------------------------
# State Machine
# ----------------------------
@dataclass
class EngineState:
    kb: Dict[str,Any]
    report_lines: List[str] = field(default_factory=list)
    warnings: List[str] = field(default_factory=list)
    locked_protocol: Dict[str,Any] = field(default_factory=dict)
    universe_query: str = ""
    universe_count: int = 0
    ground_truth: List[Dict[str,Any]] = field(default_factory=list) # records of INCLUDE (after plausibility)
    mesh_vernac: Dict[str,List[str]] = field(default_factory=dict)
    recommended_filters: Dict[str,str] = field(default_factory=dict)
    final_strategy_count: int = 0
    include_pmids: List[str] = field(default_factory=list)
    plaus_fails: List[str] = field(default_factory=list)

    def log(self, msg:str):
        print(msg)
        self.report_lines.append(msg)

# ---- State 1: Protocol Lockdown
def state1_protocol_lockdown(st: EngineState, nlq: str):
    st.log(f"{time.strftime('%H:%M:%S')}  [S1] Protocol lockdown...")
    template = PROTOCOL_TEMPLATE
    js = ask_json(QWEN_MODEL, protocol_system(st.kb), protocol_user(nlq, st.kb), template)
    # Sanitize: enforce designs_preference is in KB
    if js.get("designs_preference") not in st.kb["publication_types"]:
        # fallback: choose RCT if present else Comparative
        js["designs_preference"] = "Randomized Controlled Trial" if "Randomized Controlled Trial" in st.kb["publication_types"] else st.kb["publication_types"][0]
    # Enforce languages subset
    js["languages"] = [l for l in js.get("languages",[]) if l in st.kb["languages"]] or st.kb["languages"]
    st.locked_protocol = js
    st.log("  [S1] Locked protocol:")
    st.log("    " + json.dumps(st.locked_protocol, ensure_ascii=False))

# ---- State 2: Universe Definition & Sizing
def apply_remediation_to_universe(protocol: Dict[str,Any], query: str, sugg: Dict[str,Any]) -> str:
    P = list(protocol.get("population_terms",[]))
    I = list(protocol.get("intervention_terms",[]))
    if sugg.get("P_add"):
        P = (P + [t for t in sugg["P_add"] if t])[:12]
    if sugg.get("I_add"):
        I = (I + [t for t in sugg["I_add"] if t])[:12]
    anchors = sugg.get("A_enforce", protocol.get("must_have", []))
    return build_universe_query(P, I, anchors)

def state2_universe(st: EngineState):
    st.log(f"{time.strftime('%H:%M:%S')}  [S2] Universe definition & sizing...")
    P = st.locked_protocol.get("population_terms",[])
    I = st.locked_protocol.get("intervention_terms",[])
    A = st.locked_protocol.get("must_have",[])
    q = build_universe_query(P, I, A)
    wmin, wmax = st.kb.get("universe_window",{}).get("min", UNIVERSE_TARGET[0]), st.kb.get("universe_window",{}).get("max", UNIVERSE_TARGET[1])

    tries=0; max_tries=2
    while True:
        cnt, _ids = esearch_ids(q, retmax=TOP_FETCH)
        st.log(f"   [Universe] try={tries} count={cnt} window=({wmin}, {wmax}) query={q}")
        if wmin <= cnt <= wmax:
            st.universe_query = q; st.universe_count = cnt
            break
        if tries >= max_tries:
            st.warnings.append(f"Universe out of window after {tries+1} tries (count={cnt}). Proceeding cautiously.")
            st.universe_query = q; st.universe_count = cnt
            break
        # Remediate via LLM
        sugg = ask_json(QWEN_MODEL, REMEDIATION_SCOPE_SYS, remediation_scope_user(q, cnt, st.locked_protocol, (wmin,wmax)), {"P_add":[],"I_add":[],"A_enforce":[]})
        q = apply_remediation_to_universe(st.locked_protocol, q, sugg)
        tries += 1

    if st.universe_count < UNIVERSE_HARD_MIN:
        st.log("   [Universe] HARD FAIL: universe too small -> terminate.")
        raise SystemExit("Universe too small; refine NLQ or protocol.")

# ---- State 3: Ground Truth Discovery (strict screener)
def state3_ground_truth(st: EngineState):
    st.log(f"{time.strftime('%H:%M:%S')}  [S3] Ground-truth discovery & protocol validation...")
    cnt, ids = esearch_ids(st.universe_query, retmax=TOP_FETCH)
    recs = parse_pubmed_xml(efetch_xml(ids))
    includes = []
    for r in recs[:SCREEN_SAMPLE_MAX]:
        js = ask_json(SCREENER_MODEL, screener_system(), screener_user(st.locked_protocol, r), SCREENER_TEMPLATE)
        # Print each screened record title + abstract (compact)
        title = (r['title'] or "").strip()
        abstract = (r['abstract'] or "").strip()
        st.log(f"     [Screen] PMID {r['pmid']} -> decision={js.get('decision')} checklist={js.get('checklist')} why={js.get('why','')}")
        st.log("       Title: " + title[:240])
        if abstract:
            st.log("       Abstract: " + abstract[:500].replace("\n"," ") + ("..." if len(abstract) > 500 else ""))
        if js.get("decision","") == "include":
            includes.append(r)

    if len(includes) < GROUND_TRUTH_MIN:
        st.log(f"   [S3] FAIL: Found {len(includes)} INCLUDE(s) (< {GROUND_TRUTH_MIN}). Terminate.")
        raise SystemExit("Insufficient ground truth; refine NLQ or universe scope.")

    st.ground_truth = includes

# ---- State 3.5: Senior plausibility guard
def state35_plausibility(st: EngineState):
    st.log(f"{time.strftime('%H:%M:%S')}  [S3.5] Senior plausibility check...")
    vetted=[]
    fails=[]
    for r in st.ground_truth:
        js = ask_json(QWEN_MODEL, PLAUS_SYS, plaus_user(st.locked_protocol, r), {"pmid":"","plausibility":"","note":""})
        ok = (js.get("plausibility","").upper()=="PASS")
        st.log(f"     [Plausibility] PMID {r['pmid']} -> {js.get('plausibility')} note={js.get('note','')}")
        if ok: vetted.append(r)
        else:  fails.append(r['pmid'])
    if len(vetted) < GROUND_TRUTH_MIN:
        st.log(f"   [S3.5] FAIL: Only {len(vetted)} plausible INCLUDE(s) after senior check (< {GROUND_TRUTH_MIN}). Terminate.")
        raise SystemExit("Ground truth implausible; refine NLQ or scope.")
    st.ground_truth = vetted
    st.include_pmids = [r['pmid'] for r in vetted]
    st.plaus_fails = fails

    # Build MeSH vernaculum from vetted records
    mr = defaultdict(int)
    for r in vetted:
        for m in (r.get("mesh") or []):
            mr[m] += 1
    # Heuristic split using surface heuristics + protocol anchors
    P_lex = st.locked_protocol.get("population_terms",[])
    I_lex = st.locked_protocol.get("intervention_terms",[])
    vernac = {
        "P_mesh": [],
        "I_mesh": [],
        "C_mesh": [],
        "O_mesh": [],
        "fallback_P": P_lex,
        "fallback_I": I_lex
    }
    # Assign roughly by keyword cues (lightweight; screener already has roles if needed via its mesh_roles)
    for m,_n in sorted(mr.items(), key=lambda kv: -kv[1])[:30]:
        ml=m.lower()
        if "funnel chest" in ml or "pectus" in ml or "minimally invasive" in ml:
            vernac["P_mesh"].append(m)
        elif "cryosurg" in ml or "cryotherapy" in ml or "nerve block" in ml or "analgesia" in ml:
            vernac["I_mesh"].append(m)
        elif "length of stay" in ml or "pain" in ml or "quality of life" in ml or "respiratory" in ml:
            vernac["O_mesh"].append(m)
        elif "comparative" in ml or "randomized" in ml or "prospective" in ml:
            vernac["C_mesh"].append(m)
        elif "intercostal" in ml:
            vernac["I_mesh"].append(m)
        elif "thoracic" in ml:
            vernac["P_mesh"].append(m)
    st.mesh_vernac = vernac

# ---- State 4: Strategy validation & refinement
def make_topic_filter(vernac: Dict[str,List[str]]) -> str:
    def mesh_block(arr):
        toks=[]
        for m in arr:
            m=m.strip()
            if not m: continue
            toks.append(f"\"{m}\"[MeSH Terms]")
        return "(" + " OR ".join(toks) + ")" if toks else ""
    blocks=[]
    for key in ("P_mesh","I_mesh","O_mesh"):
        b = mesh_block(vernac.get(key,[]))
        if b: blocks.append(b)
    if not blocks:
        return ""
    return " AND ".join(blocks)

def pick_design_filter(st: EngineState) -> str:
    pref = st.locked_protocol.get("designs_preference", "Randomized Controlled Trial")
    # choose the tightest matching filter from KB design_filters
    dfs = st.kb.get("design_filters", DEFAULT_KB["design_filters"])
    if "Randomized Controlled Trial" in pref and "RCT" in dfs:
        return dfs["RCT"]
    if "Comparative" in pref and "Comparative" in dfs:
        return dfs["Comparative"]
    return dfs.get("AnyTrial", dfs["RCT"])

def combine_strategy(universe_q: str, topic_filter: str, design_filter: str, langs: List[str]) -> str:
    q = f"({universe_q})"
    if topic_filter:
        q += f" AND ({topic_filter})"
    if design_filter:
        q += f" AND ({design_filter})"
    if langs:
        lang_or = " OR ".join(f"\"{l}\"[lang]" for l in langs)
        q += f" AND ({lang_or})"
    return q

def recall_okay(query: str, must_have_pmids: List[str]) -> Tuple[bool, int, List[str]]:
    # ensure all gt pmids are in current results
    total, ids = esearch_ids(query, retmax=5000)
    idset = set(ids)
    all_in = all(p in idset for p in must_have_pmids)
    return all_in, total, ids

def state4_strategy(st: EngineState):
    st.log(f"{time.strftime('%H:%M:%S')}  [S4] Strategy validation & refinement...")
    topic = make_topic_filter(st.mesh_vernac)
    if not topic:
        st.warnings.append("No reliable MeSH topic filter; falling back to lexical anchors only.")
    design = pick_design_filter(st)
    langs  = st.locked_protocol.get("languages", st.kb["languages"])

    filters = {"topic":topic, "design":design, "languages": langs}
    tries=0; max_tries=3
    while True:
        strat = combine_strategy(st.universe_query, filters["topic"], filters["design"], filters["languages"])
        ok, total, ids = recall_okay(strat, st.include_pmids)
        st.log(f"   [Strategy] try={tries} total={total} recall_ok={ok}")
        st.final_strategy_count = total
        if ok:
            # precision window check
            pmin = st.kb.get("precision_window",{}).get("min", 5)
            pmax = st.kb.get("precision_window",{}).get("max", 2000)
            if pmin <= total <= pmax:
                st.recommended_filters = {"topic":filters["topic"], "design":filters["design"], "languages":filters["languages"]}
                st.log("   [Strategy] PASSED recall & precision checks.")
                break
        if tries >= max_tries:
            st.log("   [Strategy] FAIL after remediation attempts. Terminate.")
            raise SystemExit("Strategy cannot satisfy recall/precision simultaneously.")
        # Remediate via LLM
        stats = {"total":total,"recall_ok":ok,"precision_min":pmin,"precision_max":pmax}
        edit = ask_json(QWEN_MODEL, REMEDIATION_STRAT_SYS, remediation_strat_user(filters, stats), {"op":"","term":""})
        op = (edit.get("op") or "").upper()
        term = edit.get("term","").strip()
        if op == "DROP_TERM" and term and filters["topic"]:
            # try to drop a MeSH term by name (loose remove)
            new_topic = re.sub(rf'\s*"?{re.escape(term)}"?\[MeSH Terms\]\s*(OR)?','', filters["topic"])
            new_topic = re.sub(r'\(\s*OR\s*\)','', new_topic).replace("  "," ").strip()
            filters["topic"] = new_topic
        elif op == "ADD_ANCHOR" and term:
            # add as TIAB anchor inside UniverseQuery (soft)
            st.universe_query = f"({st.universe_query}) AND ({or_block([term],'tiab')})"
        elif op == "BROADEN_DESIGN_FILTER":
            # switch to less restrictive if not already
            dfs = st.kb.get("design_filters", DEFAULT_KB["design_filters"])
            filters["design"] = dfs.get("Comparative", dfs.get("AnyTrial", filters["design"]))
        else:
            st.warnings.append(f"Unrecognized remediation op or term missing: {edit}")
        tries += 1

# ---- State 5: Finalization & Handoff
def state5_finalize(st: EngineState, nlq: str):
    st.log(f"{time.strftime('%H:%M:%S')}  [S5] Finalization & handoff...")

    # Compose embedding string
    embed_str = _lm_call(QWEN_MODEL, [
        {"role":"system","content":EMBED_SYS},
        {"role":"user","content": embed_user(st.locked_protocol, st.mesh_vernac)}
    ])

    # Human-readable report
    rep = []
    rep.append("==================== SNIFF REPORT ====================")
    rep.append("NLQ:\n  " + textwrap.shorten(nlq, width=240, placeholder="..."))
    rep.append("\nLocked Protocol:\n  " + json.dumps(st.locked_protocol, ensure_ascii=False))
    rep.append(f"\nUniverse:\n  query: {st.universe_query}\n  count: {st.universe_count}")
    rep.append(f"\nGround truth (plausible INCLUDEs >= {GROUND_TRUTH_MIN}): {len(st.ground_truth)}")
    for r in st.ground_truth[:10]:
        rep.append(f"  • [{r['pmid']}] {r['year']} {r['language']} | {textwrap.shorten(r['title'] or '', 140)}")
    if st.plaus_fails:
        rep.append(f"\nPlausibility FAILed PMIDs (excluded): {', '.join(st.plaus_fails)}")
    rep.append("\nMeSH vernaculum:\n  " + json.dumps(st.mesh_vernac, ensure_ascii=False))
    rep.append("\nValidated Filters:")
    rep.append("  topic: " + (st.recommended_filters.get("topic") or "(none)"))
    rep.append("  design: " + (st.recommended_filters.get("design") or "(none)"))
    rep.append("  languages: " + ", ".join(st.recommended_filters.get("languages",[])))
    rep.append(f"\nFinal strategy total: {st.final_strategy_count}")
    if st.warnings:
        rep.append("\nWARNINGS:")
        for w in st.warnings:
            rep.append("  - " + w)
    rep.append("\nEmbed string:\n  " + textwrap.shorten(embed_str.strip(), 300))
    rep.append("\n================== END OF REPORT =====================")

    report_text = "\n".join(st.report_lines) + "\n\n" + "\n".join(rep)

    # Write files
    (OUT_DIR/"sniff_report.txt").write_text(report_text, encoding="utf-8")
    artifacts = {
        "locked_protocol": st.locked_protocol,
        "universe_query": st.universe_query,
        "universe_count": st.universe_count,
        "ground_truth_pmids": [r["pmid"] for r in st.ground_truth],
        "mesh_vernaculum": st.mesh_vernac,
        "recommended_filters": st.recommended_filters,
        "final_strategy_total": st.final_strategy_count,
        "warnings": st.warnings,
        "research_question_string_for_embedding": embed_str.strip()
    }
    (OUT_DIR/"sniff_artifacts.json").write_text(json.dumps(artifacts, indent=2, ensure_ascii=False), encoding="utf-8")

    # quick TSV for ground truth
    lines = ["pmid\tyear\tlanguage\ttitle"]
    for r in st.ground_truth:
        lines.append(f"{r['pmid']}\t{r.get('year','')}\t{r.get('language','')}\t{(r.get('title') or '').replace('\t',' ')}")
    (OUT_DIR/"ground_truth.tsv").write_text("\n".join(lines), encoding="utf-8")

    # Print concise tail summary to notebook/stdout
    print("\n==================== FINAL SUMMARY ====================")
    print(f"Universe: count={st.universe_count}")
    print(f"Ground truth: n={len(st.ground_truth)} pmids={', '.join([r['pmid'] for r in st.ground_truth[:8]])}{' ...' if len(st.ground_truth)>8 else ''}")
    print("Validated filters:")
    print("  topic:   " + (st.recommended_filters.get("topic") or "(none)"))
    print("  design:  " + (st.recommended_filters.get("design") or "(none)"))
    print("  lang:    " + ", ".join(st.recommended_filters.get("languages",[])))
    print("Final total:", st.final_strategy_count)
    if st.warnings:
        print("Warnings:")
        for w in st.warnings: print(" -", w)
    print("Artifacts:", OUT_DIR)

# ----------------------------
# Top-level runner
# ----------------------------
def sniff_validate_engine(USER_NLQ: str, kb_path="system_knowledge_base.json"):
    st = EngineState(kb=load_kb(kb_path))
    try:
        state1_protocol_lockdown(st, USER_NLQ)
        state2_universe(st)
        state3_ground_truth(st)
        state35_plausibility(st)
        state4_strategy(st)
        state5_finalize(st, USER_NLQ)
    except SystemExit as e:
        # Persist partial report/warnings for debugging
        st.warnings.append(f"TERMINATED: {str(e)}")
        (OUT_DIR/"sniff_report.txt").write_text("\n".join(st.report_lines) + f"\n\nTERMINATED: {str(e)}", encoding="utf-8")
        raise

# ----------------------------
# RUN: put your NLQ here
# ----------------------------
USER_NLQ = """Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE).
Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE
(the intervention of interest is INC, not the surgery).
Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block,
or systemic multimodal analgesia.
Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days.
Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control.
Year_min = 2015. Languages = English, Portuguese, Spanish.
"""

if __name__ == "__main__":
    sniff_validate_engine(USER_NLQ)


16:30:08  [S1] Protocol lockdown...
  [S1] Locked protocol:
    {"population_terms": ["adults", "minimally invasive repair of pectus excavatum", "Nuss procedure", "pectus excavatum surgery"], "intervention_terms": ["intercostal nerve cryoablation", "INC", "intraoperative analgesia"], "comparators_terms": ["thoracic epidural", "paravertebral block", "intercostal nerve block", "erector spinae plane block", "systemic multimodal analgesia"], "outcomes_terms": ["postoperative opioid consumption", "in-hospital opioid use", "discharge opioid use", "pain scores", "0-7 day pain assessment"], "must_have": ["pectus excavatum surgery", "minimally invasive repair", "postoperative pain management", "analgesia techniques"], "avoid": ["cardiac surgery", "chest wall deformity"], "designs_preference": "Randomized Controlled Trial", "languages": ["english", "portuguese", "spanish"], "year_min": 2015}
16:30:35  [S2] Universe definition & sizing...
   [Universe] try=0 count=27 window=(50, 10000) query=(adu

HTTPError: 404 Client Error: Not Found for url: http://127.0.0.1:1234/v1/chat/completions

In [37]:
# Sniff Validation Engine v3.1
# Fixes over v3.0:
# - Protocol terms MUST be 1–3 words; hard guardrails for MIRPE/Nuss/pectus + cryoablation
# - Remediation ops: ADD_POP, ADD_INT, ADD_ANCHOR, SIMPLIFY_TERM, REMOVE_TERM
# - True PICO-weighted TF-IDF: sum(pico_weight * tfidf(term, doc))
# - MeSH role classification via LLM using the article's actual MeSH list only
# - Ask-Validate-Retry writes invalid outputs to sniff_out/llm_debug_*.json
# - Model eviction: TTL + 'lms unload --all' + 10s wait

import os, re, json, time, math, subprocess, shutil, random, pathlib, textwrap
from typing import List, Dict, Any, Tuple
from collections import Counter, defaultdict

import requests
from xml.etree import ElementTree as ET

# -----------------------------
# Config & Paths
# -----------------------------
LMSTUDIO_BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234").rstrip("/")
QWEN_MODEL     = os.getenv("QWEN_MODEL", "unsloth/qwen3-4b")
SCREENER_MODEL = os.getenv("SCREENER_MODEL", "gemma-3n-e4b-it@q5_k_m")   # LM Studio display id
ENTREZ_EMAIL   = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY = os.getenv("ENTREZ_API_KEY", "")
HTTP_TIMEOUT   = int(os.getenv("HTTP_TIMEOUT", "120"))
LM_TTL_SECONDS = int(os.getenv("LM_TTL_SECONDS", "5"))
LM_SWITCH_WAIT_SECONDS = int(os.getenv("LM_SWITCH_WAIT_SECONDS", "10"))

WORKDIR = pathlib.Path(".").resolve()
OUTDIR = WORKDIR / "sniff_out"
OUTDIR.mkdir(parents=True, exist_ok=True)

KB_PATH = WORKDIR / "system_knowledge_base.json"
REPORT_TXT = OUTDIR / "sniff_report.txt"
ARTIFACTS_JSON = OUTDIR / "sniff_artifacts.json"

# Universe sizing thresholds
UNIVERSE_MIN = 50
UNIVERSE_MAX = 10000
UNIVERSE_HARD_MIN = 25  # below this, terminate

# Rerank corpus/sample sizes
RERANK_FETCH_N = 600   # efetch corpus size for reranker
GROUND_TOP_N   = 50    # how many to screen after rerank

# Ground truth requirements
MIN_GROUND_TRUTH = 3

# Strategy validation
PRECISION_WINDOW = (10, 5000)
REMEDIATION_MAX_TRIES = 3

# -----------------------------
# Utilities: LM Studio management
# -----------------------------
def lmstudio_unload_all_safely():
    try:
        exe = shutil.which("lms") or shutil.which("lms.exe")
        if exe:
            subprocess.run([exe, "unload", "--all"], check=False,
                           stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except Exception:
        pass

class ModelManager:
    def __init__(self):
        self.current = None
    def ensure_model(self, model_id: str):
        if self.current == model_id:
            return
        lmstudio_unload_all_safely()
        time.sleep(LM_SWITCH_WAIT_SECONDS)
        self.current = model_id

LM = ModelManager()

# -----------------------------
# HTTP / LLM JSON helpers
# -----------------------------
def lm_chat(model: str, system: str, user: str,
            temperature: float = 0.0,
            max_tokens: int = 2048,
            stop: List[str] | None = None) -> str:
    LM.ensure_model(model)
    url = f"{LMSTUDIO_BASE}/v1/chat/completions"
    body = {
        "model": model,
        "messages": [{"role":"system","content":system},{"role":"user","content":user}],
        "temperature": float(temperature),
        "max_tokens": int(max_tokens),
        "stream": False,
        "ttl": LM_TTL_SECONDS  # best-effort TTL eviction
    }
    if stop:
        body["stop"] = stop
    r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.json()["choices"][0]["message"]["content"]

FENCE = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)
BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
END   = re.compile(r"\s*END_JSON", re.I)

def _sanitize_json_str(s: str) -> str:
    s = (s or "").replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json(txt: str) -> str:
    blocks=[]; pos=0
    while True:
        m1 = BEGIN.search(txt or "", pos)
        if not m1: break
        m2 = END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])
    fences = FENCE.findall(txt or "")
    if fences:
        return _sanitize_json_str(fences[-1])
    # fallback: last balanced {..}
    s = txt or ""
    last,stack,start=None,0,None
    for i,ch in enumerate(s):
        if ch=="{":
            if stack==0: start=i
            stack+=1
        elif ch=="}" and stack>0:
            stack-=1
            if stack==0 and start is not None:
                last = s[start:i+1]
    if last: return _sanitize_json_str(last)
    raise ValueError("no JSON found")

def dump_invalid(stage: str, content: str, idx: int):
    (OUTDIR / f"llm_debug_{stage}_{idx}.txt").write_text(content, encoding="utf-8")

def get_validated_json(model: str, system: str, user_base: str,
                       validator, template_hint: str = "",
                       max_tries: int = 3, temperature: float = 0.0, max_tokens: int = 2048,
                       dbg_stage: str = "generic") -> dict:
    history_user = user_base
    for i in range(max_tries):
        raw = lm_chat(model, system, history_user + "\n\nReturn ONLY:\nBEGIN_JSON\n{...}\nEND_JSON",
                      temperature=temperature, max_tokens=max_tokens, stop=["END_JSON"])
        try:
            js = json.loads(extract_json(raw))
        except Exception as e:
            dump_invalid(f"{dbg_stage}_parse", raw, i)
            history_user = user_base + f"\n\nYour previous JSON was invalid: {e}\n{template_hint}\nUse ONLY true/false booleans; keys exactly as specified."
            continue
        try:
            ok, why = validator(js)
        except Exception as e:
            ok, why = False, f"Validation exception: {e}"
        if ok:
            return js
        dump_invalid(f"{dbg_stage}_schema", json.dumps(js,indent=2), i)
        history_user = user_base + f"\n\nYour previous JSON failed validation: {why}\n{template_hint}\nCorrect and resubmit the SAME schema."
    raise SystemExit("Fatal: LLM failed to produce valid JSON after retries.")

# -----------------------------
# Knowledge base
# -----------------------------
DEFAULT_KB = {
  "publication_types": [
    "Randomized Controlled Trial","Clinical Trial","Controlled Clinical Trial",
    "Comparative Study","Prospective Studies","Cohort Studies","Case-Control Studies",
    "Systematic Review","Meta-Analysis","Network Meta-Analysis"
  ],
  "languages": ["english","portuguese","spanish","french","german","italian"]
}

def ensure_kb() -> dict:
    if KB_PATH.exists():
        try: return json.loads(KB_PATH.read_text(encoding="utf-8"))
        except Exception: pass
    KB_PATH.write_text(json.dumps(DEFAULT_KB, indent=2), encoding="utf-8")
    print(f"[KB] No KB found. Wrote defaults to {KB_PATH}")
    return DEFAULT_KB

KB = ensure_kb()

# -----------------------------
# PubMed E-utilities
# -----------------------------
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"

def esearch(term: str, mindate: int|None=None, retmax=200, usehistory=True) -> dict:
    params = {"db":"pubmed","retmode":"json","term":term,"retmax":retmax,"email":ENTREZ_EMAIL}
    if usehistory: params["usehistory"]="y"
    if mindate: params["mindate"]=str(mindate)
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/esearch.fcgi", params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.json()["esearchresult"]

def esearch_all_ids(term: str, mindate: int|None=None, cap:int=10000) -> List[str]:
    js = esearch(term, mindate=mindate, retmax=0, usehistory=True)
    count = int(js.get("count","0"))
    if count == 0: return []
    webenv = js["webenv"]; qk = js["querykey"]
    ids=[]
    retstart=0
    while retstart < min(count,cap):
        chunk = min(5000, cap-retstart)
        r = requests.get(f"{EUTILS}/esearch.fcgi", params={
            "db":"pubmed","retmode":"json","retstart":retstart,"retmax":chunk,
            "WebEnv":webenv,"query_key":qk,"email":ENTREZ_EMAIL,
            **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
        }, timeout=HTTP_TIMEOUT)
        r.raise_for_status()
        ids.extend(r.json()["esearchresult"].get("idlist",[]))
        retstart += chunk
        time.sleep(0.34)
    return ids

def efetch_xml(ids: List[str]) -> str:
    if not ids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(ids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str) -> List[dict]:
    out=[]
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node):
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({"pmid":pmid,"title":title,"abstract":abstract,"year":year,"language":lang,"publication_types":pubtypes,"mesh":mesh})
    return out

# -----------------------------
# Query building
# -----------------------------
def or_block(terms: List[str], field="tiab") -> str:
    toks=[]
    for t in terms or []:
        t=t.strip()
        if not t: continue
        if " " in t or "-" in t:
            toks.append(f"\"{t}\"[{field}]")
        else:
            toks.append(f"{t}[{field}]")
    return "("+" OR ".join(toks)+")" if toks else ""

def lang_filter(langs: List[str]) -> str:
    toks=[f"\"{L}\"[Language]" for L in langs or []]
    return "("+" OR ".join(toks)+")" if toks else ""

def build_universe_query(protocol: dict) -> str:
    P = or_block(protocol.get("population_terms"), "tiab")
    I = or_block(protocol.get("intervention_terms"), "tiab")
    A = or_block(protocol.get("must_have"), "tiab")
    parts=[p for p in [P,I,A] if p]
    q = " AND ".join(parts) if parts else ""
    if protocol.get("languages"):
        q = f"({q}) AND {lang_filter(protocol['languages'])}" if q else lang_filter(protocol['languages'])
    return q

# -----------------------------
# Validators
# -----------------------------
def _is_strlist(x): return isinstance(x, list) and all(isinstance(t,str) and t.strip() for t in x)
def _max_words_ok(lst, max_words=3):
    for t in lst or []:
        if len(t.strip().split()) > max_words:
            return False, t
    return True, ""

def validate_protocol(js: dict) -> Tuple[bool,str]:
    req = ["population_terms","intervention_terms","comparators_terms","outcomes_terms","must_have","avoid","designs_preference","languages","year_min"]
    for k in req:
        if k not in js: return False, f"missing key {k}"
    if not _is_strlist(js["population_terms"]): return False,"population_terms must be list[str]"
    if not _is_strlist(js["intervention_terms"]): return False,"intervention_terms must be list[str]"
    if not isinstance(js["comparators_terms"], list): return False,"comparators_terms list"
    if not _is_strlist(js["outcomes_terms"]): return False,"outcomes_terms list[str]"
    if not isinstance(js["must_have"], list): return False,"must_have list[str]"
    if not isinstance(js["avoid"], list): return False,"avoid list[str]"
    if js["designs_preference"] not in KB["publication_types"]: return False,"designs_preference must be in KB.publication_types"
    if not set(js["languages"]).issubset(set(KB["languages"])): return False,"languages must be subset of KB.languages"
    if not isinstance(js["year_min"], int): return False,"year_min int"
    # length constraints: each term <= 3 words
    for key in ["population_terms","intervention_terms","must_have"]:
        ok, bad = _max_words_ok(js[key], 3)
        if not ok: return False, f"{key} contains overly long phrase: '{bad}'. Use <=3-word tokens."
    # anchor constraints
    P_join = " ".join(js["population_terms"]).lower()
    if not any(a in P_join for a in ["pectus","nuss","mirpe"]):
        return False, "population_terms must include MIRPE/Nuss/pectus anchors"
    I_join = " ".join(js["intervention_terms"]).lower()
    if not any("cryoablat" in I_join or "cryoanalg" in I_join or "inc" == t.lower() for t in js["intervention_terms"]):
        return False, "intervention_terms must include cryoablation/cryoanalgesia/INC"
    # adult-only guard
    if any("pediatric" in t.lower() for t in js["population_terms"] + js["avoid"]):
        # allowed in avoid, forbidden in population_terms
        if any("pediatric" in t.lower() for t in js["population_terms"]):
            return False,"population_terms must not include pediatric"
    return True, "ok"

def validate_screener_output(js: dict) -> Tuple[bool,str]:
    if "pmid" not in js or not isinstance(js["pmid"], str): return False,"pmid missing"
    chk = js.get("checklist",{})
    for k in ["P","I","O","D"]:
        if not isinstance(chk.get(k), bool): return False, f"checklist.{k} must be bool"
    if js.get("decision") not in ["INCLUDE","EXCLUDE","BORDERLINE"]:
        return False,"decision invalid"
    if "reason" not in js: return False,"reason missing"
    # mesh_roles optional here; roles will be computed separately from actual MeSH list
    return True,"ok"

def validate_mesh_roles(js: dict) -> Tuple[bool,str]:
    if "pmid" not in js or "labels" not in js: return False,"missing keys"
    if not isinstance(js["labels"], list): return False,"labels list"
    for it in js["labels"]:
        if not isinstance(it, dict): return False,"label must be dict"
        if "mesh" not in it or "role" not in it: return False,"mesh+role required"
        if it["role"] not in ["P","I","O","C","G","X"]: return False,"invalid role"
    return True,"ok"

def validate_remediation(js: dict) -> Tuple[bool,str]:
    allowed_ops = {"DROP_TERM","ADD_ANCHOR","ADD_POP","ADD_INT","SIMPLIFY_TERM","REMOVE_TERM","BROADEN_DESIGN_FILTER"}
    if js.get("op") not in allowed_ops: return False,"op invalid"
    return True,"ok"

# -----------------------------
# Prompts
# -----------------------------
S1_SYSTEM = """You are a rigorous protocol compiler. Return a LOCKED PROTOCOL with short, searchable tokens (each 1–3 words)."""

def s1_user(nlq: str, kb: dict) -> str:
    return f"""From the NLQ below, produce a STRICT protocol JSON.

**Hard rules**
- Use short tokens (each 1–3 words). DO NOT output long phrases like "intraoperative intercostal nerve cryoablation for analgesia".
- Population MUST anchor to MIRPE/Nuss/pectus excavatum for ADULTS (pediatric forbidden in population).
- Intervention MUST include intercostal nerve cryoablation / cryoanalgesia / INC (for analgesia).
- designs_preference must be one of: {kb["publication_types"]}
- languages subset of: {kb["languages"]}

**Bad vs Good examples**
Bad: "intraoperative intercostal nerve cryoablation for analgesia"
Good: ["intercostal nerve","cryoablation","analgesia","INC"]

NLQ:
<<<
{nlq}
>>>

Return ONLY:
BEGIN_JSON
{{"population_terms": ["adults","Nuss","MIRPE","pectus excavatum","minimally invasive repair"],
 "intervention_terms": ["intercostal nerve","cryoablation","cryoanalgesia","INC","analgesia"],
 "comparators_terms": ["thoracic epidural","paravertebral block","intercostal nerve block","erector spinae plane block","systemic multimodal analgesia"],
 "outcomes_terms": ["postoperative opioid consumption","pain scores","0-7 day pain","discharge opioid use"],
 "must_have": ["MIRPE","Nuss","pectus excavatum","cryoablation"],
 "avoid": ["pediatric"],
 "designs_preference": "Randomized Controlled Trial",
 "languages": ["english","portuguese","spanish"],
 "year_min": 2015}}
END_JSON"""

S2_REMEDIATION_SYSTEM = "You are a cautious query remediation assistant. You NEVER violate core constraints and you can remove or simplify problematic tokens."
def s2_remediation_user(protocol: dict, query: str, count: int, why: str) -> str:
    constraints = """Core constraints:
- Adults only
- Population must mention MIRPE/Nuss/pectus excavatum
- Intervention must be intercostal nerve cryoablation (INC)/cryoanalgesia for analgesia
- Terms must be 1–3 words only (simplify long ones)"""
    return f"""The universe query is too {why}. Current count={count}.
Query:
{query}

{constraints}

Propose ONE fix as JSON:
- {{"op":"SIMPLIFY_TERM","where":"population|intervention|anchor","term":"<existing long token>","replacement":"<1–3 word token>"}}
- {{"op":"REMOVE_TERM","where":"population|intervention|anchor","term":"<overly-specific token>"}}
- {{"op":"ADD_POP","term":"<short population token>"}}
- {{"op":"ADD_INT","term":"<short intervention token>"}}
- {{"op":"ADD_ANCHOR","term":"<short anchor token>"}}

Return ONLY:
BEGIN_JSON
{{"op":"ADD_ANCHOR","term":"pectus excavatum"}}
END_JSON"""

S3_SCREEN_SYSTEM = "You are a strict senior PRISMA screener. Use ONLY true/false booleans. Return JSON only."
def s3_screen_user(protocol: dict, rec: dict) -> str:
    return f"""Screen against protocol (ALL must be true for INCLUDE):
- P: Title/abstract explicitly mentions pectus excavatum OR Nuss OR MIRPE AND adults.
- I: Title/abstract explicitly mentions intercostal nerve cryoablation (INC) / cryoanalgesia for analgesia (not tumor/derm cryotherapy).
- O: Reports/plans postoperative opioid consumption OR pain scores within 0–7 days.
- D: Primary comparative study (RCT preferred; accept clearly comparative cohort/case-control).

Return:
BEGIN_JSON
{{"pmid":"{rec['pmid']}",
  "checklist": {{"P": false, "I": false, "O": false, "D": false}},
  "decision":"EXCLUDE|BORDERLINE|INCLUDE",
  "reason":"<=1 line why"}}
END_JSON"""

S3_MESH_ROLE_SYSTEM = "You classify MeSH descriptors by role for PICOS."
def s3_mesh_role_user(rec: dict) -> str:
    return f"""Assign roles to THIS RECORD'S MeSH ONLY (do not invent new terms).
Allowed roles: P (population/procedure context), I (intervention/analgesia), O (outcome), C (comparator), G (generic), X (irrelevant).

MeSH descriptors:
{rec.get('mesh',[])}

Return ONLY:
BEGIN_JSON
{{"pmid":"{rec['pmid']}", "labels":[{{"mesh":"Funnel Chest","role":"P"}}]}}
END_JSON"""

S35_PLAUS_SYS = "You are the senior plausibility gate. Be ruthless and concise."
def s35_plaus_user(protocol: dict, rec: dict) -> str:
    return f"""Does the core topic plausibly match the protocol?
Protocol core: adults + (pectus excavatum / MIRPE / Nuss) + intercostal nerve cryoablation for analgesia.

Title: {rec['title']}
Abstract: {rec['abstract'][:1000]}

Return:
BEGIN_JSON
{{"pmid":"{rec['pmid']}", "verdict":"PASS|FAIL", "why":"<=1 line"}} 
END_JSON"""

S4_STRATEGY_SYSTEM = "You are a cautious filter strategist. Choose among allowed options only."
def s4_strategy_user(protocol: dict, mesh_vocab: Dict[str,str]) -> str:
    roles = defaultdict(list)
    for m,r in mesh_vocab.items():
        roles[r].append(m)
    return f"""Build recommended filters based ONLY on included articles' MeSH (exact descriptors).
- topic_filter: OR of P/I/O MeSH terms (use [MeSH Terms]); <= 15 tokens total; prefer P and I anchors.
- design_filter: ONE publication type equal to designs_preference={protocol["designs_preference"]} from {KB["publication_types"]}.

Available MeSH by role:
P={roles.get('P',[])}
I={roles.get('I',[])}
O={roles.get('O',[])}

Return ONLY:
BEGIN_JSON
{{"topic_filter":"(Funnel Chest[MeSH Terms] OR Cryosurgery[MeSH Terms])",
  "design_filter":"Randomized Controlled Trial[Publication Type]"}}
END_JSON"""

S4_REMEDIATE_SYSTEM = "You repair strategies with a single, constrained action."
def s4_remediate_user(snapshot: dict) -> str:
    return f"""Strategy failed.
Snapshot:
{json.dumps(snapshot, ensure_ascii=False)}

Propose ONE:
- {{"op":"DROP_TERM","term":"<MeSH>","where":"topic"}}
- {{"op":"ADD_ANCHOR","term":"<short tiab token>"}}
- {{"op":"BROADEN_DESIGN_FILTER"}}

Return ONLY:
BEGIN_JSON
{{"op":"DROP_TERM","term":"X","where":"topic"}}
END_JSON"""

# -----------------------------
# S1: Protocol lockdown
# -----------------------------
def state1_protocol_lockdown(nlq: str) -> dict:
    print("[S1] Protocol lockdown...")
    proto = get_validated_json(
        QWEN_MODEL, S1_SYSTEM, s1_user(nlq, KB),
        validator=validate_protocol,
        template_hint="Each term must be 1–3 words. Anchors required: MIRPE/Nuss/pectus in Population; cryoablation/cryoanalgesia/INC in Intervention.",
        max_tokens=1536, dbg_stage="protocol"
    )
    # Dedup + lowercase normalize where sensible
    def dedup_keep_order(lst):
        seen=set(); out=[]
        for t in lst:
            k=t.strip()
            if not k or k.lower() in seen: continue
            seen.add(k.lower()); out.append(k)
        return out
    for key in ["population_terms","intervention_terms","comparators_terms","outcomes_terms","must_have","avoid","languages"]:
        proto[key] = dedup_keep_order(proto.get(key,[]))
    print("  [S1] Locked protocol:\n   ", json.dumps(proto, ensure_ascii=False))
    return proto

# -----------------------------
# S2: Universe definition & remediation
# -----------------------------
def state2_universe(protocol: dict) -> Tuple[str, List[str]]:
    print("[S2] Universe definition & sizing...")
    query = build_universe_query(protocol)
    tries = 0
    while True:
        ids = esearch_all_ids(query, mindate=protocol["year_min"], cap=UNIVERSE_MAX+500)
        count = len(ids)
        print(f"   [Universe] try={tries} count={count} window=({UNIVERSE_MIN}, {UNIVERSE_MAX})")
        if UNIVERSE_MIN <= count <= UNIVERSE_MAX:
            return query, ids
        if tries >= 2:
            if count < UNIVERSE_HARD_MIN:
                raise SystemExit(f"Fatal: universe too small ({count}<{UNIVERSE_HARD_MIN}).")
            print(f"   [Universe] Proceeding with suboptimal size={count} and a WARNING.")
            return query, ids
        why = "narrow" if count < UNIVERSE_MIN else "broad"
        fix = get_validated_json(QWEN_MODEL, S2_REMEDIATION_SYSTEM,
                                 s2_remediation_user(protocol, query, count, why),
                                 validator=validate_remediation, max_tokens=512, dbg_stage="universe_remed")
        op = fix["op"]
        if op in ("ADD_ANCHOR","ADD_POP","ADD_INT"):
            term = fix.get("term","").strip()
            if not term: 
                tries += 1; continue
            if op=="ADD_ANCHOR": protocol["must_have"].append(term)
            elif op=="ADD_POP": protocol["population_terms"].append(term)
            elif op=="ADD_INT": protocol["intervention_terms"].append(term)
        elif op=="SIMPLIFY_TERM":
            where = fix.get("where")
            term = fix.get("term","").strip()
            repl = fix.get("replacement","").strip()
            target = {"population": "population_terms", "intervention":"intervention_terms", "anchor":"must_have"}.get(where)
            if target and term and repl and target in protocol:
                lst = protocol[target]
                protocol[target] = [repl if x.strip().lower()==term.lower() else x for x in lst]
        elif op=="REMOVE_TERM":
            where = fix.get("where")
            term = fix.get("term","").strip().lower()
            target = {"population": "population_terms", "intervention":"intervention_terms", "anchor":"must_have"}.get(where)
            if target and term and target in protocol:
                protocol[target] = [x for x in protocol[target] if x.strip().lower()!=term]
        # rebuild
        # also enforce 1–3 word rule post-change (defensive)
        for key in ["population_terms","intervention_terms","must_have"]:
            protocol[key] = [t for t in protocol[key] if len(t.split())<=3]
        query = build_universe_query(protocol)
        tries += 1

# -----------------------------
# S2.5: True PICO-weighted TF-IDF reranker
# -----------------------------
def safe_import_sklearn():
    try:
        from sklearn.feature_extraction.text import TfidfVectorizer
        import numpy as np
        return TfidfVectorizer, np
    except Exception:
        return None, None

def state2_5_rerank_universe(query: str, ids: List[str], protocol: dict) -> List[dict]:
    print("[S2.5] Rerank universe with PICO-weighted TF-IDF...")
    fetch_ids = ids[:RERANK_FETCH_N]
    docs=[]
    for i in range(0, len(fetch_ids), 300):
        chunk = fetch_ids[i:i+300]
        xml = efetch_xml(chunk)
        docs.extend(parse_pubmed_xml(xml))
        time.sleep(0.34)
    if not docs: return []

    texts = [(d["title"] or "") + " " + (d["abstract"] or "") for d in docs]
    TfidfVectorizer, np = safe_import_sklearn()
    if TfidfVectorizer is None:
        # fallback: keyword presence heuristic
        def score_text(t):
            low=t.lower(); s=0.0
            for term in protocol["population_terms"] + protocol["intervention_terms"]:
                if term.lower() in low: s += 1.5
            for term in protocol["outcomes_terms"]:
                if term.lower() in low: s += 1.0
            for term in protocol["avoid"]:
                if term.lower() in low: s -= 2.0
            return s
        for d,t in zip(docs, texts):
            d["_score"] = score_text(t)
        docs.sort(key=lambda x:x.get("_score",0.0), reverse=True)
        return docs

    # Real TF-IDF with weighted term columns
    vec = TfidfVectorizer(min_df=2, ngram_range=(1,2), stop_words="english")
    X = vec.fit_transform(texts)  # shape (n_docs, n_terms)
    vocab = vec.vocabulary_
    # collect unique protocol tokens
    def toks(lst): 
        return [w.lower() for w in lst or []]
    P = toks(protocol["population_terms"]); I = toks(protocol["intervention_terms"]); O = toks(protocol["outcomes_terms"]); A = toks(protocol["must_have"])
    weights = defaultdict(float)
    for term in P: weights[term] += 1.5
    for term in I: weights[term] += 1.5
    for term in O: weights[term] += 1.0
    for term in A: weights[term] += 0.5  # light anchor boost

    # compute doc scores as sum(weight * tfidf_col)
    import numpy as np
    score = np.zeros(X.shape[0], dtype=float)
    for term, w in weights.items():
        if term in vocab:
            col = vocab[term]
            score += w * X[:, col].toarray().ravel()
    # penalty for avoid tokens if present in vocab
    for term in [t.lower() for t in protocol.get("avoid",[])]:
        if term in vocab:
            col = vocab[term]
            score -= 2.0 * X[:, col].toarray().ravel()

    for d, s in zip(docs, score):
        d["_score"] = float(s)
    docs.sort(key=lambda x:x.get("_score",0.0), reverse=True)
    return docs

# -----------------------------
# S3: Ground truth discovery (strict screener) + MeSH roles via LLM (actual MeSH only)
# -----------------------------
def screen_record(protocol: dict, rec: dict) -> dict:
    js = get_validated_json(
        SCREENER_MODEL, S3_SCREEN_SYSTEM, s3_screen_user(protocol, rec),
        validator=validate_screener_output, max_tokens=1024, dbg_stage="screen"
    )
    snippet = (rec["abstract"] or "")[:500].replace("\n"," ")
    print(f'  [Screen] PMID {rec["pmid"]} -> decision={js["decision"]} checklist={js["checklist"]} why={js["reason"]}')
    print(f'    Title: {rec["title"][:160]}')
    print(f'    Abstract: {snippet}{"..." if len(rec["abstract"] or "")>500 else ""}')
    return js

def classify_mesh_roles(rec: dict) -> Dict[str,str]:
    # ask LLM to role-tag THIS record's MeSH ONLY
    js = get_validated_json(
        QWEN_MODEL, S3_MESH_ROLE_SYSTEM, s3_mesh_role_user(rec),
        validator=validate_mesh_roles, max_tokens=768, dbg_stage="mesh_roles"
    )
    roles={}
    record_mesh = set([m.strip() for m in (rec.get("mesh") or []) if m])
    for item in js.get("labels",[]):
        m=item.get("mesh","").strip()
        r=item.get("role","G").strip()
        if m in record_mesh:
            roles[m]=r
    return roles

def state3_ground_truth(reranked_docs: List[dict], protocol: dict) -> Tuple[List[str], Dict[str,str]]:
    print("[S3] Ground-truth discovery & vocabulary mining...")
    includes=[]
    mesh_vocab = {}
    for rec in reranked_docs[:GROUND_TOP_N]:
        js = screen_record(protocol, rec)
        if js["decision"] == "INCLUDE" and all(js["checklist"].get(k,False) for k in ["P","I","O","D"]):
            includes.append(rec)
            # roles strictly from this record's actual MeSH (via LLM)
            roles = classify_mesh_roles(rec)
            for m,r in roles.items():
                if r in ["P","I","O"]:  # keep only useful roles
                    mesh_vocab[m]=r
    pmids = [r["pmid"] for r in includes]
    print(f"  [S3] Ground truth PMIDs: {pmids}")
    return pmids, mesh_vocab

# -----------------------------
# S3.5: Senior plausibility gate
# -----------------------------
def state3_5_plausibility(protocol: dict, included_pmids: List[str]) -> List[str]:
    print("[S3.5] Senior plausibility spot-check...")
    if not included_pmids: return []
    xml = efetch_xml(included_pmids[:300])
    recs = {r["pmid"]: r for r in parse_pubmed_xml(xml)}
    keep=[]
    for pmid in included_pmids:
        rec = recs.get(pmid)
        if not rec: continue
        js = get_validated_json(QWEN_MODEL, S35_PLAUS_SYS, s35_plaus_user(protocol, rec),
                                validator=lambda x: (x.get("verdict") in ("PASS","FAIL") and x.get("pmid")==pmid, "bad verdict or pmid"),
                                max_tokens=256, dbg_stage="plausibility")
        print(f'  [Plaus] PMID {pmid} -> {js["verdict"]} : {js["why"]}')
        if js["verdict"] == "PASS":
            keep.append(pmid)
    if len(keep) < MIN_GROUND_TRUTH:
        raise SystemExit(f"Fatal: insufficient plausible ground truth (got {len(keep)}/{MIN_GROUND_TRUTH}).")
    return keep

# -----------------------------
# S4: Strategy validation & refinement
# -----------------------------
def state4_validate_strategy(universe_query: str, ground_pmids: List[str], protocol: dict, mesh_vocab: Dict[str,str]) -> Tuple[dict, List[str]]:
    print("[S4] Strategy validation & refinement...")
    filt = get_validated_json(QWEN_MODEL, S4_STRATEGY_SYSTEM, s4_strategy_user(protocol, mesh_vocab),
                              validator=lambda x: ("topic_filter" in x and "design_filter" in x, "missing fields"),
                              max_tokens=512, dbg_stage="strategy_build")
    attempts = 0
    while attempts <= REMEDIATION_MAX_TRIES:
        combined = f"({universe_query}) AND ({filt['topic_filter']}) AND ({filt['design_filter']})"
        ids = esearch_all_ids(combined, mindate=protocol["year_min"], cap=10000)
        total = len(ids)
        recall_ok = set(ground_pmids).issubset(set(ids))
        precision_ok = PRECISION_WINDOW[0] <= total <= PRECISION_WINDOW[1]
        print(f"   [Strategy] try={attempts} total={total} recall_ok={recall_ok} precision_ok={precision_ok}")
        if recall_ok and precision_ok:
            return {"topic":filt["topic_filter"], "design":filt["design_filter"]}, ids
        if attempts == REMEDIATION_MAX_TRIES:
            break
        snapshot = {"universe_query": universe_query, "filters": filt, "total": total, "ground_truth": ground_pmids}
        fix = get_validated_json(QWEN_MODEL, S4_REMEDIATE_SYSTEM, s4_remediate_user(snapshot),
                                 validator=validate_remediation, max_tokens=256, dbg_stage="strategy_remed")
        op = fix["op"]
        if op == "DROP_TERM" and fix.get("where") == "topic":
            term = fix.get("term","").strip()
            # remove a single OR'd MeSH term from topic_filter
            # crude but effective: remove "<term>[MeSH Terms]" occurrences + surrounding ORs
            filt["topic_filter"] = re.sub(rf'\s*\(?{re.escape(term)}\[MeSH Terms\]\s*OR\s*', '(', filt["topic_filter"])
            filt["topic_filter"] = re.sub(rf'\s*OR\s*{re.escape(term)}\[MeSH Terms\]\s*\)?', ')', filt["topic_filter"])
            filt["topic_filter"] = re.sub(r'\(\s*\)', '(*)', filt["topic_filter"])
        elif op == "ADD_ANCHOR" and fix.get("term"):
            protocol["must_have"].append(fix["term"])
            universe_query = build_universe_query(protocol)
        elif op == "BROADEN_DESIGN_FILTER":
            filt["design_filter"] = "Clinical Trial[Publication Type]"
        attempts += 1
    raise SystemExit("Fatal: strategy could not be validated within remediation budget.")

# -----------------------------
# S5: Finalization
# -----------------------------
def state5_finalize(nlq: str, protocol: dict, universe_query: str, rec_filters: dict,
                    ground_pmids: List[str], mesh_vocab: Dict[str,str], warnings: List[str]):
    rq_embed = "Adults undergoing MIRPE/Nuss; intercostal nerve cryoablation for analgesia; outcomes: opioid use & 0–7 day pain."
    artifacts = {
        "locked_protocol": protocol,
        "universe_query": universe_query,
        "recommended_filters": rec_filters,
        "ground_truth_pmids": ground_pmids,
        "mesh_vernaculum": mesh_vocab,
        "warnings": warnings,
        "research_question_string_for_embedding": rq_embed,
        "nlq": nlq
    }
    ARTIFACTS_JSON.write_text(json.dumps(artifacts, indent=2, ensure_ascii=False), encoding="utf-8")

    lines = []
    lines.append("==================== SNIFF REPORT ====================")
    lines.append("NLQ:\n  " + textwrap.shorten(nlq, 220))
    lines.append("\nLOCKED PROTOCOL:")
    lines.append("  " + json.dumps(protocol, ensure_ascii=False))
    lines.append("\nUNIVERSE:")
    lines.append(f"  Query: {universe_query}")
    lines.append("\nRECOMMENDED FILTERS:")
    lines.append(f"  topic:  {rec_filters['topic']}")
    lines.append(f"  design: {rec_filters['design']}")
    lines.append("\nGROUND TRUTH PMIDs:")
    lines.append("  " + ", ".join(ground_pmids))
    if mesh_vocab:
        top = ", ".join([f"{m}({r})" for m,r in list(mesh_vocab.items())[:20]])
        lines.append("\nMESH vocab (top):")
        lines.append("  " + top)
    if warnings:
        lines.append("\nWARNINGS:")
        for w in warnings: lines.append(f"  - {w}")
    lines.append("\nArtifacts saved to: " + str(OUTDIR))
    lines.append("================== END OF REPORT =====================")
    REPORT_TXT.write_text("\n".join(lines), encoding="utf-8")
    print("\n".join(lines))

# -----------------------------
# MAIN ENGINE
# -----------------------------
def sniff_engine(nlq: str):
    warnings=[]
    proto = state1_protocol_lockdown(nlq)
    universe_query, u_ids = state2_universe(proto)
    reranked_docs = state2_5_rerank_universe(universe_query, u_ids, proto)
    gt_pmids, mesh_vocab = state3_ground_truth(reranked_docs, proto)
    gt_pmids = state3_5_plausibility(proto, gt_pmids)
    rec_filters, combined_ids = state4_validate_strategy(universe_query, gt_pmids, proto, mesh_vocab)
    state5_finalize(nlq, proto, universe_query, rec_filters, gt_pmids, mesh_vocab, warnings)

# -----------------------------
# RUN EXAMPLE
# -----------------------------
if __name__ == "__main__":
    USER_NLQ = """Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE).
Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during MIRPE/Nuss (the intervention of interest is INC, not the surgery).
Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia.
Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days.
Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control.
Year_min = 2015.
Languages = English, Portuguese, Spanish."""
    sniff_engine(USER_NLQ)


[S1] Protocol lockdown...
  [S1] Locked protocol:
    {"population_terms": ["adults", "Nuss", "MIRPE", "pectus excavatum", "minimally invasive repair"], "intervention_terms": ["intercostal nerve", "cryoablation", "cryoanalgesia", "INC", "analgesia"], "comparators_terms": ["thoracic epidural", "paravertebral block", "intercostal nerve block", "erector spinae plane block", "systemic multimodal analgesia"], "outcomes_terms": ["postoperative opioid consumption", "pain scores", "0-7 day pain", "discharge opioid use"], "must_have": ["MIRPE", "Nuss", "pectus excavatum", "cryoablation"], "avoid": ["pediatric"], "designs_preference": "Randomized Controlled Trial", "languages": ["english", "portuguese", "spanish"], "year_min": 2015}
[S2] Universe definition & sizing...
   [Universe] try=0 count=299 window=(50, 10000)
[S2.5] Rerank universe with PICO-weighted TF-IDF...
[S3] Ground-truth discovery & vocabulary mining...
  [Screen] PMID 31259649 -> decision=EXCLUDE checklist={'P': False, 'I': True,

HTTPError: 404 Client Error: Not Found for url: http://127.0.0.1:1234/v1/chat/completions

In [None]:
# sniff_validation_engine_v3_1.py
# Refactored "Sniff Validation Engine" (state-machine architecture)
# - Single "universe" query + validated filters (no legacy BROAD/FOCUSED)
# - Deterministic prefilter (language/year/design) BEFORE LLM
# - PICO-weighted TF-IDF re-ranker (true TF-IDF × role weights)
# - Strict screener with Ask-Validate-Retry JSON
# - Senior plausibility check (second LLM pass) to prevent topic drift
# - Ground-truth vocabulary (MeSH) only from confirmed INCLUDEs
# - Robust model switching with idle TTL and conservative waits to avoid CPU fallback
#
# Requirements:
#   - LM Studio running at LMSTUDIO_BASE (default http://127.0.0.1:1234)
#   - Two local models served by LM Studio:
#       QWEN_MODEL  (for protocol, remediation, plausibility)
#       SCREENER_MODEL (fast model for checklist screening)
#   - Internet for NCBI E-utilities
#
# Usage:
#   1) Edit USER_NLQ below (natural language RQ; you may include notes for screening).
#   2) Optionally edit constants (MODEL names, thresholds) or set via env vars.
#   3) Run as a single cell/script. See printed summary + output files in OUT_DIR.
#
# Outputs:
#   - sniff_report.txt  : human-readable report of states, warnings, and final strategy
#   - sniff_artifacts.json : machine-readable details (locked protocol, universe query,
#                            recommended filters, ground truth PMIDs, MeSH vernaculum,
#                            research_question_string_for_embedding, warnings)
#
# NOTE: We DO NOT hard-filter by "Humans" or MeSH age bands deterministically.
#       Deterministic gates: year, language, and publication type (design allowlist).
#       The LLM uses PubTypes + MeSH contextually during screening.

import os, json, time, re, textwrap, pathlib, random, math
from collections import Counter, defaultdict
from typing import List, Dict, Any, Tuple, Callable, Optional
import requests
from xml.etree import ElementTree as ET

# ----------------------------
# Config / Constants
# ----------------------------
LMSTUDIO_BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234")
QWEN_MODEL    = os.getenv("QWEN_MODEL", "qwen/qwen3-4b")
SCREENER_MODEL= os.getenv("SCREENER_MODEL", "gemma-3n-e2b-it")  # fast checklist screener

ENTREZ_EMAIL   = os.getenv("ENTREZ_EMAIL", "you@example.com")
ENTREZ_API_KEY = os.getenv("ENTREZ_API_KEY", "")

HTTP_TIMEOUT   = int(os.getenv("HTTP_TIMEOUT", "300"))
MODEL_TTL_SEC  = float(os.getenv("MODEL_TTL_SEC", "5.0"))  # idle TTL hint (LM Studio)
MODEL_SWAP_WAIT= float(os.getenv("MODEL_SWAP_WAIT", "10.0"))  # conservative wait between model swaps

OUT_DIR = pathlib.Path("sniff_out")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Universe sizing thresholds
UNIVERSE_TARGET = (50, 10000)   # ideal window
UNIVERSE_HARD_MIN = 25          # if final count < 25 after remediation -> terminate

# Rerank / screening sizes
UNIVERSE_FETCH_MAX = 800        # number of PubMed records to fetch for rerank (cap)
SCREEN_TOP_K       = 60         # how many (after rerank+prefilter) to send to screener

# Screener rules
SCREENER_RETRY_MAX = 3
PLAUSIBILITY_MIN_INCLUDES = 3   # need ≥ this many includes, post-plausibility, or terminate

# Role weights for PICO TF-IDF reranker
WEIGHTS = {
    "P": 1.5,
    "I": 1.75,
    "C": 1.0,
    "O": 1.0,
    "ANCHOR": 2.0,
    "AVOID": -2.5
}

# PubMed E-utilities base + headers
EUTILS = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils"
HEADERS = {"User-Agent": "sniff-validation-engine/3.1 (+local)", "Accept": "application/json"}

random.seed(42)

# ----------------------------
# Utilities: LM Studio model management
# ----------------------------
class ModelManager:
    def __init__(self, base: str, idle_ttl_sec: float = MODEL_TTL_SEC, swap_wait: float = MODEL_SWAP_WAIT):
        self.base = base.rstrip("/")
        self.idle_ttl = idle_ttl_sec
        self.swap_wait = swap_wait
        self.current_model = None
        self.last_used_ts = 0.0

    def _maybe_wait_for_idle_eviction(self):
        now = time.time()
        idle = now - self.last_used_ts
        if idle < self.idle_ttl:
            time.sleep(self.idle_ttl - idle)
        # conservative extra wait to allow LM Studio to evict models
        time.sleep(max(0.0, self.swap_wait - self.idle_ttl))

    def _best_effort_unload_all(self):
        # LM Studio does not officially document unload; try likely endpoints, ignore errors.
        for path in ["/v1/models/unload_all", "/v1/engines/unload_all", "/v1/models/unload"]:
            try:
                requests.post(self.base + path, timeout=3.0)
            except Exception:
                pass

    def switch(self, model: str):
        if self.current_model and self.current_model != model:
            # allow time for the previous model to be evicted
            self._maybe_wait_for_idle_eviction()
            self._best_effort_unload_all()
        self.current_model = model
        self.last_used_ts = time.time()

    def mark_used(self):
        self.last_used_ts = time.time()

MM = ModelManager(LMSTUDIO_BASE)

def lm_chat(model: str, system: str, user: str, temperature=0.0, max_tokens=2048) -> str:
    MM.switch(model)
    url = f"{LMSTUDIO_BASE.rstrip('/')}/v1/chat/completions"
    body = {
        "model": model,
        "messages": [{"role":"system","content":system},{"role":"user","content":user}],
        "temperature": float(temperature),
        "max_tokens": int(max_tokens),
        "stream": False
    }
    r = requests.post(url, json=body, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    MM.mark_used()
    return r.json()["choices"][0]["message"]["content"]

# ----------------------------
# JSON extraction & Ask-Validate-Retry
# ----------------------------
_BEGIN = re.compile(r"BEGIN_JSON\s*", re.I)
_END   = re.compile(r"\s*END_JSON", re.I)
FENCE  = re.compile(r"```(?:json)?\s*([\s\S]*?)```", re.I)

def _sanitize_json_str(s: str) -> str:
    s = s.replace("\u201c", '"').replace("\u201d", '"').replace("\u2018","'").replace("\u2019","'")
    s = re.sub(r",\s*(\}|\])", r"\1", s)
    return s.strip()

def extract_json_block_or_fence(txt: str) -> str:
    blocks = []
    pos=0
    while True:
        m1 = _BEGIN.search(txt, pos)
        if not m1: break
        m2 = _END.search(txt, m1.end())
        if not m2: break
        blocks.append(txt[m1.end():m2.start()])
        pos = m2.end()
    if blocks:
        return _sanitize_json_str(blocks[-1])

    fences = FENCE.findall(txt)
    if fences:
        return _sanitize_json_str(fences[-1])

    # last balanced {...}
    s = txt
    last_obj=None; stack=0; start=None
    for i,ch in enumerate(s):
        if ch=='{':
            if stack==0: start=i
            stack+=1
        elif ch=='}':
            if stack>0:
                stack-=1
                if stack==0 and start is not None:
                    last_obj = s[start:i+1]
    if last_obj:
        return _sanitize_json_str(last_obj)
    raise ValueError("No JSON-like content found")

STRICT_JSON_RULES = (
  "Return ONLY one JSON object. No analysis, no preface, no notes. "
  "Wrap it EXACTLY with:\nBEGIN_JSON\n{...}\nEND_JSON"
)

def get_validated_json(
    model: str,
    system_prompt: str,
    user_prompt: str,
    validator: Callable[[Dict[str,Any]], Tuple[bool,str]],
    retries: int = 3,
    max_tokens: int = 2048
) -> Dict[str,Any]:
    history_user = user_prompt
    for i in range(retries):
        raw = lm_chat(model, system_prompt, history_user + "\n\n" + STRICT_JSON_RULES, max_tokens=max_tokens)
        try:
            js = json.loads(extract_json_block_or_fence(raw))
        except Exception as e:
            err = f"malformed JSON: {e}"
            if i == retries-1:
                raise SystemExit(f"Fatal: LLM failed to produce valid JSON after retries. Last error: {err}")
            history_user += f"\n\nYour previous output was invalid due to: {err}\nPlease fix and return a single valid JSON object."
            continue
        ok, why = validator(js)
        if ok:
            return js
        if i == retries-1:
            raise SystemExit(f"Fatal: LLM JSON schema invalid after retries: {why}")
        history_user += f"\n\nYour previous JSON failed validation: {why}\nPlease correct your output and adhere to the required schema."

# ----------------------------
# KB defaults (designs, languages, pubtype map)
# ----------------------------
KB_PATH = pathlib.Path("system_knowledge_base.json")
KB_DEFAULT = {
    "publication_types_allowable": [
        "Randomized Controlled Trial",
        "Controlled Clinical Trial",
        "Clinical Trial",
        "Comparative Study",
        "Cohort Studies",
        "Case-Control Studies",
        "Observational Study",
        "Multicenter Study",
        "Cross-Sectional Studies",
        "Clinical Trial Protocol",
        "Evaluation Study"
    ],
    "languages": ["english","spanish","portuguese","french","german","italian","chinese","japanese","korean"],
    "designs_primary": ["Randomized Controlled Trial","Controlled Clinical Trial","Clinical Trial"],
    "designs_secondary": ["Comparative Study","Cohort Studies","Case-Control Studies","Observational Study","Multicenter Study","Evaluation Study"],
    "pubtype_aliases": {
        "Randomized Controlled Trial": ["Randomized Controlled Trial"],
        "Controlled Clinical Trial": ["Controlled Clinical Trial"],
        "Clinical Trial": ["Clinical Trial"],
        "Comparative Study": ["Comparative Study"],
        "Cohort Studies": ["Cohort Studies","Prospective Studies","Retrospective Studies"],
        "Case-Control Studies": ["Case-Control Studies"],
        "Observational Study": ["Observational Study"],
        "Multicenter Study": ["Multicenter Study"],
        "Cross-Sectional Studies": ["Cross-Sectional Studies"],
        "Clinical Trial Protocol": ["Clinical Trial Protocol","Study Protocols"],
        "Evaluation Study": ["Evaluation Study"]
    }
}

def load_or_init_kb() -> Dict[str,Any]:
    if KB_PATH.exists():
        try:
            on_disk = json.loads(KB_PATH.read_text(encoding="utf-8"))
        except Exception:
            # if corrupted, reset to defaults
            KB_PATH.write_text(json.dumps(KB_DEFAULT, indent=2), encoding="utf-8")
            return KB_DEFAULT

        # Merge defaults → fill any missing keys from KB_DEFAULT
        merged = dict(KB_DEFAULT)
        for k, v in on_disk.items():
            merged[k] = v

        # Persist the merged file so future runs are stable
        KB_PATH.write_text(json.dumps(merged, indent=2), encoding="utf-8")
        return merged

    KB_PATH.write_text(json.dumps(KB_DEFAULT, indent=2), encoding="utf-8")
    return KB_DEFAULT


KB = load_or_init_kb()

# ----------------------------
# State 1: Protocol Lockdown
# ----------------------------
PROTO_SYSTEM = """You are designing a structured, search-ready SR protocol from a natural-language question.

Produce a protocol that includes BOTH narrative fields for LLMs and structured fields for code.

Rules:
- Use concise search tokens for P/I/C/O (each token ≤ 3-4 words). Avoid overlong phrases.
- Populate 'designs_preference' by selecting ONE from the provided KB 'designs_primary'.
- 'deterministic_filters' MUST include: languages (subset of KB.languages) and year_min (from user or a reasonable default).
- Do not hallucinate comparators or outcomes not implied; it's OK to leave lists empty if not provided.
- If the question is incoherent or underspecified, set "needs_clarification"=true and write a short "clarification_request".

Return ONLY JSON as requested."""

def proto_user(nlq: str, kb: Dict[str,Any]) -> str:
    kb_view = {
        "designs_primary": kb.get("designs_primary", KB_DEFAULT["designs_primary"]),
        "languages": kb.get("languages", KB_DEFAULT["languages"])
    }
    return f"""Natural-Language Question:
<<<
{nlq.strip()}
>>>

Knowledge Base (valid choices):
{json.dumps({"designs_primary":KB["designs_primary"], "languages":KB["languages"]}, indent=2)}

Output schema:
{{
  "narrative_question": "<1 paragraph restatement>",
  "inclusion_criteria": ["...","..."],
  "exclusion_criteria": ["..."],
  "screening_rules_note": {{
    "user_notes": "<verbatim any adjunct/instructions embedded in NLQ>",
    "llm_guidance": "<short additional instructions inferred>"
  }},
  "pico_tokens": {{
    "P": ["..."],
    "I": ["..."],
    "C": ["..."],
    "O": ["..."]
  }},
  "anchors_must_have": ["..."],   // topical anchors to enforce (e.g., MIRPE, Nuss)
  "avoid_terms": ["..."],
  "designs_preference": "<ONE of designs_primary>",
  "deterministic_filters": {{
     "languages": ["..."],  // subset of KB.languages
     "year_min": 2015
  }},
  "needs_clarification": false,
  "clarification_request": ""
}}"""

def validate_protocol(js: Dict[str,Any]) -> Tuple[bool,str]:
    try:
        # minimal schema checks
        req_top = ["narrative_question","inclusion_criteria","exclusion_criteria",
                   "screening_rules_note","pico_tokens","anchors_must_have",
                   "avoid_terms","designs_preference","deterministic_filters",
                   "needs_clarification","clarification_request"]
        for k in req_top:
            if k not in js: return False, f"missing key: {k}"
        if not isinstance(js["pico_tokens"], dict): return False, "pico_tokens must be object"
        for k in ["P","I","C","O"]:
            if k not in js["pico_tokens"]: return False, f"pico_tokens missing {k}"
            if not isinstance(js["pico_tokens"][k], list): return False, f"pico_tokens[{k}] must be list"
        df = js["deterministic_filters"]
        if not isinstance(df.get("languages",[]), list) or not df.get("languages"):
            return False, "languages must be non-empty list"

        y = df.get("year_min", 0)
        if isinstance(y, str) and y.isdigit():
            df["year_min"] = int(y)
        elif not isinstance(y, int):
            return False, "year_min must be int or numeric string"

        if js["designs_preference"] not in KB["designs_primary"]:
            return False, "designs_preference must be one of KB.designs_primary"

        # keep the short-token safeguard
        long_bad = [t for t in (js["pico_tokens"]["P"]+js["pico_tokens"]["I"]+js["pico_tokens"]["C"]+js["pico_tokens"]["O"]) if len(t.split())>5]
        if long_bad:
            return False, f"tokens too long: {long_bad[:3]}"
        return True, ""
    except Exception as e:
        return False, f"exception in protocol validation: {e}"

# ----------------------------
# PubMed: search & fetch
# ----------------------------
def esearch_ids(term: str, mindate: Optional[int], retmax: int = 5000) -> Tuple[int, List[str]]:
    p = {"db":"pubmed","retmode":"json","term":term,"retmax":retmax,"email":ENTREZ_EMAIL,"usehistory":"y"}
    if ENTREZ_API_KEY: p["api_key"]=ENTREZ_API_KEY
    if mindate: p["mindate"]=str(mindate)
    r = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params=p, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    js = r.json().get("esearchresult", {})
    count = int(js.get("count","0"))
    webenv = js.get("webenv"); qk = js.get("querykey")
    if not count or not webenv or not qk:
        return 0, []
    r2 = requests.get(f"{EUTILS}/esearch.fcgi", headers=HEADERS, params={
        "db":"pubmed","retmode":"json","retmax":retmax,"retstart":0,"email":ENTREZ_EMAIL,"WebEnv":webenv,"query_key":qk,
        **({"api_key":ENTREZ_API_KEY} if ENTREZ_API_KEY else {})
    }, timeout=HTTP_TIMEOUT)
    r2.raise_for_status()
    ids = r2.json().get("esearchresult",{}).get("idlist",[])
    return count, [str(x) for x in ids]

def efetch_xml(pmids: List[str]) -> str:
    if not pmids: return ""
    params = {"db":"pubmed","retmode":"xml","rettype":"abstract","id":",".join(pmids),"email":ENTREZ_EMAIL}
    if ENTREZ_API_KEY: params["api_key"]=ENTREZ_API_KEY
    r = requests.get(f"{EUTILS}/efetch.fcgi", headers={"User-Agent":"sniff-validation-engine/3.1"}, params=params, timeout=HTTP_TIMEOUT)
    r.raise_for_status()
    return r.text

def parse_pubmed_xml(xml_text: str) -> List[Dict[str,Any]]:
    out = []
    if not xml_text.strip(): return out
    root = ET.fromstring(xml_text)
    def _join(node):
        if node is None: return ""
        try: return "".join(node.itertext())
        except Exception: return node.text or ""
    for art in root.findall(".//PubmedArticle"):
        pmid = art.findtext(".//PMID") or ""
        title = _join(art.find(".//ArticleTitle")).strip()
        abs_nodes = art.findall(".//Abstract/AbstractText")
        abstract = " ".join(_join(n).strip() for n in abs_nodes) if abs_nodes else ""
        year = None
        for path in (".//ArticleDate/Year",".//PubDate/Year",".//DateCreated/Year",".//PubDate/MedlineDate"):
            s = art.findtext(path)
            if s:
                m = re.search(r"\d{4}", s)
                if m: year = int(m.group(0)); break
        lang = art.findtext(".//Language") or None
        pubtypes = [pt.text for pt in art.findall(".//PublicationTypeList/PublicationType") if pt.text]
        mesh = [mh.findtext("./DescriptorName") for mh in art.findall(".//MeshHeadingList/MeshHeading") if mh.findtext("./DescriptorName")]
        out.append({"pmid":pmid,"title":title,"abstract":abstract,"year":year,"language":lang,"pubtypes":pubtypes,"mesh":mesh})
    return out

# ----------------------------
# Query assembly, remediation
# ----------------------------
def or_block(terms: List[str], field="tiab") -> str:
    toks=[]
    for t in terms:
        t=t.strip()
        if not t: continue
        if " " in t or "-" in t:
            toks.append(f"\"{t}\"[{field}]")
        else:
            toks.append(f"{t}[{field}]")
    if not toks: return ""
    return "(" + " OR ".join(toks) + ")"

def build_universe_query(P: List[str], I: List[str], anchors: List[str]) -> str:
    Pq = or_block(P, "tiab"); Iq = or_block(I, "tiab")
    Aq = or_block(anchors, "tiab") if anchors else ""
    parts = [x for x in [Pq, Iq, Aq] if x]
    return " AND ".join(parts)

REM_SYS = """You are a search strategy repair assistant. The current query is underperforming (too few hits).

Constraints (do NOT violate):
- Keep the core topic: population and intervention must remain faithful to the protocol.
- Only operate on the P/I token lists: REMOVE_TERM, SIMPLIFY_TERM (shorten phrase), or ADD_ALTERNATE (synonym).
- Return at most 2 operations.
- Do NOT introduce terms that contradict population or intervention focus.
Return JSON only."""

def rem_user(query: str, count: int, protocol: Dict[str,Any]) -> str:
    return f"""Current universe query (hits={count}):
{query}

Protocol (brief):
P tokens: {protocol["pico_tokens"]["P"]}
I tokens: {protocol["pico_tokens"]["I"]}
Anchors: {protocol["anchors_must_have"]}
Avoid: {protocol["avoid_terms"]}
Design preference: {protocol["designs_preference"]}

Allowed ops (array of steps):
[{{"op":"REMOVE_TERM","where":"P|I","term":"..."}}, {{"op":"SIMPLIFY_TERM","where":"P|I","term":"full phrase","simplified":"short term"}}, {{"op":"ADD_ALTERNATE","where":"P|I","term":"root","alternate":"synonym"}}]

BEGIN_JSON
{{"ops":[]}}
END_JSON"""

def validate_remediation(js: Dict[str,Any]) -> Tuple[bool,str]:
    if "ops" not in js or not isinstance(js["ops"], list): return False, "missing ops[]"
    if len(js["ops"])>2: return False, "too many ops"
    for op in js["ops"]:
        if op.get("op") not in ["REMOVE_TERM","SIMPLIFY_TERM","ADD_ALTERNATE"]:
            return False, f"bad op: {op.get('op')}"
        if op.get("where") not in ["P","I"]:
            return False, "where must be P or I"
    return True, ""

def apply_remediation(P: List[str], I: List[str], ops: List[Dict[str,str]]) -> Tuple[List[str], List[str]]:
    Pn = P[:]; In = I[:]
    def _apply(lst, op):
        if op["op"]=="REMOVE_TERM":
            lst = [t for t in lst if t.lower()!=op.get("term","").lower()]
        elif op["op"]=="SIMPLIFY_TERM":
            t = op.get("term",""); s=op.get("simplified","")
            lst = [s if x.lower()==t.lower() and s else x for x in lst]
        elif op["op"]=="ADD_ALTERNATE":
            alt = op.get("alternate","")
            if alt and alt.lower() not in [x.lower() for x in lst]:
                lst.append(alt)
        return lst
    for op in ops:
        if op["where"]=="P":
            Pn = _apply(Pn, op)
        else:
            In = _apply(In, op)
    return Pn, In

# ----------------------------
# Deterministic prefilter (language/year/design only)
# ----------------------------
def passes_prefilter(rec: Dict[str,Any], languages: List[str], year_min: int, design_allowlist: List[str], pubtype_alias: Dict[str,List[str]]) -> bool:
    if rec.get("year") and rec["year"] < year_min:
        return False
    if rec.get("language") and rec["language"].lower() not in [x.lower() for x in languages]:
        return False
    if design_allowlist:
        # Any intersection between aliases for allowed designs and rec.pubtypes
        rpts = set(rec.get("pubtypes") or [])
        for design in design_allowlist:
            aliases = set(pubtype_alias.get(design, [design]))
            if rpts & aliases:
                return True
        # allow if no pubtypes present (unknown design) -> keep for LLM
        if not rpts:
            return True
        return False
    return True

# ----------------------------
# TF-IDF PICO reranker
# ----------------------------
def build_tfidf_and_score(records: List[Dict[str,Any]], protocol: Dict[str,Any]) -> List[Tuple[float,Dict[str,Any]]]:
    texts = []
    for r in records:
        t = (r.get("title","") + " " + r.get("abstract","")).strip()
        texts.append(t if t else r.get("title",""))
    # import TF-IDF with fallback
    try:
        from sklearn.feature_extraction.text import TfidfVectorizer
        vec = TfidfVectorizer(stop_words="english", max_features=50000)
        X = vec.fit_transform(texts)
        vocab = vec.vocabulary_
        idf_diag = None  # scikit handles internally
        def tfidf(term, row_idx):
            j = vocab.get(term.lower())
            if j is None: return 0.0
            return X[row_idx, j]
    except Exception:
        # very simple fallback: case-insensitive term frequency proxy
        vocab = {}
        def tfidf(term, row_idx):
            low = texts[row_idx].lower()
            return float(low.count(term.lower()))

    # compile weighted term list
    wt_terms = []
    for t in protocol["pico_tokens"]["P"]:
        wt_terms.append( (t, WEIGHTS["P"]) )
    for t in protocol["pico_tokens"]["I"]:
        wt_terms.append( (t, WEIGHTS["I"]) )
    for t in protocol["pico_tokens"]["C"]:
        wt_terms.append( (t, WEIGHTS["C"]) )
    for t in protocol["pico_tokens"]["O"]:
        wt_terms.append( (t, WEIGHTS["O"]) )
    for t in protocol["anchors_must_have"]:
        wt_terms.append( (t, WEIGHTS["ANCHOR"]) )
    for t in protocol["avoid_terms"]:
        wt_terms.append( (t, WEIGHTS["AVOID"]) )

    scored = []
    for i, rec in enumerate(records):
        s = 0.0
        for term, w in wt_terms:
            if not term: continue
            s += float(tfidf(term, i)) * w
        scored.append( (s, rec) )
    scored.sort(key=lambda x: x[0], reverse=True)
    return scored

# ----------------------------
# Screener prompts & validators
# ----------------------------
SCREEN_SYS = """You are a strict but realistic title+abstract screener for an evidence scan.

Checklist logic (INCLUDE requires P & I true AND (O OR D) true):
- P (Population/Context): study matches the target clinical context; synonyms acceptable.
- I (Intervention): intercostal nerve cryoablation / cryoanalgesia used intraoperatively for the target surgery; synonyms acceptable.
- O (Outcomes): any acute postoperative analgesia outcomes acceptable (pain, opioid use, LOS, early complications). Do NOT require exact day windows at abstract level unless protocol explicitly demands it.
- D (Design): randomized/comparative preferred; strong cohorts acceptable if protocol allows. Use PubTypes if available; otherwise infer from abstract.

Return ONLY JSON with schema below; be conservative but do not nitpick details that require full-text.
If the record is clearly pediatric while protocol is adults-only (or vice-versa), you may EXCLUDE for population mismatch."""

def screen_user(protocol: Dict[str,Any], record: Dict[str,Any]) -> str:
    return f"""Protocol (narrative):
{protocol["narrative_question"]}

Key lists:
P: {protocol["pico_tokens"]["P"]}
I: {protocol["pico_tokens"]["I"]}
C: {protocol["pico_tokens"]["C"]}
O: {protocol["pico_tokens"]["O"]}
Design preference: {protocol["designs_preference"]}
Anchors: {protocol["anchors_must_have"]}
Avoid: {protocol["avoid_terms"]}
Inclusion criteria: {protocol["inclusion_criteria"]}
Exclusion criteria: {protocol["exclusion_criteria"]}
Screening notes: {protocol["screening_rules_note"]}

Record:
PMID: {record['pmid']}
Title: {record['title']}
PubTypes: {record.get('pubtypes',[])}
MeSH: {record.get('mesh',[])}
Abstract:
{record.get('abstract','')}

Return schema:
{{
  "pmid": "{record['pmid']}",
  "decision": "INCLUDE|BORDERLINE|EXCLUDE",
  "why": "<one concise reason>",
  "checklist": {{"P": true|false, "I": true|false, "O": true|false, "D": true|false}},
  "mesh_roles": [{{"mesh":"...","role":"P|I|C|O|G"}}]
}}"""

def validate_screen(js: Dict[str,Any]) -> Tuple[bool,str]:
    try:
        if js.get("decision") not in ["INCLUDE","BORDERLINE","EXCLUDE"]:
            return False, "bad decision"
        ch = js.get("checklist",{})
        for k in ["P","I","O","D"]:
            if not isinstance(ch.get(k), bool):
                return False, f"checklist.{k} must be bool"
        m = js.get("mesh_roles",[])
        if not isinstance(m, list):
            return False, "mesh_roles must be list"
        for it in m:
            if not isinstance(it, dict): return False, "mesh_roles items must be dict"
            if "mesh" not in it or "role" not in it: return False, "mesh_roles items need mesh & role"
        return True, ""
    except Exception as e:
        return False, f"exception in screen validation: {e}"

# Senior plausibility check
PLAUS_SYS = """You are a senior reviewer validating junior screening decisions to prevent topic drift.
Given the protocol and an already-INCLUDED record, answer PASS if the record’s core topic clearly matches the protocol’s core P+I context; otherwise FAIL.
Be brief and conservative. Return JSON only with {"pmid":"...","verdict":"PASS|FAIL","why":"..."}"""

def plaus_user(protocol: Dict[str,Any], record: Dict[str,Any]) -> str:
    core = f"P core terms: {protocol['pico_tokens']['P']} ; I core terms: {protocol['pico_tokens']['I']} ; Anchors: {protocol['anchors_must_have']}"
    return f"""Protocol core:
{core}

Record:
PMID: {record['pmid']}
Title: {record['title']}
PubTypes: {record.get('pubtypes',[])}
MeSH: {record.get('mesh',[])}
Abstract:
{record.get('abstract','')}

BEGIN_JSON
{{"pmid":"{record['pmid']}", "verdict":"PASS", "why":""}}
END_JSON"""

def validate_plaus(js: Dict[str,Any]) -> Tuple[bool,str]:
    v = js.get("verdict")
    if v not in ["PASS","FAIL"]: return False, "verdict must be PASS|FAIL"
    if "pmid" not in js: return False, "missing pmid"
    return True, ""

# ----------------------------
# State machine
# ----------------------------
def state1_protocol_lockdown(nlq: str) -> Dict[str,Any]:
    print("[S1] Protocol lockdown...")
    system = PROTO_SYSTEM
    # Add explicit guardrails/examples to discourage long tokens
    user = proto_user(nlq, KB) + """

Guidance:
- BAD token example: "intraoperative intercostal nerve cryoablation for analgesia"
- GOOD tokens: ["intercostal nerve","cryoablation","cryoanalgesia","INC","analgesia"]
"""
    proto = get_validated_json(QWEN_MODEL, system, user, validate_protocol, retries=3, max_tokens=2048)
    if proto.get("needs_clarification"):
        raise SystemExit("Protocol needs clarification: " + proto.get("clarification_request",""))
    print("  [S1] Locked protocol:")
    print("   ", json.dumps(proto, ensure_ascii=False))
    return proto

def state2_universe(protocol: Dict[str,Any]) -> Tuple[str, int, List[str]]:
    print("[S2] Universe definition & sizing...")
    P = protocol["pico_tokens"]["P"]
    I = protocol["pico_tokens"]["I"]
    anchors = protocol["anchors_must_have"]
    query = build_universe_query(P, I, anchors)
    count, ids = esearch_ids(query, protocol["deterministic_filters"]["year_min"], retmax=UNIVERSE_FETCH_MAX)
    print(f"   [Universe] try=0 count={count} window={UNIVERSE_TARGET}")
    tries = 0
    while (count < UNIVERSE_TARGET[0] or count > UNIVERSE_TARGET[1]) and tries < 2:
        # remediation loop
        rem = get_validated_json(QWEN_MODEL, REM_SYS, rem_user(query, count, protocol), validate_remediation, retries=2, max_tokens=1024)
        P, I = apply_remediation(P, I, rem.get("ops",[]))
        query = build_universe_query(P, I, anchors)
        count, ids = esearch_ids(query, protocol["deterministic_filters"]["year_min"], retmax=UNIVERSE_FETCH_MAX)
        tries += 1
        print(f"   [Universe] try={tries} count={count} window={UNIVERSE_TARGET}")
    if count < UNIVERSE_HARD_MIN:
        raise SystemExit(f"Fatal: universe too small after remediation (count={count} < {UNIVERSE_HARD_MIN}).")
    return query, count, ids

def deterministic_prefilter(records: List[Dict[str,Any]], protocol: Dict[str,Any]) -> List[Dict[str,Any]]:
    langs = protocol["deterministic_filters"]["languages"]
    ymin  = protocol["deterministic_filters"]["year_min"]
    # Build allowlist: prefer primary first; include secondary too at sniff stage
    allowlist = list(dict.fromkeys(KB["designs_primary"] + KB["designs_secondary"]))
    out=[]
    for r in records:
        if passes_prefilter(r, langs, ymin, allowlist, KB["pubtype_aliases"]):
            out.append(r)
    return out

def state2_5_rerank_universe(query: str, ids: List[str], protocol: Dict[str,Any]) -> List[Dict[str,Any]]:
    print("[S2.5] Rerank universe with PICO-weighted TF-IDF...")
    # fetch up to UNIVERSE_FETCH_MAX for reranking
    ids = ids[:UNIVERSE_FETCH_MAX]
    xml = efetch_xml(ids)
    recs = parse_pubmed_xml(xml)
    pre = deterministic_prefilter(recs, protocol)
    scored = build_tfidf_and_score(pre, protocol)
    # return re-ordered records only
    return [r for (s,r) in scored]

def state3_ground_truth(reranked_records: List[Dict[str,Any]], protocol: Dict[str,Any]) -> Tuple[List[Dict[str,Any]], List[Dict[str,Any]]]:
    print("[S3] Ground-truth discovery & vocabulary mining...")
    to_screen = reranked_records[:SCREEN_TOP_K]
    includes=[]; borderlines=[]
    for r in to_screen:
        js = get_validated_json(SCREENER_MODEL, SCREEN_SYS, screen_user(protocol, r), validate_screen, retries=SCREENER_RETRY_MAX, max_tokens=1536)
        d = js.get("decision")
        why = js.get("why","")
        chk = js.get("checklist",{})
        # concise logging
        short_abs = (r.get("abstract","")[:320] + "…") if r.get("abstract") and len(r["abstract"])>320 else (r.get("abstract","") or "")
        print(f"  [Screen] PMID {r['pmid']} -> decision={d} checklist={chk} why={why}")
        print(f"    Title: {r['title']}")
        print(f"    Abstract: {short_abs}")
        if d=="INCLUDE":
            # attach mesh_roles if any
            r["_mesh_roles"] = js.get("mesh_roles",[])
            includes.append(r)
        elif d=="BORDERLINE":
            r["_mesh_roles"] = js.get("mesh_roles",[])
            borderlines.append(r)
    return includes, borderlines

def state3_5_plausibility(includes: List[Dict[str,Any]], protocol: Dict[str,Any]) -> List[Dict[str,Any]]:
    print("[S3.5] Senior plausibility check (guard against topic drift)...")
    confirmed=[]
    for r in includes:
        js = get_validated_json(QWEN_MODEL, PLAUS_SYS, plaus_user(protocol, r), validate_plaus, retries=2, max_tokens=768)
        if js.get("verdict")=="PASS":
            confirmed.append(r)
        else:
            print(f"   [Plausibility] DROP PMID {r['pmid']} — {js.get('why','')}")
    return confirmed

def mesh_vernaculum_from(includes: List[Dict[str,Any]]) -> Dict[str,List[str]]:
    roles = {"P":set(),"I":set(),"C":set(),"O":set(),"G":set()}
    for r in includes:
        for mr in r.get("_mesh_roles",[]):
            m = mr.get("mesh"); role = mr.get("role","G")
            if m and role in roles:
                roles[role].add(m)
    return {k:sorted(v) for k,v in roles.items()}

def state4_validate_strategy(universe_query: str, confirmed_includes: List[Dict[str,Any]], protocol: Dict[str,Any], vernac: Dict[str,List[str]]) -> Dict[str,str]:
    print("[S4] Search-strategy validation & refinement...")
    # Build "topic_filter" deterministically from vernaculum (use P+I+O meshes as TIAB surface tokens)
    topic_tokens = list(dict.fromkeys(vernac.get("P",[]) + vernac.get("I",[]) + vernac.get("O",[])))
    topic_filter = or_block(topic_tokens, "tiab") if topic_tokens else ""
    # Build design filter deterministically from protocol preference (map to aliases)
    pref = protocol["designs_preference"]
    aliases = KB["pubtype_aliases"].get(pref, [pref])
    # recommended_filters are strings meant to be combined during Harvest:
    #   final_query := (universe_query) AND (topic_filter)  then apply 'design_filter' at execution time
    recommended = {
        "topic_filter": topic_filter,
        "design_filter": " OR ".join(f'"{a}"[Publication Type]' for a in aliases)
    }

    # Validation: recall of includes
    # We check that each include is still retrievable with (universe AND topic_filter)
    recall_ok = True
    for r in confirmed_includes:
        # cheap check: topic_filter tokens appear in title/abstract (proxy for final execution)
        if topic_filter:
            any_tok = False
            low = (r.get("title","") + " " + r.get("abstract","")).lower()
            # parse tokens out of the topic_filter string approximately
            toks = re.findall(r'"([^"]+)"\[tiab\]|(\w+)\[tiab\]', topic_filter)
            flat = [a or b for a,b in toks if (a or b)]
            for t in flat:
                if t.lower() in low:
                    any_tok = True; break
            if not any_tok:
                recall_ok = False
                print(f"   [S4] Recall risk: topic_filter might drop PMID {r['pmid']}")

    if not recall_ok:
        print("   [S4] Relaxing topic_filter (drop vernaculum; rely on universe_query only).")
        recommended["topic_filter"] = ""  # fall back to universe-only; design filter still applied downstream

    return recommended

def state5_finalize(protocol: Dict[str,Any], universe_query: str, recommended_filters: Dict[str,str],
                    confirmed_includes: List[Dict[str,Any]], vernac: Dict[str,List[str]], warnings: List[str]) -> None:
    print("[S5] Finalization & handoff...")
    # embedding string: concise, validated question
    rq_embed = f"{protocol['narrative_question']} | P:{', '.join(protocol['pico_tokens']['P'])} I:{', '.join(protocol['pico_tokens']['I'])} O:{', '.join(protocol['pico_tokens']['O'])} Anchors:{', '.join(protocol['anchors_must_have'])}"

    artifacts = {
        "locked_protocol": protocol,
        "universe_query": universe_query,
        "recommended_filters": recommended_filters,
        "ground_truth_pmids": [r["pmid"] for r in confirmed_includes],
        "mesh_vernaculum": vernac,
        "research_question_string_for_embedding": rq_embed,
        "warnings": warnings
    }
    (OUT_DIR/"sniff_artifacts.json").write_text(json.dumps(artifacts, indent=2, ensure_ascii=False), encoding="utf-8")

    # Human-readable report
    lines=[]
    lines.append("========= SNIFF VALIDATION ENGINE REPORT (v3.1) =========\n")
    lines.append("Protocol (narrative):\n" + textwrap.fill(protocol["narrative_question"], 100) + "\n")
    lines.append("Deterministic filters: languages=" + ", ".join(protocol["deterministic_filters"]["languages"]) +
                 f" ; year_min={protocol['deterministic_filters']['year_min']}\n")
    lines.append("Universe query:\n" + universe_query + "\n")
    lines.append("Recommended filters:\n  topic_filter=" + (recommended_filters["topic_filter"] or "<none>") +
                 "\n  design_filter=" + recommended_filters["design_filter"] + "\n")
    lines.append(f"Ground truth includes (n={len(artifacts['ground_truth_pmids'])}): " + ", ".join(artifacts["ground_truth_pmids"]) + "\n")
    lines.append("MeSH vernaculum (from includes only):\n" + json.dumps(vernac, indent=2, ensure_ascii=False) + "\n")
    if warnings:
        lines.append("WARNINGS:\n- " + "\n- ".join(warnings) + "\n")
    (OUT_DIR/"sniff_report.txt").write_text("\n".join(lines), encoding="utf-8")
    print("  wrote:", OUT_DIR/"sniff_artifacts.json", "and", OUT_DIR/"sniff_report.txt")

# ----------------------------
# Orchestration
# ----------------------------
def sniff_engine_run(USER_NLQ: str):
    warnings=[]
    # S1
    protocol = state1_protocol_lockdown(USER_NLQ)

    # S2
    universe_query, universe_count, universe_ids = state2_universe(protocol)

    # S2.5
    reranked = state2_5_rerank_universe(universe_query, universe_ids, protocol)
    if not reranked:
        raise SystemExit("Fatal: no records after deterministic prefilter.")

    # S3
    includes, borderlines = state3_ground_truth(reranked, protocol)
    if not includes:
        raise SystemExit("Fatal: no includes after screening. Revisit protocol or universe scope.")

    # S3.5
    confirmed = state3_5_plausibility(includes, protocol)
    if len(confirmed) < PLAUSIBILITY_MIN_INCLUDES:
        raise SystemExit(f"Fatal: insufficient confirmed includes after plausibility ({len(confirmed)}<{PLAUSIBILITY_MIN_INCLUDES}).")

    # vernaculum strictly from confirmed includes
    vernac = mesh_vernaculum_from(confirmed)

    # S4
    recommended_filters = state4_validate_strategy(universe_query, confirmed, protocol, vernac)

    # S5
    state5_finalize(protocol, universe_query, recommended_filters, confirmed, vernac, warnings)

# ----------------------------
# Example run
# ----------------------------
if __name__ == "__main__":
    USER_NLQ = """
Population = children/adolescents undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE).
Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE.
Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia.
Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days (abstract-level timing not strictly required).
Study designs = RCTs preferred; if absent, include comparative cohorts/case-control/observational.
Year_min = 2015.
Languages = English, Portuguese, Spanish.
Screening notes: Be conservative; INCLUDE if P & I present and (O or D) is present; do not exclude for lack of exact day window if acute postop outcomes are clearly reported.
"""
    sniff_engine_run(USER_NLQ.strip())


[S1] Protocol lockdown...


KeyError: 'designs_primary'

In [None]:
# dump_lmstudio.py
import os, json, datetime, requests

BASE = os.getenv("LMSTUDIO_BASE", "http://127.0.0.1:1234")
URL  = f"{BASE.rstrip('/')}/v1/chat/completions"

def dump_completion(
    model: str,
    system: str,
    user: str,
    *,
    temperature: float = 0.0,
    max_tokens: int = 1200,
    stop=None,
    stream: bool = False,
):
    body = {
        "model": model,
        "messages": [
            {"role": "system", "content": system},
            {"role": "user",   "content": user},
        ],
        "temperature": float(temperature),
        "max_tokens": int(max_tokens),
        "stream": bool(stream),
    }
    if stop:
        body["stop"] = stop

    print("=== REQUEST BODY ===")
    print(json.dumps(body, ensure_ascii=False, indent=2))

    r = requests.post(URL, headers={"Content-Type": "application/json"}, json=body, stream=stream)
    print("\n=== HTTP STATUS ===")
    print(r.status_code)

    if stream:
        print("\n=== RAW STREAM ===")
        raw_chunks = []
        for line in r.iter_lines(decode_unicode=True):
            if line:
                print(line)
                raw_chunks.append(line)
        raw_text = "\n".join(raw_chunks)
    else:
        print("\n=== RAW TEXT ===")
        raw_text = r.text
        print(raw_text)

    # Save raw response exactly as returned
    ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    out_path = f"lmstudio_raw_{ts}.txt"
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(raw_text)
    print(f"\n(saved exact response to {out_path})")

    # Try to parse JSON (optional)
    if not stream:
        try:
            js = r.json()
            print("\n=== PARSED: choices[0].message.content ===")
            print(js["choices"][0]["message"]["content"])
        except Exception as e:
            print("\n(JSON parse failed or no message.content):", repr(e))

if __name__ == "__main__":
    import os

    # Model + system
    model  = os.getenv("LM_MODEL",  "qwen/qwen3-4b")
    system = os.getenv("LM_SYSTEM", "Return exactly one JSON object between BEGIN_JSON and END_JSON. No other text.")

    # Prompt for extracting clean term lists (P/I/C/O + must_have + avoid)
    user = """TASK
You will extract concise biomedical term lists from the natural-language question (NLQ) below.

Output policy:
- Return EXACTLY ONE JSON object between:
  BEGIN_JSON
  { ... }
  END_JSON
- No other text. No backticks. No “think” prefaces.
- Arrays only; 2–10 items per list when possible.
- Items are plain phrases (no boolean operators, quotes, field tags, or brackets).
- Prefer standard medical wording and common acronyms (e.g., MIRPE, INC).
- Keep scope tightly on the NLQ intent.

Fill these keys:
- population: synonyms/labels for adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE).
- intervention: synonyms/labels for intercostal nerve cryoablation used for analgesia during Nuss/MIRPE (e.g., cryoanalgesia, INC).
- comparators: thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, systemic multimodal analgesia (and common variants).
- outcomes: postoperative opioid consumption and pain scores within 0–7 days (include common phrasings).
- must_have: 3–6 anchor tokens that should appear to ensure topicality (e.g., MIRPE, Nuss, cryoablation).
- avoid: 3–6 obvious confounders to avoid if they dominate (e.g., pediatric oncology, cardiac surgery).

NLQ
Population = adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE). Intervention = intercostal nerve cryoablation (INC) used intraoperatively for analgesia during Nuss/MIRPE (the intervention of interest is INC, not the surgery). Comparators = thoracic epidural, paravertebral block, intercostal nerve block, erector spinae plane block, or systemic multimodal analgesia. Outcomes = postoperative opioid consumption (in-hospital and at discharge) and pain scores within 0–7 days. Study designs = RCTs preferred; if RCTs absent, include comparative cohort/case-control. Year_min = 2015. Languages = English, Portuguese, Spanish.

TEMPLATE_TO_FILL (structure only; you must populate the arrays with phrases)
BEGIN_JSON
{
  "population": [],
  "intervention": [],
  "comparators": [],
  "outcomes": [],
  "must_have": [],
  "avoid": []
}
END_JSON"""

    # No stops by default (safer for models that emit <think> blocks)
    stops  = os.getenv("LM_STOP", "").split("|") if os.getenv("LM_STOP") else None
    # If you want to try cutting prefaces, set:
    # stops = ["</think>", "```"]


    dump_completion(
        model=model,
        system=system,
        user=user,
        temperature=0.0,
        max_tokens=4000,
        stop=stops,        # e.g., export LM_STOP="END_JSON|</think>|```"
        stream=False,      # set True to see event stream
    )


=== REQUEST BODY ===
{
  "model": "qwen/qwen3-4b",
  "messages": [
    {
      "role": "system",
      "content": "Return exactly one JSON object between BEGIN_JSON and END_JSON. No other text."
    },
    {
      "role": "user",
      "content": "TASK\nYou will extract concise biomedical term lists from the natural-language question (NLQ) below.\n\nOutput policy:\n- Return EXACTLY ONE JSON object between:\n  BEGIN_JSON\n  { ... }\n  END_JSON\n- No other text. No backticks. No “think” prefaces.\n- Arrays only; 2–10 items per list when possible.\n- Items are plain phrases (no boolean operators, quotes, field tags, or brackets).\n- Prefer standard medical wording and common acronyms (e.g., MIRPE, INC).\n- Keep scope tightly on the NLQ intent.\n\nFill these keys:\n- population: synonyms/labels for adults undergoing minimally invasive repair of pectus excavatum (Nuss/MIRPE).\n- intervention: synonyms/labels for intercostal nerve cryoablation used for analgesia during Nuss/MIRPE (e.g., cryo