#### Install and import dependencies

In [1]:
# If using Colab Notebook, install faiss and transformers right in the notebook
!pip install faiss-cpu
!pip install transformers

Collecting faiss-cpu
  Downloading faiss_cpu-1.7.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)
[K     |████████████████████████████████| 8.6 MB 36.3 MB/s eta 0:00:01
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.7.2


In [32]:
import json
from pprint import pprint
import faiss
import numpy as np
import torch
from transformers import AutoModel, AutoTokenizer

#### Load and vectorize the documents

In [33]:
# Load the documents
with open('data/sentences.json', 'r') as file:
    documents = json.load(file)

In [34]:
# Load the a BERT model and a tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
# Build a function that uses a BERT model to vectorize the texts
def encode(document: str) -> torch.Tensor:
    tokens = tokenizer(document, return_tensors='pt')
    vector = model(**tokens)[0].detach().squeeze()
    return torch.mean(vector, dim=0)

In [36]:
# vectorize the documents
vectors = [encode(d) for d in documents]

#### Build a Faiss index

In [40]:
# Create a flat Faiss index
index = faiss.IndexIDMap(faiss.IndexFlatIP(768)) # the size of our vector space

# Add the document vectors into the index. They need to be transformed into numpy arrays first
index.add_with_ids(
    np.array([vec.numpy() for vec in vectors]),
    # the IDs will be 0 to len(documents)
    np.array(range(0, len(documents)))
)
faiss.write_index(index, 'data/pandemics')

#### Search the index

In [41]:
def search(query: str,documents, k=5):
    encoded_query = encode(query).unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    scores = top_k[0][0]
    results = [documents[_id] for _id in top_k[1][0]]
    return list(zip(results, scores))

In [42]:
pprint(search("spanish flu casualties",documents, k=2))

[('The Spanish flu, also known as the 1918 flu pandemic, was an unusually '
  'deadly influenza pandemic caused by the H1N1 influenza A virus.',
  51.069508),
 ('As of 2018, approximately 37.9 million people are infected with HIV '
  'globally.',
  45.203117)]
