In [1]:
!pip install --upgrade pip
!pip install groq pandas


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)


ERROR: To modify pip, please run the following command:
C:\Users\anish\anaconda3\python.exe -m pip install --upgrade pip




In [2]:
# ==== PATHS / MODEL CONFIG ====
EVAL_SELECTION_CSV = "medquad_selected_questions.csv"  # <- same questions file you used earlier
ANSWER_FIELD       = "answer"  # if your selection CSV has 'gold', we'll auto-detect below

# Groq model for generation (few-shot answering)
GROQ_API_KEY = ""   # <-- put your key 
GEN_MODEL    = "llama-3.1-8b-instant"    # or "llama-3.1-8b-instant" if you prefer free/cheaper

# Generation knobs (keep fixed across runs for fair comparison)
TEMPERATURE = 0.0
MAX_TOKENS  = 256
TOP_P       = 1.0

# Ablation settings (how many examples to prepend)
FEWSHOT_LIST = [2, 3, 5]   # run all of these
PRINT_PROMPTS_FOR_FEWSHOT = False  # True = print the full prompt sent to the model

# Where to save outputs
RESULTS_DIR = "."

# (Optional) run the LLM-as-judge scoring after each ablation run
RUN_EVAL = True


In [3]:
import os, time, json, re, pandas as pd
from typing import List, Dict, Tuple

# Load the SAME questions file you used before
sel = pd.read_csv(EVAL_SELECTION_CSV)

# figure out which column in the selection file holds the gold text
gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

# keep the rows you want (you can change head(3) to use more)
eval_df = sel[["question", gold_col]].copy().head(3)
eval_df.columns = ["question", "gold"]  # internal rename only (we're not saving this CSV)

print("Few-shot evaluation questions:", len(eval_df))
print(eval_df["question"].to_string(index=False))


Few-shot evaluation questions: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


In [4]:
from groq import Groq

client = Groq(api_key=GROQ_API_KEY)

def chat_messages(model: str, messages: List[Dict], temperature: float = 0.0, max_tokens: int = 256, top_p: float = 1.0) -> str:
    """
    Thin wrapper around Groq chat completions. Returns the text content.
    """
    r = client.chat.completions.create(
        model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p,
        messages=messages
    )
    return r.choices[0].message.content.strip()


In [5]:
# --- Curated few-shot exemplars (medical, generic, not the same as your eval questions) ---
# Feel free to edit/add more examples here; the builder will pick the first k you request.
CURATED_FEWSHOTS: List[Tuple[str, str]] = [
    (
        "What are common side effects of the influenza vaccine?",
        "Most side effects are mild and short-lived, such as soreness at the injection site, low-grade fever, fatigue, or headache. "
        "Severe allergic reactions are rare; anyone with a history of anaphylaxis to a vaccine component should discuss risks with a clinician."
    ),
    (
        "How is mild dehydration treated at home?",
        "Oral rehydration with water or oral rehydration solutions is first-line. Sip small amounts frequently, avoid alcohol and caffeine, "
        "and address the underlying cause (e.g., gastrointestinal losses). Seek care if symptoms persist or worsen."
    ),
    (
        "Who should be screened for hypertension?",
        "All adults should have periodic blood pressure screening. Screening is especially important for individuals with risk factors such as "
        "obesity, diabetes, kidney disease, or a family history of hypertension."
    ),
    (
        "What are red-flag symptoms of chest pain?",
        "Red flags include pressure-like pain radiating to the arm/jaw, shortness of breath, diaphoresis, syncope, or hemodynamic instability. "
        "Immediate evaluation is warranted to rule out acute coronary syndrome or other emergent causes."
    ),
    (
        "When are antibiotics indicated for acute bronchitis?",
        "Most cases are viral; antibiotics are generally not indicated unless there is strong suspicion of bacterial infection, "
        "high-risk comorbidities, or evidence of pneumonia. Symptomatic care is usually sufficient."
    ),
]

def build_few_shot_messages(question: str, k: int = 3, print_prompt: bool = False) -> List[Dict]:
    """
    Build a few-shot prompt of k exemplars + your question.
    Prints the full prompt if print_prompt=True (good for debugging/learning).
    """
    k = max(0, min(k, len(CURATED_FEWSHOTS)))
    exemplars = CURATED_FEWSHOTS[:k]

    # System message keeps the model focused and concise
    system_msg = (
        "You are a concise, evidence-focused medical assistant. "
        "Answer briefly (2–4 sentences) and avoid speculation. If unsure, say you don't know."
    )

    # Build a readable 'Examples' block (Q/A pairs)
    examples_block = []
    for i, (q_ex, a_ex) in enumerate(exemplars, start=1):
        examples_block.append(f"Example {i} — Question: {q_ex}\nExample {i} — Answer: {a_ex}")
    examples_text = "\n\n".join(examples_block) if examples_block else "(No examples)"

    user_msg = (
        f"Use the examples below as a style and reasoning guide.\n\n"
        f"{examples_text}\n\n"
        f"Now answer the new question:\nQuestion: {question}\n\n"
        f"Instructions: Provide a factual, succinct answer in 2–4 sentences. "
        f"If information is insufficient, say 'I don't know.'"
    )

    messages = [
        {"role": "system", "content": system_msg},
        {"role": "user",   "content": user_msg},
    ]

    if print_prompt:
        print("\n" + "="*88)
        print("[FEW-SHOT PROMPT]")
        print("\n[SYSTEM]\n" + system_msg)
        print("\n[USER]\n" + user_msg)
        print("="*88)

    return messages


In [6]:
import unicodedata

def normalize_text(t: str) -> str:
    t = (t or "").strip()
    t = unicodedata.normalize("NFKC", t)
    t = re.sub(r"\s+", " ", t)
    return t


In [7]:
import pandas as pd

def run_few_shot(n_shots: int, print_prompts: bool = False) -> pd.DataFrame:
    """
    Runs few-shot prompting over eval_df with n_shots exemplars.
    Returns the results DataFrame and saves a CSV tagged with n_shots.
    """
    rows = []
    for _, r in eval_df.iterrows():
        q   = normalize_text(str(r["question"]))
        gold= normalize_text(str(r["gold"]))

        messages = build_few_shot_messages(q, k=n_shots, print_prompt=print_prompts)
        ans = chat_messages(GEN_MODEL, messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, top_p=TOP_P)

        rows.append({
            "question": q,
            "gold": gold,
            "strategy": f"few-shot-{n_shots}",
            "answer": normalize_text(ans),
        })

    df_out = pd.DataFrame(rows)
    out_csv = f"{RESULTS_DIR}/medquad_few_shot_answers_n{n_shots}.csv"
    df_out.to_csv(out_csv, index=False)
    print(f"[FEW-SHOT] Saved answers to: {out_csv}  (rows={len(df_out)})")
    return df_out


In [8]:
# === LLM-as-judge for: Faithfulness, Hallucination, Correctness, Relevance ===
# Uses Groq judge model (fast & cheap) separate from your generator model.

import re, pandas as pd
from typing import List

JUDGE_MODEL = "llama-3.3-70b-versatile"   # <- LLM-as-judge
JUDGE_TEMPERATURE = 0.0
JUDGE_MAX_TOKENS = 64
JUDGE_TOP_P = 1.0

def _extract_float(txt: str) -> float:
    m = re.search(r"\d*\.?\d+(?:[eE][-+]?\d+)?", txt or "")
    try: x = float(m.group(0)) if m else 0.0
    except: x = 0.0
    return max(0.0, min(1.0, x))

def chat_judge(system_prompt: str, user_prompt: str, model: str = JUDGE_MODEL) -> str:
    r = client.chat.completions.create(
        model=model, temperature=JUDGE_TEMPERATURE, max_tokens=JUDGE_MAX_TOKENS, top_p=JUDGE_TOP_P,
        messages=[{"role":"system","content":system_prompt},
                  {"role":"user","content":user_prompt}]
    )
    return r.choices[0].message.content.strip()

def entail_prob(premise: str, claim: str) -> float:
    """Return P(premise entails claim) in [0,1]."""
    sys = ("You are an evaluator. Given a PREMISE (evidence) and a CLAIM (one sentence), "
           "return ONLY a number in [0,1] = probability that PREMISE ENTAILS CLAIM.")
    usr = f"PREMISE:\n{premise}\n\nCLAIM:\n{claim}\n\nOutput only a number in [0,1]."
    return _extract_float(chat_judge(sys, usr))

def split_sentences(t: str) -> List[str]:
    t = (t or "").strip()
    # light splitter—good enough for scoring short answers
    return re.split(r'(?<=[.!?])\s+', t) if t else []

def faithfulness_verbose(answer: str, gold_reference: str, thresh: float = 0.5):
    """
    Treat the gold answer as the 'evidence'. For each sentence in the model answer,
    compute entailment from gold->sentence. Faithfulness = fraction supported (>= thresh).
    """
    sents = [s.strip() for s in split_sentences(answer) if s.strip()]
    if not sents:
        return 0.0, 1.0, pd.DataFrame(columns=["sentence","best_entail_prob","supported"])
    rows, supported = [], 0
    for s in sents:
        p = entail_prob(gold_reference, s)
        sup = p >= thresh
        supported += int(sup)
        rows.append({"sentence": s, "best_entail_prob": round(p,3), "supported": sup})
    faith = supported / len(sents)
    halluc = 1.0 - faith
    return faith, halluc, pd.DataFrame(rows)

def answer_correctness_llm(gold_answer: str, model_answer: str) -> float:
    """Bidirectional: 0.5*(E(gold->answer)+E(answer->gold))."""
    e1 = entail_prob(gold_answer, model_answer)
    e2 = entail_prob(model_answer, gold_answer)
    return 0.5 * (e1 + e2)

def answer_relevance(q: str, a: str) -> float:
    sys = (
        "You are an evaluator. Rate how well the ANSWER addresses the QUESTION.\n"
        "- 1.0 = Directly answers, accurate and focused.\n"
        "- 0.7 = Mostly answers with minor gaps/irrelevance.\n"
        "- 0.4 = Partial answer; noticeable gaps or off-topic parts.\n"
        "- 0.0 = Does not answer or off-topic.\n"
        "Return ONLY a number in [0,1]."
    )
    usr = f"QUESTION:\n{q}\n\nANSWER:\n{a}\n\nScore:"
    return _extract_float(chat_judge(sys, usr))

def band(x: float) -> str:
    return "Excellent" if x>=0.90 else "Good" if x>=0.75 else "Borderline" if x>=0.60 else "Poor"

def score_all_metrics(df_answers: pd.DataFrame, faith_thresh: float = 0.5) -> pd.DataFrame:
    """
    Adds: faithfulness, hallucination_rate, answer_relevance, answer_correctness (+ bands).
    """
    rows = []
    for _, r in df_answers.iterrows():
        q, a, g = r["question"], r["answer"], r["gold"]
        faith, halluc, _df = faithfulness_verbose(a, g, thresh=faith_thresh)
        relev = answer_relevance(q, a)
        corr  = answer_correctness_llm(g, a)
        rows.append({
            "question": q,
            "strategy": r.get("strategy",""),
            "faithfulness": round(faith, 3),
            "hallucination_rate": round(halluc, 3),
            "answer_relevance": round(relev, 3),
            "answer_correctness": round(corr, 3),
            "faith_band": band(faith),
            "relevance_band": band(relev),
            "correctness_band": band(corr),
        })
    return pd.DataFrame(rows)


In [9]:
# === Few-shot ablation over N in FEWSHOT_LIST, with full metric coverage ===
all_runs = {}
all_scores = []

for N in FEWSHOT_LIST:
    print("\n" + "="*88)
    print(f"[ABLATION] FEW-SHOT with N={N}")
    print("="*88)

    dfN = run_few_shot(N, print_prompts=PRINT_PROMPTS_FOR_FEWSHOT)   # uses your generator
    all_runs[N] = dfN

    # Score all four metrics
    scored = score_all_metrics(dfN, faith_thresh=0.5)
    scores_csv = f"{RESULTS_DIR}/medquad_few_shot_scores_n{N}.csv"
    scored.to_csv(scores_csv, index=False)
    print(f"[EVAL] Saved per-item scores to: {scores_csv}")

    # Aggregate for this N
    aggN = {
        "n_shots": N,
        "faithfulness_avg": float(scored["faithfulness"].mean()),
        "hallucination_rate_avg": float(scored["hallucination_rate"].mean()),
        "answer_relevance_avg": float(scored["answer_relevance"].mean()),
        "answer_correctness_avg": float(scored["answer_correctness"].mean()),
    }
    all_scores.append(aggN)

# Final comparison table
agg_df = pd.DataFrame(all_scores).sort_values(by="faithfulness_avg", ascending=False)
summary_csv = f"{RESULTS_DIR}/medquad_few_shot_ablation_summary_full_metrics.csv"
agg_df.to_csv(summary_csv, index=False)

print("\n=== FEW-SHOT ABLATION SUMMARY (higher is better) ===")
print(agg_df.to_string(index=False))
print("Saved summary to:", summary_csv)



[ABLATION] FEW-SHOT with N=2
[FEW-SHOT] Saved answers to: ./medquad_few_shot_answers_n2.csv  (rows=3)
[EVAL] Saved per-item scores to: ./medquad_few_shot_scores_n2.csv

[ABLATION] FEW-SHOT with N=3
[FEW-SHOT] Saved answers to: ./medquad_few_shot_answers_n3.csv  (rows=3)
[EVAL] Saved per-item scores to: ./medquad_few_shot_scores_n3.csv

[ABLATION] FEW-SHOT with N=5
[FEW-SHOT] Saved answers to: ./medquad_few_shot_answers_n5.csv  (rows=3)
[EVAL] Saved per-item scores to: ./medquad_few_shot_scores_n5.csv

=== FEW-SHOT ABLATION SUMMARY (higher is better) ===
 n_shots  faithfulness_avg  hallucination_rate_avg  answer_relevance_avg  answer_correctness_avg
       5          0.616667                0.383333              0.133333                0.033333
       2          0.555667                0.444333              0.566667                0.575000
       3          0.166667                0.833333              0.000000                0.166667
Saved summary to: ./medquad_few_shot_ablation_summa