# 2. Сборка датасета 

In [1]:
import os
import sys
import random
import pickle

import pandas as pd
import numpy as np
from matplotlib import pyplot as plt

from tqdm import tqdm
from pandarallel import pandarallel

import pymorphy2
import nltk

from sklearn.model_selection import train_test_split

SEED = 1
random.seed(SEED)

pd.set_option('display.max_colwidth', 255)
tqdm.pandas()
pandarallel.initialize(progress_bar=True, nb_workers=8, use_memory_fs=False)

INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


## 1. Загрузка данных 

In [2]:
abbr = pd.read_csv("../data/abbr.csv")
lenta = pd.read_csv("../data/lenta.csv")

  lenta = pd.read_csv("../data/lenta.csv")


## Замена явных сокращений в текстках 

In [5]:
abbr_counter = abbr.groupby("abbr_norm").size()
explicit_abbr = abbr_counter[abbr_counter == 1].index.to_list()

explicit_dict = abbr[abbr["abbr_norm"].isin(explicit_abbr)]
explicit_dict = dict(explicit_dict[["abbr_norm", "desc_norm"]].values)

In [6]:
def replace_by_dict(line):
    for k, v in explicit_dict.items():
        line = line.replace(f" {k} ", f" {v} ")
    return line

lenta["text_norm"] = lenta["text_norm"].parallel_apply(replace_by_dict)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=85735), Label(value='0 / 85735')))…

## 2. Построение дерева сокращений

In [50]:
class AbbrInfo:
    def __init__(self, abbr_id, abbr, abbr_count):
        self.abbr_id = abbr_id 
        self.abbr = abbr
        self.abbr_count = abbr_count

ABBR_LIST_KEY = "<ABBR_LIST_KEY>"
        
def create_abbr_tree(abbr, abbr_list_key = ABBR_LIST_KEY):        
    tree = {}
    for norm_desc, norm_abbr, abbr_id, abbr_count in abbr[["desc_norm", 
                                                           "abbr_norm", 
                                                           "abbr_id", 
                                                           "abbr_count"]].values:
        words = norm_desc.split(" ")

        curr_tree = tree
        for word in words:
            if word not in curr_tree:
                curr_tree[word] = {}
            curr_tree = curr_tree[word]


        if abbr_list_key not in curr_tree:
            curr_tree[abbr_list_key] = []

        curr_tree[abbr_list_key].append(AbbrInfo(abbr_id, norm_abbr, abbr_count))
    return tree

In [51]:
abbr_tree = create_abbr_tree(abbr)
abbr_tree["министерство"]

{'оборона': {'украина': {'<ABBR_LIST_KEY>': [<__main__.AbbrInfo at 0x7fd91cf04a30>]}},
 'транспорт': {'<ABBR_LIST_KEY>': [<__main__.AbbrInfo at 0x7fd90c516be0>]},
 'национальный': {'безопасность': {'<ABBR_LIST_KEY>': [<__main__.AbbrInfo at 0x7fd90c3311c0>]}}}

## 3. Получение меток для токенов

In [52]:
def choice_abbr(abbr_list: list, 
                weighted_choice: bool = True, 
                add_to_zeros: float = 0):
    abbr_counts = []
    
    if weighted_choice:
        for abbr_info in abbr_list:
            cnt = abbr_info.abbr_count
            if cnt == 0:
                cnt = add_to_zeros
            abbr_counts.append(cnt)
    else:
        abbr_counts = None
    
    
    return random.choices(abbr_list, weights=abbr_counts, k=1)[0]

In [53]:
OUTSIDE_LABEL = "_"
BEGIN_LABEL = "B"
END_LABEL = "E"
INSIDE_LABEL = "I"
ONE_WORD_LABEL = "W"

def get_text_labels(text, 
                    abbr_tree, 
                    weighted_choice: bool = None, 
                    add_to_zeros: float = None):
    text = text.split(" ")
    labels = [OUTSIDE_LABEL for i in range(len(text))]

    curr_node = abbr_tree
    desc_start = None

    word_i = 0
    while word_i < len(text):
        curr_i = word_i
        while curr_i < len(text) and text[curr_i] in curr_node:
            curr_node = curr_node[text[curr_i]]
            curr_i += 1

        if ABBR_LIST_KEY in curr_node: 

            abbr_id = choice_abbr(curr_node[ABBR_LIST_KEY], weighted_choice, add_to_zeros).abbr_id

            labels[word_i] = f"{BEGIN_LABEL}-{abbr_id}"
            for j in range(word_i + 1, curr_i - 1): 
                labels[j] = f"{INSIDE_LABEL}-{abbr_id}"
            labels[curr_i - 1] = f"{END_LABEL}-{abbr_id}"

            if word_i == curr_i - 1:
                labels[word_i] = f"{ONE_WORD_LABEL}-{abbr_id}"

            word_i = curr_i - 1

        curr_node = abbr_tree
        word_i += 1
    return " ".join(labels)

get_text_labels(lenta.text_norm.iloc[0], abbr_tree)

'_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ W-1167 _ _ _ _ W-310 _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ W-748 _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ W-974 _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ W-567 _ _'

In [54]:
from ipymarkup         import show_box_markup
from ipymarkup.palette import palette, PALETTE, BLUE, RED, GREEN, PURPLE, BROWN, ORANGE

def show_markup(recipe,  tags, use_abbr: bool = True):
    mapper = lambda tag: tag[2:] if "-" in tag else tag
    
    tags  = [mapper(tag) for tag in tags]
    text  = ' '.join(recipe)
    spans = []
        
    start, end, tag = 0, len(recipe[0]), tags[0]
    
    for word, ttag in zip(recipe[1:], tags[1:]): 
        
        if tag == ttag:
            end  += 1 + len(word)
            
        else:
            span  = (start, end, tag)
            spans.append(span)
        
            start = 1 + end
            end  += 1 + len(word)
            
            if ttag != "_":
                if use_abbr:
                    label = abbr[abbr.abbr_id == int(ttag)].abbr_norm.iloc[0]
                else:
                    label = abbr[abbr.abbr_id == int(ttag)].desc_norm.iloc[0]
                ttag = label + f" ({ttag})"
            tag   = ttag
            
    span  = (start, end, tag)
    spans.append(span)        
            
    show_box_markup(text, spans)

In [55]:
text = lenta.text_norm.iloc[22]
labels = get_text_labels(text, abbr_tree)

show_markup(text.split(" "), labels.split(" "))

In [56]:
lenta["labels"] = lenta["text_norm"].parallel_apply(lambda x: get_text_labels(x, abbr_tree))

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=85735), Label(value='0 / 85735')))…

## 4. Замена слов на сокращения 

In [57]:
def replace_word_by_abbr(text, labels, abbr, p_replace: float = 0.2):
    text = text.split(" ")
    labels = labels.split(" ")
    
    new_text = []
    new_labels = []

    i = 0
    while i < len(text):
        label = labels[i]
        if label == OUTSIDE_LABEL:
            new_text.append(text[i])
            new_labels.append(OUTSIDE_LABEL)

        mode = label[0]

        if mode in [ONE_WORD_LABEL, BEGIN_LABEL]:
            abbr_id = int(label[2:])
            replaced = random.choices([False, True], weights=[(1 - p_replace), p_replace])[0]
            if replaced:
                norm_abbr = abbr[abbr.abbr_id == abbr_id].abbr_norm.iloc[0].split(" ")
                
                if len(norm_abbr) == 1:
                    new_text.append(norm_abbr[0])
                    new_labels.append(f"{ONE_WORD_LABEL}-{str(abbr_id)}")
                else:
                    new_text.append(norm_abbr[0])
                    new_labels.append(f"{BEGIN_LABEL}-{str(abbr_id)}")
                    for word in norm_abbr[1:-1]:
                        new_text.append(word)
                        new_labels.append(f"{INSIDE_LABEL}-{str(abbr_id)}")
                    new_text.append(norm_abbr[-1])
                    new_labels.append(f"{END_LABEL}-{str(abbr_id)}")

            while i < len(text) and labels[i] != OUTSIDE_LABEL and int(labels[i][2:]) == abbr_id:
                if not replaced:
                    new_text.append(text[i])
                    new_labels.append(OUTSIDE_LABEL)
                i += 1
        else:
            i += 1

    new_text = " ".join(new_text)
    new_labels = " ".join(new_labels)
    
    return pd.Series({"new_text": new_text, "new_labels": new_labels})

In [68]:
for i in range(10, 15):
    text = lenta.text_norm.iloc[i]
    labels = get_text_labels(text, abbr_tree)

    replaced_series = replace_word_by_abbr(text, labels, abbr)


    show_markup(replaced_series["new_text"].split(" "), 
                replaced_series["new_labels"].split(" "), 
                use_abbr=False)

In [69]:
lenta[["text_new", "labels_new"]] = (
    lenta[["text_norm", "labels"]]
        .parallel_apply(lambda x: replace_word_by_abbr(x["text_norm"], x["labels"], 
                                                       abbr, p_replace=0.3), axis=1)
)

VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=85735), Label(value='0 / 85735')))…

## Статистика 

In [80]:
from collections import Counter
counter = Counter({})
for i in tqdm(range(lenta.shape[0])):
    counter += Counter(lenta["labels_new"].iloc[i].split(" "))

100%|██████████| 685875/685875 [00:17<00:00, 39405.74it/s]


In [82]:
counter = counter.items()
sorted(counter, key=lambda x: x[1])

[('W-997', 3),
 ('W-1255', 5),
 ('W-939', 12),
 ('W-671', 13),
 ('W-574', 14),
 ('W-1246', 16),
 ('W-1165', 19),
 ('W-1164', 20),
 ('W-615', 21),
 ('W-472', 23),
 ('W-1057', 23),
 ('W-45', 24),
 ('W-1223', 26),
 ('W-795', 26),
 ('W-788', 26),
 ('W-1248', 27),
 ('W-6', 28),
 ('W-946', 29),
 ('W-637', 29),
 ('W-331', 29),
 ('W-1186', 29),
 ('W-768', 30),
 ('W-44', 30),
 ('W-824', 31),
 ('W-573', 31),
 ('W-983', 32),
 ('W-266', 33),
 ('W-251', 33),
 ('W-794', 33),
 ('W-765', 34),
 ('W-1215', 34),
 ('W-229', 35),
 ('W-352', 35),
 ('W-395', 35),
 ('W-535', 36),
 ('W-241', 36),
 ('W-817', 36),
 ('W-114', 36),
 ('W-79', 36),
 ('W-168', 36),
 ('W-639', 36),
 ('W-1042', 36),
 ('W-979', 37),
 ('W-64', 38),
 ('W-80', 38),
 ('W-868', 38),
 ('W-782', 38),
 ('W-953', 38),
 ('W-360', 39),
 ('W-536', 39),
 ('W-1156', 39),
 ('W-1218', 39),
 ('W-689', 40),
 ('W-7', 40),
 ('W-438', 40),
 ('W-1219', 40),
 ('W-217', 40),
 ('W-1125', 41),
 ('W-457', 41),
 ('W-493', 42),
 ('W-382', 42),
 ('W-630', 42),
 ('W-

## 5. Резделение на обучение и тест 

In [83]:
lenta_train, lenta_test = train_test_split(lenta, test_size=0.2, shuffle=True, random_state=SEED)

## 6. Сохранение данных

In [84]:
lenta_train.to_csv("../data/lenta_train.csv", index=False, header=True)
lenta_test.to_csv("../data/lenta_test.csv", index=False, header=True)