In [7]:
import re
import numpy as np
from nltk.corpus import stopwords

# Load Embeddings

In [8]:
from sklearn.preprocessing import normalize

In [9]:
def load_starspace(file):
    result = {}
    for line in file:
        word, *emb = line.rstrip().split('\t')
        result[word] = np.array([float(x) for x in emb], dtype='float32')
    return result

In [11]:
with open('../embeddings/starspace.emb.tsv', encoding='utf-8') as f:
    starspace_emb = load_starspace(f)

In [12]:
token_re = re.compile(r'([a-zA-Z$#][\.\-]*[0-9a-zA-Z\.\-$#]*[0-9a-zA-Z$#])')

In [13]:
stop_en = stopwords.words('english')

In [14]:
def process_text(text):
    words = token_re.findall(text.lower())
    return [w for w in words if not w in stop_en]

In [15]:
def text_to_emb(text, emb, emb_shape=100):
    tokens = process_text(text)
    embs = np.array([emb[token] for token in tokens if token in emb])
    if embs.shape[0] != 0:
        return normalize(np.mean(embs, axis=0).reshape(1, -1))[0]
    else:
        return np.zeros((emb_shape,))
    

# Generate Post Embeddings

In [16]:
import sqlite3
import pickle
from tqdm import tqdm_notebook

In [17]:
con = sqlite3.connect('../data/original/filtered_posts.db')

In [18]:
cur = con.cursor()

In [19]:
list(cur.execute('select count(*) from Posts where Score > 3;'))

[(856648,)]

In [None]:
ids = [ids]

In [19]:
ids_titles = []
for post_id, title in tqdm_notebook(cur.execute('select Id, Title from Posts where Score > 3;')):
    ids_titles.append((post_id, title))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [20]:
post_embeddings = []
for post_id, title, body in tqdm_notebook(cur.execute('selsect Id, Title, Body from Posts where Score > 3;')):
    post_embeddings.append(text_to_emb(title + '. ' + body, starspace_emb))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [21]:
post_embeddings = np.array(post_embeddings)

In [22]:
post_embeddings.shape

(856648, 100)

In [23]:
with open('../data/original/post_embeddings.pkl4', 'wb') as f:
    pickle.dump(post_embeddings, f, protocol=4)

In [20]:
del post_embeddings

In [20]:
with open('../data/original/ids_titles.pkl4', 'wb') as f:
    pickle.dump(ids_titles, f, protocol=4)

In [21]:
ids = [x[0] for x in ids_titles]

In [22]:
with open('../data/original/ids.pkl4', 'wb') as f:
    pickle.dump(ids, f, protocol=4)

# Ranking

In [9]:
import pickle
from sklearn.metrics.pairwise import cosine_similarity

In [10]:
from tqdm import tqdm_notebook

In [11]:
with open('../data/original/post_embeddings.pkl4', 'rb') as f:
    post_embeddings = pickle.load(f)

In [23]:
def find_similar_post(query, post_embeddings, emb):
    query_emb = text_to_emb(query, emb).reshape(1, -1)
    top5 = []
    for i in range(0, post_embeddings.shape[0], 100000):
        sims = cosine_similarity(post_embeddings[i:i + 100000], query_emb)
        idxs = np.argpartition(sims, -5, axis=None)[-5:]
        top5 += list(zip(idxs + i, sims[idxs]))
        top5 = sorted(top5, key=lambda x: x[1], reverse=True)[:5]
    return top5

In [24]:
post_ids = find_similar_post('html change button color', post_embeddings, starspace_emb)
print(post_ids)

[(356397, array([0.882248])), (67568, array([0.83382447])), (486906, array([0.81986725])), (587557, array([0.81450965])), (801499, array([0.78309313]))]


In [27]:
for post_id in post_ids:
    print(ids_titles[post_id[0]])

(11176365, 'How do I change the text color of a Button?')
(1819878, 'Changing button color programmatically')
(17201401, 'Button text color in Extjs 4')
(23819847, 'How to change color of a button when clicked using bootstrap?')
(42549031, 'How to change UIDocumentInteractionController Done button text and background color')


# Trying LSH

In [48]:
sqlite3.register_adapter(np.int64, lambda val: int(val))
sqlite3.register_adapter(np.int32, lambda val: int(val))

In [23]:
from sklearn.preprocessing import normalize
from itertools import chain
from collections import defaultdict
import gc

In [33]:
hash_functions = [np.random.rand(100, 20) - 0.5 for _ in range(100)]

In [34]:
#with open('../data/original/lsh_hash_functions_100.pkl', 'wb') as f:
#    pickle.dump(hash_functions, f)

In [35]:
with open('../data/original/lsh_hash_functions_100.pkl', 'rb') as f:
    hash_functions = pickle.load(f)

In [36]:
def get_hash(hash_function, emb):
    return (np.dot(emb, hash_function)>0).astype('int8')

In [37]:
def get_buckets(hash_functions, emb):
    hashes = [get_hash(hash_function, emb) for hash_function in hash_functions]
    buckets = [h.dot(1 << np.arange(h.shape[-1] - 1, -1, -1)) for h in hashes]
    return buckets

In [38]:
con = sqlite3.connect('../data/original/bucket_registry_100.db')

In [39]:
cur = con.cursor()

In [40]:
list(cur.execute('create table if not exists BucketRegistry('\
                 'BucketId int,'\
                 'PostIndex int);'))

[]

In [41]:
list(cur.execute('create index if not exists bucket_bucket_id on BucketRegistry(BucketId);'))

[]

In [42]:
gc.collect()

0

In [43]:
for hash_index, hash_function in tqdm_notebook(enumerate(hash_functions)):
    buckets = get_buckets([hash_function], post_embeddings)[0]
    values = []
    for post_index, post_bucket in enumerate(buckets):
        values.append((int(post_bucket), int(post_index)))
    cur.executemany('insert into BucketRegistry values(?, ?);', values)
    con.commit()

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




In [55]:
bucket_registry = defaultdict(set)
for post_bucket, post_index in cur.execute('select BucketId, PostIndex from BucketRegistry;'):
    bucket_registry[post_bucket].add(post_index)

KeyboardInterrupt: 

In [101]:
con.close()

In [34]:
def query_to_hash(query, hash_function, emb):
    query_emb = text_to_emb(query, emb).reshape(1, -1)
    return get_hash(hash_function, query_emb)

In [35]:
def query_to_buckets(query, hash_functions, emb):
    query_emb = text_to_emb(query, emb).reshape(1, -1)
    buckets = get_buckets(hash_functions, query_emb)
    return [x[0] for x in buckets]
    

In [36]:
query_to_hash('python', hash_functions[0], starspace_emb)

array([[1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0]],
      dtype=int8)

In [37]:
query_to_hash('python numpy', hash_functions[0], starspace_emb)

array([[1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0]],
      dtype=int8)

In [38]:
a = query_to_buckets('graph in matplotlib', hash_functions, starspace_emb)

In [39]:
b = query_to_buckets('plot python', hash_functions, starspace_emb)

In [40]:
set(a)&set(b)

{1046769}

In [57]:
def search_buckets_sqlite(query, hash_functions, post_embeddings, con, emb):
    query_emb = text_to_emb(query, emb).reshape(1, -1)
    query_buckets = query_to_buckets(query, hash_functions, emb)
    top5 = []
    for bucket_id in query_buckets:
        potential_ids = con.execute('select PostIndex from BucketRegistry where BucketId = ?;', (int(bucket_id),))
        potential_ids = set([x[0] for x in potential_ids])
        potential_ids = list(potential_ids - set([x[0] for x in top5]))
        sims = cosine_similarity(post_embeddings[potential_ids], query_emb)
        top = max(-5, -len(potential_ids))
        idxs = np.argpartition(sims, top, axis=None)[-top:]
        idxs_ = [potential_ids[x] for x in idxs]
        top5 += list(zip(idxs_, sims[idxs]))
        top5 = sorted(top5, key=lambda x: x[1], reverse=True)[:5]
    return top5

In [58]:
def search_buckets_memory(query, hash_functions, post_embeddings, bucket_registry, emb):
    query_emb = text_to_emb(query, emb).reshape(1, -1)
    query_buckets = query_to_buckets(query, hash_functions, emb)
    top5 = []
    for bucket_id in query_buckets:
        potential_ids = bucket_registry[bucket_id]
        potential_ids = list(potential_ids - set([x[0] for x in top5]))
        sims = cosine_similarity(post_embeddings[potential_ids], query_emb)
        top = max(-5, -len(potential_ids))
        idxs = np.argpartition(sims, top, axis=None)[-top:]
        idxs_ = [potential_ids[x] for x in idxs]
        top5 += list(zip(idxs_, sims[idxs]))
        top5 = sorted(top5, key=lambda x: x[1], reverse=True)[:5]
    return top5

In [44]:
post_ids = search_buckets_sqlite('locality sensitive hashing', hash_functions, post_embeddings, con, starspace_emb)
print(post_ids)

[(204490, array([0.68784533])), (257587, array([0.66864631])), (248319, array([0.66469806])), (550705, array([0.65088469])), (605244, array([0.64432025]))]


In [45]:
for post_id in post_ids:
    print(ids_titles[post_id[0]])

(5769949, 'Locality Sensitive Hash Implementation?')
(7489257, 'what is the meaning of Kanatype Sensitive KS and width sensitive')
(7172117, 'Locality Preserving Hash Function')
(21001455, 'Should a REST API be case sensitive or non case sensitive?')
(25170063, 'Salt/Hash for Firebase Simple Login?')


In [46]:
%timeit find_similar_post('python tkinter interface', post_embeddings, starspace_emb)

1.03 s ± 9.81 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [60]:
%timeit search_buckets_sqlite('python tkinter interface', hash_functions, post_embeddings, con, starspace_emb)

520 ms ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [62]:
%timeit search_buckets_memory('python tkinter interface', hash_functions, post_embeddings, bucket_registry, starspace_emb)

120 ms ± 2.55 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [69]:
for post_id in find_similar_post("OpenCV: 'BruteForceMatcher' : undeclared identifier", post_embeddings, starspace_emb):
    print(ids[post_id[0]])

16830842
10876052
14675279
19565262
12123479


In [70]:
for post_id in search_buckets_sqlite("OpenCV: 'BruteForceMatcher' : undeclared identifier",
                              hash_functions, post_embeddings, con, starspace_emb):
    print(ids[post_id[0]])

10876052
12123479
36807747
37238431
15293220


In [63]:
del ids_titles

In [66]:
list(cur.execute('select count(*) from BucketRegistry;'))

[(128497200,)]

In [72]:
post_embeddings.shape

(856648, 100)