In [2]:
import pandas as pd
import numpy as np
import torch
import transformers
from transformers import AutoTokenizer
from transformers import DistilBertTokenizer, DistilBertModel
from sentence_transformers import SentenceTransformer
import faiss
import sentence_transformers

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


Moving 0 files to the new cache system


0it [00:00, ?it/s]


In [3]:
sentences = pd.read_json("https://lp-prod-resources.s3.amazonaws.com/493/57248/2021-05-04-13-31-46/sentences.json").rename(columns = {0:"sentence_text"}) 
questions = pd.read_json("https://lp-prod-resources.s3.amazonaws.com/493/57248/2021-08-16-19-04-45/questions.json").rename(columns = {0:"question_text"}) 

In [4]:
sentences, questions

(                                        sentence_text
 0   A pandemic is an epidemic of an infectious dis...
 1   The most fatal pandemic in recorded history wa...
 2   Current pandemics include COVID-19 (SARS-CoV-2...
 3   As of 2018, approximately 37.9 million people ...
 4   Cholera is an infection of the small intestine...
 5   Classic cholera symptom is large amounts of wa...
 6   The COVID-19 pandemic, also known as the coron...
 7   Common symptoms of COVID-19 include fever, cou...
 8   The Plague of Cyprian was a pandemic that affl...
 9   The Spanish flu, also known as the 1918 flu pa...
 10  The death toll of Spanish Flu is estimated to ...,
                                        question_text
 0      How many people have died during Black Death?
 1      Which diseases can be transmitted by animals?
 2  Connection between climate change and a likeli...
 3               What is an example of a latent virus
 4                          Viruses in nanotechnology
 5             

In [5]:
model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')

In [6]:
model.encode("Test String").shape

(768,)

In [7]:
vectors = [
    model.encode(document)
    for document in sentences.sentence_text
]

In [8]:
[v.shape for v in vectors]

[(768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,),
 (768,)]

In [36]:
vectors[1]

array([ 1.11043811e+00, -8.38996112e-01,  1.19438671e-01, -6.83997154e-01,
        4.34035629e-01,  2.10175231e-01,  1.31315857e-01, -6.84658170e-01,
       -2.32820094e-01, -9.02489007e-01,  2.74648726e-01,  8.18223283e-02,
       -7.81049669e-01,  6.55575871e-01,  4.38592434e-01,  1.87987499e-02,
        1.48680127e+00, -1.58416316e-01, -5.15827298e-01, -1.82558417e-01,
       -4.37732965e-01,  3.22552204e-01,  7.44238555e-01,  9.90150034e-01,
       -6.38946831e-01, -5.82955219e-02,  2.03597303e-02,  4.45197448e-02,
        7.57766604e-01, -8.83774012e-02,  1.40692937e+00, -1.74215868e-01,
       -2.42551908e-01, -3.41019891e-02, -7.43072331e-01,  5.10099009e-02,
       -1.34035122e+00,  2.46486843e-01,  3.60357910e-01, -1.06637621e+00,
       -1.38826072e+00, -5.81414521e-01,  6.03832126e-01, -1.73750624e-01,
        3.74760777e-02,  7.53780961e-01,  1.15123475e+00,  2.63759345e-01,
        4.35051262e-01,  5.07907391e-01, -7.77997077e-01,  4.16097492e-02,
       -1.05468348e-01,  

In [48]:
def encode(document: str):
  vector = model.encode(document)
  return vector

In [49]:
vectors[0].shape[0]

768

In [50]:
index = faiss.IndexFlatL2(vectors[0].shape[0])

In [51]:
index = faiss.IndexIDMap(index)

In [52]:
index.add_with_ids(
    np.array([t for t in vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(sentences.sentence_text))).astype(np.int64)
)

In [53]:
def search(query: str, k=1):
  encoded_query = np.expand_dims(encode(query), 0)
  top_k = index.search(encoded_query, k)
  scores = top_k[0][0]
  results = [sentences.sentence_text[_id] for _id in top_k[1][0]]
  return list(zip(results, scores))

In [57]:
q_num = 1
num_similar_docs = 5

print("Question: {}".format("HIV")), print("Answer: {}".format(search("37.9 million people", num_similar_docs)))

Question: HIV
Answer: [('As of 2018, approximately 37.9 million people are infected with HIV globally.', 153.05344), ('A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.', 402.02948), ('The Plague of Cyprian was a pandemic that afflicted the Roman Empire about from AD 249 to 262.', 423.2583), ('The death toll of Spanish Flu is estimated to have been somewhere between 17 million and 50 million, and possibly as high as 100 million, making it one of the deadliest pandemics in human history.', 470.86157), ('Common symptoms of COVID-19 include fever, cough, fatigue, breathing difficulties, and loss of smell.', 471.43066)]


(None, None)

In [58]:
faiss.write_index(index, 'search_index_2')