In [None]:
!pip install torch
!pip install faiss-gpu
!pip install sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting faiss-gpu
  Using cached faiss_gpu-1.7.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
Installing collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import faiss
import numpy as np
import time
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-MiniLM-L6-v2')

sentences = [
          'A man is eating food.',
          'A man is eating a piece of bread.',
          'The girl is carrying a baby.',
          'A man is riding a horse.',
          'A woman is playing violin.',
          'Two men pushed carts through the woods.',
          'A man is riding a white horse on an enclosed ground.',
          'A monkey is playing drums.',
          'A cheetah is running behind its prey.'
]
encoded_data = model.encode(sentences)
encoded_data = np.asarray(encoded_data.astype('float32'))
index = faiss.IndexIDMap(faiss.IndexFlatL2(encoded_data.shape[1]))

ids = np.array(range(0, len(sentences)))
ids = np.asarray(ids.astype('int64'))
print(ids.shape, encoded_data.shape, type(ids))

index.add_with_ids(encoded_data, ids)
faiss.write_index(index, 'search.index')

del index
index = faiss.read_index('search.index')


def fetch_sentence(id):
    return sentences[id]


def search(query, index, model,  *, top_k=10):
    t = time.time()
    query_vector = model.encode([query])
    top_k_result = index.search(np.asarray(query_vector.astype('float32')), top_k)
    print('Results in total time:', time.time() - t)
    top_k_ids = top_k_result[1].tolist()[0]
    top_k_ids = list(np.unique(top_k_ids))
    result = [fetch_sentence(id) for id in top_k_ids]
    return result

(9,) (9, 384) <class 'numpy.ndarray'>


In [None]:
print(search('the person eating pasta', index, model, top_k=3))


Results in total time: 0.024699687957763672
['A man is eating food.', 'A man is eating a piece of bread.', 'A cheetah is running behind its prey.']


In [None]:
print(search('a running person', index, model, top_k=3))

Results in total time: 0.0238034725189209
['The girl is carrying a baby.', 'A man is riding a white horse on an enclosed ground.', 'A cheetah is running behind its prey.']
