In [167]:
import numpy as np
import pandas as pd
import os
import glob
from pathlib import Path
from scipy.spatial import distance
import json
from collections import defaultdict
from sklearn.metrics.pairwise import cosine_similarity

In [168]:
input_dir = 'phase2_outputs/task0b'
mode = 'tfidf' #tf 
file_list = glob.glob('phase2_outputs/task0b/{}_*.txt'.format(mode))

In [169]:
def loadDataMatrix(input_dir, mode):
    file_list = glob.glob(input_dir + '/{}_*.txt'.format(mode))
    data = dict()
    for fname in sorted(file_list):
        fileNum = int(fname.split("_")[-1].split('.')[0])
        data[fileNum] = json.loads(json.load(open(fname, 'r')))
    return data

In [170]:
data = loadDataMatrix(input_dir, mode)

In [171]:
data.keys()

dict_keys([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589])

In [172]:
class LSH:
    def __init__(self, L, k, mode, input_dir):
        self.L= L
        self.k = k
        self.mode = mode
        self.data = None
        self.hash_tables = []
        self.n_words = None
        self.file_list = sorted(glob.glob('{}/{}_*.txt'.format(input_dir, mode)))
        self.idx_2_file = {i:fname for i,fname in enumerate(self.file_list)}
        self.file_2_idx = {fname:i for i,fname in enumerate(self.file_list)}
    
    
    def get_random_vectors(self):
        """Generates a random array of num_words*k dimension"""
        return np.random.randn(self.n_words, self.k)
    
    def load_data(self):
        self.data = np.array([json.loads(json.load(open(fname, 'r'))) for fname in self.file_list])
        self.n_words = len(self.data[0])
    
    def binary_2_integer(self, binary_vectors):
        exponents = np.array([2**i for i in range(self.k - 1, -1, -1)])
        return binary_vectors.dot(exponents)
    
    def hash_data(self, data, random_vectors):
        if(len(data.shape)==1):
            data = data.reshape(1,data.shape[0])
        binary_repr = data.dot(random_vectors) >= 0
        #binary_inds = self.binary_2_integer(binary_repr)
        binary_string_arr = []
        for idx, binaryArr in enumerate(binary_repr.astype(int).astype(str)):
            binary_string_arr.append(str.encode(''.join(binaryArr)))
        return binary_string_arr
    
    def train(self):
        self.load_data()
        for i in range(self.L):
            random_vectors = self.get_random_vectors()
            binary_inds = self.hash_data(self.data, random_vectors)
            table = defaultdict(list)
            for idx, bin_ind in enumerate(binary_inds):
                table[bin_ind].append(idx)
            hash_table = {'random_vectors': random_vectors, 'table': table}
            self.hash_tables.append(hash_table)
            
            
    def query(self, data_point, max_results):
        retrieval = set()
        n_buckets = 0
        n_candidates = 0
        permLevel = 0
        
        while(len(retrieval) < max_results):
            for hash_table in self.hash_tables:
                table = hash_table['table']
                random_vectors = hash_table['random_vectors']
                binary_idx = self.hash_data(data_point, random_vectors)[0]
                
                # Neighbors of permutations of the original hash
                if(permLevel>0):
                    for hash_table_idx in table.keys():
                        # counting changed bits
                        xorNum = bin(int(binary_idx,2) ^ int(hash_table_idx,2))[2:]
                        bitChanges = xorNum.encode().count(b'1')

                        # if this is the current permutation level desired
                        if(bitChanges == permLevel):
                            n_buckets += 1
                            retrieval.update(table[hash_table_idx])
                        
                        if(len(retrieval) >= max_results):
                            break
                    
                    if(len(retrieval) >= max_results):
                            break
                # Original hash neighbors
                else:
                    if(table[binary_idx]):
                        n_buckets += 1
                        #print("Bucket {} : {}".format(n_buckets, len(table[binary_idx])))
                        retrieval.update(table[binary_idx])
                    
                    
            permLevel += 1
            
            # No More Permutations left
            if(permLevel == self.k):
                break
            
        
        retrieval = list(retrieval)
        sim_scores = cosine_similarity(np.expand_dims(data_point, 0), self.data[retrieval]).ravel()
        data_idx = sim_scores.argsort()[::-1][:max_results]
        #print(sim_scores)
        #print(data_idx)
        #print(retrieval)
        assert len(retrieval) == len(sim_scores)
        return {'n_buckets': n_buckets, 
                'n_candidates': len(retrieval),
                'scores': sim_scores[data_idx],
                'retrieved_files': [self.idx_2_file[retrieval[d]] for d in data_idx]
               }
            

In [173]:
lsh = LSH(L=8, k=4, mode='tfidf', input_dir='phase2_outputs/task0b')

In [174]:
lsh.load_data()

In [175]:
lsh.train()

In [176]:
query_idx = 10
lsh.query(np.array(data[query_idx]), max_results=98)

{'n_buckets': 117,
 'n_candidates': 93,
 'scores': array([1.        , 0.27214871, 0.260384  , 0.25065707, 0.24599183,
        0.24595429, 0.24285738, 0.23327691, 0.23297201, 0.23057715,
        0.22743643, 0.22290228, 0.22274283, 0.22006634, 0.21413458,
        0.21133903, 0.20892665, 0.20834708, 0.20325307, 0.20285688,
        0.1996427 , 0.19601594, 0.19575157, 0.19420888, 0.1937912 ,
        0.19208847, 0.19072984, 0.19056732, 0.18780037, 0.18347695,
        0.18238245, 0.17984276, 0.17706218, 0.17701715, 0.17405459,
        0.17159915, 0.16711412, 0.16649461, 0.16562823, 0.16365316,
        0.16235872, 0.15999322, 0.15781616, 0.15697307, 0.15330541,
        0.15195794, 0.15083902, 0.14988838, 0.14834812, 0.14626289,
        0.14543055, 0.14542234, 0.14404824, 0.14383728, 0.14338521,
        0.14178141, 0.14047743, 0.13981319, 0.13954217, 0.13902027,
        0.13573397, 0.1339465 , 0.13116409, 0.13106888, 0.13075548,
        0.13070178, 0.12641772, 0.12630289, 0.12520483, 0.12164735