In [2]:
import faiss
import json
import mlflow
import os
import torch
import numpy as np
from src.data.text_retriever import TextRetriever
from src.models.knrm_model import KNRM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
PARENT_DIR = os.path.abspath(os.path.join('', os.pardir))
DOCUMENTS_PATH = PARENT_DIR + '/data/processed/documents.json'
ML_RUNS_PATH = PARENT_DIR + '/models/ml_runs/'

with open(DOCUMENTS_PATH) as f:
    documents = json.load(f)

if mlflow.get_tracking_uri() != 'file:///' + ML_RUNS_PATH:
    mlflow.set_tracking_uri('file:///' + ML_RUNS_PATH)
EXP_ID = mlflow.get_experiment_by_name('QuoraRankingExtendedTraining').experiment_id
RUN_ID = mlflow.search_runs(experiment_ids=[EXP_ID])['run_id'][0]

MODEL_URI = "runs:/{}/model".format(RUN_ID)
VOCAB_URI = "runs:/{}/vocab".format(RUN_ID)
knrm = mlflow.pytorch.load_model(MODEL_URI)
vocab = mlflow.artifacts.load_dict(VOCAB_URI)

In [4]:
len(documents)

537916

In [5]:
idxs, docs = [], []
for idx in documents:
    idxs.append(int(idx))
    docs.append(documents[idx])

In [6]:
embeddings = []
oov_val = vocab['OOV']
tr = TextRetriever()
emb_layer = knrm.embeddings.state_dict()['weight']
for d in docs:
    tmp_emb = [vocab.get(w, oov_val) for w in tr.lower_and_tokenize_words(d)]
    tmp_emb = emb_layer[tmp_emb].mean(dim=0)
    embeddings.append(np.array(tmp_emb))

embeddings = np.array([embedding for embedding in embeddings]).astype(np.float32)

In [7]:
embeddings.shape[1]

50

In [8]:
index = faiss.IndexFlatL2(embeddings.shape[1])
index = faiss.IndexIDMap(index)
index.add_with_ids(embeddings, np.array(idxs))

In [10]:
def get_memory(index):
    faiss.write_index(index, './temp.index')
    file_size = os.path.getsize('./temp.index')
    os.remove('./temp.index')
    return file_size

print(get_memory(index))

111886618


In [11]:
query = 'How is the life of a math student?'

q_vector = [vocab.get(token, oov_val) for token in tr.lower_and_tokenize_words(query)]
q_emb = emb_layer[q_vector].mean(dim=0).reshape(1, -1)
q_emb = np.array(q_emb).astype(np.float32)

In [12]:
_, I = index.search(q_emb, k=100)

In [13]:
def text_to_token_ids(text_list, vocab):
    tokenized = []
    for text in text_list:
        tokenized_text = tr.lower_and_tokenize_words(text)
        token_idxs = [vocab.get(i, vocab["OOV"]) for i in tokenized_text]
        tokenized.append(token_idxs)
    max_len = 30
    tokenized = [elem + [0] * (max_len - len(elem)) for elem in tokenized]
    tokenized = torch.LongTensor(tokenized)
    return tokenized

In [14]:
cands = [(str(i), documents[str(i)]) for i in I[0] if i != -1]
inputs = dict()
inputs['query'] = text_to_token_ids([query] * len(cands), vocab)
inputs['document'] = text_to_token_ids([cnd[1] for cnd in cands], vocab)
scores = knrm.predict(inputs)

In [15]:
res_ids = scores.reshape(-1).argsort(descending=True)
res_ids = res_ids[:10]
res = [cands[i] for i in res_ids.tolist()]

In [16]:
query, res

('How is the life of a math student?',
 [('537770',
   'How hard is it to get into Art Center College of Design, CA for a middle-class Indian student with science background?'),
  ('337183',
   "What does a librarian with a library science master's degree do? What is it like a day in the life of a librarian?"),
  ('213221',
   'How is the life of a math student? Could you describe your own experiences?'),
  ('38754',
   "What is more important in a letter of recommendation, the teacher's designation or whatever the teacher writes about the student?"),
  ('144409',
   'Since each of us has been a student first, what is the harshest thing a teacher taught you?'),
  ('20643',
   'What is the best time table for a student of maths and science of class 11th?'),
  ('29651',
   'What is the importance of clubs(technical) in a life of a engineering student?'),
  ('203537', 'How is the life/experience of a student at MBBS college?'),
  ('250533',
   'What is the best field of engineering to a s

In [38]:
emb_matrix[0][:3], emb_matrix[1][:3]

(array([0.0195254 , 0.08607575, 0.04110535]),
 array([ 0.02807871, -0.02455939,  0.19534954]))

In [43]:
emb_matrix[[0, 1], :3].mean(axis=0)

array([0.02380205, 0.03075818, 0.11822744])

In [None]:
oov_val = self.vocab["OOV"]

embeddings = []
emb_layer = self.model.embeddings.state_dict()['weight']
for d in docs:
    tmp_emb = [self.vocab.get(w, oov_val) for w in self._simple_preproc(d)]
    tmp_emb = emb_layer[tmp_emb].mean(dim = 0)
    embeddings.append(np.array(tmp_emb))          
embeddings = np.array([embedding for embedding in embeddings]).astype(np.float32)
self.index = faiss.IndexFlatL2(embeddings.shape[1])
self.index = faiss.IndexIDMap(self.index)
self.index.add_with_ids(embeddings, np.array(idxs))
index_size = self.index.ntotal
global index_is_ready
index_is_ready = True

In [17]:
import numpy as np
d = 64                          # dimension
nb = 100000                     # database size
nq = 100                       # nb of queries
np.random.seed(1234)             # make reproducible
xb = np.random.random((nb, d)).astype('float32')
xb[:, 0] += np.arange(nb) / 1000.
xq = np.random.random((nq, d)).astype('float32')
xq[:, 0] += np.arange(nq) / 1000.

In [18]:
import faiss                   # make faiss available
index = faiss.IndexFlatL2(d)   # build the index
print(index.is_trained)
index.add(xb)                  # add vectors to the index
print(index.ntotal)

True
100000


In [19]:
k = 4                          # we want to see 4 nearest neighbors
D, I = index.search(xb[:5], k) # sanity check
print(I)
print(D)
D, I = index.search(xq, k)     # actual search
print(I[:5])                   # neighbors of the 5 first queries
print(I[-5:])   

[[  0 393 363  78]
 [  1 555 277 364]
 [  2 304 101  13]
 [  3 173  18 182]
 [  4 288 370 531]]
[[0.        7.175174  7.2076287 7.251163 ]
 [0.        6.323565  6.684582  6.799944 ]
 [0.        5.7964087 6.3917365 7.2815127]
 [0.        7.277905  7.5279875 7.6628447]
 [0.        6.763804  7.295122  7.368814 ]]
[[ 381  207  210  477]
 [ 526  911  142   72]
 [ 838  527 1290  425]
 [ 196  184  164  359]
 [ 526  377  120  425]]
[[ 801  781  933  385]
 [1073  786 1076  381]
 [ 549  244  100 1008]
 [ 917  140  965   68]
 [ 511  789  225  781]]


In [15]:
xq[0], xb[66]

(array([0.012596  , 0.6246662 , 0.56637293, 0.95298845, 0.08366151],
       dtype=float32),
 array([0.11797123, 0.5159861 , 0.40499622, 0.9996496 , 0.10857001],
       dtype=float32))

In [16]:
D

array([[0.05175545, 0.13396178, 0.14332213, 0.15562324],
       [0.05284711, 0.05362   , 0.06072539, 0.07184806],
       [0.0321276 , 0.04367728, 0.08774583, 0.0887173 ]], dtype=float32)