# Разработка модели контекстного перевода сокращений в текстах на русском языке
## 4. Word2Vec

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

import time

import pandas as pd
import numpy as np

import pickle
from collections import defaultdict

from tqdm import tqdm
import random

import gensim
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models.phrases import Phrases, Phraser

from src.abbr import AbbrInfo, AbbrTree

In [2]:
with open('../data/train_texts.pickle', 'rb') as f:
    train_texts = pickle.load(f)
with open('../data/test_texts.pickle', 'rb') as f:
    test_texts = pickle.load(f)
    
with open('../data/train_labels.pickle', 'rb') as f:
    train_labels = pickle.load(f)
with open('../data/test_labels.pickle', 'rb') as f:
    test_labels = pickle.load(f)

train_texts = [text.tolist() for text in train_texts]
test_texts = [text.tolist() for text in test_texts]

train_flags = [labels != 0 for labels in train_labels]
test_flags = [labels != 0 for labels in test_labels]

abbr_info = AbbrInfo(data_dir="../data")

## Обучение модели 

In [3]:
class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self, verbose: int = 1):
        self.epoch = 0
        self.verbose = verbose
        self.start_sec = int(time.time())
        self.last_loss = None
        
    def on_epoch_end(self, model):
        if self.epoch % self.verbose == 0:
            loss = model.get_latest_training_loss()
            
            end_sec = int(time.time())
            diff_sec = end_sec - self.start_sec
            self.start_sec = end_sec
            
            if self.last_loss == None:
                self.last_loss = loss
            
            print(f'Epoch {self.epoch}, loss diff {loss - self.last_loss}, sec {diff_sec}\n')
            
            self.last_loss = loss
        self.epoch += 1

In [10]:
epoch_logger = EpochLogger(verbose=10)
v2w = gensim.models.Word2Vec(sentences=train_texts[:100000], 
                             vector_size=100, window=5, min_count=5, workers=16,
                             sg=1, epochs=100, callbacks=[epoch_logger], compute_loss=True)

Epoch 0, loss diff 0.0, sec 8

Epoch 10, loss diff 55190212.0, sec 57

Epoch 20, loss diff 3259488.0, sec 57

Epoch 30, loss diff 3341856.0, sec 56

Epoch 40, loss diff 3128040.0, sec 57

Epoch 50, loss diff 2999896.0, sec 55

Epoch 60, loss diff 3094904.0, sec 58

Epoch 70, loss diff 2892872.0, sec 57

Epoch 80, loss diff 2855048.0, sec 57

Epoch 90, loss diff 2753600.0, sec 57



## Решение  

In [37]:
desc_list = abbr_info.abbr2desc_list.get("цб")
desc_list

[('центральный', 'банк', 'российский', 'федерация'),
 ('центральный', 'бюро'),
 ('целиакальный', 'болезнь'),
 ('центр', 'безопасность'),
 ('центральный', 'батарея'),
 ('центр', 'биоинженерия'),
 ('центральный', 'база'),
 ('центральный', 'библиотека'),
 ('ценный', 'бумага')]

In [38]:
similar_dict = dict(v2w.wv.most_similar("цб", topn=1000))
similar_dict

{'центробанк': 0.9093866348266602,
 'регулятор': 0.8473715782165527,
 'банк': 0.7962188124656677,
 'финорганизация': 0.7621607780456543,
 'госбанк': 0.7564061880111694,
 'асв': 0.754317045211792,
 'сбербанк': 0.7407891154289246,
 'фсфр': 0.7381514310836792,
 'минфин': 0.7177820801734924,
 'валютный': 0.7171207666397095,
 'нацбанк': 0.7029962539672852,
 'кредитный': 0.6979836225509644,
 'депозит': 0.6976214051246643,
 'финрынок': 0.6873639225959778,
 'росбанк': 0.6858466267585754,
 'ликвидность': 0.6835777759552002,
 'цбр': 0.683411180973053,
 'валюта': 0.6786748170852661,
 'депозитный': 0.6755256652832031,
 'набиуллин': 0.6735376119613647,
 'нпф': 0.6712146401405334,
 'вэб': 0.6681516766548157,
 'банковский': 0.6670024991035461,
 'внешэкономбанк': 0.6666185259819031,
 'нбу': 0.6620579361915588,
 'облигация': 0.6605402827262878,
 'беззалоговый': 0.6589810848236084,
 'золотовалютный': 0.658361554145813,
 'втб': 0.6561356782913208,
 'минэкономразвития': 0.6530010104179382,
 'ммвб': 0.6519

In [43]:
w2v_model = Word2VecSimilarModel(abbr_info)
w2v_model._get_most_similar_desc(desc_list, similar_dict)

('центральный', 'банк', 'российский', 'федерация') 0.19905470311641693
('центральный', 'бюро') 0.0
('целиакальный', 'болезнь') 0.0
('центр', 'безопасность') 0.0
('центральный', 'батарея') 0.0
('центр', 'биоинженерия') 0.0
('центральный', 'база') 0.0
('центральный', 'библиотека') 0.0
('ценный', 'бумага') 0.5141985863447189


('ценный', 'бумага')

In [53]:
class Word2VecSimilarModel:
    def __init__(self, abbr_info: AbbrInfo):
        self.abbr_info = abbr_info
    
    def predict(self, texts: list):
        pred_labels = []
        for text in tqdm(texts):
            curr_text_labels = self._predict_text(text)
            pred_labels.append(np.array(curr_text_labels))
        pred_labels
        
    def _predict_text(self, text: list):
        curr_text_labels = []
        for word in text:
            desc_list = self.abbr_info.abbr2desc_list.get(word)
            if desc_list is not None:
                if word in v2w.wv.key_to_index:
                    similar_dict = dict(v2w.wv.most_similar(word, topn=10000))
                    desc = self._get_most_similar_desc(desc_list, similar_dict)
                    abbr_id = self.abbr_info.desc2id[desc]
                else:
                    abbr_id = 0
            else:
                abbr_id = 0
            curr_text_labels.append(abbr_id)
        return curr_text_labels
                
    def _get_most_similar_desc(self, desc_list, similar_dict):
        best_desc = None
        best_similar = None
        for desc in desc_list:
            desc_similar = 0
            for word in desc:
                word_similar = similar_dict.get(word)
                desc_similar += word_similar if word_similar is not None else 0
            desc_similar /= len(desc)
             
            if best_similar is None or desc_similar > best_similar:
                best_desc = desc
                best_similar = desc_similar
        return best_desc

In [54]:
w2v_model = Word2VecSimilarModel(abbr_info)
pred_labels = w2v_model.predict(test_texts)
AbbrEstimator(test_labels, pred_labels).print_scores()

  0%|          | 135/159328 [01:28<28:57:31,  1.53it/s]


KeyboardInterrupt: 