# End to End Passage Retrieval

In [None]:
import json 
import pandas as pd
import os
import torch
from torch import nn
from datetime import datetime
import time
from sentence_transformers import SentenceTransformer, CrossEncoder
from sentence_transformers import util as sentenceutils
import pickle
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_path = '/data/Eli5/Eli5_reranked/eli5_reranked.json'

In [None]:
# read dataset
with open(data_path, 'r') as f:
    data = json.load(f)

data = pd.read_json(data, orient='records')

In [None]:
# load encoders 
bi_encoder_path = '/contextretrieval/bi-encoder/eli5/tuned_models/msmarco-distilbert-base-tas-b_eli5/'
bi_encoder = SentenceTransformer(bi_encoder_path)

cross_encoder_path = '/contextretrieval/cross-encoder/eli5/tuned_models/ms-marco-MiniLM-L-6-v2_eli5/'
cross_encoder = CrossEncoder(cross_encoder_path,default_activation_function=nn.Sigmoid())

In [None]:
# format passages depending on chosen model 
passage_titles = []
if 'output' in data.columns:
    for i in range(len(data)):
        title = data['output'][i][0]['provenance'][0]['title']
        passage_titles.append(title)
else:
    for i in range(len(data)):
        title = data['passages'][i][0]['title']
        passage_titles.append(title)

if 'dpr' in bi_encoder_path:
    passages = []
    for i in range(0,len(data)):
        passages.append(passage_titles[i]+ ' [SEP] ' + data['passages_text'][i])
else:
    passages = list(data['passages_text'])

## Semantic Search & Re-Ranker

In [None]:
# load pre-computed corpus embeddings 
embeddings_path = '/contextretrieval/bi-encoder/eli5/embeddings/msmarco-distilbert-base-tas-b_eli5.pickle'

with open(embeddings_path, 'rb') as pkl:
    corpus_embeddings = pickle.load(pkl)

In [None]:
# select number of passages to retrieve 
top_k=10 

In [None]:
def search_and_rank(query):
    
    # ------ PASSAGE RETRIEVAL ------
    start_time = time.time()
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    hits = sentenceutils.semantic_search(question_embedding, corpus_embeddings, top_k=top_k, score_function=sentenceutils.dot_score)
    hits = hits[0]  # Get the hits for the first query
    end_time = time.time()
    
    print("Input question:", query)
    print("\n-------------------------\n")
    print("Top 10 passages (after {:.3f} seconds):".format(end_time - start_time))
    
    for hit in hits:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']]))
        hit['passage'] = passages[hit['corpus_id']]
    
    # ------ RE-RANKER -----
    # score passages
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)
    
    # sort results
    for i in range(len(cross_scores)):
        hits[i]['cross-score'] = cross_scores[i]

    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))

In [None]:
input_query = input('>>>') 
passages = search_and_rank(input_query)
print(passages)