In [1]:
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.cluster import MiniBatchKMeans
from collections import defaultdict

import ir_datasets
from tqdm import tqdm
import math

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fc5b4a9d3f0>

In [3]:
# 1. DATA
ds = load_dataset("ms_marco", "v2.1", split="train[:2000]")
passages = [" ".join(ex["passages"]["passage_text"]) for ex in ds]
queries  = ds[:6]["query"]

In [4]:
# 2. ENCODER
# Load tokenizer + encoder (MiniLM)
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name).to(DEVICE).eval()

In [5]:
def encode_text(texts, max_len=64):
    """
    Returns a list of *L_i × d* numpy arrays (one per input string),
    **L_i excludes [CLS]/[SEP]** and embeddings are ℓ₂-normalised.
    """
    if isinstance(texts, str):
        texts = [texts]

    inp = tokenizer(
        texts, padding=True, truncation=True,
        return_tensors="pt", max_length=max_len
    ).to(DEVICE)

    emb = encoder(**inp).last_hidden_state            # (B, L, d)
    mask = inp.attention_mask.bool()

    out = []
    for i in range(len(texts)):
        # drop special-tokens, move to cpu → np, normalise
        vecs = emb[i][mask[i]][1:-1].cpu().numpy()
        vecs /= np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-8
        out.append(vecs.astype(np.float32))

    return out                              # list of (L_i, d) arrays

In [6]:
# 3. OFFLINE INDEX BUILDING
doc_vecs  = encode_text(passages[:1000])
token_mat = np.vstack(doc_vecs)

# §4.1 PLAID
k = 256
kmeans = MiniBatchKMeans(k, batch_size=2048, random_state=42)
kmeans.fit(token_mat)
centroids = kmeans.cluster_centers_.astype(np.float32)
centroids /= np.linalg.norm(centroids, axis=1, keepdims=True)  # normalise

In [7]:
# Inverted index: centroid_id → {doc_id}
inv_index = defaultdict(set)
offset = 0
for doc_id, v in enumerate(doc_vecs):
    c_ids = kmeans.predict(v)                        # (L_i,)
    for cid in np.unique(c_ids):
        inv_index[cid].add(doc_id)
    doc_vecs[doc_id] = c_ids                         # keep as centroid-ids
    offset += len(v)

inv_index = {cid: np.fromiter(docs, dtype=np.int32)
             for cid, docs in inv_index.items()}

In [8]:
# 4. SEARCH COMPONENTS
def centroid_scores(q_vec):
    """C · Qᵀ  →  (k, |q|)  as in Eq. 2 of the paper."""
    return centroids @ q_vec.T                       # float32

def centroid_pruned_ids(S_cq, t_cs=0.45, nprobe=2):
    """
    Stage-1 candidate generation (§4.1) + centroid pruning (§4.3).
    *nprobe*  : #top-centroids per query-token.
    *t_cs*    : pruning threshold.
    returns   : unique doc ids (np.int32)
    """
    topc = np.argpartition(S_cq, -nprobe, axis=0)[-nprobe:]    # (nprobe, |q|)
    cand_docs = set()
    for cid in np.unique(topc):
        # prune whole centroid if its best score < t_cs (§3.4, Eq. 5)
        if S_cq[cid].max() < t_cs:
            continue
        cand_docs.update(inv_index.get(cid, []))
    return np.fromiter(cand_docs, dtype=np.int32)

def centroid_interaction_score(S_cq, doc_cids):
    """
    Stage-2/3 centroid interaction (§4.2, Eq. 3-4).
    doc_cids : 1-D array of centroid ids for that document.
    """
    doc_scores = S_cq[doc_cids]               # (len(doc), |q|)
    return doc_scores.max(axis=0).sum()       # scalar

def rank_documents(query, k_final=10, nprobe=2, t_cs=0.45, ndocs=256):
    """
    Complete PLAID pipeline up to *centroid-only* ranking
    (i.e. without residual decompression to keep code short).
    """
    q_vec = encode_text(query)[0]                             # (|q|, d)
    S_cq = centroid_scores(q_vec)                             # (k, |q|)
    C1   = centroid_pruned_ids(S_cq, t_cs, nprobe)            # Stage-1/2
    if len(C1) == 0:
        return []

    # Stage-3: centroid interaction on |C1| docs, keep top-ndocs
    scores = np.array([
        centroid_interaction_score(S_cq, doc_vecs[did]) for did in C1
    ], dtype=np.float32)
    n_keep = min(ndocs, len(scores))              # <-- NEW
    if n_keep == 0:
        return []                                 # (shouldn't happen, but safe)
    top_idx   = np.argpartition(scores, -n_keep)[-n_keep:]
    top_nd    = C1[top_idx]

    top_scores = [(did, scores[i]) for i, did in zip(top_idx, top_nd)]
    top_scores.sort(key=lambda x: x[1], reverse=True)
    return top_scores[:k_final]

In [9]:
hits = rank_documents(queries[0], k_final=5)
print("\nQuery:", queries[0])
print("\nTop 5 passages")
for rank, (pid, score) in enumerate(hits, 1):
    print(f"{rank:>2}. (doc#{pid}) {passages[pid][:120]} …  [{score:.3f}]")


Query: )what was the immediate impact of the success of the manhattan project?

Top 5 passages
 1. (doc#971) Louis Racine. Louis Racine (born November 6, 1692, Paris; died January 29, 1763, Paris) was a French poet of the Age of  …  [5.190]
 2. (doc#976) introDUCtion: tHe BeneFits oF an eFFeCtive Corporate internal investigation. Corporations are being scrutinized today as …  [5.106]
 3. (doc#26) My husband and I stayed at the Residence Inn in Shelton, CT for 3 months following a kitchen fire in our house. This cou …  [5.102]
 4. (doc#0) The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as sci …  [5.094]
 5. (doc#936) (CNN)Biggest Loser host and fitness trainer Bob Harper said he is thankful to be alive after suffering a mid-February he …  [4.956]


In [10]:
# -----------------------------
# 1. Load and limit MS MARCO-dev qrels
# -----------------------------
msmarco_dev = ir_datasets.load("msmarco-passage/dev")

qrels = defaultdict(set)
query_texts = {}
doc_texts = {}
count = 0
max_qrels = 10000

# Load only up to 500 relevant qrels
for qrel in msmarco_dev.qrels_iter():
    if qrel.relevance > 0:
        qid, did = qrel.query_id, qrel.doc_id
        qrels[qid].add(did)
        if qid not in query_texts:
            query_texts[qid] = None  # placeholder
        if did not in doc_texts:
            doc_texts[did] = None  # placeholder
        count += 1
        if count >= max_qrels:
            break

# Now fill in the actual query and document texts for the limited set
for q in msmarco_dev.queries_iter():
    if q.query_id in query_texts:
        query_texts[q.query_id] = q.text

for d in msmarco_dev.docs_iter():
    if d.doc_id in doc_texts:
        doc_texts[d.doc_id] = d.text

# crude reverse map text → doc_id
text2id = {v.strip(): k for k, v in doc_texts.items() if v is not None}

# -- NEW: create passages and id maps ---------------------------------------
ms_ids   = np.array(sorted(doc_texts))                # external ids
passages = [doc_texts[did] for did in ms_ids]         # aligned texts

# 1. OFFLINE INDEX -----------------------------------------------------------
doc_vecs  = encode_text(passages)                     # list of (L_i, d)
token_mat = np.vstack(doc_vecs)

kmeans = MiniBatchKMeans(256, batch_size=2048, random_state=42).fit(token_mat)
centroids = kmeans.cluster_centers_.astype(np.float32)
centroids /= np.linalg.norm(centroids, axis=1, keepdims=True)


inv_index = defaultdict(set)
for internal_id, v in enumerate(doc_vecs):
    c_ids = kmeans.predict(v)
    for cid in np.unique(c_ids):
        inv_index[cid].add(internal_id)
    doc_vecs[internal_id] = c_ids

inv_index = {cid: np.fromiter(docs, dtype=np.int32) for cid, docs in inv_index.items()}


In [11]:
# -----------------------------
# 2.  Scoring & metrics
# -----------------------------
K_MRR, K_REC = 10, 100
tot_mrr = tot_rec = n_eval = 0

for qid, qtext in tqdm(query_texts.items(), total=len(query_texts)):
    top = rank_documents(qtext, k_final=K_REC, nprobe=2, t_cs=0.45, ndocs=1024)
    our_ids = [ms_ids[pid] for pid, _ in top]         # <-- single guaranteed map

    rels = qrels[qid]
    if not rels:
        continue

    # MRR@10
    first_hit = next((i for i, did in enumerate(our_ids[:K_MRR]) if did in rels), None)
    tot_mrr  += 1.0 / (first_hit + 1) if first_hit is not None else 0.0

    # Recall@100
    tot_rec += sum(did in rels for did in our_ids) / len(rels)
    n_eval  += 1

print(f"\nEvaluated {n_eval} queries")
print(f"MRR@{K_MRR}   = {tot_mrr / n_eval:.4f}")
print(f"Recall@{K_REC} = {tot_rec / n_eval:.4f}")

100%|██████████| 9529/9529 [08:35<00:00, 18.50it/s]


Evaluated 9529 queries
MRR@10   = 0.2997
Recall@100 = 0.7854



