In [1]:
!pip install --upgrade pip
!pip install groq faiss-cpu ragas datasets langchain langchain-community langchain-groq langchain-text-splitters tiktoken scikit-learn pyarrow fastparquet
# Some RAGAS setups prefer pandas<3:
# %pip install 'pandas<3'


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)


ERROR: To modify pip, please run the following command:
C:\Users\anish\anaconda3\python.exe -m pip install --upgrade pip




In [2]:
from pathlib import Path

# ===== Inputs =====
DATA_PATH          = "medqa_cleaned.csv"               # your cleaned dataset CSV
EVAL_SELECTION_CSV = "medquad_selected_questions.csv" # SAME questions CSV used for sparse
ANSWER_FIELD       = "answer"     # do NOT rename your CSV on disk
TEXT_FIELD         = "answer"     # index this column (use 'source_text' if you have longer passages)

# ===== Chunking =====
CHUNK_TOKENS   = 480
CHUNK_OVERLAP  = 80
TOKEN_ENCODING = "cl100k_base"   # sizing proxy

# ===== Embeddings / FAISS persistence =====
EMBED_MODEL      = "BAAI/bge-small-en-v1.5"   # 384-dim, fast
CHUNK_CACHE      = Path("medqa_chunks_token.parquet")   # cache of chunks (id/text/meta)
FAISS_INDEX_PATH = Path("faiss_medqa_bge_small.index")  # FAISS index file (persistent)
MAPPING_PATH     = Path("medqa_chunk_mapping.parquet")  # id->text/meta in the SAME order as FAISS

# Streaming FAISS build knobs
EMBED_BATCH = 64     # smaller = lower RAM + more progress prints
SAVE_EVERY  = 20     # checkpoint FAISS to disk every N batches

# ===== Retrieval / LLM =====
TOP_K         = 3
GROQ_API_KEY  = ""
GEN_MODEL     = "llama-3.3-70b-versatile"
PRINT_PROMPTS = True

# ===== Outputs =====
ANSWERS_CSV   = "dense_rag_faiss_answers.csv"
METRICS_JSON  = "dense_rag_faiss_ragas.json"

# ===== Repro =====
SEED = 42


In [3]:
import os, re, json, time, random, math
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple

random.seed(SEED); np.random.seed(SEED)

# Tame native threads (optional, helps stability in some Windows/Anaconda setups)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"]        = "1"
os.environ["OPENBLAS_NUM_THREADS"]   = "1"
os.environ["MKL_NUM_THREADS"]        = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"]    = "1"


In [4]:
df = pd.read_csv(DATA_PATH)
assert "question" in df.columns, "CSV must have 'question'"
assert ANSWER_FIELD in df.columns, f"CSV must have '{ANSWER_FIELD}'"
assert TEXT_FIELD   in df.columns, f"CSV must have '{TEXT_FIELD}'"

df = df.copy()
df["question"]   = df["question"].astype(str).str.strip()
df[ANSWER_FIELD] = df[ANSWER_FIELD].astype(str).str.strip()
df[TEXT_FIELD]   = df[TEXT_FIELD].astype(str).str.strip()

print("Loaded rows:", len(df))
print("Columns:", list(df.columns))
df.head(3)


Loaded rows: 16018
Columns: ['question', 'answer', 'qtype']


Unnamed: 0,question,answer,qtype
0,Who is at risk for Lymphocytic Choriomeningiti...,LCMV infections can occur after exposure to fr...,susceptibility
1,What are the symptoms of Lymphocytic Choriomen...,LCMV is most commonly recognized as causing ne...,symptoms
2,Who is at risk for Lymphocytic Choriomeningiti...,Individuals of all ages who come into contact ...,susceptibility


In [5]:
from langchain_text_splitters import TokenTextSplitter

token_splitter = TokenTextSplitter(
    chunk_size=CHUNK_TOKENS,
    chunk_overlap=CHUNK_OVERLAP,
    encoding_name=TOKEN_ENCODING
)

def make_chunks_from_df(df, text_col: str):
    chunk_texts, chunk_ids, chunk_meta = [], [], []
    for row_idx, row in df.iterrows():
        text = str(row[text_col] or "")
        chunks = token_splitter.split_text(text) or [text]
        for ci, ch in enumerate(chunks):
            chunk_texts.append(ch)
            chunk_ids.append(f"r{row_idx}_ck{ci}")
            chunk_meta.append({"row_index": int(row_idx), "chunk_index": int(ci)})
    return chunk_ids, chunk_texts, chunk_meta

def save_chunks_parquet(chunk_ids, chunk_texts, chunk_meta, path: Path):
    pd.DataFrame({
        "id":   chunk_ids,
        "text": chunk_texts,
        "meta": [json.dumps(m) for m in chunk_meta],
    }).to_parquet(path, index=False)

def load_chunks_parquet(path: Path):
    d = pd.read_parquet(path)
    return d["id"].tolist(), d["text"].tolist(), [json.loads(x) for x in d["meta"].tolist()]

if CHUNK_CACHE.exists():
    print("Loading chunks from cache:", CHUNK_CACHE)
    chunk_ids, chunk_texts, chunk_meta = load_chunks_parquet(CHUNK_CACHE)
else:
    print("Chunking with TokenTextSplitter…")
    chunk_ids, chunk_texts, chunk_meta = make_chunks_from_df(df, TEXT_FIELD)
    print("Saving chunk cache:", CHUNK_CACHE)
    save_chunks_parquet(chunk_ids, chunk_texts, chunk_meta, CHUNK_CACHE)

print(f"Total chunks: {len(chunk_texts)} (from {len(df)} rows)")
print("Sample chunk:", (chunk_texts[0][:300] + ("…" if len(chunk_texts[0])>300 else "")) if chunk_texts else "<no chunks>")


Chunking with TokenTextSplitter…
Saving chunk cache: medqa_chunks_token.parquet
Total chunks: 18559 (from 16018 rows)
Sample chunk: LCMV infections can occur after exposure to fresh urine, droppings, saliva, or nesting materials from infected rodents. Transmission may also occur when these materials are directly introduced into broken skin, the nose, the eyes, or the mouth, or presumably, via the bite of an infected rodent. Pers…


In [6]:
# === ROBUST + RESUMABLE STREAMING FAISS BUILD ===
# - Embeds in micro-batches to avoid long “silent” stalls
# - Resumes from existing FAISS index (continues from index.ntotal)
# - Saves checkpoints frequently
# - Lots of progress prints so you always see movement

import faiss, time, json, numpy as np, pandas as pd
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings

# ---- knobs you can tweak safely ----
EMBED_BATCH   = 64   # outer batch (group of micro-batches)
MICRO_BATCH   = 8    # embed this many texts per call (keeps each call quick)
SAVE_EVERY_MB = 40   # save index every N micro-batches
PRINT_EVERY_MB= 5    # print a progress line every N micro-batches

assert len(chunk_texts) == len(chunk_ids) == len(chunk_meta), "Chunks not prepared. Run the chunking cell first."
TOTAL = len(chunk_texts)

# Save mapping parquet ONCE (id -> text/meta in FAISS order)
if not MAPPING_PATH.exists():
    mapping_df = pd.DataFrame({
        "id": chunk_ids,
        "text": chunk_texts,
        "meta": [json.dumps(m) for m in chunk_meta],
    })
    mapping_df.to_parquet(MAPPING_PATH, index=False)
    print(f"[FAISS BUILD] Wrote mapping parquet: {MAPPING_PATH}")
else:
    print(f"[FAISS BUILD] Mapping parquet exists: {MAPPING_PATH}")

# Create embedder once
embedder = FastEmbedEmbeddings(model_name=EMBED_MODEL)

# Probe embedding dim (quick)
probe_vec = np.asarray(embedder.embed_documents([chunk_texts[0] or ""]), dtype="float32")
dim = probe_vec.shape[1]
print(f"[FAISS BUILD] Embedding dim: {dim}")

# Cosine via inner product: normalize vectors
def l2_normalize(mat: np.ndarray) -> np.ndarray:
    eps = 1e-12
    norms = np.linalg.norm(mat, axis=1, keepdims=True) + eps
    return mat / norms

# Create or resume index
if FAISS_INDEX_PATH.exists():
    index = faiss.read_index(str(FAISS_INDEX_PATH))
    print(f"[FAISS BUILD] Resuming from existing index: {FAISS_INDEX_PATH}  (ntotal={index.ntotal})")
else:
    index = faiss.IndexFlatIP(dim)
    print("[FAISS BUILD] Starting a fresh FAISS index (IndexFlatIP).")

# Resume point = how many vectors already added
start_i = int(index.ntotal)
if start_i >= TOTAL:
    print(f"[FAISS BUILD] Nothing to do. Index already has all {TOTAL} vectors.")
else:
    print(f"[FAISS BUILD] Will embed & add from position {start_i} to {TOTAL} (remaining: {TOTAL-start_i}).")

t0 = time.time()
micro_counter = 0
added_before  = index.ntotal

# Helper: embed a small list of texts quickly and add
def embed_and_add(texts):
    vecs = np.asarray(embedder.embed_documents(texts), dtype="float32")
    vecs = l2_normalize(vecs)
    index.add(vecs)

# Main loop: outer batches purely for readable progress, inner loop does tiny calls
i = start_i
outer_batch_id = 0
while i < TOTAL:
    outer_batch_id += 1
    batch_end = min(i + EMBED_BATCH, TOTAL)
    # Process this batch in MICRO_BATCH chunks
    j = i
    while j < batch_end:
        micro_end = min(j + MICRO_BATCH, batch_end)
        micro_texts = chunk_texts[j:micro_end]

        t_mb = time.time()
        embed_and_add(micro_texts)
        dt_mb = time.time() - t_mb

        micro_counter += 1
        if (micro_counter % PRINT_EVERY_MB) == 0:
            print(f"[FAISS BUILD] micro {micro_counter:>5} | added={index.ntotal:>7} / {TOTAL} | last_micro={dt_mb:.2f}s")

        if (micro_counter % SAVE_EVERY_MB) == 0:
            faiss.write_index(index, str(FAISS_INDEX_PATH))
            print(f"[FAISS BUILD]   checkpoint saved -> {FAISS_INDEX_PATH} (ntotal={index.ntotal})")

        j = micro_end

    # End of outer batch: summarize
    print(f"[FAISS BUILD] Batch {outer_batch_id}  range=[{i}:{batch_end})  batch_added={batch_end - i}  total={index.ntotal}/{TOTAL}")
    i = batch_end

# Final save
faiss.write_index(index, str(FAISS_INDEX_PATH))
print(f"[FAISS BUILD] ✅ Done. Added {index.ntotal - added_before} new vectors. Total={index.ntotal}/{TOTAL}.")
print(f"[FAISS BUILD] Elapsed: {time.time()-t0:.1f}s | Index saved at: {FAISS_INDEX_PATH}")


[FAISS BUILD] Wrote mapping parquet: medqa_chunk_mapping.parquet


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


config.json:   0%|          | 0.00/706 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


special_tokens_map.json:   0%|          | 0.00/695 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

model_optimized.onnx:   0%|          | 0.00/66.5M [00:00<?, ?B/s]

[FAISS BUILD] Embedding dim: 384
[FAISS BUILD] Starting a fresh FAISS index (IndexFlatIP).
[FAISS BUILD] Will embed & add from position 0 to 18559 (remaining: 18559).
[FAISS BUILD] micro     5 | added=     40 / 18559 | last_micro=4.44s
[FAISS BUILD] Batch 1  range=[0:64)  batch_added=64  total=64/18559
[FAISS BUILD] micro    10 | added=     80 / 18559 | last_micro=4.70s
[FAISS BUILD] micro    15 | added=    120 / 18559 | last_micro=2.35s
[FAISS BUILD] Batch 2  range=[64:128)  batch_added=64  total=128/18559
[FAISS BUILD] micro    20 | added=    160 / 18559 | last_micro=4.02s
[FAISS BUILD] Batch 3  range=[128:192)  batch_added=64  total=192/18559
[FAISS BUILD] micro    25 | added=    200 / 18559 | last_micro=4.20s
[FAISS BUILD] micro    30 | added=    240 / 18559 | last_micro=3.68s
[FAISS BUILD] Batch 4  range=[192:256)  batch_added=64  total=256/18559
[FAISS BUILD] micro    35 | added=    280 / 18559 | last_micro=4.04s
[FAISS BUILD] micro    40 | added=    320 / 18559 | last_micro=2.57

In [7]:
# You can re-run JUST this cell (and the ones below) for new experiments — no rebuild needed.

import json, numpy as np, pandas as pd, faiss
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings

# Load persistent assets
index = faiss.read_index(str(FAISS_INDEX_PATH))
m = pd.read_parquet(MAPPING_PATH)
chunk_ids   = m["id"].tolist()
chunk_texts = m["text"].tolist()
chunk_meta  = [json.loads(x) for x in m["meta"].tolist()]
assert index.ntotal == len(chunk_ids), "Index/mapping size mismatch"

# Query embedder
q_embedder = FastEmbedEmbeddings(model_name=EMBED_MODEL)

def retrieve_topk_faiss(query: str, k: int = TOP_K):
    q = np.asarray(q_embedder.embed_query(query), dtype="float32")[None, :]
    q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)
    sims, idxs = index.search(q, k)  # inner product on unit vectors == cosine
    idxs = idxs[0].tolist(); sims = sims[0].tolist()
    rows = [(chunk_ids[i], chunk_texts[i], float(sims[j])) for j, i in enumerate(idxs)]
    return rows

# Smoke test
q0 = df.iloc[0]["question"] if len(df) else "What is hypertension?"
print("Sample query:", q0[:120], "…")
print("Top IDs:", [x[0] for x in retrieve_topk_faiss(q0, k=2)])


Sample query: Who is at risk for Lymphocytic Choriomeningitis (LCM)? ? …
Top IDs: ['r5_ck1', 'r2542_ck0']


In [8]:
from groq import Groq

client_groq = Groq(api_key=GROQ_API_KEY)

def build_rag_messages(question: str, contexts: List[str], print_prompt: bool = True) -> List[Dict]:
    system_txt = (
        "You are a concise, evidence-focused medical assistant. "
        "Use the provided context passages to answer accurately. "
        "If the context does not contain the answer, say you don't know."
    )
    ctx_block = "\n\n".join([f"[Context {i+1}]\n{c}" for i, c in enumerate(contexts)])
    user_txt = (
        f"Question: {question}\n\n"
        f"Context Passages:\n{ctx_block}\n\n"
        "Instructions: Answer in 2–4 sentences. Cite which Context numbers support your statements (e.g., [1], [2]). "
        "If insufficient evidence, say 'I don't know based on the given context.'"
    )
    messages = [{"role":"system","content":system_txt},{"role":"user","content":user_txt}]
    if print_prompt:
        print("\n"+"="*88); print("[RAG PROMPT]")
        print("\n[SYSTEM]\n"+system_txt)
        print("\n[USER]\n"+user_txt)
        print("="*88)
    return messages

def groq_chat(messages, model=GEN_MODEL, temperature=0.0, max_tokens=512, top_p=1.0) -> str:
    r = client_groq.chat.completions.create(
        model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, messages=messages
    )
    return r.choices[0].message.content.strip()


In [9]:
sel = pd.read_csv(EVAL_SELECTION_CSV)
gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

# keep the same 3 questions you used earlier
eval_df = sel[["question", gold_col]].copy().head(3)
print("Evaluating SAME questions used in sparse:", len(eval_df))
print(eval_df["question"].to_string(index=False))


Evaluating SAME questions used in sparse: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


In [10]:
rows = []
for _, r in eval_df.iterrows():
    q  = str(r["question"]).strip()
    gt = str(r[gold_col]).strip()   # do NOT rename the source column; we just read it

    hits = retrieve_topk_faiss(q, k=TOP_K)
    contexts = [txt for (_id, txt, _s) in hits]

    messages = build_rag_messages(q, contexts, print_prompt=PRINT_PROMPTS)
    ans = groq_chat(messages, model=GEN_MODEL, temperature=0.0, max_tokens=512, top_p=1.0)

    rows.append({"question": q, "answer": ans, "contexts": contexts, "ground_truth": gt})

rag_results_df = pd.DataFrame(rows)
print("Collected rows:", len(rag_results_df))
rag_results_df.head(2)



[RAG PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Use the provided context passages to answer accurately. If the context does not contain the answer, say you don't know.

[USER]
Question: Do you have information about X-Rays

Context Passages:
[Context 1]
 signs of disease. X ray. An x ray is a picture created by using radiation and recorded on film or on a computer. The amount of radiation used is small. An x-ray technician performs the x ray at a hospital or an outpatient center, and a radiologista doctor who specializes in medical imaginginterprets the images. Anesthesia is not needed. The patient will lie on a table or stand during the x ray. The technician positions the x-ray machine over the spine area to look for "butterfly" vertebrae. The patient will hold his or her breath as the picture is taken so that the picture will not be blurry. The patient may be asked to change position for additional pictures. Abdominal ultrasound. Ultrasound uses a devic

Unnamed: 0,question,answer,contexts,ground_truth
0,Do you have information about X-Rays,X-rays are a type of electromagnetic wave that...,[ signs of disease. X ray. An x ray is a pictu...,Summary : X-rays are a type of radiation calle...
1,What are the symptoms of Alpha-ketoglutarate d...,The symptoms of Alpha-ketoglutarate dehydrogen...,[What are the signs and symptoms of Alpha-keto...,What are the signs and symptoms of Alpha-ketog...


In [11]:
from datasets import Dataset
from ragas import evaluate

# Import metrics; answer_correctness may not exist in older ragas
try:
    from ragas.metrics import (
        faithfulness,
        answer_relevancy,
        context_precision,
        context_recall,
        answer_correctness,
    )
    HAVE_CORR = True
except Exception:
    from ragas.metrics import (
        faithfulness,
        answer_relevancy,
        context_precision,
        context_recall,
    )
    HAVE_CORR = False

from langchain_groq import ChatGroq
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings as LC_FastEmbedEmbeddings

ragas_ds = Dataset.from_pandas(rag_results_df[["question","answer","contexts","ground_truth"]].copy())
ragas_llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama-3.1-8b-instant", temperature=0.0, max_retries=6, request_timeout=60)
ragas_embeddings = LC_FastEmbedEmbeddings(model_name=EMBED_MODEL)

def eval_metric_with_backoff(metric, tries=5, wait_seconds=70):
    last_err = None
    for t in range(1, tries+1):
        try:
            rep = evaluate(
                dataset=ragas_ds,
                metrics=[metric],
                llm=ragas_llm,
                embeddings=ragas_embeddings,
                is_async=False,
                raise_exceptions=False,
            )
            return rep
        except Exception as e:
            last_err = e
            print(f"[RAGAS] Retry {t}/{tries} after {wait_seconds}s due to: {e}")
            time.sleep(wait_seconds)
    raise RuntimeError(f"RAGAS metric failed after {tries} retries") from last_err

metrics = [faithfulness, answer_relevancy, context_precision, context_recall]
if HAVE_CORR:
    metrics.append(answer_correctness)

scores = {}
for m in metrics:
    rep = eval_metric_with_backoff(m)
    key = next(iter(rep.keys()))
    try:
        scores[key] = float(rep[key])
    except Exception:
        try:
            scores[key] = float(str(rep[key]))
        except Exception:
            scores[key] = rep[key]

print("\n=== RAGAS METRICS (averages) ===")
order = ["faithfulness", "answer_relevancy", "context_precision", "context_recall"]
if HAVE_CORR:
    order.append("answer_correctness")
for k in order:
    if k in scores:
        v = scores[k]
        print(f"{k}: {v:.3f}" if isinstance(v, float) else f"{k}: {v}")

# Derived hallucination
if "faithfulness" in scores and isinstance(scores["faithfulness"], float):
    print(f"hallucination (1 - faithfulness): {1.0 - scores['faithfulness']:.3f}")
else:
    print("hallucination (1 - faithfulness): N/A")



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from ragas.metrics._answer_correctness import AnswerCorrectness, answer_correctness

For example, replace imports like: `from langchain.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from ragas.metrics._context_entities_recall import (


Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]


=== RAGAS METRICS (averages) ===
faithfulness: 0.905
answer_relevancy: 0.582
context_precision: 1.000
context_recall: 0.609
answer_correctness: 0.635
hallucination (1 - faithfulness): 0.095


In [12]:
rag_results_df.to_csv(ANSWERS_CSV, index=False)
with open(METRICS_JSON, "w", encoding="utf-8") as f:
    json.dump(scores, f, indent=2)

print("Saved answers to:", ANSWERS_CSV)
print("Saved RAGAS metrics to:", METRICS_JSON)


Saved answers to: dense_rag_faiss_answers.csv
Saved RAGAS metrics to: dense_rag_faiss_ragas.json


### Albation summary

In [1]:
from pathlib import Path

# ===== Inputs =====

EVAL_SELECTION_CSV = "medquad_selected_questions.csv" # SAME questions CSV used for sparse
ANSWER_FIELD       = "answer"     # do NOT rename your CSV on disk


# ===== Embeddings / FAISS persistence =====
EMBED_MODEL      = "BAAI/bge-small-en-v1.5"   # 384-dim, fast
FAISS_INDEX_PATH = Path("faiss_medqa_bge_small.index")  # FAISS index file (persistent)
MAPPING_PATH     = Path("medqa_chunk_mapping.parquet")  # id->text/meta in the SAME order as FAISS


# ===== Retrieval / LLM =====
# === K values to test (you can change) ===
TOP_K_LIST = [3, 5, 10]   # keep modest to avoid rate limits
TOP_K         = 3
GROQ_API_KEY  = ""    # <-- put your key here
GEN_MODEL     = "llama-3.3-70b-versatile"
PRINT_PROMPTS = True

# ===== Outputs =====
OUT_DIR = "."




### Load FAISS + mapping (no rebuild) and set up retrieval

In [2]:
import json, numpy as np, pandas as pd, faiss
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings

# Load persistent assets
assert FAISS_INDEX_PATH.exists(), f"Missing {FAISS_INDEX_PATH}"
assert MAPPING_PATH.exists(), f"Missing {MAPPING_PATH}"

index = faiss.read_index(str(FAISS_INDEX_PATH))
m = pd.read_parquet(MAPPING_PATH)
chunk_ids   = m["id"].tolist()
chunk_texts = m["text"].tolist()
chunk_meta  = [json.loads(x) for x in m["meta"].tolist()]
assert index.ntotal == len(chunk_ids), "Index/mapping size mismatch"

# Query embedder (same model used to build the index)
q_embedder = FastEmbedEmbeddings(model_name=EMBED_MODEL)

def retrieve_topk_faiss(query: str, k: int):
    q = np.asarray(q_embedder.embed_query(query), dtype="float32")[None, :]
    q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-12)
    sims, idxs = index.search(q, k)  # cosine via IP on unit vectors
    idxs = idxs[0].tolist(); sims = sims[0].tolist()
    return [(chunk_ids[i], chunk_texts[i], float(sims[j])) for j, i in enumerate(idxs)]


In [3]:
### Load SAME questions CSV

In [4]:
sel = pd.read_csv(EVAL_SELECTION_CSV)
gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

# keep whatever count you want; small helps avoid rate limits
eval_df = sel[["question", gold_col]].copy().head(3)
print("Ablation questions:", len(eval_df))
print(eval_df["question"].to_string(index=False))


Ablation questions: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


### Prompt builder + Groq chat

In [5]:
from typing import List, Dict
from groq import Groq

client_groq = Groq(api_key=GROQ_API_KEY)

def build_rag_messages(question: str, contexts: List[str], print_prompt: bool = False) -> List[Dict]:
    system_txt = (
        "You are a concise, evidence-focused medical assistant. "
        "Use the provided context passages to answer accurately. "
        "If the context does not contain the answer, say you don't know."
    )
    ctx_block = "\n\n".join([f"[Context {i+1}]\n{c}" for i, c in enumerate(contexts)])
    user_txt = (
        f"Question: {question}\n\n"
        f"Context Passages:\n{ctx_block}\n\n"
        "Instructions: Answer in 2–4 sentences. Cite which Context numbers support your statements (e.g., [1], [2]). "
        "If insufficient evidence, say 'I don't know based on the given context.'"
    )
    msgs = [{"role":"system","content":system_txt},{"role":"user","content":user_txt}]
    if print_prompt:
        print("\n"+"="*88); print("[RAG PROMPT]")
        print("\n[SYSTEM]\n"+system_txt)
        print("\n[USER]\n"+user_txt)
        print("="*88)
    return msgs

def groq_chat(messages, model=GEN_MODEL, temperature=0.0, max_tokens=512, top_p=1.0) -> str:
    r = client_groq.chat.completions.create(
        model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, messages=messages
    )
    return r.choices[0].message.content.strip()


### Run ablation (loop over K), evaluate with RAGAS, and aggregate

In [6]:
from datasets import Dataset
from ragas import evaluate

# metrics (answer_correctness may not exist in older ragas)
try:
    from ragas.metrics import (
        faithfulness, answer_relevancy, context_precision, context_recall, answer_correctness
    )
    HAVE_CORR = True
except Exception:
    from ragas.metrics import (
        faithfulness, answer_relevancy, context_precision, context_recall
    )
    HAVE_CORR = False

from langchain_groq import ChatGroq
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings as LC_FastEmbedEmbeddings

ragas_llm = ChatGroq(groq_api_key=GROQ_API_KEY, model_name="llama-3.1-8b-instant",
                     temperature=0.0, max_retries=6, request_timeout=60)
ragas_embeddings = LC_FastEmbedEmbeddings(model_name=EMBED_MODEL)

def eval_with_backoff(ragas_ds, metric, tries=5, wait_seconds=70):
    last_err = None
    for t in range(1, tries+1):
        try:
            rep = evaluate(
                dataset=ragas_ds,
                metrics=[metric],
                llm=ragas_llm,
                embeddings=ragas_embeddings,
                is_async=False,
                raise_exceptions=False,
            )
            return rep
        except Exception as e:
            last_err = e
            print(f"[RAGAS] Retry {t}/{tries} after {wait_seconds}s due to: {e}")
            import time; time.sleep(wait_seconds)
    raise RuntimeError("RAGAS metric failed") from last_err

ablation_rows = []

for K in TOP_K_LIST:
    print("\n" + "="*88)
    print(f"[ABLATION] Running Dense RAG with TOP_K={K}")
    print("="*88)

    # Generate answers for this K
    rows = []
    for _, r in eval_df.iterrows():
        q  = str(r["question"]).strip()
        gt = str(r[gold_col]).strip()

        hits = retrieve_topk_faiss(q, k=K)
        contexts = [txt for (_id, txt, _s) in hits]
        messages = build_rag_messages(q, contexts, print_prompt=PRINT_PROMPTS)
        ans = groq_chat(messages, model=GEN_MODEL, temperature=0.0, max_tokens=512, top_p=1.0)

        rows.append({"question": q, "answer": ans, "contexts": contexts, "ground_truth": gt})

    run_df = pd.DataFrame(rows)
    answers_csv_k = f"{OUT_DIR}/dense_rag_k{K}_answers.csv"
    run_df.to_csv(answers_csv_k, index=False)
    print(f"[ABLATION] Saved answers: {answers_csv_k}")

    # RAGAS for this K
    ragas_ds = Dataset.from_pandas(run_df[["question","answer","contexts","ground_truth"]].copy())
    metrics = [faithfulness, answer_relevancy, context_precision, context_recall]
    if HAVE_CORR: metrics.append(answer_correctness)

    scores = {}
    for m in metrics:
        rep = eval_with_backoff(ragas_ds, m)
        key = next(iter(rep.keys()))
        try: scores[key] = float(rep[key])
        except Exception:
            try: scores[key] = float(str(rep[key]))
            except Exception: scores[key] = rep[key]

    import json
    metrics_json_k = f"{OUT_DIR}/dense_rag_k{K}_ragas.json"
    with open(metrics_json_k, "w", encoding="utf-8") as f:
        json.dump(scores, f, indent=2)
    print(f"[ABLATION] Saved RAGAS metrics: {metrics_json_k}")

    row = {"TOP_K": K}
    for kname in ["faithfulness","answer_relevancy","context_precision","context_recall","answer_correctness"]:
        if kname in scores and isinstance(scores[kname], float):
            row[kname] = scores[kname]
    ablation_rows.append(row)

# Aggregate results
ablate_df = pd.DataFrame(ablation_rows).sort_values(by="faithfulness", ascending=False)
summary_csv = f"{OUT_DIR}/dense_rag_ablation_summary.csv"
ablate_df.to_csv(summary_csv, index=False)
print("\n=== ABLATION SUMMARY (higher is better) ===")
print(ablate_df.to_string(index=False))
print("Saved ablation summary to:", summary_csv)



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from ragas.metrics._answer_correctness import AnswerCorrectness, answer_correctness

For example, replace imports like: `from langchain.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from ragas.metrics._context_entities_recall import (



[ABLATION] Running Dense RAG with TOP_K=3

[RAG PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Use the provided context passages to answer accurately. If the context does not contain the answer, say you don't know.

[USER]
Question: Do you have information about X-Rays

Context Passages:
[Context 1]
 signs of disease. X ray. An x ray is a picture created by using radiation and recorded on film or on a computer. The amount of radiation used is small. An x-ray technician performs the x ray at a hospital or an outpatient center, and a radiologista doctor who specializes in medical imaginginterprets the images. Anesthesia is not needed. The patient will lie on a table or stand during the x ray. The technician positions the x-ray machine over the spine area to look for "butterfly" vertebrae. The patient will hold his or her breath as the picture is taken so that the picture will not be blurry. The patient may be asked to change position for additional pictures. Ab

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

[ABLATION] Saved RAGAS metrics: ./dense_rag_k3_ragas.json

[ABLATION] Running Dense RAG with TOP_K=5

[RAG PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Use the provided context passages to answer accurately. If the context does not contain the answer, say you don't know.

[USER]
Question: Do you have information about X-Rays

Context Passages:
[Context 1]
 signs of disease. X ray. An x ray is a picture created by using radiation and recorded on film or on a computer. The amount of radiation used is small. An x-ray technician performs the x ray at a hospital or an outpatient center, and a radiologista doctor who specializes in medical imaginginterprets the images. Anesthesia is not needed. The patient will lie on a table or stand during the x ray. The technician positions the x-ray machine over the spine area to look for "butterfly" vertebrae. The patient will hold his or her breath as the picture is taken so that the picture will not be blurry. The patient m

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

[ABLATION] Saved RAGAS metrics: ./dense_rag_k5_ragas.json

[ABLATION] Running Dense RAG with TOP_K=10

[RAG PROMPT]

[SYSTEM]
You are a concise, evidence-focused medical assistant. Use the provided context passages to answer accurately. If the context does not contain the answer, say you don't know.

[USER]
Question: Do you have information about X-Rays

Context Passages:
[Context 1]
 signs of disease. X ray. An x ray is a picture created by using radiation and recorded on film or on a computer. The amount of radiation used is small. An x-ray technician performs the x ray at a hospital or an outpatient center, and a radiologista doctor who specializes in medical imaginginterprets the images. Anesthesia is not needed. The patient will lie on a table or stand during the x ray. The technician positions the x-ray machine over the spine area to look for "butterfly" vertebrae. The patient will hold his or her breath as the picture is taken so that the picture will not be blurry. The patient 

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Runner in Executor raised an exception
Traceback (most recent call last):
  File "C:\Users\anish\anaconda3\Lib\site-packages\ragas\executor.py", line 78, in _aresults
    r = await future
        ^^^^^^^^^^^^
  File "C:\Users\anish\anaconda3\Lib\asyncio\tasks.py", line 615, in _wait_for_one
    return f.result()  # May raise f.exception().
           ^^^^^^^^^^
  File "C:\Users\anish\anaconda3\Lib\site-packages\ragas\executor.py", line 37, in sema_coro
    return await coro
           ^^^^^^^^^^
  File "C:\Users\anish\anaconda3\Lib\site-packages\ragas\executor.py", line 111, in wrapped_callable_async
    return counter, await callable(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\anish\anaconda3\Lib\site-packages\ragas\metrics\base.py", line 125, in ascore
    raise e
  File "C:\Users\anish\anaconda3\Lib\site-packages\ragas\metrics\base.py", line 121, in ascore
    score = await self._ascore(row=row, callbacks=group_cm, is_async=is_async)
       

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

[ABLATION] Saved RAGAS metrics: ./dense_rag_k10_ragas.json

=== ABLATION SUMMARY (higher is better) ===
 TOP_K  faithfulness  answer_relevancy  context_precision  context_recall  answer_correctness
     3      0.874510          0.564356           1.000000        0.608974            0.626136
    10      0.809524          0.872146           0.795488        0.516746            0.605371
     5      0.734300          0.895798           1.000000        0.833333            0.639776
Saved ablation summary to: ./dense_rag_ablation_summary.csv
