In [1]:
import requests 
import pandas as pd 
from io import StringIO 

In [2]:
res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [3]:
sentences = data['sentence_A'].tolist() 
sentences[:5]

['A group of kids is playing in a yard and an old man is standing in the background',
 'A group of children is playing in the house and there is no man standing in the background',
 'The young boys are playing outdoors and the man is smiling nearby',
 'The kids are playing outdoors near a man with a smile',
 'The young boys are playing outdoors and the man is smiling nearby']

In [4]:
sentences = data['sentence_A'].tolist() 
sentence_b = data['sentence_B'].tolist() 

sentences.extend(sentence_b)
len(set(sentences))

4802

In [5]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [6]:
for url in urls:
    res = requests.get(url)
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='skip')
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

In [7]:
len(set(sentences))

14505

In [8]:
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [9]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('bert-base-nli-mean-tokens')
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape 

(14504, 768)

In [10]:
import faiss 

In [11]:
d = sentence_embeddings.shape[1]
d 

768

In [12]:
index = faiss.IndexFlatL2(d)

In [13]:
index.is_trained

True

In [14]:
index.add(sentence_embeddings)
index.ntotal

14504

In [15]:
k = 4
xq = model.encode(['Someone sprints with a football'])

In [16]:
%%time 
D, I = index.search(xq, k)
print(I)

[[12522  1724  8370 11995]]
CPU times: total: 62.5 ms
Wall time: 16 ms


In [17]:
[f'{i}: {sentences[i]}' for i in I[0]]

['12522: A group of football players is running in the field',
 '1724: A group of people playing football is running in the field',
 '8370: Two groups of people are playing football',
 '11995: A person playing football is running past an official carrying a football']

In [18]:
import numpy as np

vecs = np.zeros((k, d))

for i, val in enumerate(I[0].tolist()):
    vecs[i, :] = index.reconstruct(val)

In [19]:
vecs.shape 

(4, 768)

In [20]:
vecs[0][:100]

array([ 0.01627026,  0.22325921, -0.15037383, -0.3074725 , -0.27122411,
       -0.1059322 , -0.06460909,  0.04738213, -0.73349071, -0.37657705,
       -0.76762825,  0.16902907,  0.53107649,  0.5117662 ,  1.14415824,
       -0.08562867, -0.67240065, -0.96637082,  0.0254546 , -0.21559845,
       -1.25656605, -0.82982188, -0.09824985, -0.21850882,  0.50610268,
        0.10527927,  0.50396878,  0.65242976, -1.39458668,  0.65847492,
       -0.21525343, -0.22487457,  0.81818342,  0.08464303, -0.76141709,
       -0.28928316, -0.09825778, -0.73046142,  0.07855811, -0.84354585,
       -0.59242046,  0.7747137 , -1.20920575, -0.22757983, -1.30733597,
       -0.23081483, -1.31322527,  0.01629082, -0.97285444,  0.19308162,
        0.47424546,  1.18920887, -1.96741295, -0.70061129, -0.2963877 ,
        0.60533702,  0.6240744 , -0.70340389, -0.86754185,  0.17673105,
       -0.19170599, -0.02951975,  0.22623511, -0.16695438, -0.80402523,
       -0.45918944,  0.69675517, -0.24928214, -1.01478684, -0.92

In [21]:
nlist = 50 
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [22]:
index.is_trained

False

In [23]:
index.train(sentence_embeddings)
index.is_trained

True

In [24]:
index.add(sentence_embeddings)
index.ntotal

14504

In [25]:
%%time 
D, I = index.search(xq, k)
print(I)

[[12522  1724  8370 11995]]
CPU times: total: 0 ns
Wall time: 997 µs


In [26]:
index.nprobe = 10 

In [27]:
%%time 
D, I = index.search(xq, k)
print(I)

[[12522  1724  8370 11995]]
CPU times: total: 0 ns
Wall time: 5.65 ms


In [28]:
index.make_direct_map()

In [29]:
index.reconstruct(1503)[:100]

array([ 5.11477530e-01,  2.50236094e-01,  5.32061219e-01,  4.22964215e-01,
       -1.05860777e-01, -6.13012254e-01,  1.27440798e+00,  2.29222566e-01,
       -3.43923792e-02,  2.49006867e-01, -1.02004655e-01,  1.76842198e-01,
        4.79001969e-01,  6.37657583e-01,  1.42909408e-01,  8.58738959e-01,
        5.52705664e-04,  1.54480308e-01, -9.28468779e-02,  4.21602428e-01,
       -5.90845942e-01, -3.13356787e-01,  1.68123782e-01, -7.63870716e-01,
        3.96952838e-01,  5.99601924e-01,  2.61421889e-01, -4.54949200e-01,
       -1.18678296e+00,  8.22250724e-01, -9.76284668e-02,  1.15797174e+00,
       -4.72993970e-01,  5.84034383e-01,  4.82119828e-01,  7.12632120e-01,
        1.09804288e-01, -5.02621353e-01, -4.24552932e-02, -2.83914566e-01,
        3.67833495e-01, -5.18662930e-01, -1.66022405e-01,  5.23425758e-01,
       -3.73695433e-01, -2.26753175e-01,  1.63917685e+00,  3.16410512e-01,
        1.70677811e-01, -2.41446778e-01,  2.41915271e-01, -2.77081192e-01,
        6.23954058e-01,  

In [30]:
m = 8 
bits = 8 

quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

In [31]:
index.is_trained

False

In [32]:
index.train(sentence_embeddings)

In [33]:
index.add(sentence_embeddings)

In [34]:
index.nprobe = 10 

In [35]:
%%time 
D, I = index.search(xq, k)
print(I)

[[ 1724 12522  6988  9772]]
CPU times: total: 0 ns
Wall time: 998 µs


In [36]:
[f'{i}: {sentences[i]}' for i in I[0]]

['1724: A group of people playing football is running in the field',
 '12522: A group of football players is running in the field',
 '6988: Different football players are teaming on a field',
 '9772: Different teams are playing football on a field']