#### Install and import dependencies

In [1]:
# If using Colab Notebook, install faiss and transformers right in the notebook
!pip install faiss-cpu
!pip install transformers

[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m
[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m
You should consider upgrading via the '/usr/local/opt/python@3.8/bin/python3.8 -m pip install --upgrade pip' command.[0m
[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m
[33mDEPRECATION: Configuring installation scheme with distutils config

In [2]:
import json
from pprint import pprint
import faiss
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

#### Load and vectorize the documents

In [3]:
# Load the documents
with open('data/sentences.json', 'r') as file:
    documents = json.load(file)

In [4]:
# Load the a BERT model and a tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
# Build a function that uses a BERT model to vectorize the texts
def encode(document: str) -> torch.Tensor:
    tokens = tokenizer(document, return_tensors='pt')
    vector = model(**tokens)[0].detach().squeeze()
    return torch.mean(vector, dim=0)

In [6]:
documents[0]

'A pandemic is an epidemic of an infectious disease that has spread across a large region, for instance multiple continents or worldwide, affecting a substantial number of people.'

In [7]:
tokens = tokenizer(documents[0], return_tensors='pt')
tokens

{'input_ids': tensor([[  101,  1037,  6090,  3207,  7712,  2003,  2019, 16311,  1997,  2019,
         16514,  4295,  2008,  2038,  3659,  2408,  1037,  2312,  2555,  1010,
          2005,  6013,  3674, 17846,  2030,  4969,  1010, 12473,  1037,  6937,
          2193,  1997,  2111,  1012,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [8]:
vector = model(**tokens)[0].detach().squeeze()
vector

tensor([[ 0.1025, -0.2056, -0.5713,  ..., -0.0051,  0.1253,  0.7564],
        [-0.3454, -0.0013, -0.7068,  ..., -0.3020,  0.0376,  0.4525],
        [ 0.0288, -0.1346, -0.0729,  ..., -0.2813,  0.3779, -0.0193],
        ...,
        [ 0.2545,  0.4376, -0.0790,  ..., -0.3760, -0.1083, -0.3239],
        [ 0.8220,  0.1464, -0.5341,  ...,  0.1889, -0.5492, -0.2128],
        [ 0.1956,  0.2456, -0.4352,  ..., -0.0896, -0.2934,  0.2751]])

In [9]:
# vectorize the documents
vectors = [encode(d) for d in documents]

In [14]:
np.array([vec.numpy() for vec in vectors]).shape

(11, 768)

#### Build a FAISS index

In [10]:
# Create a flat Faiss index
inner_index = faiss.IndexFlatIP(768) # the size of our vector space
index = faiss.IndexIDMap(inner_index)

# Add the document vectors into the index. They need to be transformed into numpy arrays first
index.add_with_ids(
    np.array([vec.numpy() for vec in vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(documents)))
)
faiss.write_index(index, 'data/pandemics')

In [16]:
encoded_query = encode("spanish flu casualties").unsqueeze(dim=0).numpy()
encoded_query.shape

(1, 768)

In [17]:
top_k = index.search(encoded_query, 2)
top_k

(array([[51.069527, 45.2031  ]], dtype=float32), array([[9, 3]]))

#### Search the index

In [12]:
def search(query: str, documents, k=5):
    encoded_query = encode(query).unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    scores = top_k[0][0]
    results = [documents[_id] for _id in top_k[1][0]]
    return list(zip(results, scores))

In [13]:
pprint(search("spanish flu casualties", documents, k=2))

[('The Spanish flu, also known as the 1918 flu pandemic, was an unusually '
  'deadly influenza pandemic caused by the H1N1 influenza A virus.',
  51.069527),
 ('As of 2018, approximately 37.9 million people are infected with HIV '
  'globally.',
  45.2031)]
