In [1]:
import json
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import gzip
import os
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.set_default_device("cuda")

In [6]:
bi_encoder = SentenceTransformer("sentence-transformers/multi-qa-distilbert-cos-v1")

In [8]:
top_k = 16
bi_encoder.max_seq_length = 256

In [9]:
cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [17]:
wikipedia_filepath = 'simplewiki-2020-11-01.jsonl.gz'

if not os.path.exists(wikipedia_filepath):
    util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath)

passages = []
with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn:
    for line in fIn:
        data = json.loads(line.strip())

        #Add all paragraphs
        passages.extend(data['paragraphs'])

        #Only add the first paragraph
        # passages.append(data['paragraphs'][0])

print("Passages:", len(passages))

Passages: 509663


In [18]:
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)

Batches: 100%|██████████| 15927/15927 [11:38<00:00, 22.79it/s]


In [19]:
corpus_embeddings[0].shape

torch.Size([768])

In [20]:
def search(query, top_k):
    #performing a semantic search with bi-encoder first and getting the top-16 results
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query
    
    #re-ranking with cross-encoder
    cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    # Output of top-5 hits from bi-encoder
    print("\n-------------------------\n")
    print("Top-3 Bi-Encoder Retrieval hits")
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))

    # Output of top-5 hits from re-ranker
    print("\n-------------------------\n")
    print("Top-3 Cross-Encoder Re-ranker hits")
    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    for hit in hits[0:3]:
        print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))

In [22]:
search(query = "What is the capital of the United States?", top_k=top_k)


-------------------------

Top-3 Bi-Encoder Retrieval hits
	0.751	United States Capitol.
	0.621	United States of America;
	0.605	The first capital city of the United States was New York City. At this time, Congress met in City Hall (Federal Hall) from 1785 to 1790. When the capital was moved to Philadelphia, Pennsylvania, from 1790 to 1800, the Philadelphia County Building (Congress Hall) became the capitol. In 1800, the capital moved again to Washington, D.C., and a new capitol building was built.

-------------------------

Top-3 Cross-Encoder Re-ranker hits
	8.906	Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.
	7.323	The first capital city of the United States was New York City. At this time, Congress met in C