In [6]:
import pandas as pd
import numpy as np
import torch
import transformers
from transformers import AutoTokenizer
from transformers import DistilBertTokenizer, DistilBertModel
from sentence_transformers import SentenceTransformer
import faiss
import sentence_transformers

In [12]:
sentences = pd.read_json("https://lp-prod-resources.s3.amazonaws.com/493/57248/2021-05-04-13-31-46/sentences.json").rename(columns = {0:"sentence_text"}) 
questions = pd.read_json("https://lp-prod-resources.s3.amazonaws.com/493/57248/2021-08-16-19-04-45/questions.json").rename(columns = {0:"question_text"}) 

In [13]:
sentences, questions

(                                        sentence_text
 0   A pandemic is an epidemic of an infectious dis...
 1   The most fatal pandemic in recorded history wa...
 2   Current pandemics include COVID-19 (SARS-CoV-2...
 3   As of 2018, approximately 37.9 million people ...
 4   Cholera is an infection of the small intestine...
 5   Classic cholera symptom is large amounts of wa...
 6   The COVID-19 pandemic, also known as the coron...
 7   Common symptoms of COVID-19 include fever, cou...
 8   The Plague of Cyprian was a pandemic that affl...
 9   The Spanish flu, also known as the 1918 flu pa...
 10  The death toll of Spanish Flu is estimated to ...,
                                        question_text
 0      How many people have died during Black Death?
 1      Which diseases can be transmitted by animals?
 2  Connection between climate change and a likeli...
 3               What is an example of a latent virus
 4                          Viruses in nanotechnology
 5             

In [7]:
model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

Downloading: 100%|██████████| 345/345 [00:00<00:00, 172kB/s]
Downloading: 100%|██████████| 190/190 [00:00<00:00, 30.0kB/s]
Downloading: 100%|██████████| 4.01k/4.01k [00:00<00:00, 652kB/s]
Downloading: 100%|██████████| 555/555 [00:00<00:00, 208kB/s]
Downloading: 100%|██████████| 122/122 [00:00<00:00, 14.0kB/s]
Downloading: 100%|██████████| 265M/265M [00:39<00:00, 6.78MB/s] 
Downloading: 100%|██████████| 53.0/53.0 [00:00<00:00, 10.3kB/s]
Downloading: 100%|██████████| 112/112 [00:00<00:00, 20.4kB/s]
Downloading: 100%|██████████| 466k/466k [00:00<00:00, 1.73MB/s]
Downloading: 100%|██████████| 505/505 [00:00<00:00, 505kB/s]
Downloading: 100%|██████████| 232k/232k [00:00<00:00, 1.06MB/s]
Downloading: 100%|██████████| 229/229 [00:00<00:00, 63.6kB/s]


In [11]:
model.encode("Test String").shape

(768,)

In [14]:
vectors = [
    model.encode(document)
    for document in sentences.sentence_text
]

In [17]:
[v.shape for v in vectors]

[(768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,)]

In [38]:
def encode(document: str):
  vector = model.encode(document)
  return vector

In [39]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(768))

In [54]:
index.add_with_ids(
    np.array([t for t in vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(sentences.sentence_text))).astype(np.int64)
)

In [63]:
def search(query: str, k=1):
  encoded_query = np.expand_dims(encode(query), 0)
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [sentences.sentence_text[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

In [65]:
q_num = 2
num_similar_docs = 5

print("Question: {}".format(questions.question_text[q_num])), print("Answer: {}".format(search(questions.question_text[q_num], num_similar_docs)))

Question: Connection between climate change and a likelihood of a pandemic
Answer: [('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 136.00598), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 136.00598), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 136.00598), ('The Spanish flu, also known as the 1918 flu pandemic, was an unusually deadly influenza pandemic caused by the H1N1 influenza A virus.', 88.706924), ('The Spanish flu, also known as the 1918 flu pandemic, was an unusually deadly influenza pandemic caused by the H1N1 influenza A virus.', 88.706924)]


(None, None)