In [1]:
# Installation and Imports

!pip install -q datasets sentence-transformers transformers faiss-cpu numpy pandas

from datasets import load_dataset
import numpy as np
from typing import List
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import faiss
import re

print("All packages imported successfully")

All packages imported successfully


In [2]:
# =============================================================================
# 1. Configuration
# =============================================================================

RAG_CONFIG = {
    "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
    "vector_dim": 384,
    "chunk_size": 600,
    "batch_size": 64,
    "initial_retrieval": 30,  # Retrieve more for reranking
    "final_top_k": 5,  # After reranking
    "query_variations": 2,
    "context_budget": 2000
}

print("\n" + "="*80)
print("CONFIGURATION")
print("="*80)
for k, v in RAG_CONFIG.items():
    print(f"  {k}: {v}")
print("="*80)



CONFIGURATION
  embedding_model: sentence-transformers/all-MiniLM-L6-v2
  vector_dim: 384
  chunk_size: 600
  batch_size: 64
  initial_retrieval: 30
  final_top_k: 5
  query_variations: 2
  context_budget: 2000


In [3]:
# =============================================================================
# 2. Load Data
# =============================================================================

print("\nLoading dataset...")

corpus = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus")
qa_set = load_dataset("rag-datasets/rag-mini-wikipedia", "question-answer")

passages = corpus["passages"]
test_questions = qa_set["test"]

# Limit for experimentation
MAX_DOCS = 1000
passages = passages.select(range(min(MAX_DOCS, len(passages))))

print(f"Working with {len(passages)} documents")


Loading dataset...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Working with 1000 documents


In [4]:
# =============================================================================
# 3. Chunking
# =============================================================================

def split_into_chunks(text: str, size: int) -> List[str]:
    """Split text into fixed character chunks."""
    if not text:
        return []
    return [text[i:i+size] for i in range(0, len(text), size)]

# Create chunks
all_chunks = []
for doc_idx, doc in enumerate(passages):
    text_content = doc.get("passage", "")
    chunks = split_into_chunks(text_content, RAG_CONFIG["chunk_size"])

    for chunk_idx, chunk_text in enumerate(chunks):
        all_chunks.append({
            "id": f"{doc_idx}-{chunk_idx}",
            "text": chunk_text,
            "doc_id": doc_idx
        })

print(f"Created {len(all_chunks)} chunks")


Created 1289 chunks


In [5]:
# =============================================================================
# 4. Generate Embeddings
# =============================================================================

print("\nGenerating embeddings...")

encoder = SentenceTransformer(RAG_CONFIG["embedding_model"])

texts = [c["text"] for c in all_chunks]
vectors = encoder.encode(
    texts,
    batch_size=RAG_CONFIG["batch_size"],
    show_progress_bar=True,
    normalize_embeddings=True,
    convert_to_numpy=True
).astype("float32")

print(f"Embeddings shape: {vectors.shape}")



Generating embeddings...


Batches:   0%|          | 0/21 [00:00<?, ?it/s]

Embeddings shape: (1289, 384)


In [6]:

# =============================================================================
# 5. Build FAISS Index
# =============================================================================

faiss_index = faiss.IndexFlatIP(RAG_CONFIG["vector_dim"])
faiss_index.add(vectors)

print(f"FAISS index contains {faiss_index.ntotal} vectors")


FAISS index contains 1289 vectors


In [7]:
# =============================================================================
# 6. FLAN-T5 Setup
# =============================================================================

print("\nLoading Flan-T5...")

llm_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
llm_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
text_gen = pipeline("text2text-generation", model=llm_model, tokenizer=llm_tokenizer)

def clean_text(text: str) -> str:
    """Normalize text."""
    text = re.sub(r"\s+", " ", text.strip())
    return text

print("LLM ready")


Loading Flan-T5...


Device set to use cuda:0


LLM ready


In [8]:
# =============================================================================
# 7. Enhancement #1: Query Rewriting
# =============================================================================

def generate_query_rewrites(query: str, num_rewrites: int = RAG_CONFIG["query_variations"]) -> List[str]:
    """
    Generate query variations using LLM.
    """
    cleaned = clean_text(query)

    rewrite_templates = [
        f"Paraphrase this question: {cleaned}",
        f"Reformulate this query: {cleaned}",
        f"Express this differently: {cleaned}"
    ]

    all_queries = [cleaned]  # Include original

    for template in rewrite_templates[:num_rewrites]:
        try:
            output = text_gen(
                template,
                max_new_tokens=64,
                num_beams=2,
                do_sample=False
            )[0]["generated_text"].strip()

            if output and len(output) > 10 and output != cleaned:
                all_queries.append(clean_text(output))
        except:
            continue

    # Deduplicate
    unique_queries = []
    seen = set()
    for q in all_queries:
        normalized = q.lower()
        if normalized not in seen:
            seen.add(normalized)
            unique_queries.append(q)

    return unique_queries


In [9]:
# =============================================================================
# 8. Initial Retrievel
# =============================================================================

def initial_retrieval(query: str, top_k: int = RAG_CONFIG["initial_retrieval"]) -> List[dict]:
    """
    Retrieve candidates using query rewriting.
    """
    query_list = generate_query_rewrites(query)

    print(f"\nQuery variations ({len(query_list)}):")
    for i, q in enumerate(query_list, 1):
        print(f"  {i}. {q}")

    # Encode all query variations
    query_vecs = encoder.encode(
        query_list,
        normalize_embeddings=True,
        convert_to_numpy=True
    ).astype("float32")

    # Search with each variation
    seen_indices = {}

    for qvec in query_vecs:
        distances, indices = faiss_index.search(np.array([qvec]), top_k)

        for dist, idx in zip(distances[0], indices[0]):
            # Keep highest score for each chunk
            if idx not in seen_indices or dist > seen_indices[idx]["score"]:
                seen_indices[idx] = {
                    "index": int(idx),
                    "score": float(dist),
                    "text": all_chunks[idx]["text"],
                    "chunk_id": all_chunks[idx]["id"]
                }

    # Sort by score
    candidates = sorted(seen_indices.values(), key=lambda x: x["score"], reverse=True)

    print(f"Retrieved {len(candidates)} unique candidates")

    return candidates[:top_k]


In [10]:
# =============================================================================
# 9. Enhancement #2: Cross-Encoder Reranking
# =============================================================================

print("\nLoading CrossEncoder for reranking...")

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512)

print("Reranker ready")

def rerank_candidates(query: str, candidates: List[dict], top_k: int = RAG_CONFIG["final_top_k"]) -> List[dict]:
    """
    Rerank using CrossEncoder.
    """
    if not candidates:
        return []

    # Prepare query-passage pairs
    pairs = [(query, cand["text"]) for cand in candidates]

    # Get reranking scores
    rerank_scores = reranker.predict(
        pairs,
        convert_to_numpy=True,
        show_progress_bar=False,
        batch_size=32
    )

    # Combine with original candidates
    for i, cand in enumerate(candidates):
        cand["rerank_score"] = float(rerank_scores[i])

    # Sort by rerank score
    reranked = sorted(candidates, key=lambda x: x["rerank_score"], reverse=True)

    print(f"Reranked to top {top_k}")

    return reranked[:top_k]


Loading CrossEncoder for reranking...
Reranker ready


In [11]:
# =============================================================================
# 10. Context Assembly
# =============================================================================

def build_context(chunks: List[dict], budget: int = RAG_CONFIG["context_budget"]) -> tuple:
    """
    Build context with citation markers.
    """
    context_parts = []
    citations = []
    chars_used = 0

    for chunk in chunks:
        chunk_id = chunk["chunk_id"]
        score = chunk["rerank_score"]
        text = chunk["text"]

        # Add citation marker
        marker = f"[{chunk_id} | score: {score:.3f}]"
        segment = f"{marker}\n{text}"

        # Check budget
        if chars_used + len(segment) > budget:
            remaining = budget - chars_used
            if remaining > 100:
                context_parts.append(segment[:remaining] + "...")
                citations.append({"id": chunk_id, "score": score})
            break

        context_parts.append(segment)
        citations.append({"id": chunk_id, "score": score})
        chars_used += len(segment)

    full_context = "\n\n".join(context_parts)
    return full_context, citations


In [12]:
# =============================================================================
# 11. Answer Generation
# =============================================================================

def build_prompt(context: str, question: str) -> str:
    """
    Persona prompt (best from Step 3).
    """
    return f"""You are a subject matter expert. Use only the context.
If the answer is not in the context, say 'I don't know'. Be direct.

Context: {context}
Question: {question}
Answer:"""

def generate_rag_answer(question: str) -> tuple:
    """
    Complete RAG pipeline: query rewriting + retrieval + reranking + generation.
    """
    print(f"\n{'='*80}")
    print(f"Question: {question}")
    print(f"{'='*80}")

    # Step 1: Initial retrieval with query rewriting
    candidates = initial_retrieval(question)

    # Step 2: Rerank candidates
    reranked = rerank_candidates(question, candidates)

    # Step 3: Build context
    context, refs = build_context(reranked)

    print(f"\nContext assembled: {len(context)} chars from {len(refs)} chunks")

    # Step 4: Generate answer
    prompt = build_prompt(context, question)

    answer = text_gen(
        prompt,
        max_new_tokens=256,
        do_sample=False
    )[0]["generated_text"].strip()

    return answer, refs

In [14]:
# =============================================================================
# 12. Test with 2 questions
# =============================================================================

print("\n" + "─" * 80)
print("TESTING RAG SYSTEM - MULTIPLE QUERIES")
print("─" * 80)

# Select first two test questions
test_questions_list = [
    test_questions[0]["question"],
    test_questions[1]["question"]
]

for q_num, test_q in enumerate(test_questions_list, 1):
    print(f"\n{'='*80}")
    print(f"TEST QUERY #{q_num}")
    print(f"{'='*80}")
    print(f"Query: {test_q}")

    answer, citations = generate_rag_answer(test_q)

    print(f"\n{'─'*80}")
    print("RESULTS")
    print(f"{'─'*80}")

    print(f"\nAnswer: {answer}")

    print("\nSource Citations:")
    print("┌──────┬─────────────┬──────────────┐")
    print("│ Rank │  Chunk ID   │   Relevance  │")
    print("├──────┼─────────────┼──────────────┤")
    for idx, cite in enumerate(citations, 1):
        chunk_id = cite.get('id', 'N/A')
        relevance = cite.get('score', 0.0)
        print(f"│  {idx:<3} │ {chunk_id:<11} │ {relevance:>12.4f} │")
    print("└──────┴─────────────┴──────────────┘")

    print(f"\n{'─'*80}\n")

print("\n" + "─" * 80)
print("ALL TESTS COMPLETED")
print("─" * 80)

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset



────────────────────────────────────────────────────────────────────────────────
TESTING RAG SYSTEM - MULTIPLE QUERIES
────────────────────────────────────────────────────────────────────────────────

TEST QUERY #1
Query: Was Abraham Lincoln the sixteenth President of the United States?

Question: Was Abraham Lincoln the sixteenth President of the United States?

Query variations (2):
  1. Was Abraham Lincoln the sixteenth President of the United States?
  2. Abraham Lincoln was the sixteenth president of the United States.
Retrieved 35 unique candidates
Reranked to top 5

Context assembled: 1936 chars from 4 chunks

────────────────────────────────────────────────────────────────────────────────
RESULTS
────────────────────────────────────────────────────────────────────────────────

Answer: Abraham Lincoln (February 12, 1809 â April 15, 1865) was the sixteenth President of the United States, serving from March 4, 1861 until his assassination.

Source Citations:
┌──────┬─────────────