<a href="https://colab.research.google.com/github/TatianaShavrina/gensim_BERT/blob/master/BERT_indexing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [36]:
!pip install pytorch_pretrained_bert nmslib



In [37]:
! pip install rusenttokenize



In [38]:
import os, re, json
import json
import nmslib
import torch
import random
import pandas as pd
from tqdm import tqdm
import numpy as np
#import tensorflow_hub as hub
from sklearn.feature_extraction.text import TfidfVectorizer,CountVectorizer
from joblib import dump, load
from string import punctuation
from operator import itemgetter
from functools import wraps
from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig
from sklearn.metrics.pairwise import cosine_similarity
from rusenttokenize import ru_sent_tokenize

In [39]:
def singleton(cls):
    instance = None
    @wraps(cls)
    def inner(*args, **kwargs):
        nonlocal instance
        if instance is None:
            instance = cls(*args, **kwargs)
        return instance
    return inner

In [40]:
class BertEmbedder(object):
    """
    Embedding Wrapper on Bert Multilingual Cased
    """
    def __init__(self, path=''):
        # use self.model_file with a path instead of 'bert-base-uncased' if you have a custom pretrained model
        self.model_file = 'bert-base-uncased'#os.path.join(path, "bert-base-multilingual-cased.tar.gz")
        self.vocab_file = 'bert-base-uncased'#os.path.join(path, "data_bert-base-multilingual-cased-vocab.txt")
        self.model = self.bert_model()
        self.tokenizer = self.bert_tokenizer()
        self.embedding_matrix = self.get_bert_embed_matrix()

    @singleton
    def bert_model(self):
        model = BertModel.from_pretrained(self.model_file).eval()
        return model

    @singleton
    def bert_tokenizer(self):
        tokenizer = BertTokenizer.from_pretrained(self.vocab_file, do_lower_case=False) 
        return tokenizer

    @singleton
    def get_bert_embed_matrix(self):
        bert_embeddings = list(self.model.children())[0]
        bert_word_embeddings = list(bert_embeddings.children())[0]
        matrix = bert_word_embeddings.weight.data.numpy()
        return matrix

    def sentence_embedding(self, text):
        token_list = self.tokenizer.tokenize("[CLS] " + text + " [SEP]")
        segments_ids, indexed_tokens = [1] * len(token_list), self.tokenizer.convert_tokens_to_ids(token_list)
        segments_tensors, tokens_tensor = torch.tensor([segments_ids]), torch.tensor([indexed_tokens])
        with torch.no_grad():
            encoded_layers, _ = self.model(tokens_tensor, segments_tensors)
        sent_embedding = torch.mean(encoded_layers[11], 1)
        return sent_embedding
    
    def sentences_embedding(self, text_list):
        embeddings = []
        for text in tqdm(text_list):
            token_list = self.tokenizer.tokenize("[CLS] " + text + " [SEP]")
            segments_ids, indexed_tokens = [1] * len(token_list), self.tokenizer.convert_tokens_to_ids(token_list)
            segments_tensors, tokens_tensor = torch.tensor([segments_ids]), torch.tensor([indexed_tokens])
            with torch.no_grad():
                encoded_layers, _ = self.model(tokens_tensor, segments_tensors)
            sent_embedding = torch.mean(encoded_layers[11], 1)
            embeddings.append(sent_embedding)
        return embeddings

    def token_embedding(self, token_list):
        token_embedding = []
        for token in token_list:
            ontoken = self.tokenizer.tokenize(token)
            segments_ids, indexed_tokens = [1] * len(ontoken), self.tokenizer.convert_tokens_to_ids(ontoken)
            segments_tensors, tokens_tensor = torch.tensor([segments_ids]), torch.tensor([indexed_tokens])
            with torch.no_grad():
                encoded_layers, _ = self.model(tokens_tensor, segments_tensors)
            ontoken_embeddings = []
            for subtoken_i in range(len(ontoken)):
                hidden_layers = []
                for layer_i in range(len(encoded_layers)):
                    vector = encoded_layers[layer_i][0][subtoken_i]
                    hidden_layers.append(vector)
                ontoken_embeddings.append(hidden_layers)
            cat_last_4_layers = [torch.cat((layer[-4:]), 0) for layer in ontoken_embeddings]
            token_embedding.append(cat_last_4_layers)
        token_embedding = torch.stack(token_embedding[0], 0) if len(token_embedding) > 1 else token_embedding[0][0]
        return token_embedding


In [47]:
class BertIndexer():
    def __init__(self, bert_model=BertEmbedder()):
        self.model = bert_model
        self.space_type = 'cosinesimil'
        self.method_name = 'hnsw'
        self.index = nmslib.init(space  = space_type,
                    method = method_name,
                    data_type = nmslib.DataType.DENSE_VECTOR,
                    dtype = nmslib.DistType.FLOAT)
        self.index_params = {'NN' : 15}
        self.index_is_loaded = False
        self.data = []

        class IndexError(Exception):
            """Base class for other exceptions"""
            pass
        
    def load_index(self, index_path, data):
        self.index.loadIndex( index_path, load_data=True)
        self.index_is_loaded = True
        self.data = data

    def create_index(self, index_path, data):
        names_sparse_matrix = self.make_data_embeddings(data)
        self.index.addDataPointBatch(data = names_sparse_matrix)
        self.index.createIndex( print_progress=True)
        self.index.saveIndex(index_path, save_data=True)
        self.index_is_loaded = True
        self.data = data
    
    def create_embedding(self, text):
        return self.model.sentence_embedding(text).numpy()
        
    def make_data_embeddings(self, data):
        names_sparse_matrix = []
        for i in tqdm(range(len(data))):
            try:
                names_sparse_matrix.append(self.model.sentence_embedding(data[i])[0].numpy())
            except:
                names_sparse_matrix.append(np.zeros(768))
        return names_sparse_matrix

    def return_closest(self, text, k=2, num_threads=2):
        if self.index_is_loaded:
            r = self.model.sentence_embedding(text).numpy()
            near_neighbors = self.index.knnQueryBatch(queries=[r], k=k, num_threads=num_threads)
            return [self.data[i] for i in near_neighbors[0][0]]
        else: raise IndexError("Index is not yet created or loaded")

    
    



The pre-trained model you are loading is an uncased model but you have set `do_lower_case` to False. We are setting `do_lower_case=True` for you but you may want to check this behavior.


In [48]:
indexer = BertIndexer()

In [49]:
data = ['мама мыла раму','папа мыл раму', 'Маша мыла раму', 'Россия - наше Отечество','Смерть неизбежна']

In [50]:
indexer.create_index('test_index', data)

100%|██████████| 5/5 [00:00<00:00,  8.96it/s]


In [51]:
indexer.return_closest('Россия', k=2)

['Россия - наше Отечество', 'Смерть неизбежна']

In [54]:
indexer.load_index('test_index', data)

In [55]:
indexer.return_closest('Россия', k=1)

['Россия - наше Отечество']