# Toy Examples DPR

In [None]:
#https://github.com/beir-cellar/beir/blob/main/examples/retrieval/evaluation/dense/evaluate_dpr.py
!pip install beir



# Import Libraries

In [None]:
import numpy as np
import pandas as pd

# Load Toy Data-Sets

In [None]:
#documents
docs = pd.read_csv('./toy_data/docs.csv', dtype=str)

#queries
queries = pd.read_csv('./toy_data/queries.csv', dtype=str)

#qrels
qrels = pd.read_csv('./toy_data/qrels.csv', dtype=str)
qrels = qrels.astype({'label': 'int32'})


#prints
print(docs.shape)
print(docs.head())

print(queries.shape)
print(queries.head())

print(qrels.shape)
print(qrels.head())

# Dense IR - Using Dense Passage Retrieval (DPR)

In [None]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import logging
import pathlib, os

In [None]:
#IMPLEMENTED MODEL FROM https://github.com/beir-cellar/beir
#https://github.com/beir-cellar/beir/blob/main/beir/retrieval/models/dpr.py

from transformers import DPRContextEncoder, DPRContextEncoderTokenizerFast
from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizerFast
from typing import Union, List, Dict, Tuple
from tqdm.autonotebook import trange
import torch

class DPR:
    def __init__(self, model_path: Union[str, Tuple] = None, **kwargs):
        # Query tokenizer and model
        self.q_tokenizer = DPRQuestionEncoderTokenizerFast.from_pretrained(model_path[0])
        self.q_model = DPRQuestionEncoder.from_pretrained(model_path[0])
        self.q_model.cuda()
        self.q_model.eval()

        # Context tokenizer and model
        self.ctx_tokenizer = DPRContextEncoderTokenizerFast.from_pretrained(model_path[1])
        self.ctx_model = DPRContextEncoder.from_pretrained(model_path[1])
        self.ctx_model.cuda()
        self.ctx_model.eval()

    def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> torch.Tensor:
        query_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(queries), batch_size):
                encoded = self.q_tokenizer(queries[start_idx:start_idx+batch_size], truncation=True, padding=True, return_tensors='pt')
                model_out = self.q_model(encoded['input_ids'].cuda(), attention_mask=encoded['attention_mask'].cuda())
                #model_out = self.q_model(encoded['input_ids'], attention_mask=encoded['attention_mask'])
                query_embeddings += model_out.pooler_output

        return torch.stack(query_embeddings)

    def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int = 8, **kwargs) -> torch.Tensor:

        corpus_embeddings = []
        with torch.no_grad():
            for start_idx in trange(0, len(corpus), batch_size):
                #titles = [row['title'] for row in corpus[start_idx:start_idx+batch_size]]
                texts = [row['text']  for row in corpus[start_idx:start_idx+batch_size]]
                #encoded = self.ctx_tokenizer(titles, texts, truncation='longest_first', padding=True, return_tensors='pt')
                encoded = self.ctx_tokenizer(texts, truncation='longest_first', padding=True, return_tensors='pt')
                model_out = self.ctx_model(encoded['input_ids'].cuda(), attention_mask=encoded['attention_mask'].cuda())
                #model_out = self.ctx_model(encoded['input_ids'], attention_mask=encoded['attention_mask'])
                corpus_embeddings += model_out.pooler_output.detach()

        return torch.stack(corpus_embeddings)

In [None]:
new_docs = {}
for i in range(len(docs)):
    new_docs[docs['docno'][i]] = {'text' : docs['text'][i]}

In [None]:
new_queries = {}
for i in range(len(queries)):
    new_queries[queries['qid'][i]] = queries['query'][i]

In [None]:
new_qrels = {}
for i in range(len(qrels)):
    new_qrels[qrels['qid'][i]] = {qrels['docno'][i] : int(qrels['label'][i])}

In [None]:
model_dpr = DRES(DPR((
     "facebook/dpr-question_encoder-multiset-base",
     "facebook/dpr-ctx_encoder-multiset-base"), batch_size=16))
retriever_dpr = EvaluateRetrieval(model_dpr, score_function="dot") # or "dot" for dot-product
results_dpr = retriever_dpr.retrieve(new_docs, new_queries)

In [None]:
model = DRES(models.SentenceBERT("msmarco-distilbert-base-tas-b"), batch_size=16)
retriever = EvaluateRetrieval(model, score_function="dot") # or "cos_sim" for cosine similarity
results = retriever.retrieve(new_docs, new_queries)

In [None]:
#https://www.sbert.net/docs/pretrained-models/msmarco-v3.html
#model_ance = DRES(models.SentenceBERT('msmarco-distilroberta-base-v3'))
model_ance = DRES(models.SentenceBERT('msmarco-roberta-base-v3'))
#model_ance = DRES(models.SentenceBERT('msmarco-distilbert-base-tas-b'))
retriever_ance = EvaluateRetrieval(model_ance, score_function="cos_sim")

#### Retrieve dense results (format of results is identical to qrels)
results_ance = retriever_ance.retrieve(new_docs, new_queries)

In [None]:
model_dpr_alt = DRES(models.SentenceBERT((
    "facebook-dpr-question_encoder-multiset-base",
    "facebook-dpr-ctx_encoder-multiset-base",
    " [SEP] "), batch_size=128))
retriever_dpr_alt = EvaluateRetrieval(model_dpr_alt, score_function="dot")
results_dpr_alt = retriever_dpr_alt.retrieve(new_docs, new_queries)

In [None]:
#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K  where k = [1,3,5,10,100,1000] [10,100,100]
ndcg, _map, recall, precision = retriever_dpr.evaluate(new_qrels, results_dpr, [10,100,1000]) #retriever_dpr.k_values)
ndcg_alt, _map_alt, recall_alt, precision_alt = retriever.evaluate(new_qrels,results, [10,100,1000]) # retriever.k_values)
ndcg_ance, _map_ance, recall_ance, precision_ance = retriever_ance.evaluate(new_qrels, results_ance, [10,100,1000]) #retriever_ance.k_values)
ndcg_dpr_alt, _map_dpr_alt, recall_dpr_alt, precision_dpr_alt = retriever_dpr_alt.evaluate(new_qrels, results_dpr_alt, [10,100,1000]) #retriever_dpr_alt.k_values)


In [None]:
print("Original DPR:", ndcg)
print("Original Sentence BERT:", ndcg_alt)
print("Original ANCE:", ndcg_ance)
print("Alternative DPR", ndcg_dpr_alt)

In [None]:
print("Original DPR:", _map)
print("Original Sentence BERT:", _map_alt)
print("Original ANCE:", _map_ance)
print("Alternative DPR", _map_dpr_alt)

In [None]:
print("Original DPR:", recall)
print("Original Sentence BERT:", recall_alt)
print("Original ANCE:", recall_ance)
print("Alternative DPR", recall_dpr_alt)

In [None]:
print("Original DPR:", precision)
print("Original Sentence BERT:", precision_alt)
print("Original ANCE:", precision_ance)
print("Alternative DPR", precision_dpr_alt)