In [7]:
import json
import faiss
import numpy as np
import torch
from transformers import DistilBertTokenizer, DistilBertModel

A bit more philosophically, you can think of each number in the vector as a coordinate in an N-dimensional space (where N is the length of the vector). The working assumption here is that, **if the transformer learned a useful representation of the document, similar documents will live close together in that N-dimensional space.**

In [2]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained("distilbert-base-uncased")

documents = [
    "That restaurant was not as good as the last movie I watched.",
    "I'm selling a used car in good condition",
    "Food was okay, the rest so so",
    "I love cats, but don't really like hyenas",
    "On the road, you must be careful",
]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
vectors = [
  # tokenize the document, return it as PyTorch tensors (vectors),
  # and pass it onto the model
  model(**tokenizer(document, return_tensors='pt'))[0].detach().squeeze()
  for document in documents
]

[v.size() for v in vectors]

[torch.Size([15, 768]),
 torch.Size([12, 768]),
 torch.Size([10, 768]),
 torch.Size([15, 768]),
 torch.Size([10, 768])]

In [5]:
# avg representation
averaged_vectors = [torch.mean(vector, dim=0) for vector in vectors]
[v.size() for v in averaged_vectors]

[torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768]),
 torch.Size([768])]

In [6]:
def encode(document: str) -> torch.Tensor:
    """this function will return a transformer representation of the doc"""
    tokens = tokenizer(document, return_tensors='pt')
    vector = model(**tokens)[0].detach().squeeze()
    return torch.mean(vector, dim=0)

In [8]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
index.add_with_ids(
    np.array([t.numpy() for t in averaged_vectors]),
    np.array(range(0, len(documents)))
)

def search(query: str, k=1):
    encoded_query = encode(query).unsqueeze(dim=0).numpy()
    top_k = index.search(encoded_query, k)
    scores = top_k[0][0]
    results = [documents[_id] for _id in top_k[1][0]]
    return list(zip(results, scores))

In [10]:
documents[1]

"I'm selling a used car in good condition"

In [11]:
search(documents[1], k=2)

[("I'm selling a used car in good condition", 70.69185),
 ('On the road, you must be careful', 53.795795)]

In [12]:
search('I know how to drive', k=1)

[('On the road, you must be careful', 54.49833)]

In [13]:
index

<faiss.swigfaiss_avx2.IndexIDMap; proxy of <Swig Object of type 'faiss::IndexIDMapTemplate< faiss::Index > *' at 0x7fda0332be40> >