# 🧠 DPR-Style Dense Retriever with FAISS (RAG-ready)

In [6]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
import faiss
import pandas as pd
import numpy as np
import os


In [7]:
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer

class DPRRetriever(nn.Module):
    def __init__(self, model_name="bert-base-uncased", proj_dim=512):
        super(DPRRetriever, self).__init__()

        self.query_encoder = BertModel.from_pretrained(model_name)
        self.passage_encoder = BertModel.from_pretrained(model_name)
        self.tokenizer = BertTokenizer.from_pretrained(model_name)

        # Freeze BERT parameters
        for param in self.query_encoder.parameters():
            param.requires_grad = False
        for param in self.passage_encoder.parameters():
            param.requires_grad = False

        # Add trainable projection layers: 768 → 512
        self.query_proj = nn.Linear(self.query_encoder.config.hidden_size, proj_dim)
        self.passage_proj = nn.Linear(self.passage_encoder.config.hidden_size, proj_dim)

    def encode_query(self, texts, device):
        with torch.no_grad():
            encoding = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            outputs = self.query_encoder(input_ids=input_ids, attention_mask=attention_mask)
            cls_token = outputs.last_hidden_state[:, 0]  # CLS token
        return self.query_proj(cls_token)  # Trainable layer

    def encode_passage(self, texts, device):
        with torch.no_grad():
            encoding = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
            input_ids = encoding['input_ids'].to(device)
            attention_mask = encoding['attention_mask'].to(device)
            outputs = self.passage_encoder(input_ids=input_ids, attention_mask=attention_mask)
            cls_token = outputs.last_hidden_state[:, 0]
        return self.passage_proj(cls_token)  # Trainable layer

In [8]:

class QADataset(Dataset):
    def __init__(self, questions, contexts):
        self.questions = questions
        self.contexts = contexts

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        return {
            'question': self.questions[idx],
            'context': self.contexts[idx]
        }

def collate_batch(batch):
    return {
        "questions": [str(item["question"]) for item in batch],
        "contexts": [str(item["context"]) for item in batch]
    }

def multiple_negatives_ranking_loss(q_embeds, p_embeds):
    scores = torch.matmul(q_embeds, p_embeds.T)  # [B, B]
    labels = torch.arange(len(scores)).to(scores.device)
    return F.cross_entropy(scores, labels)


In [9]:

# Load dataset
ds = load_dataset("neural-bridge/rag-dataset-12000", split="train")
questions, contexts = ds['question'], ds['context']
train_dataset = QADataset(questions, contexts)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_batch)


In [10]:

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
model = DPRRetriever().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

for epoch in range(3):
    model.train()
    total_loss = 0

    for batch in train_loader:
        q_texts, p_texts = batch['questions'], batch['contexts']
        q_embed = model.encode_query(q_texts, device)
        p_embed = model.encode_passage(p_texts, device)

        loss = multiple_negatives_ranking_loss(q_embed, p_embed)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")


Epoch 1 | Loss: 2034.5306
Epoch 2 | Loss: 1317.6489
Epoch 3 | Loss: 1049.3585


In [15]:
torch.save(model.state_dict(), "dpr_model.pt")

In [11]:

# Build FAISS index for all passages
os.makedirs("dpr_faiss_store", exist_ok=True)
context_embeddings = []

model.eval()
with torch.no_grad():
    for i in range(0, len(contexts), 32):
        batch_texts = contexts[i:i+32]
        embs = model.encode_passage(batch_texts, device).cpu().numpy()
        context_embeddings.append(embs)

context_embeddings = np.vstack(context_embeddings)

# Save context mapping
pd.DataFrame({"context": contexts}).to_csv("dpr_faiss_store/context_mapping.csv", index=False)

# Build FAISS index
dimension = context_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(context_embeddings)
faiss.write_index(index, "dpr_faiss_store/context_index.faiss")
print("FAISS index saved.")


FAISS index saved.


In [13]:

# Load everything and run a query
model.eval()
index = faiss.read_index("dpr_faiss_store/context_index.faiss")
context_df = pd.read_csv("dpr_faiss_store/context_mapping.csv")

def retrieve_top_k(query, k=5):
    with torch.no_grad():
        query_vec = model.encode_query([query], device).cpu().numpy()
    distances, indices = index.search(query_vec, k)
    return [context_df.iloc[i]["context"] for i in indices[0]]

# Example usage
query = "What is the Berry Export Summary 2028 and what is its purpose?"
results = retrieve_top_k(query)
for i, r in enumerate(results, 1):
    print(f"[{i}] {r}\n")


[1] follows in the wake of its neighboring territories and decrees the perimeter closure to contain COVID-19
Cantabria is confined, thus following in the footsteps of its neighboring communities that have been doing the same throughout these days. In this way, as government sources have confirmed to elDiario.es, Health will decree the perimeter closure to try to contain the worrying advance of COVID-19 in recent weeks, in which there have been triggered daily infections and, as a consequence, also hospital occupancy, and does not rule out applying more restrictive measures in the field of mobility in the municipalities during the coming days.
Cantabria has finally decided to close its borders after the president, Miguel Ángel Revilla (PRC), has not contemplated this possibility, questioned by this matter in recent days, even going so far as to assess how “very bad news” for the community the closure of Asturias and Euskadi decreed on Monday.
(There will be enlargement)

[2] The The Sca