In [21]:
# ============================================================
# CELL 1: Install Dependencies & Imports
# ============================================================
!pip -q install ir_datasets transformers datasets faiss-cpu pandas pyarrow tqdm

import os, re, json, math, numpy as np, pandas as pd, faiss, torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

# ==== CONFIGURATION ====
SEED                = 42
DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"

# Index size (subset of full 21M to keep it light)
N_PASSAGES_TOTAL    = 200_000      # tune for your machine
SHARD_ROWS          = 20_000       # rows per shard
BATCH_ENCODE        = 256
MAX_LEN             = 256
USE_COSINE          = False        # if True, L2-normalize vectors; search stays IP

# IVF params
IVF_NLIST           = 32768        # try 16384–65536
IVF_TRAIN_EMB       = 50_000       # vectors to train IVF (<= N_PASSAGES_TOTAL)

# Output paths
OUT_DIR       = "contriever_ivf_wiki_subset"
INDEX_PATH    = os.path.join(OUT_DIR, "ivf.index")
MANIFEST_PATH = os.path.join(OUT_DIR, "manifest.json")
os.makedirs(OUT_DIR, exist_ok=True)

np.random.seed(SEED)
torch.manual_seed(SEED)
print("Device:", DEVICE)

Device: cuda


In [22]:
# ============================================================
# CELL 2: Load Contriever Model
# ============================================================
CONTRIEVER_MODEL = "facebook/contriever"

print(f"Loading Contriever from {CONTRIEVER_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(CONTRIEVER_MODEL)
model = AutoModel.from_pretrained(CONTRIEVER_MODEL).to(DEVICE).eval()
print("Contriever loaded successfully!")
print(f"Embedding dimension: {model.config.hidden_size}")

Loading Contriever from facebook/contriever...
Contriever loaded successfully!
Embedding dimension: 768


In [23]:
# ============================================================
# CELL 3: Define Encoding Functions
# ============================================================
def mean_pooling(token_embeddings, attention_mask):
    """
    Contriever uses mean pooling over token embeddings
    """
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

@torch.no_grad()
def encode_contriever(texts, batch=BATCH_ENCODE, max_len=MAX_LEN, normalize=USE_COSINE):
    """
    Encode texts using Contriever encoder with mean pooling
    Works for both queries and passages
    """
    out = []
    for i in range(0, len(texts), batch):
        batch_texts = texts[i:i+batch]

        # Tokenize
        encoded = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_len,
            return_tensors="pt"
        ).to(DEVICE)

        # Forward pass
        outputs = model(**encoded)

        # Mean pooling
        embeddings = mean_pooling(outputs.last_hidden_state, encoded['attention_mask'])

        # Optional normalization
        if normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        out.append(embeddings.float().cpu().numpy())

    return np.vstack(out) if out else np.zeros((0, model.config.hidden_size), dtype=np.float32)

# Wrapper functions for compatibility
def encode_passages_ctx(texts, batch=BATCH_ENCODE, max_len=MAX_LEN, normalize=USE_COSINE):
    """Encode passages - wrapper for consistency"""
    return encode_contriever(texts, batch, max_len, normalize)

def encode_queries(qs, batch=128, max_len=MAX_LEN, normalize=USE_COSINE):
    """Encode queries - wrapper for consistency"""
    return encode_contriever(qs, batch, max_len, normalize)

def write_manifest(chunks, path=MANIFEST_PATH):
    with open(path, "w") as f:
        json.dump(chunks, f)
    return path

def read_manifest(path=MANIFEST_PATH):
    with open(path, "r") as f:
        return json.load(f)

def normalize_text(s):
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

print("Encoding functions defined successfully!")

Encoding functions defined successfully!


In [24]:
# ============================================================
# CELL 4: Load DPR-W100 Dataset
# ============================================================
import ir_datasets

print("Loading DPR-W100 dataset...")
dpr = ir_datasets.load("dpr-w100")

def take_docs(dataset, n):
    """Generator over first n docs"""
    i = 0
    for doc in dataset.docs_iter():
        yield doc
        i += 1
        if i >= n:
            break

print("Dataset loaded successfully!")

Loading DPR-W100 dataset...
Dataset loaded successfully!


In [25]:
# ============================================================
# CELL 5: Collect & Encode IVF Training Data
# ============================================================
IVF_TRAIN_EMB = min(IVF_TRAIN_EMB, N_PASSAGES_TOTAL)

train_texts = []
for doc in tqdm(take_docs(dpr, IVF_TRAIN_EMB), total=IVF_TRAIN_EMB, desc="Collect IVF train docs"):
    if doc.text:
        train_texts.append(doc.text)

print(f"Encoding {len(train_texts)} training passages...")
train_mat = encode_passages_ctx(train_texts, batch=BATCH_ENCODE)
d = train_mat.shape[1]
print(f"IVF training matrix: {train_mat.shape}, embedding dim: {d}")


[INFO] If you have a local copy of https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz, you can symlink it here to avoid downloading it again: /root/.ir_datasets/downloads/612fe66e0b6b41ee28f806140226c563
[INFO] [starting] https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
Collect IVF train docs:   0%|          | 0/50000 [00:00<?, ?it/s]
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.0%| 0.00/4.69G [00:00<?, ?B/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.0%| 24.6k/4.69G [00:00<6:34:44, 198kB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.0%| 65.5k/4.69G [00:00<5:17:35, 246kB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.0%| 172k/4.69G [00:00<3:05:50, 421kB/s] [A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.0%| 377k/4.69G [00:00<1:54:20, 684kB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.ts

Encoding 50000 training passages...
IVF training matrix: (50000, 768), embedding dim: 768


In [26]:
# ============================================================
# CELL 6: Train IVF Index
# ============================================================
print("Building IVF index...")
quantizer = faiss.IndexFlatIP(d)
ivf = faiss.IndexIVFFlat(quantizer, d, IVF_NLIST, faiss.METRIC_INNER_PRODUCT)
ivf.train(train_mat)
print(f"IVF trained: {ivf.is_trained}")

Building IVF index...
IVF trained: True


In [27]:
# ============================================================
# CELL 7: Index All Passages & Save Shards
# ============================================================
manifest = []
internal_base = 0
shard_idx = 0

texts, titles = [], []
pbar = tqdm(total=N_PASSAGES_TOTAL, desc="Shard & add to IVF")
count = 0

for doc in dpr.docs_iter():
    if not doc.text:
        continue
    texts.append(doc.text)
    titles.append(doc.title or "")
    count += 1
    pbar.update(1)

    if len(texts) == SHARD_ROWS:
        # Encode and add to index
        emb = encode_passages_ctx(texts, batch=BATCH_ENCODE)
        ivf.add(emb)

        # Save shard
        shard_path = os.path.join(OUT_DIR, f"passages_shard_{shard_idx:05d}.parquet")
        pd.DataFrame({
            "internal_id": np.arange(internal_base, internal_base + len(texts)),
            "title": titles,
            "text": texts
        }).to_parquet(shard_path, index=False)
        manifest.append({"path": shard_path, "start_id": int(internal_base), "n": int(len(texts))})

        internal_base += len(texts)
        shard_idx += 1
        texts, titles = [], []

    if count >= N_PASSAGES_TOTAL:
        break

# Flush remainder
if texts:
    emb = encode_passages_ctx(texts, batch=BATCH_ENCODE)
    ivf.add(emb)
    shard_path = os.path.join(OUT_DIR, f"passages_shard_{shard_idx:05d}.parquet")
    pd.DataFrame({
        "internal_id": np.arange(internal_base, internal_base + len(texts)),
        "title": titles,
        "text": texts
    }).to_parquet(shard_path, index=False)
    manifest.append({"path": shard_path, "start_id": int(internal_base), "n": int(len(texts))})

pbar.close()
print(f"Total vectors added: {ivf.ntotal}")

# Save index and manifest
write_manifest(manifest, MANIFEST_PATH)
faiss.write_index(ivf, INDEX_PATH)
print(f"Saved index to {INDEX_PATH}")
print(f"Saved manifest to {MANIFEST_PATH}")



Shard & add to IVF: 100%|██████████| 200000/200000 [42:09<00:00, 79.06it/s] 


Total vectors added: 200000
Saved index to contriever_ivf_wiki_subset/ivf.index
Saved manifest to contriever_ivf_wiki_subset/manifest.json


In [28]:
# ============================================================
# CELL 8: Load Index for Retrieval
# ============================================================
ivf = faiss.read_index(INDEX_PATH)
with open(MANIFEST_PATH, "r") as f:
    _manifest = json.load(f)

ranges = []
for chunk in _manifest:
    start = int(chunk["start_id"])
    end = start + int(chunk["n"])
    ranges.append((start, end, chunk["path"]))

print(f"Index entries: {ivf.ntotal} | Shards: {len(ranges)}")

Index entries: 200000 | Shards: 10


In [29]:
# ============================================================
# CELL 9: Define Retrieval Functions
# ============================================================
def load_rows_by_ids(id_list):
    """Return rows in the same order as id_list; None where missing."""
    groups = {}
    for _id in id_list:
        found = False
        for (start, end, path) in ranges:
            if start <= _id < end:
                groups.setdefault(path, []).append(_id)
                found = True
                break
        if not found:
            groups.setdefault(None, []).append(_id)

    fetched = {}
    for path, ids in groups.items():
        if path is None:
            for _id in ids:
                fetched[_id] = None
            continue
        df = pd.read_parquet(path)
        m = df[df["internal_id"].isin(ids)]
        for _, row in m.iterrows():
            fetched[int(row["internal_id"])] = {
                "id": int(row["internal_id"]),
                "title": row["title"],
                "text": row["text"]
            }
        for _id in ids:
            fetched.setdefault(_id, None)

    return [fetched.get(i, None) for i in id_list]

# Tune probing
ivf.nprobe = max(1, min(64, IVF_NLIST // 512))

@torch.no_grad()
def retrieve(question, topk=10):
    """Retrieve top-k passages for a question using Contriever"""
    qv = encode_queries([question]).astype(np.float32)
    D, I = ivf.search(qv, topk)

    # Drop FAISS "no result" ids (-1), keep order
    pairs = [(float(s), int(pid)) for s, pid in zip(D[0], I[0]) if pid != -1]
    if not pairs:
        return []

    ids = [pid for _, pid in pairs]
    rows = load_rows_by_ids(ids)

    hits = []
    for rank, ((score, pid), row) in enumerate(zip(pairs, rows), 1):
        if row is None:
            continue
        hits.append({
            "rank": rank,
            "score": score,
            "internal_id": row["id"],
            "title": row["title"],
            "text": row["text"]
        })
    return hits

print("Retrieval functions defined successfully!")


Retrieval functions defined successfully!


In [30]:
# ============================================================
# CELL 10: Smoke Test
# ============================================================
test_query = "Who wrote War and Peace?"
print(f"Query: {test_query}\n")
for h in retrieve(test_query, topk=5):
    print(f"{h['rank']:>2}. {h['score']:.3f}  {h['title'][:60]}")
    print(h["text"][:150].replace("\n", " "), "...\n")

Query: Who wrote War and Peace?

 1. 0.978  "Carl von Clausewitz"
"thought before 1945 other than via British writers, though Generals Eisenhower and Patton were avid readers. He did influence Karl Marx, Friedrich En ...

 2. 0.946  "Carl von Clausewitz"
"see the sixteen essays presented in ""Clausewitz in the Twenty-First Century"" edited by Hew Strachan and Andreas Herberg-Rothe. In military academie ...

 3. 0.943  "Carl von Clausewitz"
"his seemingly contradictory claims (discussions pertinent to the tactical, operational and strategic levels of war are one example). Clausewitz const ...

 4. 0.923  "Arthur Schopenhauer"
"what he has written in ""War and Peace"" is also said by Schopenhauer in ""The World as Will and Representation"". Jorge Luis Borges remarked that th ...

 5. 0.918  "All Quiet on the Western Front"
"the opening statement that the novel does not advocate any political position, but is merely an attempt to describe the experiences of the soldier. T ...



In [31]:
# ============================================================
# CELL 11: Load NQ-Open Dataset & Filter by Corpus
# ============================================================
from datasets import load_dataset

def load_nq_open_robust(n=400):
    ds = load_dataset("nq_open")
    split = "validation" if "validation" in ds else ("test" if "test" in ds else "train")
    sub = ds[split].select(range(min(n, ds[split].num_rows)))

    questions, golds = [], []
    for ex in sub:
        q = ex.get("question") or ex.get("query")
        if "answers" in ex:
            a = ex["answers"]
        elif "answer" in ex:
            a = ex["answer"]
        elif "answer_text" in ex:
            a = ex["answer_text"]
        else:
            a = []

        if isinstance(a, str):
            a = [a]
        elif isinstance(a, dict):
            if "text" in a and isinstance(a["text"], (list, tuple)):
                a = list(a["text"])
            elif "value" in a:
                v = a["value"]
                a = v if isinstance(v, (list, tuple)) else [v]
            else:
                vals = []
                for v in a.values():
                    if isinstance(v, (list, tuple)):
                        vals.extend(v)
                    elif isinstance(v, str):
                        vals.append(v)
                a = vals

        if q and a:
            questions.append(q)
            golds.append([str(x) for x in a])

    tag = f"nq_open:{split}[:{len(questions)}]"
    return questions, golds, tag

QUESTIONS, GOLD, TAG = load_nq_open_robust(400)
print(TAG, "| n:", len(QUESTIONS))
for i in range(min(3, len(QUESTIONS))):
    print(f"- Q{i+1}:", QUESTIONS[i])
    print("  gold:", GOLD[i][:3], "...\n")

# Filter questions by corpus coverage
def question_in_corpus(q, golds, retrieve_fn, k=100):
    """Check if any gold answer string appears in retrieved texts."""
    gold_norm = [normalize_text(a) for a in golds]
    H = retrieve_fn(q, topk=k)
    for h in H:
        txt = normalize_text(h["text"])
        if any(g in txt for g in gold_norm):
            return True
    return False

filtered_Q, filtered_G = [], []
for q, g in tqdm(zip(QUESTIONS, GOLD), total=len(QUESTIONS), desc="Filtering by corpus"):
    if question_in_corpus(q, g, retrieve, k=100):
        filtered_Q.append(q)
        filtered_G.append(g)

print(f"Filtered questions: {len(filtered_Q)}/{len(QUESTIONS)} remain in corpus")



README.md: 0.00B [00:00, ?B/s]

nq_open/train-00000-of-00001.parquet:   0%|          | 0.00/4.46M [00:00<?, ?B/s]

nq_open/validation-00000-of-00001.parque(…):   0%|          | 0.00/214k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87925 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3610 [00:00<?, ? examples/s]

nq_open:validation[:400] | n: 400
- Q1: when was the last time anyone was on the moon
  gold: ['14 December 1972 UTC', 'December 1972'] ...

- Q2: who wrote he ain't heavy he's my brother lyrics
  gold: ['Bobby Scott', 'Bob Russell'] ...

- Q3: how many seasons of the bastard executioner are there
  gold: ['one', 'one season'] ...



Filtering by corpus: 100%|██████████| 400/400 [03:55<00:00,  1.70it/s]

Filtered questions: 171/400 remain in corpus





In [32]:
# ============================================================
# CELL 12: Evaluate Hit@K Metrics
# ============================================================
def hit_at_k(questions, golds, k=20):
    if not questions:
        return 0.0
    hits = 0
    for q, gold in tqdm(zip(questions, golds), total=len(questions), desc=f"Hit@{k}"):
        gold_norm = [normalize_text(a) for a in gold]
        H = retrieve(q, topk=k)
        found = False
        for h in H:
            txt = normalize_text(h["text"])
            if any(g in txt for g in gold_norm):
                found = True
                break
        hits += int(found)
    return hits / len(questions)

h20 = hit_at_k(filtered_Q, filtered_G, k=20)
h100 = hit_at_k(filtered_Q, filtered_G, k=100)
print(f"\n[{TAG} | filtered] Contriever Retrieval Hit@20={h20:.3f}  Hit@100={h100:.3f}")



Hit@20: 100%|██████████| 171/171 [01:00<00:00,  2.81it/s]
Hit@100: 100%|██████████| 171/171 [01:38<00:00,  1.73it/s]


[nq_open:validation[:400] | filtered] Contriever Retrieval Hit@20=0.795  Hit@100=1.000



