In [1]:
import os
import sys
import string 

import pandas as pd
import numpy as np
import random
import itertools

from matplotlib import pyplot as plt

from tqdm import tqdm

from pandarallel import pandarallel

import pymorphy2
import nltk
import pickle

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from keras.preprocessing.sequence import pad_sequences

import gensim
from gensim.models.callbacks import CallbackAny2Vec
from gensim.models.phrases import Phrases, Phraser

import time

sys.path.append("..")
from src import *

SEED = 1
def init_random_seed(value=0):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)
    torch.cuda.manual_seed(value)
    torch.backends.cudnn.deterministic = True
init_random_seed(SEED)
    
pd.set_option('display.max_colwidth', 255)
tqdm.pandas()
pandarallel.initialize(progress_bar=True, nb_workers=8, use_memory_fs=False)

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


In [2]:
abbr = pd.read_csv("../data/abbr.csv")
lenta_train = pd.read_csv("../data/lenta_train.csv")
lenta_test = pd.read_csv("../data/lenta_test.csv")

In [3]:
id2abbr = {}
abbr2id = {}
id2desc = {}
desc2id = {}
for idx, abbr_name, desc in abbr[["abbr_id", "abbr_norm", "desc_norm"]].values:
    id2abbr[idx] = abbr_name
    abbr2id[abbr_name] = idx
    id2desc[idx] = desc
    desc2id[desc] = idx

In [4]:
lenta_train["labels_new"] = (
    lenta_train["labels_new"].str.replace("W", "")
                            .str.replace("B", "")
                            .str.replace("E", "")
                            .str.replace("-", "")
)
train_texts = list(map(lambda x: x.split(" "), lenta_train["text_new"].to_list()))
train_labels = list(map(lambda x: x.split(" "), lenta_train["labels_new"].to_list()))

lenta_test["labels_new"] = (
    lenta_test["labels_new"].str.replace("W", "")
                            .str.replace("B", "")
                            .str.replace("E", "")
                            .str.replace("-", "")
)
test_texts = list(map(lambda x: x.split(" "), lenta_test["text_new"].to_list()))
test_labels = list(map(lambda x: x.split(" "), lenta_test["labels_new"].to_list()))

In [5]:
PAD_TOKEN = "<PAD>"
PAD_TOKEN_ID = 0

PAD_LABEL = "<NOLABEL>"
PAD_LABEL_ID = 0

EMPTY_LABEL = "_"
EMPTY_LABEL_ID = 1

train_texts_global = list(itertools.chain(*train_texts))
train_labels_global = list(itertools.chain(*train_labels))
train_labels_global = list(filter(lambda x: x != EMPTY_LABEL, train_labels_global))

UNIQUE_TOKENS = [PAD_TOKEN] + list(set(train_texts_global))
UNIQUE_LABELS = [PAD_LABEL, EMPTY_LABEL] + list(set(train_labels_global))

token2id = {label: i for i, label in enumerate(UNIQUE_TOKENS)}
id2token = {i: label for label, i in token2id.items()}

label2id = {label: i for i, label in enumerate(UNIQUE_LABELS)}
id2label = {i: label for label, i in label2id.items()}

MAX_SENTENCE_LEN = lenta_train.text_new.str.split(" ").str.len().max()
train_size = len(train_texts)
test_size = len(test_texts)
TOKENS_NUM = len(UNIQUE_TOKENS)
LABELS_NUM = len(UNIQUE_LABELS)

print(MAX_SENTENCE_LEN, train_size, test_size, TOKENS_NUM, LABELS_NUM)

280 548700 137175 587151 1275


## 3. Обучение модели на n-граммах 

In [None]:
connector_words = nltk.corpus.stopwords.words("russian")
phrase_model = Phrases(
    train_texts, 
    min_count=5, 
    threshold=10,
    connector_words=connector_words
).freeze()

In [7]:
class EpochLogger(CallbackAny2Vec):
    '''Callback to log information about training'''

    def __init__(self, verbose: int = 1):
        self.epoch = 0
        self.verbose = verbose
        self.start_sec = int(time.time())
        self.last_loss = None
        
    def on_epoch_end(self, model):
        if self.epoch % self.verbose == 0:
            loss = model.get_latest_training_loss()
            
            end_sec = int(time.time())
            diff_sec = end_sec - self.start_sec
            self.start_sec = end_sec
            
            if self.last_loss == None:
                self.last_loss = loss
            
            print(f'Epoch {self.epoch}, loss {loss - self.last_loss}, sec {diff_sec}\n')
            
            self.last_loss = loss
        self.epoch += 1

epoch_logger = EpochLogger(verbose=1)

model = gensim.models.Word2Vec(
    sentences=phrase_model[train_texts], 
    vector_size=64, window=5, min_count=5, workers=16,
    sg=1, epochs=20, callbacks=[epoch_logger], compute_loss=True)

Epoch 0, loss 0.0, sec 111

Epoch 1, loss 18307920.0, sec 65

Epoch 2, loss 13422032.0, sec 66

Epoch 3, loss 1285320.0, sec 66

Epoch 4, loss 1361448.0, sec 64

Epoch 5, loss 1308896.0, sec 63

Epoch 6, loss 1334672.0, sec 63

Epoch 7, loss 1380768.0, sec 62

Epoch 8, loss 1426192.0, sec 61

Epoch 9, loss 1294024.0, sec 60

Epoch 10, loss 1387312.0, sec 60

Epoch 11, loss 1382384.0, sec 60

Epoch 12, loss 1230976.0, sec 63

Epoch 13, loss 1377960.0, sec 63

Epoch 14, loss 1305464.0, sec 63

Epoch 15, loss 1335024.0, sec 61

Epoch 16, loss 1329544.0, sec 63

Epoch 17, loss 1293096.0, sec 63

Epoch 18, loss 1286224.0, sec 64

Epoch 19, loss 1234632.0, sec 63



In [6]:
model_name = os.path.join("../models", "emb_64.word2vec")
# model.save(model_name)
model = gensim.models.Word2Vec.load(model_name)

In [7]:
model.wv.most_similar("звук", topn=10)

[('шум', 0.8209859728813171),
 ('звуковой_волна', 0.8093071579933167),
 ('высокочастотный_звук', 0.8066103458404541),
 ('низкочастотный_звук', 0.7903728485107422),
 ('громкий_звук', 0.7900891304016113),
 ('звуковой', 0.7884595394134521),
 ('рычание', 0.777118444442749),
 ('издавать_звук', 0.7766465544700623),
 ('трель', 0.7728555798530579),
 ('мерцание', 0.7690140008926392)]

In [126]:
def lev_dist(a, b):
    a = list(a)
    b = list(b)
    def recursive(i, j):
        if i == 0 or j == 0:
            return max(i, j)
        elif a[i - 1] == b[j - 1]:
            return recursive(i - 1, j - 1)
        else:
            return 1 + min(
                recursive(i, j - 1),  # удаление
                recursive(i - 1, j),   # вставка
                recursive(i - 1, j - 1)  # замена
            )
    return recursive(len(a), len(b))

def intersection_dist(a, b):
    a = set(a)
    b = set(b)
    return len(a) - len(a.intersection(b))

def prefix_dist(a, b):
    max_i = 0
    for i in range(len(a)):
        if a[:i] in b:
            max_i = max(i + 1, max_i)
    return len(a) - max_i

def get_desc_find_by_dist(word, w2v_model, topn=5, dist=prefix_dist):
    desc_score_dist = []
    for desc, score in w2v_model.wv.most_similar(word, topn=topn):
        desc = desc.replace("_", " ")
        desc_score_dist.append([desc, score, dist(word, desc)])
    desc_score_dist = sorted(desc_score_dist, key=lambda x: x[2])
    return desc_score_dist[0][0]

In [137]:
def get_desc_first(word, w2v_model):
    first_desc = w2v_model.wv.most_similar(word, topn=1)[0][0]
    first_desc = first_desc.replace("_", " ")
    return first_desc

def get_desc_find_in_dict(word, w2v_model, desc2id):
    for desc, score in w2v_model.wv.most_similar(word, topn=10):
        desc = desc.replace("_", " ")
        if desc in desc2id.keys():
            return desc
        
    return get_desc_first(word, w2v_model)
        

def get_token2desc(tokens, w2v_model, desc2id, 
                   get_desc_f=get_desc_find_by_dist, topn=5, dist=intersection_dist):
    abbr_detection = AbbrDetection()

    token2desc = {}
    for token in tqdm(tokens):
        if abbr_detection.word_is_abbr(token):
            if token in w2v_model.wv.key_to_index:
                desc = get_desc_find_by_dist(token, w2v_model, 
                                             topn=topn, dist=intersection_dist)
                label = desc2id.get(desc, "_")
            else:
                label = "_"
        else:
            label = "_"
        token2desc[token] = label
    return token2desc
    

In [None]:
for dist in [lev_dist, intersection_dist, prefix_dist]:
    for topn in [1, 3, 5, 10, 20]:
        print(dist.__name__, topn)
        token2desc = get_token2desc(tokens=list(token2id.keys()), 
                                    w2v_model=model, 
                                    desc2id=desc2id, 
                                    topn=topn, 
                                    dist=dist)

        preds = []
        for text in tqdm(test_texts):
            labels = []
            for word in text:
                if word in token2desc:
                    label = token2desc[word]  
                else:
                    label = "_"
                labels.append(str(label))
            preds.append(labels)

        test_labels_global = list(itertools.chain(*test_labels))
        test_preds_global = list(itertools.chain(*preds))

        test_labels_global_upd = []
        for label_id in test_labels_global:
            if label_id == "_":
                test_labels_global_upd.append("_")
            else:
                test_labels_global_upd.append(id2desc.get(int(label_id), "_"))

        test_preds_global_upd = []
        for label_id in test_preds_global:
            if label_id == "_":
                test_preds_global_upd.append("_")
            else:
                test_preds_global_upd.append(id2desc.get(int(label_id), "_"))

        f1 = f1_score(test_labels_global_upd, test_preds_global_upd, average="macro")
        print(f1)
        print()

lev_dist 1


100%|██████████| 587151/587151 [04:16<00:00, 2291.81it/s]
100%|██████████| 137175/137175 [00:03<00:00, 37372.58it/s]


0.2740340771686382

lev_dist 3


100%|██████████| 587151/587151 [06:14<00:00, 1566.57it/s]
100%|██████████| 137175/137175 [00:06<00:00, 22290.57it/s]


0.2820024683059773

lev_dist 5


100%|██████████| 587151/587151 [05:48<00:00, 1684.53it/s]
100%|██████████| 137175/137175 [00:03<00:00, 35690.21it/s]


0.2818605053388838

lev_dist 10


100%|██████████| 587151/587151 [06:28<00:00, 1513.27it/s]
100%|██████████| 137175/137175 [00:06<00:00, 20702.10it/s]


In [73]:
print(classification_report(test_labels_global_upd, test_preds_global_upd))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                          precision    recall  f1-score   support

                                                       _       1.00      1.00      1.00  22326537
                                                   абрам       1.00      1.00      1.00        32
                                              абсолютный       1.00      1.00      1.00       251
                                                  август       1.00      1.00      1.00      4879
                                авиационный безопасность       0.00      0.00      0.00        16
                                    авиационный комплекс       0.00      0.00      0.00        11
                                                 автобус       0.00      0.00      0.00       820
                                                 автомат       0.00      0.00      0.00       311
                                      автомат калашников       0.00      0.00      0.00        84
                   

  _warn_prf(average, modifier, msg_start, len(result))
