In [1]:
from transformers import DPRContextEncoder, DPRQuestionEncoder, DPRContextEncoderTokenizer, DPRQuestionEncoderTokenizer
import torch
import numpy as np

# Load DPR models and tokenizers
context_model_name = 'facebook/dpr-ctx_encoder-multiset-base'
question_model_name = 'facebook/dpr-question_encoder-multiset-base'
context_tokenizer = DPRContextEncoderTokenizer.from_pretrained(context_model_name)
question_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(question_model_name)

# Initialize models
context_encoder = DPRContextEncoder.from_pretrained(context_model_name)
question_encoder = DPRQuestionEncoder.from_pretrained(question_model_name)

# Example passages and queries
passages = [
    "mamals have 4 legs.",
    "Mamals give birth to young ones",
    "The bird is not a mamel",
    "Mamals have 2 eyes",
    "Mamals have 4 arms",
]

queries = [
    "Do mamals lay eggs or do they give birth to young ones?",
    "How many legs does a mamal have?"
]

# Encode passages and queries
def encode_passages(passages):
    inputs = context_tokenizer(passages, truncation=True, padding=True, return_tensors="pt")
    with torch.no_grad():
        outputs = context_encoder(**inputs)
    return outputs.pooler_output

def encode_queries(queries):
    inputs = question_tokenizer(queries, truncation=True, padding=True, return_tensors="pt")
    with torch.no_grad():
        outputs = question_encoder(**inputs)
    return outputs.pooler_output

# Retrieve relevant passages for a query
def retrieve_passages(query_embedding, passage_embeddings):
    similarity_scores = {}
    for i, passage_embedding in enumerate(passage_embeddings):
        similarity = np.dot(query_embedding, passage_embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(passage_embedding))
        similarity_scores[i] = similarity
    sorted_passages = sorted(similarity_scores.items(), key=lambda x: x[1], reverse=True)
    return sorted_passages

# Encode passages and queries
passage_embeddings = encode_passages(passages)
query_embeddings = encode_queries(queries)

# Example usage: retrieve passages for each query
for i, query_embedding in enumerate(query_embeddings):
    top_passages = retrieve_passages(query_embedding.cpu().numpy(), passage_embeddings.cpu().numpy())
    print(f"Top passages for query '{queries[i]}':")
    for passage_idx, score in top_passages[:2]:  # Adjust the number of passages shown
        print(f"Passage {passage_idx + 1}: {passages[passage_idx]} (Score: {score:.4f})")
    print()


  from .autonotebook import tqdm as notebook_tqdm
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRContextEncoderTokenizer'.
Some weights of the model checkpoint at facebook/dpr-ctx_encoder-multiset-base were not used when initializing DPRContextEncoder: ['ctx_encoder.bert_model.pooler.dense.bias', 'ctx_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRContextEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRContextEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassificati

Top passages for query 'Do mamals lay eggs or do they give birth to young ones?':
Passage 2: Mamals give birth to young ones (Score: 0.6731)
Passage 1: mamals have 4 legs. (Score: 0.5483)

Top passages for query 'How many legs does a mamal have?':
Passage 1: mamals have 4 legs. (Score: 0.6985)
Passage 5: Mamals have 4 arms (Score: 0.6080)

