# Retrieve & Re-Rank Demo over Simple Wikipedia

This examples demonstrates the Retrieve & Re-Rank Setup and allows to search over [Simple Wikipedia](https://simple.wikipedia.org/wiki/Main_Page).

You can input a query or a question. The script then uses semantic search
to find relevant passages in Simple English Wikipedia (as it is smaller and fits better in RAM).

For semantic search, we use `SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')` and retrieve
32 potentially passages that answer the input query.

Next, we use a more powerful CrossEncoder (`cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')`) that
scores the query and all retrieved passages for their relevancy. The cross-encoder further boost the performance,
especially when you search over a corpus for which the bi-encoder was not trained for.


In [16]:
!pip freeze | grep sentence-transformers 

sentence-transformers==2.1.0


In [3]:
!pip install -U sentence-transformers rank_bm25



In [4]:

import json
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import gzip
import os
import torch

In [8]:
gpu = True
if not torch.cuda.is_available():
  print("GPU not found")
  gpu = False

#we use biencoder to encode all passages, so that we use it with semantic search
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_sequence_length = 256
top_k = 32 # number of documents to be retrieved from biencoder

#Biencoder will retrieve n documents, cross encoder will rerank
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

#dataset here is wikipedia
wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)

passages = []

with gzip.open(wikipedia_filepath, 'rt', encoding='utf-8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())

        passages.append(data['paragraphs'][0])

print(f"{len(passages)}")

#encoder all passages into vector space
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)


Downloading:   0%|          | 0.00/794 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.7M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/316 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

  0%|          | 0.00/50.2M [00:00<?, ?B/s]

169597


Batches:   0%|          | 0/5300 [00:00<?, ?it/s]

In [12]:

def search(query, n_ranked_hits, top_k, corpus_embeddings, passages, gpu):
    """
    :param query:
    :param n_ranked_hits:
    :param top_k:
    :param corpus_embeddings:
    :param passagees:
    :param gpu:
    :return:
    """
    #semantic encoding of query
    query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    if gpu:
        query_embedding.cuda()
    hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]

    #Rerank
    cross_encoder_input = [(query, passages[hit['corpus_id']]) for hit in hits]
    cross_scores = cross_encoder.predict(cross_encoder_input)

    #Sort results by cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross_score'] = cross_scores[idx]

    #Output top n reranker hits
    hits = sorted(hits, key = lambda x: x['cross_score'], reverse=True)

    return hits[0:n_ranked_hits]


def format_hits(passages, hits):
    """
    :param passages:
    :param hits:
    :return:
    """
    for hit in hits:
        print(f"{hit['cross_score']} {passages[hit['corpus_id']]}")


In [13]:
query = 'What is the capital of the United States' #read from path location
n_ranked_hits = 3
reranked_hits = search(query, n_ranked_hits, top_k, corpus_embeddings, passages, gpu)

# format hits in readable format
format_hits(passages, reranked_hits)


9.025672912597656 Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.
3.9615566730499268 The United States Capitol is the building where the United States Congress meets. It is the center of the legislative branch of the U.S. federal government. It is in Washington, D.C., on top of Capitol Hill at the east end of the National Mall.
3.7371184825897217 The continental United States is the area of the United States of America that is located in the continent of North America. It includes 49 of the 50 states (48 of which are located south of Canada and north of Mexico, known as the "lower 48 states", the other being Alaska) and the District of Columbia, which contains the federal capital, Washington, D.C. The only state wh