In [100]:
import pdb
from time import time
import scipy
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from numpy import dot
from numpy.linalg import norm
from sentence_transformers import SentenceTransformer

# Locality sensitive hashing custom implementation

Data is from here: https://github.com/MichaelAllen1966/1804_python_healthcare/blob/master/data/drugsComTrain_raw.csv

Goal is to make a model that quickly returns closest sentence to query sentence

## Model

In [2]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')

Downloading:   0%|          | 0.00/690 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.99k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/550 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/122 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/265M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/450 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/229 [00:00<?, ?B/s]

## Sentences

In [3]:
data=pd.read_csv('data/drugsComTrain_raw.csv')
data.shape

(161297, 7)

In [4]:
data=data.drop_duplicates()

In [5]:
data=data.iloc[data['review'].drop_duplicates().index,:]
data=data.reset_index()
data.shape

(112329, 8)

In [6]:
data.head()

Unnamed: 0,index,uniqueID,drugName,condition,review,rating,date,usefulCount
0,0,206461,Valsartan,Left Ventricular Dysfunction,"""It has no side effect, I take it in combinati...",9,20-May-12,27
1,1,95260,Guanfacine,ADHD,"""My son is halfway through his fourth week of ...",8,27-Apr-10,192
2,2,92703,Lybrel,Birth Control,"""I used to take another oral contraceptive, wh...",5,14-Dec-09,17
3,3,138000,Ortho Evra,Birth Control,"""This is my first time using any form of birth...",8,3-Nov-15,10
4,4,35696,Buprenorphine / naloxone,Opiate Dependence,"""Suboxone has completely turned my life around...",9,27-Nov-16,37


In [7]:
data.dropna(inplace=True,axis=0)
data=data.reset_index(drop=True)
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 111772 entries, 0 to 111771
Data columns (total 8 columns):
 #   Column       Non-Null Count   Dtype 
---  ------       --------------   ----- 
 0   index        111772 non-null  int64 
 1   uniqueID     111772 non-null  int64 
 2   drugName     111772 non-null  object
 3   condition    111772 non-null  object
 4   review       111772 non-null  object
 5   rating       111772 non-null  int64 
 6   date         111772 non-null  object
 7   usefulCount  111772 non-null  int64 
dtypes: int64(4), object(4)
memory usage: 6.8+ MB


## Encode

In [8]:
%time complaint_embeddings = model.encode(list(data['review']))

CPU times: user 4min 30s, sys: 1min 53s, total: 6min 24s
Wall time: 5min 39s


In [9]:
complaint_embeddings.shape

(111772, 768)

# Original solution, using numpy

## Make search space

In [10]:
search_space_len=complaint_embeddings.shape[0]
embedding_dims=complaint_embeddings.shape[1]
n_buckets=search_space_len/16
n_buckets

6985.75

In [11]:
n_planes=11
n_repeats=25

####  generate random planes and hash tables

In [12]:
#Generate 11 planes randomly. This gives us a 768 X 11 dimensional matrix
planes_l = [np.random.normal(size=(embedding_dims, n_planes)) for i in range(n_repeats)]
print(len(planes_l))
planes_l[0].shape

25


(768, 11)

In [13]:
def hash_vector(v, planes):
    #Dot vector with randomly generated planes
    # pdb.set_trace()
    dot_product = np.dot(v.T,planes) #( 1 , 768 ) X (768, 11)
    # get the sign of the dot product (1,11) shaped vector
    sign_of_dot_product = np.sign(dot_product)
    h = np.squeeze(sign_of_dot_product>=0)
    hash_value = 0
    n_planes = planes.shape[1]
    for i in range(n_planes):
        # increment the hash value by 2^i * h_i
        hash_value += np.power(2,i)*h[i]
    hash_value = int(hash_value)
    return hash_value

In [14]:
def make_hash_tables(search_space, planes):
    num_of_planes = planes.shape[1]
    # number of buckets is 2^(number of planes)
    num_buckets = 2**num_of_planes
    # Keys are integers (0,1,2.. number of buckets)
    # Values are empty lists
    hash_table = {i:[] for i in range(num_buckets)}
    # Keys are integers (0,1,2... number of buckets)
    # Values are empty lists
    id_table = {i:[] for i in range(num_buckets)}
    # for each vector in 'search_space'
    for i, v in enumerate(search_space):
        # calculate the hash value for the vector
        h = hash_vector(v,planes)
        # store the vector into hash_table at key h,
        hash_table[h].append(v)
        # store the vector's index 'i' 
        id_table[h].append(i)
    return hash_table, id_table

In [15]:
# Creating the hashtables
start=time()
hash_tables = []
id_tables = []
for i in range(n_repeats):  # there are 25 hash tables
    print('Creating  hash table :', i)
    planes = planes_l[i]
    hash_table, id_table = make_hash_tables(complaint_embeddings, planes)
    hash_tables.append(hash_table)
    id_tables.append(id_table)
end=time()
print(f'elapsed time {end-start}')

Creating  hash table : 0
Creating  hash table : 1
Creating  hash table : 2
Creating  hash table : 3
Creating  hash table : 4
Creating  hash table : 5
Creating  hash table : 6
Creating  hash table : 7
Creating  hash table : 8
Creating  hash table : 9
Creating  hash table : 10
Creating  hash table : 11
Creating  hash table : 12
Creating  hash table : 13
Creating  hash table : 14
Creating  hash table : 15
Creating  hash table : 16
Creating  hash table : 17
Creating  hash table : 18
Creating  hash table : 19
Creating  hash table : 20
Creating  hash table : 21
Creating  hash table : 22
Creating  hash table : 23
Creating  hash table : 24
elapsed time 84.76434326171875


In [17]:
len(planes_l)

25

## Search

In [18]:
def reduce_search_space(v, planes_l, k=3, num_tables=n_repeats):
    # Vectors that will be checked as possible similar vectors
    candidate_vectors = []
    # list of IDs for our vectors
    candidate_ids = []
    candidate_ids_set = set()
    # loop through all hash tables 
    for table in range(num_tables):
        # get the set of planes from the planes_l list, for this particular hash table
        planes = planes_l[table]
        # get the hash value of the vector for this set of planes
        hash_value = hash_vector(v, planes)
        # get the hash table
        hash_table = hash_tables[table]
        # get the list of vectors for this hash table, with the same hash value as our vector v
        retreived_vectors = hash_table[hash_value]
        # get the id_table for this particular universe_id
        id_table = id_tables[table]
        new_ids_to_consider = id_table[hash_value]
        # loop through the subset of document vectors to consider
        for i, new_id in enumerate(new_ids_to_consider):
            # if the document ID is not yet in the set ids_to_consider...
            if new_id not in candidate_ids_set:
                # append the vector and id to corresponding lists
                candidate_vectors.append(retreived_vectors[i])
                candidate_ids.append(new_id)
                # also add the new_id to the set of ids to consider
                candidate_ids_set.add(new_id)
    # Now run k-NN on the smaller set of vecs-to-consider.
    print("Reduced space from 25971 documents to %d documents" % len(candidate_vectors))
    vecs_to_consider_arr = np.array(candidate_vectors)
    return vecs_to_consider_arr

In [19]:
problem='Chest feels heavy and difficulty in breathing, I keep coughing. Sweating alot'

In [20]:
problem_embedding=model.encode(problem)

In [38]:
%time vecs_to_consider=reduce_search_space(problem_embedding,planes_l)

Reduced space from 25971 documents to 24904 documents
CPU times: user 82.7 ms, sys: 0 ns, total: 82.7 ms
Wall time: 81.3 ms


  dot_product = np.dot(v.T,planes) #( 1 , 768 ) X (768, 11)


In [39]:
def cosine_similarity(a,b):
    return dot(a, b)/(norm(a)*norm(b))

In [40]:
def nearest_k(v,search_embeddings,top_n=3):
    similarities=[]
    for i in search_embeddings:
        similarities.append(cosine_similarity(v,i))
    return search_embeddings[np.argsort(similarities)[::-1]][:3]

In [41]:
%time x=nearest_k(problem_embedding, complaint_embeddings)

CPU times: user 1.97 s, sys: 0 ns, total: 1.97 s
Wall time: 1.97 s


In [42]:
%time y=nearest_k(problem_embedding, vecs_to_consider)

CPU times: user 449 ms, sys: 0 ns, total: 449 ms
Wall time: 447 ms


In [43]:
id_to_vec={}

for i in range(complaint_embeddings.shape[0]):
    id_to_vec[i]=complaint_embeddings[i]

In [44]:
#Results using entire search space
for k,v in id_to_vec.items():
    for t in x:
        if (v==t).all():
            print(data['review'][k],data['condition'][k])
            print()

"I feel like I cough even more after taking this. It is not effective at all." Cough

"Did not work. Coughing so hard my back hurts." Cough

"Made me feel hot, diarrhea, fatigue, bloated." Pain



In [45]:
#Results using reduced search space
for k,v in id_to_vec.items():
    for t in y:
        if (v==t).all():
            print(data['review'][k],data['condition'][k])
              print()

"I feel like I cough even more after taking this. It is not effective at all." Cough

"Did not work. Coughing so hard my back hurts." Cough

"Made me feel hot, diarrhea, fatigue, bloated." Pain



In [47]:
#turn it into a function to more easily calculate speed
def find_similar_reduced(complaint_embeddings, id_to_vec, planes_l):
    vecs_to_consider=reduce_search_space(problem_embedding,planes_l)
    y=nearest_k(problem_embedding, vecs_to_consider)
    for k,v in id_to_vec.items():
        for t in y:
            if (v==t).all():
                print(data['review'][k],data['condition'][k])
                print()
     
%time find_similar_reduced(complaint_embeddings, id_to_vec, planes_l)

Reduced space from 25971 documents to 24904 documents
"I feel like I cough even more after taking this. It is not effective at all." Cough

"Did not work. Coughing so hard my back hurts." Cough

"Made me feel hot, diarrhea, fatigue, bloated." Pain

CPU times: user 1.65 s, sys: 0 ns, total: 1.65 s
Wall time: 1.65 s


# Pytorch version

In [29]:
complaint_embeddings_t=torch.tensor(complaint_embeddings)

In [113]:
class LSH:
    def __init__(self, texts, embeddings, n_planes=11, n_repeats=25):
        self.texts=texts
        self.embeddings=embeddings
        self.n_planes=n_planes
        self.n_repeats=n_repeats
        self.num_buckets = 2**self.n_planes
        self.search_space_len=self.embeddings.shape[0]
        self.embedding_dims=self.embeddings.shape[1]
        self.n_buckets=self.search_space_len/16
        
        self.row_idx=torch.tensor(list(range(self.n_planes)))
        self.row_pow=torch.pow(self.row_idx, 2)
        self.cossim= nn.CosineSimilarity(dim=1, eps=1e-6)
        self.planes=torch.randn((self.n_repeats, self.embedding_dims, self.n_planes))
        self.hash_tables=[]
        self.id_tables=[]
        self.init_hash_tables_()
        
    def init_hash_tables_(self):
        self.hash_tables = [{i:[] for i in range(self.num_buckets)} for _ in range(self.n_repeats)]
        self.id_tables = [{i:[] for i in range(self.num_buckets)} for _ in range(self.n_repeats)]

    def hash_vector(self, vector, planes):
        dot_prod=vector@planes
        signs=torch.sign(dot_prod)>=0
        hash_values=(self.row_pow.repeat(signs.shape[0], 1)*signs).sum(axis=1)
        return hash_values
    
    def make_search_space(self):
        for i, v in enumerate(self.embeddings):
            # calculate the hash value for the vector
            h = self.hash_vector(v, self.planes)
            for j, h_val in enumerate(h):
                h_val=h_val.item()
                # store the vector into hash_table at key h,
                self.hash_tables[j][h_val].append(v)
                # store the vector's index 'i' 
                self.id_tables[j][h_val].append(i)
                
    def reduce_search_space(self, v, k=3):
        candidate_vectors = []
        candidate_ids = []
        candidate_ids_set = set()
        h = self.hash_vector(v, self.planes)
        for j, h_val in enumerate(h):
                h_val=h_val.item()
                retreived_vectors = self.hash_tables[j][h_val]
                new_ids_to_consider = self.id_tables[j][h_val]
                for i, new_id in enumerate(new_ids_to_consider):
                    # if the document ID is not yet in the set ids_to_consider...
                    if new_id not in candidate_ids_set:
                        # append the vector and id to corresponding lists
                        candidate_vectors.append(retreived_vectors[i])
                        candidate_ids.append(new_id)
                        # also add the new_id to the set of ids to consider
                        candidate_ids_set.add(new_id)
        return torch.stack(candidate_vectors), candidate_ids
    
    def nearest_index(self, v, search_embeddings):
        sims=self.cossim(v, search_embeddings)
        sorted_idx=torch.argsort(sims, descending=True)
        return sorted_idx
    
    def find_topk(self, v, k=3):
        candidate_vectors, candidate_ids=self.reduce_search_space(v)
        sorted_reduced_idx= self.nearest_index(v, candidate_vectors)
        texts_idx=np.array(candidate_ids)[sorted_reduced_idx][:k]
        top_texts=[self.texts[idx] for idx in texts_idx]
        return top_texts
        

In [114]:
%time lsh=LSH(data['review'].tolist(), complaint_embeddings_t)

CPU times: user 1.43 s, sys: 8.85 ms, total: 1.44 s
Wall time: 1.44 s


In [115]:
%time lsh.make_search_space()

CPU times: user 4min, sys: 503 ms, total: 4min
Wall time: 40.4 s


In [116]:
problem='Chest feels heavy and difficulty in breathing, I keep coughing. Sweating alot'
problem_embedding=torch.tensor(model.encode(problem))

In [117]:
%time lsh.find_topk(problem_embedding)

CPU times: user 289 ms, sys: 190 ms, total: 479 ms
Wall time: 132 ms


['"Did not work. Coughing so hard my back hurts."',
 '"I feel like I cough even more after taking this. It is not effective at all."',
 '"Made me feel hot, diarrhea, fatigue, bloated."']

Pytorch version search space creation is 2x faster (85 vs 40 sec) and search also a lot faster (1.6 vs 0.13 sec)