In [None]:
!pip install faiss-gpu
!pip install sentence_transformers

In [None]:
import requests
from io import StringIO
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss

## Data-Preprocessing

In [None]:
res = requests.get('https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt')
# create dataframe
data = pd.read_csv(StringIO(res.text), sep='\t')
data.head()

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL
2,3,The young boys are playing outdoors and the ma...,The kids are playing outdoors near a man with ...,4.7,ENTAILMENT
3,5,The kids are playing outdoors near a man with ...,A group of kids is playing in a yard and an ol...,3.4,NEUTRAL
4,9,The young boys are playing outdoors and the ma...,A group of kids is playing in a yard and an ol...,3.7,NEUTRAL


In [None]:
sentences = data['sentence_A'].tolist()
sentence_b = data['sentence_B'].tolist()
sentences.extend(sentence_b)  # merge them

In [None]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [None]:
for url in urls:
    res = requests.get(url)
    data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)

    sentences.extend(data[1].tolist())
    sentences.extend(data[2].tolist())



  data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)
Skipping line 191: expected 3 fields, saw 4
Skipping line 206: expected 3 fields, saw 4
Skipping line 295: expected 3 fields, saw 4
Skipping line 695: expected 3 fields, saw 4
Skipping line 699: expected 3 fields, saw 4



  data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)
Skipping line 104: expected 3 fields, saw 4
Skipping line 181: expected 3 fields, saw 4
Skipping line 317: expected 3 fields, saw 4
Skipping line 412: expected 3 fields, saw 5
Skipping line 508: expected 3 fields, saw 4



  data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)


  data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)


  data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)


  data = pd.read_csv(StringIO(res.text), sep='\t', header=None, error_bad_lines=False)


  data = pd.read_csv(S

In [None]:
# Remove duplicates and NaN
sentences = [word for word in list(set(sentences)) if type(word) is str]

In [None]:
# initialize sentence transformer model
model = SentenceTransformer('bert-base-nli-mean-tokens')
# create sentence embeddings
sentence_embeddings = model.encode(sentences)
sentence_embeddings.shape

(14504, 768)

## Exact Search - IndexFlatL2

In [None]:
d = sentence_embeddings.shape[1]
index = faiss.IndexFlatL2(d)

#Some indexes need to be trained before loading data (is_trained return False)
index.is_trained

True

In [None]:
index.add(sentence_embeddings)
index.ntotal

14504

In [None]:
k = 4
# xq -> given query
xq = model.encode(["Someone sprints with a football"])

In [None]:
%%time
D, I = index.search(xq, k)  # Perform the search and return the retrieved embeddings and the total time spent
print(I)

[[ 6694  3096 10629  8186]]
CPU times: user 18.7 ms, sys: 0 ns, total: 18.7 ms
Wall time: 19.3 ms


In [None]:
pd.DataFrame(sentences).iloc[[4586, 10252, 12465, 190]]

Unnamed: 0,0
4586,"He urged Congress to ""send me the final bill a..."
10252,A black person is running along a white stand ...
12465,a chain of atoms in a molecule forming a close...
190,Forecasters said warnings might go up for Cuba...


## Approximate Search - IndexIVFFlat

In [None]:
nlist = 50  # how many cells
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [None]:
index.is_trained

False

In [None]:
# We need to train this index
index.train(sentence_embeddings)
index.is_trained

True

In [None]:
index.add(sentence_embeddings)
index.ntotal

14504

In [None]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 6694  3096 10629  8186]]
CPU times: user 2.02 ms, sys: 0 ns, total: 2.02 ms
Wall time: 2.04 ms


In [None]:
# Improve accuracy by increasing the search scope (nprobe attribute value)
index.nprobe = 10

In [None]:
%%time
D, I = index.search(xq, k)  # search
print(I)

[[ 6694  3096 10629  8186]]
CPU times: user 9.7 ms, sys: 2.99 ms, total: 12.7 ms
Wall time: 23.6 ms


## Quantization

In [None]:
m = 8  # number of centroid IDs in final compressed vectors
bits = 8 # number of bits in each centroid

quantizer = faiss.IndexFlatL2(d)  # we keep the same L2 distance flat index
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)

In [None]:
index.is_trained

False

In [None]:
index.train(sentence_embeddings)
index.is_trained

True

In [None]:
index.add(sentence_embeddings)

In [None]:
index.nprobe = 10

In [None]:
%%time
D, I = index.search(xq, k)
print(I)

[[ 6714 10629 12700  3100]]
CPU times: user 2.56 ms, sys: 9 µs, total: 2.57 ms
Wall time: 2.59 ms
