In [None]:
# 🧠 RAGBench Evaluation Notebook (Qwen3 RAG vs No-RAG)
# ----------------------------------------------------
# This notebook compares performance of Qwen3-based models with and without RAG.
# Evaluation metrics: Word-level F1 and Exact Match (EM).

# =====================================================
# 📦 1. Setup and Installation
# =====================================================
!pip install -q transformers accelerate bitsandbytes datasets faiss-cpu sentence_transformers tqdm textwrap3
!pip install -q git+https://github.com/huggingface/transformers.git

import os
import json
import time
from tqdm.auto import tqdm
import numpy as np
import faiss
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import re
import pandas as pd
from google.colab import userdata

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.4/41.4 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m133.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.3/564.3 kB[0m [31m46.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for transformers (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sentence-transformers 5.1.1 requires transformers<5.0.0,>=4.41.0, but you have transformers 5.0.0.dev0 which is incompatible.[0m[31m
[0m

In [None]:
# =====================================================
# ⚙️ 2. Configuration
# =====================================================

HF_TOKEN = userdata.get('HF_Token')
if HF_TOKEN is None:
  print("Warning: HF_TOKEN not set. You may hit rate limits or fail to download gated models.")

EMBED_MODEL = "Qwen/Qwen3-Embedding-0.6B"
GEN_MODEL = "Qwen/Qwen3-4B-Instruct-2507"
DATASET_NAME = "galileo-ai/ragbench"
SPLIT = "test"
MAX_DOCS = 2000
CHUNK_SIZE = 512
CHUNK_OVERLAP = 64
BATCH_SIZE_EMBED = 32
TOP_K = 5
MAX_GEN_TOKENS = 200

In [None]:
# =====================================================
# 📚 3. Load Dataset (RAGBench)
# =====================================================
print("Loading RAGBench dataset (may take a while)...")
dset = load_dataset(DATASET_NAME, "covidqa", split='test')
print("Total samples in split:", len(dset))

if MAX_DOCS and len(dset) > MAX_DOCS:
  dset = dset.select(range(MAX_DOCS))
  print(f"Dataset truncated to first {MAX_DOCS} samples for demo.")

print(dset[0])

Loading RAGBench dataset (may take a while)...
Total samples in split: 246
{'id': '1421', 'question': 'Which viruses may not cause prolonged inflammation due to strong induction of antiviral clearance?', 'documents': ['Title: Type I Interferon Receptor Deficiency in Dendritic Cells Facilitates Systemic Murine Norovirus Persistence Despite Enhanced Adaptive Immunity\nPassage: successful treatment for HCV serves to circumvent the viral inhibition of IFN induction. Thus, HCV may be an example of a medically relevant persistent viral infection that persists due, in part, to loss of innate immune function. Persistence of other continuously replicating RNA viruses, such as chikungunya, measles, polyomavirus, may be similarly due to ineffective innate responses.', 'Title: Type I Interferon Response Is Delayed in Human Astrovirus Infections\nPassage: Results suggest that HAstV infection is not able to disrupt the innate immune sensing pathway induced by polyI:C . Only a previous infection with

In [None]:
# =====================================================
# ✂️ 4. Chunking and Corpus Building
# =====================================================
def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
  words = text.split()
  chunks = []
  i = 0
  while i < len(words):
    chunk = " ".join(words[i:i+chunk_size])
    chunks.append(chunk)
    i += chunk_size - overlap
  return chunks

corpus_texts, corpus_meta = [], []
print("Building chunked corpus from dataset...")
for idx, sample in enumerate(tqdm(dset)):
  docs = sample.get("documents") or sample.get("contexts") or []
  if isinstance(docs, str):
    docs = [docs]
  for doc_id, doc in enumerate(docs):
    if not doc:
      continue
    for cidx, ch in enumerate(chunk_text(doc)):
      corpus_meta.append({
          "sample_idx": int(idx),
          "doc_id": int(doc_id),
          "chunk_id": int(cidx),
          "domain": sample.get("domain", None)
          })
      corpus_texts.append(ch)

print(f"Total corpus chunks: {len(corpus_texts)}")


Building chunked corpus from dataset...


  0%|          | 0/246 [00:00<?, ?it/s]

Total corpus chunks: 984


In [None]:
# =====================================================
# 🔢 5. Embedding and FAISS Index
# =====================================================
print("Loading embedding model: ", EMBED_MODEL)
embed_model = SentenceTransformer(EMBED_MODEL, device='cuda' if torch.cuda.is_available() else 'cpu')

all_embs = []
for i in tqdm(range(0, len(corpus_texts), BATCH_SIZE_EMBED)):
  batch = corpus_texts[i:i+BATCH_SIZE_EMBED]
  embs = embed_model.encode(batch, show_progress_bar=False, convert_to_numpy=True)
  all_embs.append(embs)
all_embs = np.vstack(all_embs).astype('float32')

np.save("corpus_embeddings.npy", all_embs)
with open("corpus_meta.json", "w") as f:
  json.dump(corpus_meta, f)

d = all_embs.shape[1]
index = faiss.IndexHNSWFlat(d, 32)
index.hnsw.efConstruction = 200
index.add(all_embs)
faiss.write_index(index, "faiss_qwen_hnsw.index")
print("FAISS index built and saved.")

Loading embedding model:  Qwen/Qwen3-Embedding-0.6B


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.19G [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

  0%|          | 0/31 [00:00<?, ?it/s]

FAISS index built and saved.


In [None]:
# =====================================================
# 🧩 6. Load Generator Model (Qwen3-4B-Instruct)
# =====================================================
print("Loading generator model (may be large). Use device_map='auto' for Colab GPU if available.")

tok = AutoTokenizer.from_pretrained(GEN_MODEL, use_fast=True)
try:
  gen_model = AutoModelForCausalLM.from_pretrained(
  GEN_MODEL,
  device_map='auto',
  load_in_8bit=True,
  trust_remote_code=True,
  )
except Exception as e:
  print("8-bit load failed, fallback to float16:", e)
  gen_model = AutoModelForCausalLM.from_pretrained(
  GEN_MODEL,
  device_map='auto',
  torch_dtype=torch.float16,
  trust_remote_code=True,
  )

with open("corpus_meta.json", "r") as f:
  corpus_meta = json.load(f)

Loading generator model (may be large). Use device_map='auto' for Colab GPU if available.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/3.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/99.6M [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


8-bit load failed, fallback to float16: Qwen3ForCausalLM.__init__() got an unexpected keyword argument 'load_in_8bit'


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/238 [00:00<?, ?B/s]

In [None]:
# =====================================================
# 🔍 7. Retrieval and Generation Functions
# =====================================================
def retrieve_top_k(query, k=TOP_K):
  q_emb = embed_model.encode([query], convert_to_numpy=True).astype('float32')
  D, I = index.search(q_emb, k)
  return [(i, corpus_texts[i], corpus_meta[i]) for i in I[0].tolist()], D[0]

def gen_answer_no_rag(query, max_new_tokens=MAX_GEN_TOKENS):
  prompt = f"Answer the following question concisely. If you don't know, say 'I don't know.'\nQuestion: {query}\nAnswer:"
  inputs = tok(prompt, return_tensors='pt').to(gen_model.device)
  out = gen_model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
  return tok.decode(out[0], skip_special_tokens=True)

def gen_answer_with_rag(query, k=TOP_K, max_new_tokens=MAX_GEN_TOKENS):
  retrieved, D = retrieve_top_k(query, k=k)
  contexts = "\n\n".join([f"Context {i+1}: {t}" for i, t, m in retrieved])
  prompt = f"Use the retrieved contexts to answer accurately. If not found, say 'I don't know.'\n{contexts}\nQuestion: {query}\nAnswer:"
  inputs = tok(prompt, return_tensors='pt', truncation=True, max_length=4096).to(gen_model.device)
  out = gen_model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
  return tok.decode(out[0], skip_special_tokens=True), retrieved, D

In [None]:
# =====================================================
# 🧾 8. Evaluation Metrics (Word-level F1 + Exact Match)
# =====================================================
def simple_f1(pred, ref):
  pred_tokens = set(pred.lower().split())
  ref_tokens = set(ref.lower().split())
  common = len(pred_tokens & ref_tokens)
  if len(pred_tokens) == 0 or len(ref_tokens) == 0:
    return 0.0
  precision = common / len(pred_tokens)
  recall = common / len(ref_tokens)
  if precision + recall == 0:
    return 0.0
  return 2 * (precision * recall) / (precision + recall)

def normalize_text(s):
  s = re.sub(r"[^a-z0-9\s]", "", s.lower().strip())
  return s

def exact_match(pred, ref):
  return 1.0 if normalize_text(pred) == normalize_text(ref) else 0.0


In [None]:
# =====================================================
# 🧪 9. Run Evaluation (RAG vs No-RAG)
# =====================================================
N_EVAL = min(50, len(dset))
results = []
print(f"Evaluating on {N_EVAL} samples (word-level F1 + EM). This will take time.")

for i in tqdm(range(N_EVAL)):
  sample = dset[i]
  query = sample['question']
  gold = sample.get('response')

  if not gold.strip():
    continue

  pred_no = gen_answer_no_rag(query)
  pred_rag, _, _ = gen_answer_with_rag(query)

  f1_no = simple_f1(pred_no, gold)
  f1_r = simple_f1(pred_rag, gold)
  em_no = exact_match(pred_no, gold)
  em_r = exact_match(pred_rag, gold)

  results.append({
      'idx': i,
      'domain': sample.get('domain', 'NA'),
      'f1_no': f1_no,
      'f1_rag': f1_r,
      'em_no': em_no,
      'em_rag': em_r
      })


Evaluating on 50 samples (word-level F1 + EM). This will take time.


  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# =====================================================
# 📈 11. Recall@K Check
# =====================================================
recall_counts = 0
for i in range(N_EVAL):
  sample = dset[i]
  query = sample['question']
  gold_sample_idx = i
  retrieved, D = retrieve_top_k(query, k=TOP_K)
  retrieved_sample_idxs = [m['sample_idx'] for (_, _, m) in retrieved]
  if gold_sample_idx in retrieved_sample_idxs:
    recall_counts += 1

print(f"Recall@{TOP_K}:", recall_counts / N_EVAL)
print('✅ Notebook complete. Save FAISS index, embeddings, and eval results for further analysis.')

Recall@5: 1.0
✅ Notebook complete. Save FAISS index, embeddings, and eval results for further analysis.


In [None]:
# =====================================================
# 🧮 12. Report Summary (F1 Scores)
# =====================================================

f1_no_avg = np.mean([r['f1_no'] for r in results])
f1_rag_avg = np.mean([r['f1_rag'] for r in results])
em_no_avg = np.mean([r['em_no'] for r in results])
em_rag_avg = np.mean([r['em_rag'] for r in results])

print("========== RAGBench Evaluation Summary ==========")
print(f"Average Word-level F1 (No-RAG): {f1_no_avg:.4f}")
print(f"Average Word-level F1 (RAG):    {f1_rag_avg:.4f}")
print(f"Average Exact Match (No-RAG):   {em_no_avg:.4f}")
print(f"Average Exact Match (RAG):      {em_rag_avg:.4f}")

improvement = ((f1_rag_avg - f1_no_avg) / f1_no_avg * 100) if f1_no_avg > 0 else 0
print(f"F1 Improvement with RAG: {improvement:.2f}%")

Average Word-level F1 (No-RAG): 0.2070
Average Word-level F1 (RAG):    0.1707
Average Exact Match (No-RAG):   0.0000
Average Exact Match (RAG):      0.0000
F1 Improvement with RAG: -17.52%
