In [None]:
!pip install beir

"""
STAGE 1 — DeepImpact-style Retriever
"""

import os
import json
import pickle
import random
from tqdm import tqdm
from dataclasses import dataclass
from typing import Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, AutoModel
from beir import util
from beir.datasets.data_loader import GenericDataLoader

# Utility
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# Dataset: Triples
@dataclass
class TripleExample:
    qid: str
    pos_id: str
    neg_id: str
    q: str
    pos: str
    neg: str


class TripleDataset(Dataset):
    def __init__(self, path):
        self.items = []
        with open(path, "r") as f:
            for line in f:
                o = json.loads(line)
                self.items.append(
                    TripleExample(
                        o["qid"], o["pos_id"], o["neg_id"],
                        o["query"], o["pos"], o["neg"]
                    )
                )

    def __len__(self):
        return len(self.items)

    def __getitem__(self, i):
        return self.items[i]


def triple_collate_fn(batch):
    q_list = [item.q for item in batch]
    pos_list = [item.pos for item in batch]
    neg_list = [item.neg for item in batch]
    return {
        "q": q_list,
        "pos": pos_list,
        "neg": neg_list
    }

# DeepImpact Model
class ImpactEncoder(nn.Module):
    def __init__(self, model_name="bert-base-uncased", freeze=True):
        super().__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        if freeze:
            for p in self.bert.parameters():
                p.requires_grad = False

        self.mlp = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, ids, mask):
        h = self.bert(ids, mask).last_hidden_state
        scores = self.mlp(h).squeeze(-1)
        return torch.relu(scores)      # non-negative token impacts


class DeepImpactScorer:
    def __init__(self, freeze=True):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.model = ImpactEncoder(freeze=freeze).to(self.device)

    def encode_pair(self, q, d):
        enc = self.tokenizer(q, d, truncation=True, max_length=256,
                             padding="max_length", return_tensors="pt")
        return enc["input_ids"].to(self.device), enc["attention_mask"].to(self.device)

    def score_query_with_posting(self, query, posting, top_k=100):
        q_ids = self.tokenizer(query)["input_ids"]
        results = []
        for doc_id, pl in posting.items():
            s = 0
            for t in q_ids:
                if t in pl:
                    s += pl[t]
            results.append((doc_id, s))
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]

# Training Loop (simple softmax ranking)
def train_stage1(triple_file, freeze_bert=True, save_path="impact.pt"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    scorer = DeepImpactScorer(freeze=freeze_bert)
    model = scorer.model
    tokenizer = scorer.tokenizer

    ds = TripleDataset(triple_file)
    dl = DataLoader(ds, batch_size=8, shuffle=True, collate_fn=triple_collate_fn)

    opt = None
    ranking_head = None

    if freeze_bert:
        # If BERT is frozen, we need a trainable head for the simple softmax ranking.
        # This head will project the [CLS] embedding from BERT.
        ranking_head = nn.Linear(model.bert.config.hidden_size, model.bert.config.hidden_size).to(device)
        # The optimizer should train both the ranking_head and the ImpactEncoder's mlp.
        # model.mlp parameters are already trainable if freeze_bert=True.
        opt = torch.optim.Adam(list(model.mlp.parameters()) + list(ranking_head.parameters()), lr=1e-3)
    else:
        # If BERT is not frozen, then all trainable parameters of the ImpactEncoder (bert and mlp) are optimized.
        opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(2):
        total = 0
        for ex_batch in tqdm(dl, desc=f"Epoch {epoch+1}"):

            q = ex_batch["q"]
            p = ex_batch["pos"]
            n = ex_batch["neg"]

            # Encode independently (simple version)
            q_enc = tokenizer(q, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
            p_enc = tokenizer(p, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)
            n_enc = tokenizer(n, truncation=True, padding=True, max_length=256, return_tensors="pt").to(device)

            q_vec = model.bert(**q_enc).last_hidden_state[:, 0, :]
            p_vec = model.bert(**p_enc).last_hidden_state[:, 0, :]
            n_vec = model.bert(**n_enc).last_hidden_state[:, 0, :]

            if freeze_bert:
                # Pass through the trainable ranking head
                q_vec = ranking_head(q_vec)
                p_vec = ranking_head(p_vec)
                n_vec = ranking_head(n_vec)

            s_pos = (q_vec * p_vec).sum(dim=-1)
            s_neg = (q_vec * n_vec).sum(dim=-1)

            logits = torch.stack([s_pos, s_neg], dim=1)
            labels = torch.zeros(len(ex_batch["q"]), dtype=torch.long, device=device)
            loss = loss_fn(logits, labels)

            opt.zero_grad()
            loss.backward()
            opt.step()

            total += loss.item()

        print("Avg loss:", total / len(dl))

    torch.save(model.state_dict(), save_path)
    print("Model saved at", save_path)
    return scorer

# Posting List Builder
def build_posting_list(scorer, corpus, query_example):
    posting = {}
    scorer.model.eval()
    for did, doc_dict in tqdm(corpus.items(), desc="Posting list"):
        text = doc_dict["text"]  # Extract the actual text from the dictionary
        ids, mask = scorer.encode_pair(query_example, text)
        with torch.no_grad():
            impacts = scorer.model(ids, mask)[0]

        doc_pl = {}
        tids = ids[0].tolist()
        for t, imp in zip(tids, impacts.tolist()):
            if imp > 0:
                doc_pl[t] = doc_pl.get(t, 0.0) + float(imp)
        posting[did] = doc_pl
    return posting


# REAL WORKING BEIR LOADER
def load_beir_dataset(name):
    """
    Downloads automatically and loads without any path error.
    """
    print(f"Downloading BEIR dataset: {name}")
    out = util.download_and_unzip(f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{name}.zip",
                                  "./datasets")

    loader = GenericDataLoader(out)
    corpus, queries, qrels = loader.load(split="test")     # Stage 1 uses test queries
    return corpus, queries, qrels

# MAIN
def main():

    set_seed(42)

    # 1. Load BEIR datasets correctly
    corpus_quora, queries_quora, qrels_quora = load_beir_dataset("quora")
    corpus_trec, queries_trec, qrels_trec = load_beir_dataset("trec-covid")

    # Merge two corpora
    corpus = {**corpus_quora, **corpus_trec}

    # 2. Build simple triples for training
    triple_file = "train_triples.jsonl"
    with open(triple_file, "w") as f:
        for qid, qtext in list(queries_quora.items())[:2000]:   # small subset for demo
            if qid not in qrels_quora:
                continue
            positives = list(qrels_quora[qid].keys())
            if len(positives) == 0:
                continue
            pos_doc = positives[0]
            pos_text = corpus_quora[pos_doc]["text"]

            # pick random negative
            neg_doc = random.choice(list(corpus_quora.keys()))
            neg_text = corpus_quora[neg_doc]["text"]

            f.write(json.dumps({
                "qid": qid,
                "pos_id": pos_doc,
                "neg_id": neg_doc,
                "query": qtext,
                "pos": pos_text,
                "neg": neg_text
            }) + "\n")

    print("Triples saved:", triple_file)

    # 3. Train Stage 1 model
    scorer = train_stage1(triple_file, freeze_bert=True, save_path="impact.pt")

    # 4. Build posting list
    any_query = list(queries_quora.values())[0]
    posting = build_posting_list(scorer, corpus, any_query)

    with open("posting.pkl", "wb") as f:
        pickle.dump(posting, f)

    print("Posting list saved.")

    print("Stage 1 complete ✔")


if __name__ == "__main__":
    main()

Downloading BEIR dataset: quora


  0%|          | 0/522931 [00:00<?, ?it/s]

Downloading BEIR dataset: trec-covid


  0%|          | 0/171332 [00:00<?, ?it/s]

Triples saved: train_triples.jsonl


Epoch 1: 100%|██████████| 250/250 [00:11<00:00, 22.06it/s]


Avg loss: 0.6307371041676798


Epoch 2: 100%|██████████| 250/250 [00:11<00:00, 21.16it/s]


Avg loss: 0.7035781812922051
Model saved at impact.pt


Posting list: 100%|██████████| 694263/694263 [3:20:06<00:00, 57.83it/s]


Posting list saved.
Stage 1 complete ✔
