In [1]:
import requests
from io import StringIO
import pandas as pd

In [2]:
res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [3]:
# we take all samples from both sentence A and B
sentences = data['sentence_A'].tolist()
sentences[:5]

['A group of kids is playing in a yard and an old man is standing in the background',
 'A group of children is playing in the house and there is no man standing in the background',
 'The young boys are playing outdoors and the man is smiling nearby',
 'The kids are playing outdoors near a man with a smile',
 'The young boys are playing outdoors and the man is smiling nearby']

In [4]:
# we take all samples from both sentence A and B
sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them
len(set(sentences))  # together we have ~4.5K unique sentences

4802

In [5]:
sentences[:5]

['A group of kids is playing in a yard and an old man is standing in the background',
 'A group of children is playing in the house and there is no man standing in the background',
 'The young boys are playing outdoors and the man is smiling nearby',
 'The kids are playing outdoors near a man with a smile',
 'The young boys are playing outdoors and the man is smiling nearby']

In [6]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [7]:
# each of these dataset have the same structure, so we loop through each creating our sentences data
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    # data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='skip')
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

In [8]:
len(set(sentences))

14505

In [9]:
# remove duplicates and NaN
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [10]:
from sentence_transformers import SentenceTransformer
# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape

  from .autonotebook import tqdm as notebook_tqdm


(14504, 768)

In [11]:
import faiss

In [12]:
d = sentence_embeddings.shape[1]
d

768

In [13]:
index = faiss.IndexFlatL2(d)
index

<faiss.swigfaiss_avx2.IndexFlatL2; proxy of <Swig Object of type 'faiss::IndexFlatL2 *' at 0x7feed51b5d10> >

In [14]:
index.is_trained

True

In [15]:
index.add(sentence_embeddings)

In [16]:
index.ntotal

14504

In [17]:
k = 4
xq = model.encode(["Someone sprints with a football"])

In [18]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[13324  7537  2871   285]]
CPU times: user 2.46 ms, sys: 439 µs, total: 2.9 ms
Wall time: 2.26 ms


In [None]:
data['sentence_A'].iloc[[4586, 10252, 12465, 190]]

In [21]:
import numpy as np
# we have 4 vectors to return (k) - so we initialize a zero array to hold them
vecs = np.zeros((k, d))
# then iterate through each ID from I and add the reconstructed vector to our zero-array
for i, val in enumerate(I[0].tolist()):
    vecs[i, :] = index.reconstruct(val)

In [22]:
vecs.shape

(4, 768)

In [23]:
vecs[0][:100]

array([ 0.01627057,  0.22325902, -0.15037422, -0.30747274, -0.27122441,
       -0.10593193, -0.06460959,  0.04738158, -0.73349029, -0.3765769 ,
       -0.76762778,  0.16902904,  0.53107625,  0.51176697,  1.14415824,
       -0.08562907, -0.67240065, -0.96637058,  0.02545422, -0.21559791,
       -1.25656605, -0.82982141, -0.09824955, -0.21850848,  0.50610214,
        0.10527924,  0.5039683 ,  0.65242988, -1.39458704,  0.65847462,
       -0.21525325, -0.22487473,  0.8181836 ,  0.08464272, -0.76141697,
       -0.28928283, -0.09825852, -0.73046178,  0.07855759, -0.84354687,
       -0.59242094,  0.77471352, -1.20920551, -0.22757928, -1.30733573,
       -0.23081458, -1.31322575,  0.01629071, -0.97285485,  0.19308197,
        0.47424522,  1.18920851, -1.96741307, -0.70061117, -0.2963866 ,
        0.6053372 ,  0.62407416, -0.70340377, -0.86754256,  0.1767319 ,
       -0.19170512, -0.02951975,  0.22623561, -0.16695452, -0.80402535,
       -0.45918918,  0.69675523, -0.24928199, -1.01478684, -0.92

In [24]:
nlist = 50  # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [25]:
index.is_trained

False

In [26]:
index.train(sentence_embeddings)
index.is_trained  # check if index is now trained

True

In [27]:
index.add(sentence_embeddings)
index.ntotal  # number of embeddings indexed

14504

In [28]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[13324  7537  2871   285]]
CPU times: user 459 µs, sys: 280 µs, total: 739 µs
Wall time: 515 µs


In [29]:
index.nprobe = 10

In [30]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[13324  7537  2871   285]]
CPU times: user 1.29 ms, sys: 0 ns, total: 1.29 ms
Wall time: 859 µs


In [31]:
index.make_direct_map()

In [32]:
index.reconstruct(7460)[:100]

array([ 0.04719909,  0.27071124, -0.25176895,  0.20996305,  0.11449947,
        0.05592548, -0.1624022 ,  0.35324556, -0.23271006, -0.24057451,
       -1.2727222 ,  1.0075784 ,  0.21982376,  0.85123116,  0.7766814 ,
        0.5149639 , -0.6781561 , -0.4014718 , -0.32394302, -0.04647766,
       -0.48390946, -0.06864288,  0.8759547 , -0.60910803,  0.07459264,
       -0.23768818, -0.825604  ,  0.56349486, -0.37702727,  0.43336022,
        0.07187182, -0.32258922,  0.33846772, -0.6341695 , -1.7811332 ,
       -0.24863796,  0.05044975,  0.35617852,  0.5964581 , -0.17314644,
        0.8332708 ,  0.36916506,  0.24882203, -0.11726724, -0.8684139 ,
       -0.5823152 ,  0.9893843 , -0.540314  , -0.08682217,  0.28553036,
        0.22805563,  0.6131767 , -0.0706891 ,  0.13420382,  0.29612824,
       -0.4199144 , -0.25088763, -0.4078441 ,  0.38194108, -0.24361636,
       -0.340561  ,  0.65562594, -0.44020045, -0.84618604, -0.7984815 ,
        0.5892812 , -0.20308773, -0.56330496, -0.2699528 , -0.25

In [33]:
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits) 

In [34]:
index.is_trained

False

In [35]:
index.train(sentence_embeddings)

In [36]:
index.add(sentence_embeddings)

In [37]:
index.nprobe = 10  # align to previous IndexIVFFlat nprobe value

In [38]:
%%time
D, I = index.search(xq, k)
print(I)

[[ 285  533  901 1962]]
CPU times: user 552 µs, sys: 0 ns, total: 552 µs
Wall time: 428 µs


In [39]:
[f'{i}: {sentences[i]}' for i in I[0]]

['285: A person playing football is running past an official carrying a football',
 '533: some football players in red jerseys taking another in a white jersey to the ground',
 '901: Two teams are competing in a football match',
 '1962: Different teams are playing football on a field']