In [1]:
!pip install --upgrade pip
!pip install faiss-cpu sentence-transformers pandas groq datasets ragas langchain_groq langchain_community


Collecting pip
  Using cached pip-25.2-py3-none-any.whl.metadata (4.7 kB)
Using cached pip-25.2-py3-none-any.whl (1.8 MB)


ERROR: To modify pip, please run the following command:
C:\Users\anish\anaconda3\python.exe -m pip install --upgrade pip




In [2]:
import os, warnings
warnings.filterwarnings("ignore")

# Keep tokenizers & BLAS quiet and single-threaded to avoid kernel crashes
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"


In [3]:
### config (uses your saved files, FastEmbed, same questions CSV)

In [4]:
# ===== Hard-coded artifacts you shared =====
from pathlib import Path

FAISS_INDEX_PATH = Path("faiss_medqa_bge_small.index")   # your FAISS file
MAPPING_PATH     = Path("medqa_chunk_mapping.parquet")   # id/text/meta table (same order as index)

# If you ever need the chunks parquet, keep it here (not required for this run):
CHUNKS_PARQUET   = Path("medqa_chunks_token.parquet")

# ===== Retrieval must match how the index was built =====
EMBED_MODEL = "BAAI/bge-small-en-v1.5"   # you used BGE-small; keep the same
USE_COSINE_IP = True                     # BGE is typically normalized + IndexFlatIP
USE_BGE_QUERY_INSTRUCTION = True         # set False if you did NOT use this during build
BGE_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "

TOP_K_LIST = [3]                          # keep simple; you can add [5,10] later

# ===== SAME questions CSV you use for prompting =====
EVAL_SELECTION_CSV = "medquad_selected_questions.csv"
ANSWER_FIELD       = "answer"             # if your CSV has 'gold', we auto-detect below
N_EVAL             = 3                    # reduce if you hit rate limits

# ===== Groq LLM (only this changes for your experiment) =====
from groq import Groq
GROQ_API_KEY = ""
GEN_MODEL    = "llama-3.3-70b-versatile"     # <— just change this to test another Groq LLM

PRINT_PROMPTS = False

# ===== RAGAS (can switch off to avoid rate limits) =====
RUN_RAGAS_EVAL   = True
RAGAS_JUDGE_MODEL = "llama-3.1-8b-instant"
RAGAS_EMBED_MODEL = EMBED_MODEL           # FastEmbed in LangChain

# ===== Output =====
OUT_DIR = "."


In [5]:
### keep the kernel light (threads + env)

In [6]:
import os, warnings
warnings.filterwarnings("ignore")

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"


In [7]:
### load questions (same CSV)

In [8]:
import pandas as pd

sel = pd.read_csv(EVAL_SELECTION_CSV)

gold_col = "gold" if "gold" in sel.columns else (ANSWER_FIELD if ANSWER_FIELD in sel.columns else None)
assert gold_col is not None, f"Selection file must contain either 'gold' or '{ANSWER_FIELD}'"

eval_df = sel[["question", gold_col]].copy()
if isinstance(N_EVAL, int):
    eval_df = eval_df.head(N_EVAL)

print(f"Eval questions: {len(eval_df)}")
print(eval_df['question'].to_string(index=False))


Eval questions: 3
              Do you have information about X-Rays
What are the symptoms of Alpha-ketoglutarate de...
What are the treatments for GLUT1 deficiency sy...


In [9]:
### load FAISS + mapping (no reindexing, no re-chunking)

In [10]:
import faiss, json

assert FAISS_INDEX_PATH.exists(), f"Missing {FAISS_INDEX_PATH}"
assert MAPPING_PATH.exists(), f"Missing {MAPPING_PATH}"

index = faiss.read_index(str(FAISS_INDEX_PATH))
try:
    faiss.omp_set_num_threads(1)  # keep FAISS single-threaded — fewer crashes
except Exception:
    pass

m = pd.read_parquet(MAPPING_PATH)

# Accept either {id,text,meta} or similar column names:
ID_COL_CAND    = ["id","chunk_id","doc_id"]
TEXT_COL_CAND  = ["text","chunk_text","content","passage","body","doc"]
META_COL_CAND  = ["meta","metadata"]

id_col   = next((c for c in ID_COL_CAND if c in m.columns), None)
text_col = next((c for c in TEXT_COL_CAND if c in m.columns), None)
meta_col = next((c for c in META_COL_CAND if c in m.columns), None)

assert id_col and text_col, f"Mapping parquet must have id/text. Found: {list(m.columns)}"

chunk_ids   = m[id_col].astype(str).tolist()
chunk_texts = m[text_col].astype(str).tolist()
if meta_col:
    try:
        chunk_meta  = [json.loads(x) if isinstance(x, str) else x for x in m[meta_col].tolist()]
    except Exception:
        chunk_meta = m[meta_col].tolist()
else:
    chunk_meta = [{} for _ in range(len(chunk_ids))]

print("FAISS ntotal:", index.ntotal, "| mapping rows:", len(chunk_ids))
assert index.ntotal == len(chunk_ids), "Index/mapping size mismatch — they must be built together."


FAISS ntotal: 18559 | mapping rows: 18559


In [11]:
### FastEmbed (same as your build) + retrieval helper

In [12]:
import numpy as np
from langchain_community.embeddings.fastembed import FastEmbedEmbeddings

# Light-weight CPU embedder (downloads model once; much lighter than sentence-transformers)
q_embedder = FastEmbedEmbeddings(model_name=EMBED_MODEL)

def _format_query(q: str) -> str:
    if USE_BGE_QUERY_INSTRUCTION:
        return BGE_QUERY_PREFIX + q
    return q

def retrieve_topk_faiss(query: str, k: int):
    qtext = _format_query(query)
    # embed_query returns a single vector
    q_vec = np.asarray(q_embedder.embed_query(qtext), dtype="float32")[None, :]
    if USE_COSINE_IP:
        # Normalize for cosine via inner product
        q_vec /= (np.linalg.norm(q_vec, axis=1, keepdims=True) + 1e-12)
    sims, idxs = index.search(q_vec, k)   # (1,k)
    idxs = idxs[0].tolist(); sims = sims[0].tolist()
    hits = []
    for rank, (i, s) in enumerate(zip(idxs, sims), start=1):
        if i < 0: continue
        hits.append((chunk_ids[i], chunk_texts[i], float(s)))
    return hits


In [13]:
### RAG prompt + Groq chat (same template you used)

In [14]:
from typing import List, Dict
from groq import Groq

client_groq = Groq(api_key=GROQ_API_KEY)

def build_rag_messages(question: str, contexts: List[str], print_prompt: bool = False) -> List[Dict]:
    system_txt = (
        "You are a concise, evidence-focused medical assistant. "
        "Use the provided context passages to answer accurately. "
        "If the context does not contain the answer, say you don't know."
    )
    ctx_block = "\n\n".join([f"[Context {i+1}]\n{c}" for i, c in enumerate(contexts)])
    user_txt = (
        f"Question: {question}\n\n"
        f"Context Passages:\n{ctx_block}\n\n"
        "Instructions: Answer in 2–4 sentences. Cite which Context numbers support your statements (e.g., [1], [2]). "
        "If insufficient evidence, say 'I don't know based on the given context.'"
    )
    msgs = [{"role":"system","content":system_txt},{"role":"user","content":user_txt}]
    if print_prompt:
        print("\n"+"="*88); print("[RAG PROMPT]")
        print("\n[SYSTEM]\n"+system_txt)
        print("\n[USER]\n"+user_txt)
        print("="*88)
    return msgs

def groq_chat(messages, model=GEN_MODEL, temperature=0.0, max_tokens=512, top_p=1.0) -> str:
    r = client_groq.chat.completions.create(
        model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, messages=messages
    )
    return r.choices[0].message.content.strip()


In [15]:
### run Dense RAG (reuse index; no heavy memory); save answers

In [16]:
rows_all = []

for K in TOP_K_LIST:
    print("\n" + "="*88)
    print(f"[RUN] Dense RAG with TOP_K={K}  |  GEN_MODEL={GEN_MODEL}")
    print("="*88)

    rows = []
    for _, r in eval_df.iterrows():
        q  = str(r["question"]).strip()
        gt = str(r[gold_col]).strip()

        hits = retrieve_topk_faiss(q, k=K)
        contexts = [txt for (_id, txt, _s) in hits]
        msgs = build_rag_messages(q, contexts, print_prompt=PRINT_PROMPTS)
        ans = groq_chat(msgs, model=GEN_MODEL, temperature=0.0, max_tokens=512, top_p=1.0)

        rows.append({"question": q, "answer": ans, "contexts": contexts, "ground_truth": gt, "top_k": K})

    run_df = pd.DataFrame(rows)
    out_csv = f"{OUT_DIR}/dense_rag_k{K}_answers.csv"
    run_df.to_csv(out_csv, index=False)
    print(f"[RUN] Saved answers: {out_csv}  (rows={len(run_df)})")
    rows_all.append(run_df)

rag_results_df = pd.concat(rows_all, ignore_index=True)
rag_results_df.head(2)



[RUN] Dense RAG with TOP_K=3  |  GEN_MODEL=llama-3.3-70b-versatile
[RUN] Saved answers: ./dense_rag_k3_answers.csv  (rows=3)


Unnamed: 0,question,answer,contexts,ground_truth,top_k
0,Do you have information about X-Rays,X-rays are a type of radiation called electrom...,[ signs of disease. X ray. An x ray is a pictu...,Summary : X-rays are a type of radiation calle...,3
1,What are the symptoms of Alpha-ketoglutarate d...,The symptoms of Alpha-ketoglutarate dehydrogen...,[What are the signs and symptoms of Alpha-keto...,What are the signs and symptoms of Alpha-ketog...,3


In [17]:
### RAGAS with FastEmbed (robust, one metric at a time w/ retry)

In [18]:
if RUN_RAGAS_EVAL:
    try:
        from datasets import Dataset
        from ragas import evaluate
        # try to import answer_correctness if available
        try:
            from ragas.metrics import faithfulness, answer_relevancy, context_precision, context_recall, answer_correctness
            METRICS = [faithfulness, answer_relevancy, context_precision, context_recall, answer_correctness]
        except Exception:
            from ragas.metrics import faithfulness, answer_relevancy, context_precision, context_recall
            METRICS = [faithfulness, answer_relevancy, context_precision, context_recall]

        from langchain_groq import ChatGroq
        from langchain_community.embeddings.fastembed import FastEmbedEmbeddings as LC_FastEmbedEmbeddings

        ragas_llm = ChatGroq(
            groq_api_key=GROQ_API_KEY, model_name=RAGAS_JUDGE_MODEL,
            temperature=0.0, max_retries=6, request_timeout=60
        )
        ragas_embeddings = LC_FastEmbedEmbeddings(model_name=RAGAS_EMBED_MODEL)

        # evaluate per K to keep loads small and avoid 429s
        ablation_rows = []
        for K in sorted(rag_results_df["top_k"].unique()):
            sub = rag_results_df[rag_results_df["top_k"] == K][["question","answer","contexts","ground_truth"]].copy()
            ds  = Dataset.from_pandas(sub)

            scores = {}
            for m in METRICS:
                for attempt in range(5):
                    try:
                        rep = evaluate(dataset=ds, metrics=[m], llm=ragas_llm, embeddings=ragas_embeddings,
                                       is_async=False, raise_exceptions=False)
                        key = next(iter(rep.keys()))
                        try: scores[key] = float(rep[key])
                        except: scores[key] = float(str(rep[key]))
                        break
                    except Exception as e:
                        import time
                        print(f"[RAGAS][K={K}] retry {attempt+1}/5 due to: {e}")
                        time.sleep(30 + attempt*20)
                else:
                    scores[str(m)] = None

            row = {"TOP_K": int(K)}
            for k in ["faithfulness","answer_relevancy","context_precision","context_recall","answer_correctness"]:
                if k in scores and isinstance(scores[k], float):
                    row[k] = scores[k]
            ablation_rows.append(row)

        ablate_df = pd.DataFrame(ablation_rows).sort_values(by="faithfulness", ascending=False)
        out_summary = f"{OUT_DIR}/dense_rag_ragas_summary.csv"
        ablate_df.to_csv(out_summary, index=False)
        print("\n=== RAGAS SUMMARY (Dense RAG) ===")
        print(ablate_df.to_string(index=False))
        print("Saved:", out_summary)

    except Exception as e:
        print("[RAGAS WARNING] Skipped due to error:", e)



For example, replace imports like: `from langchain_core.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from ragas.metrics._answer_correctness import AnswerCorrectness, answer_correctness

For example, replace imports like: `from langchain.pydantic_v1 import BaseModel`
with: `from pydantic import BaseModel`
or the v1 compatibility namespace if you are working in a code base that has not been fully upgraded to pydantic 2 yet. 	from pydantic.v1 import BaseModel

  from ragas.metrics._context_entities_recall import (


Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]


=== RAGAS SUMMARY (Dense RAG) ===
 TOP_K  faithfulness  answer_relevancy  context_precision  context_recall  answer_correctness
     3      0.916667          0.901494                1.0        0.608974             0.69494
Saved: ./dense_rag_ragas_summary.csv
