In [23]:
import requests
from io import StringIO
import pandas as pd

In [24]:
res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [25]:
sentences = data['sentence_A'].tolist()
sentences[:5]

['A group of kids is playing in a yard and an old man is standing in the background',
 'A group of children is playing in the house and there is no man standing in the background',
 'The young boys are playing outdoors and the man is smiling nearby',
 'The kids are playing outdoors near a man with a smile',
 'The young boys are playing outdoors and the man is smiling nearby']

In [26]:
sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them
len(set(sentences))  # together we have ~4.5K unique sentences

4802

This isn't a particularly large number, so let's pull in a few more similar datasets.

In [27]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [28]:
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='skip')
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

In [29]:
len(set(sentences))

14505

Next, we remove any duplicates, leaving us with 14.5K unique sentences. Finally, we build our dense vector representations of each sentence using the sentence-BERT library.

In [30]:
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [31]:
from sentence_transformers import SentenceTransformer
# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape



(14504, 768)

In [12]:
!pip install faiss-gpu

Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.7.2


In [13]:
import faiss


In [32]:
d = sentence_embeddings.shape[1]
d

768

In [33]:
index = faiss.IndexFlatL2(d)

In [34]:
index.is_trained

True

In [35]:
index.add(sentence_embeddings)

In [36]:
index.ntotal

14504

In [37]:
k = 4
xq = model.encode(["Someone sprints with a football"])

In [38]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 7250  3007 11985 12785]]
CPU times: user 15.4 ms, sys: 0 ns, total: 15.4 ms
Wall time: 17.9 ms


In [40]:
data

Unnamed: 0,0,1,2
0,,Small dog chews on a big stick.,a dog shews on a big stick.
1,,A tennis player hitting the ball.,Two boys splashing in the surf.
2,,a lone snowboarder in the middle of a snowy gust,A snowboarder is throwing up snow as he rides ...
3,4.4,A pair of dogs playing with a purple ball.,Two dogs play with purple football.
4,0.6,a bird lands in the water.,a boat floats in the water.
...,...,...,...
1495,,A man doing a trick on a skateboard.,A man in mid air on a skateboard.
1496,,A young girl in swim goggles does the backstro...,A girl does the backstroke in the pool.
1497,5.0,A deer jumps a fence.,A deer is jumping over a fence.
1498,1.0,A young girl dressed in a Minnie mouse outfit ...,a man wearing a white suit holding a newspaper...


In [44]:
import numpy as np

vecs = np.zeros((k, d))
# then iterate through each ID from I and add the reconstructed vector to our zero-array
for i, val in enumerate(I[0].tolist()):
    vecs[i, :] = index.reconstruct(val)

In [45]:
vecs.shape

(4, 768)

In [46]:
vecs[0][:100]

array([ 0.01627072,  0.2232592 , -0.15037425, -0.30747271, -0.27122465,
       -0.10593155, -0.06460934,  0.04738171, -0.73349047, -0.37657681,
       -0.76762789,  0.16902882,  0.53107649,  0.51176697,  1.14415824,
       -0.08562881, -0.67240071, -0.96637076,  0.02545465, -0.21559809,
       -1.25656545, -0.82982159, -0.09825006, -0.21850838,  0.50610238,
        0.10527924,  0.50396848,  0.6524294 , -1.39458752,  0.65847486,
       -0.21525319, -0.22487473,  0.818183  ,  0.08464295, -0.76141769,
       -0.28928289, -0.09825794, -0.73046207,  0.07855801, -0.84354597,
       -0.59242105,  0.7747131 , -1.20920527, -0.22757922, -1.30733585,
       -0.23081493, -1.31322539,  0.01629098, -0.97285485,  0.19308187,
        0.47424555,  1.18920887, -1.96741295, -0.70061141, -0.2963869 ,
        0.60533738,  0.62407422, -0.70340371, -0.86754245,  0.17673171,
       -0.19170482, -0.02951987,  0.22623563, -0.16695446, -0.80402559,
       -0.45918921,  0.69675452, -0.24928184, -1.01478708, -0.92

In [47]:
nlist = 50  # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [48]:
index.is_trained

False

In [49]:
index.train(sentence_embeddings)
index.is_trained  # check if index is now trained

True

In [50]:
index.add(sentence_embeddings)
index.ntotal  # number of embeddings indexed

14504

In [51]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 3007 11985 12785  8210]]
CPU times: user 1.59 ms, sys: 0 ns, total: 1.59 ms
Wall time: 1.02 ms


In [52]:
index.nprobe = 10

In [53]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 7250  3007 11985 12785]]
CPU times: user 4.83 ms, sys: 990 µs, total: 5.82 ms
Wall time: 7.21 ms


In [54]:
index.make_direct_map()

In [55]:
index.reconstruct(7460)[:100]

array([ 0.2532105 , -1.1464525 ,  1.6131321 ,  0.1137628 , -0.48931664,
       -0.14094701, -0.6618215 ,  0.7676524 , -0.29652897, -0.41654602,
        0.4421947 ,  0.36599612,  0.7545412 ,  0.81472385, -0.59650886,
        0.07824454,  0.30118698, -0.5452254 , -0.09904951, -0.34059128,
        0.61253095, -0.40346298,  0.12878814, -0.03620107, -0.65426576,
       -0.73397946, -0.08348604, -0.39714962,  1.4018917 ,  0.60986745,
       -0.2931748 ,  0.34492987,  0.7750454 , -0.86670184, -1.4876094 ,
        0.0220895 ,  0.5365365 ,  0.03781834,  0.64236295,  0.9533433 ,
       -1.0234523 , -1.2253761 ,  0.4560891 ,  0.8965654 , -1.236459  ,
       -0.81649834,  1.0527798 ,  0.28629693,  0.2750041 ,  0.20900282,
       -0.60449535, -0.3767317 ,  0.17593057,  1.0868392 ,  0.06988721,
       -1.0946486 ,  0.13831893, -0.9020952 , -0.35559574,  0.16727714,
        0.26723957,  0.83093905, -0.14632499, -0.55993366, -0.61618835,
        0.87459385,  1.0174788 , -0.13155219, -0.05729083, -0.65

In [56]:
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

In [57]:
index.is_trained

False

In [58]:
index.train(sentence_embeddings)

In [59]:
index.add(sentence_embeddings)

In [60]:
index.nprobe = 10

In [61]:
%%time
D, I = index.search(xq, k)
print(I)

[[ 618 2298 3007 1616]]
CPU times: user 2.29 ms, sys: 19 µs, total: 2.31 ms
Wall time: 2.25 ms


In [62]:
[f'{i}: {sentences[i]}' for i in I[0]]

['618: A man in a football uniform is running with a football during a game.',
 '2298: Two football players are scrambling for the ball on the court',
 '3007: A group of people playing football is running in the field',
 '1616: Two teams are competing in a football match']