In [1]:
import requests
from io import StringIO
import pandas as pd

res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')

sentences = data['sentence_A'].tolist()

sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them

In [2]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

# each of these dataset have the same structure, so we loop through each creating our sentences data
for url in urls:
    res = requests.get(url)
    # extract to dataframe
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='skip')
    # add to columns 1 and 2 to sentences list
    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())

In [3]:
# remove duplicates and NaN
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [4]:
from sentence_transformers import SentenceTransformer
# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)

In [5]:
import faiss
d = sentence_embeddings.shape[1]
index = faiss.IndexFlatL2(d)

index.add(sentence_embeddings)

print(index.ntotal)

14504


In [6]:
k = 4
xq = model.encode(["Someone sprints with a football"])

In [7]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 4851 13499  1592  2080]]
CPU times: user 1.68 ms, sys: 1.47 ms, total: 3.15 ms
Wall time: 2.39 ms


In [8]:
df = pd.DataFrame(sentences)
df.iloc[I[0]]

Unnamed: 0,0
4851,A group of football players is running in the ...
13499,A group of people playing football is running ...
1592,Two groups of people are playing football
2080,A person playing football is running past an o...


In [9]:
import numpy as np

# we have 4 vectors to return (k) - so we initialize a zero array to hold them
vecs = np.zeros((k, d))
# then iterate through each ID from I and add the reconstructed vector to our zero-array
for i, val in enumerate(I[0].tolist()):
    vecs[i, :] = index.reconstruct(val)

In [10]:
vecs.shape

(4, 768)

In [11]:
vecs[0][:100]

array([ 0.01627016,  0.22325927, -0.15037419, -0.3074725 , -0.27122453,
       -0.1059317 , -0.06460921,  0.04738173, -0.73349053, -0.37657681,
       -0.76762795,  0.16902876,  0.53107649,  0.51176685,  1.14415848,
       -0.08562913, -0.67240077, -0.96637076,  0.02545468, -0.21559817,
       -1.25656593, -0.82982165, -0.09824992, -0.21850854,  0.50610226,
        0.10527941,  0.50396854,  0.65242964, -1.39458752,  0.6584745 ,
       -0.21525328, -0.22487439,  0.81818342,  0.08464288, -0.76141691,
       -0.28928283, -0.09825843, -0.7304619 ,  0.07855771, -0.84354633,
       -0.59242111,  0.77471358, -1.20920551, -0.22757958, -1.30733597,
       -0.23081468, -1.31322515,  0.01629069, -0.97285467,  0.19308208,
        0.47424546,  1.18920863, -1.96741295, -0.70061129, -0.29638684,
        0.60533673,  0.62407446, -0.70340395, -0.86754221,  0.17673145,
       -0.191705  , -0.02951988,  0.22623558, -0.16695474, -0.80402523,
       -0.45918933,  0.69675493, -0.24928206, -1.01478684, -0.92

In [12]:
nlist = 50  # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [13]:
index.is_trained

False

In [14]:
index.train(sentence_embeddings)
index.is_trained  # check if index is now trained

True

In [15]:
index.add(sentence_embeddings)
index.ntotal  # number of embeddings indexed

14504

In [16]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 4851 13499  8199 13258]]
CPU times: user 7.94 ms, sys: 8.76 ms, total: 16.7 ms
Wall time: 852 µs


In [17]:
index.nprobe = 10

In [18]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 4851 13499  1592  2080]]
CPU times: user 317 µs, sys: 349 µs, total: 666 µs
Wall time: 505 µs


In [19]:
index.make_direct_map()

In [20]:
index.reconstruct(7460)[:100]

array([ 3.04394990e-01, -2.84085423e-01,  2.02326822e+00,  4.01441276e-01,
        1.29176721e-01,  4.74078238e-01, -5.89090586e-01,  5.95174670e-01,
       -4.31007594e-01, -2.53566921e-01, -7.75542676e-01,  2.09441051e-01,
        5.23927808e-01,  9.53339696e-01, -3.38644058e-01, -3.12875994e-02,
       -4.28985327e-01, -6.78114772e-01,  1.25197813e-01, -8.27183664e-01,
        3.72428745e-01,  7.88508773e-01,  4.12336707e-01, -5.37086129e-01,
       -8.79266977e-01, -7.75970995e-01, -4.84306544e-01, -1.50105238e+00,
       -2.17005953e-01,  1.39823020e-01,  1.43711686e-01, -6.38287663e-01,
        6.29600883e-01,  1.16711289e-01, -3.79063159e-01,  1.61511853e-01,
        6.17935359e-01, -3.27075124e-02,  6.11150414e-02, -5.80821753e-01,
        3.92929882e-01,  4.34127003e-01,  9.69810486e-01,  9.00198638e-01,
       -3.77318382e-01, -4.86120760e-01,  6.72118425e-01,  1.16454877e-01,
        3.77468318e-02, -7.94528961e-01, -1.45748162e+00, -1.75383747e-01,
        4.08601820e-01,  

In [21]:
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits) 

In [22]:
index.is_trained

False

In [23]:
index.train(sentence_embeddings)

In [24]:
index.add(sentence_embeddings)

In [25]:
index.nprobe = 10  # align to previous IndexIVFFlat nprobe value

In [26]:
%%time
D, I = index.search(xq, k)
print(I)

[[11170  6510  9088  2080]]
CPU times: user 4.06 ms, sys: 4.77 ms, total: 8.83 ms
Wall time: 301 µs


In [27]:
[f'{i}: {sentences[i]}' for i in I[0]]

['11170: A football player is running past an official carrying a football',
 '6510: A football player kicks the ball.',
 '9088: Football players are on the field.',
 '2080: A person playing football is running past an official carrying a football']