In [2]:
!pip install datasets transformers faiss-cpu sentence-transformers torch

Collecting datasets
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting transformers
  Downloading transformers-4.45.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting faiss-cpu
  Downloading faiss_cpu-1.8.0.post1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Collecting sentence-transformers
  Downloading sentence_transformers-3.1.1-py3-none-any.whl.metadata (10 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.17-py310-none-any.whl.metadata (7.2 kB)
Collecting tokenizers<0.21,>=0.20 (from transformers)
  Downloading tokenizers-0.20.0-cp310-cp310-many

In [1]:
from datasets import load_dataset

squad_dataset = load_dataset('squad')

corpus = {i: squad_dataset['train'][i]['context'] for i in range(len(squad_dataset['train']))}
queries = {i: squad_dataset['train'][i]['question'] for i in range(len(squad_dataset['train']))}

corpus_ids = list(corpus.keys())
corpus_texts = list(corpus.values())


In [2]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cuda


In [3]:
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np

embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device=device)

corpus_embeddings = embedding_model.encode(corpus_texts, convert_to_tensor=True, show_progress_bar=True)
corpus_embeddings = corpus_embeddings.cpu()

faiss_index = faiss.IndexFlatIP(corpus_embeddings.shape[1])
faiss_index.add(corpus_embeddings.numpy())

def retrieve_candidates(query, top_k=10, model=embedding_model):
    query_embedding = model.encode(query, convert_to_tensor=True, device=device)
    query_embedding = query_embedding.cpu().numpy().reshape(1, -1)
    scores, indices = faiss_index.search(query_embedding, top_k)
    return [(corpus_ids[idx], scores[0][i]) for i, idx in enumerate(indices[0])]


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]



1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/2738 [00:00<?, ?it/s]

In [4]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

cross_encoder_model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-12-v2')
cross_encoder_model.to(device)
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L-12-v2')

def rerank_passages(query, candidates, model=cross_encoder_model, tokenizer=tokenizer):
    scores = []
    for doc_id, _ in candidates:
        passage = corpus[doc_id]
        inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, max_length=512).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        scores.append((doc_id, outputs.logits.item()))

    return sorted(scores, key=lambda x: x[1], reverse=True)

config.json:   0%|          | 0.00/791 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/134M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/316 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [5]:
query_id = 0
query_text = queries[query_id]

candidates = retrieve_candidates(query_text, top_k=10)

reranked_results = rerank_passages(query_text, candidates)

print("Top 3 Results before Reranking:")
for doc_id, score in candidates[:3]:
    print(f"Document ID: {doc_id}, Score: {score}, Passage: {corpus[doc_id][:200]}")

print("\nTop 3 Results after Reranking:")
for doc_id, score in reranked_results[:3]:
    print(f"Document ID: {doc_id}, Reranked Score: {score}, Passage: {corpus[doc_id][:200]}")


Top 3 Results before Reranking:
Document ID: 4, Score: 0.6013752222061157, Passage: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper sta
Document ID: 3, Score: 0.6013752222061157, Passage: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper sta
Document ID: 2, Score: 0.6013752222061157, Passage: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper sta

Top 3 Results after Reranking:
Document ID: 4, Reranked Score: 6.644032955169678, Passage: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Vi