# Разработка модели контекстного перевода сокращений в текстах на русском языке
## 3. Простые пайплайны и измерение качества

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

import pandas as pd
import numpy as np

import pickle
from collections import defaultdict

from tqdm import tqdm
import random

from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

from src.abbr import AbbrInfo, AbbrTree

In [2]:
with open('../data/train_texts.pickle', 'rb') as f:
    train_texts = pickle.load(f)
with open('../data/test_texts.pickle', 'rb') as f:
    test_texts = pickle.load(f)
    
with open('../data/train_labels.pickle', 'rb') as f:
    train_labels = pickle.load(f)
with open('../data/test_labels.pickle', 'rb') as f:
    test_labels = pickle.load(f)
    
train_flags = [labels != 0 for labels in train_labels]
test_flags = [labels != 0 for labels in test_labels]

abbr_info = AbbrInfo(data_dir="../data")

In [3]:
train_texts[0][train_flags[0]]

array(['росс', 'изд', 'рос', 'м', 'проц', 'росс', 'гг', 'н', 'м', 'проц',
       'заявл', 'запад', 'ко', 'ко', 'млд', 'рос'], dtype='<U17')

## 1. Оценка решения

In [4]:
class AbbrEstimator:
    def __init__(self, true_labels: list, pred_labels: list):
        self.true_labels = true_labels
        self.pred_labels = pred_labels
        
        self.true_labels_stacked = np.hstack(self.true_labels)
        self.pred_labels_stacked = np.hstack(self.pred_labels)
    
    def detection_score(self):
        detection_score = accuracy_score(self.true_labels_stacked != 0, 
                                         self.pred_labels_stacked != 0)
        return detection_score
    
    def replacement_score(self):
        labels_mask = (self.true_labels_stacked | self.pred_labels_stacked) != 0
        replacement_score = accuracy_score(self.true_labels_stacked[labels_mask],
                                           self.pred_labels_stacked[labels_mask])
        return replacement_score
    
    def print_scores(self):
        print("Detection score: ", self.detection_score())
        print("Replacement score: ", self.replacement_score())
    
AbbrEstimator(test_labels, test_labels).print_scores()

Detection score:  1.0
Replacement score:  1.0


## 2. Случайный выбор из словаря 

In [5]:
class AbbrRandomSearcningModel:
    def __init__(self, abbr_info: AbbrInfo):
        self.abbr_info = abbr_info
        
    def predict(self, texts: list):
        pred_labels = []
        for text in tqdm(texts):
            curr_text_labels = []
            for word in text:
                desc_list = self.abbr_info.abbr2desc_list.get(word)
                if desc_list is not None:
                    word_id = self.abbr_info.desc2id[desc_list[0]]
                else:
                    word_id = 0
                curr_text_labels.append(word_id)
            pred_labels.append(np.array(curr_text_labels))
        return pred_labels   
    
random_searhing_model = AbbrRandomSearcningModel(abbr_info)
pred_labels = random_searhing_model.predict(test_texts)

AbbrEstimator(test_labels, pred_labels).print_scores()

100%|██████████| 159328/159328 [00:09<00:00, 17076.02it/s]


Detection score:  0.7682640370728182
Replacement score:  0.0809094447082947


## 3. Выбор из словаря по частоте в тексте

In [6]:
class AbbrFreqSearcningModel:
    def __init__(self, abbr_info: AbbrInfo):
        self.abbr_info = abbr_info
        self.abbr_tree = AbbrTree(abbr_info)
        
    def predict(self, texts: list):
        pred_labels = []
        for text in tqdm(texts):
            desc_labels = self.abbr_tree.get_text_labels(text)
            
            curr_text_labels = []
            for word in text:
                desc_list = self.abbr_info.abbr2desc_list.get(word)
                if desc_list is not None:
                    desc = self._get_most_freq_desc(desc_labels, desc_list)
                    abbr_id = self.abbr_info.desc2id[desc]
                else:
                    abbr_id = 0
                curr_text_labels.append(abbr_id)
            pred_labels.append(np.array(curr_text_labels))
        return pred_labels
    
    def _get_most_freq_desc(self, desc_labels: list, desc_list: list):
        desc_freq = []
        for desc in desc_list:
            abbr_id = self.abbr_info.desc2id[desc]
            desc_count = desc_labels.count(abbr_id) // len(desc)
            desc_freq.append((desc, desc_count))
        return sorted(desc_freq, key=lambda x: -x[1])[0][0]
    
abbr_freq_model = AbbrFreqSearcningModel(abbr_info)
pred_labels = abbr_freq_model.predict(test_texts)

AbbrEstimator(test_labels, pred_labels).print_scores()

100%|██████████| 159328/159328 [02:41<00:00, 987.24it/s] 


Detection score:  0.7682640370728182
Replacement score:  0.12046059773390325


## 3. Выбор из словаря по частоте во всех текстах 

In [42]:
class AbbrGlobalFreqSearcningModel:
    def __init__(self, abbr_info: AbbrInfo):
        self.abbr_info = abbr_info
        self.abbr_tree = AbbrTree(abbr_info)
        
    def predict(self, texts: list):
        global_labels = self._get_global_labels(texts)
        global_counter = Counter(global_labels)
        
        pred_labels = []
        for text in tqdm(texts):
            curr_text_labels = []
            for word in text:
                desc_list = self.abbr_info.abbr2desc_list.get(word)
                if desc_list is not None:
                    desc = self._get_most_freq_desc(global_counter, desc_list)
                    abbr_id = self.abbr_info.desc2id[desc]
                else:
                    abbr_id = 0
                curr_text_labels.append(abbr_id)
            pred_labels.append(np.array(curr_text_labels))
        return pred_labels

    def _get_global_labels(self, texts: list):
        desc_labels = [self.abbr_tree.get_text_labels(text) for text in tqdm(texts)]
        global_labels = np.hstack(desc_labels)
        global_labels = global_labels[global_labels != 0]
        return global_labels
    
    def _get_most_freq_desc(self, global_counter: Counter, desc_list: list):
        desc_freq = []
        for desc in desc_list:
            abbr_id = self.abbr_info.desc2id[desc]
            desc_count = global_counter[abbr_id] // len(desc)
            desc_freq.append((desc, desc_count))
        return sorted(desc_freq, key=lambda x: -x[1])[0][0]
    

abbr_global_freq_model = AbbrGlobalFreqSearcningModel(abbr_info)
pred_labels = abbr_global_freq_model.predict(test_texts)

AbbrEstimator(test_labels, pred_labels).print_scores()

100%|██████████| 159328/159328 [00:07<00:00, 21526.33it/s]
100%|██████████| 159328/159328 [00:38<00:00, 4125.62it/s]


Detection score:  0.7682640370728182
Replacement score:  0.18411176225584427
