# Разработка модели контекстного перевода сокращений в текстах на русском языке
## 4. Word2Vec

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

import time

import pandas as pd
import numpy as np

import pickle
from collections import defaultdict

from tqdm import tqdm
import random

import gensim
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models.phrases import Phrases, Phraser


from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

from src.abbr import AbbrInfo, AbbrTree

In [2]:
with open('../data/train_texts.pickle', 'rb') as f:
    train_texts = pickle.load(f)
with open('../data/test_texts.pickle', 'rb') as f:
    test_texts = pickle.load(f)
    
with open('../data/train_labels.pickle', 'rb') as f:
    train_labels = pickle.load(f)
with open('../data/test_labels.pickle', 'rb') as f:
    test_labels = pickle.load(f)

train_texts = [text.tolist() for text in train_texts]
test_texts = [text.tolist() for text in test_texts]

train_flags = [labels != 0 for labels in train_labels]
test_flags = [labels != 0 for labels in test_labels]

abbr_info = AbbrInfo(data_dir="../data")

## Обучение модели 

In [11]:
class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self, verbose: int = 1):
        self.epoch = 0
        self.verbose = verbose
        self.start_sec = int(time.time())
        self.last_loss = None
        
    def on_epoch_end(self, model):
        if self.epoch % self.verbose == 0:
            loss = model.get_latest_training_loss()
            
            end_sec = int(time.time())
            diff_sec = end_sec - self.start_sec
            self.start_sec = end_sec
            
            if self.last_loss == None:
                self.last_loss = loss
            
            print(f'Epoch {self.epoch}, loss {loss - self.last_loss}, sec {diff_sec}\n')
            
            self.last_loss = loss
        self.epoch += 1
        
epoch_logger = EpochLogger(verbose=1)

In [10]:
phrases = Phrases(train_texts)
phraser = Phraser(phrases)

In [13]:
v2w = gensim.models.Word2Vec(sentences=phrases[train_texts], 
                             vector_size=50, window=5, min_count=5, workers=16,
                             sg=1, epochs=5, callbacks=[epoch_logger], compute_loss=True)

Epoch 0, loss 0.0, sec 378

Epoch 1, loss 24713128.0, sec 165

Epoch 2, loss 1427328.0, sec 169

Epoch 3, loss 1461640.0, sec 164

Epoch 4, loss 1464096.0, sec 165



## Решение  

In [17]:
desc_list = abbr_info.abbr2desc_list.get("цб")
desc_list

[('центральный', 'банк', 'российский', 'федерация'),
 ('центральный', 'бюро'),
 ('целиакальный', 'болезнь'),
 ('центр', 'безопасность'),
 ('центральный', 'батарея'),
 ('центр', 'биоинженерия'),
 ('центральный', 'база'),
 ('центральный', 'библиотека'),
 ('ценный', 'бумага')]

In [20]:
similar_dict = dict(v2w.wv.most_similar("цб", topn=1000))
similar_dict

{'центробанк': 0.9548358917236328,
 'цб_рф': 0.9392297267913818,
 'регулятор': 0.9232557415962219,
 'центробанк_рос': 0.9227277040481567,
 'кредитный_организация': 0.918311357498169,
 'фсфр': 0.901150643825531,
 'банк': 0.8988412618637085,
 'цбр': 0.8953655362129211,
 'фин_организация': 0.892299473285675,
 'финансовый_организация': 0.8896483182907104,
 'кред_учреждение': 0.885093092918396,
 'финорганизация': 0.8771950006484985,
 'финучреждение': 0.8749286532402039,
 'сбербанк': 0.8711209893226624,
 'кредитный_учреждение': 0.867357075214386,
 'депозитный': 0.8667817711830139,
 'нбб': 0.8634616136550903,
 'юниаструма_банк': 0.8604764342308044,
 'кредитный_учр': 0.8578700423240662,
 'сер_игнатьев': 0.857785701751709,
 'эмитент': 0.8576349020004272,
 'страховщик_осаго': 0.8553304076194763,
 'михаил_сухов': 0.8544135689735413,
 'асв': 0.8538258671760559,
 'нацбанк_белоруссия': 0.8537482619285583,
 'втб_24': 0.8536747694015503,
 'фрс': 0.8532385230064392,
 'рублёвый_депозит': 0.8517798185348

In [21]:
w2v_model = Word2VecSimilarModel(abbr_info)
w2v_model._get_most_similar_desc(desc_list, similar_dict)

NameError: name 'Word2VecSimilarModel' is not defined

In [22]:
class AbbrEstimator:
    def __init__(self, true_labels: list, pred_labels: list):
        self.true_labels = true_labels
        self.pred_labels = pred_labels
        
        self.true_labels_stacked = np.hstack(self.true_labels)
        self.pred_labels_stacked = np.hstack(self.pred_labels)
    
    def detection_score(self):
        detection_score = accuracy_score(self.true_labels_stacked != 0, 
                                         self.pred_labels_stacked != 0)
        return detection_score
    
    def replacement_score(self):
        labels_mask = (self.true_labels_stacked | self.pred_labels_stacked) != 0
        replacement_score = accuracy_score(self.true_labels_stacked[labels_mask],
                                           self.pred_labels_stacked[labels_mask])
        return replacement_score
    
    def print_scores(self):
        print("Detection score: ", self.detection_score())
        print("Replacement score: ", self.replacement_score())


In [24]:
class Word2VecSimilarModel:
    def __init__(self, abbr_info: AbbrInfo, w2v_model):
        self.abbr_info = abbr_info
        self.w2v_model = w2v_model
    
    def predict(self, texts: list):
        pred_labels = []
        for text in tqdm(texts):
            curr_text_labels = self._predict_text(text)
            pred_labels.append(np.array(curr_text_labels))
        return pred_labels
        
    def _predict_text(self, text: list):
        curr_text_labels = []
        for word in text:
            desc_list = self.abbr_info.abbr2desc_list.get(word)
            if (desc_list is not None) and (word in self.w2v_model.wv.key_to_index):
                desc = self._get_most_similar_desc(word, desc_list)
                abbr_id = self.abbr_info.desc2id[desc]
            else:
                abbr_id = 0
            curr_text_labels.append(abbr_id)
        return curr_text_labels
                
    def _get_most_similar_desc(self, abbr, desc_list):
        best_desc = None
        best_distance = None
        
        for desc in desc_list:
            desc_distance = 0
        
            if ("_".join(desc) in self.w2v_model.wv.key_to_index):
                desc_distance = self.w2v_model.wv.distance("_".join(desc), abbr)

             
            if best_distance is None or desc_distance < best_distance:
                best_desc = desc
                best_distance = desc_distance

        return best_desc
    
w2v_model = Word2VecSimilarModel(abbr_info, v2w)
pred_labels = w2v_model.predict(test_texts[])
AbbrEstimator(test_labels, pred_labels).print_scores()

  0%|          | 698/159328 [00:04<17:04, 154.82it/s]


KeyboardInterrupt: 