In [19]:
!pip -q install transformers faiss-cpu ranx tqdm ir_datasets

import os, math, time, json
from pathlib import Path
import numpy as np
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import ir_datasets
from ranx import Qrels, Run, evaluate

# ---------- Config ----------
WORKDIR = Path("/kaggle/working")
EMB_DIR = WORKDIR / "miracl_ar_contriever_emb_chunks"
RUN_DIR = WORKDIR / "runs"
EMB_DIR.mkdir(parents=True, exist_ok=True)
RUN_DIR.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "facebook/contriever"
DATASET = "miracl/ar/dev"     # queries + qrels
CORPUS  = "miracl/ar"         # corpus

CHUNK_SIZE = 50_000
BATCH_SIZE = 128
MAX_LENGTH = 256

TOPK = 1000

# IVF params (good for 2M docs)
N_LIST = 8192
N_PROBE = 16
TRAIN_VECS = 400_000  # better than 200k if possible

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# ---------- Model ----------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

def mean_pooling(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = (last_hidden_state * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp(min=1e-9)
    return summed / counts

@torch.no_grad()
def encode_texts(texts, batch_size=BATCH_SIZE):
    embs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        tok = tokenizer(batch, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")
        tok = {k: v.to(device) for k, v in tok.items()}
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            out = model(**tok)
            emb = mean_pooling(out.last_hidden_state, tok["attention_mask"])
            emb = torch.nn.functional.normalize(emb, p=2, dim=1)
        embs.append(emb.detach().cpu().numpy().astype("float32"))
    return np.vstack(embs)

def chunk_paths(chunk_idx: int):
    emb_path = EMB_DIR / f"emb_{chunk_idx:05d}.npy"
    ids_path = EMB_DIR / f"ids_{chunk_idx:05d}.npy"
    meta_path = EMB_DIR / f"meta_{chunk_idx:05d}.json"
    return emb_path, ids_path, meta_path

# ---------- Load datasets ----------
ds = ir_datasets.load(DATASET)     # has queries + qrels + (docs_iter too)
corpus_ds = ir_datasets.load(CORPUS)

num_docs = corpus_ds.docs_count()
num_chunks = math.ceil(num_docs / CHUNK_SIZE)
print("Corpus docs:", num_docs, "Num chunks:", num_chunks)

# ---------- 1) Encode corpus in chunks (SKIP if exists) ----------
docs_iter = corpus_ds.docs_iter()
chunk_idx = 0
buf_ids, buf_texts = [], []
seen = 0

for d in tqdm(docs_iter, total=num_docs, desc="Streaming corpus"):
    doc_id = str(d.doc_id)
    title = (d.title or "").strip()
    text = (d.text or "").strip()
    full = (title + "\n" + text).strip() if title else text

    buf_ids.append(doc_id)
    buf_texts.append(full)
    seen += 1

    if len(buf_ids) == CHUNK_SIZE or seen == num_docs:
        emb_path, ids_path, meta_path = chunk_paths(chunk_idx)

        if emb_path.exists() and ids_path.exists():
            # skip
            pass
        else:
            t0 = time.time()
            emb = encode_texts(buf_texts)
            dt = time.time() - t0
            np.save(emb_path, emb)
            np.save(ids_path, np.array(buf_ids, dtype=object))
            meta_path.write_text(json.dumps({
                "chunk_idx": chunk_idx, "n": len(buf_ids), "seconds": dt
            }, ensure_ascii=False, indent=2), encoding="utf-8")
            print(f"✅ saved chunk {chunk_idx} n={len(buf_ids)} time={dt/60:.1f}m")

        buf_ids, buf_texts = [], []
        chunk_idx += 1

print("✅ All chunks ready.")

# ---------- 2) Build FAISS IVF index ----------
first = np.load(chunk_paths(0)[0])
dim = first.shape[1]
print("Embedding dim:", dim)

# collect training vectors
rng = np.random.default_rng(42)
train_buf = []
need = TRAIN_VECS

for ci in range(num_chunks):
    emb = np.load(chunk_paths(ci)[0]).astype("float32")
    if emb.shape[0] <= need:
        train_buf.append(emb)
        need -= emb.shape[0]
    else:
        idx = rng.choice(emb.shape[0], size=need, replace=False)
        train_buf.append(emb[idx])
        need = 0
    if need == 0:
        break

train_x = np.vstack(train_buf)
print("Train sample:", train_x.shape)

quantizer = faiss.IndexFlatIP(dim)
index = faiss.IndexIVFFlat(quantizer, dim, N_LIST, faiss.METRIC_INNER_PRODUCT)
print("Training IVF...")
index.train(train_x)
index.nprobe = N_PROBE
print("✅ IVF trained. Adding vectors...")

all_doc_ids = []
total = 0
for ci in range(num_chunks):
    emb = np.load(chunk_paths(ci)[0]).astype("float32")
    ids = np.load(chunk_paths(ci)[1], allow_pickle=True)
    index.add(emb)
    all_doc_ids.append(ids)
    total += emb.shape[0]
    if ci % 10 == 0:
        print(f"  added {ci}/{num_chunks-1}, total={total}")

all_doc_ids = np.concatenate(all_doc_ids)
print("Index total:", index.ntotal, "Docids:", len(all_doc_ids))

INDEX_PATH = WORKDIR / "faiss_miracl_ar_contriever_ivf.index"
DOCIDS_PATH = WORKDIR / "faiss_miracl_ar_docids.npy"
faiss.write_index(index, str(INDEX_PATH))
np.save(DOCIDS_PATH, all_doc_ids)
print("✅ Saved index + docids")

# ---------- 3) Encode queries ----------
queries = list(ds.queries_iter())
query_ids = [str(q.query_id) for q in queries]
query_texts = [q.text for q in queries]
q_emb = encode_texts(query_texts, batch_size=64)
print("Queries:", len(query_ids), "q_emb:", q_emb.shape)

# ---------- 4) Search + write run ----------
scores, idxs = index.search(q_emb.astype("float32"), TOPK)
run_path = RUN_DIR / "contriever_zero_miracl_ar_dev.run"
with open(run_path, "w", encoding="utf-8") as f:
    for i, qid in enumerate(query_ids):
        for rank, (doc_idx, score) in enumerate(zip(idxs[i], scores[i]), start=1):
            if doc_idx < 0:
                continue
            docid = str(all_doc_ids[int(doc_idx)])
            f.write(f"{qid} Q0 {docid} {rank} {float(score)} contriever_zero\n")
print("✅ Wrote run:", run_path)

# ---------- 5) Eval (qrels from SAME ds) ----------
qrels_dict = {}
for r in ds.qrels_iter():
    qrels_dict.setdefault(str(r.query_id), {})
    qrels_dict[str(r.query_id)][str(r.doc_id)] = int(r.relevance)

qrels = Qrels(qrels_dict)
run = Run.from_file(str(run_path), kind="trec")

metrics = ["map", "mrr", "ndcg@10", "recall@100"]
results = evaluate(qrels, run, metrics)
print(results)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Device: cuda
Corpus docs: 2061414 Num chunks: 42


  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
Streaming corpus:   3%|▎         | 64452/2061414 [02:56<2:14:19, 247.77it/s]

✅ saved chunk 0 n=50000 time=2.9m


Streaming corpus:   6%|▌         | 114911/2061414 [05:51<2:05:15, 259.01it/s]

✅ saved chunk 1 n=50000 time=2.9m


Streaming corpus:   8%|▊         | 164167/2061414 [08:46<2:07:19, 248.34it/s]

✅ saved chunk 2 n=50000 time=2.9m


Streaming corpus:  10%|█         | 215430/2061414 [11:39<1:57:22, 262.13it/s]

✅ saved chunk 3 n=50000 time=2.9m


Streaming corpus:  13%|█▎        | 266034/2061414 [14:32<1:56:01, 257.92it/s]

✅ saved chunk 4 n=50000 time=2.9m


Streaming corpus:  15%|█▌        | 315632/2061414 [17:23<1:43:55, 279.99it/s]

✅ saved chunk 5 n=50000 time=2.8m


Streaming corpus:  18%|█▊        | 366154/2061414 [20:15<1:50:24, 255.90it/s]

✅ saved chunk 6 n=50000 time=2.9m


Streaming corpus:  20%|██        | 417730/2061414 [23:07<1:31:50, 298.28it/s]

✅ saved chunk 7 n=50000 time=2.9m


Streaming corpus:  22%|██▏       | 457903/2061414 [25:55<2:16:26, 195.88it/s]

✅ saved chunk 8 n=50000 time=2.8m


Streaming corpus:  25%|██▌       | 515563/2061414 [28:48<1:31:37, 281.17it/s]

✅ saved chunk 9 n=50000 time=2.9m


Streaming corpus:  27%|██▋       | 565927/2061414 [31:40<1:37:48, 254.84it/s]

✅ saved chunk 10 n=50000 time=2.9m


Streaming corpus:  30%|██▉       | 616167/2061414 [34:32<1:34:20, 255.33it/s]

✅ saved chunk 11 n=50000 time=2.9m


Streaming corpus:  32%|███▏      | 657855/2061414 [37:24<1:58:45, 196.97it/s]

✅ saved chunk 12 n=50000 time=2.9m


Streaming corpus:  35%|███▍      | 715857/2061414 [40:17<1:28:27, 253.54it/s]

✅ saved chunk 13 n=50000 time=2.9m


Streaming corpus:  37%|███▋      | 757847/2061414 [43:09<2:13:16, 163.02it/s]

✅ saved chunk 14 n=50000 time=2.9m


Streaming corpus:  40%|███▉      | 816804/2061414 [46:01<1:11:18, 290.92it/s]

✅ saved chunk 15 n=50000 time=2.8m


Streaming corpus:  42%|████▏     | 867077/2061414 [48:53<1:07:31, 294.82it/s]

✅ saved chunk 16 n=50000 time=2.8m


Streaming corpus:  44%|████▍     | 907499/2061414 [51:44<1:39:36, 193.07it/s]

✅ saved chunk 17 n=50000 time=2.9m


Streaming corpus:  47%|████▋     | 967604/2061414 [54:37<1:06:04, 275.90it/s]

✅ saved chunk 18 n=50000 time=2.9m


Streaming corpus:  49%|████▉     | 1006270/2061414 [57:28<1:38:01, 179.40it/s]

✅ saved chunk 19 n=50000 time=2.8m


Streaming corpus:  51%|█████▏    | 1057718/2061414 [1:00:21<1:34:00, 177.96it/s]

✅ saved chunk 20 n=50000 time=2.9m


Streaming corpus:  54%|█████▍    | 1117706/2061414 [1:03:14<57:14, 274.78it/s]  

✅ saved chunk 21 n=50000 time=2.9m


Streaming corpus:  57%|█████▋    | 1167250/2061414 [1:06:07<51:31, 289.21it/s]  

✅ saved chunk 22 n=50000 time=2.9m


Streaming corpus:  59%|█████▊    | 1207997/2061414 [1:08:54<1:13:06, 194.56it/s]

✅ saved chunk 23 n=50000 time=2.8m


Streaming corpus:  61%|██████    | 1257891/2061414 [1:11:19<55:52, 239.67it/s]  

✅ saved chunk 24 n=50000 time=2.4m


Streaming corpus:  64%|██████▍   | 1316484/2061414 [1:14:09<42:57, 289.02it/s]  

✅ saved chunk 25 n=50000 time=2.8m


Streaming corpus:  66%|██████▌   | 1359200/2061414 [1:16:54<54:26, 214.95it/s]  

✅ saved chunk 26 n=50000 time=2.7m


Streaming corpus:  69%|██████▊   | 1415449/2061414 [1:19:44<38:59, 276.14it/s]  

✅ saved chunk 27 n=50000 time=2.8m


Streaming corpus:  71%|███████   | 1466119/2061414 [1:22:34<34:21, 288.81it/s]  

✅ saved chunk 28 n=50000 time=2.8m


Streaming corpus:  74%|███████▎  | 1517902/2061414 [1:25:26<30:09, 300.35it/s]  

✅ saved chunk 29 n=50000 time=2.8m


Streaming corpus:  76%|███████▌  | 1559477/2061414 [1:28:15<41:16, 202.65it/s]  

✅ saved chunk 30 n=50000 time=2.8m


Streaming corpus:  79%|███████▊  | 1618372/2061414 [1:31:04<25:24, 290.66it/s]  

✅ saved chunk 31 n=50000 time=2.8m


Streaming corpus:  81%|████████  | 1671305/2061414 [1:33:44<18:00, 361.08it/s]

✅ saved chunk 32 n=50000 time=2.7m


Streaming corpus:  83%|████████▎ | 1707814/2061414 [1:36:29<26:51, 219.48it/s]

✅ saved chunk 33 n=50000 time=2.7m


Streaming corpus:  85%|████████▌ | 1757627/2061414 [1:39:19<26:12, 193.23it/s]

✅ saved chunk 34 n=50000 time=2.8m


Streaming corpus:  88%|████████▊ | 1817374/2061414 [1:42:09<13:30, 301.11it/s]

✅ saved chunk 35 n=50000 time=2.8m


Streaming corpus:  91%|█████████ | 1867880/2061414 [1:45:00<10:50, 297.39it/s]

✅ saved chunk 36 n=50000 time=2.8m


Streaming corpus:  93%|█████████▎| 1909904/2061414 [1:47:51<12:12, 206.96it/s]

✅ saved chunk 37 n=50000 time=2.8m


Streaming corpus:  96%|█████████▌| 1968773/2061414 [1:50:38<05:04, 304.00it/s]

✅ saved chunk 38 n=50000 time=2.8m


Streaming corpus:  97%|█████████▋| 2008630/2061414 [1:53:24<04:35, 191.61it/s] 

✅ saved chunk 39 n=50000 time=2.8m


Streaming corpus: 100%|█████████▉| 2059790/2061414 [1:56:14<00:08, 196.22it/s]

✅ saved chunk 40 n=50000 time=2.8m


Streaming corpus: 100%|██████████| 2061414/2061414 [1:56:52<00:00, 293.97it/s]

✅ saved chunk 41 n=11414 time=0.6m
✅ All chunks ready.





Embedding dim: 768
Train sample: (400000, 768)
Training IVF...
✅ IVF trained. Adding vectors...
  added 0/41, total=50000
  added 10/41, total=550000
  added 20/41, total=1050000
  added 30/41, total=1550000
  added 40/41, total=2050000
Index total: 2061414 Docids: 2061414
✅ Saved index + docids
Queries: 2896 q_emb: (2896, 768)
✅ Wrote run: /kaggle/working/runs/contriever_zero_miracl_ar_dev.run
{'map': np.float64(0.0006234993661948514), 'mrr': np.float64(0.0006794800701671288), 'ndcg@10': np.float64(0.0007499007726032625), 'recall@100': np.float64(0.0024171270718232043)}
