In [1]:
# REFRAG Prototype Colab Notebook
# ==================================
# Proof-of-concept for REFRAG (Rethinking RAG-based Decoding).
# Run all cells in sequence.

# --- Install dependencies ---
!pip install sentence-transformers faiss-cpu transformers torch tqdm --quiet

# --- Imports ---
import torch, faiss, math, random
import torch.nn as nn
import torch.optim as optim
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from typing import List, Dict, Tuple

# --- Config ---
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "gpt2"
CHUNK_TOKENS = 64
TOP_K = 12
INITIAL_EXPAND = 2
MAX_EXPANSIONS = 6
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- Utilities: chunk text ---
def chunk_text(text: str, tokens_per_chunk: int = CHUNK_TOKENS) -> List[str]:
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        chunks.append(" ".join(words[i:i + tokens_per_chunk]))
        i += tokens_per_chunk
    return chunks

# --- Embedding Index ---
class EmbeddingIndex:
    def __init__(self, embed_model_name=EMBED_MODEL):
        self.encoder = SentenceTransformer(embed_model_name)
        self.dim = self.encoder.get_sentence_embedding_dimension()
        self.index = None
        self.metadata = []
        self.embeddings = None

    def build_from_docs(self, docs: Dict[str, str]):
        import numpy as np
        all_chunks, meta = [], []
        for doc_id, text in docs.items():
            chunks = chunk_text(text)
            all_chunks.extend(chunks)
            for c in chunks:
                meta.append({"text": c, "doc_id": doc_id})
        embs = self.encoder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
        self.embeddings = embs
        self.index = faiss.IndexFlatIP(self.dim)
        faiss.normalize_L2(self.embeddings)
        self.index.add(self.embeddings)
        self.metadata = meta

    def query(self, q_text: str, k=TOP_K) -> List[Tuple[float, Dict]]:
        import numpy as np
        q_emb = self.encoder.encode([q_text], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        D, I = self.index.search(q_emb, k)
        results = []
        for score, idx in zip(D[0], I[0]):
            meta = self.metadata[idx].copy()
            meta['score'] = float(score)
            results.append((float(score), meta))
        return results

# --- Policy Network ---
class PolicyNet(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim * 2 + 1, 128),
            nn.ReLU(),
            nn.Linear(128, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, q_emb, chunk_emb, decoder_proxy):
        x = torch.cat([q_emb, chunk_emb, decoder_proxy.unsqueeze(-1)], dim=-1)
        return self.net(x).squeeze(-1)

# --- LLM Wrapper ---
class LLMWrapper:
    def __init__(self, model_name=LLM_MODEL, device=DEVICE):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        self.device = device

    def generate(self, prompt: str, max_new_tokens=64):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        out = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

    def next_token_entropy_proxy(self, prompt: str):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            ent = -torch.sum(probs * torch.log(probs + 1e-12), dim=-1)
            return float(ent.cpu().item())

# --- REFRAG-like Answer ---
def refrag_answer(query: str, index: EmbeddingIndex, llm: LLMWrapper, policy: PolicyNet=None):
    top = index.query(query, k=TOP_K)
    expansions = {i: m['text'] for i, (s, m) in enumerate(top[:INITIAL_EXPAND])}

    def build_prompt(expansions_map):
        expanded_texts = []
        for i, (s, m) in enumerate(top):
            if i in expansions_map:
                expanded_texts.append(f"CHUNK_{i} (expanded):\n{expansions_map[i]}\n")
            else:
                expanded_texts.append(f"CHUNK_{i} (compressed): {m['text'][:50]}...")
        ctx = "\n".join(expanded_texts)
        return f"Use the chunks below.\n\nContext:\n{ctx}\n\nQuery: {query}\nAnswer:"

    generated = ""
    expansions_done = len(expansions)

    while True:
        prompt = build_prompt(expansions) + "\n" + generated
        out = llm.generate(prompt, max_new_tokens=32)
        newly = out[len(prompt):] if out.startswith(prompt) else out
        generated += newly
        ent = llm.next_token_entropy_proxy(prompt + newly)

        if policy is not None and expansions_done < MAX_EXPANSIONS:
            import numpy as np
            q_emb = torch.tensor(index.encoder.encode([query]), dtype=torch.float32).to(DEVICE)
            candidates, cand_idx = [], []
            for i, (s, m) in enumerate(top):
                if i in expansions: continue
                cand_idx.append(i); candidates.append(m['text'])
            if candidates:
                cand_embs = torch.tensor(index.encoder.encode(candidates), dtype=torch.float32).to(DEVICE)
                q_emb_rep = q_emb.repeat(cand_embs.size(0), 1)
                ent_rep = torch.tensor([ent]*cand_embs.size(0), dtype=torch.float32).to(DEVICE)
                scores = policy(q_emb_rep, cand_embs, ent_rep)
                best_idx = torch.argmax(scores).item()
                expansions[cand_idx[best_idx]] = candidates[best_idx]
                expansions_done += 1
                if float(scores[best_idx].cpu()) < 0.1:
                    break
            else: break
        else:
            break
    return generated

# --- Demo ---
def run_demo():
    docs = {
        "doc1": "Paris is the capital of France, known for the Eiffel Tower, Louvre, art, fashion, and culture.",
        "doc2": "Python is a programming language used for machine learning, data analysis, and scripting.",
        "doc3": "Transformers are sequence models with self-attention, used in large language models."
    }
    idx = EmbeddingIndex()
    idx.build_from_docs(docs)
    llm = LLMWrapper()
    policy = PolicyNet(idx.encoder.get_sentence_embedding_dimension()).to(DEVICE)

    q = "Why is Paris culturally important?"
    print("Query:", q)
    print("\n--- Without Policy ---")
    ans1 = refrag_answer(q, idx, llm, policy=None)
    print(ans1)

    print("\n--- With Policy (untrained) ---")
    ans2 = refrag_answer(q, idx, llm, policy=policy)
    print(ans2)

run_demo()


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m45.4 MB/s[0m eta [36m0:00:00[0m
[?25h

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Query: Why is Paris culturally important?

--- Without Policy ---


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Paris is the capital of France, known for the Eiffel Tower, Louvre, art, fashion, and culture.

CHUNK_

--- With Policy (untrained) ---


Consider using tensor.detach() first. (Triggered internally at /pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:835.)
  if float(scores[best_idx].cpu()) < 0.1:
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Paris is the capital of France, known for the Eiffel Tower, Louvre, art, fashion, and culture.

CHUNK_1 (expanded):

Transformers are sequence models with self-attention, used in large language models.

CHUNK_2 (compressed): Python is a programming language used for machine ...

CHUNK_3 (compressed):

Transformers are sequence models with self-attention, used in large language models.

CHUNK_4 (compressed):

Transformers are sequence models with self-attenti...

CHUNK_5 (compressed): Transformers are sequence models with self-attenti...

CHUNK_6


In [4]:
# REFRAG Prototype Colab Notebook (Improved)
# ==================================
# This notebook is an improved Colab-ready prototype implementing REFRAG-style selective chunk expansion.
# Changes from earlier draft:
# - Robust tokenizer-based chunking (fixes truncated/duplicate chunk problems)
# - Proper pad_token handling to silence generation warnings
# - Detach tensors before converting to scalars (fixes requires_grad warning)
# - Deduplication of chunks and metadata hygiene
# - Scalable Faiss setup (HNSW/IVF options) for large cross-cut databases
# - Optional auto-summary pipeline to create compressed chunk representations
# - Better demo using a synthetic larger, cross-cut dataset

# Run all cells in sequence.

# --- Install dependencies ---
!pip install sentence-transformers faiss-cpu transformers torch tqdm datasets --quiet

# --- Imports ---
import os
import math
import random
import uuid
from typing import List, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm

# --- Config ---
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "gpt2"
CHUNK_TOKENS = 64        # target approximate tokens per chunk
TOP_K = 16               # retrieve more candidates for larger DB
INITIAL_EXPAND = 2
MAX_EXPANSIONS = 8
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FAISS_INDEX_TYPE = "hnsw"  # options: 'flat', 'hnsw', 'ivf'

# -----------------------
# Improved tokenizer-aware chunker
# -----------------------
def chunk_text_by_tokens(text: str, tokenizer, tokens_per_chunk: int = CHUNK_TOKENS) -> List[str]:
    """Chunk text by tokenizer token count to avoid mid-token splits and get stable chunk sizes."""
    toks = tokenizer.tokenize(text)
    chunks = []
    i = 0
    while i < len(toks):
        sub = toks[i:i+tokens_per_chunk]
        chunk_text = tokenizer.convert_tokens_to_string(sub)
        chunks.append(chunk_text.strip())
        i += tokens_per_chunk
    return [c for c in chunks if c]

# -----------------------
# Embedding Index (scalable, dedup-aware)
# -----------------------
class EmbeddingIndex:
    def __init__(self, embed_model_name=EMBED_MODEL, index_type=FAISS_INDEX_TYPE):
        self.encoder = SentenceTransformer(embed_model_name)
        self.dim = self.encoder.get_sentence_embedding_dimension()
        self.index = None
        self.metadata = []  # list of dicts: {text, doc_id, chunk_id}
        self.embeddings = None
        self.index_type = index_type

    def _init_faiss(self, n_items_estimate=1000):
        if self.index_type == 'flat':
            self.index = faiss.IndexFlatIP(self.dim)
        elif self.index_type == 'hnsw':
            self.index = faiss.IndexHNSWFlat(self.dim, 32)
            self.index.hnsw.efConstruction = 200
        elif self.index_type == 'ivf':
            nlist = max(100, n_items_estimate // 10)
            quantizer = faiss.IndexFlatL2(self.dim)
            self.index = faiss.IndexIVFFlat(quantizer, self.dim, nlist, faiss.METRIC_INNER_PRODUCT)
        else:
            raise ValueError('unknown index type')

    def build_from_docs(self, docs: Dict[str, str], tokenizer):
        """docs: dict doc_id -> text. tokenizer used to chunk."""
        all_chunks = []
        meta = []
        for doc_id, text in docs.items():
            chunks = chunk_text_by_tokens(text, tokenizer, CHUNK_TOKENS)
            for i, c in enumerate(chunks):
                if c.strip() == '':
                    continue
                all_chunks.append(c)
                meta.append({"text": c, "doc_id": doc_id, "chunk_id": f"{doc_id}__{i}"})
        if not all_chunks:
            raise ValueError('No chunks produced')
        embs = self.encoder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
        faiss.normalize_L2(embs)
        self.embeddings = embs
        self.metadata = meta
        self._init_faiss(n_items_estimate=len(all_chunks))
        self.index.add(self.embeddings)

    def query(self, q_text: str, k=TOP_K) -> List[Tuple[float, Dict]]:
        q_emb = self.encoder.encode([q_text], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        D, I = self.index.search(q_emb, k)
        results = []
        for score, idx in zip(D[0], I[0]):
            meta = self.metadata[idx].copy()
            meta['score'] = float(score)
            results.append((float(score), meta))
        return results

# -----------------------
# Policy network
# -----------------------
class PolicyNet(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(emb_dim * 2 + 1, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, q_emb, chunk_emb, decoder_proxy):
        x = torch.cat([q_emb, chunk_emb, decoder_proxy.unsqueeze(-1)], dim=-1)
        return self.net(x).squeeze(-1)

# -----------------------
# LLM wrapper
# -----------------------
class LLMWrapper:
    def __init__(self, model_name=LLM_MODEL, device=DEVICE):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.model.config.pad_token_id = self.model.config.eos_token_id
        self.device = device

    def generate(self, prompt: str, max_new_tokens=64):
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
        out = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

    def next_token_entropy_proxy(self, prompt: str):
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            ent = -torch.sum(probs * torch.log(probs + 1e-12), dim=-1)
            return float(ent.cpu().item())

# -----------------------
# REFRAG-like run
# -----------------------
def refrag_answer(query: str, index: EmbeddingIndex, llm: LLMWrapper, policy: PolicyNet=None,
                  initial_expand=INITIAL_EXPAND, max_expansions=MAX_EXPANSIONS):
    top = index.query(query, k=TOP_K)
    expansions = {i: top[i][1]['text'] for i in range(min(initial_expand, len(top)))}

    def build_prompt(expansions_map):
        expanded_texts = []
        for i, (s, m) in enumerate(top):
            if i in expansions_map:
                expanded_texts.append(f"CHUNK_{i} (expanded):\n{expansions_map[i]}\n")
            else:
                txt = m['text']
                expanded_texts.append(f"CHUNK_{i} (compressed): {txt[:60]}... [score={s:.3f}]")
        ctx = "\n\n".join(expanded_texts)
        prompt = f"Use the following retrieved chunks to answer the query.\n\nContext:\n{ctx}\n\nQuery: {query}\n\nAnswer concisely:"
        return prompt

    generated = ""
    expansions_done = len(expansions)

    while True:
        prompt = build_prompt(expansions) + "\n\nPreviously generated:\n" + generated + "\nContinue:"
        out = llm.generate(prompt, max_new_tokens=32)
        newly = out[len(prompt):] if out.startswith(prompt) else out
        generated += newly
        ent = llm.next_token_entropy_proxy(prompt + newly)

        if policy is not None and expansions_done < max_expansions:
            q_emb_np = index.encoder.encode([query])
            candidates = []
            cand_idx = []
            for i, (s, m) in enumerate(top):
                if i in expansions:
                    continue
                cand_idx.append(i)
                candidates.append(m['text'])
            if not candidates:
                break
            cand_embs_np = index.encoder.encode(candidates)
            q_emb = torch.tensor(q_emb_np.repeat(len(candidates), 1), dtype=torch.float32).to(DEVICE)
            cand_embs = torch.tensor(cand_embs_np, dtype=torch.float32).to(DEVICE)
            ent_rep = torch.tensor([ent]*len(candidates), dtype=torch.float32).to(DEVICE)
            scores = policy(q_emb, cand_embs, ent_rep)
            best = torch.argmax(scores).item()
            chosen_global = cand_idx[best]
            best_score_val = float(scores[best].detach().cpu().item())
            if best_score_val < 0.05:
                break
            expansions[chosen_global] = top[chosen_global][1]['text']
            expansions_done += 1
        else:
            break
    return generated

# -----------------------
# Utility: cross-cut dataset
# -----------------------
def make_cross_cut_dataset(n_docs=200):
    topics = [
        "history of paris and france",
        "programming languages and python",
        "transformer models and attention",
        "art and museums",
        "fashion and design",
        "data science and ml",
        "travel and tourism",
    ]
    docs = {}
    for i in range(n_docs):
        topic = random.choice(topics)
        par = " ".join([f"Sentence about {topic}." for _ in range(30)])
        docs[f"doc_{i}"] = par
    return docs

# -----------------------
# Demo
# -----------------------
def run_improved_demo():
    tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
    docs = make_cross_cut_dataset(n_docs=120)
    idx = EmbeddingIndex()
    print("Building index (this may take a minute)...")
    idx.build_from_docs(docs, tokenizer)
    llm = LLMWrapper()
    policy = PolicyNet(idx.encoder.get_sentence_embedding_dimension()).to(DEVICE)

    q = "Why is Paris culturally important?"
    print("Query:", q)
    print("\n--- Without Policy ---")
    ans1 = refrag_answer(q, idx, llm, policy=None)
    print(ans1[:1000])

    print("\n--- With Policy (untrained) ---")
    ans2 = refrag_answer(q, idx, llm, policy=policy)
    print(ans2[:1000])

run_improved_demo()


Building index (this may take a minute)...


Batches:   0%|          | 0/17 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Query: Why is Paris culturally important?

--- Without Policy ---


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.




Continue:

Continue:

Continue:

Continue:

Continue:

Continue:

Continue:

Continue:

--- With Policy (untrained) ---


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 1 but got size 14 for tensor number 1 in the list.

In [5]:
# --- Benchmark & Graphing additions for REFRAG Colab demo ---
# Paste this after your existing notebook code cells (it re-uses EmbeddingIndex, LLMWrapper, refrag_answer, etc.)
# Requires: matplotlib, psutil
!pip install psutil matplotlib --quiet

import time
import matplotlib.pyplot as plt
import psutil
import os
import gc
from typing import List, Dict, Tuple

# -----------------------
# Large synthetic dataset generator (very large, cross-cut)
# -----------------------
def make_large_cross_cut_dataset(n_docs=5000, doc_length_sentences=200, topics=None, seed=42):
    """
    Create a synthetic multi-domain dataset:
      - n_docs: number of docs
      - doc_length_sentences: approx sentences per doc (controls doc size)
      - topics: list of topic strings to cross-cut
    Returns dict doc_id -> text
    """
    random.seed(seed)
    if topics is None:
        topics = [
            "history of paris and france",
            "programming languages and python",
            "transformer models and attention",
            "art and museums",
            "fashion and design",
            "data science and ml",
            "travel and tourism",
            "quantum physics",
            "economics and markets",
            "biology and genetics",
            "sports and events",
            "cooking and recipes",
            "mythology and legends",
            "film and cinema",
        ]
    docs = {}
    # pre-build phrase pools per topic for variety
    phrase_pool = {}
    for t in topics:
        phrase_pool[t] = [f"A note about {t}." for _ in range(20)]
    for i in range(n_docs):
        t = random.choice(topics)
        # make doc by sampling sentences; mix in occasional sentences from related topics to make cross-cut noise
        sentences = []
        for s in range(doc_length_sentences):
            if random.random() < 0.85:
                sentences.append(random.choice(phrase_pool[t]))
            else:
                # cross-cut injection
                other = random.choice(topics)
                sentences.append(random.choice(phrase_pool[other]))
        docs[f"doc_{i}"] = " ".join(sentences)
    return docs

# -----------------------
# Standard RAG answer: expand all retrieved chunks
# -----------------------
def standard_rag_answer(query: str, index: 'EmbeddingIndex', llm: 'LLMWrapper', top_k:int):
    """
    Retrieve top_k and expand them all into the prompt (classic RAG).
    Returns answer_text, expansions_count
    """
    top = index.query(query, k=top_k)
    # build full expanded prompt
    parts = []
    for i, (s, m) in enumerate(top):
        parts.append(f"CHUNK_{i}:\n{m['text']}\n")
    ctx = "\n\n".join(parts)
    prompt = f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer concisely:"
    t0 = time.time()
    out = llm.generate(prompt, max_new_tokens=128)
    t1 = time.time()
    return out, (t1 - t0), len(top)

# -----------------------
# Lightweight heuristic policy used by REFRAG (dot-product + small MLP optional)
# -----------------------
def simple_policy_score(query_emb: List[float], chunk_embs: List[List[float]]):
    """
    Very fast heuristic: cosine similarity (dot after normalize) between query_emb and chunk_embs.
    Returns list of scores (higher = more important).
    """
    import numpy as np
    q = query_emb / (np.linalg.norm(query_emb) + 1e-12)
    # chunk_embs already normalized in our index, but compute robustly
    scores = []
    for c in chunk_embs:
        c_n = c / (np.linalg.norm(c) + 1e-12)
        scores.append(float(np.dot(q, c_n)))
    return scores

# -----------------------
# REFRAG-style answer wrapper (selective expansion)
# -----------------------
def refrag_answer_selective(query: str, index: 'EmbeddingIndex', llm: 'LLMWrapper',
                            top_k:int, initial_expand:int=2, max_expands:int=6, policy_fn=None):
    """
    Implementation notes:
    - Retrieves top_k compressed chunks
    - Initially expands initial_expand chunks (highest score)
    - Iteratively decodes small steps and uses policy_fn to pick next chunk to expand until
      max_expands is reached or policy score threshold falls below.
    - Returns generated text, total_time, expansions_done
    """
    top = index.query(query, k=top_k)
    # basic compressed placeholders
    compressed = [m['text'] for (_, m) in top]
    # initial expansions: highest score ones (top sorted)
    expansions = {i: compressed[i] for i in range(min(initial_expand, len(compressed)))}
    # prepare query embedding and chunk embeddings for policy decisions
    q_emb = index.encoder.encode([query])[0]
    chunk_embs = [index.encoder.encode([c])[0] for c in compressed]  # note: repeated encode; could be optimized using stored embeddings
    # start iterative loop
    t0 = time.time()
    generated = ""
    expansions_done = len(expansions)
    # stopping criteria: maximum expansions or low policy score
    while True:
        # build prompt from expansions + compressed placeholders
        prompt_parts = []
        for i, c in enumerate(compressed):
            if i in expansions:
                prompt_parts.append(f"CHUNK_{i} (expanded):\n{expansions[i]}\n")
            else:
                prompt_parts.append(f"CHUNK_{i} (compressed): {c[:80]}... [id={i}]")
        prompt = "Use the following retrieved chunks to answer the query.\n\nContext:\n" + "\n\n".join(prompt_parts) + f"\n\nQuery: {query}\nAnswer concisely:"
        # issue a short generation to let policy evaluate and keep TTFT low
        out = llm.generate(prompt, max_new_tokens=32)
        newly = out[len(prompt):] if out.startswith(prompt) else out
        generated += newly
        # check expand budget
        if expansions_done >= max_expands:
            break
        # choose the best candidate among not-yet-expanded
        not_expanded_idx = [i for i in range(len(compressed)) if i not in expansions]
        if not not_expanded_idx:
            break
        # gather their embeddings and score via policy
        candidate_embs = [chunk_embs[i] for i in not_expanded_idx]
        scores = policy_fn(q_emb, candidate_embs) if policy_fn is not None else simple_policy_score(q_emb, candidate_embs)
        # pick best
        best_local = int(max(range(len(scores)), key=lambda ii: scores[ii]))
        best_score = float(scores[best_local])
        best_global_idx = not_expanded_idx[best_local]
        # threshold to stop expanding if scores are small
        if best_score < 0.05:
            break
        # expand it
        expansions[best_global_idx] = compressed[best_global_idx]
        expansions_done += 1
    t1 = time.time()
    return generated, (t1 - t0), expansions_done

# -----------------------
# Helpers: memory measurement
# -----------------------
def get_memory_usage():
    # returns (host_rss_MB, cuda_alloc_MB)
    process = psutil.Process(os.getpid())
    rss = process.memory_info().rss / (1024**2)
    cuda_mem = 0.0
    if torch.cuda.is_available():
        cuda_mem = torch.cuda.memory_allocated() / (1024**2)
    return rss, cuda_mem

# -----------------------
# Benchmark harness
# -----------------------
def run_benchmark(index: 'EmbeddingIndex', llm: 'LLMWrapper', queries: List[str],
                  top_k_values: List[int], runs_per_point:int=2):
    """
    For each top_k in top_k_values:
      - run standard RAG and REFRAG on each query (runs_per_point times)
      - record TTFT / full generation time (approx), expansions, memory
    Returns dicts of metrics for plotting.
    """
    metrics = {
        'k': [],
        'standard_time': [],
        'refrag_time': [],
        'standard_rss': [],
        'refrag_rss': [],
        'standard_cuda': [],
        'refrag_cuda': [],
        'refrag_expansions': []
    }
    for k in top_k_values:
        std_times = []
        ref_times = []
        std_rss = []
        ref_rss = []
        std_cuda = []
        ref_cuda = []
        ref_exp = []
        print(f"\nBenchmarking k={k} ...")
        for q in queries:
            # standard RAG
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            rss0, cuda0 = get_memory_usage()
            _, t_std, expanded_std = standard_rag_answer(q, index, llm, top_k=k)
            rss1, cuda1 = get_memory_usage()
            std_times.append(t_std)
            std_rss.append(rss1 - rss0)
            std_cuda.append(cuda1 - cuda0)
            # REFRAG (selective)
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            rss0, cuda0 = get_memory_usage()
            _, t_ref, exp_count = refrag_answer_selective(q, index, llm, top_k=k,
                                                          initial_expand=2, max_expands=min(8, k), policy_fn=simple_policy_score)
            rss1, cuda1 = get_memory_usage()
            ref_times.append(t_ref)
            ref_rss.append(rss1 - rss0)
            ref_cuda.append(cuda1 - cuda0)
            ref_exp.append(exp_count)
        # aggregate
        metrics['k'].append(k)
        metrics['standard_time'].append(sum(std_times)/len(std_times))
        metrics['refrag_time'].append(sum(ref_times)/len(ref_times))
        metrics['standard_rss'].append(sum(std_rss)/len(std_rss))
        metrics['refrag_rss'].append(sum(ref_rss)/len(ref_rss))
        metrics['standard_cuda'].append(sum(std_cuda)/len(std_cuda))
        metrics['refrag_cuda'].append(sum(ref_cuda)/len(ref_cuda))
        metrics['refrag_expansions'].append(sum(ref_exp)/len(ref_exp))
        print(f"k={k} => standard_time={metrics['standard_time'][-1]:.3f}s, refrag_time={metrics['refrag_time'][-1]:.3f}s, refrag_expansions={metrics['refrag_expansions'][-1]:.2f}")
    return metrics

# -----------------------
# Plotting utilities
# -----------------------
def plot_metrics(metrics):
    ks = metrics['k']
    plt.figure(figsize=(14,4))
    plt.subplot(1,3,1)
    plt.plot(ks, metrics['standard_time'], marker='o', label='Standard RAG')
    plt.plot(ks, metrics['refrag_time'], marker='o', label='REFRAG (selective)')
    plt.xlabel('Top-K retrieved')
    plt.ylabel('Avg generation time (s)')
    plt.title('Latency (approx)')
    plt.legend()
    plt.grid(True)

    plt.subplot(1,3,2)
    plt.plot(ks, metrics['standard_rss'], marker='o', label='Standard RSS (MB)')
    plt.plot(ks, metrics['refrag_rss'], marker='o', label='REFRAG RSS (MB)')
    plt.xlabel('Top-K retrieved')
    plt.ylabel('Host RSS delta (MB)')
    plt.title('Host memory delta')
    plt.legend()
    plt.grid(True)

    plt.subplot(1,3,3)
    plt.plot(ks, metrics['refrag_expansions'], marker='o', label='Avg expansions')
    plt.xlabel('Top-K retrieved')
    plt.ylabel('Expansions (count)')
    plt.title('REFRAG expansion rate')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

# -----------------------
# HOW TO RUN (example)
# -----------------------
# 1) Build or reuse the index (this can be expensive for large datasets).
#    For very large runs use smaller models or precomputed embeddings.
#
# Example quick demo using a small dataset:
# docs = make_large_cross_cut_dataset(n_docs=200, doc_length_sentences=80)
# tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
# idx = EmbeddingIndex()
# idx.build_from_docs(docs, tokenizer)
#
# llm = LLMWrapper()
# queries = ["Why is Paris culturally important?", "What is Python commonly used for?"]
# ks = [4, 8, 16, 32]   # vary retrieval size
# metrics = run_benchmark(idx, llm, queries, top_k_values=ks)
# plot_metrics(metrics)
#
# For a true stress test:
# docs = make_large_cross_cut_dataset(n_docs=3000, doc_length_sentences=200)
# (warning: will take time & memory to embed + index). For Colab, try n_docs <= 2000 with mini models.

# -----------------------
# Notes:
# - This benchmark measures *approx* latency and memory deltas. For real TTFT you'd integrate streaming APIs and measure time until first token arrives via streaming callbacks.
# - FAISS index build and embedding generation are the heavy parts; for very large corpora precompute and store embeddings on disk.
# - Replace llm.generate with an API stream-invocation to measure true time-to-first-token for production LLMs.
# - The policy here is a simple similarity heuristic. Replace with a learned policy (PolicyNet) trained with REINFORCE or supervised labels to get better expansion decisions.


In [6]:
# ============================================
# REFRAG Prototype + Benchmark (Colab Notebook)
# ============================================

# --- Install dependencies ---
!pip install sentence-transformers faiss-cpu transformers torch tqdm psutil matplotlib --quiet

# --- Imports ---
import torch, faiss, random, time, gc, os, psutil
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

# --- Config ---
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "gpt2"
CHUNK_TOKENS = 64
TOP_K = 12
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------
# Utilities: chunk text
# -----------------------
def chunk_text(text: str, tokens_per_chunk: int = CHUNK_TOKENS) -> List[str]:
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        chunks.append(" ".join(words[i:i + tokens_per_chunk]))
        i += tokens_per_chunk
    return chunks

# -----------------------
# Embedding Index
# -----------------------
class EmbeddingIndex:
    def __init__(self, embed_model_name=EMBED_MODEL):
        self.encoder = SentenceTransformer(embed_model_name)
        self.dim = self.encoder.get_sentence_embedding_dimension()
        self.index = None
        self.metadata = []
        self.embeddings = None

    def build_from_docs(self, docs: Dict[str, str]):
        import numpy as np
        all_chunks, meta = [], []
        for doc_id, text in docs.items():
            chunks = chunk_text(text)
            all_chunks.extend(chunks)
            for c in chunks:
                meta.append({"text": c, "doc_id": doc_id})
        embs = self.encoder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
        self.embeddings = embs
        self.index = faiss.IndexFlatIP(self.dim)
        faiss.normalize_L2(self.embeddings)
        self.index.add(self.embeddings)
        self.metadata = meta

    def query(self, q_text: str, k=TOP_K) -> List[Tuple[float, Dict]]:
        import numpy as np
        q_emb = self.encoder.encode([q_text], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        D, I = self.index.search(q_emb, k)
        results = []
        for score, idx in zip(D[0], I[0]):
            meta = self.metadata[idx].copy()
            meta['score'] = float(score)
            results.append((float(score), meta))
        return results

# -----------------------
# LLM Wrapper
# -----------------------
class LLMWrapper:
    def __init__(self, model_name=LLM_MODEL, device=DEVICE):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        self.device = device

    def generate(self, prompt: str, max_new_tokens=64):
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        out = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

# -----------------------
# Standard RAG Answer
# -----------------------
def standard_rag_answer(query: str, index: 'EmbeddingIndex', llm: 'LLMWrapper', top_k:int):
    top = index.query(query, k=top_k)
    parts = []
    for i, (s, m) in enumerate(top):
        parts.append(f"CHUNK_{i}:\n{m['text']}\n")
    ctx = "\n\n".join(parts)
    prompt = f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer concisely:"
    t0 = time.time()
    out = llm.generate(prompt, max_new_tokens=128)
    t1 = time.time()
    return out, (t1 - t0), len(top)

# -----------------------
# Simple Policy (cosine)
# -----------------------
def simple_policy_score(query_emb, chunk_embs):
    import numpy as np
    q = query_emb / (np.linalg.norm(query_emb) + 1e-12)
    scores = []
    for c in chunk_embs:
        c_n = c / (np.linalg.norm(c) + 1e-12)
        scores.append(float(np.dot(q, c_n)))
    return scores

# -----------------------
# REFRAG-style Answer
# -----------------------
def refrag_answer_selective(query: str, index: 'EmbeddingIndex', llm: 'LLMWrapper',
                            top_k:int, initial_expand:int=2, max_expands:int=6, policy_fn=None):
    top = index.query(query, k=top_k)
    compressed = [m['text'] for (_, m) in top]
    expansions = {i: compressed[i] for i in range(min(initial_expand, len(compressed)))}
    q_emb = index.encoder.encode([query])[0]
    chunk_embs = [index.encoder.encode([c])[0] for c in compressed]

    t0 = time.time()
    generated = ""
    expansions_done = len(expansions)

    while True:
        prompt_parts = []
        for i, c in enumerate(compressed):
            if i in expansions:
                prompt_parts.append(f"CHUNK_{i} (expanded):\n{expansions[i]}\n")
            else:
                prompt_parts.append(f"CHUNK_{i} (compressed): {c[:80]}... [id={i}]")
        prompt = "Use the following retrieved chunks to answer the query.\n\nContext:\n" + "\n\n".join(prompt_parts) + f"\n\nQuery: {query}\nAnswer concisely:"
        out = llm.generate(prompt, max_new_tokens=32)
        newly = out[len(prompt):] if out.startswith(prompt) else out
        generated += newly

        if expansions_done >= max_expands:
            break
        not_expanded_idx = [i for i in range(len(compressed)) if i not in expansions]
        if not not_expanded_idx:
            break
        candidate_embs = [chunk_embs[i] for i in not_expanded_idx]
        scores = policy_fn(q_emb, candidate_embs) if policy_fn is not None else simple_policy_score(q_emb, candidate_embs)
        best_local = int(max(range(len(scores)), key=lambda ii: scores[ii]))
        best_score = float(scores[best_local])
        best_global_idx = not_expanded_idx[best_local]
        if best_score < 0.05:
            break
        expansions[best_global_idx] = compressed[best_global_idx]
        expansions_done += 1
    t1 = time.time()
    return generated, (t1 - t0), expansions_done

# -----------------------
# Memory usage helper
# -----------------------
def get_memory_usage():
    process = psutil.Process(os.getpid())
    rss = process.memory_info().rss / (1024**2)
    cuda_mem = 0.0
    if torch.cuda.is_available():
        cuda_mem = torch.cuda.memory_allocated() / (1024**2)
    return rss, cuda_mem

# -----------------------
# Benchmark harness
# -----------------------
def run_benchmark(index: 'EmbeddingIndex', llm: 'LLMWrapper', queries: List[str],
                  top_k_values: List[int], runs_per_point:int=1):
    metrics = {
        'k': [],
        'standard_time': [],
        'refrag_time': [],
        'standard_rss': [],
        'refrag_rss': [],
        'refrag_expansions': []
    }
    for k in top_k_values:
        std_times, ref_times = [], []
        std_rss, ref_rss = [], []
        ref_exp = []
        print(f"\nBenchmarking k={k} ...")
        for q in queries:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            rss0, _ = get_memory_usage()
            _, t_std, _ = standard_rag_answer(q, index, llm, top_k=k)
            rss1, _ = get_memory_usage()
            std_times.append(t_std)
            std_rss.append(rss1 - rss0)

            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            rss0, _ = get_memory_usage()
            _, t_ref, exp_count = refrag_answer_selective(q, index, llm, top_k=k, initial_expand=2, max_expands=min(8, k), policy_fn=simple_policy_score)
            rss1, _ = get_memory_usage()
            ref_times.append(t_ref)
            ref_rss.append(rss1 - rss0)
            ref_exp.append(exp_count)

        metrics['k'].append(k)
        metrics['standard_time'].append(sum(std_times)/len(std_times))
        metrics['refrag_time'].append(sum(ref_times)/len(ref_times))
        metrics['standard_rss'].append(sum(std_rss)/len(std_rss))
        metrics['refrag_rss'].append(sum(ref_rss)/len(ref_rss))
        metrics['refrag_expansions'].append(sum(ref_exp)/len(ref_exp))
        print(f"k={k}: RAG={metrics['standard_time'][-1]:.3f}s, REFRAG={metrics['refrag_time'][-1]:.3f}s, expansions={metrics['refrag_expansions'][-1]:.2f}")
    return metrics

# -----------------------
# Plotting
# -----------------------
def plot_metrics(metrics):
    ks = metrics['k']
    plt.figure(figsize=(14,4))
    plt.subplot(1,3,1)
    plt.plot(ks, metrics['standard_time'], marker='o', label='Standard RAG')
    plt.plot(ks, metrics['refrag_time'], marker='o', label='REFRAG')
    plt.xlabel('Top-K')
    plt.ylabel('Avg time (s)')
    plt.title('Latency')
    plt.legend(); plt.grid(True)

    plt.subplot(1,3,2)
    plt.plot(ks, metrics['standard_rss'], marker='o', label='Standard RSS')
    plt.plot(ks, metrics['refrag_rss'], marker='o', label='REFRAG RSS')
    plt.xlabel('Top-K')
    plt.ylabel('Host memory delta (MB)')
    plt.title('Memory')
    plt.legend(); plt.grid(True)

    plt.subplot(1,3,3)
    plt.plot(ks, metrics['refrag_expansions'], marker='o', label='REFRAG expansions')
    plt.xlabel('Top-K')
    plt.ylabel('Expansions')
    plt.title('Expansion Rate')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

# -----------------------
# Large Realistic Docs
# -----------------------
docs = {
    "doc1": ("Paris, the capital city of France, is internationally recognized not only for its "
             "iconic monuments such as the Eiffel Tower, the Louvre Museum, and Notre-Dame, "
             "but also for the profound cultural influence it has exerted in art, literature, "
             "philosophy, cuisine, and fashion.") * 30,
    "doc2": ("Python is a high-level programming language that emphasizes readability and efficiency, "
             "with a versatile ecosystem of libraries and frameworks that enable rapid development "
             "in machine learning, AI, data analysis, and scripting.") * 30,
    "doc3": ("Transformers are deep learning architectures leveraging self-attention mechanisms "
             "to model long-range dependencies in sequences, powering BERT, GPT, and T5 for NLP tasks.") * 30,
    "doc4": ("Quantum mechanics revolutionized physics by describing subatomic particle behavior, "
             "leading to technologies like semiconductors, lasers, and quantum computers.") * 30,
    "doc5": ("Global economics is shaped by markets, trade, and institutions such as the IMF and World Bank, "
             "where crises or wars can create ripple effects across supply chains, currencies, and jobs.") * 30,
}
print(f"Loaded {len(docs)} long docs with ~{sum(len(v.split()) for v in docs.values())//len(docs)} words avg.")

# -----------------------
# Run Demo Benchmark
# -----------------------
idx = EmbeddingIndex()
idx.build_from_docs(docs)
llm = LLMWrapper()

queries = [
    "Why is Paris culturally important?",
    "What is Python used for?",
    "Explain transformers in AI.",
    "What impact did quantum mechanics have on technology?",
    "How do global economic shocks spread?"
]

ks = [4, 8, 16, 32]
metrics = run_benchmark(idx, llm, queries, top_k_values=ks)
plot_metrics(metrics)


Loaded 5 long docs with ~829 words avg.


Batches:   0%|          | 0/3 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



Benchmarking k=4 ...


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end gene

k=4: RAG=6.773s, REFRAG=6.898s, expansions=4.00

Benchmarking k=8 ...


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end gene

IndexError: index out of range in self

In [8]:
# ============================================
# REFRAG Prototype + Benchmark (Colab Notebook)
# ============================================

# --- Install dependencies ---
!pip install sentence-transformers faiss-cpu transformers torch tqdm psutil matplotlib --quiet

# --- Imports ---
import torch
import faiss
import random
import time
import gc
import os
import psutil
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple

# --- Config ---
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "gpt2"
CHUNK_TOKENS = 64
TOP_K = 12
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------
# Utilities: chunk text
# -----------------------
def chunk_text(text: str, tokens_per_chunk: int = CHUNK_TOKENS) -> List[str]:
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        chunks.append(" ".join(words[i:i + tokens_per_chunk]))
        i += tokens_per_chunk
    return chunks

# -----------------------
# Embedding Index
# -----------------------
class EmbeddingIndex:
    def __init__(self, embed_model_name=EMBED_MODEL):
        self.encoder = SentenceTransformer(embed_model_name)
        self.dim = self.encoder.get_sentence_embedding_dimension()
        self.index = None
        self.metadata = []
        self.embeddings = None

    def build_from_docs(self, docs: Dict[str, str]):
        import numpy as np
        all_chunks, meta = [], []
        for doc_id, text in docs.items():
            chunks = chunk_text(text)
            all_chunks.extend(chunks)
            for c in chunks:
                meta.append({"text": c, "doc_id": doc_id})
        embs = self.encoder.encode(all_chunks, convert_to_numpy=True, show_progress_bar=True)
        self.embeddings = embs
        self.index = faiss.IndexFlatIP(self.dim)
        faiss.normalize_L2(self.embeddings)
        self.index.add(self.embeddings)
        self.metadata = meta

    def query(self, q_text: str, k=TOP_K) -> List[Tuple[float, Dict]]:
        import numpy as np
        q_emb = self.encoder.encode([q_text], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        D, I = self.index.search(q_emb, k)
        results = []
        for score, idx in zip(D[0], I[0]):
            meta = self.metadata[idx].copy()
            meta['score'] = float(score)
            results.append((float(score), meta))
        return results

# -----------------------
# LLM Wrapper
# -----------------------
class LLMWrapper:
    def __init__(self, model_name=LLM_MODEL, device=DEVICE):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            # make pad_token = eos_token to avoid HF generation warnings
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
        # set model config pad_token_id if missing
        if getattr(self.model.config, "pad_token_id", None) is None and getattr(self.model.config, "eos_token_id", None) is not None:
            self.model.config.pad_token_id = self.model.config.eos_token_id
        self.device = device

    def generate(self, prompt: str, max_new_tokens=64):
        # tokenizer returns a BatchEncoding which supports .to(device)
        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device)
        out = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
        return self.tokenizer.decode(out[0], skip_special_tokens=True)

# -----------------------
# Standard RAG Answer
# -----------------------
def standard_rag_answer(query: str, index: 'EmbeddingIndex', llm: 'LLMWrapper', top_k:int):
    top = index.query(query, k=top_k)
    parts = []
    for i, (s, m) in enumerate(top):
        parts.append(f"CHUNK_{i}:\n{m['text']}\n")
    ctx = "\n\n".join(parts)
    prompt = f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer concisely:"
    t0 = time.time()
    out = llm.generate(prompt, max_new_tokens=128)
    t1 = time.time()
    return out, (t1 - t0), len(top)

# -----------------------
# Simple Policy (cosine)
# -----------------------
def simple_policy_score(query_emb, chunk_embs):
    import numpy as np
    q = query_emb / (np.linalg.norm(query_emb) + 1e-12)
    scores = []
    for c in chunk_embs:
        c_n = c / (np.linalg.norm(c) + 1e-12)
        scores.append(float(np.dot(q, c_n)))
    return scores

# -----------------------
# REFRAG-style Answer
# -----------------------
def refrag_answer_selective(query: str, index: 'EmbeddingIndex', llm: 'LLMWrapper',
                            top_k:int, initial_expand:int=2, max_expands:int=6, policy_fn=None):
    top = index.query(query, k=top_k)
    compressed = [m['text'] for (_, m) in top]
    expansions = {i: compressed[i] for i in range(min(initial_expand, len(compressed)))}
    q_emb = index.encoder.encode([query])[0]
    # use stored embeddings if available; fallback to encoding chunks
    try:
        # map top indices into embeddings if we kept them in same order
        # we built embeddings in build_from_docs as all_chunks order; since metadata corresponds, we can find them via metadata index
        # but for simplicity here re-encode chunk texts (ok for demo)
        chunk_embs = [index.encoder.encode([c])[0] for c in compressed]
    except Exception:
        chunk_embs = [index.encoder.encode([c])[0] for c in compressed]

    t0 = time.time()
    generated = ""
    expansions_done = len(expansions)

    while True:
        prompt_parts = []
        for i, c in enumerate(compressed):
            if i in expansions:
                prompt_parts.append(f"CHUNK_{i} (expanded):\n{expansions[i]}\n")
            else:
                prompt_parts.append(f"CHUNK_{i} (compressed): {c[:80]}... [id={i}]")
        prompt = "Use the following retrieved chunks to answer the query.\n\nContext:\n" + "\n\n".join(prompt_parts) + f"\n\nQuery: {query}\nAnswer concisely:"
        out = llm.generate(prompt, max_new_tokens=32)
        newly = out[len(prompt):] if out.startswith(prompt) else out
        generated += newly

        if expansions_done >= max_expands:
            break
        not_expanded_idx = [i for i in range(len(compressed)) if i not in expansions]
        if not not_expanded_idx:
            break
        candidate_embs = [chunk_embs[i] for i in not_expanded_idx]
        scores = policy_fn(q_emb, candidate_embs) if policy_fn is not None else simple_policy_score(q_emb, candidate_embs)
        best_local = int(max(range(len(scores)), key=lambda ii: scores[ii]))
        best_score = float(scores[best_local])
        best_global_idx = not_expanded_idx[best_local]
        if best_score < 0.05:
            break
        expansions[best_global_idx] = compressed[best_global_idx]
        expansions_done += 1
    t1 = time.time()
    return generated, (t1 - t0), expansions_done

# -----------------------
# Memory usage helper
# -----------------------
def get_memory_usage():
    process = psutil.Process(os.getpid())
    rss = process.memory_info().rss / (1024**2)
    cuda_mem = 0.0
    if torch.cuda.is_available():
        cuda_mem = torch.cuda.memory_allocated() / (1024**2)
    return rss, cuda_mem

# -----------------------
# Benchmark harness
# -----------------------
def run_benchmark(index: 'EmbeddingIndex', llm: 'LLMWrapper', queries: List[str],
                  top_k_values: List[int], runs_per_point:int=1):
    metrics = {
        'k': [],
        'standard_time': [],
        'refrag_time': [],
        'standard_rss': [],
        'refrag_rss': [],
        'refrag_expansions': []
    }
    for k in top_k_values:
        std_times, ref_times = [], []
        std_rss, ref_rss = [], []
        ref_exp = []
        print(f"\nBenchmarking k={k} ...")
        for q in queries:
            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            rss0, _ = get_memory_usage()
            _, t_std, _ = standard_rag_answer(q, index, llm, top_k=k)
            rss1, _ = get_memory_usage()
            std_times.append(t_std)
            std_rss.append(rss1 - rss0)

            gc.collect()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            rss0, _ = get_memory_usage()
            _, t_ref, exp_count = refrag_answer_selective(q, index, llm, top_k=k, initial_expand=2, max_expands=min(8, k), policy_fn=simple_policy_score)
            rss1, _ = get_memory_usage()
            ref_times.append(t_ref)
            ref_rss.append(rss1 - rss0)
            ref_exp.append(exp_count)

        metrics['k'].append(k)
        metrics['standard_time'].append(sum(std_times)/len(std_times))
        metrics['refrag_time'].append(sum(ref_times)/len(ref_times))
        metrics['standard_rss'].append(sum(std_rss)/len(std_rss))
        metrics['refrag_rss'].append(sum(ref_rss)/len(ref_rss))
        metrics['refrag_expansions'].append(sum(ref_exp)/len(ref_exp))
        print(f"k={k}: RAG={metrics['standard_time'][-1]:.3f}s, REFRAG={metrics['refrag_time'][-1]:.3f}s, expansions={metrics['refrag_expansions'][-1]:.2f}")
    return metrics

# -----------------------
# Plotting
# -----------------------
def plot_metrics(metrics):
    ks = metrics['k']
    plt.figure(figsize=(14,4))
    plt.subplot(1,3,1)
    plt.plot(ks, metrics['standard_time'], marker='o', label='Standard RAG')
    plt.plot(ks, metrics['refrag_time'], marker='o', label='REFRAG')
    plt.xlabel('Top-K')
    plt.ylabel('Avg time (s)')
    plt.title('Latency')
    plt.legend(); plt.grid(True)

    plt.subplot(1,3,2)
    plt.plot(ks, metrics['standard_rss'], marker='o', label='Standard RSS')
    plt.plot(ks, metrics['refrag_rss'], marker='o', label='REFRAG RSS')
    plt.xlabel('Top-K')
    plt.ylabel('Host memory delta (MB)')
    plt.title('Memory')
    plt.legend(); plt.grid(True)

    plt.subplot(1,3,3)
    plt.plot(ks, metrics['refrag_expansions'], marker='o', label='REFRAG expansions')
    plt.xlabel('Top-K')
    plt.ylabel('Expansions')
    plt.title('Expansion Rate')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

# -----------------------
# Large Realistic Docs (long)
# -----------------------
docs = {
    "doc1": ("Paris, the capital city of France, is internationally recognized not only for its "
             "iconic monuments such as the Eiffel Tower, the Louvre Museum, and Notre-Dame, "
             "but also for the profound cultural influence it has exerted in art, literature, "
             "philosophy, cuisine, and fashion.") * 30,
    "doc2": ("Python is a high-level programming language that emphasizes readability and efficiency, "
             "with a versatile ecosystem of libraries and frameworks that enable rapid development "
             "in machine learning, AI, data analysis, and scripting.") * 30,
    "doc3": ("Transformers are deep learning architectures leveraging self-attention mechanisms "
             "to model long-range dependencies in sequences, powering BERT, GPT, and T5 for NLP tasks.") * 30,
    "doc4": ("Quantum mechanics revolutionized physics by describing subatomic particle behavior, "
             "leading to technologies like semiconductors, lasers, and quantum computers.") * 30,
    "doc5": ("Global economics is shaped by markets, trade, and institutions such as the IMF and World Bank, "
             "where crises or wars can create ripple effects across supply chains, currencies, and jobs.") * 30,
}
print(f"Loaded {len(docs)} long docs with ~{sum(len(v.split()) for v in docs.values())//len(docs)} words avg.")

# -----------------------
# Run Demo Benchmark
# -----------------------
idx = EmbeddingIndex()
idx.build_from_docs(docs)
llm = LLMWrapper()

queries = [
    "Why is Paris culturally important?",
    "What is Python used for?",
    "Explain transformers in AI.",
    "What impact did quantum mechanics have on technology?",
    "How do global economic shocks spread?"
]

ks = [4, 8, 16, 32]
metrics = run_benchmark(idx, llm, queries, top_k_values=ks)
plot_metrics(metrics)

# ------------- WARNING -------------
# This notebook builds embeddings and a FAISS index for long documents. On Colab this can take several minutes
# and consume a few GBs of RAM. Reduce repetition factor or number of docs if you run into memory limits.


Loaded 5 long docs with ~829 words avg.

Benchmarking k=4 ...


AttributeError: 'EmbeddingIndex' object has no attribute 'query'