In [1]:
import pickle
import numpy as np

data_vocab_path = "../data/jigsaw/data_vocab.bin"

In [2]:
vocab=pickle.load(open(data_vocab_path,'rb'))

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
from pytorch_pretrained_bert.tokenization import BertTokenizer

bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [4]:
bert_vocab_toks = list(bert_tokenizer.vocab.keys())
vocab_toks = set( [w for idx, w in vocab.get_index_to_token_vocabulary().items() ])
len(vocab_toks), len(bert_vocab_toks)

(305140, 30522)

In [5]:
import nmslib, time

M = 25
efC = 200

num_threads = 0
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post' : 0}
print('Index-time parameters', index_time_params)

Index-time parameters {'M': 25, 'indexThreadQty': 0, 'efConstruction': 200, 'post': 0}


In [6]:
# Space name should correspond to the space name 
# used for brute-force search
space_name='leven'


# Intitialize the library, specify the space, the type of the vector and add data points 
index = nmslib.init(method='hnsw', space=space_name, dtype=nmslib.DistType.INT, data_type=nmslib.DataType.OBJECT_AS_STRING) 
index.addDataPointBatch(bert_vocab_toks)

30522

In [7]:
# Create an index
start = time.time()
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
index.createIndex(index_time_params) 
end = time.time() 
print('Index-time parameters', index_time_params)
print('Indexing time = %f' % (end-start))

Index-time parameters {'M': 25, 'indexThreadQty': 0, 'efConstruction': 200}
Indexing time = 5.272334


In [8]:
# Setting query-time parameters
efS = 1000
K=10
query_time_params = {'efSearch': efS}
print('Setting query-time parameters', query_time_params)
index.setQueryTimeParams(query_time_params)

Setting query-time parameters {'efSearch': 1000}


In [17]:
query_arr = ['fuuck', 'фuck']
K=10

In [18]:
# Querying
query_qty = len(query_arr)
start = time.time() 
nbrs = index.knnQueryBatch(query_arr, k = K, num_threads = num_threads)
end = time.time() 
print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % 
      (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty))

kNN time total=0.006527 (sec), per query=0.003263 (sec), per query adjusted for thread number=0.000000 (sec)


In [19]:
nbrs[0][0], nbrs[0][1]

(array([ 6616,  4248,  4744,  6735,  8057,  9457, 10131, 18029, 11891,
        12722], dtype=int32),
 array([1, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32))

In [20]:
for query_id in range(query_qty):
    print('Query:', query_id)
    qty = len(nbrs[query_id][0])
    for i in range(qty):
        print("word:", bert_vocab_toks[nbrs[query_id][0][i]], "distance: ", nbrs[query_id][1][i])
    print('===============')

Query: 0
word: fuck distance:  1
word: quick distance:  2
word: truck distance:  2
word: luck distance:  2
word: chuck distance:  2
word: duck distance:  2
word: buck distance:  2
word: tuck distance:  2
word: suck distance:  2
word: ##uck distance:  2
Query: 1
word: fuck distance:  2
word: truck distance:  2
word: luck distance:  2
word: chuck distance:  2
word: duck distance:  2
word: buck distance:  2
word: tuck distance:  2
word: suck distance:  2
word: ##uck distance:  2
word: snuck distance:  2
