In [15]:
!pip -q install ir_datasets transformers datasets faiss-cpu pandas pyarrow tqdm

import os, re, json, math, numpy as np, pandas as pd, faiss, torch
from tqdm import tqdm

# ==== KNOBS ====
SEED                = 42
DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"

# Index size (subset of full 21M to keep it light)
N_PASSAGES_TOTAL    = 200_000      # tune for your machine
SHARD_ROWS          = 20_000       # rows per shard
BATCH_ENCODE        = 256
MAX_LEN             = 256
USE_COSINE          = False        # if True, L2-normalize vectors; search stays IP

# IVF params
IVF_NLIST           = 32768        # try 16384–65536
IVF_TRAIN_EMB       = 50_000       # vectors to train IVF (<= N_PASSAGES_TOTAL)

# Output paths
OUT_DIR       = "dpr_ivf_wiki_subset"
INDEX_PATH    = os.path.join(OUT_DIR, "ivf.index")
MANIFEST_PATH = os.path.join(OUT_DIR, "manifest.json")
os.makedirs(OUT_DIR, exist_ok=True)

np.random.seed(SEED); torch.manual_seed(SEED)
print("Device:", DEVICE)


  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m866.1/866.1 kB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m149.0/149.0 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.1/45.1 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m89.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for warc3-wet-clueweb09 (setup.py) ... [?25l[?25hdone
  Building wheel for cbor (setup.py) ... [?25l[?25hdone
Device: cuda


In [16]:
from transformers import (
    DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast,
    DPRContextEncoder,  DPRContextEncoderTokenizerFast
)

Q_MODEL = "facebook/dpr-question_encoder-single-nq-base"
P_MODEL = "facebook/dpr-ctx_encoder-single-nq-base"

q_tok = DPRQuestionEncoderTokenizerFast.from_pretrained(Q_MODEL)
p_tok = DPRContextEncoderTokenizerFast.from_pretrained(P_MODEL)

q_enc = DPRQuestionEncoder.from_pretrained(Q_MODEL).to(DEVICE).eval()
p_enc = DPRContextEncoder.from_pretrained(P_MODEL).to(DEVICE).eval()


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizerFast'.
Some weights of the model checkpoint at facebook/dpr-question_encoder-single-nq-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequ

In [17]:
@torch.no_grad()
def encode_passages_ctx(texts, batch=BATCH_ENCODE, max_len=MAX_LEN, normalize=USE_COSINE):
    out = []
    for i in range(0, len(texts), batch):
        t = p_tok(texts[i:i+batch], padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(DEVICE)
        emb = p_enc(**t).pooler_output
        if normalize:
            emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        out.append(emb.float().cpu().numpy())
    return np.vstack(out) if out else np.zeros((0, p_enc.config.hidden_size), dtype=np.float32)

@torch.no_grad()
def encode_queries(qs, batch=128, max_len=MAX_LEN, normalize=USE_COSINE):
    out = []
    for i in range(0, len(qs), batch):
        t = q_tok(qs[i:i+batch], padding=True, truncation=True, max_length=max_len, return_tensors="pt").to(DEVICE)
        e = q_enc(**t).pooler_output
        if normalize:
            e = torch.nn.functional.normalize(e, p=2, dim=1)
        out.append(e.float().cpu().numpy())
    return np.vstack(out)

def write_manifest(chunks, path=MANIFEST_PATH):
    with open(path, "w") as f:
        json.dump(chunks, f)
    return path

def read_manifest(path=MANIFEST_PATH):
    with open(path, "r") as f:
        return json.load(f)

def normalize_text(s):
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s]", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


In [18]:
import ir_datasets

# Exact corpus: “A wikipedia dump from 20 Dec 2018, split into 100-word passages” (21M docs).
# Each doc: (doc_id, text, title). This matches DPR’s psgs_w100.  :contentReference[oaicite:1]{index=1}
dpr = ir_datasets.load("dpr-w100")

# A generator over the first N_PASSAGES_TOTAL docs to keep memory down
def take_docs(dataset, n):
    i = 0
    for doc in dataset.docs_iter():
        yield doc
        i += 1
        if i >= n: break


In [19]:
IVF_TRAIN_EMB = min(IVF_TRAIN_EMB, N_PASSAGES_TOTAL)

train_texts = []
for doc in tqdm(take_docs(dpr, IVF_TRAIN_EMB), total=IVF_TRAIN_EMB, desc="Collect IVF train docs"):
    # doc has fields: doc_id, text, title
    if doc.text:
        train_texts.append(doc.text)

train_mat = encode_passages_ctx(train_texts, batch=BATCH_ENCODE)
d = train_mat.shape[1]
print("IVF training matrix:", train_mat.shape)


[INFO] If you have a local copy of https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz, you can symlink it here to avoid downloading it again: /root/.ir_datasets/downloads/612fe66e0b6b41ee28f806140226c563
[INFO] [starting] https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz
Collect IVF train docs:   0%|          | 0/50000 [00:00<?, ?it/s]
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.0%| 0.00/4.69G [00:00<?, ?B/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.1%| 4.20M/4.69G [00:00<01:55, 40.6MB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.2%| 8.40M/4.69G [00:00<01:58, 39.4MB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.3%| 12.6M/4.69G [00:00<01:58, 39.6MB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz: 0.4%| 16.8M/4.69G [00:00<02:09, 36.1MB/s][A
https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.g

IVF training matrix: (50000, 768)


In [20]:
quantizer = faiss.IndexFlatIP(d)  # IP; cosine works if vectors normalized
ivf = faiss.IndexIVFFlat(quantizer, d, IVF_NLIST, faiss.METRIC_INNER_PRODUCT)
ivf.train(train_mat)
print("IVF trained:", ivf.is_trained)


IVF trained: True


In [25]:
manifest = []   # list of {"path","start_id","n"}
internal_base = 0
shard_idx = 0

texts, titles = [], []
pbar = tqdm(total=N_PASSAGES_TOTAL, desc="Shard & add DPR w100 to IVF")
count = 0

for doc in dpr.docs_iter():
    if not doc.text:
        continue
    texts.append(doc.text)
    titles.append(doc.title or "")
    count += 1
    pbar.update(1)

    if len(texts) == SHARD_ROWS:
        emb = encode_passages_ctx(texts, batch=BATCH_ENCODE)
        ivf.add(emb)

        shard_path = os.path.join(OUT_DIR, f"passages_shard_{shard_idx:05d}.parquet")
        pd.DataFrame({
            "internal_id": np.arange(internal_base, internal_base + len(texts)),
            "title": titles,
            "text": texts
        }).to_parquet(shard_path, index=False)
        manifest.append({"path": shard_path, "start_id": int(internal_base), "n": int(len(texts))})

        internal_base += len(texts)
        shard_idx += 1
        texts, titles = [], []

    if count >= N_PASSAGES_TOTAL:
        break

# flush remainder
if texts:
    emb = encode_passages_ctx(texts, batch=BATCH_ENCODE)
    ivf.add(emb)
    shard_path = os.path.join(OUT_DIR, f"passages_shard_{shard_idx:05d}.parquet")
    pd.DataFrame({
        "internal_id": np.arange(internal_base, internal_base + len(texts)),
        "title": titles,
        "text": texts
    }).to_parquet(shard_path, index=False)
    manifest.append({"path": shard_path, "start_id": int(internal_base), "n": int(len(texts))})
    internal_base += len(texts)
    shard_idx += 1

pbar.close()
print("Total vectors added:", ivf.ntotal)
write_manifest(manifest, MANIFEST_PATH)
faiss.write_index(ivf, INDEX_PATH)
print("Saved index to", INDEX_PATH)
print("Saved manifest to", MANIFEST_PATH)


Shard & add to IVF:  10%|█         | 20000/200000 [00:11<00:02, 88752.33it/s]

KeyboardInterrupt: 

In [26]:
import json, pandas as pd

ivf = faiss.read_index(INDEX_PATH)
with open(MANIFEST_PATH, "r") as f:
    _manifest = json.load(f)

ranges = []
for chunk in _manifest:
    start = int(chunk["start_id"])
    end   = start + int(chunk["n"])
    ranges.append((start, end, chunk["path"]))

print("Index entries:", ivf.ntotal, "| Shards:", len(ranges))


Index entries: 200000 | Shards: 10


In [27]:
def load_rows_by_ids(id_list):
    """Return rows in the same order as id_list; None where missing."""
    groups = {}
    for _id in id_list:
        found = False
        for (start, end, path) in ranges:
            if start <= _id < end:
                groups.setdefault(path, []).append(_id)
                found = True
                break
        if not found:
            groups.setdefault(None, []).append(_id)  # mark as missing

    fetched = {}
    for path, ids in groups.items():
        if path is None:
            for _id in ids: fetched[_id] = None
            continue
        df = pd.read_parquet(path)
        m  = df[df["internal_id"].isin(ids)]
        for _, row in m.iterrows():
            fetched[int(row["internal_id"])] = {
                "id": int(row["internal_id"]),
                "title": row["title"],
                "text": row["text"]
            }
        # Any ids not found in this shard -> None
        for _id in ids:
            fetched.setdefault(_id, None)

    return [fetched.get(i, None) for i in id_list]


In [28]:
# Tune probing (higher nprobe => better recall, slower)
ivf.nprobe = max(1, min(64, IVF_NLIST // 512))  # e.g., 32 for 16,384 lists

@torch.no_grad()
def retrieve(question, topk=10):
    qv = encode_queries([question]).astype(np.float32)
    D, I = ivf.search(qv, topk)

    # Drop FAISS "no result" ids (-1), keep order
    pairs = [(float(s), int(pid)) for s, pid in zip(D[0], I[0]) if pid != -1]
    if not pairs:
        return []

    ids  = [pid for _, pid in pairs]
    rows = load_rows_by_ids(ids)

    hits = []
    for rank, ((score, pid), row) in enumerate(zip(pairs, rows), 1):
        if row is None:
            continue
        hits.append({
            "rank": rank,
            "score": score,
            "internal_id": row["id"],
            "title": row["title"],
            "text": row["text"]
        })
    return hits

# quick smoke test
for h in retrieve("Who wrote War and Peace?", topk=5):
    print(f"{h['rank']:>2}. {h['score']:.3f}  {h['title'][:60]}")
    print(h["text"][:150].replace("\n", " "), "...\n")


 1. 71.216  "Alexander of Hales"
"declaring war in three ways: the relief of good people, coercion of the wicked, and peace for all. It is important to note that Alexander put ‘peace  ...

 2. 71.033  "Animal Farm"
"War is suggested when Napoleon and Pilkington, both suspicious, ""played an ace of spades simultaneously"". Similarly, the music in the novel, starti ...

 3. 70.973  "Herman Melville"
"the War"" (1866) was his poetic reflection on the moral questions of the American Civil War. In 1867, his oldest child Malcolm died at home from a se ...

 4. 70.443  "Alternate history"
"war, involving rival paratime empires, was developed in Fritz Leiber's Change War series, starting with the Hugo Award winning ""The Big Time"" (1958 ...

 5. 69.942  "All Quiet on the Western Front"
"adapted into comic book form as part of the ""Classics Illustrated"" series. All Quiet on the Western Front All Quiet on the Western Front () is a no ...



In [29]:
from transformers import DPRReader, DPRReaderTokenizerFast

READER_CKPT = "facebook/dpr-reader-single-nq-base"
r_tok  = DPRReaderTokenizerFast.from_pretrained(READER_CKPT)
reader = DPRReader.from_pretrained(READER_CKPT).to(DEVICE).eval()

@torch.no_grad()
def read_best_answer(question, hits, n_ctx=12):
    if not hits:
        return "", float("-inf")
    titles = [h["title"] or f"doc_{h['internal_id']}" for h in hits[:n_ctx]]
    texts  = [h["text"] for h in hits[:n_ctx]]

    enc = r_tok(
        questions=[question]*len(texts),
        titles=titles,
        texts=texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(DEVICE)

    out = reader(**enc)
    start, end, rel = out.start_logits, out.end_logits, out.relevance_logits.squeeze(-1)

    best, best_score = "", -1e9
    for i in range(start.size(0)):
        s_idx = int(torch.argmax(start[i]))
        e_idx = int(torch.argmax(end[i]))
        if e_idx < s_idx:
            e_idx = s_idx
        span = r_tok.decode(enc["input_ids"][i, s_idx:e_idx+1], skip_special_tokens=True).strip()
        score = float(start[i, s_idx] + end[i, e_idx] + rel[i])
        if score > best_score:
            best, best_score = span, score
    return best, best_score

# sanity
q = "Who wrote War and Peace?"
hits = retrieve(q, topk=20)
ans, sc = read_best_answer(q, hits, n_ctx=12)
print("Answer:", ans, "| score:", sc)


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRReaderTokenizerFast'.
Some weights of the model checkpoint at facebook/dpr-reader-single-nq-base were not used when initializing DPRReader: ['span_predictor.encoder.bert_model.pooler.dense.bias', 'span_predictor.encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRReader from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRReader from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
You're

Answer: erich maria remarque | score: 16.045578002929688


In [33]:
from datasets import load_dataset

def load_nq_open_robust(n=400):
    ds = load_dataset("nq_open")
    split = "validation" if "validation" in ds else ("test" if "test" in ds else "train")
    sub = ds[split].select(range(min(n, ds[split].num_rows)))

    questions, golds = [], []
    for ex in sub:
        q = ex.get("question") or ex.get("query")
        if "answers" in ex:         a = ex["answers"]
        elif "answer" in ex:        a = ex["answer"]
        elif "answer_text" in ex:   a = ex["answer_text"]
        else:                       a = []

        if isinstance(a, str):
            a = [a]
        elif isinstance(a, dict):
            if "text" in a and isinstance(a["text"], (list, tuple)):
                a = list(a["text"])
            elif "value" in a:
                v = a["value"]; a = v if isinstance(v, (list, tuple)) else [v]
            else:
                vals = []
                for v in a.values():
                    if isinstance(v, (list, tuple)): vals.extend(v)
                    elif isinstance(v, str): vals.append(v)
                a = vals

        if q and a:
            questions.append(q)
            golds.append([str(x) for x in a])

    tag = f"nq_open:{split}[:{len(questions)}]"
    return questions, golds, tag

QUESTIONS, GOLD, TAG = load_nq_open_robust(400)
print(TAG, "| n:", len(QUESTIONS))
for i in range(min(3, len(QUESTIONS))):
    print(f"- Q{i+1}:", QUESTIONS[i])
    print("  gold:", GOLD[i][:3], "...\n")


from tqdm import tqdm

def question_in_corpus(q, golds, retrieve_fn, k=100):
    """Check if any gold answer string appears in retrieved texts."""
    gold_norm = [normalize_text(a) for a in golds]
    H = retrieve_fn(q, topk=k)
    for h in H:
        txt = normalize_text(h["text"])
        if any(g in txt for g in gold_norm):
            return True
    return False

filtered_Q, filtered_G = [], []
for q, g in tqdm(zip(QUESTIONS, GOLD), total=len(QUESTIONS), desc="Filtering by corpus"):
    if question_in_corpus(q, g, retrieve, k=100):  # uses your DPR retrieve() function
        filtered_Q.append(q)
        filtered_G.append(g)

print(f"Filtered questions: {len(filtered_Q)}/{len(QUESTIONS)} remain in corpus")


nq_open:validation[:400] | n: 400
- Q1: when was the last time anyone was on the moon
  gold: ['14 December 1972 UTC', 'December 1972'] ...

- Q2: who wrote he ain't heavy he's my brother lyrics
  gold: ['Bobby Scott', 'Bob Russell'] ...

- Q3: how many seasons of the bastard executioner are there
  gold: ['one', 'one season'] ...




Filtering by corpus:   0%|          | 0/400 [00:00<?, ?it/s][A
Filtering by corpus:   0%|          | 1/400 [00:00<02:51,  2.33it/s][A
Filtering by corpus:   0%|          | 2/400 [00:01<03:32,  1.87it/s][A
Filtering by corpus:   1%|          | 3/400 [00:01<03:41,  1.79it/s][A
Filtering by corpus:   1%|          | 4/400 [00:02<03:27,  1.90it/s][A
Filtering by corpus:   1%|▏         | 5/400 [00:02<03:37,  1.82it/s][A
Filtering by corpus:   2%|▏         | 6/400 [00:03<03:44,  1.76it/s][A
Filtering by corpus:   2%|▏         | 7/400 [00:03<03:50,  1.71it/s][A
Filtering by corpus:   2%|▏         | 8/400 [00:04<03:43,  1.75it/s][A
Filtering by corpus:   2%|▏         | 9/400 [00:05<03:47,  1.72it/s][A
Filtering by corpus:   2%|▎         | 10/400 [00:05<03:42,  1.75it/s][A
Filtering by corpus:   3%|▎         | 11/400 [00:05<03:12,  2.02it/s][A
Filtering by corpus:   3%|▎         | 12/400 [00:06<03:15,  1.98it/s][A
Filtering by corpus:   3%|▎         | 13/400 [00:07<03:26,  1.88it/s

Filtered questions: 180/400 remain in corpus





In [34]:
from tqdm import tqdm

def hit_at_k(questions, golds, k=20):
    if not questions:
        return 0.0
    hits = 0
    for q, gold in tqdm(zip(questions, golds), total=len(questions), desc=f"Hit@{k}"):
        gold_norm = [normalize_text(a) for a in gold]
        H = retrieve(q, topk=k)
        found = False
        for h in H:
            txt = normalize_text(h["text"])
            if any(g in txt for g in gold_norm):
                found = True
                break
        hits += int(found)
    return hits / len(questions)

# --- (next existing cell) ---
h20  = hit_at_k(filtered_Q, filtered_G, k=20)
h100 = hit_at_k(filtered_Q, filtered_G, k=100)
print(f"[{TAG} | filtered] Retrieval Hit@20={h20:.3f}  Hit@100={h100:.3f}")



Hit@20:   0%|          | 0/180 [00:00<?, ?it/s][A
Hit@20:   1%|          | 1/180 [00:00<00:44,  4.04it/s][A
Hit@20:   1%|          | 2/180 [00:01<01:41,  1.76it/s][A
Hit@20:   2%|▏         | 3/180 [00:01<01:55,  1.54it/s][A
Hit@20:   2%|▏         | 4/180 [00:02<01:51,  1.58it/s][A
Hit@20:   3%|▎         | 5/180 [00:02<01:34,  1.85it/s][A
Hit@20:   3%|▎         | 6/180 [00:03<01:39,  1.74it/s][A
Hit@20:   4%|▍         | 7/180 [00:03<01:20,  2.14it/s][A
Hit@20:   4%|▍         | 8/180 [00:04<01:18,  2.19it/s][A
Hit@20:   5%|▌         | 9/180 [00:04<01:06,  2.56it/s][A
Hit@20:   6%|▌         | 10/180 [00:04<01:08,  2.50it/s][A
Hit@20:   6%|▌         | 11/180 [00:05<01:05,  2.57it/s][A
Hit@20:   7%|▋         | 13/180 [00:05<00:53,  3.12it/s][A
Hit@20:   8%|▊         | 14/180 [00:06<00:57,  2.88it/s][A
Hit@20:   8%|▊         | 15/180 [00:06<01:05,  2.51it/s][A
Hit@20:   9%|▉         | 16/180 [00:06<01:03,  2.59it/s][A
Hit@20:   9%|▉         | 17/180 [00:07<00:59,  2.75it/s]

[nq_open:validation[:400] | filtered] Retrieval Hit@20=0.706  Hit@100=1.000



