<a href="https://colab.research.google.com/github/PhongCT1105/RecSys_MAG/blob/main/CS554_Assignment_2_(Student).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **CS 554 - Assignment 2: Retrieval-Augmented Generation for Biomedical Question Answering  (100 pts)**

In this assignment, you will continue working with Gemma-3-1B, the same model we used for the previous assignment, but now applied in a more realistic setting where retrieval is essential. Instead of SQuAD-style QA with answering short, factoid questions directly from the model, we will shift to a domain-specific question answering using the BioASQ dataset, which consists of complex questions and information dense answers, making it closer to real-world scientific Q&A in a field like medicine where answers must be near-perfect.

Though you can use your own laptop if GPUs are available, we strongly encourage using Google Colab on the L4 GPU. If the session is inactive for a certain period, **the runtime will disconnect**, and there's a very good chance **you will lose your progress**.

### **Submission**

Please submit your work on Canvas. You can either:

- Upload your Jupyter notebook file (.ipynb), or

- Provide a link to your notebook (e.g., Google Colab).

Make sure the notebook runs end-to-end without errors and includes all required outputs and answers.

In [14]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [8]:
# Install required packages
!pip install faiss-cpu langchain langchain-community langchain-openai pandas python-dotenv langchain_huggingface bert_score rouge_score datasets evaluate langchain_experimental



## **Task 1: Understanding Evaluation Metrics with Complex Q&A Tasks (16 PTS)**

In this section, most of the code to setting up the [`Gemma-3-1B`](https://huggingface.co/google/gemma-3-1b-it) model and the [`rag_mini_bioasq`](
https://huggingface.co/datasets/rag-datasets/rag-mini-bioasq/) dataset have been given to you. Complete the following tasks and discussion question below:

1. **[DISCUSSION]** Run through the provided code under "`1.1: Dataset Loading`". Notice the question difficulties and answer depths. Looking at the Q&A pairs, discuss why the evaluation metrics we used in Assignment 1, Exact Match (EM) and F1 scores, are not enough. (4 pts)

2. **[DISCUSSION]** Read the following article on [NLP Model Evaluation](https://plainenglish.io/blog/evaluating-nlp-models-a-comprehensive-guide-to-rouge-bleu-meteor-and-bertscore-metrics-d0f1b1). Looking at BERTScore, discuss how it works, its purpose, and its strengths and weaknesses.  (4 pts)

3. **[DISCUSSION]** Evaluate the correctness of the Q&A pairs with model predictions outputted in the given code in "`1.3: BERTScore Evaluation`". Since the answers are highly domain-specific and difficult to understand, feel free to use well-designed AI chatbots like ChatGPT to assess the answers for this part only. (2 pts)

4. **[CODE]** Evaluate the predicted_results_only against answer_5 using BERTScore, at the bottom of "`1.3: BERTScore Evaluation`". You can reference this [link](https://docs.google.com/spreadsheets/d/1RKOVpselB98Nnh_EOC4A2BYn8_201tmPODpNWu4w7xI/edit?gid=0#gid=0) for model you want to use.(4 pts)

5. **[DISCUSSION]** Compare the actual correctness of machine outputs vs. the metrics provided by BERTScore. Is BERTScore really enough? Why or why not? (2 pts)


### **1.1：Dataset Loading (Provided Code)**

Run the following cell to load the **BioASQ biomedical Q&A dataset**.  

This code prepares the evaluation split and selects the first 5 examples for testing.  

In [9]:
from datasets import load_dataset

class PrettyList(list):
    def __repr__(self):
        lines = []
        for i, item in enumerate(self, start=1):
            lines.append(f"{i}. {item}")
        return "\n\n".join(lines)

def load_bioasq_dataset():
  ds_bio = load_dataset("enelpol/rag-mini-bioasq", "question-answer-passages")
  bio_corpus = load_dataset("enelpol/rag-mini-bioasq", "text-corpus")
  eval_data = ds_bio["test"]
  questions = [item["question"] for item in eval_data]
  answers = [item["answer"] for item in eval_data]
  return ds_bio, bio_corpus, questions, answers

ds_bio, bio_corpus, questions, answers=load_bioasq_dataset()

# Load the first 5 test datas for evaluation
questions5 = PrettyList(questions[:5])
answers5   = PrettyList(answers[:5])

In [3]:
questions5

1. Is capmatinib effective for glioblastoma?

2. Describe the mechanism of action of ibalizumab.

3. What is the function of Neu5Gc (N-Glycolylneuraminic acid)?

4. What is the mechanism of action of Inclisiran?

5. What is F105-P?

In [4]:
answers5

1. No. Combination of capmatinib buparlisib resulted in no clear activity in patients with recurrent PTEN-deficient glioblastoma.

2. Ibalizumab is a humanized monoclonal antibody that acts as post-attachment inhibitor by binding CD4 2nd domain of T lymphocyte and preventing HIV connection to CCR5 or CXCR4. It has been recently approved by Food and Drug Administration as a new intravenous antiretroviral agent for heavily treated HIV adults with multi -drug resistant infection.

3. N-glycolylneuraminic acid (Neu5Gc) is an immunogenic sugar of dietary origin that metabolically incorporates into diverse native glycoconjugates in humans.  Humans lack a functional cytidine monophosphate-N-acetylneuraminic acid hydroxylase (CMAH) protein and cannot synthesize the sugar Neu5Gc, an innate mammalian signal of self. N-Glycolylneuraminic acid (Neu5Gc) can be incorporated in human cells and can trigger immune response, a response that is diverse and polyclonal. As dietary Neu5Gc is primarily found

### **1.2: Model Loading (Provided Code)**

In this step, we initialize the **Gemma-3-1B** model for text generation.  

We use Hugging Face’s `transformers` library to load the tokenizer and model, and wrap them in a simple text-generation `pipeline`.  

The `langchain` wrapper (`HuggingFacePipeline`) allows us to call the model consistently inside later RAG components.  

We configure the pipeline to produce **deterministic outputs** (`do_sample=False`, `temperature=0.0`) with a cap of **256 new tokens**, so results are reproducible and stay within resource limits.

For those interested, check out Hugging Face's [Guide on Pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines) for more detail.

In [15]:
## Load Gemma-3-1B

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_community.llms import HuggingFacePipeline
MODEL_ID = "google/gemma-3-1b-it"
def load_llm(model_id):
  '''
  load LLM model with pipeline, which can be easily used to generate text.

  '''

  tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
  mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto", device_map="auto")

  gen = pipeline(
      "text-generation",
      model=mdl,
      tokenizer=tok,
      max_new_tokens=256,
      do_sample=False,
      temperature=0.0
  )

  llm = HuggingFacePipeline(pipeline=gen)
  return llm

llm = load_llm(MODEL_ID)

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/899 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

Device set to use cuda:0
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
  llm = HuggingFacePipeline(pipeline=gen)


In [7]:
!nvidia-smi

Tue Nov 18 00:33:43 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   37C    P8             16W /   72W |       3MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

### **1.3: Result Evaluation (Provided Code + TODO)**

For demonstrative purposes, here’s a quick but effective way to clean up messy raw generations so they’re easier to compare against the gold answers. This snippet shows how to remove that noise and then print the **question, reference (gold) answer, and cleaned model output side by side** for the first few examples.  

This process also avoids wasted tokens from normal responses:
the baseline model sometimes echoes the question, includes disclaimers, or adds Markdown formatting — all of which can pollute evaluation.  

In [16]:
import re

DISCARD_LINES_PAT = re.compile(
    r"^(?:\s*(?:Disclaimer:|Note:|Please note:).*$|"
    r"\s*(?:I am an AI|As an AI|This is not medical advice).*$)",
    flags=re.IGNORECASE
)

def clean_output(text: str, question: str) -> str:
    # Normalize newlines and strip
    s = text.replace("\r\n", "\n").strip()

    # If the model echoed the question on the first line, drop it
    first_line, *rest = s.split("\n")
    if first_line.strip() == question.strip():
        s = "\n".join(rest).strip()

    # Remove obvious boilerplate/disclaimers and empty lines
    kept = []
    for line in s.split("\n"):
        line = line.strip()
        if not line:
            continue
        if DISCARD_LINES_PAT.match(line):
            continue
        # Strip markdown headers/bullets if present
        line = re.sub(r"^\s*(?:[#>*-]\s*)+", "", line)
        kept.append(line)
    s = " ".join(kept)

    # Collapse whitespace and trim special tokens
    s = re.sub(r"\s+", " ", s)
    s = s.replace("```", "").strip()

    # Avoid returning empty; keep at least a short span to not break metrics
    return s if s else "[no answer]"

def generate_predicted_answer(llm, questions):
    raw = llm.batch(questions)
    return [clean_output(o, q) for o, q in zip(raw, questions)]

predicted_results_only = generate_predicted_answer(llm, questions5)

`generation_config` default values have been modified to match model-specific defaults: {'do_sample': True}. If this is not desired, please set these values explicitly.


In [17]:
import textwrap

# Side-by-side glance: Question, Gold Answer, Model Prediction (with wrapping)
for i, (q, gold, pred) in enumerate(zip(questions5[:3], answers5[:3], predicted_results_only[:3])):
    print(f"[{i}] Q: {q}\n")

    print("REF (gold):")
    for line in textwrap.wrap(gold, width=100):
        print("  " + line)

    print("\nGEN (clean):")
    for line in textwrap.wrap(pred, width=100):
        print("  " + line)

    print("-" * 80)

[0] Q: Is capmatinib effective for glioblastoma?

REF (gold):
  No. Combination of capmatinib buparlisib resulted in no clear activity in patients with recurrent
  PTEN-deficient glioblastoma.

GEN (clean):
  The question of whether capmatinib is effective for glioblastoma is complex and currently under
  investigation. Here's a breakdown of what we know: Initial Studies:** Early studies (primarily in
  the 2010s) showed that capmatinib, a tyrosine kinase inhibitor, demonstrated a modest, but
  statistically significant, improvement in overall survival in patients with glioblastoma. Recent
  Research:** More recent research, including trials in patients with glioblastoma who had progressed
  on other treatments, suggests that capmatinib may be more effective than standard treatment in
  certain subgroups. Key Findings:** Improved Survival:** Some studies have reported a statistically
  significant increase in overall survival (OS) in patients treated with capmatinib compared to
  stand

In [18]:
# Evaluate the answer with BERT score
import evaluate
from bert_score import score

# TODO: Evaluate the predicted_results_only against answer_5 using BERTScore
P, R, F1 = score(predicted_results_only, answers5, model_type="microsoft/deberta-large-mnli")

print(f"BERTScore - P: {P.mean().item():.3f}, R: {R.mean().item():.3f}, F1: {F1.mean().item():.3f}")

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/729 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.62G [00:00<?, ?B/s]

BERTScore - P: 0.470, R: 0.584, F1: 0.520


## **Task 2: Simple RAG (28 PTS)**

In this section, you will implement a simple Retrieval-Augmented Generation (RAG) pipeline with the provided biomedical corpus. Complete the following tasks and discussion questions:

1. **[CODE]** Prepare the corpus for retrieval. (3 pts)

    - **[DISCUSSION]** Why do we need to chunk the documents instead of just embedding whole passages directly? (2 pts).
    

2. **[CODE]** Convert chunks into embeddings and store them in `FAISS` for retrieval. (3 pts)

    - **[DISCUSSION]** Skim through the following documentation on [`text_splitters`](https://python.langchain.com/docs/concepts/text_splitters/) and answer why they are necessary (2 pts).

3. **[CODE]** Implement the `simple_rag` pipeline below. (10 pts)

4. **[CODE]** Evaluate the answers with `BERTScore`. (2 pts)

    - **[DISCUSSION]** Look at the output scores. What do they suggest about your generated answers? (3 pts)

5. **[DISCUSSION]** Run the provided code below and check out the outputs. Briefly summarize how the RAG workflow implemented here works. (3 pts)

In [19]:
# Load & inspect the dataset
# bio_corpus['test'] contains passages + IDs
# Each entry looks like: {"passage": "...", "id": 1234}

print(bio_corpus['test'][0])

{'passage': 'New data on viruses isolated from patients with subacute thyroiditis de Quervain \nare reported. Characteristic morphological, cytological, some physico-chemical \nand biological features of the isolated viruses are described. A possible role \nof these viruses in human and animal health disorders is discussed. The isolated \nviruses remain unclassified so far.', 'id': 9797}


In [20]:
# =======================
# Step 1.1: Wrap passages into LangChain Document objects
# page_content → the text to embed
# metadata → keep doc_id for tracking retrieval
# =======================

from langchain.schema import Document

# TODO: build a list of Document objects from bio_corpus["test"]
# HINT: each entry has keys "passage" (the text) and "id" (unique identifier)
docs = [
    Document(
        page_content=entry["passage"],                # TODO: fill with entry["passage"]
        metadata={"doc_id": entry["id"]}         # TODO: fill with entry["id"]
    )
    for entry in bio_corpus["test"]
]

In [21]:
# =======================
# Step 1.2: Split large passages into smaller chunks
# Why? → Improves retrieval accuracy (finer granularity)
# Parameters:
#   - chunk_size: max tokens per chunk
#   - chunk_overlap: overlap to preserve context
# =======================

from langchain.text_splitter import RecursiveCharacterTextSplitter

text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=200,
    length_function=len,
    is_separator_regex=False,
)

chunked_docs = text_splitter.split_documents(docs)

In [22]:
# =======================
# Step 2.1: Convert chunks into embeddings + store in FAISS
# Embeddings: dense vectors for semantic similarity
# FAISS: vector DB for fast nearest-neighbor retrieval
# =======================

from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

# HuggingFace embedding model
emb_model = HuggingFaceEmbeddings(
    model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1",
    model_kwargs={"device": "cuda"}
)

# Build FAISS vector store from chunked_docs
vectordb = FAISS.from_documents(chunked_docs, emb_model)   # TODO: fill with docs and embedding model

modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [23]:
# =======================
# Step 3: Define a minimal RAG pipeline
# Input:
#   - questions (list[str]): list of str
#   - vectordb: FAISS retriever
#   - llm: language model (Gemma-3, LangChain-wrapped)
#   - k: number of docs retrieved per query
#
# Process:
#   1. Retrieve top-k docs
#   2. Build context from retrieved docs
#   3. Generate answer with LLM
#   4. Post-process + collect doc IDs
# Output:
#   - rag_predicted_answers (list[str]): list of generated answers
#   - retrieved_ids (list[list[str]]): doc IDs retrieved
# =======================

import re

def simple_rag(questions, vectordb, llm, k=5):

    rag_predicted_answers, retrieved_ids, retrieved_docs = [], [], []
    retriever = vectordb.as_retriever(search_kwargs={"k": k})

    for q in questions:
        # Use retriever to get relevant documents
        docs = retriever.get_relevant_documents(q)

        # TODO: Concatenate top-k document contents into context
        # HINT: join d.page_content for d in docs[:k]
        context = "\n".join([d.page_content for d in docs[:k]])

        # TODO: Use llm to generate an answer based on question + context
        # HINT: llm.invoke()
        prompt = f"Use the context to answer.\n\nQuestion: {q}\n\nContext:\n{context}\n\nAnswer:"
        answer = llm.invoke(prompt)
        answer = answer.content if hasattr(answer, "content") else answer

        answer_only = re.split(r"Answer:\s*", answer, maxsplit=1)[-1].strip()
        rag_predicted_answers.append(answer_only)

        # TODO: Collect retrieved document IDs
        # HINT: use d.metadata.get() for each doc
        ids = [d.metadata.get("doc_id") for d in docs[:k]]
        retrieved_ids.append(ids)
        retrieved_docs.append(docs[:k])

    return rag_predicted_answers, retrieved_ids, retrieved_docs

# TODO: run the RAG pipeline with 5 retrieved docs
rag_predicted_answers, retrieved_ids, retrieved_docs = simple_rag(questions5, vectordb, llm, k=5)


  docs = retriever.get_relevant_documents(q)


In [24]:
# =======================
# Step 4: Evaluate with BERTScore
#   - P (precision): how much of the predicted text is correct
#   - R (recall): how much of the reference was covered
#   - F1: harmonic mean of P and R
# =======================

# Clean up GPU memory
import torch, gc, os
gc.collect(); torch.cuda.empty_cache()

import evaluate
from bert_score import score

P, R, F1 = score(rag_predicted_answers, answers5, model_type="microsoft/deberta-large-mnli")
print(f"BERTScore - P: {P.mean().item():.3f}, R: {R.mean().item():.3f}, F1: {F1.mean().item():.3f}")

BERTScore - P: 0.565, R: 0.572, F1: 0.566


In [25]:
# =======================
# Step 5: Display Q/A results for inspection
# Run the provided code below
# =======================

# =======================
# Baseline display WITH sentence-level attribution
# =======================
from IPython.display import display, HTML
import html, re
import math

def _sent_split(s: str):
    s = str(s).strip()
    parts = re.split(r'(?<=[.!?])\s+', s)
    return [p.strip() for p in parts if p.strip()]

def _norm_tokens(s: str):
    return [t for t in re.findall(r"[a-z0-9]+", s.lower()) if t]

def _overlap_score(a: str, b: str):
    # Jaccard on unigrams (simple, fast, robust)
    A, B = set(_norm_tokens(a)), set(_norm_tokens(b))
    if not A or not B: return 0.0
    return len(A & B) / len(A | B)

def _short(s, n):
    s = str(s).strip().replace("\n"," ")
    return s[:n-1]+"…" if len(s)>n else s

def render_rag_results_with_attribution(
    questions, preds, golds,
    retrieved_docs=None, retrieved_ids=None,
    max_preview_chars=260, per_sent_top=1
):
    cards=[]
    for i, q in enumerate(questions):
        pred = str(getattr(preds[i], "content", preds[i])) if i < len(preds) else ""
        gold = str(golds[i]) if i < len(golds) else ""

        # Build an array of (doc_id, sentences[]) for attribution if docs available
        doc_sents = []
        if retrieved_docs and i < len(retrieved_docs) and retrieved_docs[i]:
            for d in retrieved_docs[i]:
                if hasattr(d, "page_content"):
                    did = (getattr(d, "metadata", {}) or {}).get("doc_id", "NA")
                    sents = _sent_split(getattr(d, "page_content", ""))
                    if sents:
                        doc_sents.append((did, sents))

        # Per prediction sentence, pick highest-overlap supporting sentence across docs
        attrib_blocks = []
        pred_sents = _sent_split(pred)
        for ps in pred_sents:
            best = []
            for did, sents in doc_sents:
                # find best sentence in this doc
                local_best = max(sents, key=lambda s: _overlap_score(ps, s)) if sents else ""
                score = _overlap_score(ps, local_best) if local_best else 0.0
                best.append((score, did, local_best))
            if best:
                best.sort(reverse=True, key=lambda x: x[0])
                chosen = best[:per_sent_top]
                # render chosen supports (color intensity by score)
                sup_html = []
                for sc, did, sup in chosen:
                    shade = int(255 - min(1.0, sc) * 120)  # lower = darker for higher score
                    sup_html.append(
                        f"<div style='margin:6px 0;padding:8px;border-radius:8px;"
                        f"background: rgb({shade},{shade},{shade}); color:#111;'>"
                        f"<div style='font-size:12px;color:#222'><b>doc_id:</b> {html.escape(str(did))} "
                        f"<span style='opacity:.8'>(score {sc:.2f})</span></div>"
                        f"<div style='font-size:14px;line-height:1.35'>{html.escape(_short(sup, max_preview_chars))}</div>"
                        f"</div>"
                    )
                attrib_blocks.append(
                    f"<div style='margin-top:8px'>"
                    f"<div style='font-weight:700'>Pred sentence:</div>"
                    f"<div style='margin:4px 0 6px'>{html.escape(ps)}</div>"
                    f"{''.join(sup_html)}"
                    f"</div>"
                )

        # main card
        block = [f"""
        <div style="border:1px solid #ddd;border-radius:12px;padding:14px;margin:10px 0;">
          <div style="font-size:14px;color:#666;">Example {i+1}</div>
          <div style="font-size:16px;font-weight:700;margin-top:4px;">Question</div>
          <div>{html.escape(q)}</div>

          <div style="display:flex;gap:16px;margin-top:10px;">
            <div style="flex:1;">
              <div style="font-weight:700;">Predicted (stitched from retrieved chunks)</div>
              <div>{html.escape(pred) if pred else "<i>(empty)</i>"}</div>
            </div>
            <div style="flex:1;">
              <div style="font-weight:700;">Gold</div>
              <div>{html.escape(gold)}</div>
            </div>
          </div>
        """]
        # Attribution section
        if attrib_blocks:
            block.append("<div style='margin-top:12px;font-weight:700;'>Where the prediction likely came from</div>")
            block.append("".join(attrib_blocks))
        else:
            # fallback: list ids or short previews
            block.append("<div style='margin-top:12px;font-weight:700;'>Retrieved Evidence</div>")
            if retrieved_docs and i < len(retrieved_docs) and retrieved_docs[i]:
                for d in retrieved_docs[i]:
                    if hasattr(d, "page_content"):
                        did = (getattr(d, "metadata", {}) or {}).get("doc_id", "NA")
                        prev = _short(getattr(d, "page_content", ""), max_preview_chars)
                        block.append(
                            f"<div style='margin-top:8px;padding:10px;background:#fafafa;border:1px dashed #ddd;border-radius:8px;'>"
                            f"<div style='font-size:13px;color:#555;'><b>doc_id:</b> {html.escape(str(did))}</div>"
                            f"<div style='margin-top:4px;font-size:14px;line-height:1.4;'>{html.escape(prev)}</div>"
                            f"</div>"
                        )
            elif retrieved_ids and i < len(retrieved_ids):
                block.append(f"<div style='color:#555;'>doc_ids: {html.escape(str(retrieved_ids[i]))}</div>")
            else:
                block.append("<div style='color:#999;'>No retrievals recorded.</div>")

        block.append("</div>")
        cards.append("".join(block))
    display(HTML("".join(cards)))



In [26]:
render_rag_results_with_attribution(
    questions5, rag_predicted_answers, answers5,
    retrieved_docs=retrieved_docs,  # list[list[Document]]
    per_sent_top=1                  # show top-1 supporting sentence per pred sentence
)


## **Task 3: Implementing and Evaluating RAG Variants (26 PT)**
In this task, you will implement different variants of Retrieval-Augmented Generation (RAG).

Your goal is not just to code, but to understand the concept behind each method.

* Why do we need this variant of RAG?

* What problem does it solve compared to simple RAG?

* What are its potential advantages or trade-offs?

Complete the following tasks and discussion questions:

1. **[DISCUSSION]** Check this [instruction](https://levelup.gitconnected.com/testing-18-rag-techniques-to-find-the-best-094d166af27f#c20e) for variants of RAG. From each of the following six methods, please use 2-3 sentences each to explain why we need this variant and how the workflow work: 1. Semantic Chunking. 2. Query Transformation. 3. Contextual Compression. 4.Re-Ranker. 5. HyDE. 6. Self-RAG. (6 pts)
    
2. (20 pts) Implement (1) Query Transformation (2) Hypothetical Document Expansion (3) Self-RAG (4) Semantic Chunking +Headers. For each of the method:
    - **[CODE]** Fill in the TODOs for each of the methods (3 pts)
    - **[CODE]** Print the mertic numbers (1 pts)
    - **[DISCUSSION]** The pros and cons of the methods, and explain the result (1 pt)

### **3.1：Query Transformation**

In [28]:
### Query Transformation (Multi-Query/Reformulation)
import re

# -------- Query Transformation (Multi-Query Retrieval) --------
def query_transformation(question, llm, retriever, m=3, k_each=5):
    """
    Generate multiple reformulations of the question (paraphrases),
    retrieve documents for each reformulated query, and merge results.

    Args:
        question: original question (string)
        llm: a LangChain LLM (e.g., Gemma3 wrapped with HuggingFacePipeline)
        retriever: a vector retriever (e.g., FAISS retriever)
        m: number of reformulated queries to generate
        k_each: top-k documents to retrieve per reformulated query

    Returns:
        merged_docs: list of deduplicated documents from all queries
        queries: list of queries used (original + reformulations)
    """
    # Step 1: create reformulated queries
    queries = [question]
    for _ in range(m):
        # TODO: Use the LLM to generate reformulated versions of the original question and append them to the query list.
        reform = llm.invoke(
    f"Rewrite this question in a different way, but keep the meaning.\n\nQuestion: {question}\n\nRewritten:" )
        reform_text = getattr(reform, "content", reform)   # handle AIMessage or str
        queries.append(str(reform_text).strip())

    # Step 2: retrieve docs for each query
    pool = []
    for q in queries:
        hits = retriever.get_relevant_documents(q) or []
        pool.extend(hits[:k_each])

    # Step 3: deduplicate results (by doc id or content hash)
    seen = set()
    merged_docs = []
    for d in pool:
        did = (getattr(d, "metadata", {}) or {}).get("id")
        key = did if did is not None else hash(getattr(d, "page_content", ""))
        if key not in seen:
            seen.add(key)
            merged_docs.append(d)

    return merged_docs, queries


# -------- RAG pipeline using Query Transformation --------
def query_transformation_rag(questions, vectordb, llm, m=3, k_each=3, k_final=5, max_ctx_chars=6000):
    """
    Minimal RAG pipeline with query transformation.
    For each question:
      - generate reformulated queries
      - retrieve docs for each reformulation
      - merge and deduplicate docs
      - select top-k_final docs as context
      - generate answer using LLM

    Args:
        questions: list of questions
        vectordb: LangChain VectorStore (e.g., FAISS)
        llm: a LangChain LLM (e.g., Gemma3)
        m: number of reformulated queries
        k_each: top-k docs per reformulation
        k_final: number of docs used as final context
        max_ctx_chars: hard cap on context length

    Returns:
        rag_predicted_answers: list of generated answers (strings)
        retrieved_ids: list of lists of retrieved document IDs
    """
    rag_predicted_answers = []
    retrieved_ids = []
    retriever = vectordb.as_retriever(search_kwargs={"k": max(k_each, k_final)})

    for q in questions:
        # Step A: multi-query retrieval
        docs_pool, queries_used = query_transformation(q, llm, retriever, m=m, k_each=k_each)

        # Step B: keep only top-k_final docs
        docs = docs_pool[:k_final]

        # Step C: build context string
        parts = [getattr(d, "page_content", "").strip() for d in docs if getattr(d, "page_content", None)]
        context = "\n\n".join(parts)
        if len(context) > max_ctx_chars:
            context = context[:max_ctx_chars]

        # Step D: construct prompt for the LLM
        prompt = (
            "Answer ONLY using the context. If the answer is not in the context, say: I don't know.\n\n"
            f"Question: {q}\n\nContext:\n{context or '(EMPTY)'}\n\nAnswer:"
        )
        out = llm.invoke(prompt)
        answer_text = getattr(out, "content", out)
        answer_text = str(answer_text).strip()

        # Step E: remove repeated question at the start (optional cleanup)
        answer_text = re.sub(rf"^{re.escape(q)}[\s\?]*", "", answer_text, flags=re.IGNORECASE).strip()

        # Step F: collect retrieved doc IDs
        ids = []
        for d in docs:
            mid = (getattr(d, "metadata", {}) or {}).get("id", "NA")
            ids.append(mid)

        rag_predicted_answers.append(answer_text)
        retrieved_ids.append(ids)

    return rag_predicted_answers, retrieved_ids

querytrans_rag_predicted_answers, retrieved_ids = query_transformation_rag(questions5, vectordb, llm, m=3, k_each=3, k_final=5, max_ctx_chars=6000)
# TODO: Call the score function with predicted answers and gold answers to compute Precision, Recall, and F1. Print the results.
P, R, F1 = score(querytrans_rag_predicted_answers, answers5, model_type="microsoft/deberta-large-mnli")
print(f"BERTScore - P: {P.mean().item():.3f}, R: {R.mean().item():.3f}, F1: {F1.mean().item():.3f}")

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


BERTScore - P: 0.403, R: 0.593, F1: 0.479


### **3.2: Hypothetical Document Expansion**

In [29]:
# HyDE (Hypothetical Document Expansion)
import re

# -------- HyDE module: generate pseudo answer and retrieve by vector --------
def hyde_neighbors(question, llm, emb_model, vectordb, k=8, pseudo_sentences=3):
    """
    HyDE: Generate a short hypothetical answer/snippet, embed it, and retrieve neighbors by vector.

    Args:
        question: user question (string)
        llm: a LangChain LLM (used to draft the hypothetical snippet)
        emb_model: embedding model wrapper used by your vector store (e.g., HuggingFaceEmbeddings)
        vectordb: your VectorStore (e.g., FAISS) with similarity_search_by_vector available
        k: number of neighbors to return
        pseudo_sentences: how long the hypothetical snippet should be (soft guidance)

    Returns:
        docs: top-k retrieved documents (list)
        pseudo: the hypothetical snippet generated by the LLM (string)
    """
    # 1) Ask LLM to write a short hypothetical snippet for retrieval ONLY
    # TODO: write the prompt to let LLM generate pseudo answer
    prompt = (
    f"Write a short hypothetical answer (2–3 sentences) to this question. "
    f"This is ONLY for retrieval — do NOT mention uncertainty.\n\nQuestion: {question}\n\nPseudo answer:"
)
    out = llm.invoke(prompt)
    pseudo = getattr(out, "content", out)
    pseudo = str(pseudo).strip()

    # 2) Embed the pseudo text and search by vector
    q_vec = emb_model.embed_query(pseudo)
    docs = vectordb.similarity_search_by_vector(q_vec, k=k) or []

    return docs, pseudo


# -------- RAG pipeline using HyDE --------
def hyde_rag(questions, vectordb, llm, emb_model, k_hyde=8, k_final=3, max_ctx_chars=6000, also_mix_query=False):
    """
    Minimal RAG with HyDE.
    For each question:
      - generate a hypothetical snippet (HyDE)
      - embed snippet and retrieve neighbors by vector
      - (optionally) mix in normal retrieval of the original query for robustness
      - deduplicate, keep top-k_final chunks as context
      - ask the LLM to answer using ONLY the real context (never the pseudo text)

    Args:
        questions: list of input questions
        vectordb: VectorStore (e.g., FAISS) supporting similarity_search_by_vector
        llm: a LangChain LLM
        emb_model: the embedding wrapper used for queries
        k_hyde: neighbors from HyDE vector search
        k_final: number of docs to stuff into the prompt
        max_ctx_chars: hard cap on total context length
        also_mix_query: if True, also retrieve with the raw question and merge

    Returns:
        rag_predicted_answers: list of answers
        retrieved_ids: list of document IDs used per question
    """
    rag_predicted_answers = []
    retrieved_ids = []

    # A simple retriever for the raw question (optional mix-in)
    base_retriever = vectordb.as_retriever(search_kwargs={"k": max(k_hyde, k_final)})

    for q in questions:
        # 1) HyDE neighbors
        hyde_docs, pseudo = hyde_neighbors(q, llm, emb_model, vectordb, k=k_hyde)

        # 2) Optionally: retrieve with the raw question and merge pools
        pool = list(hyde_docs)
        if also_mix_query:
            raw_hits = base_retriever.get_relevant_documents(q) or []
            pool.extend(raw_hits)

        # 3) Deduplicate by (id or content hash)
        seen = set()
        merged = []
        for d in pool:
            did = (getattr(d, "metadata", {}) or {}).get("id")
            key = did if did is not None else hash(getattr(d, "page_content", ""))
            if key not in seen:
                seen.add(key)
                merged.append(d)

        # 4) Keep only k_final docs for context
        docs = merged[:k_final]

        # 5) Build context from REAL docs (never include the hypothetical snippet)
        parts = [getattr(d, "page_content", "").strip() for d in docs if getattr(d, "page_content", None)]
        context = "\n\n".join(parts)
        if len(context) > max_ctx_chars:
            context = context[:max_ctx_chars]

        # 6) Ask the LLM (discourage hallucination)
        prompt = (
            "Answer ONLY using the context. If the answer is not in the context, say: I don't know.\n\n"
            f"Question: {q}\n\nContext:\n{context or '(EMPTY)'}\n\nAnswer:"
        )
        out = llm.invoke(prompt)
        answer_text = getattr(out, "content", out)
        answer_text = str(answer_text).strip()

        # optional cleanup: remove repeated question at start
        answer_text = re.sub(rf"^{re.escape(q)}[\s\?]*", "", answer_text, flags=re.IGNORECASE).strip()

        # 7) Collect IDs for evaluation
        ids = []
        for d in docs:
            mid = (getattr(d, "metadata", {}) or {}).get("id", "NA")
            ids.append(mid)

        rag_predicted_answers.append(answer_text)
        retrieved_ids.append(ids)

    return rag_predicted_answers, retrieved_ids
hyde_predicted_answers, retrieved_ids = hyde_rag(questions5, vectordb, llm, emb_model, k_hyde=8, k_final=3, max_ctx_chars=6000, also_mix_query=False)
# TODO: Call the score function with predicted answers and gold answers to compute Precision, Recall, and F1. Print the results.
P, R, F1 = score(hyde_predicted_answers, answers5, model_type="microsoft/deberta-large-mnli")
print(f"BERTScore - P: {P.mean().item():.3f}, R: {R.mean().item():.3f}, F1: {F1.mean().item():.3f}")

BERTScore - P: 0.419, R: 0.603, F1: 0.493


### **3.3: Self-RAG**

In [30]:
### Self-Rag
import re
import json

# -------- Helper: build context from docs --------
def _build_context(docs, max_ctx_chars):
    parts = [getattr(d, "page_content", "").strip() for d in docs if getattr(d, "page_content", None)]
    context = "\n\n".join(parts)
    if len(context) > max_ctx_chars:
        context = context[:max_ctx_chars]
    return context

# -------- Step 1: draft an answer using ONLY the context --------
def selfrag_draft_answer(question, context, llm):
    """
    Produce a first-pass answer strictly from the provided context.
    """
    prompt = (
        "You must answer ONLY using the provided context. "
        "If the answer is not in the context, say: I don't know.\n\n"
        f"Question: {question}\n\nContext:\n{context or '(EMPTY)'}\n\n"
        "Draft answer (concise):"
    )
    out = llm.invoke(prompt)
    text = getattr(out, "content", out)
    text = str(text).strip()
    # optional cleanup: remove echoed question
    text = re.sub(rf"^{re.escape(question)}[\s\?]*", "", text, flags=re.IGNORECASE).strip()
    return text

# -------- Step 2: self-check (are the claims supported by context?) --------
def selfrag_selfcheck(question, draft_answer, context, llm):
    """
    Ask the model to verify whether each claim in draft_answer is supported by context.
    Returns a small dict with verdict and lists of supported/unsupported claims
    """
    # TODO:
    prompt = (
    "Evaluate whether the DRAFT ANSWER is fully supported by the CONTEXT.\n"
    "Return JSON with: verdict, supported, unsupported.\n\n"
    f"QUESTION: {question}\n\n"
    f"CONTEXT:\n{context}\n\n"
    f"DRAFT ANSWER:\n{draft_answer}\n\n"
    "JSON:"
)
    out = llm.invoke(prompt)
    raw = getattr(out, "content", out)
    raw = str(raw).strip()

    # Best-effort JSON extraction
    try:
        # find the first and last curly braces block
        start = raw.find("{"); end = raw.rfind("}")
        parsed = json.loads(raw[start:end+1]) if start != -1 and end != -1 else {}
    except Exception:
        parsed = {}

    # normalize structure
    verdict = parsed.get("verdict", "").upper()
    if verdict not in ("SUPPORTED", "PARTIAL", "UNSUPPORTED"):
        verdict = "PARTIAL"
    supported = parsed.get("supported", [])
    unsupported = parsed.get("unsupported", [])
    if not isinstance(supported, list): supported = [str(supported)]
    if not isinstance(unsupported, list): unsupported = [str(unsupported)]

    return {"verdict": verdict, "supported": supported, "unsupported": unsupported, "raw": raw}

# -------- Step 3: revise or abstain based on self-check --------
def selfrag_revise(question, draft_answer, context, selfcheck, llm):
    """
    If verdict is not SUPPORTED, revise the answer using ONLY supported evidence.
    If evidence is insufficient, abstain with 'I don't know'.
    """
    verdict = selfcheck.get("verdict", "PARTIAL")
    unsupported = selfcheck.get("unsupported", [])
    supported = selfcheck.get("supported", [])

    if verdict == "SUPPORTED":
        return draft_answer

    # If nothing is supported, abstain
    if not supported:
        return "I don't know."

    # Otherwise, rewrite using only supported parts of the context.
    ## TODO
    prompt =(
    "Rewrite the final answer using ONLY the supported evidence. "
    "Remove unsupported information completely.\n\n"
    f"QUESTION: {question}\n\n"
    f"SUPPORTED EVIDENCE:\n{supported}\n\n"
    "Revised answer:"
)
    out = llm.invoke(prompt)
    text = getattr(out, "content", out)
    text = str(text).strip()
    text = re.sub(rf"^{re.escape(question)}[\s\?]*", "", text, flags=re.IGNORECASE).strip()
    return text

# -------- Self-RAG pipeline (retrieve, draft, self-check, revise) --------
def selfrag_pipeline(questions, vectordb, llm, k=5, max_ctx_chars=6000):
    """
    Minimal Self-RAG pipeline.
    For each question:
      - retrieve k docs
      - draft an answer using ONLY the context
      - self-check faithfulness against the context
      - if needed, revise or abstain

    Returns:
      rag_predicted_answers: list of final answers
      retrieved_ids: list of document IDs used per question
      diagnostics: list of dicts with 'draft', 'selfcheck_verdict', 'unsupported_claims'
    """
    rag_predicted_answers = []
    retrieved_ids = []
    diagnostics = []

    retriever = vectordb.as_retriever(search_kwargs={"k": k})

    for q in questions:
        # A) retrieve
        docs = retriever.get_relevant_documents(q) or []

        # B) build context
        context = _build_context(docs, max_ctx_chars=max_ctx_chars)

        # C) draft
        draft = selfrag_draft_answer(q, context, llm)

        # D) self-check
        sc = selfrag_selfcheck(q, draft, context, llm)

        # E) revise/abstain
        final_answer = selfrag_revise(q, draft, context, sc, llm)

        # F) collect ids
        ids = []
        for d in docs:
            mid = (getattr(d, "metadata", {}) or {}).get("id", "NA")
            ids.append(mid)

        rag_predicted_answers.append(final_answer.strip())
        retrieved_ids.append(ids)
        diagnostics.append({
            "draft": draft,
            "selfcheck_verdict": sc.get("verdict", "PARTIAL"),
            "unsupported_claims": sc.get("unsupported", []),
            "raw_selfcheck": sc.get("raw", "")
        })

    return rag_predicted_answers, retrieved_ids, diagnostics

selfrag_predicted_answers, retrieved_ids, diagnostics = selfrag_pipeline(questions5, vectordb, llm, k=5, max_ctx_chars=6000)
## TODO: Call the score function with predicted answers and gold answers to compute Precision, Recall, and F1. Print the results.
P, R, F1 = score(selfrag_predicted_answers, answers5, model_type="microsoft/deberta-large-mnli")
print(f"BERTScore - P: {P.mean().item():.3f}, R: {R.mean().item():.3f}, F1: {F1.mean().item():.3f}")

BERTScore - P: 0.386, R: 0.562, F1: 0.454


### **3.4: Semantic Chunking +Headers**

In [33]:
### Semantic Chunking +Headers

from langchain_experimental.text_splitter import SemanticChunker
from langchain_huggingface import HuggingFaceEmbeddings

emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

## TODO: define the chunker from SemanticChunker, and chunck the docs
chunker = SemanticChunker(emb)
sem_chunks = chunker.split_documents(docs)

def add_header(doc):
    md = doc.metadata or {}
    header = "[title={}; source={}; date={}; section={}]".format(
        md.get("title",""), md.get("source",""), md.get("published_at",""), md.get("section","")
    )
    doc.page_content = f"{header}\n{doc.page_content}"
    return doc

sem_chunks = [add_header(d) for d in sem_chunks]

vectordb_semantic=FAISS.from_documents(sem_chunks, emb)
rag_predicted_answers, retrieved_ids, retrieved_docs = simple_rag(
    questions5,
    vectordb_semantic,
    llm,
    k=5
)
# TODO: Call the score function with predicted answers and gold answers to compute Precision, Recall, and F1. Print the results.
P, R, F1 = score(rag_predicted_answers, answers5, model_type="microsoft/deberta-large-mnli")
print(f"BERTScore - P: {P.mean().item():.3f}, R: {R.mean().item():.3f}, F1: {F1.mean().item():.3f}")

BERTScore - P: 0.588, R: 0.557, F1: 0.568


## **Task 4: Final Hybrid RAG Variants for Bio Q&A (30 PTS + 5 EC)**

In this task, you will complete and test a full RAG pipeline on the BioASQ dataset. Starting from hybrid retrieval (dense + BM25), you’ll add two RAG variants—cross-encoder reranking and contextual compression—to improve evidence quality. You’ll then generate answers constrained to the retrieved context, implement an adaptive retry policy for refusals, and finally compare the pipeline’s outputs against a simple baseline. This exercise shows how each component contributes to moving from noisy retrieval toward more accurate, faithful biomedical QA.

### **CONFIG & IMPORTS (GIVEN CODE)**

In [34]:
# =========================================
# Step 0: CONFIG & IMPORTS
# =========================================
import re
import random
from typing import List, Any, Tuple, Optional
from dataclasses import dataclass

from langchain.schema import Document
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain_community.retrievers import BM25Retriever
from sentence_transformers import CrossEncoder, SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

# reuse your existing tokenizer object named `tokenizer`
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-3-1b-it",
    device_map="auto",
    torch_dtype="auto"  # bfloat16/float16 when available
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")


gen = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=128,
    do_sample=False,           # deterministic for grading; flip to True if you want sampling
    temperature=0.0,
    pad_token_id=tokenizer.eos_token_id
)

llm = HuggingFacePipeline(pipeline=gen)  # <- now llm.invoke(prompt) works

# ---- Configuration knobs ----
CFG = {
    # Embedders & rerankers
    "dense_embedder": "pritamdeka/S-Biomed-Roberta-snli-multinli-stsb",
    "rerankers_try": [
        "ncbi/MedCPT-Cross-Encoder",
        "cross-encoder/ms-marco-MiniLM-L-12-v2",
        "BAAI/bge-reranker-base",
    ],
    # LLM & tokenizer
    "tokenizer": "google/gemma-3-1b-it",
    # Retrieval sizes
    "dense_k": 80,
    "bm25_k": 120,
    "rerank_k": 8,
    "rerank_k_retry": 12,
    # Compression thresholds
    "tau": 0.5,
    "tau_retry": 0.47,
    "max_sents": 8,
    "max_sents_retry": 10,
    # Context token cap
    "max_ctx_tokens": 1400,
    # Scoring
    "idk_str": "I don't know.",
    # Device flags (let backends pick GPU if available)
    "device": "cuda",
    "seed": 42,
}

random.seed(CFG["seed"])



Device set to use cuda:0


### **DATA PREPARATION (5 pts)**


1. **[CODE]** Call the provided function `build_base_docs` on `bio_corpus`. (0.5 pt)

2. **[CODE]** Call the provided function `make_snippets` on `base_docs`. (0.5 pt)

    * **[DISCUSSION]** What are the trade-offs to the following metrics if we kept each abstract as a single long document vs. very short snippets with one sentence each?

        - Recall, or the probability of retrieving the right document (1 pt)

        - Precision, or the probability that the retrieved text is directly relevant (1 pt)

        - LLM answer quality, given token limits and irrelevant context (1 pt)

    * **[DISCUSSION]** We chose snippets with 2-3 sentences as a middle ground. Why do you think this is a reasonable length? (1 pt)

This first step is about wrapping the dataset into a common format. BioASQ gives you passages and IDs as plain JSON; LangChain expects Document objects with page_content and metadata. By attaching the id to the metadata as doc_id, we ensure later retrievals can be traced back to the correct passage.
    

In [35]:
# =========================================
# Step 1: LOAD BIOASQ CORPUS  -> base_docs
# - Convert raw JSON records into Document objects
# - Attach doc_id to metadata so we can track provenance later
# =========================================

def build_base_docs(bio_corpus) -> List[Document]:
    # bio_corpus['test'] = [{"passage": "...", "id": ...}, ...]
    return [
        Document(page_content=rec["passage"], metadata={"doc_id": rec["id"]})
        for rec in bio_corpus["test"]
    ]

# TODO: call build_base_docs on bio_corpus
base_docs = build_base_docs(bio_corpus)

Now, long abstracts can overwhelm retrieval if treated as single units, so we split them into smaller “snippets” to make it easier for the retriever to pull only the relevant part.

Here, we chunk each abstract into 2–3 sentence snippets. That way, retrieval granularity is closer to what a QA system needs.

The regex splitter is deliberately simple; this is enough for demonstration without needing heavy NLP libraries.

In [36]:
# =========================================
# Step 2: SNIPPETIZATION
# - Break each abstract into 2–3 sentence chunks
# - This improves retriever granularity
# =========================================

# Provided helper: simple regex-based sentence splitter
_SENT_SPLIT = re.compile(r'(?<=[.!?])\s+')

def sent_tokenize_quick(text: str) -> List[str]:
    sents = [s.strip() for s in _SENT_SPLIT.split(text) if s.strip()]
    return sents if sents else [text.strip()]

def make_snippets(docs: List[Document], max_sents_per_snip: int = 3) -> List[Document]:
    out = []
    for d in docs:
        sents = sent_tokenize_quick(d.page_content)
        buf = []
        for s in sents:
            buf.append(s)
            if len(buf) >= max_sents_per_snip:
                out.append(Document(page_content=" ".join(buf), metadata=d.metadata))
                buf = []
        if buf:
            out.append(Document(page_content=" ".join(buf), metadata=d.metadata))
    return out

# TODO: call make_snippets on base_docs with max_sents_per_snip=3
snippet_docs = make_snippets(base_docs, max_sents_per_snip=3)


DISCUSSION (Step 2 - Snippetization)
Recall:
If we kept each abstract as one large document, recall would usually be higher because the entire passage is stored together. But it also means we might retrieve a very long chunk even if only a tiny part is relevant.  
If we used tiny one-sentence snippets, recall could drop because a single sentence may not contain enough keyword overlap or context to match the question.

Precision:  
Large documents cause low precision because we retrieve big blocks full of unrelated sentences. Short one-sentence snippets give high precision but risk missing context that the LLM needs.

LLM Answer Quality:  
If context is too long and noisy, the LLM gets distracted and may hallucinate. If it is too short (1 sentence), the LLM may not have enough context to answer correctly.  
Using 2-3 sentence snippets is a good middle ground: the chunk is small enough to be precise but long enough to preserve meaning.

Why 2-3 sentences is reasonable:  
Biomedical answers usually depend on a short mechanistic explanation or a clinical finding. Most useful information appears within a few consecutive sentences, so 2-3 sentence snippets capture the needed detail without overwhelming the LLM.


### **RETRIEVAL SETUP (3 PTS)**

3. **[CODE]** Call the provided function `build_retrievers` to build both retrievers on `snippet_docs`. (1 pt)

    * **[DISCUSSION]** What are the strengths and weaknesses of each retriever on biomedical text specifically? (1 pts)

    * **[DISCUSSION]** Why is BioASQ a setting where we especially need both? (1 pts)


Here, we construct two complementary retrievers: a **dense semantic retriever** and a **sparse lexical retriever**.

**Dense Retrieval: FAISS + Biomedical SBERT**

* Biomedical SBERT (Sentence-BERT):
    - A transformer model fine-tuned so that semantically similar sentences map to nearby vectors in embedding space.

    - The biomedical variant (e.g., pritamdeka/S-Biomed-Roberta-) is further tuned on PubMed / biomedical NLI (natural language inference) and STS (semantic textual similarity) tasks, so it “understands” biomedical jargon, synonyms, abbreviations (e.g., EGFR vs epidermal growth factor receptor).

* FAISS (Facebook AI Similarity Search):

    - A highly optimized vector index for nearest-neighbor search in high-dimensional space.

    - Here, we embed all snippet documents into 768-D vectors using SBERT, store them in FAISS, then at query time embed the question and run nearest-neighbor search.

    - Note that embeddings are normalized to discard vector length, so we are left with angular (cosine) similarity that can efficiently represent semantic closeness.

**Sparse Retrieval: BM25**

* BM25 (Best Matching 25):
    - A classic bag-of-words ranking function for information retrieval.

    - Scoring is based on exact term overlap, adjusted by:

        + Inverse Document Frequency (IDF): rare terms matter more.

        + Term Frequency Saturation: repeated words help up to a point, but not linearly.

        + Document Length Normalization: prevents long docs from dominating.

In [39]:
# =========================================
# Step 3.1: Build dual retrievers over snippets
# - Dense: FAISS (Facebook AI Similarity Search) index with biomedical SBERT (sentence-BERT) embeddings
# - Sparse: BM25 (Best Matching 25) lexical retriever
# - Output: two retrievers for hybrid recall
# =========================================

def build_retrievers(snippet_docs: List[Document]):
    emb = HuggingFaceEmbeddings(
        model_name=CFG["dense_embedder"],
        model_kwargs={"device": CFG["device"]},
        encode_kwargs={"normalize_embeddings": True}
    )
    # Vector store + dense retriever
    vectordb = FAISS.from_documents(snippet_docs, emb)
    dense_retriever = vectordb.as_retriever(search_kwargs={"k": CFG["dense_k"]})

    # Sparse retriever (pure BM25 over snippet text)
    bm25 = BM25Retriever.from_documents(snippet_docs)
    return dense_retriever, bm25

# TODO: Build both retrievers from snippet_docs
dense_retriever, bm25 = build_retrievers(snippet_docs)

DISCUSSION (Step 3.1 - Dense vs BM25)

Dense Retrieval Strengths:
- Understands biomedical synonyms (e.g., “PCSK9 inhibition” ≈ “reducing PCSK9 expression”).  
- Works well even when exact keywords don’t match.  
- Captures semantic meaning.

Dense Retrieval Weaknesses:
- Sometimes retrieves semantically similar but medically irrelevant text.  
- Can miss exact keyword matches like specific gene names.

BM25 Strengths:
- Excellent for exact biomedical terminology (drug names, gene markers, abbreviations).  
- Very fast and precise when the question contains rare keywords.

BM25 Weaknesses:
- Fails when wording is different (synonyms or paraphrases).  
- Does not understand context or meaning.

Why BioASQ needs both:
Biomedical text is full of exact technical terms *and* conceptual descriptions. Using both dense + BM25 increases recall and ensures we retrieve:  
- hard keyword matches (BM25)  
- meaningful semantic matches (dense)


In [38]:
!pip install rank_bm25

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [40]:
# =========================================
# Step 3.2: Hybrid retrieval — merge dense+BM25 and dedupe
# - Pull candidates from both retrievers
# - Slice BM25 by bm25_fetch_k (lexical recall budget)
# - Deduplicate by (doc_id, first_120_chars) to avoid near-duplicates
# - Output: merged candidate list for reranking
# =========================================

def hybrid_candidates(q: str, dense_retriever, bm25, bm25_fetch_k: int) -> List[Document]:
    dense = dense_retriever.get_relevant_documents(q) or []
    sparse = bm25.get_relevant_documents(q)[:bm25_fetch_k] or []
    seen, merged = set(), []
    for d in (dense + sparse):
        key = (d.metadata.get("doc_id", "NA"), d.page_content[:120])
        if key not in seen and d.page_content.strip():
            seen.add(key); merged.append(d)
    return merged

### **RERANKING & COMPRESSION (14 PTS)**

Here, we experiment with two common RAG variants that operate on the retrieval results before sending them to the LLM.

4. **[CODE]** Fill in the TODOs to initialize the reranker, create (query, snippet) pairs, predict relevance scores, and return the top-k ranked snippets. (5 pts)

    - **[DISCUSSION]** Why is reranking needed if we already have hybrid retrieval? What types of errors can it fix? (1 pt)

5. **[CODE]** Fill in the TODOs to split snippets into sentences, encode them with a lightweight biomedical encoder, compute similarity to the query, and filter by threshold τ and max_sents. (5 pts)

    - **[DISCUSSION]** What might happen if τ (the similarity threshold in compression) is set too high or too low? (2 pt)

    - **[DISCUSSION]** How does max_sents affect the trade-off between coverage and noise? (1 pts)

After hybrid retrieval, we often have dozens of candidate snippets. Many are weak matches or tangentially relevant. Reranking is the process of scoring each candidate against the query and reordering them by true relevance.

The process is as follows:

* Cross-Encoder Model: Unlike dense retrievers (which embed query and document separately), a cross-encoder reads the query and snippet together and outputs a direct relevance score.

    - Example: (“What is the mechanism of action of Inclisiran?”, “Inclisiran is an siRNA drug that inhibits PCSK9…”) → high score

    - This joint encoding lets it catch subtle biomedical phrasing differences that simple embeddings might miss.

* Candidate Pairing: For each candidate snippet, create a (query, snippet_text) pair.

* Prediction: Use the reranker’s .predict() method to compute a relevance score for each pair.

* Selection: Sort snippets by score (highest first) and keep the top k_final candidates for downstream QA.

In [42]:
# =========================================
# Step 4: Cross-Encoder Reranking
# - Load a pretrained reranker (biomedical if available)
# - Score each (query, snippet) pair with the reranker
# - Sort candidates by score, keep top-k for downstream QA
# =========================================

def load_reranker() -> CrossEncoder:
    last_err = None
    for name in CFG["rerankers_try"]:
        try:
            # TODO: initialize a CrossEncoder with device=CFG["device"], max_length=512
            ce = CrossEncoder(name, device=CFG["device"], max_length=512)
            print("Loaded reranker:", name)
            return ce
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"Failed to load all rerankers. Last error: {last_err}")

def rerank_snippets(q: str, docs: List[Document], reranker: CrossEncoder, k_final: int) -> List[Document]:
    if not docs:
        return []

    # TODO: form pairs of (query, doc_text) for each candidate snippet
    pairs = [(q, d.page_content) for d in docs]

    # TODO: run reranker.predict on pairs (hint: batch_size=32, convert_to_numpy=True)
    scores = reranker.predict(pairs, batch_size=32, convert_to_numpy=True)

    # TODO: sort doc indices by score (descending), take top k_final
    order = scores.argsort()[::-1][:k_final]

    return [docs[i] for i in order]

reranker = load_reranker()

config.json:   0%|          | 0.00/741 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

Loaded reranker: ncbi/MedCPT-Cross-Encoder


DISCUSSION (Step 4 - Why Reranking)

Hybrid retrieval gives us many candidates, but not all are truly relevant.  
A cross-encoder reads the question and snippet *together*, so it can detect subtle biomedical relationships that dense/BM25 alone may rank poorly.

Reranking fixes:
- When dense retrieval picks semantically similar but off-topic text  
- When BM25 retrieves something just because a drug name appears once  
- When multiple snippets come from the same abstract but only one truly answers the question

Overall, reranking reduces noisy retrieval and keeps only the strongest candidates.

Even after reranking, many candidate snippets are still verbose or only partly relevant. The goal of contextual compression is to shrink these snippets down to just the sentences that are most relevant to the question. This reduces noise and makes the context passed to the LLM much more focused.

The process is as follows:

* Sentence Splitting: Each snippet is broken into individual sentences (using `sent_tokenize_quick`).

* Encoding: Both the query and all candidate sentences are embedded with a biomedical encoder (`SentenceTransformer`).

* Similarity Scoring: We compute cosine similarity (dot product since embeddings are normalized) between the query and each sentence.

* Selection:

    - Sort all sentences by score (highest → lowest).

    - Keep sentences until either:

        - the similarity falls below threshold τ, or

        - the number of sentences reaches max_sents.

If no sentences survive, the insufficient status (as a boolean) will be returned to be true.


In [43]:
# =========================================
# Step 5: Contextual Compression (RAG Variant 2)
# - Break candidate snippets into sentences
# - Encode query + sentences with a lightweight biomed encoder
# - Score by cosine similarity, keep top sentences above tau
# - Return compressed context (or insufficient)
# =========================================

# Lightweight encoder for compression
_comp_enc = SentenceTransformer(CFG["dense_embedder"], device=CFG["device"])

def compress(q: str, docs_kept: List[Document], max_sents: int, tau: float):
    candidates, cand_ids = [], []
    for d in docs_kept:
        # TODO: split snippet into sentences using sent_tokenize_quick on page content
        sents = sent_tokenize_quick(d.page_content)
        for s in sents:
            if s:
                candidates.append(s)
                cand_ids.append(d.metadata.get("doc_id", "NA"))
    if not candidates:
        return [], [], True    # no sentences → immediate insufficient

    # TODO: encode (.encode from SentenceTransformer above) query and candidate sentences, set normalize_embeddings=True
    q_emb = _comp_enc.encode([q], normalize_embeddings=True)
    s_emb = _comp_enc.encode(candidates, normalize_embeddings=True)

    # Compute similarity scores (cosine = dot product since normalized)
    sims = (q_emb @ s_emb.T).ravel()

    # TODO: sort indices by score in descending order
    idx = sims.argsort()[::-1]

    picked_sents, picked_ids = [], []
    for i in idx:
        # TODO: stop if score below tau
        if sims[i] < tau:
            break

        # TODO: add the candidate sentence and its doc_id
        picked_sents.append(candidates[i])
        picked_ids.append(cand_ids[i])

        # TODO: stop if we’ve reached max_sents
        if len(picked_sents) >= max_sents:
            break

    # TODO: flag if no sentences were picked
    insufficient = len(picked_sents) == 0

    return picked_sents, picked_ids, insufficient

### **ANSWER GENERATION (8 PTS)**

6. Look over the given code for step 6 and run it.

    - **[DISCSSION]** Why is it important to include the “I don’t know” option in a retrieval-augmented QA system, especially for biomedical text? (1 pts)
    
    - **[DISCSSION]** What risks arise if we let the model freely use outside knowledge rather than constraining it to the given context? (1 pt)

7. **[CODE]** Implement the adaptive retry policy (4 pts)

8. **[DISCUSSION]** Run this visualization on your 5-question test set. Compare the reranker outputs with the baseline predictions and gold answers. In 1–2 sentences, summarize what you see: did reranking help reduce wrong answers, increase IDKs, or improve correctness overall compared to the baseline? (2 pts)

We now ask the LLM to answer using only the compressed context. The prompt enforces abstention (“I don’t know”) when evidence is weak. We also:

- Token-truncate context to a fixed budget to avoid overflow,

- Sanitize echoes (remove prompt fragments from outputs),

- Detect refusals as I don’t know / insufficient evidence.

In [44]:
# =========================================
# Step 6: ANSWER GENERATION
# - Truncate context to a fixed token budget
# - Force “use ONLY the context” + abstention policy
# - Clean echoed prompt text from the model output
# =========================================

tokenizer = AutoTokenizer.from_pretrained(CFG["tokenizer"])

def token_truncate(text: str, tokenizer, max_tokens: int) -> str:
    ids = tokenizer.encode(text, add_special_tokens=False)
    return tokenizer.decode(ids[:max_tokens])

def clean_answer_echo(ans: str, q: str) -> str:
    t = str(ans).strip()
    if "Answer:" in t:
        t = t.split("Answer:", 1)[-1].strip()
    t = re.sub(r"^Answer ONLY.*?Answer:\s*", "", t, flags=re.IGNORECASE|re.DOTALL)
    t = re.sub(r"Question:.*?Context:.*?Answer:\s*", "", t, flags=re.IGNORECASE|re.DOTALL)
    t = re.sub(rf"^{re.escape(q)}[\s\?]*", "", t, flags=re.IGNORECASE).strip()
    t = re.sub(r"\s+", " ", t).strip()
    return t

def is_refusal(txt: str) -> bool:
    t = str(txt).lower().strip()
    return (t.startswith("i don't know")
            or "insufficient evidence" in t
            or "cannot answer" in t
            or t == "")

def answer_with_context(q: str, sents: List[str], ids_used: List[Any], max_ctx_tokens: int):
    ctx = "\n".join(s.strip() for s in sents) if sents else ""
    if ctx:
        ctx = token_truncate(ctx, tokenizer, max_ctx_tokens)
    prompt = (
        "Use ONLY the context to answer concisely. "
        "If the answer is unclear or missing, say \"I don't know.\""
        f"\n\nQuestion: {q}\n\nContext:\n{ctx or '(EMPTY)'}\n\nAnswer:"
    )
    out = llm.invoke(prompt)
    text = getattr(out, "content", out)
    pred = clean_answer_echo(text, q)
    return pred, ids_used


DISCUSSION (Step 6 - Importance of “I don't know”)

Biomedical QA is high-risk because incorrect answers can be harmful.  
Including “I don’t know” forces the model to avoid hallucinating when evidence is missing.  
This improves trustworthiness.

If we allow the model to use outside knowledge:  
- It may hallucinate incorrect medical facts  
- It may answer from general world knowledge instead of retrieved evidence  
- It becomes impossible to verify source support

Constraining it to context ensures factual grounding.


At this point, we’ve built all the individual pieces: hybrid retrieval, reranking, compression, and answer generation. The adaptive retry block is where we stitch them together into a full QA pipeline.

The process is as follows:

- First pass: For each question, we retrieve, rerank, compress, and generate an answer.

- Check refusal: If the model answers “I don’t know” (or equivalent), we don’t give up immediately.

- Retry with relaxed parameters:

    + Rerun reranking with a wider candidate pool (`k_retry`).

    + Rerun compression with looser constraints (`tau_retry`, `max_sents_retry`).

    + Try answering again with this expanded evidence set.

- Accept better answer if found: If the retry yields a non-refusal answer, we replace the original refusal. Otherwise, we keep the initial “I don’t know.”

This design allows the pipeline to be somewhat careful first, but still opportunistic on retry when there might be borderline evidence worth considering.

In [45]:
# =========================================
# Step 7: Adaptive Retry (policy + instrumentation)
# - Run hybrid retrieve → rerank → compress → answer
# - If refusal ("I don't know"), retry ONCE with a relaxed policy
# =========================================

def qa_pipeline(questions: List[str],
                dense_retriever,
                bm25,
                reranker: CrossEncoder,
                k_final: int,
                k_retry: int,
                tau: float,
                tau_retry: float,
                max_sents: int,
                max_sents_retry: int,
                max_ctx_tokens: int):
    preds, ids = [], []
    for q in questions:
        # For 554 maybe make them pierce these together themselves

        # S3.2 Hybrid Retrieve
        cands = hybrid_candidates(q, dense_retriever, bm25, CFG["bm25_k"])

        # S4: Rerank
        kept  = rerank_snippets(q, cands, reranker, k_final)

        # S5: Compress
        sents, ids_used, insufficient = compress(q, kept, max_sents=max_sents, tau=tau)
        if insufficient:
            preds.append(CFG["idk_str"]); ids.append([]); continue

        # S6: Answer
        pred, used = answer_with_context(q, sents, ids_used, max_ctx_tokens=max_ctx_tokens)

        if is_refusal(pred):

            # TODO: rerun reranking with a wider k (hint: use k_retry instead of k_final)
            kept_retry = rerank_snippets(q, cands, reranker, k_retry)

            # TODO: rerun compression with relaxed (_retry) parameters
            sents_retry, ids_retry, insufficient2 = compress(
                q, kept_retry,
                max_sents=max_sents_retry,
                tau=tau_retry
            )
            if not insufficient2:
                # TODO: call answer_with_context again using the retry sentences
                pred2, used2 = answer_with_context(
                    q, sents_retry, ids_retry,
                    max_ctx_tokens=max_ctx_tokens
                )

                # TODO: if retry answer is not a refusal, overwrite pred/used with pred/used2
                if not is_refusal(pred2):
                    pred, used = pred2, used2

        preds.append(pred); ids.append(used)
    return preds, ids

# Hybrid QA pipeline using helpers
preds_pipeline, ids_pipeline = qa_pipeline(
    questions=questions5,
    dense_retriever=dense_retriever,
    bm25=bm25,
    reranker=reranker,
    k_final=CFG["rerank_k"],         # Step 4: rerank_snippets
    k_retry=CFG["rerank_k_retry"],   # Step 7: adaptive retry
    tau=CFG["tau"],                  # Step 5: compress threshold
    tau_retry=CFG["tau_retry"],      # Step 7: retry threshold
    max_sents=CFG["max_sents"],      # Step 5: compression sentence cap
    max_sents_retry=CFG["max_sents_retry"],
    max_ctx_tokens=CFG["max_ctx_tokens"],  # Step 6: answer_with_context token budget
)

In [46]:
# =========================================
# Step 8: Render
# - Visualize baseline vs reranker outputs side by side
# - Include question, gold answer, predictions, and supporting doc IDs
# - This helps us compare retrieval+QA variants directly
# =========================================

from IPython.display import display, HTML
import html

def render_baseline_vs_variant(questions, base_preds, variant_preds, golds,
                               base_ids, variant_ids):
    cards = []
    for i, (q, bpred, vpred, g, bids, vids) in enumerate(
        zip(questions, base_preds, variant_preds, golds, base_ids, variant_ids)
    ):
        cards.append(f"""
        <div style="border:1px solid #ddd;border-radius:10px;padding:12px;margin:10px 0;">
          <div style="color:#666">Example {i+1}</div>

          <div style="font-weight:700;margin-top:4px;">Question</div>
          <div>{html.escape(q)}</div>

          <div style="display:flex;gap:16px;margin-top:10px;">
            <div style="flex:1;">
              <div style="font-weight:700;">Baseline Prediction</div>
              <div>{html.escape(str(bpred)) if bpred else "<i>(empty)</i>"}</div>
            </div>
            <div style="flex:1;">
              <div style="font-weight:700;">Pipeline Prediction</div>
              <div>{html.escape(str(vpred)) if vpred else "<i>(empty)</i>"}</div>
            </div>
            <div style="flex:1;">
              <div style="font-weight:700;">Gold</div>
              <div>{html.escape(str(g))}</div>
            </div>
          </div>

          <div style="font-weight:700;margin-top:10px;">Doc IDs</div>
          <div style="font-size:13px;"><b>Baseline:</b> {html.escape(str(bids))}</div>
          <div style="font-size:13px;"><b>Pipeline:</b> {html.escape(str(vids))}</div>
        </div>
        """)
    display(HTML("".join(cards)))

# Render comparison
render_baseline_vs_variant(
    questions5,
    rag_predicted_answers,   # baseline predictions
    preds_pipeline,          # full pipeline (hybrid + rerank + compress + retry + answer)
    answers5,                # gold labels
    retrieved_ids,           # baseline doc IDs
    ids_pipeline             # pipeline doc IDs
)

DISCUSSION (Step 8 - Comparison Summary)

From the visualization, I can clearly see that the hybrid pipeline (hybrid retrieval + rerank + compression + retry) behaves more accurately and more safely compared to the baseline.  
A few noticeable differences:

- The pipeline reduces wrong or overconfident answers.  
- In cases where evidence is weak, the pipeline answers “I don't know” instead of hallucinating.  
- When relevant evidence exists, the pipeline usually produces a cleaner, more concise answer aligned with the gold label.

Overall, reranking + compression improves the relevance of retrieved context, which helps the LLM avoid hallucination and stay closer to the ground-truth biomedical answer.


### **EXTRA CREDIT: LLM-AS-A-JUDGE EVALUATION (5 PTS)**

9. **[EXPERIMENT]** Paste your results into a sophisticated chatbot, such as ChatGPT on GPT-5, and have it evaluate the results of the outputs of the baseline vs. our advanced pipeline. You should come up with a prompt that asks it to evaluate the results with a scoring system. Justify the validity of your scoring system (hint: remember we favor IDK over wrong answers), and share screenshots or a shareable link to the chatbot. (2 pts)

10. Depending on the results of the experiment in Step 9 (3 pts):

    - **[DISCUSSION]** If your results from the advanced pipeline are better, note the flaws and discuss potential improvements to the pipeline to enhance the results even further.

    - **[EXPERIMENT]** If your results from the advanced pipeline are worse, what might've gone wrong? Go back and adjust certain parameters and do a few more tests before arriving a conclusion. Note that the default ones given worked when creating this assignment, but there may be uncertainties due to randomness or differences in evaluation techniques.