https://www.pinecone.io/learn/series/faiss/faiss-tutorial/ (ref)

In [1]:
import requests
from io import StringIO
import pandas as pd

In [2]:
url = 'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/sick2014/SICK_train.txt'
res = requests.get(url)
data_df = pd.read_csv(StringIO(res.text), sep='\t')
data_df.head(2)

Unnamed: 0,pair_ID,sentence_A,sentence_B,relatedness_score,entailment_judgment
0,1,A group of kids is playing in a yard and an ol...,A group of boys in a yard is playing and a man...,4.5,NEUTRAL
1,2,A group of children is playing in the house an...,A group of kids is playing in a yard and an ol...,3.2,NEUTRAL


In [3]:
sentences = []
sentences.extend(data_df['sentence_A'].tolist())
sentences.extend(data_df['sentence_B'].tolist())
sentences = list(set(sentences))
print('total:', len(sentences))
sentences[:2]

total: 4802


['There is no dog running on the road',
 'Two children are rolling in muddy water']

In [4]:
len(list(set(sentences)))

4802

In [5]:
urls = [
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.train.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/MSRpar.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2012/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2013/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/OnWN.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2014/images.test.tsv',
    'https://raw.githubusercontent.com/brmson/dataset-sts/master/data/sts/semeval-sts/2015/images.test.tsv'
]

In [6]:
for url in urls:
    res = requests.get(url)
    data_df = pd.read_csv(StringIO(res.text), sep='\t', header=None, on_bad_lines='skip')
    sentences.extend(data_df[1].tolist())
    sentences.extend(data_df[2].tolist())
sentences = list(set(sentences))
print('total:', len(sentences))
sentences[:2]

total: 14505


['A young lady with light brown hair is wearing a red necklace, a sweatshirt and earrings and is smiling',
 'A skateboarder jumps off the stairs.']

In [7]:
sentences = [word for word in sentences if type(word)==str]
print('total:', len(sentences))

total: 14504


In [10]:
from sentence_transformers import SentenceTransformer

In [14]:
# !pip install sentence_transformers
# !pip -q install faiss-cpu

In [15]:
model = SentenceTransformer('bert-base-nli-mean-tokens')

In [16]:
sentence_embeddings = model.encode(sentences, show_progress_bar=True)

Batches:   0%|          | 0/454 [00:00<?, ?it/s]

In [22]:
sentence_embeddings.shape

(14504, 768)

In [23]:
import faiss

In [24]:
d = sentence_embeddings.shape[1]
d

768

In [25]:
index = faiss.IndexFlatL2(d)
index.is_trained

True

In [26]:
index.add(sentence_embeddings)

In [27]:
k = 4
query = "how to paly piano?"
query_emb = model.encode([query])
query_emb.shape

(1, 768)

In [28]:
%%time
D, I = index.search(query_emb, k)
print(I)

[[12719  1419   788  8775]]
CPU times: user 5.63 ms, sys: 991 µs, total: 6.62 ms
Wall time: 10.7 ms


In [29]:
for idx in I[0]:
    print(sentences[idx])

Somebody is playing the piano
Someone is playing a piano
The piano is being played by someone
A person is playing a piano


In [30]:
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)
index.is_trained

False

In [31]:
index.train(sentence_embeddings)

In [32]:
index.is_trained

True

In [33]:
%%time
index.add(sentence_embeddings)

CPU times: user 95.2 ms, sys: 0 ns, total: 95.2 ms
Wall time: 104 ms


In [34]:
index.ntotal

14504

In [35]:
%%time
D, I = index.search(query_emb, k)
print(I)

[[ 9609  8874 11352  3810]]
CPU times: user 2.64 ms, sys: 0 ns, total: 2.64 ms
Wall time: 2.71 ms


In [36]:
for idx in I[0]:
    print(sentences[idx])

A kid is playing the piano
A girl is playing the piano
A little girl is playing the piano
A cat song is being played on a piano


In [37]:
index.nprobe

1

In [38]:
index.nprobe = 10

In [39]:
%%time
D, I = index.search(query_emb, k)
print(I)

[[12719  1419   788  8775]]
CPU times: user 2.83 ms, sys: 0 ns, total: 2.83 ms
Wall time: 1.67 ms


In [40]:
for idx in I[0]:
    print(sentences[idx])

Somebody is playing the piano
Someone is playing a piano
The piano is being played by someone
A person is playing a piano


In [41]:
m = 8
bits = 8

quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, bits)
index.is_trained

False

In [42]:
index.train(sentence_embeddings)

In [43]:
index.nprobe = 10

In [44]:
index.add(sentence_embeddings)

In [45]:
%%time
D, I = index.search(query_emb, k)
print(I)

[[ 788 1419 5028 7234]]
CPU times: user 2.58 ms, sys: 0 ns, total: 2.58 ms
Wall time: 2.22 ms


In [46]:
for idx in I[0]:
    print(sentences[idx])

The piano is being played by someone
Someone is playing a piano
Someone is playing piano
A piano is being played by a person
