In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
%%writefile requirments.txt
gradio
qdrant-client
sentence-transformers
tqdm

Overwriting requirments.txt


In [3]:
!echo "Requirments Are:" && cat requirments.txt && echo "-----"

!pip install -r requirments.txt

Requirments Are:
gradio
qdrant-client
sentence-transformers
tqdm
-----


In [4]:
from download_dataset import get_dataset

SIMPLE_WIKI_PATH = 'simplewiki-2020-11-01.jsonl.gz'

dataset = get_dataset(SIMPLE_WIKI_PATH)

passages = dataset['passages']
articles = dataset['articles']

100%|██████████| 50.2M/50.2M [00:01<00:00, 33.4MB/s]


In [5]:
from sentence_transformers import SentenceTransformer, CrossEncoder
from pprint import pprint


encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
pprint({
    'max_seq_length': encoder.get_max_seq_length(),
    'sentence_embedding_dimension': encoder.get_sentence_embedding_dimension(),
    'tokenizer': encoder.tokenizer
})

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
pprint({
    'max_length': cross_encoder.max_length,
    'tokenizer': cross_encoder.tokenizer
})

{'max_seq_length': 512,
 'sentence_embedding_dimension': 384,
 'tokenizer': BertTokenizerFast(name_or_path='sentence-transformers/multi-qa-MiniLM-L6-cos-v1', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}}
{'max_length': None,

In [6]:
import os

from qdrant_client import QdrantClient, models
from tqdm.autonotebook import tqdm

qdrant = QdrantClient(':memory:') # create in-mem instance of vector db
# qdrant = QdrantClient(
#     url=os.environ['QDRANT_URL'],
#     api_key=os.environ['QDRANT_API_KEY'],
# )

COLLECTION_NAME = 'simplewiki'

collections_names = list(map(lambda x: x.name, qdrant.get_collections().collections))
assert COLLECTION_NAME in collections_names
assert qdrant.get_collection(COLLECTION_NAME).vectors_count == 251664

In [4]:
from build_simplewiki_index import build_index

if COLLECTION_NAME not in collections_names:
    build_index(
        passages=passages, 
        batch_size=200,
        start_idx=0,
        encoder=encoder,
        collection_name=COLLECTION_NAME
    )

In [10]:
from search_helpers import (
    retrieve_top_k, 
    rerank_hits,
    fetch_top_article, 
    fetch_top_article_with_passage_highlighted, 
    fetch_article_title_with_order, 
    extract_sentence_and_partition,
    fetch_top_passage
)

# query = "who built the pyramids in egypt?"
query = "capital of united states"
# query = "egypt history"
top_k = 10

original_hits = retrieve_top_k(query, top_k, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME)

# # pprint([h['article_id'] for h in original_hits])
# pprint(fetch_top_article(original_hits, articles=articles))
# pprint(fetch_article_title_with_order(original_hits, articles=articles))
# pprint(fetch_top_article_with_passage_highlighted(original_hits, articles=articles))

reranked_hits = rerank_hits(query, original_hits, cross_encoder=cross_encoder)

# # pprint([h['article_id'] for h in reranked_hits])
# fetch_top_article(reranked_hits, articles=articles)
# pprint(fetch_article_title_with_order(reranked_hits, articles=articles))
# pprint(fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles))

In [7]:
%load_ext gradio

In [14]:
import gradio as gr


def process_query(query, top_k):
  original_hits = retrieve_top_k(query, top_k, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME)
  original_titles_with_order = fetch_article_title_with_order(original_hits, articles=articles)
  reranked_hits = rerank_hits(query, original_hits, cross_encoder=cross_encoder)
  reranked_titles_with_order = fetch_article_title_with_order(reranked_hits, articles=articles)

  return (
    fetch_top_passage(original_hits),
    fetch_top_passage(reranked_hits),
    fetch_top_article_with_passage_highlighted(original_hits, articles=articles),
    fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles),
    original_titles_with_order,
    reranked_titles_with_order,
  )


with gr.Blocks() as retrieve_rerank_demo:
  gr.Markdown(
    """
    # Simple Wikipedia Semantic Search 🔍 Through Retrieval and Reranking
    By inputing queries or questions, this space leverages machine learning to surface the most relevant Wikipedia passages and articles, providing insights and answers.

    ## Features
    - **Dataset**: This program sifts through Simple Wikipedia data, creating a comprehensive semantic search platform that covers a broad spectrum of information.
    - **Indexing**: Utilizes an index vector database on `Qdrant` cloud, constructed from {} Wikipedia passages, for quick and efficient retrieval.
    - **NLP Models**: 
      - **Bi-Encoder**: Employs the pretrained encoder [`multi-qa-MiniLM-L6-cos-v1`](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) from Sentence Transformers to retrieve the top K relevant passages.
      - **Cross-Encoder Reranking**: Further refines search results using the [`cross-encoder/ms-marco-MiniLM-L-12-v2`](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2?text=I+like+you.+I+love+you) model to rerank the passages, ensuring the delivery of contextually relevant information.
    - **Discovery**: Showcases the most relevant passage alongside the top article before and after reranking, with the relevant passage highlighted for ease of understanding. Additionally, displays the article titles of the top K results before and after reranking, offering a comprehensive view of your search results.

    ## How to Use
    1. **Enter a Query or Question**: Start by typing your query or question into the system.
    2. **Select K Value (2-30)**: Choose the number of results you wish to retrieve.
    3. **Discover Relevant Information**: The system then processes your input using both bi-encoder and cross-encoder techniques to first fetch and then rerank the passages. The most relevant Wikipedia passages and articles, before and after reranking, are presented, highlighting the progress of relevance through the process.

    Experience the power of retrieval and reranking for semantic search and find the articles you need more efficiently.

    """.format(qdrant.get_collection(COLLECTION_NAME).vectors_count)
  )

  with gr.Row():

    with gr.Column(scale=4):
      with gr.Row():
        passage_before_rerank = gr.TextArea(
          value="", 
          label="Answer - Retrieval Only", 
          autoscroll=False, 
          interactive=False,
          lines=4,
        )
        passage_after_rerank = gr.TextArea(
          value="", 
          label="Answer - Retrieval + Reranking", 
          autoscroll=False, 
          interactive=False,
          lines=4,
        )

      with gr.Row():
        highlighted_article_before_rerank = gr.HighlightedText(
          value=[], 
          label="Top Article - Retrieval Only", 
          color_map={'relevant passage': 'green'},
        )
        highlighted_article_after_rerank = gr.HighlightedText(
          value=[], 
          label="Top Article - Retrieval + Reranking", 
          color_map={'relevant passage': 'green'}
        )

      with gr.Row():
        titles_before_rerank = gr.Label(value={}, label="Article Titles - Retrieval Only")
        titles_after_rerank = gr.Label(value={}, label="Article Titles - Retrieval + Reranking")

    with gr.Column(scale=1):
      input_question = gr.Textbox(
        label="Query",
        placeholder="What's on your mind?"
      )
      top_k_slider = gr.Slider(
        value=10,
        minimum=2,
        maximum=30,
        label="Articles to retrieve & rerank",
        interactive=True,
        step=1
      )
      button = gr.Button("Find the most relevant article from Simple Wikipedia")
      gr.Examples(
        examples=[
          ["capital of united states", 10],
          ["pyramids of Egypt", 25],
          ["number of countries in Africa", 7],

        ],
        inputs=[input_question, top_k_slider]
      )

  button.click(
    fn=process_query,
    inputs=[
      input_question,
      top_k_slider
    ],
    outputs=[
      passage_before_rerank,
      passage_after_rerank,
      highlighted_article_before_rerank,
      highlighted_article_after_rerank,
      titles_before_rerank,
      titles_after_rerank,
    ]
  )


retrieve_rerank_demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7865
Running on public URL: https://2628f9593248ddd0f4.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


