In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
%%writefile requirments.txt
gradio
qdrant-client
sentence-transformers
tqdm

Overwriting requirments.txt


In [3]:
!echo "Requirments Are:" && cat requirments.txt && echo "-----"

!pip install -r requirments.txt

Requirments Are:
gradio
qdrant-client
sentence-transformers
tqdm
-----
Collecting gradio (from -r requirments.txt (line 1))
  Downloading gradio-4.22.0-py3-none-any.whl.metadata (15 kB)
Collecting qdrant-client (from -r requirments.txt (line 2))
  Downloading qdrant_client-1.8.0-py3-none-any.whl.metadata (9.5 kB)
Collecting sentence-transformers (from -r requirments.txt (line 3))
  Downloading sentence_transformers-2.6.0-py3-none-any.whl.metadata (11 kB)
Collecting tqdm (from -r requirments.txt (line 4))
  Downloading tqdm-4.66.2-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.6/57.6 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio->-r requirments.txt (line 1))
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting altair<6.0,>=4.2.0 (from gradio->-r requirments.txt (line 1))
  Downloading altair-5.2.0-py3-none-any.whl.metadata (8.7 kB)
Collecting fastapi (from grad

In [2]:
from download_dataset import get_dataset

SIMPLE_WIKI_PATH = 'simplewiki-2020-11-01.jsonl.gz'

dataset = get_dataset(SIMPLE_WIKI_PATH)

passages = dataset['passages']
articles = dataset['articles']

100%|██████████| 50.2M/50.2M [00:01<00:00, 34.6MB/s]


In [3]:
from sentence_transformers import SentenceTransformer, CrossEncoder
from pprint import pprint


encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
pprint({
    'max_seq_length': encoder.get_max_seq_length(),
    'sentence_embedding_dimension': encoder.get_sentence_embedding_dimension(),
    'tokenizer': encoder.tokenizer
})

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-12-v2')
pprint({
    'max_length': cross_encoder.max_length,
    'tokenizer': cross_encoder.tokenizer
})

{'max_seq_length': 512,
 'sentence_embedding_dimension': 384,
 'tokenizer': BertTokenizerFast(name_or_path='sentence-transformers/multi-qa-MiniLM-L6-cos-v1', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}}
{'max_length': None,

In [4]:
import os

from qdrant_client import QdrantClient, models
from tqdm.autonotebook import tqdm

qdrant = QdrantClient(':memory:') # create in-mem instance of vector db
# qdrant = QdrantClient(
#     url=os.environ['QDRANT_URL'],
#     api_key=os.environ['QDRANT_API_KEY'],
# )

COLLECTION_NAME = 'simplewiki'

collections_names = list(map(lambda x: x.name, qdrant.get_collections().collections))
assert COLLECTION_NAME in collections_names
assert qdrant.get_collection(COLLECTION_NAME).vectors_count == 508000

In [5]:
from build_simplewiki_index import build_index

if COLLECTION_NAME not in collections_names:
    build_index(
        passages=passages, 
        batch_size=200,
        start_idx=0,
        encoder=encoder,
        collection_name=COLLECTION_NAME
    )

In [5]:
from search_helpers import (
    retrieve_top_k, 
    rerank_hits,
    fetch_top_article, 
    fetch_top_article_with_passage_highlighted, 
    extract_sentence_and_partition,
    fetch_top_passage
)

# query = "who built the pyramids in egypt?"
query = "capital of united states"
# query = "egypt history"
top_k = 10

original_hits, _ = retrieve_top_k(query, top_k, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME)

# pprint([h['article_id'] for h in original_hits])
# pprint(fetch_top_article(original_hits, articles=articles))
# pprint(fetch_article_title_with_order(original_hits, articles=articles))
# pprint(fetch_top_article_with_passage_highlighted(original_hits, articles=articles))

reranked_hits, _ = rerank_hits(query, original_hits, cross_encoder=cross_encoder, articles=articles)

# # pprint([h['article_id'] for h in reranked_hits])
# fetch_top_article(reranked_hits, articles=articles)
# pprint(fetch_article_title_with_order(reranked_hits, articles=articles))
# pprint(fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles))

In [7]:
%load_ext gradio

In [6]:
import gradio as gr
import pandas as pd


RETRIEVAL_TOP_K = 40
DISPLAY_TOP_K = 10


def process_query(query):
  original_hits, retrieval_time = retrieve_top_k(query, RETRIEVAL_TOP_K, vec_db=qdrant, encoder=encoder, collection_name=COLLECTION_NAME)
  reranked_hits, reranking_time = rerank_hits(query, original_hits, cross_encoder=cross_encoder, articles=articles)

  reranked_hits = reranked_hits[:DISPLAY_TOP_K]

  df = pd.DataFrame(
    {
      "Retrieval Order": [value['retrieval_order'] for value in reranked_hits],
      "Reranking Order": [value['reranked_order'] for value in reranked_hits],
      "Title": [value['title'] for value in reranked_hits],
      "Answer": [value['passage'] for value in reranked_hits],
      "Article Text": [articles[value['article_id']]['content'] for value in reranked_hits],
    }
  )

  return (
    fetch_top_article_with_passage_highlighted(reranked_hits, articles=articles),
    df,
    {
      "Retrieval Time": str(round(retrieval_time, 3)) + " s",
      "Reranking Time": str(round(reranking_time, 3)) + " s",
    }
  )


def update(selected_index: gr.SelectData, df):
  val = df.iloc[selected_index.index[0]]
  return extract_sentence_and_partition(val['Article Text'], val['Answer'])


with gr.Blocks() as retrieve_rerank_demo:
  gr.Markdown(
      """
      # Simple Wikipedia Semantic Search 🔍 Through Retrieval and Reranking
      By inputing queries or questions, this space leverages machine learning to surface the most relevant Simple Wikipedia passages and articles, providing most relevant answers out of **{}** passages indexed on Qdrant cloud using binary quantization.
      """.format(qdrant.get_collection(COLLECTION_NAME).vectors_count)
  )

  with gr.Accordion("Click to learn about the retreival process", open=False):
    gr.Markdown(
      """
      ## Features
      1. Encode all passages from Simple Wikipedia dataset into embeddings using a pretrained bi-encoder [`multi-qa-MiniLM-L6-cos-v1`](https://huggingface.co/sentence-transformers/multi-qa-MiniLM-L6-cos-v1) from Sentence Transformers
      2. Index the embeddings on `Qdrant` cloud using binary quantization for efficient retrieval, resulting in {} vector embeddings for encoded passages
      3. The user enters a search query like a sentence or a questions
      4. Encoding the user search query using the bi-encoder model
      5. Retrieve the 40 most relevant passages to the input query by sifting through the indexed embeddings in the Qdrant collection and by leveraging binary quantization to boost retrieval speed
      6. Rerank search results using a cross-encoder [`ms-marco-MiniLM-L-12-v2`](https://huggingface.co/cross-encoder/ms-marco-MiniLM-L-12-v2) to priortize the most contextually relevant passages
      7. Show the top article with the answer highlighted in green, the top 10 reranked answers in a DataFrame view, and the processing time required for both retrieval and reranking

      """.format(qdrant.get_collection(COLLECTION_NAME).vectors_count)
    )

  input_question = gr.Textbox(
    label="Query for Simple Wikipedia articles",
    placeholder="Enter a query to search for relevant texts from Simple Wikipedia",
  )
  gr.Examples(
    examples=[
      ["capital of united states"],
      ["pyramids of Egypt"],
      ["number of countries in Africa"],
      ["how many people live in alexandria"],
      ["where is the red sea?"]
    ],
    inputs=[input_question]
  )
  button = gr.Button("Search 🔍")

  with gr.Accordion("Click to read the top article with answer highlighted", open=True):
    highlighted_article_after_rerank = gr.HighlightedText(
      value=[], 
      label="Top Article with Answer Highlighted", 
      color_map={'relevant passage': 'green'}
    )

  df_output = gr.Dataframe(
    headers=[
      "Retrieval Order",
      "Reranking Order",
      "Title",
      "Answer",
      "Article Text"
    ]
  )

  runtime_info = gr.Json()

  button.click(
    fn=process_query,
    inputs=[
      input_question,
    ],
    outputs=[
      highlighted_article_after_rerank,
      df_output,
      runtime_info
    ]
  )

  df_output.select(
    fn=update,
    inputs=df_output, 
    outputs=highlighted_article_after_rerank
  )


retrieve_rerank_demo.launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://578a2ddd183052e293.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




Traceback (most recent call last):
  File "/home/codespace/.local/lib/python3.10/site-packages/httpx/_transports/default.py", line 69, in map_httpcore_exceptions
    yield
  File "/home/codespace/.local/lib/python3.10/site-packages/httpx/_transports/default.py", line 233, in handle_request
    resp = self._pool.handle_request(req)
  File "/home/codespace/.local/lib/python3.10/site-packages/httpcore/_sync/connection_pool.py", line 216, in handle_request
    raise exc from None
  File "/home/codespace/.local/lib/python3.10/site-packages/httpcore/_sync/connection_pool.py", line 196, in handle_request
    response = connection.handle_request(
  File "/home/codespace/.local/lib/python3.10/site-packages/httpcore/_sync/connection.py", line 99, in handle_request
    raise exc
  File "/home/codespace/.local/lib/python3.10/site-packages/httpcore/_sync/connection.py", line 76, in handle_request
    stream = self._connect(request)
  File "/home/codespace/.local/lib/python3.10/site-packages/httpcor