In [64]:
!pip install -q -qU datasets==2.12.0 qdrant-client==1.2.0 sentence-transformers==2.2.2 torch==2.0.1
!pip install -q sentence-transformers

from sentence_transformers import SentenceTransformer
import numpy as np
import json
import pandas as pd
from tqdm.notebook import tqdm
from typing import List

In [65]:
# upload dataset e.g. bigBasketProducts.csv

In [66]:
model = SentenceTransformer('all-mpnet-base-v2', device="cuda")

In [67]:
df = pd.read_csv('./bigBasketProducts.csv')

In [68]:
headers = [h for h in df.columns]

In [69]:
items = []
for row in df.itertuples():
  string = ''
  for i in range(len(headers)):
    string += headers[i] + ': ' + str(row[i]) + '; '
  items.append(string)

In [70]:
vectors = model.encode(items, show_progress_bar=True)

Batches:   0%|          | 0/862 [00:00<?, ?it/s]

In [71]:
vectors.shape

(27555, 768)

In [72]:
from qdrant_client import QdrantClient
from qdrant_client.http import models

In [73]:
client = QdrantClient(":memory:")
collection_name = "extractive-question-answering"
collections = client.get_collections()
print(collections)

# only create collection if it doesn't exist
if collection_name not in [c.name for c in collections.collections]:
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config=models.VectorParams(
            size=768,
            distance=models.Distance.COSINE,
        ),
    )
collections = client.get_collections()
print(collections)

collections=[]
collections=[CollectionDescription(name='extractive-question-answering')]


In [74]:
batch_size = 512  # specify batch size according to your RAM and compute, higher batch size = more RAM usage

for index in tqdm(range(0, len(df), batch_size)):
    i_end = min(index + batch_size, len(df))  # find end of batch
    batch = df.iloc[index:i_end]  # extract batch
    emb =  vectors[index:i_end].tolist() # generate embeddings for batch
    meta = batch.to_dict(orient="records")  # get metadata
    ids = list(range(index, i_end))  # create unique IDs
    # upsert to qdrant
    client.upsert(
        collection_name=collection_name,
        points=models.Batch(ids=ids, vectors=emb, payloads=meta),
    )

collection_vector_count = client.get_collection(collection_name=collection_name).vectors_count
print(f"Vector count in collection: {collection_vector_count}")
assert collection_vector_count == len(df)

  0%|          | 0/54 [00:00<?, ?it/s]

Vector count in collection: 27555


In [75]:
model_name = "bert-large-uncased-whole-word-masking-finetuned-squad"
from transformers import pipeline

# load the reader model into a question-answering pipeline
reader = pipeline("question-answering", model=model_name, tokenizer=model_name)
print(reader.model, reader)

Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForQuestionAnswering(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-23): 24 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), ep

In [76]:
def get_context(question: str, top_k: int, progress_bar) -> List[str]:
    """
    Get the relevant plot for a given question

    Args:
        question (str): What do we want to know?
        top_k (int): Top K results to return

    Returns:
        context (List[str]):
    """
    try:
        encoded_query = model.encode(question).tolist()  # generate embeddings for the question

        result = client.search(
            collection_name=collection_name,
            query_vector=encoded_query,
            limit=top_k,
        )  # search qdrant collection for context passage with the answer

        context = [
            [x.payload] for x in result
        ]  # extract title and payload from result
        progress_bar.update(1)
        return context

    except Exception as e:
        print({e})
        progress_bar.update(1)

In [77]:
def extract_answer(question: str, context: List[str], progress_bar):
    """
    Extract the answer from the context for a given question

    Args:
        question (str): _description_
        context (list[str]): _description_
    """
    results = []
    for c in context:
        # feed the reader the question and contexts to extract answers
        string = ''
        for i in range(len(headers)):
          string += headers[i] + ': ' + str(c[0][headers[i]]) + '; '
        answer = reader(question=question, context = string)
        answer["title"] = c[0]['index']
        results.append(answer)
        progress_bar.update(1)

    # sort the result based on the score from reader model
    sorted_result = sorted(results, key=lambda x: x["score"], reverse=True)
    return sorted_result

In [78]:
def display_answer(query):
  progress_bar = tqdm(total=12, desc='Processing', leave=False)
  context = get_context(query,10, progress_bar)
  result = extract_answer(query, context, progress_bar)
  string = ''
  for i in range(5):
    if result[i]["answer"] not in string:
      if i==0: string += result[i]["answer"] + f'[{result[i]["title"]}]'
      else: string += ', ' + result[i]["answer"] + f'[{result[i]["title"]}]'
  progress_bar.update(1)
  progress_bar.close()
  string += '.'
  return string

In [79]:
import ipywidgets as widgets
from IPython.display import display

# Define interactive widgets
index = 0
text_input = widgets.Text(value='Enter your Query here...', description='Query:')
button = widgets.Button(description='Search')
output = widgets.Output()
# Define a function to be called when the button is clicked
def on_button_click(b):
    global index
    with output:
        res = display_answer(text_input.value)
        print(f'{index+1}. Query: {text_input.value} \nAnswer: {res} \n')
        index += 1

# Connect the button click event to the function
button.on_click(on_button_click)

# Display widgets
display(text_input)
display(button)
display(output)


Text(value='Enter your Query here...', description='Query:')

Button(description='Search', style=ButtonStyle())

Output()