### Rerank endpoint 
Takes a query and a list of documents and predicts the relevance between the query and each document. 

It can be used in a two-stage retrieval setup: First you take the user question, and retrieve the top-100 documents from your collection by either using lexical search or semantic search.

You then pass the question and these top-100 documents to our relevance-endpoint to get a score for each document. You can then rank these documents based on these scores.

We will demonstrate the rerank endpoint in this notebook.



In [None]:
!pip install --upgrade setuptools==69.5.1 wheel --quiet
!pip install --upgrade cohere-aws

In [None]:
import requests
import numpy as np
from time import time
from typing import List
from pprint import pprint
from cohere_aws import Client
import boto3

In [None]:
# Set up your cohere client
co = Client(region_name='us-east-1')
co.connect_to_endpoint(endpoint_name="cohere-rerank-multilingual")

## Search on  Wikipedia - End2End demo
The following is an example how to use this model end-to-end to search over the Simple English Wikipedia, which consists of about 500k passages. 

We use BM25 lexical search to retrieve the top-100 passages matching the query and then send these 100 passages and the query to our rerank endpoint to get a re-ranked list. We output the top-3 hits according to BM25 lexical search 
and the re-ranked list from our endpoint.


In [None]:
!pip install -U  rank_bm25

In [None]:
import json
import gzip
import os
from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm

In [None]:
!wget http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz

In [None]:
# As dataset, we use Simple English Wikipedia. Compared to the full English wikipedia, it has only
# about 170k articles. We split these articles into paragraphs and encode them with the bi-encoder
wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())
        passages.extend(data['paragraphs'])

print("Passages:", len(passages))

In [None]:
print(f"--{passages[0]}\n--{passages[1]}")

In [None]:
# We compare the results to lexical search (keyword search). Here, we use 
# the BM25 algorithm which is implemented in the rank_bm25 package.

# We lower case our text and remove stop-words from indexing
def bm25_tokenizer(text):
    tokenized_doc = []
    for token in text.lower().split():
        token = token.strip(string.punctuation)

        if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
            tokenized_doc.append(token)
    return tokenized_doc


tokenized_corpus = []
for passage in tqdm(passages):
    tokenized_corpus.append(bm25_tokenizer(passage))

#Create a BM25 index from the tokenized document corpus
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
# This function will search all wikipedia articles for passages that
# answer the query. We then re-rank using our rerank endpoint

def search(query, top_k=3, num_candidates=100):
    print("Input question:", query)

    ##### BM25 search (lexical search) #####
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    top_n = np.argpartition(bm25_scores, -num_candidates)[-num_candidates:]
    bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
    bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
    
    print(f"\nTop-3 lexical search (BM25) hits")
    print("-----------------------------------")

    for hit in bm25_hits[0:top_k]:
        print("\t{:.3f}\t{}".format(hit['score'],passages[hit['corpus_id']].replace("\n", " ")))

    #Add re-ranking
    docs = [passages[hit['corpus_id']] for hit in bm25_hits]
    
    print(f"\nTop-3 hits by rank-API ({len(bm25_hits)} BM25 hits re-ranked)")
    print("-------------------------------------------------")

    results = co.rerank(query=query,documents=docs, top_n=top_k)
    
    for hit in results:
        hit.index+=1
        print("\t{:.3f} was({})\t{}".format(hit.relevance_score, hit.index, hit.document["text"].replace("\n", " ")))

In [None]:
search(query = "What is the capital of the United States?")

In [None]:
search(query = "Elon Musk year birth")

# Clean-up
Delete the model

Now that you have successfully performed a real-time inference, you do not need the endpoint any more. You can terminate the endpoint to avoid being charged.


In [None]:
co.delete_endpoint()
co.close()