In [1]:
import pandas as pd
from tqdm.auto import tqdm
from nnsplit import NNSplit
import numpy as np
from threading import Thread
import torch
import pickle
import h5py
from io import BytesIO
splitter = NNSplit.load("en", use_cuda=True)

In [2]:
import pymongo
from pymongo import MongoClient
client = MongoClient('mongodb+srv://cdminix:LTEG2pfoDiKfH29M@cluster0.pdjrf.mongodb.net/Reviews_Data?retryWrites=true&w=majority')

In [3]:
db = client.Reviews_Data

In [4]:
from bson.objectid import ObjectId
import zlib

In [5]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('paraphrase-distilroberta-base-v1')

In [6]:
class LSH:
    def __init__(self,
                 hdf5_file="data.hdf5",
                 input_dim=768,
                 hash_dim=6,
                 seed=42,
                 chunksize=1_000,
                 dtype='int8',
        ):
        self.planes = []
        self.input_dim = input_dim
        np.random.seed(seed)
        for i in range(hash_dim):
            v = np.random.rand(input_dim)
            v_hat = v / np.linalg.norm(v)
            self.planes.append(v_hat)
    
        self.planes = np.matrix(self.planes)
        self.data = h5py.File(hdf5_file, "w")
        self.chunksize = chunksize
        self.buckets = {}
        self.id_buckets = {}
        self.dtype = dtype
    
    # Returns LSH of a vector
    def hash(self, vector):
        hash_vector = np.where((self.planes @ vector) < 0, 1, 0)[0]
        hash_string = "".join([str(num) for num in hash_vector])
        return hash_string
    
    def quantize(self, vector_list):
        vector_list = np.array(vector_list)
        if self.dtype in ['float16', 'float32']:
            return vector_list.astype(self.dtype)
        if self.dtype == 'int8':
            return np.asarray(vector_list * 128, dtype=np.int8)
        raise ValueError(f'dtype needs to be float32, float16 or int8')
    
    # Add vector to bucket
    def add(self, vector, i):
        hashed = self.hash(vector)
        
        if hashed not in self.buckets:
            self.buckets[hashed] = []
            self.id_buckets[hashed] = []
        
        self.buckets[hashed].append(vector)
        self.id_buckets[hashed].append(i)
        
        if len(self.buckets[hashed]) >= self.chunksize:
            if hashed not in self.data:
                self.data.create_dataset(hashed, (self.chunksize,self.input_dim), compression='gzip', dtype=self.dtype, chunks=True, maxshape=(None,self.input_dim))
                self.data.create_dataset(hashed+'_id', (self.chunksize,), compression='gzip', dtype='int32', chunks=True, maxshape=(None,))
            else:
                hf = self.data[hashed]
                hf_id = self.data[hashed+'_id']
                hf.resize((hf.shape[0] + self.chunksize), axis=0)
                hf_id.resize((hf_id.shape[0] + self.chunksize), axis=0)
            self.data[hashed][-self.chunksize:] = self.quantize(self.buckets[hashed])
            self.data[hashed+'_id'][-self.chunksize:] = self.id_buckets[hashed]
            self.buckets[hashed] = []
            self.id_buckets[hashed] = []
            
    def flush(self):
        for hashed, vectors in self.buckets.items():
            list_size = len(vectors)
            if hashed not in self.data:
                self.data.create_dataset(hashed, (list_size,self.input_dim), compression='gzip', dtype=self.dtype, chunks=True, maxshape=(None,self.input_dim))
                self.data.create_dataset(hashed+'_id', (list_size,), compression='gzip', dtype='int32', chunks=True, maxshape=(None,))
            else:
                hf = self.data[hashed]
                hf_id = self.data[hashed+'_id']
                hf.resize((hf.shape[0] + list_size), axis=0)
                hf_id.resize((hf_id.shape[0] + list_size), axis=0)
            self.data[hashed][-list_size:] = self.quantize(self.buckets[hashed])
            self.data[hashed+'_id'][-list_size:] = self.id_buckets[hashed]
            self.buckets[hashed] = []
            self.id_buckets[hashed] = []
    
    # Returns bucket vector is in
    def get(self, vector):
        hashed = self.hash(vector)
        
        if hashed in self.data:
            return self.data[hashed]
        
        return []

In [7]:
try:
    db['sentence_data'].drop()
    lsh_store.data.close()
except:
    pass

In [None]:
batch_size = 1_000
#i = 0
#r_i = 0

insert_thread = None


#lsh_store = LSH(chunksize=batch_size)

max_entries = db['review_data'].count()
#max_entries = 1_001

for review in tqdm(db['review_data'].find(), total=max_entries-i):
    if i >= max_entries:
        lsh_store.flush()
        break
    if i % batch_size == 0:
        if i > 0:
            review_l = []
            sentence_l = []
            start_index_l = []
            end_index_l = []
            for j, val in enumerate(splitter.split(texts)):
                for k, sentence in enumerate(val):
                    sentence = str(sentence)
                    strip_sentence = sentence.strip()
                    if len(strip_sentence) > 0:
                        review_l.append(ids[j])
                        sentence_l.append(strip_sentence)
                        if k >= 1:
                            start_index_l.append(end_index_l[-1] + 1)
                        else:
                            start_index_l.append(0)
                        end_index_l.append(start_index_l[-1] + len(sentence))
            embeddings = model.encode(sentence_l, convert_to_tensor=True)
            insert_list = []
            for k, (indv_review, sentence, start_index, end_index) in enumerate(zip(
                review_l,
                embeddings,
                start_index_l,
                end_index_l
            )):
                lsh_store.add(sentence.numpy(), r_i)
                insert_list.append({
                    '_id': r_i,
                    'review_id': indv_review,
                    's': start_index,
                    'e': end_index,
                })
                r_i += 1
            insert_thread = Thread(target=db['sentence_data'].insert_many, args=(insert_list,))
            insert_thread.start()
        texts = []
        ids = []
    texts.append(zlib.decompress(review['review_text']).decode())
    ids.append(review['_id'])
    i += 1

  # Remove the CWD from sys.path while we load stuff.


HBox(children=(FloatProgress(value=0.0, max=11213281.0), HTML(value='')))

Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/pool.py", line 1180, in connect
    sock = _configured_socket(self.address, self.opts)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/pool.py", line 988, in _configured_socket
    sock = _create_connection(address, options)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/pool.py", line 944, in _create_connection
    for res in socket.getaddrinfo(host, port, family, socket.SOCK_STREAM):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/socket.py", line 752, in getaddrinfo
    for res in _socket.getaddrinfo(host, port, family, type, proto, flags):
socket.gaierror: [Errno -3] Temporary failure in name resolution

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/mongo_client.py

Exception in thread Thread-602:
Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/mongo_client.py", line 1404, in _retry_internal
    server = self._select_server(writable_server_selector, session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/mongo_client.py", line 1278, in _select_server
    server = topology.select_server(server_selector)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/topology.py", line 243, in select_server
    address))
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/topology.py", line 200, in select_servers
    selector, server_timeout, address)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/topology.py", line 217, in _select_servers_loop
    (self._error_message(selector), timeout, self.description))
pymongo.errors.ServerSelectionTimeoutError: cluster0-shard-00-01.pdjrf.mon

Exception in thread Thread-606:
Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/collection.py", line 761, in insert_many
    blk.execute(write_concern, session=session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 528, in execute
    return self.execute_command(generator, write_concern, session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 358, in execute_command
    with client._tmp_session(session) as s:
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/contextlib.py", line 112, in __enter__
    return next(self.gen)
  File "/home/cdminix/anaconda3/envs/apaut/lib/pytho

Exception in thread Thread-609:
Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/collection.py", line 761, in insert_many
    blk.execute(write_concern, session=session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 528, in execute
    return self.execute_command(generator, write_concern, session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 358, in execute_command
    with client._tmp_session(session) as s:
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/contextlib.py", line 112, in __enter__
    return next(self.gen)
  File "/home/cdminix/anaconda3/envs/apaut/lib/pytho

Exception in thread Thread-612:
Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/collection.py", line 761, in insert_many
    blk.execute(write_concern, session=session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 528, in execute
    return self.execute_command(generator, write_concern, session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 358, in execute_command
    with client._tmp_session(session) as s:
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/contextlib.py", line 112, in __enter__
    return next(self.gen)
  File "/home/cdminix/anaconda3/envs/apaut/lib/pytho

Exception in thread Thread-615:
Traceback (most recent call last):
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 926, in _bootstrap_inner
    self.run()
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/threading.py", line 870, in run
    self._target(*self._args, **self._kwargs)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/collection.py", line 761, in insert_many
    blk.execute(write_concern, session=session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 528, in execute
    return self.execute_command(generator, write_concern, session)
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/site-packages/pymongo/bulk.py", line 358, in execute_command
    with client._tmp_session(session) as s:
  File "/home/cdminix/anaconda3/envs/apaut/lib/python3.7/contextlib.py", line 112, in __enter__
    return next(self.gen)
  File "/home/cdminix/anaconda3/envs/apaut/lib/pytho

In [10]:
lsh_store.flush()

# Vector Quantization Tests

In [17]:
import pickle5 as pickle
with open('test.pkl', "rb") as f:
  data = pickle.load(f)
data_small = data.sample(100_000, random_state=42)

In [18]:
n = 1000  #chunk row size
list_df = [data_small[i:i+n] for i in range(0,data_small.shape[0],n)]

In [19]:
corpus_embeddings = []
orig_reviews = data_small.index.values.tolist()
for df in tqdm(list_df):
    corpus_embeddings += model.encode(df['review_text'].values, convert_to_tensor=True)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [68]:
from sentence_transformers import SentenceTransformer, util
import torch

def run_query(queries, quant=32):

    # Find the closest 5 sentences of the corpus for each query sentence based on cosine similarity
    top_k = min(5, len(corpus_embeddings))
    for query_text in queries:
        query = model.encode(query_text, convert_to_tensor=True)
        corpus = torch.stack(corpus_embeddings)

        if quant == 16:
            corpus = torch.tensor(corpus.numpy().astype('float16').astype('float32'))
            query = torch.tensor(query.numpy().astype('float16').astype('float32'))
        if quant == 8:
            corpus = np.asarray(corpus * 128, dtype=np.int8).astype('float32')
            query = np.asarray(query * 128, dtype=np.int8).astype('float32')

        cos_scores = util.pytorch_cos_sim(query, corpus)[0]
        cos_scores = cos_scores.cpu()

        #We use torch.topk to find the highest 5 scores
        top_results = torch.topk(cos_scores, k=top_k)

        print("\n\n======================\n\n")
        print("Query:", query_text)
        print("Quantization:", quant)
        print("\nTop 5 most similar sentences in corpus:")

        for score, idx in zip(top_results[0], top_results[1]):
            print(data.loc[orig_reviews[idx]]['review_text'], "(Score: %.4f)" % (score))

In [69]:
run_query([
    'blew my socks off',
    'decent film',
    'quite bad',
    'could have more dogs',
    'peperami'
], quant=32)





Query: blew my socks off
Quantization: 32

Top 5 most similar sentences in corpus:
I laughed my ass off. (Score: 0.5259)
Blew me away!!!! (Score: 0.5193)
Blew me away! Twisted and unique. (Score: 0.4619)
Scared the hell out of me... (Score: 0.4187)
Sucked. (Score: 0.4004)




Query: decent film
Quantization: 32

Top 5 most similar sentences in corpus:
Decent story (Score: 0.5028)
Great book,really good movie (Score: 0.5020)
spectacular novel (Score: 0.4965)
very similar to the movie (Score: 0.4911)
The movie is better. (Score: 0.4860)




Query: quite bad
Quantization: 32

Top 5 most similar sentences in corpus:
Very bad. (Score: 0.8503)
Not terrible, but pretty bad. (Score: 0.7357)
Not great.... (Score: 0.7232)
Not bad. Not great, but not bad. (Score: 0.6885)
Not wonderful. (Score: 0.6870)




Query: could have more dogs
Quantization: 32

Top 5 most similar sentences in corpus:
Did not have enough fire-cats. Please add more fire-cats. (Score: 0.4206)
Rather than write a review, I'

In [70]:
run_query([
    'blew my socks off',
    'decent film',
    'quite bad',
    'could have more dogs',
    'peperami'
], quant=16)





Query: blew my socks off
Quantization: 16

Top 5 most similar sentences in corpus:
I laughed my ass off. (Score: 0.5260)
Blew me away!!!! (Score: 0.5193)
Blew me away! Twisted and unique. (Score: 0.4619)
Scared the hell out of me... (Score: 0.4187)
Sucked. (Score: 0.4003)




Query: decent film
Quantization: 16

Top 5 most similar sentences in corpus:
Decent story (Score: 0.5028)
Great book,really good movie (Score: 0.5020)
spectacular novel (Score: 0.4965)
very similar to the movie (Score: 0.4911)
The movie is better. (Score: 0.4860)




Query: quite bad
Quantization: 16

Top 5 most similar sentences in corpus:
Very bad. (Score: 0.8503)
Not terrible, but pretty bad. (Score: 0.7357)
Not great.... (Score: 0.7232)
Not bad. Not great, but not bad. (Score: 0.6885)
Not wonderful. (Score: 0.6870)




Query: could have more dogs
Quantization: 16

Top 5 most similar sentences in corpus:
Did not have enough fire-cats. Please add more fire-cats. (Score: 0.4206)
Rather than write a review, I'

In [71]:
run_query([
    'blew my socks off',
    'decent film',
    'quite bad',
    'could have more dogs',
    'peperami'
], quant=8)





Query: blew my socks off
Quantization: 8

Top 5 most similar sentences in corpus:
Blew me away!!!! (Score: 0.4751)
I laughed my ass off. (Score: 0.4402)
Blew me away! Twisted and unique. (Score: 0.4246)
Scared the hell out of me... (Score: 0.3950)
Creepy and abrupt! House of Leaves scared my pants off! Like literally my pants jumped off and ran away and I am now in my under ware!!!! (Score: 0.3657)




Query: decent film
Quantization: 8

Top 5 most similar sentences in corpus:
Great book,really good movie (Score: 0.5018)
spectacular novel (Score: 0.4851)
very similar to the movie (Score: 0.4850)
Good book, so was the movie (Score: 0.4723)
Decent story (Score: 0.4707)




Query: quite bad
Quantization: 8

Top 5 most similar sentences in corpus:
Very bad. (Score: 0.8193)
Not terrible, but pretty bad. (Score: 0.7040)
Not great.... (Score: 0.6986)
Not wonderful. (Score: 0.6644)
Not bad. Not great, but not bad. (Score: 0.6587)




Query: could have more dogs
Quantization: 8

Top 5 most 