## Import

In [1]:
import warnings
# Suppress Hugging Face and other warnings
warnings.filterwarnings("ignore")


In [None]:
from huggingface_hub import login

login("")

In [3]:
import os
import pandas as pd
from langchain.prompts import PromptTemplate
import requests
from transformers import AutoTokenizer, AutoModelForCausalLM,pipeline
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
import joblib  # for saving the embedding model

import re
from tqdm import tqdm

from langchain_huggingface import HuggingFacePipeline

## Loading and Describing Dataset

In [4]:
# 1. Load data from Hugging Face Hub
df_passages = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-bioasq/data/passages.parquet/part.0.parquet"
)
df_test = pd.read_parquet(
    "hf://datasets/rag-datasets/rag-mini-bioasq/data/test.parquet/part.0.parquet"
)

# 2. Reset index cleanly
df_passages = df_passages.reset_index()
df_test = df_test.reset_index()

df_passages = df_passages.rename(columns={'id': 'id'})
df_test = df_test.rename(columns={'id': 'id'})

# 3. Create local save folder
save_path = "dataset/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

# 4. Save locally as Parquet (fast & space-efficient)
df_passages.to_parquet(save_path + "passages.parquet", index=False)
df_test.to_parquet(save_path + "test.parquet", index=False)


In [5]:
# Reload locally to verify
df_passages = pd.read_parquet(save_path + "passages.parquet")
df_test = pd.read_parquet(save_path + "test.parquet")

In [6]:
# Quick checks
print("Passages shape:", df_passages.shape)
print("Test shape:", df_test.shape)

print("\nPassages sample:")
df_passages.head(2)

Passages shape: (40221, 2)
Test shape: (4719, 4)

Passages sample:


Unnamed: 0,id,passage
0,9797,New data on viruses isolated from patients wit...
1,11906,We describe an improved method for detecting d...


In [7]:
print("\nTest sample:")
df_test.head(2)


Test sample:


Unnamed: 0,id,question,answer,relevant_passage_ids
0,0,Is Hirschsprung disease a mendelian or a multi...,"Coding sequence mutations in RET, GDNF, EDNRB,...","[20598273, 6650562, 15829955, 15617541, 230011..."
1,1,List signaling molecules (ligands) that intera...,The 7 known EGFR ligands are: epidermal growt...,"[23821377, 24323361, 23382875, 22247333, 23787..."


In [8]:
# Clean: Remove rows where passage is null or duplicated
df_passages = df_passages.dropna(subset=['passage'])
df_passages = df_passages.drop_duplicates(subset='passage')

# Clean test DataFrame
df_test = df_test.dropna(subset=['question', 'answer'])
df_test = df_test.drop_duplicates(subset='question')

print("Passages:", df_passages.shape)
print("Test Q&A:", df_test.shape)

# Extract both chunks and doc_ids (removing nulls ensures alignment)
chunks = df_passages['passage'].tolist()
doc_ids = df_passages['id'].tolist()

Passages: (27975, 2)
Test Q&A: (4719, 4)


In [9]:
df_passages.head(5)

Unnamed: 0,id,passage
0,9797,New data on viruses isolated from patients wit...
1,11906,We describe an improved method for detecting d...
2,16083,We have studied the effects of curare on respo...
3,23188,Kinetic and electrophoretic properties of 230-...
4,23469,Male Wistar specific-pathogen-free rats aged 2...


In [10]:
df_test.head(10)

Unnamed: 0,id,question,answer,relevant_passage_ids
0,0,Is Hirschsprung disease a mendelian or a multi...,"Coding sequence mutations in RET, GDNF, EDNRB,...","[20598273, 6650562, 15829955, 15617541, 230011..."
1,1,List signaling molecules (ligands) that intera...,The 7 known EGFR ligands are: epidermal growt...,"[23821377, 24323361, 23382875, 22247333, 23787..."
2,2,Is the protein Papilin secreted?,"Yes, papilin is a secreted protein","[21784067, 19297413, 15094122, 7515725, 332004..."
3,3,Are long non coding RNAs spliced?,Long non coding RNAs appear to be spliced thro...,"[22955974, 21622663, 22707570, 22955988, 24285..."
4,4,Is RANKL secreted from the cells?,Receptor activator of nuclear factor κB ligand...,"[22867712, 23827649, 21618594, 23835909, 24265..."
5,5,Does metformin interfere thyroxine absorption?,No. There are not reported data indicating tha...,[26191653]
6,6,Which miRNAs could be used as potential biomar...,"miR-200a, miR-100, miR-141, miR-200b, miR-200c...","[23918241, 23621186, 22246341, 23978303, 23888..."
7,7,Which acetylcholinesterase inhibitors are used...,Pyridostigmine and neostygmine are acetylcholi...,"[21328290, 21133188, 15610702, 20663605, 21815..."
8,8,Has Denosumab (Prolia) been approved by FDA?,"Yes, Denosumab was approved by the FDA in 2010.","[24114694, 22540167, 21129866, 21170699, 23956..."
9,9,List the human genes encoding for the dishevel...,DVL-1 DVL-2 DVL-3,"[16457155, 12883684, 19618470, 23836490, 88173..."


In [11]:
df_test.shape

(4719, 4)

In [12]:
df_passages.head(10)

Unnamed: 0,id,passage
0,9797,New data on viruses isolated from patients wit...
1,11906,We describe an improved method for detecting d...
2,16083,We have studied the effects of curare on respo...
3,23188,Kinetic and electrophoretic properties of 230-...
4,23469,Male Wistar specific-pathogen-free rats aged 2...
5,24032,Tyrosine hydroxylase (TH) and phenylethanolami...
6,30666,Hemolytic anemia is a well-recognized complica...
7,58611,(1) The RNA replicase induced by bacteriophage...
8,61441,Mice were inoculated with human sarcoid tissue...
9,83311,Bleomycin is potentially capable of inducing a...


## Building Faiss Index and saving locally

In [13]:
# =========================
# Build documents
# =========================
print("Building LangChain documents with metadata...")
docs = [
    Document(page_content=text, metadata={"doc_id": doc_id})
    for text, doc_id in zip(chunks, doc_ids)
]

# =========================
# Create FAISS index
# =========================
print("Building FAISS index...")
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_db = FAISS.from_documents(docs, embedding_model)
print("FAISS index created.")


Building LangChain documents with metadata...
Building FAISS index...
FAISS index created.


In [14]:
# =========================
# Save FAISS index and embedder
# =========================
save_path = "faiss_index_folder"

if not os.path.exists(save_path):
    os.makedirs(save_path)

# Save FAISS index
vector_db.save_local(save_path)
print(f"FAISS index saved to '{save_path}'.")

# Save embedding model
save_path = "embedder_model_folder/"

if not os.path.exists(save_path):
    os.makedirs(save_path)
joblib.dump(embedding_model, os.path.join(save_path, "model.pkl"))
print(f"Embedding model saved to '{save_path}'.")


FAISS index saved to 'faiss_index_folder'.
Embedding model saved to 'embedder_model_folder/'.


In [16]:
import os
import joblib
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings

# =========================
# Load embedder
# =========================
save_path = "faiss_index_folder"
embedder_path = "embedder_model_folder/"

print("Loading embedding model from disk...")
embedding_model = joblib.load(os.path.join(embedder_path, "embedding_model.pkl"))
print("Embedding model loaded successfully.")

# =========================
# Load FAISS index
# =========================
print("Loading FAISS index from disk...")
vector_db_loaded = FAISS.load_local(save_path, embedding_model, allow_dangerous_deserialization=True)
print("FAISS index loaded successfully.")


Loading embedding model from disk...
Embedding model loaded successfully.
Loading FAISS index from disk...
FAISS index loaded successfully.


In [16]:
# =========================
# Test retrieval
# =========================
query = "What causes heart attack?"
results = vector_db_loaded.similarity_search(query, k=3)
for i, doc in enumerate(results, start=1):
    print(f"\nResult {i}:")
    print(f"Doc ID: {doc.metadata['doc_id']}")
    print(f"Content: {doc.page_content[:100]}...")



Result 1:
Doc ID: 17322504
Content: Most cases of sudden cardiac death in young athletes (<35 years) are caused by 
inherited cardiomyop...

Result 2:
Doc ID: 1450882
Content: Sudden death in athletes is a rare but tragic occurrence. Congenital 
cardiovascular abnormalities, ...

Result 3:
Doc ID: 9858396
Content: The athlete projects the ultimate image of well-being in the health status 
spectrum. Nevertheless, ...


## Downloading 'gemma-2b-it model' and saving in local folder

In [17]:
model_name = "google/gemma-2b-it"  # or "google/gemma-7b-it"
save_dir = "saved_models/gemma"

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# Download and save tokenizer & model
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.save_pretrained(save_dir)

model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto").to("mps")
model.save_pretrained(save_dir)

print(f"✅ Saved {model_name} to {save_dir}")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Saved google/gemma-2b-it to saved_models/gemma


In [18]:
device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print("Using device:", device)

model_dir = "saved_models/gemma"

tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    trust_remote_code=True
).to(device).eval()

print("✅ Gemma loaded locally and ready")


Using device: mps


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Gemma loaded locally and ready


## Creating Gemma model pipeline

In [19]:
# Create HF text-generation pipeline =====
gen_pipeline = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,
    num_beams=4,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    device=0 if device.type == "cuda" else -1
)

# Wrap in LangChain's HuggingFacePipeline =====
llm_rewriter = HuggingFacePipeline(pipeline=gen_pipeline)

# Prompt Template =====
rewrite_prompt = PromptTemplate.from_template(
    """Provide several specific rewritten versions of the biomedical question, ranging from broad to precise. 
Output as 'Option 1', 'Option 2', etc.

Question: {question}
Rewritten:"""
)

# Chain =====
rewrite_chain = rewrite_prompt | llm_rewriter


Device set to use mps:0


In [20]:
# 6. Test =====
question = "What are the causes of heart attack?"
result = rewrite_chain.invoke({"question": question})
print(result)

Provide several specific rewritten versions of the biomedical question, ranging from broad to precise. 
Output as 'Option 1', 'Option 2', etc.

Question: What are the causes of heart attack?
Rewritten:
1. What are the risk factors for heart attack?
2. What are the contributing factors to the development of heart disease?
3. What are the underlying causes of cardiovascular events?
4. What are the factors that increase the risk of heart attack?
5. What are the determinants of heart attack risk?


## Main RAG code

In [21]:
def retrieve_top_k(query, k=5):
    results = vector_db.similarity_search(query, k=k)
    return [(doc.page_content, doc.metadata.get("doc_id")) for doc in results]

In [22]:
def rewrite_query(question, preferred_option="Option 2"):
    output = rewrite_chain.invoke({"question": question})
    
    # Ensure we get the actual string from LangChain's output
    if hasattr(output, "content"):  
        output_text = output.content or ""
    else:  
        output_text = str(output)

    for line in output_text.split("\n"):
        if line.strip().startswith(preferred_option):
            return line.split(":", 1)[1].strip()

    # Fallback: return first option or original question
    for line in output_text.split("\n"):
        if line.strip().startswith("Option 1"):
            return line.split(":", 1)[1].strip()

    return question


In [23]:
INTENT_PATTERNS = {
    "symptoms":  re.compile(r"\b(symptom|sign|clinical presentation|manifestation)\b", re.I),
    "causes":    re.compile(r"\b(cause|etiolog|due to|result[s]? from|lead[s]? to|because)\b", re.I),
    "treatments":re.compile(r"\b(treat|therapy|management|intervention|drug|medication)\b", re.I),
    "risks":     re.compile(r"\b(risk factor|risk|predispos|associated with|correlate)\b", re.I),
    "mechanisms":re.compile(r"\b(pathophysiolog|mechanism|how.*work|underlying process)\b", re.I),
    "definition":re.compile(r"\b(what is|define|definition)\b", re.I),
}

def detect_question_intent(question: str) -> str:
    q = question.lower()
    # Priority: explicit patterns
    for intent, pat in INTENT_PATTERNS.items():
        if pat.search(q):
            return intent
    # Heuristics
    if any(w in q for w in ["cause", "etiology", "why"]):
        return "causes"
    if any(w in q for w in ["symptom", "sign", "presentation"]):
        return "symptoms"
    if any(w in q for w in ["treat", "therapy", "manage", "medicat"]):
        return "treatments"
    if any(w in q for w in ["risk", "predispos"]):
        return "risks"
    if any(w in q for w in ["mechanism", "how does it work", "pathophys"]):
        return "mechanisms"
    if q.startswith("what is") or q.startswith("define"):
        return "definition"
    return "general"

_SENT_SPLIT = re.compile(r"(?<=[\.\?\!])\s+")

def split_sentences(text: str):
    text = re.sub(r"\s+", " ", text).strip()
    return [s.strip() for s in _SENT_SPLIT.split(text) if s.strip()]

# ----------- 2) Intent-specific context filtering (precision boost) ----------

INTENT_CUE_SETS = {
    "causes": [
        r"\b(cause|caused by|etiolog\w*|due to|because|results? from|triggered by|lead[s]? to)\b"
    ],
    "symptoms": [
        r"\b(symptom|signs?|presents? with|manifestation)\b",
        r"\b(headache|fever|cough|dyspnea|fatigue|myalgia|rash|nausea|diarrhea|vomit|pain|loss of smell|loss of taste)\b"
    ],
    "treatments": [
        r"\b(treat|therapy|therapies|management|intervention|drug|medication|dose|dosing)\b"
    ],
    "risks": [
        r"\b(risk factor|risk|predispos|associated with|correlat)\b"
    ],
    "mechanisms": [
        r"\b(pathophysiolog\w*|mechanism|biologic\w* process|immune|inflammation|autoimmun\w*)\b"
    ],
    "definition": [
        r"\b(is defined as|refers to|is a|means)\b"
    ],
    "general": []  # keep broad
}

def filter_context_for_intent(context: str, intent: str, max_sents: int = 15) -> str:
    sents = split_sentences(context)
    if not sents:
        return ""
    cues = INTENT_CUE_SETS.get(intent, [])
    if not cues:  # general → keep most informative (length heuristic)
        ranked = sorted(sents, key=len, reverse=True)[:max_sents]
        return " ".join(ranked)

    # Rank sentences by number of cue matches (then by length as tiebreaker)
    compiled = [re.compile(pat, re.I) for pat in cues]
    scored = []
    for s in sents:
        score = sum(1 for pat in compiled if pat.search(s))
        if score > 0:
            scored.append((score, len(s), s))
    if not scored:
        # fallback: return top long sentences rather than empty
        ranked = sorted(sents, key=len, reverse=True)[:min(max_sents, 8)]
        return " ".join(ranked)

    ranked = sorted(scored, key=lambda x: (-x[0], -x[1]))[:max_sents]
    return " ".join(s for _, __, s in ranked)

# ------------------- 3) Intent-aware instruction templates -------------------

INTENT_TEMPLATES = {
    "causes": """You are a biomedical expert.
Using ONLY the context, answer the question by listing the CAUSES/ETIOLOGY explicitly mentioned.
Rules:
- Extract only statements that indicate causation (e.g., "caused by", "due to", "results from").
- If causes are not directly stated, say exactly: "I'm sorry, I cannot answer that question based on the provided information."
- Be concise and structured (bullets).

Context:
{context}

Question:
{question}

Answer:
""",
    "symptoms": """You are a biomedical expert.
Using ONLY the context, list ALL symptoms/signs mentioned, grouped logically (respiratory, neurological, cardiovascular, GI, mental health, etc.). Do not add anything not stated.

Context:
{context}

Question:
{question}

Answer:
""",
    "treatments": """You are a biomedical expert.
Using ONLY the context, summarize evidence-based treatments/management mentioned (drugs, interventions, dose notes if present). If none, use the exact fallback line.

Context:
{context}

Question:
{question}

Answer:
""",
    "risks": """You are a biomedical expert.
Using ONLY the context, list risk factors and associations mentioned. If none are present, use the exact fallback line.

Context:
{context}

Question:
{question}

Answer:
""",
    "mechanisms": """You are a biomedical expert.
Using ONLY the context, explain the pathophysiology/mechanisms mentioned. If mechanisms are not described, use the exact fallback line.

Context:
{context}

Question:
{question}

Answer:
""",
    "definition": """You are a biomedical expert.
Using ONLY the context, provide a crisp definition/description. If no definitional text is present, use the exact fallback line.

Context:
{context}

Question:
{question}

Answer:
""",
    "general": """You are a biomedical expert.
Using ONLY the context, answer as clearly and completely as possible.
- Combine relevant points concisely.
- Do not invent information not in the context.
- If nothing in the context supports an answer, use the exact fallback line.

Context:
{context}

Question:
{question}

Answer:
"""
}

FALLBACK_LINE = "I'm sorry, I cannot answer that question based on the provided information."


In [24]:
@torch.no_grad()
def answer_with_gemma(question, context, tokenizer=None, model=None, device="cpu"):
    """
    Biomedical QA with:
    - intent detection
    - intent-targeted context filtering
    - structured, intent-specific prompting
    - deterministic decoding with safety knobs
    """
    intent = detect_question_intent(question)
    focused_context = filter_context_for_intent(context, intent)

    # If the context is empty after filtering, immediately fallback
    if not focused_context.strip():
        return FALLBACK_LINE

    prompt = INTENT_TEMPLATES[intent].format(context=focused_context, question=question)

    inputs = tokenizer([prompt], return_tensors="pt", padding=True, truncation=True).to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=320,
            num_beams=4,              # thorough but not too slow
            do_sample=False,          # deterministic (prevents invalid 'temperature' issue)
            no_repeat_ngram_size=3,
            repetition_penalty=1.05,  # light touch
            length_penalty=0.9,       # keep it concise
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

    raw = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove echoed prompt if present
    answer = raw.replace(prompt, "").strip()

    # Guardrails: if the model hedges without giving content, or returns empty
    if not answer or answer.lower().startswith("the context does not") or answer.count("cannot answer") > 0:
        # Try a shorter, stricter follow-up once using same generation (no extra call if you prefer)
        # Here we just fall back cleanly:
        return FALLBACK_LINE

    # If user asked for causes but none of the causal cues appear in the answer, fail safe:
    if intent == "causes" and not re.search(r"\b(caused by|due to|results? from|because)\b", answer, re.I):
        # It might still be valid, but to be safe when context is weak:
        # Prefer to return fallback rather than hallucinate.
        # If you want to be less strict, comment this out.
        # Do a soft check: if very short and generic, fallback.
        if len(answer) < 30:
            return FALLBACK_LINE

    return answer


In [25]:
def rag_pipeline(question, k=3):
    # Step 1 – Rewrite query
    rewritten = rewrite_query(question)

    # Step 2 – Retrieve top-k passages (strings)
    retrieved_passages = retrieve_top_k(rewritten, k)  # list[(text, score)] or list[str]
    # Normalize to strings
    passages = [p[0] if isinstance(p, (list, tuple)) else p for p in retrieved_passages]

    # Step 3 – Build raw context
    context = "\n".join(passages)

    # Step 4 – Generate answer
    answer = answer_with_gemma(question, context, tokenizer=tokenizer, model=model, device=device)

    # Step 5 – Structured output
    return {
        "question": question,
        "rewritten": rewritten,
        "context": context,
        "answer": answer
    }


## Examples using Gemma

In [26]:
result = rag_pipeline("What are the symptoms of COVID-19?")
print("Original Question:", result["question"])
print("\n\n")
print("Rewritten:", result["rewritten"])
print("\n\n")
print("Retrieved Context:\n", result["context"])
print("\n\n")
print("Answer:\n", result["answer"])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Original Question: What are the symptoms of COVID-19?



Rewritten: What are the symptoms of COVID-19?



Retrieved Context:
 We aimed this systematic review to analyze and review the currently available 
published literature related to long COVID, understanding its pattern, and 
predicting the long-term effects on survivors. We thoroughly searched the 
databases for relevant articles till May 2021. The research articles that met 
our inclusion and exclusion criteria were assessed and reviewed by two 
independent researchers. After preliminary screening of the identified articles 
through title and abstract, 249 were selected. Consequently, 167 full-text 
articles were assessed and reviewed based on our inclusion criteria and thus 20 
articles were regarded as eligible and analyzed in the present analysis. All the 
studies included adult population aged between 18 and above 60 years. The median 
length of hospital stay of the COVID-19 patients during the acute infection 
phase ranged f

In [27]:
result = rag_pipeline("What is the cause of fever?")
print("Original Question:", result["question"])
print("\n\n")
print("Rewritten:", result["rewritten"])
print("\n\n")
print("Retrieved Context:\n", result["context"])
print("\n\n")
print("Answer:\n", result["answer"])

Original Question: What is the cause of fever?



Rewritten: What is the cause of fever?



Retrieved Context:
 Fever is the most common reason that children and infants are brought to 
emergency departments. Emergency physicians face the challenge of quickly 
distinguishing benign from life-threatening conditions. The management of fever 
in children is guided by the patient's age, immunization status, and immune 
status as well as the results of a careful physical examination and appropriate 
laboratory tests and radiographic views. In this article, the evaluation and 
treatment of children with fevers of known and unknown origin are described. 
Causes of common and dangerous conditions that include fever in their 
manifestation are also discussed.
11,119 patients with scarlet fever admitted in the last sixteen years, from 1973 
to 1988, to Sapporo City General Hospital, were studied statistically on 
symptoms and laboratory findings. The results were summarized as follows: 1. 
Annua

In [28]:
result = rag_pipeline("Who is the president of US?")
print("Original Question:", result["question"])
print("\n\n")
print("Rewritten:", result["rewritten"])
print("\n\n")
print("Retrieved Context:\n", result["context"])
print("\n\n")
print("Answer:\n", result["answer"])

Original Question: Who is the president of US?



Rewritten: Who is the president of US?



Retrieved Context:
 There are 219 virus species that are known to be able to infect humans. The 
first of these to be discovered was yellow fever virus in 1901, and three to 
four new species are still being found every year. Extrapolation of the 
discovery curve suggests that there is still a substantial pool of undiscovered 
human virus species, although an apparent slow-down in the rate of discovery of 
species from different families may indicate bounds to the potential range of 
diversity. More than two-thirds of human viruses can also infect non-human 
hosts, mainly mammals, and sometimes birds. Many specialist human viruses also 
have mammalian or avian origins. Indeed, a substantial proportion of mammalian 
viruses may be capable of crossing the species barrier into humans, although 
only around half of these are capable of being transmitted by humans and around 
half again of transmitti