# Dependency

In [None]:
%pip install -q langchain langchain-openai openai beautifulsoup4 html5lib requests



In [None]:
# === EXPERIMENTAL REACT EVAL — STEP 1: setup ===
# Installs (quiet)
!pip -q install -U "langchain>=0.2" "langchain-community>=0.2" duckduckgo-search beautifulsoup4 html5lib

# Silence the utcnow() deprecation noise from jupyter_client
import os, warnings
os.environ["PYTHONWARNINGS"] = "ignore:::jupyter_client.session"
warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"jupyter_client\.session")
warnings.filterwarnings("ignore", message=r".*utcnow\\(\\) is deprecated.*", category=DeprecationWarning)

# Core imports for this section
import re, math, requests
from typing import List, Tuple, Dict, Any, Optional
from urllib.parse import urlparse
from bs4 import BeautifulSoup



In [None]:
!pip -q install faiss-cpu FlagEmbedding sentence-transformers transformers accelerate bitsandbytes \
                  rank_bm25 pandas numpy pyarrow pydantic pyyaml tqdm

In [None]:
%pip install -q ddgs

In [None]:

# C1 — Install libraries (quiet) + show GPU
# !pip -q install faiss-cpu FlagEmbedding sentence-transformers transformers accelerate bitsandbytes \
#                  rank_bm25 pandas numpy pyarrow pydantic pyyaml tqdm

import sys, platform, subprocess, os
print("Python:", sys.version)
print("Platform:", platform.platform())
try:

    out = subprocess.check_output(["nvidia-smi", "-L"], text=True)
    print("\n=== GPU ===\n" + out)
except Exception as e:
    print("No GPU or nvidia-smi missing:", repr(e))




# RAG-SYSTEM | FOR SULBATMOUL

In [None]:
# C2 — Imports, seed, directories
import os, re, json, math, time, uuid, random, glob, textwrap
from pathlib import Path


import numpy as np
import pandas as pd
from tqdm import tqdm

SEED = 42
random.seed(SEED); np.random.seed(SEED)

BASE_DIR = Path.cwd()
DATA_DIR = BASE_DIR / "data_salbutamol"          # keep your input data here
ART_DIR  = DATA_DIR / "artifacts_v1"             # auto-saved/reloaded artifacts
ART_DIR.mkdir(parents=True, exist_ok=True)

print("BASE_DIR:", BASE_DIR)
print("DATA_DIR:", DATA_DIR)
print("ART_DIR :", ART_DIR)



In [None]:
# C3 — Resolve the three input files; load chunks
# Expected filenames (edit if yours differ)
RULES_JSON_PATH    = DATA_DIR / "salbutamol_protocol.rules.json"
CHUNKS_CSV_PATH    = DATA_DIR / "salbutamol_chunks_for_embedding.csv"
SYNONYMS_JSON_PATH = DATA_DIR / "salbutamol_synonyms.json"

def ensure_path(p: Path, patterns):
    """Use exact path if exists; else search common places (/content, DATA_DIR)."""
    if p.exists(): return p
    search_roots = [Path("/content"), DATA_DIR, BASE_DIR]
    for root in search_roots:
        for pat in patterns:
            hits = list(root.rglob(pat))
            if hits:
                return hits[0]
    raise FileNotFoundError(f"Could not find file like: {patterns}")

RULES_JSON_PATH    = ensure_path(RULES_JSON_PATH,    ["salbutamol_protocol.rules.json"])
CHUNKS_CSV_PATH    = ensure_path(CHUNKS_CSV_PATH,    ["salbutamol_chunks_for_embedding.csv"])
SYNONYMS_JSON_PATH = ensure_path(SYNONYMS_JSON_PATH, ["salbutamol_synonyms.json"])

print("Using:")
print(" - RULES_JSON_PATH    :", RULES_JSON_PATH)
print(" - CHUNKS_CSV_PATH    :", CHUNKS_CSV_PATH)
print(" - SYNONYMS_JSON_PATH :", SYNONYMS_JSON_PATH)

chunks_df = pd.read_csv(CHUNKS_CSV_PATH)
assert {"id","section","type","text"}.issubset(chunks_df.columns)
ids      = chunks_df["id"].astype(str).tolist()
sections = chunks_df["section"].astype(str).tolist()
texts    = chunks_df["text"].astype(str).tolist()
print(f"✅ Loaded {len(texts)} chunks")




In [None]:
# C3.5 — Change-detection (hash) for auto-rebuild
import hashlib, json

SIG_PATH = ART_DIR / "build_signature.json"

def _sha256(p: Path) -> str:
    h = hashlib.sha256()
    with open(p, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""):
            h.update(chunk)
    return h.hexdigest()

def _current_signature():
    sig = {
        "chunks_csv": str(CHUNKS_CSV_PATH),
        "chunks_hash": _sha256(CHUNKS_CSV_PATH),
        "rules_json": str(RULES_JSON_PATH),
        "rules_hash": _sha256(RULES_JSON_PATH),
        "synonyms_json": str(SYNONYMS_JSON_PATH),
        "synonyms_hash": _sha256(SYNONYMS_JSON_PATH),
    }
    return sig

def need_rebuild():
    """Return True if artifacts missing OR inputs changed since last build."""
    if not (FAISS_PATH.exists() and VEC_NPY.exists() and META_PQ.exists() and SIG_PATH.exists()):
        return True
    try:
        old = json.load(open(SIG_PATH, "r"))
    except Exception:
        return True
    new = _current_signature()
    return any(old.get(k) != new.get(k) for k in new.keys())

def write_signature():
    json.dump(_current_signature(), open(SIG_PATH, "w"))


In [None]:
# C4 — Build or load dense index (BGE-M3 + FAISS IP) — with auto-rebuild
import faiss
from FlagEmbedding import BGEM3FlagModel

FAISS_PATH = ART_DIR / "faiss_ip.index"
VEC_NPY    = ART_DIR / "doc_vectors.npy"
META_PQ    = ART_DIR / "meta.parquet"

index = None
bge = None

FORCE_REBUILD = True   # <--- do this ONCE to rebuild artifacts


REBUILD = FORCE_REBUILD or need_rebuild()

if (not REBUILD) and FAISS_PATH.exists() and VEC_NPY.exists() and META_PQ.exists():
    # === Load existing artifacts ===
    index = faiss.read_index(str(FAISS_PATH))
    emb = np.load(VEC_NPY).astype("float32")
    meta = pd.read_parquet(META_PQ)

    # Defensive fill; keep in-memory lists synced to META_PQ
    meta["section"] = meta["section"].fillna("Unknown section").astype(str)
    ids      = meta["id"].astype(str).tolist()
    sections = meta["section"].astype(str).tolist()
    texts    = meta["text"].astype(str).tolist()

    # We still need a query encoder even when we load the FAISS index
    bge = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True)

    print(f"✅ Loaded FAISS index with {emb.shape[0]} vectors (dim={emb.shape[1]})")

else:
    # === Build fresh artifacts from CURRENT CSV ===
    bge = BGEM3FlagModel("BAAI/bge-m3", use_fp16=True)  # auto GPU if available

    BATCH = 64
    vecs = []
    for i in tqdm(range(0, len(texts), BATCH), desc="Encoding (BGE-M3)"):
        batch = texts[i:i+BATCH]
        out = bge.encode(
            batch, batch_size=len(batch), max_length=512,
            return_dense=True, return_sparse=False, return_colbert_vecs=False
        )
        v = out["dense_vecs"].astype("float32")
        faiss.normalize_L2(v)  # cosine via IP
        vecs.append(v)

    emb = np.vstack(vecs).astype("float32")
    index = faiss.IndexFlatIP(emb.shape[1])
    index.add(emb)

    # Persist artifacts built from the CURRENT DataFrame (no stale metas)
    faiss.write_index(index, str(FAISS_PATH))
    np.save(VEC_NPY, emb)

    meta = chunks_df[["id","section","type","text"]].copy()
    meta["section"] = meta["section"].fillna("Unknown section").astype(str)
    meta.to_parquet(META_PQ, index=False)

    # Write/update signature so future runs know these artifacts match the inputs
    write_signature()

    print(f"✅ Built FAISS index with {emb.shape[0]} vectors (dim={emb.shape[1]})")


In [None]:

# C5 — BM25 (lexical) with synonym expansion
from rank_bm25 import BM25Okapi

def tokenize(s: str):
    s = s.lower()
    s = re.sub(r"[^a-z0-9%\-\s]+", " ", s)
    return s.split()

with open(SYNONYMS_JSON_PATH, "r", encoding="utf-8") as f:
    synonyms = json.load(f)

# Add helpful domain aliases
synonyms.update({
    "spi": ["specialist poison information", "poison center", "poison control"],
    "assessment": ["initial assessment", "history taking", "initial evaluation"]
})

corpus_tokens = [tokenize(t) for t in texts]
bm25 = BM25Okapi(corpus_tokens)

def expand_query_lexical(q: str):
    q_low = q.lower()
    expansions = set([q_low])
    for head, alts in synonyms.items():
        targets = [head] + list(alts)
        if any(t in q_low for t in targets):
            expansions.update(targets)
    return " ".join(sorted(expansions))

print(f"✅ BM25 ready over {len(texts)} chunks; synonyms={len(synonyms)}")



In [None]:

# C6 — Hybrid search utilities
def dense_search(query: str, top_k=24):
    out = bge.encode([query], max_length=512,
                     return_dense=True, return_sparse=False, return_colbert_vecs=False)
    q = out["dense_vecs"].astype("float32")
    faiss.normalize_L2(q)
    sims, idxs = index.search(q, top_k)
    ids_list = idxs[0].tolist(); sims_list = sims[0].tolist()
    return [{"idx":i, "id":ids[i], "score":float(sims_list[j]), "type":"dense",
             "text":texts[i], "section":sections[i]} for j,i in enumerate(ids_list)]

def sparse_search(query: str, top_k=24):
    toks = tokenize(expand_query_lexical(query))
    scores = bm25.get_scores(toks)
    top = np.argsort(scores)[::-1][:top_k]
    return [{"idx":int(i), "id":ids[int(i)], "score":float(scores[int(i)]), "type":"sparse",
             "text":texts[int(i)], "section":sections[int(i)]} for i in top]

def rrf_fuse(cand_lists, k=60):
    ranks = {}
    for cand_list in cand_lists:
        for rank, c in enumerate(cand_list, start=1):
            ranks.setdefault(c["id"], 0.0)
            ranks[c["id"]] += 1.0 / (k + rank)
    fused = [{"id":cid, "score":s} for cid,s in ranks.items()]
    fused.sort(key=lambda x: x["score"], reverse=True)
    id2row = {row_id: j for j, row_id in enumerate(ids)}
    for c in fused:
        j = id2row[c["id"]]
        c.update({"idx":j, "text":texts[j], "section":sections[j]})
    return fused

def hybrid_search(query: str, top_k_dense=24, top_k_sparse=24, fuse_k=50):
    d = dense_search(query, top_k_dense)
    s = sparse_search(query, top_k_sparse)
    return rrf_fuse([d, s], k=fuse_k)

# smoke test
for r in hybrid_search("activated charcoal dose")[:5]:
    print(r["id"], "|", (r["section"] or "Unknown")[:60])
print("✅ Hybrid retrieval OK")



In [None]:
# C7 — Cross-encoder re-rank
import torch
from sentence_transformers import CrossEncoder

RERANK_TOP_IN  = 50
RERANK_TOP_OUT = 8
device = "cuda" if torch.cuda.is_available() else "cpu"
reranker = CrossEncoder("BAAI/bge-reranker-v2-m3", device=device)

def rerank(query: str, fused, top_in=RERANK_TOP_IN, top_out=RERANK_TOP_OUT, batch_size=16):
    pool = fused[:top_in]
    if not pool: return []
    pairs = [(query, c["text"]) for c in pool]
    scores = reranker.predict(pairs, batch_size=batch_size)
    for c, s in zip(pool, scores):
        c["rerank_score"] = float(s)
    pool.sort(key=lambda x: x["rerank_score"], reverse=True)
    return pool[:top_out]

print("✅ Reranker ready on", device)


In [None]:
# C8 — Quote selection + stitching + numeric guard

def _clean_section(sec):
    s = str(sec) if sec is not None else ""
    return s if s.lower() not in ("nan","none") else "Unknown section"

ANCHORS = {
    "charcoal": ["activated charcoal","single dose","dose","within 1 hour","within one hour",
                 "airway","vomiting","gastrointestinal perforation","contraindication",
                 "pediatrics","adult","1 g/kg","g/kg","50 g","dose:"],
    "threshold": ["1 mg/kg","home criteria","observation criteria","asymptomatic","symptomatic"],
    "investigations": ["investigation","ecg","electrocardiogram","serum chemistry","potassium",
                       "blood glucose","cardiac monitor","abg","vbg","cxr","x-ray"],
}

def infer_intent(q: str) -> str:
    ql = q.lower()
    if "charcoal" in ql: return "charcoal"
    if "investigation" in ql or "tests" in ql: return "investigations"
    if "1 mg/kg" in ql or "threshold" in ql: return "threshold"
    return "generic"

def _contains_any(h: str, needles):
    hl = h.lower()
    return any(n in hl for n in needles)

def filter_by_anchors(cands, intent: str, min_keep=4):
    anc = ANCHORS.get(intent, [])
    if not anc: return cands
    f = [c for c in cands if _contains_any(c.get("text",""), anc)]
    return f if len(f) >= min_keep else cands

def select_quotes(cands, max_quotes=3, max_chars=500):
    out=[]
    for c in cands[:max_quotes]:
        qt = re.sub(r"\s+"," ", c["text"]).strip()
        if len(qt) > max_chars: qt = qt[:max_chars] + "…"
        out.append({"id":c["id"], "section":_clean_section(c.get("section","")), "quote":qt})
    return out

def _has_num(s): return bool(re.search(r"\d", s))

def stitch_neighbors(quotes, max_chars=500):
    id2idx = {cid:i for i,cid in enumerate(ids)}
    out=[]
    for q in quotes:
        txt = q["quote"]
        if ("dose" in txt.lower()) and not _has_num(txt):
            idx = id2idx.get(q["id"])
            if idx is not None and idx+1 < len(texts):
                nxt = re.sub(r"\s+"," ", texts[idx+1]).strip()
                stitched = (txt + " " + nxt)[:max_chars] + ("…" if len(txt+nxt) > max_chars else "")
                sec = q["section"]
                if sec.lower() in ("unknown section","none","nan"):
                    sec = str(sections[idx+1])
                out.append({"id":q["id"], "section":sec, "quote":stitched})
                continue
        out.append(q)
    return out

def numeric_guard(answer: str, quotes: list):
    pats=[r"\b\d+(?:\.\d+)?\s*(?:mg/kg|mcg/kg|g/kg|mg|mcg|g|hours|hour)\b", r"\b\d+(?:\.\d+)?\s*mg\/kg\b"]
    def ext(t):
        out=set()
        for p in pats: out |= set(m.lower().strip() for m in re.findall(p, t, flags=re.I))
        return out
    in_a = ext(answer); in_q=set()
    for q in quotes: in_q |= ext(q["quote"])
    extras = sorted(list(in_a - in_q))
    return extras


In [None]:
# C9 — LLM answerer (Gemma-2-9B-IT; quotes-only)

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

USE_HF_LOGIN = True           # set True and put your token in HF_TOKEN or use Google Colab "secrets"
HF_TOKEN="HF_TOKEN_REDACTED"             # e.g., "hf_xxx" — leave empty if already authenticated

if USE_HF_LOGIN and HF_TOKEN:
    from huggingface_hub import login
    login(token=HF_TOKEN)

LLM_ID = "google/gemma-2-9b-it"
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(LLM_ID, use_fast=True, trust_remote_code=True)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

llm = AutoModelForCausalLM.from_pretrained(
    LLM_ID, device_map="auto", quantization_config=bnb_cfg,
    torch_dtype=torch.bfloat16, trust_remote_code=True
)

# warmup
_w = tokenizer("Reply with OK", return_tensors="pt").to(llm.device)
_ = llm.generate(**_w, max_new_tokens=2, do_sample=False,
                 pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
print("✅ Gemma attached:", LLM_ID)

def build_messages_from_quotes(query: str, quotes: list):
    sys_rules = (
        "You are a clinical protocol assistant.\n"
        "RULES:\n"
        "1) Use ONLY the QUOTES provided. Do NOT add external knowledge.\n"
        "2) If a detail is not present in the quotes, write: 'Not specified in this protocol.'\n"
        "3) Be concise (2–6 lines), clinical, and factual.\n"
        "4) End with a line: 'Citations: [SALB-XXX, ...]'."
    )
    qlines = [f"[Q{i+1} id={q['id']} section={q['section']}] {q['quote']}" for i,q in enumerate(quotes)]
    quotes_block = "\n".join(qlines)
    user_msg = (
        f"QUESTION:\n{query}\n\n"
        f"QUOTES (verbatim; authoritative):\n{quotes_block}\n\n"
        "Write the answer ONLY from these quotes."
    )
    return [
        [{"role":"system","content":sys_rules},
         {"role":"user","content":user_msg}],
        [{"role":"user","content":"SYSTEM RULES (apply strictly):\n"+sys_rules+"\n\n"+user_msg}],
    ]

def llm_answer_from_quotes(query: str, quotes: list, max_new_tokens=180):
    variants = build_messages_from_quotes(query, quotes)
    answer = None
    for messages in variants:
        try:
            model_input = tokenizer.apply_chat_template(
                messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
            ).to(llm.device)
            with torch.no_grad():
                gen = llm.generate(
                    input_ids=model_input, max_new_tokens=max_new_tokens, do_sample=False,
                    pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
                )
            gen_ids = gen[0][model_input.shape[-1]:]
            answer = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
            break
        except Exception:
            continue

    if answer is None:
        sys_rules = "Use ONLY QUOTES. If missing: 'Not specified in this protocol.' Keep 2–6 lines. End with 'Citations: [...]'."
        qlines = [f"[Q{i+1} id={q['id']} section={q['section']}] {q['quote']}" for i,q in enumerate(quotes)]
        prompt = sys_rules + "\n\nQUESTION:\n" + query + "\n\nQUOTES:\n" + "\n".join(qlines) + "\n\nAnswer:\n"
        inputs = tokenizer(prompt, return_tensors="pt").to(llm.device)
        with torch.no_grad():
            gen = llm.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False,
                               pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)
        answer = tokenizer.decode(gen[0], skip_special_tokens=True).strip()

    # guard: if LLM invents numbers not in quotes
    extras = numeric_guard(answer, quotes)
    if extras:
        answer += "\n[Note] Values not present in quotes detected: " + ", ".join(extras)
    return answer


In [None]:

def ask_json(query: str, use_llm=True, max_new_tokens=160):
    fused    = hybrid_search(query)
    rer      = rerank(query, fused, top_in=100, top_out=12)
    filtered = filter_by_anchors(rer, infer_intent(query), min_keep=4)
    quotes   = stitch_neighbors(select_quotes(filtered, max_quotes=3, max_chars=500), max_chars=500)

    if not quotes:
        return {"query":query, "answer":"Not found in this protocol.", "citations":[], "quotes":[],
                "confidence":0.2, "disclaimer":"For internal protocol QA only; not medical advice."}

    if use_llm:
        ans = llm_answer_from_quotes(query, quotes, max_new_tokens=max_new_tokens)
    else:
        bullets = [f"- {q['quote']} [ref: {q['id']} | {q['section']}]" for q in quotes]
        ans = ("Based on the protocol passages below (verbatim):\n" + "\n".join(bullets) +
               "\n\nAnswer is restricted to the cited text above. If a detail isn’t present there, it’s not specified in this protocol.")

    return {
        "query": query,
        "answer": ans,
        "citations": [{"id":q["id"],"section":q["section"]} for q in quotes],
        "quotes": quotes,
        "confidence": 0.87,
        "disclaimer": "For internal protocol QA only; not medical advice."
    }

print("✅ ask_json ready")


In [None]:
# C10 — Probe with 15 questions (13 in-doc, 2 not-in-doc); compute metrics and save
from datetime import datetime
import csv, json

PROBE = [
    # IN-DOC
    ("list all  the initial assessment ?", True),
    # ("What symptoms classify as mild to moderate in salbutamol toxicity?", True),
    # ("What symptoms classify as severe in salbutamol toxicity?", True),
     ("When should investigations be ordered, and which tests are listed?", True),
    # ("What are the home criteria after oral salbutamol ingestion?", True),
    # ("What are the observation criteria in a health care facility?", True),
    # ("What are the admission or ICU criteria for salbutamol toxicity?", True),
    # ("When is single-dose activated charcoal indicated? List all required conditions.", True),
     ("What is the pediatric dose for single-dose activated charcoal?", True),
    # ("What is the adult dose for single-dose activated charcoal?", True),
    # ("For asymptomatic patients with unintentional ingestion < 1 mg/kg, how long should observation last and what should be monitored?", True),
    # ("How should nausea and vomiting be managed in salbutamol toxicity?", True),
    # ("What is the recommended management for hypokalemia, severe agitation or seizures, hypotension, and tachycardia in this protocol?", True),
    # # NOT-IN-DOC (expect strict abstention)
    ("Does the protocol recommend multiple-dose activated charcoal for salbutamol overdose?", False),
    ("Does the protocol specify using hemodialysis or lipid emulsion for severe salbutamol toxicity?", False),
]

def eval_probe(probe):
    rows = []
    in_doc_total = sum(1 for _,expect in probe if expect)
    ood_total    = len(probe) - in_doc_total
    in_doc_pass  = 0
    ood_pass     = 0
    numeric_ok_n = 0
    latencies    = []

    for i, (q, in_doc_expected) in enumerate(probe, 1):
        t0 = time.time()
        res = ask_json(q, use_llm=True, max_new_tokens=180)
        lat = time.time() - t0

        # numeric faithfulness check
        extras = numeric_guard(res["answer"], res["quotes"])
        numeric_ok = 1 if len(extras) == 0 else 0

        # abstention rule: exact phrase required when NOT in doc
        abstained = ("not specified in this protocol" in res["answer"].lower())
        passed = (abstained if not in_doc_expected else not abstained)

        rows.append({
            "idx": i, "question": q, "in_doc_expected": in_doc_expected,
            "abstained": abstained, "pass": passed, "answer": res["answer"],
            "citations": [c["id"] for c in res["citations"]],
            "latency_s": round(lat, 3), "numeric_ok": numeric_ok
        })

        if in_doc_expected and passed: in_doc_pass += 1
        if (not in_doc_expected) and passed: ood_pass += 1
        numeric_ok_n += numeric_ok
        latencies.append(lat)

        print("\n" + "="*100)
        print(f"#{i}  (IN-DOC={in_doc_expected})")
        print("Q:", q)
        print("\nA:\n", res["answer"])
        print("Citations:", res["citations"])

    # summary
    ts = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
    OUT_JSONL = ART_DIR / f"probe15_{ts}.jsonl"
    OUT_CSV   = ART_DIR / f"probe15_{ts}.csv"
    with open(OUT_JSONL, "w", encoding="utf-8") as f:
        for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n")
    with open(OUT_CSV, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader(); w.writerows(rows)

    print("\n=== Probe Summary (15 Qs) ===")
    print(f"In-doc grounding pass rate (13 Qs): {in_doc_pass/in_doc_total:.2f}")
    print(f"Not-in-doc abstention rate (2 Qs):  {ood_pass/max(1,ood_total):.2f}")
    print(f"Numeric faithfulness:               {numeric_ok_n/len(rows):.2f}")
    print(f"Avg latency (s):                    {sum(latencies)/len(latencies):.2f}")
    print("\nSaved to:")
    print("-", OUT_JSONL)
    print("-", OUT_CSV)

eval_probe(PROBE)


# The Best React-Agent %

In [None]:
import os, re, json, time, math
from typing import List, Optional
from urllib.parse import urlparse

# ✅ Set your OpenAI API key here (or use Colab "Secrets")
os.environ["OPENAI_API_KEY"] = "OPENAI_API_KEY_REDACTED"

# --- Allowlist (feel free to add more)
AUTH_DOMAINS = {
    "wikem.org",
    "litfl.com",
    "nhs.uk",
    "medlineplus.gov",
    "nih.gov",
    "ncbi.nlm.nih.gov",
    "who.int",
}

# Keep-lines filter to bias useful clinical facts
KEEP = re.compile(
    r"(dose|dosing|mg|mcg|mL|contraindicat|indication|side effect|adverse|warning|pregnan|"
    r"lact|overdos|toxic|poison|antidote|monitor|renal|hepatic|elderly|pediatric|children|"
    r"interaction|caution|bleed|black box|anaphylaxis|hypersensitivity|GI|kidney|liver|dialysis|alkalin)",
    re.I,
)

def _is_allowed(u: str) -> bool:
    try:
        host = urlparse(u).netloc.lower()
    except Exception:
        return False
    return any(host == d or host.endswith("." + d) for d in AUTH_DOMAINS)

print("✅ Config ready. Allowed domains:", ", ".join(sorted(AUTH_DOMAINS)))


In [None]:
import requests
from bs4 import BeautifulSoup
from ddgs import DDGS
from langchain.tools import tool

HTTP_TIMEOUT = (6, 15)     # (connect, read)
MAX_SEARCH_TIME = 8        # seconds per search call
FACTS_LIMIT = 40           # lines per page

def http_get(url: str, timeout=HTTP_TIMEOUT) -> Optional[str]:
    try:
        r = requests.get(url, headers={"User-Agent": "ReActMed/1.1"}, timeout=timeout)
        r.raise_for_status()
        if "html" in r.headers.get("content-type", "").lower():
            return r.text
        return None
    except Exception:
        return None

def extract_facts(html: str, limit: int = FACTS_LIMIT) -> List[str]:
    soup = BeautifulSoup(html, "html5lib")
    # strip noise
    for tag in soup(["script","style","noscript","svg","img","form","button"]):
        tag.decompose()
    for sel in ["header","nav","footer","aside",".cookie",".cookies","#cookie","#cookies",".share",".subscribe",".newsletter"]:
        for el in soup.select(sel):
            el.decompose()
    main = soup.select_one("main, [role='main']") or soup

    lines = []
    for el in main.find_all(["p","li"]):
        t = el.get_text(" ", strip=True)
        if 20 <= len(t) <= 240 and KEEP.search(t):
            lines.append(re.sub(r"\s+"," ", t))

    # Prefer numeric lines first
    with_nums = [s for s in lines if re.search(r"\d", s)]
    others = [s for s in lines if s not in with_nums]
    out, seen = [], set()
    for s in (with_nums + others):
        k = s.lower()
        if k in seen:
            continue
        seen.add(k)
        out.append(s)
        if len(out) >= limit:
            break
    return out

def ddg_allowlisted(query: str, max_n=8) -> List[str]:
    urls, seen = [], set()
    start = time.monotonic()
    # spread a few results per domain
    per_dom = max(1, math.ceil(max_n / max(1, len(AUTH_DOMAINS))))
    try:
        with DDGS() as ddg:
            for d in AUTH_DOMAINS:
                q = f"{query} site:{d}"
                for r in ddg.text(q, max_results=per_dom):
                    if time.monotonic() - start > MAX_SEARCH_TIME:
                        break
                    u = (r.get("href") or r.get("url") or "").strip()
                    if u and _is_allowed(u) and u not in seen:
                        seen.add(u); urls.append(u)
                        if len(urls) >= max_n:
                            break
                if len(urls) >= max_n:
                    break
    except Exception:
        pass
    return urls[:max_n]

@tool
def search_tox(query: str, max_n: int = 8) -> List[str]:
    """
    Search trusted medical domains for `query` and return up to ~8 URLs.
    Always uses an allowlist (wikem, LITFL, NHS, MedlinePlus, NIH/NCBI, WHO).
    """
    return ddg_allowlisted(query, max_n=max_n)

@tool
def read_url(url: str) -> List[str]:
    """Fetch page (allowlisted only) and return ~24–40 concise fact lines."""
    if not _is_allowed(url):
        return []
    html = http_get(url)
    if not html:
        return []
    return extract_facts(html, limit=FACTS_LIMIT)

print("✅ Tools ready (search_tox, read_url).")


In [None]:
from langchain_openai import ChatOpenAI
import os

# GPT-4o for the agent
llm_for_agent = ChatOpenAI(
    model="gpt-4o",
    temperature=0.0,
    max_tokens=700,    # space for ReAct trace + answer
    OPENAI_API_KEY="OPENAI_API_KEY_REDACTED", # Pass the API key directly
)

print("✅ GPT-4o ready for the agent.")

In [None]:
try:
    from langchain_core.prompts import ChatPromptTemplate
except Exception:
    from langchain.prompts import ChatPromptTemplate

tools = [search_tox, read_url]
tool_names = ", ".join([t.name for t in tools])

system_instr = (
    "You are a clinical toxicology assistant.\n"
    "Use tools to SEARCH trusted domains and READ at least one allowed URL. Prefer explicit numbers.\n"
    "It is OK to show Thought/Action/Observation while reasoning. After that, print a Final Answer list.\n"
    "Never include 'Human:' or 'Assistant:' role tags. Never fabricate facts. Keep outputs concise.\n"
    "IMPORTANT: After each search_tox, you must call read_url on one of the returned URLs."
)

react_prompt = ChatPromptTemplate.from_messages([
    ("human",
     system_instr
     + "\n\nTOOLS:\n{tools}\nValid tool names: {tool_names}\n\n"
       "Follow EXACT format for each step:\n"
       "Thought: <short>\n"
       "Action: <search_tox|read_url>\n"
       "Action Input: \"<string>\"\n"
       "Observation: <tool result>\n"
       "(repeat Thought/Action/... as needed; up to 6 steps)\n"
       "When you have enough (and only after at least one read_url), output:\n"
       "Final Answer:\n"
       "- <4–8 concise clinical bullets answering the question>\n"
       "Do NOT add a 'Citations:' block; the system will append it.\n\n"
       "Question: {input}\n\n"
       "{agent_scratchpad}"
    )
]).partial(
    tool_names=tool_names,
    tools="\n".join(f"- {t.name}: {t.description}" for t in tools),
)

print("✅ ReAct prompt ready (trace allowed).")


In [None]:
import re, json
from typing import Union
from langchain.agents import AgentExecutor, create_react_agent
from langchain.agents.agent import AgentOutputParser
from langchain.schema import AgentAction, AgentFinish

class TolerantReActParser(AgentOutputParser):
    def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
        # strip code fences & bullet prefixes that sometimes confuse parsing
        t = re.sub(r"```.*?```", "", text.strip(), flags=re.S)
        t = re.sub(r"^\s*[-•]\s*", "", t, flags=re.M)

        # Final?
        m_final = re.search(r"Final Answer:\s*(.+?)\s*$", t, flags=re.S)
        if m_final:
            return AgentFinish(return_values={"output": m_final.group(1).strip()}, log=text)

        # Tool step
        m_act = re.search(r"Action:\s*([A-Za-z_][A-Za-z0-9_]*)", t)
        m_inp = re.search(r"Action Input:\s*(.*)", t)
        if not (m_act and m_inp):
            raise ValueError("FORMAT ERROR. Use: Thought/Action/Action Input/Observation … then Final Answer.")
        tool = m_act.group(1).strip()
        raw_input = m_inp.group(1).strip()

        # Accept JSON, quoted, or raw
        arg = None
        try:
            val = json.loads(raw_input)
            if isinstance(val, list) and val and isinstance(val[0], str):
                arg = val[0]
            elif isinstance(val, str):
                arg = val
        except Exception:
            pass
        if arg is None:
            m_q = re.match(r'^"(.*)"$', raw_input)
            arg = m_q.group(1) if m_q else raw_input
        return AgentAction(tool=tool, tool_input=arg.strip(), log=text)

    @property
    def _type(self) -> str:
        return "tolerant-react-parser"

agent = create_react_agent(
    llm=llm_for_agent,
    tools=[search_tox, read_url],
    prompt=react_prompt,
    output_parser=TolerantReActParser(),
)

react_executor = AgentExecutor.from_agent_and_tools(
    agent=agent,
    tools=[search_tox, read_url],
    verbose=False,                  # we’ll render our own clean trace
    return_intermediate_steps=True, # needed for the trace
    max_iterations=6,
    early_stopping_method="force",
)

print("✅ Agent executor ready (keeps intermediate steps).")


In [None]:
try:
    from langchain_core.callbacks import BaseCallbackHandler
except Exception:
    from langchain.callbacks.base import BaseCallbackHandler

def _extract_url_from_tool_input(raw: str) -> Optional[str]:
    if raw is None: return None
    s = str(raw).strip()
    try:
        val = json.loads(s)
        if isinstance(val, list) and val and isinstance(val[0], str):
            s = val[0].strip()
        elif isinstance(val, str):
            s = val.strip()
    except Exception:
        pass
    if len(s) >= 2 and s[0] in "\"'" and s[-1] == s[0]:
        s = s[1:-1]
    m = re.search(r'https?://\S+', s)
    return m.group(0) if m else None

class ToolRunRecorder(BaseCallbackHandler):
    def __init__(self):
        self._last_tool = None
        self._last_input = None
        self.observed_urls: List[str] = []

    def on_tool_start(self, serialized=None, input_str=None, **kwargs):
        self._last_tool = (serialized or {}).get("name") if isinstance(serialized, dict) else kwargs.get("name")
        self._last_input = input_str if isinstance(input_str, str) else ("" if input_str is None else str(input_str))

    def on_tool_end(self, output, **kwargs):
        try:
            if self._last_tool == "read_url" and isinstance(output, list) and output:
                url = _extract_url_from_tool_input(self._last_input)
                if url and _is_allowed(url) and url not in self.observed_urls:
                    self.observed_urls.append(url)
        finally:
            self._last_tool = None
            self._last_input = None

def _format_trace(intermediate_steps) -> str:
    lines = []
    for i, (action, observation) in enumerate(intermediate_steps, 1):
        # try to show the Thought that preceded this Action
        thought = ""
        if getattr(action, "log", None):
            m = re.search(r"Thought:\s*(.*?)\s*Action:", action.log, flags=re.S)
            if m: thought = m.group(1).strip()
        if thought:
            lines.append(f"Thought {i}: {thought}")
        lines.append(f"Action {i}: {action.tool}")
        lines.append(f'Action Input {i}: "{action.tool_input}"')
        if isinstance(observation, list):
            preview_n = min(3, len(observation))
            lines.append(f"Observation {i} (showing first {preview_n}/{len(observation)}):")
            for s in observation[:preview_n]:
                lines.append(f"  · {s}")
        else:
            lines.append(f"Observation {i}: {str(observation)[:500]}")
        lines.append("")
    return "\n".join(lines).strip()

def _tidy_bullets(md: str) -> str:
    if not md: return md
    out = []
    for ln in md.splitlines():
        s = ln.rstrip()
        if not s: continue
        if s.lower().startswith("citations:"): continue
        if not s.lstrip().startswith(("-", "•")):
            out.append("- " + s.strip())
        else:
            out.append(s)
    # de-dup adjacent
    cleaned = []
    for ln in out:
        if not cleaned or ln.strip() != cleaned[-1].strip():
            cleaned.append(ln)
    return "\n".join(cleaned).strip()

def _append_citations(md: str, observed_urls: List[str]) -> str:
    keep, seen = [], set()
    for u in observed_urls or []:
        if _is_allowed(u) and u not in seen:
            seen.add(u); keep.append(u)
    cites = "Citations:\n" + ("\n".join(f"- {u}" for u in keep) if keep else "- Not specified (no allowed pages returned content)")
    return (md.rstrip() + "\n\n" + cites).strip()

def _fallback_citation(question: str) -> List[str]:
    # One small try if the agent forgot to call read_url
    seed = question + " mg/kg charcoal antidote nomogram bicarbonate dialysis"
    urls = search_tox.invoke(seed) or search_tox.invoke(question) or []
    for u in urls:
        lines = read_url.invoke(u)
        if lines:
            return [u]
    return []

def react_run_with_trace(question: str) -> dict:
    rec = ToolRunRecorder()
    res = react_executor.invoke({"input": question}, config={"callbacks": [rec]})
    trace_md = _format_trace(res.get("intermediate_steps", []))
    final = _tidy_bullets(res.get("output","").strip())

    observed = list(rec.observed_urls)
    if not observed:
        observed = _fallback_citation(question)  # avoid empty citations

    final_with_cites = _append_citations(final, observed)
    return {"trace": trace_md, "final": final_with_cites, "observed_urls": observed}

print("✅ Runner ready.")


In [None]:
q1 = "What is the stepwise management of salicylate poisoning, including charcoal, urinary alkalinization (target urine pH), serial levels, and hemodialysis indications?"
out = react_run_with_trace(q1)
print(out["trace"])
print("\nFinal Answer:\n" + out["final"])

print("\n" + "="*90 + "\n")

q2 = "What investigations should be ordered (with timing) after a suspected paracetamol overdose?"
out2 = react_run_with_trace(q2)
print(out2["trace"])
print("\nFinal Answer:\n" + out2["final"])


# Router + LangGraph

In [None]:
# Cell R0 — install/check LangGraph + imports
# Safe to run multiple times; does nothing if already installed.
try:
    from langgraph.graph import StateGraph, END
except Exception:
    %pip -q install -U langgraph
    from langgraph.graph import StateGraph, END

from typing import TypedDict, Optional, List, Dict, Any
import re

print("✅ LangGraph ready. Proceed to R1.")


In [None]:
# Cell R1 — Router (English-only) for SALBUTAMOL → ProtocolRAG, else → Agent
# (No cross-drug handling, no OTHER_DRUGS list)

# NOTE: Cell R0 already imported `re`. If not, uncomment the next line:
# import re

SALBUTAMOL_TERMS = {
    "salbutamol","ventolin", "proventil", "proair"
}

# Case-insensitive word-boundary match for all synonyms/brands
SALB_PATTERN = re.compile(r"\b(" + "|".join(map(re.escape, SALBUTAMOL_TERMS)) + r")\b", re.I)

def contains_salbutamol(q: str) -> bool:
    return bool(SALB_PATTERN.search(q or ""))

def decide_route(question: str) -> str:
    """
    Returns one of: 'ProtocolRAG' or 'Agent'
    """
    return "ProtocolRAG" if contains_salbutamol(question) else "Agent"

# --- quick smoke test ---
tests = [
    "What is the overdose management for SALBUTAMOL in adults?",   # → ProtocolRAG
    "Ventolin ingestion 1 mg/kg — ICU criteria?",                  # → ProtocolRAG
    "Interaction between albuterol and propranolol?",              # → ProtocolRAG (by your new rule)
    "Is ibuprofen overdose threshold different from naproxen?"     # → Agent
]
for q in tests:
    print(f"{q} → {decide_route(q)}")

print("✅ Router compiled. If the tests look right, proceed to R2.")


In [None]:
# Cell R2 — State + node adapters that wrap your existing functions

from typing import TypedDict, Optional, List, Dict, Any

class RouterState(TypedDict, total=False):
    question: str
    route: str            # "ProtocolRAG" or "Agent" (decided later by the router)
    answer_md: str        # final markdown answer (harmonized across both paths)
    citations: List[str]  # list of URLs or protocol reference IDs
    trace: Optional[str]  # optional ReAct trace (Agent path)
    source: str           # "protocol" | "agent"

def protocol_rag_node(state: RouterState) -> RouterState:
    """
    Calls your salbutamol Protocol RAG over the salbutamol document only.
    Expects `ask_json(question, ...)` to be already defined in the notebook.

    Expected return (flexible):
      - res.get('answer'): str
      - res.get('citations'): list[dict|str]  # we handle either
      - res.get('disclaimer'): str (optional)
    """
    if "ask_json" not in globals():
        raise RuntimeError(
            "ask_json(...) is not defined. Please run the cells that build your Protocol RAG first."
        )

    q = (state.get("question") or "").strip()
    res = ask_json(q)  # if your signature needs kwargs (e.g., use_llm=True), add them in your ask_json impl

    # Normalize fields robustly
    answer = (res.get("answer") or "").strip() if isinstance(res, dict) else str(res)
    cites_raw = res.get("citations", []) if isinstance(res, dict) else []
    disclaimer = (res.get("disclaimer") or "").strip() if isinstance(res, dict) else ""

    # Normalize citations to strings
    norm_cites: List[str] = []
    cite_lines: List[str] = []
    for c in (cites_raw or []):
        if isinstance(c, dict):
            cid = c.get("id") or c.get("ref") or c.get("section") or c.get("url") or "protocol"
            norm_cites.append(f"SALB:{cid}")
            sect = c.get("section", "")
            cite_lines.append(f"- SALB:{cid}" + (f" ({sect})" if sect else ""))
        else:
            norm_cites.append(str(c))
            cite_lines.append(f"- {str(c)}")

    md = answer if answer else "Not specified in this protocol."
    if cite_lines:
        md = md.rstrip() + "\n\nCitations:\n" + "\n".join(cite_lines)
    if disclaimer:
        md = md.rstrip() + f"\n\n> {disclaimer}"

    out = dict(state)
    out.update({
        "route": "ProtocolRAG",
        "answer_md": md,
        "citations": norm_cites,
        "source": "protocol",
    })
    return out

def agent_node(state: RouterState) -> RouterState:
    """
    Calls your web/search Agent (ReAct) that reads trusted domains and extracts facts.
    Expects `react_run_with_trace(question)` to be already defined.

    Expected return (flexible):
      - res.get('final'): str          # final answer text (preferred)
      - res.get('answer'): str         # some agents use 'answer'
      - res.get('observed_urls'): list # citations/URLs collected while reasoning
      - res.get('citations'): list     # alternative citation field
      - res.get('trace'): str          # ReAct trace (optional)
    """
    if "react_run_with_trace" not in globals():
        raise RuntimeError(
            "react_run_with_trace(...) is not defined. Please run the cells that define your Agent first."
        )

    q = (state.get("question") or "").strip()
    res = react_run_with_trace(q)

    final_md = (
        (res.get("final") or res.get("answer") or "").strip()
        if isinstance(res, dict) else str(res)
    )
    urls = []
    if isinstance(res, dict):
        urls = res.get("observed_urls") or res.get("citations") or []
    urls = [str(u) for u in (urls or [])]

    out = dict(state)
    out.update({
        "route": "Agent",
        "answer_md": final_md if final_md else "No answer returned.",
        "citations": urls,
        "trace": res.get("trace", "") if isinstance(res, dict) else "",
        "source": "agent",
    })
    return out

print("✅ Node adapters set. ")


In [None]:
# Cell R3 — Build and compile the LangGraph with an explicit Router node

from langgraph.graph import StateGraph, END

def router_node(state: RouterState) -> RouterState:
    """
    Pass-through node: we don't modify state here.
    Routing is handled by the conditional edges below via `route_edge`.
    """
    return state

def route_edge(state: RouterState) -> str:
    """
    Decide which branch to follow by using the R1 router logic.
    Must return exactly 'ProtocolRAG' or 'Agent'.
    """
    q = (state.get("question") or "").strip()
    return decide_route(q)  # defined in Cell R1

# Build the graph
graph = StateGraph(RouterState)

# Add nodes
graph.add_node("Router", router_node)          # entry point (pure pass-through)
graph.add_node("ProtocolRAG", protocol_rag_node)
graph.add_node("Agent", agent_node)

# Entry is Router
graph.set_entry_point("Router")

# Conditional edges from Router → (ProtocolRAG | Agent)
graph.add_conditional_edges(
    "Router",
    route_edge,
    {"ProtocolRAG": "ProtocolRAG", "Agent": "Agent"}
)

# Each branch ends the run
graph.add_edge("ProtocolRAG", END)
graph.add_edge("Agent", END)

# Compile to an app you can invoke
rag_agent_app = graph.compile()

print("✅ Graph compiled as `rag_agent_app` (Router → ProtocolRAG|Agent → END). Proceed to R4.")


In [None]:
# Cell R4 — Helper to run the whole graph in one call

from typing import Dict, Any

def route_and_run(question: str) -> Dict[str, Any]:
    """
    Runs the LangGraph app on a question and returns:
      {
        'route': 'ProtocolRAG' | 'Agent',
        'answer_md': str,
        'citations': list[str],
        'trace': str  # only for Agent path; empty string for ProtocolRAG
      }
    """
    if "rag_agent_app" not in globals():
        raise RuntimeError("rag_agent_app is not defined. Run Cell R3 first.")

    state_in: RouterState = {"question": question}
    out: RouterState = rag_agent_app.invoke(state_in)

    return {
        "route": out.get("route", ""),
        "answer_md": out.get("answer_md", ""),
        "citations": out.get("citations", []) or [],
        "trace": out.get("trace", "") if out.get("source") == "agent" else "",
    }

print("✅ Runner ready. Proceed to R5 when you want to test.")


In [None]:
# Cell R5 — Sanity tests for the router + graph

test_questions = [
    # Should route to ProtocolRAG (mentions salbutamol)
    "When should investigations be ordered after suspected salbutamol overdose?",
    # Should route to Agent (no salbutamol)
    "What is the overdose management for ibuprofen in adults?",
    # By your rule, any albuterol mention → ProtocolRAG
    "What is the interaction between albuterol and propranolol?"
]
for q in test_questions:
    print("\n" + "="*100)
    print("Q:", q)
    out = route_and_run(q)
    print("ROUTE:", out["route"])
    print("\nAnswer (truncated):\n" + (out["answer_md"][:1200] + ("..." if len(out["answer_md"]) > 1200 else "")))
    if out["citations"]:
        print("\nCitations:")
        for c in out["citations"][:8]:
            print("-", c)
    if out["trace"]:
        print("\n(Trace available; hidden)")
print("\n✅ Tests complete. If the routing matches expectations, you're done.")


# API_DEMO

In [None]:
# Cell G2 — Professional Gradio UI (messages format, clean typography)

%pip -q install gradio

import gradio as gr
import time, json, uuid
from typing import List, Any

# ---------- Professional styling ----------
CSS = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap');

:root {
  --bg: #0b0f1a;
  --panel: #0f172a;
  --border: #1e293b;
  --muted: #9fb2d0;
  --text: #e6eefb;
  --accent: #3b82f6;
  --success: #10b981;
}

* { font-family: 'Inter', system-ui, -apple-system, Segoe UI, Roboto, sans-serif; }
.gradio-container {
  background:
    radial-gradient(1200px 600px at 20% -20%, #132341 0, rgba(11,15,26,0) 60%),
    linear-gradient(180deg, #0b0f1a 0%, #0b0f1a 100%);
  color: var(--text);
}
.markdown-body h1, .markdown-body h2, .markdown-body h3 { letter-spacing: .2px; }
#status, #sidecard {
  padding: 12px 14px; border-radius: 14px;
  background: var(--panel); border: 1px solid var(--border);
}
.badge {
  display:inline-block; padding: 4px 10px; border-radius: 999px;
  font-weight:600; border:1px solid var(--border); line-height: 1.1;
}
.badge.agent { background: #0b1220; color: #cfe1ff; }
.badge.protocol { background: #0b2012; color: #b6f7d2; }
.pill { display:inline-block; padding: 2px 8px; border-radius: 999px; background:#0e1a2f; border:1px solid var(--border); }
.cites a { color: #93c5fd; text-decoration: none; }
.cites a:hover { text-decoration: underline; }
.small { color: var(--muted); font-size: 12px; }
.footer { color: var(--muted); font-size: 12px; text-align: center; margin-top: 8px; }
"""

def _status_html(route: str, latency: float) -> str:
    pill_cls = "protocol" if route == "ProtocolRAG" else "agent"
    return f"""
      <div style="display:flex; align-items:center; gap:12px;">
        <span class="badge {pill_cls}">{route}</span>
        <span class="small">{latency:.2f}s</span>
      </div>
    """

def _cites_html(citations: List[str]) -> str:
    if not citations:
        return "<span class='small'>No citations</span>"
    urls = [c for c in citations if isinstance(c, str) and c.lower().startswith("http")]
    refs = [c for c in citations if c not in urls]
    parts = []
    if urls:
        parts.append("**References:** " + " • ".join([f'<a href="{u}" target="_blank" rel="noopener">{u}</a>' for u in urls]))
    if refs:
        pills = " ".join([f"<span class='pill'>{r}</span>" for r in refs])
        parts.append(pills)
    return "<div class='cites'>" + "<br/>".join(parts) + "</div>"

def respond(message: str, messages: List[dict], show_trace: bool):
    """
    messages uses Chatbot(type='messages') format:
      [{'role':'user','content':'...'}, {'role':'assistant','content':'...'}, ...]
    """
    t0 = time.time()
    out = route_and_run(message)  # <-- your existing router+runner
    dt = time.time() - t0

    assistant_md = out["answer_md"]
    messages = (messages or []) + [
        {"role": "user", "content": message},
        {"role": "assistant", "content": assistant_md},
    ]

    status_html = _status_html(out["route"], dt)
    cites_html = _cites_html(out.get("citations", []))
    trace_text = out.get("trace", "") if (show_trace and out["route"] == "Agent") else ""

    return messages, messages, status_html, cites_html, trace_text, ""

def reset_chat():
    empty = []
    return empty, empty, "<span class='small'>Ready</span>", "<span class='small'>No citations</span>", "", ""

def export_chat(messages: List[dict]) -> str:
    path = f"/tmp/transcript_{uuid.uuid4().hex[:8]}.json"
    with open(path, "w", encoding="utf-8") as f:
        json.dump(messages or [], f, ensure_ascii=False, indent=2)
    return path

with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
    # Header
    gr.Markdown(
        "<div class='markdown-body'>"
        "<h1>Clinical Toxicology Assistant</h1>"
        "<p class='small'>Proof-of-concept chat experience</p>"
        "</div>"
    )

    with gr.Row():
        # Left: Chat
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(
                label="Conversation",
                height=560,
                type="messages",
                show_copy_button=True,
            )
            with gr.Row():
                msg = gr.Textbox(placeholder="Ask a clear clinical question…", scale=5, autofocus=True, lines=2)
                send = gr.Button("Send", variant="primary", scale=1)
        # Right: Session panel
        with gr.Column(scale=2):
            gr.Markdown("#### Session")
            status = gr.HTML("<span class='small'>Ready</span>", elem_id="status")
            gr.Markdown("#### Citations")
            cites = gr.HTML("<span class='small'>No citations</span>", elem_id="sidecard")

            with gr.Accordion("Agent trace (optional)", open=False):
                trace = gr.Code(label="Trace", language="markdown")
            show_trace = gr.Checkbox(label="Show Agent trace for non-salbutamol queries", value=False)

            gr.Markdown("#### Quick prompts")
            ex1 = gr.Button("Salbutamol overdose — recommended investigations & monitoring (adult)")
            ex2 = gr.Button("Ibuprofen overdose — ED management (adult)")
            ex3 = gr.Button("Albuterol with propranolol — clinically significant interaction?")

            with gr.Row():
                clear = gr.Button("New conversation")
                download = gr.DownloadButton(label="Export transcript (.json)")

    # State to store messages for export
    messages_state = gr.State([])

    # Handlers
    send.click(
        respond,
        inputs=[msg, messages_state, show_trace],
        outputs=[chatbot, messages_state, status, cites, trace, msg],
        show_progress="full",
    )
    msg.submit(
        respond,
        inputs=[msg, messages_state, show_trace],
        outputs=[chatbot, messages_state, status, cites, trace, msg],
        show_progress="full",
    )

    # Quick prompts
    ex1.click(lambda: "In adult salbutamol overdose, what investigations and monitoring are recommended in the ED?", outputs=msg)
    ex2.click(lambda: "What is the recommended ED management for ibuprofen overdose in an adult?", outputs=msg)
    ex3.click(lambda: "Is there a clinically significant interaction between albuterol and propranolol in acute care?", outputs=msg)

    # Utilities
    clear.click(reset_chat, outputs=[chatbot, messages_state, status, cites, trace, msg])
    download.click(export_chat, inputs=messages_state, outputs=download)

    gr.Markdown("<div class='footer'>v0.1 • internal proof of concept</div>")

demo.launch(share=True)  # share=True => quick public link


# EVALUATE THE SCOPE

In [None]:
# Cell T1 — Testing | run all questions with *route_and_run* and append "Model Answer"

import pandas as pd
import time
import os

# 🔧 Your CSV path (as you specified)
CSV_PATH = "/content/data_salbutamol/toxicology_eval_150_with_answers.csv"

# --- Hard guarantee we use the same callable as your Gradio example ---
if "route_and_run" not in globals():
    raise RuntimeError(
        "route_and_run(...) is not defined. Please run your R cells (R0–R4) so route_and_run is available."
    )

def _call_model(q: str):
    """
    EXACTLY the same entrypoint as your Gradio code uses.
    Must return a dict with keys: route, answer_md, citations, trace.
    """
    out = route_and_run(q)

    if not isinstance(out, dict):
        raise TypeError("route_and_run must return a dict.")

    # Normalize expected keys
    route = out.get("route", "")
    answer = (out.get("answer_md") or "").strip()
    citations = out.get("citations", []) or []
    trace = out.get("trace", "")

    return answer, route, citations, trace

# --- Load questions ---
if not os.path.exists(CSV_PATH):
    raise FileNotFoundError(f"CSV not found at: {CSV_PATH}")

df = pd.read_csv(CSV_PATH)
if "question" not in df.columns:
    raise ValueError("CSV must contain a 'question' column.")

# Optional: backup before overwriting
ts = time.strftime("%Y%m%d-%H%M%S")
backup_path = CSV_PATH.replace(".csv", f".backup.{ts}.csv")
df.to_csv(backup_path, index=False)

# --- Batch inference ---
model_answers = []
model_routes = []
model_citations = []

t0 = time.time()
for i, row in df.iterrows():
    q = str(row["question"])
    try:
        ans, route, cites, _trace = _call_model(q)
    except Exception as e:
        ans = f"[ERROR] {type(e).__name__}: {e}"
        route, cites = "ERROR", []
    model_answers.append(ans)
    model_routes.append(route)
    model_citations.append("; ".join(map(str, cites)))
    if (i + 1) % 10 == 0:
        print(f"Processed {i+1}/{len(df)}")

elapsed = time.time() - t0
print(f"\n✅ Done. Total: {len(df)} rows in {elapsed:.1f}s.")

# --- Write back to the same CSV (plus keep a backup) ---
df["Model Answer"] = model_answers
df["Model Route"] = model_routes         # helpful for route accuracy checks
df["Model Citations"] = model_citations  # helpful for QA

df.to_csv(CSV_PATH, index=False)
print("💾 Saved:", CSV_PATH)
print("🧯 Backup:", backup_path)

# Quick peek
df.head(10)


In [None]:
df.head(5)

In [None]:
# Optional: duplicate to /mnt/data for quick “Download” links in Colab
import shutil, os, csv
import pandas as pd

CSV_PATH = "/content/data_salbutamol/toxicology_eval_150_with_answers.csv"
df = pd.read_csv(CSV_PATH)

OUT_CSV_DL = "/content/data_salbutamol/toxicology_eval_150_with_answers_RESULTS.csv"
OUT_TSV_DL = "/content/data_salbutamol/toxicology_eval_150_with_answers_RESULTS.tsv"

df.to_csv(OUT_CSV_DL, index=False, quoting=csv.QUOTE_MINIMAL, lineterminator="\n")
df.to_csv(OUT_TSV_DL, index=False, sep="\t", lineterminator="\n")

print("Downloadable CSV:", OUT_CSV_DL)
print("Downloadable TSV:", OUT_TSV_DL)
