In [1]:
import json
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import gzip
import os
import torch
import pickle
import pandas as pd

In [2]:
def load_artifacts(bi_enc='multi-qa-MiniLM-L6-cos-v1',cross_enc='cross-encoder/ms-marco-MiniLM-L-6-v2',corpus='corpus.pkl',corpus_emb='corpus_embeddings.pkl'):
    bi_encoder = SentenceTransformer(bi_enc)
    cross_encoder = CrossEncoder(cross_enc)
    corpus_embeddings=pd.read_pickle(corpus_emb)
    corpus=pd.read_pickle(corpus)
    return bi_encoder,cross_encoder,corpus,corpus_embeddings

In [14]:
def search_helper(query,bi_encoder,cross_encoder,corpus,corpus_embeddings,top_k=100):
    ##### Sematic Search #####
    # Encode the query using the bi-encoder and find potentially relevant passages
    ans_lst=[]
    question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
    question_embedding = question_embedding.cuda()
    hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
    hits = hits[0]  # Get the hits for the first query

    ##### Re-Ranking #####
    # Now, score all retrieved passages with the cross_encoder
    cross_inp = [[query, corpus[hit['corpus_id']]] for hit in hits]
    cross_scores = cross_encoder.predict(cross_inp)

    # Sort results by the cross-encoder scores
    for idx in range(len(cross_scores)):
        hits[idx]['cross-score'] = cross_scores[idx]

    hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
    
    for hit in hits[0:5]:
        ans_lst.append(corpus[hit['corpus_id']])
                       
    return ans_lst                

In [15]:
def search(query):
    return search_helper(query,bi_encoder,cross_encoder,corpus,corpus_embeddings,top_k=100)

In [16]:
bi_encoder,cross_encoder,corpus,corpus_embeddings=load_artifacts()

In [17]:
search("where is france")

['France ( or ; ), officially the French Republic (, ), is a country whose metropolitan territory is located in Western Europe and that also comprises various overseas islands and territories located in other continents. Metropolitan France extends from the Mediterranean Sea to the English Channel and the North Sea, and from the Rhine to the Atlantic Ocean. It is often referred to as "L’Hexagone" ("The Hexagon") because of the shape of its territory. France is a unitary semi-presidential republic with its main ideals expressed in the Declaration of the Rights of Man and of the Citizen.',
 'The capital of France is Paris. In the course of history, the national capital has been in many locations other than Paris.',
 'Metropolitan France ( or "la Métropole") is the part of France that is in Europe. It can also be described as mainland France plus the island of Corsica. By contrast, Overseas France ("France d\'outre-mer") is the collective name for all of the French overseas departments, t

In [18]:
import gradio as gr

In [19]:
iface = gr.Interface(fn=search, inputs=["text"], outputs="textbox").launch(share=True)

Running on local URL:  http://127.0.0.1:7862/
Running on public URL: https://32792.gradio.app

This share link will expire in 72 hours. To get longer links, send an email to: support@gradio.app
