In [10]:
!pip install faiss-gpu
!pip install sentence-transformers

Collecting sentence-transformers
  Downloading sentence_transformers-3.1.0-py3-none-any.whl.metadata (23 kB)
Downloading sentence_transformers-3.1.0-py3-none-any.whl (249 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m249.1/249.1 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence-transformers
Successfully installed sentence-transformers-3.1.0


In [1]:
import requests
from io import StringIO
import pandas as pd

In [2]:
res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [3]:
# we take all samples from both sentence A and B
sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them
len(set(sentences))  # together we have ~4.5K unique sentences

4802

In [4]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [5]:
# each of these dataset have the same structure, so we loop through each creating our sentences data
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='skip')
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

In [6]:
len(set(sentences))

14505

In [7]:
# remove duplicates and NaN
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [8]:
len(sentences)

14504

The model bert-base-nli-mean-tokens is a pretrained sentence embedding model based on BERT (Bidirectional Encoder Representations from Transformers), specifically fine-tuned for Natural Language Inference (NLI) tasks. It is often used to generate high-quality sentence embeddings, which are fixed-size dense vector representations of sentences.

In [9]:
from sentence_transformers import SentenceTransformer
# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape

  from tqdm.autonotebook import tqdm, trange


(14504, 768)

In [10]:
sentence_embeddings

array([[-0.26019537,  0.50659895,  0.7228153 , ..., -0.71619177,
        -0.42532596,  0.08839572],
       [-0.13848802,  0.06270167,  0.20811662, ..., -0.37457234,
        -0.25141186, -0.39994887],
       [ 0.04773773, -0.2265471 ,  1.141043  , ..., -0.21873291,
         0.00557534,  0.18330637],
       ...,
       [ 0.3848488 ,  0.14460556, -0.5746919 , ...,  0.39329362,
         0.1355522 , -0.7091271 ],
       [ 0.1907153 , -0.28229755,  1.5370922 , ..., -0.34021226,
        -0.8530363 ,  0.9050524 ],
       [ 0.24362855,  0.04793499,  0.698116  , ...,  0.4753194 ,
        -0.15564944, -0.16404368]], dtype=float32)

In [11]:
import faiss

In [12]:
d = sentence_embeddings.shape[1]
d

768

## IndexFlatL2

In [13]:
index = faiss.IndexFlatL2(d)

In [14]:
#This tells us that the current index method needs training or not, If it returns True then it is already trained and we do not need to train
index.is_trained

True

In [15]:
index.add(sentence_embeddings)

In [18]:
index.ntotal

14504

In [19]:
#K is the number of nearest neighbours
k = 5
xq = model.encode(["Children playing cricket"])

In [20]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[1347 6941 6610 8855 1136]]
CPU times: user 15.4 ms, sys: 1.04 ms, total: 16.5 ms
Wall time: 19 ms


In [21]:
[sentences[i] for i in range(len(sentences)) if i in I[0]]

['A child is hitting a baseball',
 'People are playing cricket',
 'A few men are playing cricket',
 'Some people are playing cricket',
 'A young boy with a visor on plays ball with his bat in the street.']

In [22]:
#You can get there embeddings like this: -
sentence_embeddings[10297]

array([ 4.47415501e-01, -3.47982734e-01,  5.82039773e-01,  6.00613914e-02,
       -7.06242621e-02,  1.91074476e-01,  3.06938261e-01, -1.18242943e+00,
       -9.83801365e-01, -7.65966594e-01, -7.58031428e-01, -2.71596164e-01,
        4.35487986e-01,  3.24372828e-01,  4.82962459e-01, -6.34137928e-01,
        3.03477347e-02, -5.26754797e-01, -1.55304402e-01, -4.07194942e-01,
       -1.13126135e+00, -3.34909469e-01, -2.32707813e-01, -6.06418788e-01,
        7.45144844e-01,  9.17910397e-01,  2.12585017e-01,  5.45780897e-01,
       -1.33355486e+00, -1.07812069e-01,  8.00171227e-04, -3.85830879e-01,
        2.37114593e-01,  2.94348300e-01, -4.57491428e-01, -8.43519032e-01,
        6.27481520e-01, -1.28098920e-01,  5.44669986e-01, -1.29448771e-01,
       -9.97354183e-03,  8.94600689e-01, -2.00289965e-01,  2.36278892e-01,
        2.46542394e-01, -4.60772991e-01, -3.26925904e-01,  1.39552653e+00,
        1.90295637e-01, -1.44713432e-01, -3.26448798e-01, -1.69399250e-02,
       -6.97132826e-01, -

## IndexIVFFlat

In [71]:
nlist = 50  # how many cells or how many voronoi cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [72]:
index.nprobe=5 #Number of nearest clusters to search in (higher = more accurate but slower) basically we will search the nearest 5 centroids clusters

In [73]:
index.is_trained

False

In [74]:
index.train(sentence_embeddings)
index.is_trained  # check if index is now trained

True

In [75]:
index.add(sentence_embeddings)
index.ntotal  # number of embeddings indexed

14504

In [76]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[1347 6941 6610 8855 1136]]
CPU times: user 3.42 ms, sys: 23 µs, total: 3.44 ms
Wall time: 3.36 ms


The above calculation you can see took less time than the previous one. Also the results are same below sentences.

In [28]:
[sentences[i] for i in range(len(sentences)) if i in I[0]]

['A child is hitting a baseball',
 'People are playing cricket',
 'A few men are playing cricket',
 'Some people are playing cricket',
 'A young boy with a visor on plays ball with his bat in the street.']

## HNSW

In [29]:
M = 16  # Number of neighbors each point is connected to
index = faiss.IndexHNSWFlat(d, M)  # d is the dimension, M is the number of neighbors
index.hnsw.efConstruction = 40  # Quality of graph during construction

'''
The "quality of the graph during construction" in HNSW
Higher Quality: A graph with higher quality has more connections between nodes, meaning each node is connected to more of its neighbors. This makes the search more accurate because there are more paths through which the algorithm can navigate to find the nearest neighbors.
Lower Quality: A sparser graph has fewer connections, which can make the search faster but less accurate because there are fewer paths for the algorithm to traverse.

A higher-quality graph typically results in better search accuracy because the increased connectivity allows the algorithm to find the nearest neighbors more precisely.
Lower-quality graphs might miss some nearest neighbors or might find approximate results that are less accurate.

Building a higher-quality graph generally takes more time and computational resources. More connections are created, and the algorithm performs more computations during graph construction.
A lower-quality graph can be constructed more quickly but may sacrifice some accuracy during the search phase.

Definition: The efConstruction parameter controls the number of candidates considered during the graph construction phase. Higher values lead to more thorough connections.
'''


In [30]:
index.is_trained

True

In [31]:
index.add(sentence_embeddings)
index.ntotal

14504

In [33]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[1347 6941 6610 8855 1136]]
CPU times: user 909 µs, sys: 0 ns, total: 909 µs
Wall time: 643 µs


See above, It is soo fast, also the sentences are same below.

In [34]:
[sentences[i] for i in range(len(sentences)) if i in I[0]]

['A child is hitting a baseball',
 'People are playing cricket',
 'A few men are playing cricket',
 'Some people are playing cricket',
 'A young boy with a visor on plays ball with his bat in the street.']

## Quantized Version

In [77]:
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

In [78]:
index.is_trained

False

In [79]:
index.train(sentence_embeddings)

In [80]:
index.is_trained

True

In [81]:
index.nprobe = 2  # Number of clusters to search in (higher = more accurate but slower) basically same as the IVF where we search
#so basically if we pass 10 then we will consider the nearest 10 centroids and search in them same applies fro simple IVF too.

In [82]:
%%time
D, I = index.search(xq, k)
print(I)

[[-1 -1 -1 -1 -1]]
CPU times: user 2.81 ms, sys: 0 ns, total: 2.81 ms
Wall time: 2.94 ms
