In [0]:
#шаг1: скачать репозиторий FactFuEval-2016 (https://github.com/dialogue-evaluation/factRuEval-2016)
#шаг2: внутри репозитория создаем директории FactRuEval2016_results и FactRuEval2016_results/results_of_elmo_and_crf (для результатов)
#шаг3: устанавливаем библиотеку deep_ner (pip install deep_ner) 
#в данном примере весь код из deep_ner вынесен в ноутбук для понимания зависимостей и возможного рефакторинга кода
#шаг4: для обучения на наших текстах будем использовать предобученные ELMo эмбеддинги команды deeppavlov (http://docs.deeppavlov.ai/en/master/features/pretrained_vectors.html)

#работа выполена в google.colab

#!unzip factRuEval-2016-master.zip
#!pip install deep_ner

#требования по библиотекам:
#nltk==3.4.5
#numpy==1.18.1
#scikit-learn==0.22.1
#scipy==1.4.1
#tensorboard==2.1.0
#tensorflow==1.15.0
#tensorflow-hub==0.8.0
#bert-tensorflow==1.0.1
#spacy-udpipe==0.2.0
#spacy==2.2.3
#pymorphy2==0.8
#rusenttokenize==0.0.5

In [0]:
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.validation import check_is_fitted
from typing import Union, Tuple, List, Dict, Set
import numpy as np
import tensorflow as tf
import tensorflow_hub as tfhub
from logging import Logger
import spacy_udpipe
from spacy_udpipe import UDPipeLanguage
from tensorflow.python.framework import ops


import os
import codecs
import json
import random
import logging
import re
import tempfile
import copy
import csv

In [0]:
class ELMo_NER(BaseEstimator, ClassifierMixin):
    def __init__(self, elmo_hub_module_handle: str, udpipe_lang: str, use_additional_features: bool = False,
                 finetune_elmo: bool=False, batch_size: int = 32, max_seq_length: int = 32, lr: float = 1e-4,
                 l2_reg: float = 1e-5, validation_fraction: float = 0.1, max_epochs: int = 10, patience: int = 3,
                 gpu_memory_frac: float = 1.0, verbose: bool = False, random_seed: Union[int, None] = None):
        self.udpipe_lang = udpipe_lang
        self.use_additional_features = use_additional_features
        self.batch_size = batch_size
        self.lr = lr
        self.l2_reg = l2_reg
        self.elmo_hub_module_handle = elmo_hub_module_handle
        self.finetune_elmo = finetune_elmo
        self.max_epochs = max_epochs
        self.patience = patience
        self.random_seed = random_seed
        self.gpu_memory_frac = gpu_memory_frac
        self.max_seq_length = max_seq_length
        self.validation_fraction = validation_fraction
        self.verbose = verbose

    def __del__(self):
        if hasattr(self, 'classes_list_'):
            del self.classes_list_
        if hasattr(self, 'shapes_list_'):
            del self.shapes_list_
        if hasattr(self, 'nlp_'):
            del self.nlp_
        if hasattr(self, 'universal_pos_tags_dict_'):
            del self.universal_pos_tags_dict_
        if hasattr(self, 'universal_dependencies_dict_'):
            del self.universal_dependencies_dict_
        self.finalize_model()

    def fit(self, X: Union[list, tuple, np.array], y: Union[list, tuple, np.array],
            validation_data: Union[None, Tuple[Union[list, tuple, np.array], Union[list, tuple, np.array]]]=None):
        self.check_params(
            elmo_hub_module_handle=self.elmo_hub_module_handle, finetune_elmo=self.finetune_elmo,
            batch_size=self.batch_size, max_seq_length=self.max_seq_length, lr=self.lr, l2_reg=self.l2_reg,
            validation_fraction=self.validation_fraction, max_epochs=self.max_epochs, patience=self.patience,
            gpu_memory_frac=self.gpu_memory_frac, verbose=self.verbose, random_seed=self.random_seed,
            udpipe_lang=self.udpipe_lang, use_additional_features=self.use_additional_features
        )
        self.classes_list_ = self.check_Xy(X, 'X', y, 'y')
        if hasattr(self, 'shapes_list_'):
            del self.shapes_list_
        self.finalize_model()
        self.update_random_seed()
        if validation_data is None:
            if self.validation_fraction > 0.0:
                train_index, test_index = split_dataset(y, self.validation_fraction, logger=elmo_ner_logger)
                X_train_ = [X[idx] for idx in train_index]
                y_train_ = [y[idx] for idx in train_index]
                X_val_ = [X[idx] for idx in test_index]
                y_val_ = [y[idx] for idx in test_index]
                del train_index, test_index
            else:
                X_train_ = X
                y_train_ = y
                X_val_ = None
                y_val_ = None
        else:
            if (not isinstance(validation_data, tuple)) and (not isinstance(validation_data, list)):
                raise ValueError('')
            if len(validation_data) != 2:
                raise ValueError('')
            classes_list_for_validation = self.check_Xy(validation_data[0], 'X_val', validation_data[1], 'y_val')
            if not (set(classes_list_for_validation) <= set(self.classes_list_)):
                raise ValueError('')
            X_train_ = X
            y_train_ = y
            X_val_ = validation_data[0]
            y_val_ = validation_data[1]
        X_train_tokenized, y_train_tokenized, self.shapes_list_ = self.tokenize_all(X_train_, y_train_)
        X_train_tokenized, y_train_tokenized = self.extend_Xy(X_train_tokenized, y_train_tokenized, shuffle=True)
        if (X_val_ is not None) and (y_val_ is not None):
            X_val_tokenized, y_val_tokenized, _ = self.tokenize_all(X_val_, y_val_, shapes_vocabulary=self.shapes_list_)
            X_val_tokenized, y_val_tokenized = self.extend_Xy(X_val_tokenized, y_val_tokenized, shuffle=False)
        else:
            X_val_tokenized = None
            y_val_tokenized = None
        if self.verbose:
            elmo_ner_logger.info('Number of shapes is {0}.'.format(len(self.shapes_list_)))
        train_op, log_likelihood, logits_, transition_params_ = self.build_model()
        n_batches = int(np.ceil(X_train_tokenized[0].shape[0] / float(self.batch_size)))
        bounds_of_batches_for_training = []
        for iteration in range(n_batches):
            batch_start = iteration * self.batch_size
            batch_end = min(batch_start + self.batch_size, X_train_tokenized[0].shape[0])
            bounds_of_batches_for_training.append((batch_start,  batch_end))
        if X_val_tokenized is None:
            bounds_of_batches_for_validation = None
        else:
            n_batches = int(np.ceil(X_val_tokenized[0].shape[0] / float(self.batch_size)))
            bounds_of_batches_for_validation = []
            for iteration in range(n_batches):
                batch_start = iteration * self.batch_size
                batch_end = min(batch_start + self.batch_size, X_val_tokenized[0].shape[0])
                bounds_of_batches_for_validation.append((batch_start, batch_end))
        init = tf.global_variables_initializer()
        init.run(session=self.sess_)
        tmp_model_name = self.get_temp_model_name()
        if self.verbose:
            if X_val_tokenized is None:
                elmo_ner_logger.info('Epoch   Log-likelihood')
        n_epochs_without_improving = 0
        try:
            best_acc = None
            for epoch in range(self.max_epochs):
                random.shuffle(bounds_of_batches_for_training)
                feed_dict_for_batch = None
                for cur_batch in bounds_of_batches_for_training:
                    X_batch = [X_train_tokenized[channel_idx][cur_batch[0]:cur_batch[1]]
                               for channel_idx in range(len(X_train_tokenized))]
                    y_batch = y_train_tokenized[cur_batch[0]:cur_batch[1]]
                    feed_dict_for_batch = self.fill_feed_dict(X_batch, y_batch)
                    self.sess_.run(train_op, feed_dict=feed_dict_for_batch)
                acc_train = log_likelihood.eval(feed_dict=feed_dict_for_batch, session=self.sess_)
                if bounds_of_batches_for_validation is not None:
                    acc_test = 0.0
                    y_pred = []
                    for cur_batch in bounds_of_batches_for_validation:
                        X_batch = [X_val_tokenized[channel_idx][cur_batch[0]:cur_batch[1]]
                                   for channel_idx in range(len(X_val_tokenized))]
                        y_batch = y_val_tokenized[cur_batch[0]:cur_batch[1]]
                        feed_dict_for_batch = self.fill_feed_dict(X_batch, y_batch)
                        acc_test_, logits, trans_params = self.sess_.run(
                            [log_likelihood, logits_, transition_params_],
                            feed_dict=feed_dict_for_batch
                        )
                        acc_test += acc_test_ * self.batch_size
                        sequence_lengths = X_val_tokenized[1][cur_batch[0]:cur_batch[1]]
                        for logit, sequence_length in zip(logits, sequence_lengths):
                            logit = logit[:int(sequence_length)]
                            viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode(logit, trans_params)
                            y_pred += [viterbi_seq]
                    acc_test /= float(X_val_tokenized[0].shape[0])
                    if self.verbose:
                        elmo_ner_logger.info('Epoch {0}'.format(epoch))
                        elmo_ner_logger.info('  Train log-likelihood.: {0: 10.8f}'.format(acc_train))
                        elmo_ner_logger.info('  Val. log-likelihood:  {0: 10.8f}'.format(acc_test))
                    pred_entities_val = []
                    for sample_idx, labels_in_text in enumerate(y_pred[0:len(X_val_)]):
                        n_tokens = len(labels_in_text)
                        tokens = X_val_tokenized[0][sample_idx][:n_tokens]
                        bounds_of_tokens = self.calculate_bounds_of_tokens(X_val_[sample_idx], tokens)
                        new_entities = self.calculate_bounds_of_named_entities(bounds_of_tokens, self.classes_list_,
                                                                               labels_in_text)
                        pred_entities_val.append(new_entities)
                    f1_test, precision_test, recall_test, quality_by_entities = calculate_prediction_quality(
                        y_val_, pred_entities_val, self.classes_list_)
                    if best_acc is None:
                        best_acc = f1_test
                        self.save_model(tmp_model_name)
                        n_epochs_without_improving = 0
                    elif f1_test > best_acc:
                        best_acc = f1_test
                        self.save_model(tmp_model_name)
                        n_epochs_without_improving = 0
                    else:
                        n_epochs_without_improving += 1
                    if self.verbose:
                        elmo_ner_logger.info('  Val. quality for all entities:')
                        elmo_ner_logger.info('      F1={0:>6.4f}, P={1:>6.4f}, R={2:>6.4f}'.format(
                            f1_test, precision_test, recall_test))
                        max_text_width = 0
                        for ne_type in sorted(list(quality_by_entities.keys())):
                            text_width = len(ne_type)
                            if text_width > max_text_width:
                                max_text_width = text_width
                        for ne_type in sorted(list(quality_by_entities.keys())):
                            elmo_ner_logger.info('    Val. quality for {0:>{1}}:'.format(ne_type, max_text_width))
                            elmo_ner_logger.info('      F1={0:>6.4f}, P={1:>6.4f}, R={2:>6.4f})'.format(
                                quality_by_entities[ne_type][0], quality_by_entities[ne_type][1],
                                quality_by_entities[ne_type][2]))
                    del y_pred, pred_entities_val
                else:
                    if best_acc is None:
                        best_acc = acc_train
                        self.save_model(tmp_model_name)
                        n_epochs_without_improving = 0
                    elif acc_train > best_acc:
                        best_acc = acc_train
                        self.save_model(tmp_model_name)
                        n_epochs_without_improving = 0
                    else:
                        n_epochs_without_improving += 1
                    if self.verbose:
                        elmo_ner_logger.info('{0:>5}   {1:>14.8f}'.format(epoch, acc_train))
                if n_epochs_without_improving >= self.patience:
                    if self.verbose:
                        elmo_ner_logger.info('Epoch %05d: early stopping' % (epoch + 1))
                    break
            if best_acc is not None:
                self.finalize_model()
                self.load_model(tmp_model_name)
                if self.verbose:
                    if bounds_of_batches_for_validation is not None:
                        acc_test = 0.0
                        y_pred = []
                        for cur_batch in bounds_of_batches_for_validation:
                            X_batch = [X_val_tokenized[channel_idx][cur_batch[0]:cur_batch[1]] for channel_idx in
                                       range(len(X_val_tokenized))]
                            y_batch = y_val_tokenized[cur_batch[0]:cur_batch[1]]
                            feed_dict_for_batch = self.fill_feed_dict(X_batch, y_batch)
                            acc_test_, logits, trans_params = self.sess_.run(
                                ['eval/Mean:0', 'outputs_of_NER/BiasAdd:0', 'transitions:0'],
                                feed_dict=feed_dict_for_batch
                            )
                            acc_test += acc_test_ * self.batch_size
                            sequence_lengths = X_val_tokenized[1][cur_batch[0]:cur_batch[1]]
                            for logit, sequence_length in zip(logits, sequence_lengths):
                                logit = logit[:int(sequence_length)]
                                viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode(logit, trans_params)
                                y_pred += [viterbi_seq]
                        acc_test /= float(X_val_tokenized[0].shape[0])
                        pred_entities_val = []
                        for sample_idx, labels_in_text in enumerate(y_pred[0:len(X_val_)]):
                            n_tokens = len(labels_in_text)
                            tokens = X_val_tokenized[0][sample_idx][:n_tokens]
                            bounds_of_tokens = self.calculate_bounds_of_tokens(X_val_[sample_idx], tokens)
                            new_entities = self.calculate_bounds_of_named_entities(bounds_of_tokens, self.classes_list_,
                                                                                   labels_in_text)
                            pred_entities_val.append(new_entities)
                        f1_test, _, _, _ = calculate_prediction_quality(y_val_, pred_entities_val, self.classes_list_)
                        elmo_ner_logger.info('Best val. F1 is {0:>8.6f}'.format(f1_test))
                        elmo_ner_logger.info('Best val. log-likelihood is {0:>10.8f}'.format(acc_test))
        finally:
            for cur_name in self.find_all_model_files(tmp_model_name):
                os.remove(cur_name)
        return self

    def predict(self, X: Union[list, tuple, np.array]) -> List[Dict[str, List[Tuple[int, int]]]]:
        self.check_params(
            elmo_hub_module_handle=self.elmo_hub_module_handle, finetune_elmo=self.finetune_elmo,
            batch_size=self.batch_size, max_seq_length=self.max_seq_length, lr=self.lr, l2_reg=self.l2_reg,
            validation_fraction=self.validation_fraction, max_epochs=self.max_epochs, patience=self.patience,
            gpu_memory_frac=self.gpu_memory_frac, verbose=self.verbose, random_seed=self.random_seed,
            udpipe_lang=self.udpipe_lang, use_additional_features=self.use_additional_features
        )
        self.check_X(X, 'X')
        self.is_fitted()
        X_tokenized, _, _ = self.tokenize_all(X, shapes_vocabulary=self.shapes_list_)
        n_samples = X_tokenized[0].shape[0]
        X_tokenized = self.extend_Xy(X_tokenized)
        n_batches = X_tokenized[0].shape[0] // self.batch_size
        bounds_of_batches = []
        for iteration in range(n_batches):
            batch_start = iteration * self.batch_size
            batch_end = batch_start + self.batch_size
            bounds_of_batches.append((batch_start, batch_end))
        y_pred = []
        for cur_batch in bounds_of_batches:
            feed_dict = self.fill_feed_dict(
                [
                    X_tokenized[channel_idx][cur_batch[0]:cur_batch[1]]
                    for channel_idx in range(len(X_tokenized))
                ]
            )
            logits, trans_params = self.sess_.run(['outputs_of_NER/BiasAdd:0', 'transitions:0'], feed_dict=feed_dict)
            sequence_lengths = X_tokenized[1][cur_batch[0]:cur_batch[1]]
            for logit, sequence_length in zip(logits, sequence_lengths):
                logit = logit[:int(sequence_length)]
                viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode(logit, trans_params)
                y_pred += [viterbi_seq]
        del bounds_of_batches
        recognized_entities_in_texts = []
        for sample_idx, labels_in_text in enumerate(y_pred[0:n_samples]):
            n_tokens = len(labels_in_text)
            tokens = X_tokenized[0][sample_idx][:n_tokens]
            bounds_of_tokens = self.calculate_bounds_of_tokens(X[sample_idx], tokens)
            new_entities = self.calculate_bounds_of_named_entities(bounds_of_tokens, self.classes_list_, labels_in_text)
            recognized_entities_in_texts.append(new_entities)
        return recognized_entities_in_texts

    def is_fitted(self):
        check_is_fitted(self, ['classes_list_', 'shapes_list_', 'sess_'])

    def score(self, X, y, sample_weight=None) -> float:
        y_pred = self.predict(X)
        return calculate_prediction_quality(y, y_pred, self.classes_list_)[0]

    def fit_predict(self, X: Union[list, tuple, np.array],  y: Union[list, tuple, np.array], **kwargs):
        return self.fit(X, y).predict(X)

    def fill_feed_dict(self, X: List[np.array], y: np.array=None) -> dict:
        if self.use_additional_features:
            assert len(X) == 4
        else:
            assert len(X) == 2
        assert len(X[0]) == self.batch_size
        if self.use_additional_features:
            feed_dict = {ph: x for ph, x in zip(['tokens:0', 'sequence_len:0', 'shape_features:0',
                                                 'linguistic_features:0'], X)}
        else:
            feed_dict = {ph: x for ph, x in zip(['tokens:0', 'sequence_len:0'], X)}
        if y is not None:
            feed_dict['y_ph:0'] = y
        return feed_dict

    def extend_Xy(self, X: List[np.array], y: np.array = None,
                  shuffle: bool = False) -> Union[List[np.array], Tuple[List[np.array], np.array]]:
        n_samples = X[0].shape[0]
        n_extend = n_samples % self.batch_size
        if n_extend == 0:
            if y is None:
                return X
            return X, y
        n_extend = self.batch_size - n_extend
        X_ext = [
            np.concatenate(
                (
                    X[idx],
                    np.full(
                        shape=((n_extend, self.max_seq_length) if len(X[idx].shape) == 2 else
                               ((n_extend,) if len(X[idx].shape) == 1 else
                               (n_extend, self.max_seq_length, X[idx].shape[2]))),
                        fill_value=X[idx][-1],
                        dtype=X[idx].dtype
                    )
                )
            )
            for idx in range(len(X))
        ]
        if y is None:
            if shuffle:
                indices = np.arange(0, n_samples + n_extend, 1, dtype=np.int32)
                np.random.shuffle(indices)
                return [X_ext[idx][indices] for idx in range(len(X_ext))]
            return X_ext
        y_ext = np.concatenate(
            (
                y,
                np.full(shape=(n_extend, self.max_seq_length), fill_value=y[-1], dtype=y.dtype)
            )
        )
        if shuffle:
            indices = np.arange(0, n_samples + n_extend, 1, dtype=np.int32)
            return [X_ext[idx][indices] for idx in range(len(X_ext))], y_ext[indices]
        return X_ext, y_ext

    def tokenize_all(self, X: Union[list, tuple, np.array], y: Union[list, tuple, np.array] = None,
                     shapes_vocabulary: Union[tuple, None] = None) -> Tuple[List[np.ndarray], Union[np.ndarray, None],
                                                                            tuple]:
        if shapes_vocabulary is not None:
            if len(shapes_vocabulary) < 1:
                raise ValueError('Shapes vocabulary is empty!')
        tokens_of_texts = []
        lenghts_of_texts = []
        lingustic_features_of_texts = []
        y_tokenized = None if y is None else np.empty((len(y), self.max_seq_length), dtype=np.int32)
        n_samples = len(X)
        shapes_of_texts = []
        shapes_dict = dict()
        if not hasattr(self, 'universal_pos_tags_dict_'):
            self.universal_pos_tags_dict_ = dict(zip(UNIVERSAL_POS_TAGS, range(len(UNIVERSAL_POS_TAGS))))
        if not hasattr(self, 'universal_dependencies_dict_'):
            self.universal_dependencies_dict_ = dict(zip(UNIVERSAL_DEPENDENCIES, range(len(UNIVERSAL_DEPENDENCIES))))
        if y is None:
            for sample_idx in range(n_samples):
                source_text = X[sample_idx]
                if not hasattr(self, 'nlp_'):
                    self.nlp_ = create_udpipe_pipeline(self.udpipe_lang)
                spacy_doc = self.nlp_(source_text)
                tokenized_text = []
                pos_tags = []
                dependencies = []
                for spacy_token in spacy_doc:
                    tokenized_text.append(spacy_token.text)
                    pos_tags.append(spacy_token.pos_)
                    dependencies.append(spacy_token.dep_)
                del spacy_doc
                shapes_of_text = [self.get_shape_of_string(cur) for cur in tokenized_text]
                if shapes_vocabulary is None:
                    for cur_shape in shapes_of_text:
                        if cur_shape != '':
                            shapes_dict[cur_shape] = shapes_dict.get(cur_shape, 0) + 1
                ndiff = len(tokenized_text) - self.max_seq_length
                if ndiff > 0:
                    tokenized_text = tokenized_text[:self.max_seq_length]
                    shapes_of_text = shapes_of_text[:self.max_seq_length]
                    pos_tags = pos_tags[:self.max_seq_length]
                    dependencies = dependencies[:self.max_seq_length]
                    lenghts_of_texts.append(len(tokenized_text))
                elif ndiff < 0:
                    lenghts_of_texts.append(len(tokenized_text))
                    tokenized_text += ['' for _ in range(-ndiff)]
                else:
                    lenghts_of_texts.append(len(tokenized_text))
                tokens_of_texts.append(tokenized_text)
                shapes_of_texts.append(shapes_of_text)
                lingustic_features_of_texts.append(tuple(zip(pos_tags, dependencies)))
                del pos_tags, dependencies, tokenized_text
        else:
            for sample_idx in range(n_samples):
                source_text = X[sample_idx]
                if not hasattr(self, 'nlp_'):
                    self.nlp_ = create_udpipe_pipeline(self.udpipe_lang)
                spacy_doc = self.nlp_(source_text)
                tokenized_text = []
                pos_tags = []
                dependencies = []
                for spacy_token in spacy_doc:
                    tokenized_text.append(spacy_token.text)
                    pos_tags.append(spacy_token.pos_)
                    dependencies.append(spacy_token.dep_)
                del spacy_doc
                shapes_of_text = [self.get_shape_of_string(cur) for cur in tokenized_text]
                if shapes_vocabulary is None:
                    for cur_shape in shapes_of_text:
                        if cur_shape != '':
                            shapes_dict[cur_shape] = shapes_dict.get(cur_shape, 0) + 1
                bounds_of_tokens = self.calculate_bounds_of_tokens(source_text, tokenized_text)
                indices_of_named_entities, labels_IDs = self.calculate_indices_of_named_entities(
                    source_text, self.classes_list_, y[sample_idx])
                y_tokenized[sample_idx] = self.detect_token_labels(
                    bounds_of_tokens, indices_of_named_entities, labels_IDs, self.max_seq_length
                )
                ndiff = len(tokenized_text) - self.max_seq_length
                if ndiff > 0:
                    tokenized_text = tokenized_text[:self.max_seq_length]
                    shapes_of_text = shapes_of_text[:self.max_seq_length]
                    pos_tags = pos_tags[:self.max_seq_length]
                    dependencies = dependencies[:self.max_seq_length]
                    lenghts_of_texts.append(len(tokenized_text))
                elif ndiff < 0:
                    lenghts_of_texts.append(len(tokenized_text))
                    tokenized_text += ['' for _ in range(-ndiff)]
                else:
                    lenghts_of_texts.append(len(tokenized_text))
                tokens_of_texts.append(tokenized_text)
                shapes_of_texts.append(shapes_of_text)
                lingustic_features_of_texts.append(tuple(zip(pos_tags, dependencies)))
                del pos_tags, dependencies, tokenized_text
        assert len(X) == len(tokens_of_texts), '{0} != {1}'.format(len(X), len(tokens_of_texts))
        assert len(tokens_of_texts) == len(lenghts_of_texts), '{0} != {1}'.format(
            len(tokens_of_texts), len(lenghts_of_texts))
        assert len(lenghts_of_texts) == len(lingustic_features_of_texts), '{0} != {1}'.format(
            len(lenghts_of_texts), len(lingustic_features_of_texts))
        assert len(lenghts_of_texts) == len(shapes_of_texts), '{0} != {1}'.format(
            len(lenghts_of_texts), len(shapes_of_texts))
        if shapes_vocabulary is None:
            shapes_vocabulary_ = list(map(
                lambda it2: it2[0],
                filter(
                    lambda it1: it1[1] >= 3,
                    [(cur_shape, shapes_dict[cur_shape]) for cur_shape in sorted(list(shapes_dict.keys()))]
                )
            ))
            shapes_vocabulary_ = tuple(shapes_vocabulary_)
        else:
            shapes_vocabulary_ = shapes_vocabulary
        shapes_ = np.zeros((len(X), self.max_seq_length, len(shapes_vocabulary_) + 3), dtype=np.float32)
        for sample_idx in range(n_samples):
            for token_idx, cur_shape in enumerate(shapes_of_texts[sample_idx]):
                if cur_shape in shapes_vocabulary_:
                    shape_ID = shapes_vocabulary_.index(cur_shape)
                else:
                    shape_ID = len(shapes_vocabulary_)
                shapes_[sample_idx][token_idx][shape_ID] = 1.0
            shapes_[sample_idx][0][len(shapes_vocabulary_) + 1] = 1.0
            shapes_[sample_idx][len(shapes_of_texts[sample_idx]) - 1][len(shapes_vocabulary_) + 2] = 1.0
        del shapes_of_texts
        linguistic_features = np.zeros((len(X), self.max_seq_length, len(self.universal_pos_tags_dict_) +
                                        len(self.universal_dependencies_dict_)), dtype=np.float32)
        for sample_idx in range(n_samples):
            for token_idx in range(len(lingustic_features_of_texts[sample_idx])):
                pos_tag, dependency_tag = lingustic_features_of_texts[sample_idx][token_idx]
                pos_tag_id = self.universal_pos_tags_dict_.get(pos_tag, -1)
                if pos_tag_id >= 0:
                    linguistic_features[sample_idx][token_idx][pos_tag_id] = 1.0
                else:
                    raise ValueError('Part-of-speech tag `{0}` is unknown!'.format(pos_tag))
                ok = False
                for dependency_tag_part in prepare_dependency_tag(dependency_tag):
                    dependency_id = self.universal_dependencies_dict_.get(dependency_tag_part, -1)
                    if dependency_id >= 0:
                        linguistic_features[sample_idx][token_idx][dependency_id + len(UNIVERSAL_POS_TAGS)] = 1.0
                        ok = True
                if not ok:
                    raise ValueError('Dependency tag `{0}` is unknown!'.format(dependency_tag))
        if self.use_additional_features:
            X = [np.array(tokens_of_texts, dtype=np.str), np.array(lenghts_of_texts, dtype=np.int32), shapes_,
                 linguistic_features]
        else:
            X = [np.array(tokens_of_texts, dtype=np.str), np.array(lenghts_of_texts, dtype=np.int32)]
        return X, (None if y is None else np.array(y_tokenized)), shapes_vocabulary_

    def get_params(self, deep=True) -> dict:
        return {'elmo_hub_module_handle': self.elmo_hub_module_handle, 'finetune_elmo': self.finetune_elmo,
                'batch_size': self.batch_size, 'max_seq_length': self.max_seq_length, 'lr': self.lr,
                'l2_reg': self.l2_reg, 'max_epochs': self.max_epochs, 'patience': self.patience,
                'validation_fraction': self.validation_fraction, 'gpu_memory_frac': self.gpu_memory_frac,
                'verbose': self.verbose, 'random_seed': self.random_seed, 'udpipe_lang': self.udpipe_lang,
                'use_additional_features': self.use_additional_features}

    def set_params(self, **params):
        for parameter, value in params.items():
            self.__setattr__(parameter, value)
        return self

    def __copy__(self):
        cls = self.__class__
        result = cls.__new__(cls)
        result.set_params(
            elmo_hub_module_handle=self.elmo_hub_module_handle, finetune_elmo=self.finetune_elmo,
            batch_size=self.batch_size, max_seq_length=self.max_seq_length, lr=self.lr, l2_reg=self.l2_reg,
            validation_fraction=self.validation_fraction, max_epochs=self.max_epochs, patience=self.patience,
            gpu_memory_frac=self.gpu_memory_frac, verbose=self.verbose, random_seed=self.random_seed,
            udpipe_lang=self.udpipe_lang, use_additional_features=self.use_additional_features
        )
        try:
            self.is_fitted()
            is_fitted = True
        except:
            is_fitted = False
        if is_fitted:
            result.classes_list_ = self.classes_list_
            result.shapes_list_ = self.shapes_list_
            result.sess_ = self.sess_
        return result

    def __deepcopy__(self, memodict={}):
        cls = self.__class__
        result = cls.__new__(cls)
        result.set_params(
            elmo_hub_module_handle=self.elmo_hub_module_handle,  finetune_elmo=self.finetune_elmo,
            batch_size=self.batch_size, max_seq_length=self.max_seq_length, lr=self.lr, l2_reg=self.l2_reg,
            validation_fraction=self.validation_fraction, max_epochs=self.max_epochs, patience=self.patience,
            gpu_memory_frac=self.gpu_memory_frac, verbose=self.verbose, random_seed=self.random_seed,
            udpipe_lang=self.udpipe_lang, use_additional_features=self.use_additional_features
        )
        try:
            self.is_fitted()
            is_fitted = True
        except:
            is_fitted = False
        if is_fitted:
            result.classes_list_ = self.classes_list_
            result.shapes_list_ = self.shapes_list_
            result.sess_ = self.sess_
        return result

    def __getstate__(self):
        return self.dump_all()

    def __setstate__(self, state: dict):
        self.load_all(state)

    def update_random_seed(self):
        if self.random_seed is None:
            self.random_seed = int(round(time.time()))
        random.seed(self.random_seed)
        np.random.seed(self.random_seed)
        tf.random.set_random_seed(self.random_seed)

    def dump_all(self):
        try:
            self.is_fitted()
            is_fitted = True
        except:
            is_fitted = False
        params = self.get_params(True)
        if is_fitted:
            params['classes_list_'] = copy.copy(self.classes_list_)
            params['shapes_list_'] = copy.copy(self.shapes_list_)
            model_file_name = self.get_temp_model_name()
            try:
                params['model_name_'] = os.path.basename(model_file_name)
                self.save_model(model_file_name)
                for cur_name in self.find_all_model_files(model_file_name):
                    with open(cur_name, 'rb') as fp:
                        model_data = fp.read()
                    params['model.' + os.path.basename(cur_name)] = model_data
                    del model_data
            finally:
                for cur_name in self.find_all_model_files(model_file_name):
                    os.remove(cur_name)
        return params

    def load_all(self, new_params: dict):
        if not isinstance(new_params, dict):
            raise ValueError('`new_params` is wrong! Expected `{0}`, got `{1}`.'.format(type({0: 1}), type(new_params)))
        self.check_params(**new_params)
        if hasattr(self, 'classes_list_'):
            del self.classes_list_
        if hasattr(self, 'shapes_list_'):
            del self.shapes_list_
        self.finalize_model()
        is_fitted = ('classes_list_' in new_params) and ('model_name_' in new_params) and ('shapes_list_' in new_params)
        model_files = list(
            filter(
                lambda it3: len(it3) > 0,
                map(
                    lambda it2: it2[len('model.'):].strip(),
                    filter(
                        lambda it1: it1.startswith('model.') and (len(it1) > len('model.')),
                        new_params.keys()
                    )
                )
            )
        )
        if is_fitted and (len(model_files) == 0):
            is_fitted = False
        if is_fitted:
            tmp_dir_name = tempfile.gettempdir()
            tmp_file_names = [os.path.join(tmp_dir_name, cur) for cur in model_files]
            for cur in tmp_file_names:
                if os.path.isfile(cur):
                    raise ValueError('File `{0}` exists, and so it cannot be used for data transmission!'.format(cur))
            self.set_params(**new_params)
            self.classes_list_ = copy.copy(new_params['classes_list_'])
            self.shapes_list_ = copy.copy(new_params['shapes_list_'])
            self.update_random_seed()
            try:
                for idx in range(len(model_files)):
                    with open(tmp_file_names[idx], 'wb') as fp:
                        fp.write(new_params['model.' + model_files[idx]])
                self.load_model(os.path.join(tmp_dir_name, new_params['model_name_']))
            finally:
                for cur in tmp_file_names:
                    if os.path.isfile(cur):
                        os.remove(cur)
        else:
            self.set_params(**new_params)
        return self

    def build_model(self):
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_frac
        config.gpu_options.allow_growth = True
        self.sess_ = tf.Session(config=config)
        input_tokens = tf.placeholder(shape=(self.batch_size, self.max_seq_length), dtype=tf.string, name='tokens')
        sequence_lengths = tf.placeholder(shape=(self.batch_size,), dtype=tf.int32, name='sequence_len')
        y_ph = tf.placeholder(shape=(self.batch_size, self.max_seq_length), dtype=tf.int32, name='y_ph')
        elmo_inputs = dict(
            tokens=input_tokens,
            sequence_len=sequence_lengths
        )
        elmo_module = tfhub.Module(self.elmo_hub_module_handle, trainable=self.finetune_elmo)
        sequence_output = elmo_module(inputs=elmo_inputs, signature='tokens', as_dict=True)['elmo']
        sequence_output = tf.reshape(sequence_output, [self.batch_size, self.max_seq_length, 1024])
        if self.verbose:
            elmo_ner_logger.info('The ELMo model has been loaded from the TF-Hub.')
        n_tags = len(self.classes_list_) * 2 + 1
        he_init = tf.contrib.layers.variance_scaling_initializer(seed=self.random_seed)
        if self.use_additional_features:
            shape_features = tf.placeholder(
                shape=(self.batch_size, self.max_seq_length, len(self.shapes_list_) + 3), dtype=tf.float32,
                name='shape_features'
            )
            linguistic_features = tf.placeholder(
                shape=(self.batch_size, self.max_seq_length, len(UNIVERSAL_DEPENDENCIES) + len(UNIVERSAL_POS_TAGS)),
                dtype=tf.float32,
                name='linguistic_features'
            )
            if self.finetune_elmo:
                logits = tf.layers.dense(tf.concat([sequence_output, shape_features, linguistic_features], axis=-1),
                                         n_tags, activation=None, kernel_regularizer=tf.nn.l2_loss,
                                         kernel_initializer=he_init, name='outputs_of_NER')
            else:
                sequence_output_stop = tf.stop_gradient(sequence_output)
                logits = tf.layers.dense(
                    tf.concat([sequence_output_stop, shape_features, linguistic_features], axis=-1),
                    n_tags, activation=None, kernel_regularizer=tf.nn.l2_loss,
                    kernel_initializer=he_init, name='outputs_of_NER')
        else:
            if self.finetune_elmo:
                logits = tf.layers.dense(sequence_output,
                                         n_tags, activation=None, kernel_regularizer=tf.nn.l2_loss,
                                         kernel_initializer=he_init, name='outputs_of_NER')
            else:
                sequence_output_stop = tf.stop_gradient(sequence_output)
                logits = tf.layers.dense(
                    sequence_output_stop,
                    n_tags, activation=None, kernel_regularizer=tf.nn.l2_loss,
                    kernel_initializer=he_init, name='outputs_of_NER')
        log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood(logits, y_ph, sequence_lengths)
        loss_tensor = -log_likelihood
        base_loss = tf.reduce_mean(loss_tensor)
        regularization_loss = self.l2_reg * tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        final_loss = base_loss + regularization_loss
        with tf.name_scope('train'):
            optimizer = tf.train.RMSPropOptimizer(learning_rate=self.lr, momentum=0.9, decay=0.9, epsilon=1e-10)
            train_op = optimizer.minimize(final_loss)
        with tf.name_scope('eval'):
            log_likelihood_eval_, _ = tf.contrib.crf.crf_log_likelihood(logits, y_ph,
                                                                        sequence_lengths, transition_params)
            seq_norm_eval = tf.contrib.crf.crf_log_norm(logits, sequence_lengths, transition_params)
            log_likelihood_eval = tf.reduce_mean(tf.cast(log_likelihood_eval_, tf.float32) /
                                                 tf.cast(seq_norm_eval, tf.float32))
        return train_op, log_likelihood_eval, logits, transition_params

    def finalize_model(self):
        if hasattr(self, 'sess_'):
            for k in list(self.sess_.graph.get_all_collection_keys()):
                self.sess_.graph.clear_collection(k)
            self.sess_.close()
            del self.sess_
        tf.reset_default_graph()

    def save_model(self, file_name: str):
        saver = tf.train.Saver()
        saver.save(self.sess_, file_name)

    def load_model(self, file_name: str):
        if not hasattr(self, 'sess_'):
            config = tf.ConfigProto()
            config.gpu_options.per_process_gpu_memory_fraction = self.gpu_memory_frac
            config.gpu_options.allow_growth = True
            self.sess_ = tf.Session(config=config)
        saver = tf.train.import_meta_graph(file_name + '.meta', clear_devices=True)
        saver.restore(self.sess_, file_name)

    @staticmethod
    def get_temp_model_name() -> str:
        with tempfile.NamedTemporaryFile(mode='w', suffix='elmo_crf.ckpt', delete=True) as fp:
            res = fp.name
        return res

    @staticmethod
    def find_all_model_files(model_name: str) -> List[str]:
        model_files = []
        if os.path.isfile(model_name):
            model_files.append(model_name)
        dir_name = os.path.dirname(model_name)
        base_name = os.path.basename(model_name)
        for cur in filter(lambda it: it.lower().find(base_name.lower()) >= 0, os.listdir(dir_name)):
            model_files.append(os.path.join(dir_name, cur))
        return sorted(model_files)

    @staticmethod
    def check_params(**kwargs):
        if 'udpipe_lang' not in kwargs:
            raise ValueError('`udpipe_lang` is not specified!')
        if not isinstance(kwargs['udpipe_lang'], str):
            raise ValueError('`udpipe_lang` is wrong! Expected `{0}`, got `{1}`.'.format(
                type('abc'), type(kwargs['udpipe_lang'])))
        if len(kwargs['udpipe_lang']) < 1:
            raise ValueError('`udpipe_lang` is wrong! Expected a nonepty string.')
        if 'batch_size' not in kwargs:
            raise ValueError('`batch_size` is not specified!')
        if (not isinstance(kwargs['batch_size'], int)) and (not isinstance(kwargs['batch_size'], np.int32)) and \
                (not isinstance(kwargs['batch_size'], np.uint32)):
            raise ValueError('`batch_size` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(3), type(kwargs['batch_size'])))
        if kwargs['batch_size'] < 1:
            raise ValueError('`batch_size` is wrong! Expected a positive integer value, '
                             'but {0} is not positive.'.format(kwargs['batch_size']))
        if 'lr' not in kwargs:
            raise ValueError('`lr` is not specified!')
        if (not isinstance(kwargs['lr'], float)) and (not isinstance(kwargs['lr'], np.float32)) and \
                (not isinstance(kwargs['lr'], np.float64)):
            raise ValueError('`lr` is wrong! Expected `{0}`, got `{1}`.'.format(type(3.5), type(kwargs['lr'])))
        if kwargs['lr'] <= 0.0:
            raise ValueError('`lr` is wrong! Expected a positive floating-point value, '
                             'but {0} is not positive.'.format(kwargs['lr']))
        if 'l2_reg' not in kwargs:
            raise ValueError('`l2_reg` is not specified!')
        if (not isinstance(kwargs['l2_reg'], float)) and (not isinstance(kwargs['l2_reg'], np.float32)) and \
                (not isinstance(kwargs['l2_reg'], np.float64)):
            raise ValueError('`l2_reg` is wrong! Expected `{0}`, got `{1}`.'.format(type(3.5), type(kwargs['l2_reg'])))
        if kwargs['l2_reg'] < 0.0:
            raise ValueError('`l2_reg` is wrong! Expected a non-negative floating-point value, '
                             'but {0} is negative.'.format(kwargs['l2_reg']))
        if 'elmo_hub_module_handle' not in kwargs:
            raise ValueError('`elmo_hub_module_handle` is not specified!')
        if kwargs['elmo_hub_module_handle'] is not None:
            if not isinstance(kwargs['elmo_hub_module_handle'], str):
                raise ValueError('`elmo_hub_module_handle` is wrong! Expected `{0}`, got `{1}`.'.format(
                    type('abc'), type(kwargs['elmo_hub_module_handle'])))
            if len(kwargs['elmo_hub_module_handle']) < 1:
                raise ValueError('`elmo_hub_module_handle` is wrong! Expected a nonepty string.')
        if 'finetune_elmo' not in kwargs:
            raise ValueError('`finetune_elmo` is not specified!')
        if (not isinstance(kwargs['finetune_elmo'], int)) and (not isinstance(kwargs['finetune_elmo'], np.int32)) and \
                (not isinstance(kwargs['finetune_elmo'], np.uint32)) and \
                (not isinstance(kwargs['finetune_elmo'], bool)) and (not isinstance(kwargs['finetune_elmo'], np.bool)):
            raise ValueError('`finetune_elmo` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(True), type(kwargs['finetune_elmo'])))
        if 'max_epochs' not in kwargs:
            raise ValueError('`max_epochs` is not specified!')
        if (not isinstance(kwargs['max_epochs'], int)) and (not isinstance(kwargs['max_epochs'], np.int32)) and \
                (not isinstance(kwargs['max_epochs'], np.uint32)):
            raise ValueError('`max_epochs` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(3), type(kwargs['max_epochs'])))
        if kwargs['max_epochs'] < 1:
            raise ValueError('`max_epochs` is wrong! Expected a positive integer value, '
                             'but {0} is not positive.'.format(kwargs['max_epochs']))
        if 'patience' not in kwargs:
            raise ValueError('`patience` is not specified!')
        if (not isinstance(kwargs['patience'], int)) and (not isinstance(kwargs['patience'], np.int32)) and \
                (not isinstance(kwargs['patience'], np.uint32)):
            raise ValueError('`patience` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(3), type(kwargs['patience'])))
        if kwargs['patience'] < 1:
            raise ValueError('`patience` is wrong! Expected a positive integer value, '
                             'but {0} is not positive.'.format(kwargs['patience']))
        if 'random_seed' not in kwargs:
            raise ValueError('`random_seed` is not specified!')
        if kwargs['random_seed'] is not None:
            if (not isinstance(kwargs['random_seed'], int)) and (not isinstance(kwargs['random_seed'], np.int32)) and \
                    (not isinstance(kwargs['random_seed'], np.uint32)):
                raise ValueError('`random_seed` is wrong! Expected `{0}`, got `{1}`.'.format(
                    type(3), type(kwargs['random_seed'])))
        if 'gpu_memory_frac' not in kwargs:
            raise ValueError('`gpu_memory_frac` is not specified!')
        if (not isinstance(kwargs['gpu_memory_frac'], float)) and \
                (not isinstance(kwargs['gpu_memory_frac'], np.float32)) and \
                (not isinstance(kwargs['gpu_memory_frac'], np.float64)):
            raise ValueError('`gpu_memory_frac` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(3.5), type(kwargs['gpu_memory_frac'])))
        if (kwargs['gpu_memory_frac'] <= 0.0) or (kwargs['gpu_memory_frac'] > 1.0):
            raise ValueError('`gpu_memory_frac` is wrong! Expected a floating-point value in the (0.0, 1.0], '
                             'but {0} is not proper.'.format(kwargs['gpu_memory_frac']))
        if 'max_seq_length' not in kwargs:
            raise ValueError('`max_seq_length` is not specified!')
        if (not isinstance(kwargs['max_seq_length'], int)) and \
                (not isinstance(kwargs['max_seq_length'], np.int32)) and \
                (not isinstance(kwargs['max_seq_length'], np.uint32)):
            raise ValueError('`max_seq_length` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(3), type(kwargs['max_seq_length'])))
        if kwargs['max_seq_length'] < 1:
            raise ValueError('`max_seq_length` is wrong! Expected a positive integer value, '
                             'but {0} is not positive.'.format(kwargs['max_seq_length']))
        if 'validation_fraction' not in kwargs:
            raise ValueError('`validation_fraction` is not specified!')
        if (not isinstance(kwargs['validation_fraction'], float)) and \
                (not isinstance(kwargs['validation_fraction'], np.float32)) and \
                (not isinstance(kwargs['validation_fraction'], np.float64)):
            raise ValueError('`validation_fraction` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(3.5), type(kwargs['validation_fraction'])))
        if kwargs['validation_fraction'] < 0.0:
            raise ValueError('`validation_fraction` is wrong! Expected a positive floating-point value greater than or '
                             'equal to 0.0, but {0} is not positive.'.format(kwargs['validation_fraction']))
        if kwargs['validation_fraction'] >= 1.0:
            raise ValueError('`validation_fraction` is wrong! Expected a positive floating-point value less than 1.0, '
                             'but {0} is not less than 1.0.'.format(kwargs['validation_fraction']))
        if 'verbose' not in kwargs:
            raise ValueError('`verbose` is not specified!')
        if (not isinstance(kwargs['verbose'], int)) and (not isinstance(kwargs['verbose'], np.int32)) and \
                (not isinstance(kwargs['verbose'], np.uint32)) and \
                (not isinstance(kwargs['verbose'], bool)) and (not isinstance(kwargs['verbose'], np.bool)):
            raise ValueError('`verbose` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(True), type(kwargs['verbose'])))
        if 'use_additional_features' not in kwargs:
            raise ValueError('`use_additional_features` is not specified!')
        if (not isinstance(kwargs['use_additional_features'], int)) and \
                (not isinstance(kwargs['use_additional_features'], np.int32)) and \
                (not isinstance(kwargs['use_additional_features'], np.uint32)) and \
                (not isinstance(kwargs['use_additional_features'], bool)) and \
                (not isinstance(kwargs['use_additional_features'], np.bool)):
            raise ValueError('`use_additional_features` is wrong! Expected `{0}`, got `{1}`.'.format(
                type(True), type(kwargs['use_additional_features'])))

    @staticmethod
    def calculate_bounds_of_tokens(source_text: str, tokenized_text: List[str]) -> List[Tuple[int, int]]:
        bounds_of_tokens = []
        start_pos = 0
        for cur_token in tokenized_text:
            found_idx = source_text[start_pos:].find(cur_token)
            n = len(cur_token)
            if found_idx < 0:
                raise ValueError('Text `{0}` cannot be tokenized! Token `{1}` cannot be found! Tokens are: {2}'.format(
                    source_text, cur_token, tokenized_text))
            bounds_of_tokens.append((start_pos + found_idx, start_pos + found_idx + n))
            start_pos += (found_idx + n)
        return bounds_of_tokens

    @staticmethod
    def calculate_bounds_of_named_entities(bounds_of_tokens: List[Tuple[int, int]], classes_list: tuple,
                                           token_labels: List[int]) -> Dict[str, List[Tuple[int, int]]]:
        named_entities_for_text = dict()
        ne_start = -1
        ne_type = ''
        n_tokens = len(bounds_of_tokens)
        for token_idx in range(n_tokens):
            class_id = token_labels[token_idx]
            if (class_id > 0) and ((class_id - 1) // 2 < len(classes_list)):
                if ne_start < 0:
                    ne_start = token_idx
                    ne_type = classes_list[(class_id - 1) // 2]
                else:
                    if class_id % 2 == 0:
                        if ne_type in named_entities_for_text:
                            named_entities_for_text[ne_type].append(
                                (bounds_of_tokens[ne_start][0], bounds_of_tokens[token_idx - 1][1])
                            )
                        else:
                            named_entities_for_text[ne_type] = [
                                (bounds_of_tokens[ne_start][0], bounds_of_tokens[token_idx - 1][1])
                            ]
                        ne_start = token_idx
                        ne_type = classes_list[(class_id - 1) // 2]
                    else:
                        if classes_list[(class_id - 1) // 2] != ne_type:
                            if ne_type in named_entities_for_text:
                                named_entities_for_text[ne_type].append(
                                    (bounds_of_tokens[ne_start][0], bounds_of_tokens[token_idx - 1][1])
                                )
                            else:
                                named_entities_for_text[ne_type] = [
                                    (bounds_of_tokens[ne_start][0], bounds_of_tokens[token_idx - 1][1])
                                ]
                            ne_start = token_idx
                            ne_type = classes_list[(class_id - 1) // 2]
            else:
                if ne_start >= 0:
                    if ne_type in named_entities_for_text:
                        named_entities_for_text[ne_type].append(
                            (bounds_of_tokens[ne_start][0], bounds_of_tokens[token_idx - 1][1])
                        )
                    else:
                        named_entities_for_text[ne_type] = [
                            (bounds_of_tokens[ne_start][0], bounds_of_tokens[token_idx - 1][1])
                        ]
                    ne_start = -1
                    ne_type = ''
        if ne_start >= 0:
            if ne_type in named_entities_for_text:
                named_entities_for_text[ne_type].append(
                    (bounds_of_tokens[ne_start][0], bounds_of_tokens[-1][1])
                )
            else:
                named_entities_for_text[ne_type] = [
                    (bounds_of_tokens[ne_start][0], bounds_of_tokens[-1][1])
                ]
        return named_entities_for_text

    @staticmethod
    def calculate_indices_of_named_entities(source_text: str, classes_list: tuple,
                                            named_entities: Dict[str, List[tuple]]) -> \
            Tuple[np.ndarray, Dict[int, int]]:
        indices_of_named_entities = np.zeros((len(source_text),), dtype=np.int32)
        labels_to_classes = dict()
        label_ID = 1
        for ne_type in sorted(list(named_entities.keys())):
            class_id = classes_list.index(ne_type) + 1
            for ne_bounds in named_entities[ne_type]:
                for char_idx in range(ne_bounds[0], ne_bounds[1]):
                    indices_of_named_entities[char_idx] = label_ID
                labels_to_classes[label_ID] = class_id
                label_ID += 1
        return indices_of_named_entities, labels_to_classes

    @staticmethod
    def detect_token_labels(bounds_of_tokens: List[tuple], indices_of_named_entities: np.ndarray, label_ids: dict,
                            max_seq_length: int) -> np.ndarray:
        res = np.zeros((max_seq_length,), dtype=np.int32)
        n = min(len(bounds_of_tokens), max_seq_length)
        for token_idx, cur in enumerate(bounds_of_tokens[:n]):
            distr = np.zeros((len(label_ids) + 1,), dtype=np.int32)
            for char_idx in range(cur[0], cur[1]):
                distr[indices_of_named_entities[char_idx]] += 1
            label_id = distr.argmax()
            if label_id > 0:
                res[token_idx] = label_id
            del distr
        prev_label_id = 0
        for token_idx in range(max_seq_length):
            cur_label_id = res[token_idx]
            if cur_label_id > 0:
                ne_id = label_ids[res[token_idx]]
                if cur_label_id == prev_label_id:
                    res[token_idx] = ne_id * 2 - 1
                else:
                    res[token_idx] = ne_id * 2
            prev_label_id = cur_label_id
        return res

    @staticmethod
    def get_shape_of_string(src: str) -> str:
        shape = ''
        for idx in range(len(src)):
            if src[idx] in {'_', chr(11791)}:
                new_char = '_'
            elif src[idx].isalpha():
                if src[idx].isupper():
                    new_char = 'A'
                else:
                    new_char = 'a'
            elif src[idx].isdigit():
                new_char = 'D'
            elif src[idx] in {'.', ',', ':', ';', '-', '+', '!', '?', '#', '@', '$', '&', '=', '^', '`', '~', '*', '/',
                              '\\', '(', ')', '[', ']', '{', '}', "'", '"', '|', '<', '>'}:
                new_char = 'P'
            elif src[idx] in {chr(8213), chr(8212), chr(8211), chr(8210), chr(8209), chr(8208), chr(11834), chr(173),
                              chr(8722), chr(8259)}:
                new_char = '-'
            elif src[idx] in {chr(8220), chr(8221), chr(11842), chr(171), chr(187), chr(128631), chr(128630),
                              chr(128632), chr(12318), chr(12317), chr(12319)}:
                new_char = '"'
            elif src[idx] in {chr(39), chr(8216), chr(8217), chr(8218)}:
                new_char = "'"
            else:
                new_char = 'U'
            if len(shape) == 0:
                shape += new_char
            elif shape[-1] != new_char:
                shape += new_char
        return shape

    @staticmethod
    def check_X(X: Union[list, tuple, np.array], X_name: str):
        if (not hasattr(X, '__len__')) or (not hasattr(X, '__getitem__')):
            raise ValueError('`{0}` is wrong, because it is not list-like object!'.format(X_name))
        if isinstance(X, np.ndarray):
            if len(X.shape) != 1:
                raise ValueError('`{0}` is wrong, because it is not 1-D list!'.format(X_name))
        n = len(X)
        for idx in range(n):
            if (not hasattr(X[idx], '__len__')) or (not hasattr(X[idx], '__getitem__')) or \
                    (not hasattr(X[idx], 'strip')) or (not hasattr(X[idx], 'split')):
                raise ValueError('Item {0} of `{1}` is wrong, because it is not string-like object!'.format(
                    idx, X_name))

    @staticmethod
    def check_Xy(X: Union[list, tuple, np.array], X_name: str, y: Union[list, tuple, np.array], y_name: str) -> tuple:
        ELMo_NER.check_X(X, X_name)
        if (not hasattr(y, '__len__')) or (not hasattr(y, '__getitem__')):
            raise ValueError('`{0}` is wrong, because it is not a list-like object!'.format(y_name))
        if isinstance(y, np.ndarray):
            if len(y.shape) != 1:
                raise ValueError('`{0}` is wrong, because it is not 1-D list!'.format(y_name))
        n = len(y)
        if n != len(X):
            raise ValueError('Length of `{0}` does not correspond to length of `{1}`! {2} != {3}'.format(
                X_name, y_name, len(X), len(y)))
        classes_list = set()
        for idx in range(n):
            if (not hasattr(y[idx], '__len__')) or (not hasattr(y[idx], 'items')) or (not hasattr(y[idx], 'keys')) or \
                    (not hasattr(y[idx], 'values')):
                raise ValueError('Item {0} of `{1}` is wrong, because it is not a dictionary-like object!'.format(
                    idx, y_name))
            for ne_type in sorted(list(y[idx].keys())):
                if (not hasattr(ne_type, '__len__')) or (not hasattr(ne_type, '__getitem__')) or \
                        (not hasattr(ne_type, 'strip')) or (not hasattr(ne_type, 'split')):
                    raise ValueError('Item {0} of `{1}` is wrong, because its key `{2}` is not a string-like '
                                     'object!'.format(idx, y_name, ne_type))
                if (ne_type == 'O') or (ne_type == 'o') or (ne_type == 'О') or (ne_type == 'о'):
                    raise ValueError('Item {0} of `{1}` is wrong, because its key `{2}` incorrectly specifies a named '
                                     'entity!'.format(idx, y_name, ne_type))
                if (not ne_type.isalpha()) or (not ne_type.isupper()):
                    raise ValueError('Item {0} of `{1}` is wrong, because its key `{2}` incorrectly specifies a named '
                                     'entity!'.format(idx, y_name, ne_type))
                classes_list.add(ne_type)
                if (not hasattr(y[idx][ne_type], '__len__')) or (not hasattr(y[idx][ne_type], '__getitem__')):
                    raise ValueError('Item {0} of `{1}` is wrong, because its value `{2}` is not a list-like '
                                     'object!'.format(idx, y_name, y[idx][ne_type]))
                for ne_bounds in y[idx][ne_type]:
                    if (not hasattr(ne_bounds, '__len__')) or (not hasattr(ne_bounds, '__getitem__')):
                        raise ValueError('Item {0} of `{1}` is wrong, because named entity bounds `{2}` are not '
                                         'specified as list-like object!'.format(idx, y_name, ne_bounds))
                    if len(ne_bounds) != 2:
                        raise ValueError('Item {0} of `{1}` is wrong, because named entity bounds `{2}` are not '
                                         'specified as 2-D list!'.format(idx, y_name, ne_bounds))
                    if (ne_bounds[0] < 0) or (ne_bounds[1] > len(X[idx])) or (ne_bounds[0] >= ne_bounds[1]):
                        raise ValueError('Item {0} of `{1}` is wrong, because named entity bounds `{2}` are '
                                         'incorrect!'.format(idx, y_name, ne_bounds))
        return tuple(sorted(list(classes_list)))

In [0]:
def load_dataset_from_json(file_name: str) -> Tuple[List[str], List[Dict[str, List[Tuple[int, int]]]]]:

    def prepare_bounds(source_named_entities: Dict[str, List[List[int]]]) -> Dict[str, List[Tuple[int, int]]]:
        prepared_named_entities = dict()
        for cur_ne in source_named_entities:
            new_list = []
            prev_idx_ = -1
            for old_bounds in sorted(source_named_entities[cur_ne]):
                if prev_idx_ < 0:
                    new_list.append((old_bounds[0], old_bounds[1]))
                else:
                    if prev_idx_ >= old_bounds[0]:
                        new_list[-1] = (new_list[-1][0], old_bounds[1])
                    else:
                        new_list.append((old_bounds[0], old_bounds[1]))
                prev_idx_ = old_bounds[1]
            prepared_named_entities[cur_ne] = new_list
            del new_list
        return prepared_named_entities

    if not os.path.isfile(file_name):
        raise ValueError('The file `{0}` does not exist!'.format(file_name))
    X = []
    y = []
    with codecs.open(file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        data = json.load(fp)
    if not isinstance(data, list):
        raise ValueError('The file `{0}` contains incorrect data! Expected a `{1}`, but got a `{2}`.'.format(
            file_name, type([1, 2]), type(data)))
    for sample_idx, sample_value in enumerate(data):
        if not isinstance(sample_value, dict):
            raise ValueError('{0} sample in the file `{1}` contains incorrect data! Expected a `{2}`, but got a '
                             '`{3}`.'.format(sample_idx, file_name, type({'a': 1, 'b': 2}), type(sample_value)))
        if 'text' not in sample_value:
            raise ValueError('{0} sample in the file `{1}` contains incorrect data! '
                             'The key `text` is not found!'.format(sample_idx, file_name))
        if 'named_entities' not in sample_value:
            raise ValueError('{0} sample in the file `{1}` contains incorrect data! '
                             'The key `named_entities` is not found!'.format(sample_idx, file_name))
        if 'paragraph_bounds' in sample_value:
            bounds_of_paragraphs = sample_value['paragraph_bounds']
            if len(sample_value) > 3:
                excess_keys = sorted(list(set(sample_value.keys()) - {'text', 'named_entities', 'paragraph_bounds',
                                                                      'base_name'}))
                if len(excess_keys) > 0:
                    raise ValueError('{0} sample in the file `{1}` contains incorrect data! Keys {2} are '
                                     'excess!'.format(sample_idx, file_name, excess_keys))
            if not isinstance(sample_value['paragraph_bounds'], list):
                raise ValueError(
                    '{0} sample in the file `{1}` contains incorrect data! Value of `paragraph_bounds` must be '
                    'a `{2}`, but it is a `{3}`.'.format(sample_idx, file_name, type([1, 2, 3]),
                                                         type(sample_value['text'])))
        else:
            bounds_of_paragraphs = None
            if len(sample_value) > 2:
                excess_keys = sorted(list(set(sample_value.keys()) - {'text', 'named_entities', 'base_name'}))
                if len(excess_keys) > 0:
                    raise ValueError('{0} sample in the file `{1}` contains incorrect data! Keys {2} are '
                                     'excess!'.format(sample_idx, file_name, excess_keys))
        if not isinstance(sample_value['text'], str):
            raise ValueError('{0} sample in the file `{1}` contains incorrect data! Value of `text` must be a `{2}`, '
                             'but it is a `{3}`.'.format(sample_idx, file_name, type('123'),
                                                         type(sample_value['text'])))
        if not isinstance(sample_value['named_entities'], dict):
            raise ValueError('{0} sample in the file `{1}` contains incorrect data! Value of `named_entities` must be '
                             'a `{2}`, but it is a `{3}`.'.format(sample_idx, file_name, type({'a': 1, 'b': 2}),
                                                                  type(sample_value['text'])))
        for entity_type in sample_value['named_entities']:
            if not isinstance(entity_type, str):
                raise ValueError(
                    '{0} sample in the file `{1}` contains incorrect data! Entity type `{2}` is wrong! Expected a '
                    '`{3}`, but got a `{4}`.'.format(sample_idx, file_name, entity_type, type('123'),
                                                     type(entity_type)))
            for entity_bounds in sample_value['named_entities'][entity_type]:
                if not isinstance(entity_bounds, list):
                    raise ValueError('{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for '
                                     'entity bounds.'.format(sample_idx, file_name, entity_bounds))
                if len(entity_bounds) != 2:
                    raise ValueError('{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for '
                                     'entity bounds.'.format(sample_idx, file_name, entity_bounds))
                if (entity_bounds[0] < 0) or (entity_bounds[0] >= len(sample_value['text'])):
                    raise ValueError('{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for '
                                     'entity bounds.'.format(sample_idx, file_name, entity_bounds))
                if (entity_bounds[1] <= entity_bounds[0]) or (entity_bounds[1] > len(sample_value['text'])):
                    raise ValueError('{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for '
                                     'entity bounds.'.format(sample_idx, file_name, entity_bounds))
        if bounds_of_paragraphs is not None:
            for paragraph_bounds in bounds_of_paragraphs:
                if not isinstance(paragraph_bounds, list):
                    raise ValueError(
                        '{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for paragraph '
                        'bounds.'.format(sample_idx, file_name, paragraph_bounds))
                if len(paragraph_bounds) != 2:
                    raise ValueError(
                        '{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for paragraph '
                        'bounds.'.format(sample_idx, file_name, paragraph_bounds))
                if (paragraph_bounds[0] < 0) or (paragraph_bounds[0] >= len(sample_value['text'])):
                    raise ValueError(
                        '{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for paragraph '
                        'bounds.'.format(sample_idx, file_name, paragraph_bounds))
                if (paragraph_bounds[1] <= paragraph_bounds[0]) or (paragraph_bounds[1] > len(sample_value['text'])):
                    raise ValueError(
                        '{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for paragraph '
                        'bounds.'.format(sample_idx, file_name, paragraph_bounds))
            bounds_of_paragraphs = [tuple(cur) for cur in bounds_of_paragraphs]
            text_by_paragraphs = list()
            entities_by_paragraphs = list()
            for paragraph_start, paragraph_end in bounds_of_paragraphs:
                text_by_paragraphs.append(sample_value['text'][paragraph_start:paragraph_end])
                entities_by_paragraphs.append(dict())
            for entity_type in sample_value['named_entities']:
                for entity_bounds in sorted(sample_value['named_entities'][entity_type]):
                    paragraph_idx = find_paragraph(bounds_of_paragraphs, entity_bounds[0], entity_bounds[1])
                    if paragraph_idx < 0:
                        raise ValueError('{0} sample in the file `{1}` contains incorrect data! {2} is wrong value for '
                                         'entity bounds.'.format(sample_idx, file_name, entity_bounds))
                    paragraph_start = bounds_of_paragraphs[paragraph_idx][0]
                    if entity_type in entities_by_paragraphs[paragraph_idx]:
                        entities_by_paragraphs[paragraph_idx][entity_type].append(
                            (entity_bounds[0] - paragraph_start, entity_bounds[1] - paragraph_start)
                        )
                    else:
                        entities_by_paragraphs[paragraph_idx][entity_type] = [
                            (entity_bounds[0] - paragraph_start, entity_bounds[1] - paragraph_start)
                        ]
            for paragraph_idx in range(len(bounds_of_paragraphs)):
                X.append(text_by_paragraphs[paragraph_idx])
                entities_by_paragraphs_ = dict()
                for entity_type in entities_by_paragraphs[paragraph_idx]:
                    entities_list = []
                    prev_idx = -1
                    for entity_bounds in sorted(entities_by_paragraphs[paragraph_idx][entity_type]):
                        if prev_idx < 0:
                            entities_list.append(entity_bounds)
                        else:
                            if prev_idx >= entity_bounds[0]:
                                entities_list[-1] = (entities_list[-1][0], entity_bounds[1])
                            else:
                                entities_list.append(entity_bounds)
                        prev_idx = entity_bounds[1]
                    entities_by_paragraphs_[entity_type] = entities_list
                    del entities_list
                y.append(entities_by_paragraphs_)
                del entities_by_paragraphs_
        else:
            X.append(sample_value['text'])
            y.append(prepare_bounds(sample_value['named_entities']))
    return X, y

In [0]:
def factrueval2016_to_json(src_dir_name: str, dst_json_name: str, split_by_paragraphs: bool=True):
    factrueval_files = dict()
    for cur_file_name in os.listdir(src_dir_name):
        if cur_file_name.endswith('.objects'):
            base_name = cur_file_name[:-len('.objects')]
        elif cur_file_name.endswith('.spans'):
            base_name = cur_file_name[:-len('.spans')]
        elif cur_file_name.endswith('.tokens'):
            base_name = cur_file_name[:-len('.tokens')]
        else:
            base_name = None
        if base_name is not None:
            if base_name in factrueval_files:
                assert cur_file_name not in factrueval_files[base_name]
                factrueval_files[base_name].append(cur_file_name)
            else:
                factrueval_files[base_name] = [cur_file_name]
    for base_name in factrueval_files:
        if len(factrueval_files[base_name]) != 3:
            raise ValueError('Files list for `{0}` is wrong!'.format(base_name))
        text_file_name = os.path.join(src_dir_name, base_name + '.txt')
        if not os.path.isfile(text_file_name):
            raise ValueError('File `{0}` does not exist!'.format(text_file_name))
        factrueval_files[base_name].append(text_file_name)
        factrueval_files[base_name] = sorted(factrueval_files[base_name])
    train_data = []
    for base_name in sorted(list(factrueval_files.keys())):
        if split_by_paragraphs:
            tokens, text, paragraphs = load_tokens_from_factrueval2016_by_paragraphs(
                os.path.join(src_dir_name, base_name + '.txt'), os.path.join(src_dir_name, base_name + '.tokens')
            )
        else:
            tokens, text, paragraphs = load_tokens_from_factrueval2016_by_sentences(
                os.path.join(src_dir_name, base_name + '.tokens')
            )
        spans = load_spans_from_factrueval2016(os.path.join(src_dir_name, base_name + '.spans'), tokens)
        objects = load_objects_from_factrueval2016(os.path.join(src_dir_name, base_name + '.objects'), spans)
        named_entities = dict()
        if len(objects) > 0:
            for object_ID in objects:
                ne_type = objects[object_ID][0]
                tokens_of_ne = set()
                spans_of_ne = objects[object_ID][1]
                for span_ID in spans_of_ne:
                    tokens_of_ne |= set(spans[span_ID])
                tokens_of_ne = sorted(list(tokens_of_ne))
                if len(tokens_of_ne) > 0:
                    token_ID = tokens_of_ne[0]
                    ne_start = tokens[token_ID][0]
                    ne_end = tokens[token_ID][1]
                    for token_ID in tokens_of_ne[1:]:
                        if tokens[token_ID][0] < ne_start:
                            ne_start = tokens[token_ID][0]
                        if tokens[token_ID][1] > ne_end:
                            ne_end = tokens[token_ID][1]
                    if ne_type in named_entities:
                        named_entities[ne_type].append((ne_start, ne_end))
                    else:
                        named_entities[ne_type] = [(ne_start, ne_end)]
        train_data.append({'text': text, 'named_entities': named_entities, 'paragraph_bounds': paragraphs,
                           'base_name': base_name})
    with codecs.open(dst_json_name, mode='w', encoding='utf-8', errors='ignore') as fp:
        json.dump(train_data, fp, indent=4, ensure_ascii=False)

In [0]:
def load_tokens_from_factrueval2016_by_paragraphs(text_file_name: str, tokens_file_name: str) -> \
        Tuple[Dict[int, Tuple[int, int, str]], str, tuple]:
    source_text = ''
    start_pos = 0
    tokens_and_their_bounds = dict()
    line_idx = 1
    bounds_of_paragraphs = []
    texts_of_paragraphs = []
    with codecs.open(text_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        cur_line = fp.readline()
        while len(cur_line) > 0:
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                texts_of_paragraphs.append(prep_line.lower())
            cur_line = fp.readline()
    paragraph_idx = 0
    paragraph_pos = 0
    with codecs.open(tokens_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        cur_line = fp.readline()
        while len(cur_line) > 0:
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                err_msg = 'File `{0}`: line {1} is wrong!'.format(tokens_file_name, line_idx)
                parts_of_line = prep_line.split()
                if len(parts_of_line) != 4:
                    raise ValueError(err_msg)
                try:
                    token_id = int(parts_of_line[0])
                except:
                    token_id = -1
                if token_id < 0:
                    raise ValueError(err_msg)
                try:
                    token_start = int(parts_of_line[1])
                except:
                    token_start = -1
                if token_start < len(source_text):
                    raise ValueError(err_msg)
                try:
                    token_len = int(parts_of_line[2])
                except:
                    token_len = -1
                if token_len < 0:
                    raise ValueError(err_msg)
                token_text = parts_of_line[3].strip()
                if len(token_text) != token_len:
                    raise ValueError(err_msg)
                if token_id in tokens_and_their_bounds:
                    raise ValueError(err_msg)
                while len(source_text) < token_start:
                    source_text += ' '
                source_text += token_text
                tokens_and_their_bounds[token_id] = (
                    token_start, token_start + token_len,
                    token_text
                )
                found_idx_in_paragraph = texts_of_paragraphs[paragraph_idx][paragraph_pos:].find(token_text.lower())
                if found_idx_in_paragraph < 0:
                    paragraph_idx += 1
                    paragraph_pos = 0
                    while paragraph_idx < len(texts_of_paragraphs):
                        if len(bounds_of_paragraphs) == 0:
                            bounds_of_paragraphs.append((0, start_pos))
                        else:
                            bounds_of_paragraphs.append((bounds_of_paragraphs[-1][1], start_pos))
                        found_idx_in_paragraph = texts_of_paragraphs[paragraph_idx].find(token_text.lower())
                        if found_idx_in_paragraph >= 0:
                            break
                        paragraph_idx += 1
                    if paragraph_idx >= len(texts_of_paragraphs):
                        raise ValueError(err_msg)
                else:
                    paragraph_pos += (found_idx_in_paragraph + len(token_text))
                start_pos = len(source_text)
            cur_line = fp.readline()
            line_idx += 1
    if len(texts_of_paragraphs) > 0:
        if len(bounds_of_paragraphs) > 0:
            bounds_of_paragraphs.append((bounds_of_paragraphs[-1][1], start_pos))
        else:
            bounds_of_paragraphs.append((0, start_pos))
    bounds_of_paragraphs_after_strip = []
    for cur_bounds in bounds_of_paragraphs:
        if cur_bounds[0] < cur_bounds[1]:
            source_paragraph_text = source_text[cur_bounds[0]:cur_bounds[1]]
            paragraph_text_after_strip = source_paragraph_text.strip()
            found_idx = source_paragraph_text.find(paragraph_text_after_strip)
            if found_idx > 0:
                paragraph_start = cur_bounds[0] + found_idx
            else:
                paragraph_start = cur_bounds[0]
            paragraph_end = paragraph_start + len(paragraph_text_after_strip)
            bounds_of_paragraphs_after_strip.append((paragraph_start, paragraph_end))
        else:
            bounds_of_paragraphs_after_strip.append(cur_bounds)
    return tokens_and_their_bounds, source_text, tuple(bounds_of_paragraphs_after_strip)

In [0]:
def load_spans_from_factrueval2016(spans_file_name: str,
                                   tokens_dict: Dict[int, Tuple[int, int, str]]) -> Dict[int, List[int]]:
    spans = dict()
    line_idx = 1
    with codecs.open(spans_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        cur_line = fp.readline()
        while len(cur_line) > 0:
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                err_msg = 'File `{0}`: line {1} is wrong!'.format(spans_file_name, line_idx)
                parts_of_line = prep_line.split()
                if len(parts_of_line) < 9:
                    raise ValueError(err_msg)
                try:
                    span_id = int(parts_of_line[0])
                except:
                    span_id = -1
                if span_id < 0:
                    raise ValueError(err_msg)
                if span_id not in spans:
                    try:
                        found_idx = parts_of_line.index('#')
                    except:
                        found_idx = -1
                    if found_idx < 0:
                        raise ValueError(err_msg)
                    if (len(parts_of_line) - 1 - found_idx) < 2:
                        raise ValueError(err_msg)
                    if (len(parts_of_line) - 1 - found_idx) % 2 != 0:
                        raise ValueError(err_msg)
                    n = (len(parts_of_line) - 1 - found_idx) // 2
                    token_IDs = []
                    try:
                        for idx in range(found_idx + 1, found_idx + n + 1):
                            new_token_ID = int(parts_of_line[idx])
                            if new_token_ID in token_IDs:
                                token_IDs = []
                                break
                            if new_token_ID not in tokens_dict:
                                token_IDs = []
                                break
                            token_IDs.append(new_token_ID)
                            if token_IDs[-1] < 0:
                                token_IDs = []
                                break
                    except:
                        token_IDs = []
                    if len(token_IDs) == 0:
                        raise ValueError(err_msg)
                    spans[span_id] = token_IDs
                    del token_IDs
            cur_line = fp.readline()
            line_idx += 1
    return spans

In [0]:
def load_objects_from_factrueval2016(objects_file_name: str,
                                     spans_dict: Dict[int, List[int]]) -> Dict[int, Tuple[str, List[int]]]:
    objects = dict()
    line_idx = 1
    with codecs.open(objects_file_name, mode='r', encoding='utf-8', errors='ignore') as fp:
        cur_line = fp.readline()
        while len(cur_line) > 0:
            prep_line = cur_line.strip()
            if len(prep_line) > 0:
                err_msg = 'File `{0}`: line {1} is wrong!'.format(objects_file_name, line_idx)
                parts_of_line = prep_line.split()
                if len(parts_of_line) < 5:
                    raise ValueError(err_msg)
                try:
                    object_id = int(parts_of_line[0])
                    if object_id in objects:
                        object_id = -1
                except:
                    object_id = -1
                if object_id < 0:
                    raise ValueError(err_msg)
                ne_type = parts_of_line[1].upper()
                if ne_type in {'PERSON', 'LOCATION', 'ORG', 'LOCORG'}:
                    if ne_type == 'LOCORG':
                        ne_type = 'LOCATION'
                    try:
                        found_idx = parts_of_line.index('#')
                    except:
                        found_idx = -1
                    if found_idx < 3:
                        raise ValueError(err_msg)
                    span_IDs = []
                    try:
                        for idx in range(2, found_idx):
                            new_span_ID = int(parts_of_line[idx])
                            if new_span_ID < 0:
                                span_IDs = []
                                break
                            if new_span_ID not in spans_dict:
                                span_IDs = []
                                break
                            if new_span_ID in span_IDs:
                                span_IDs = []
                                break
                            span_IDs.append(new_span_ID)
                    except:
                        span_IDs = []
                    if len(span_IDs) == 0:
                        raise ValueError(err_msg)
                    objects[object_id] = (ne_type, span_IDs)
                    del span_IDs
            cur_line = fp.readline()
            line_idx += 1
    return objects

In [0]:
def find_paragraph(bounds_of_paragraphs: List[Tuple[int, int]], entity_start_idx: int, entity_end_idx: int) -> int:
    paragraph_idx = -1
    for idx, bounds in enumerate(bounds_of_paragraphs):
        if (entity_start_idx >= bounds[0]) and (entity_start_idx < bounds[1]):
            if (entity_end_idx > entity_start_idx) and (entity_end_idx <= bounds[1]):
                paragraph_idx = idx
                break
    return paragraph_idx

In [0]:
def split_dataset(y: Union[list, tuple, np.array], test_part: float, n_restarts: int=10,
                  logger: Union[Logger, None]=None) -> Tuple[np.ndarray, np.ndarray]:
    if n_restarts < 2:
        raise ValueError('{0} is too small value of restarts number. It must be greater than 1.'.format(n_restarts))
    n_samples = len(y)
    if n_samples < 2:
        raise ValueError('There are too few samples in the data set! Minimal number of samples is 2.')
    n_test = int(round(test_part * n_samples))
    n_train = n_samples - n_test
    if n_test < 1:
        raise ValueError('{0} is too small value of the test part! There are no samples for '
                         'testing subset!'.format(test_part))
    if n_train < 1:
        raise ValueError('{0} is too large value of the test part! There are no samples for '
                         'training subset!'.format(test_part))
    indices = np.arange(0, n_samples, 1, dtype=np.int32)
    np.random.shuffle(indices)
    set_of_classes_for_training = set()
    set_of_classes_for_testing = set()
    for idx in indices[0:n_train]:
        set_of_classes_for_training |= set(y[idx].keys())
    for idx in indices[n_train:]:
        set_of_classes_for_testing |= set(y[idx].keys())
    if set_of_classes_for_training == set_of_classes_for_testing:
        train_index = indices[0:n_train]
        test_index = indices[n_train:]
    else:
        if set_of_classes_for_testing < set_of_classes_for_training:
            best_indices = np.copy(indices)
        else:
            best_indices = None
        for restart in range(1, n_restarts):
            np.random.shuffle(indices)
            set_of_classes_for_training = set()
            set_of_classes_for_testing = set()
            for idx in indices[0:n_train]:
                set_of_classes_for_training |= set(y[idx].keys())
            for idx in indices[n_train:]:
                set_of_classes_for_testing |= set(y[idx].keys())
            if set_of_classes_for_training == set_of_classes_for_testing:
                best_indices = np.copy(indices)
                break
            if set_of_classes_for_testing < set_of_classes_for_training:
                best_indices = np.copy(indices)
        if best_indices is None:
            if logger is None:
                warnings.warn('Data set cannot be splitted by stratified folds.')
            else:
                logger.warning('Data set cannot be splitted by stratified folds.')
            train_index = indices[0:n_train]
            test_index = indices[n_train:]
        else:
            set_of_classes_for_training = set()
            set_of_classes_for_testing = set()
            for idx in best_indices[0:n_train]:
                set_of_classes_for_training |= set(y[idx].keys())
            for idx in best_indices[n_train:]:
                set_of_classes_for_testing |= set(y[idx].keys())
            if set_of_classes_for_training != set_of_classes_for_testing:
                if logger is None:
                    warnings.warn('Data set cannot be splitted by stratified folds.')
                else:
                    logger.warning('Data set cannot be splitted by stratified folds.')
            train_index = best_indices[0:n_train]
            test_index = best_indices[n_train:]
    return np.sort(train_index), np.sort(test_index)

In [0]:
#!unzip factRuEval-2016-master.zip

In [0]:
#!pip install tensorflow==1.15.0

In [0]:
#!pip install spacy_udpipe

In [0]:
#!pip install pymorphy2==0.8

In [18]:
%%time
factrueval2016_to_json("factRuEval-2016-master/devset", "factrueval2016devset_to_json.json")
X, y = load_dataset_from_json("factrueval2016devset_to_json.json")

recognizer = ELMo_NER(
    
            finetune_elmo=False, 
            batch_size=16, 
            l2_reg=1e-2, 
            max_seq_length=200,
            elmo_hub_module_handle='http://files.deeppavlov.ai/deeppavlov_data/elmo_ru-news_wmt11-16_1.5M_steps.tar.gz', 
            validation_fraction=0.25, 
            max_epochs=1,
            patience=10, 
            gpu_memory_frac=0.9, 
            verbose=True, 
            random_seed=42, 
            lr=1e-2, 
            udpipe_lang='ru',
            use_additional_features=False

        )

CPU times: user 308 ms, sys: 15.1 ms, total: 323 ms
Wall time: 327 ms


In [0]:
elmo_ner_logger = logging.getLogger(__name__)

In [0]:
UNIVERSAL_POS_TAGS = ['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'CONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON',
                      'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']

In [0]:
UNIVERSAL_DEPENDENCIES = ['acl', 'advcl', 'advmod', 'amod', 'appos', 'aux', 'auxpass', 'case', 'cc', 'ccomp',
                          'compound', 'conj', 'cop', 'csubj', 'csubjpass', 'dep', 'det', 'discourse', 'dislocated',
                          'dobj', 'expl', 'fixed', 'flat', 'foreign', 'goeswith', 'gov', 'iobj', 'list', 'mark', 'mwe',
                          'name', 'neg', 'nmod', 'nsubj', 'nsubjpass', 'nummod', 'obj', 'obl', 'orphan', 'parataxis',
                          'pass', 'punct', 'relcl', 'remnant', 'reparandum', 'root', 'vocative', 'xcomp']

In [0]:
def create_udpipe_pipeline(lang: str) -> UDPipeLanguage:
    try:
        pipeline = spacy_udpipe.load(lang)
    except:
        spacy_udpipe.download(lang)
        pipeline = spacy_udpipe.load(lang)
    if pipeline is None:
        del pipeline
        raise ValueError('The `{0}` language cannot be loaded for the UDPipe!')
    return pipeline

In [0]:
def prepare_dependency_tag(source_tag: str) -> Set[str]:
    re_for_splitting = re.compile('[:\-]+')
    tags = {source_tag.lower().replace(':', '').replace('-', '')}
    for cur in filter(lambda it2: len(it2) > 0, map(lambda it1: it1.strip().lower(),
                                                    re_for_splitting.split(source_tag))):
        tags.add(cur)
    return tags

In [0]:
def calculate_prediction_quality(true_entities: Union[list, tuple, np.array],
                                 predicted_entities: List[Dict[str, List[Tuple[int, int]]]], classes_list: tuple) -> \
        Tuple[float, float, float, Dict[str, Tuple[float, float, float]]]:
    true_entities_ = []
    predicted_entities_ = []
    n_samples = len(true_entities)
    quality_by_entity_classes = dict()
    for sample_idx in range(n_samples):
        instant_entities = dict()
        for ne_class in true_entities[sample_idx]:
            entities_list = []
            for entity_bounds in true_entities[sample_idx][ne_class]:
                entities_list.append((entity_bounds[0], entity_bounds[1]))
            entities_list.sort()
            instant_entities[ne_class] = entities_list
            del entities_list
        true_entities_.append(instant_entities)
        del instant_entities
        instant_entities = dict()
        for ne_class in predicted_entities[sample_idx]:
            entities_list = []
            for entity_bounds in predicted_entities[sample_idx][ne_class]:
                entities_list.append((entity_bounds[0], entity_bounds[1]))
            entities_list.sort()
            instant_entities[ne_class] = entities_list
            del entities_list
        predicted_entities_.append(instant_entities)
        del instant_entities
    tp_total = 0
    fp_total = 0
    fn_total = 0
    for ne_class in classes_list:
        tp_for_ne = 0
        fp_for_ne = 0
        fn_for_ne = 0
        for sample_idx in range(n_samples):
            if (ne_class in true_entities_[sample_idx]) and \
                    (ne_class in predicted_entities_[sample_idx]):
                n1 = len(true_entities_[sample_idx][ne_class])
                n2 = len(predicted_entities_[sample_idx][ne_class])
                similarity_dict = dict()
                for idx1, true_bounds in enumerate(true_entities_[sample_idx][ne_class]):
                    for idx2, predicted_bounds in enumerate(predicted_entities_[sample_idx][ne_class]):
                        similarity, tp, fp, fn = calc_similarity_between_entities(
                            true_bounds, predicted_bounds
                        )
                        if tp > 0:
                            similarity_dict[(idx1, idx2)] = (similarity, tp, fp, fn)
                similarity, pairs = find_pairs_of_named_entities(list(range(n1)), list(range(n2)), similarity_dict)
                tp_for_ne += sum(map(lambda it: similarity_dict[it][1], pairs))
                fp_for_ne += sum(map(lambda it: similarity_dict[it][2], pairs))
                fn_for_ne += sum(map(lambda it: similarity_dict[it][3], pairs))
                unmatched_std = sorted(list(set(range(n1)) - set(map(lambda it: it[0], pairs))))
                for idx1 in unmatched_std:
                    fn_for_ne += (true_entities_[sample_idx][ne_class][idx1][1] -
                                  true_entities_[sample_idx][ne_class][idx1][0])
                unmatched_test = sorted(list(set(range(n2)) - set(map(lambda it: it[1], pairs))))
                for idx2 in unmatched_test:
                    fp_for_ne += (predicted_entities_[sample_idx][ne_class][idx2][1] -
                                  predicted_entities_[sample_idx][ne_class][idx2][0])
            elif ne_class in true_entities_[sample_idx]:
                for entity_bounds in true_entities_[sample_idx][ne_class]:
                    fn_for_ne += (entity_bounds[1] - entity_bounds[0])
            elif ne_class in predicted_entities_[sample_idx]:
                for entity_bounds in predicted_entities_[sample_idx][ne_class]:
                    fp_for_ne += (entity_bounds[1] - entity_bounds[0])
        tp_total += tp_for_ne
        fp_total += fp_for_ne
        fn_total += fn_for_ne
        precision_for_ne = tp_for_ne / float(tp_for_ne + fp_for_ne) if tp_for_ne > 0 else 0.0
        recall_for_ne = tp_for_ne / float(tp_for_ne + fn_for_ne) if tp_for_ne > 0 else 0.0
        if (precision_for_ne + recall_for_ne) > 0.0:
            f1_for_ne = 2 * precision_for_ne * recall_for_ne / (precision_for_ne + recall_for_ne)
        else:
            f1_for_ne = 0.0
        quality_by_entity_classes[ne_class] = (f1_for_ne, precision_for_ne, recall_for_ne)
    precision = tp_total / float(tp_total + fp_total) if tp_total > 0 else 0.0
    recall = tp_total / float(tp_total + fn_total) if tp_total > 0 else 0.0
    if (precision + recall) > 0.0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0.0
    return f1, precision, recall, quality_by_entity_classes

In [0]:
def calc_similarity_between_entities(gold_entity: Tuple[int, int], predicted_entity: Tuple[int, int]) -> \
        Tuple[float, int, int, int]:
    if gold_entity[1] <= predicted_entity[0]:
        res = 0.0
        tp = 0
        fp = predicted_entity[1] - predicted_entity[0]
        fn = gold_entity[1] - gold_entity[0]
    elif predicted_entity[1] <= gold_entity[0]:
        res = 0.0
        tp = 0
        fp = predicted_entity[1] - predicted_entity[0]
        fn = gold_entity[1] - gold_entity[0]
    else:
        if (gold_entity[0] == predicted_entity[0]) and (gold_entity[1] == predicted_entity[1]):
            tp = gold_entity[1] - gold_entity[0]
            fp = 0
            fn = 0
            res = 1.0
        elif gold_entity[0] == predicted_entity[0]:
            if gold_entity[1] > predicted_entity[1]:
                tp = predicted_entity[1] - predicted_entity[0]
                fp = 0
                fn = gold_entity[1] - predicted_entity[1]
            else:
                tp = gold_entity[1] - gold_entity[0]
                fp = predicted_entity[1] - gold_entity[1]
                fn = 0
            res = tp / float(tp + fp + fn)
        elif gold_entity[1] == predicted_entity[1]:
            if gold_entity[0] < predicted_entity[0]:
                tp = predicted_entity[1] - predicted_entity[0]
                fp = 0
                fn = predicted_entity[0] - gold_entity[0]
            else:
                tp = gold_entity[1] - gold_entity[0]
                fp = gold_entity[0] - predicted_entity[0]
                fn = 0
            res = tp / float(tp + fp + fn)
        elif gold_entity[0] < predicted_entity[0]:
            if gold_entity[1] > predicted_entity[1]:
                tp = predicted_entity[1] - predicted_entity[0]
                fp = 0
                fn = (predicted_entity[0] - gold_entity[0]) + (gold_entity[1] - predicted_entity[1])
            else:
                tp = gold_entity[1] - predicted_entity[0]
                fp = predicted_entity[1] - gold_entity[1]
                fn = predicted_entity[0] - gold_entity[0]
            res = tp / float(tp + fp + fn)
        else:
            if gold_entity[1] < predicted_entity[1]:
                tp = gold_entity[1] - gold_entity[0]
                fp = (gold_entity[0] - predicted_entity[0]) + (predicted_entity[1] - gold_entity[1])
                fn = 0
            else:
                tp = predicted_entity[1] - gold_entity[0]
                fp = gold_entity[0] - predicted_entity[0]
                fn = gold_entity[1] - predicted_entity[1]
            res = tp / float(tp + fp + fn)
    return res, tp, fp, fn

In [0]:
def find_pairs_of_named_entities(true_entities: List[int], predicted_entities: List[int],
                                 similarity_dict: Dict[Tuple[int, int], Tuple[float, int, int, int]]) -> \
        Tuple[float, List[Tuple[int, int]]]:
    best_similarity_sum = 0.0
    n_true = len(true_entities)
    n_predicted = len(predicted_entities)
    best_pairs = []
    if n_true == n_predicted:
        best_pairs = list(filter(lambda it1: it1 in similarity_dict, map(lambda it2: (it2, it2), range(n_true))))
        best_similarity_sum = sum(map(lambda it: similarity_dict[it][0], best_pairs))
    else:
        N_MAX_COMB = 10
        counter = 1
        if n_true < n_predicted:
            for c in comb(n_predicted, n_true):
                pairs = list(filter(
                    lambda it1: it1 in similarity_dict,
                    map(lambda it2: (it2, c[it2]), range(n_true))
                ))
                if len(pairs) > 0:
                    similarity_sum = sum(map(lambda it: similarity_dict[it][0], pairs))
                else:
                    similarity_sum = 0.0
                if similarity_sum > best_similarity_sum:
                    best_similarity_sum = similarity_sum
                    best_pairs = copy.deepcopy(pairs)
                del pairs
                counter += 1
                if counter > N_MAX_COMB:
                    break
            pairs = []
            used_indices = set()
            for true_idx in range(n_true):
                best_pred_idx = None
                best_similarity = -1.0
                for pred_idx in filter(lambda it: it not in used_indices, range(n_predicted)):
                    pair_candidate = (true_idx, pred_idx)
                    if pair_candidate in similarity_dict:
                        if similarity_dict[pair_candidate][0] > best_similarity:
                            best_similarity = similarity_dict[pair_candidate][0]
                            best_pred_idx = pred_idx
                if best_pred_idx is None:
                    break
                used_indices.add(best_pred_idx)
                pairs.append((true_idx, best_pred_idx))
            if len(pairs) > 0:
                similarity_sum = sum(map(lambda it: similarity_dict[it][0], pairs))
            else:
                similarity_sum = 0.0
            if similarity_sum > best_similarity_sum:
                best_similarity_sum = similarity_sum
                best_pairs = copy.deepcopy(pairs)
            del pairs
            del used_indices
        else:
            for c in comb(n_true, n_predicted):
                pairs = list(filter(
                    lambda it1: it1 in similarity_dict,
                    map(lambda it2: (c[it2], it2), range(n_predicted))
                ))
                if len(pairs) > 0:
                    similarity_sum = sum(map(lambda it: similarity_dict[it][0], pairs))
                else:
                    similarity_sum = 0.0
                if similarity_sum > best_similarity_sum:
                    best_similarity_sum = similarity_sum
                    best_pairs = copy.deepcopy(pairs)
                del pairs
                counter += 1
                if counter > N_MAX_COMB:
                    break
            pairs = []
            used_indices = set()
            for pred_idx in range(n_predicted):
                best_true_idx = None
                best_similarity = -1.0
                for true_idx in filter(lambda it: it not in used_indices, range(n_true)):
                    pair_candidate = (true_idx, pred_idx)
                    if pair_candidate in similarity_dict:
                        if similarity_dict[pair_candidate][0] > best_similarity:
                            best_similarity = similarity_dict[pair_candidate][0]
                            best_true_idx = true_idx
                if best_true_idx is None:
                    break
                used_indices.add(best_true_idx)
                pairs.append((best_true_idx, pred_idx))
            if len(pairs) > 0:
                similarity_sum = sum(map(lambda it: similarity_dict[it][0], pairs))
            else:
                similarity_sum = 0.0
            if similarity_sum > best_similarity_sum:
                best_similarity_sum = similarity_sum
                best_pairs = copy.deepcopy(pairs)
            del pairs
            del used_indices
    return best_similarity_sum, best_pairs

In [0]:
def comb(n: int, k: int):
    d = list(range(0, k))
    yield d
    while True:
        i = k - 1
        while i >= 0 and d[i] + k - i + 1 > n:
            i -= 1
        if i < 0:
            return
        d[i] += 1
        for j in range(i + 1, k):
            d[j] = d[j - 1] + 1
        yield d

In [28]:
%%time
recognizer.fit(X, y)

Downloaded pre-trained UDPipe model for 'ru' language
INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.



Instructions for updating:
Use keras.layers.Dense instead.


Instructions for updating:
Use keras.layers.Dense instead.


Instructions for updating:
Please use `layer.__call__` method instead.


Instructions for updating:
Please use `layer.__call__` method instead.


Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API


Instructions for updating:
Please use `keras.layers.RNN(cell)`, which is equivalent to this API


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


INFO:tensorflow:Restoring parameters from /tmp/tmpe1284mr9elmo_crf.ckpt


INFO:tensorflow:Restoring parameters from /tmp/tmpe1284mr9elmo_crf.ckpt


CPU times: user 27min 54s, sys: 57.1 s, total: 28min 51s
Wall time: 18min 17s


ELMo_NER(batch_size=16,
         elmo_hub_module_handle='http://files.deeppavlov.ai/deeppavlov_data/elmo_ru-news_wmt11-16_1.5M_steps.tar.gz',
         finetune_elmo=False, gpu_memory_frac=0.9, l2_reg=0.01, lr=0.01,
         max_epochs=1, max_seq_length=200, patience=10, random_seed=42,
         udpipe_lang='ru', use_additional_features=False,
         validation_fraction=0.25, verbose=True)

In [29]:
%%time
factrueval2016_to_json("factRuEval-2016-master/testset", "factrueval2016testset_to_json.json")

with open("factrueval2016testset_to_json.json", 'r') as fp:
  data_for_testing = json.load(fp)

_, true_entities = load_dataset_from_json("factrueval2016testset_to_json.json")

texts = []
additional_info = []
for cur_document in data_for_testing:
  base_name = os.path.join("FactRuEval2016_results/results_of_elmo_and_crf", cur_document['base_name'] + '.task1')
  for cur_paragraph in cur_document['paragraph_bounds']:
    texts.append(cur_document['text'][cur_paragraph[0]:cur_paragraph[1]])
    additional_info.append((base_name, cur_paragraph))

CPU times: user 604 ms, sys: 44.6 ms, total: 649 ms
Wall time: 1.2 s


In [30]:
%%time
predicted_entities = recognizer.predict(texts)

CPU times: user 37min 24s, sys: 40.1 s, total: 38min 4s
Wall time: 19min 53s


In [0]:
#!mkdir FactRuEval2016_results
#!mkdir FactRuEval2016_results/results_of_elmo_and_crf

In [32]:
%%time
results_for_factrueval_2016 = dict()
for sample_idx, cur_result in enumerate(predicted_entities):
    base_name, paragraph_bounds = additional_info[sample_idx]
    for entity_type in cur_result:
        if entity_type == 'ORG':
            prepared_entity_type = 'org'
        elif entity_type == 'PERSON':
            prepared_entity_type = 'per'
        elif entity_type == 'LOCATION':
            prepared_entity_type = 'loc'
        else:
            prepared_entity_type = None
        
        for entity_bounds in cur_result[entity_type]:
            postprocessed_entity = (
                prepared_entity_type,
                entity_bounds[0] + paragraph_bounds[0],
                entity_bounds[1] - entity_bounds[0]
                )
            if base_name in results_for_factrueval_2016:
                results_for_factrueval_2016[base_name].append(postprocessed_entity)
            else:
                results_for_factrueval_2016[base_name] = [postprocessed_entity]

for base_name in results_for_factrueval_2016:
    with open(base_name, 'w+') as fp:
        for cur_entity in sorted(results_for_factrueval_2016[base_name], key=lambda it: (it[1], it[2], it[0])):
            fp.write('{0} {1} {2}\n'.format(cur_entity[0], cur_entity[1], cur_entity[2]))

CPU times: user 10.2 ms, sys: 8.97 ms, total: 19.1 ms
Wall time: 19.2 ms


In [0]:
class Evaluator:
    """Response evaluator for the 1st track"""

    def __init__(self, is_locorg_enabled=True):
        """Create an object with or without the support for locorg objects"""
        self.is_locorg_enabled = is_locorg_enabled
        if is_locorg_enabled:
            self.tags = ['per', 'loc', 'org', 'locorg', 'overall']
        else:
            self.tags = ['per', 'loc', 'org', 'overall']

        self.metrics_dict = None


    def evaluate(self, std_path, test_path, output_path='', is_silent=False):
        """Run evaluation on all files in the given directories
        If output_path is provided, evaluation reports will be written there.
        is_silent determines if the result is printed to the output"""
        std = loadAllStandard(std_path)
        test = loadAllTest(test_path)

        diff = set([x.name for x in std]).symmetric_difference(
            set([y.name for y in test]))

        if len(diff) > 0:
            print('WARNING: missing files:')
            print('\n'.join(sorted(diff, key=lambda x: int(x[5:]))))

        std_by_name = dict([(x.name, x) for x in std])
        test_by_name = dict([(x.name, x) for x in test])
        names = sorted(set([x.name for x in std]).intersection(
            set([y.name for y in test])), key=lambda x: int(x[5:]))

        res = dict((tag, Metrics()) for tag in self.tags)

        for name in names:
            s = std_by_name[name]
            t = test_by_name[name]
            m = self.evaluateDocument(s, t)
            self.metrics_dict = dict((x, m[x]) for x in self.tags)
            self.printReport(s.name, output_path)
            for key in res:
                res[key].add(self.metrics_dict[key])
            
        if not is_silent:
            print(self.buildMetricsTable(res))

        return res

    def evaluateDocument(self, standard, test):
        """Run evaluation on the given standard and test markup"""
        s = standard.makeTokenSets(self.is_locorg_enabled)
        t = test.makeTokenSets(standard, self.is_locorg_enabled)

        em = EvaluationMatrix(s, t, TokenSetQualityCalculator())
        em.findSolution()
        self.em = em

        return em.metrics

    # Metrics and reports

    def buildMetricsTable(self, metrics_dict):
        """Build a table from the provided metrics for the output"""
        assert(len(metrics_dict.keys()) == len(self.tags))
        res = 'Type    ' + Metrics.header()
        for tag in self.tags:
            res += '\n{:8} '.format(tag) + metrics_dict[tag].toLine()

        return res

    def buildReport(self):
        """Builds a detailed comparison report"""
        res = ''
        res += '------STANDARD------\n'
        res += self.em.describeMatchingStd() + '\n\n';
        res += '--------TEST--------\n'
        res += self.em.describeMatchingTest() + '\n\n';
        res += '-------METRICS------\n'
        res += self.buildMetricsTable(
                self.metrics_dict
            )

        return res

    def printReport(self, name, out_dir):
        if len(out_dir) == 0:
            return

        is_perfect = self.em.metrics['overall'].f1 == 1.0
        os.makedirs(out_dir, exist_ok=True)

        filename = ('' if is_perfect else '_') +  name + '.report.txt'
        with open(os.path.join(out_dir, filename), 'w', encoding='utf-8') as f:
            f.write(self.buildReport())

In [0]:
def loadAllStandard(path):
    """Load all standard markup files from the provided directory. Returns a list."""

    names = set([x.split('.')[0] for x in os.listdir(path)])
    res = []
    for name in names:
        if re.match('book_[0-9]+', name) == None:
            continue
        res.append(Standard(name, path))
    
    return sorted(res, key=lambda x: int(x.name[5:]))

In [0]:
def loadAllTest(path):
    """Load all test markup files from the provided directory. Returns a list"""
    names = set(x.split('.')[0] for x in os.listdir(path) if '.task1' in x)
    res = [Test(name, path) for name in names]
    
    return sorted(res, key=lambda x: int(x.name[5:]))

In [0]:
class Test:
    """Test data for the first track"""
    
    def __init__(self, name, dir='.'):
        """Load the data from the given document
        
        name - file to load the data from (without an extension)
        """
        try:
            self.name = name
            full_name = os.path.join(dir, name + '.task1')
            self.load(full_name)
        except Exception as e:
            print('Failed to load "{}"'.format(full_name))
            print(e)
    

    def load(self, filename):
        """Do the exception-prone loading"""
        
        # set the allowed tags for later
        self.allowed_tags = set(['org', 'per', 'loc', 'locorg'])
            
        self.mentions = {}
        for tag in self.allowed_tags:
            self.mentions[tag] = []
            
        # read the file that should consist of lines like
        # [TAG] [START_SYMBOL_INDEX] [LENGTH]
        with safeOpen(filename) as f:
            r = csv.reader(f, delimiter=' ', quotechar=Config.QUOTECHAR)
            for index, parts in enumerate(r):
                # skip the empty lines
                if len(parts) == 0:
                    continue
                    
                try:
                    assert(len(parts) == 3)
                    tag = normalize(parts[0])
                    assert(tag in self.allowed_tags)
                    self.mentions[tag].append(Interval(*parts[1:]))
                except Exception as e:
                    line_descr = '[{}] [START_SYMBOL_INDEX] [LENGTH]'.format(
                                '/'.join(self.allowed_tags))
                    raise Exception(
                        'Error: "{}", line {}.\nExpected: {}\nReceived: {}\nDetails: {}'.format(
                            filename, index, line_descr, ' '.join(parts), str(e)))
                    
                    
    def makeTokenSets(self, standard, is_locorg_allowed=True):
        """Create a dictionary of typed TokenSet objects corresponding to the mentions,
        using the provided standard data to tokenize the intervals"""
        
        res = []
        for key in self.allowed_tags:
            for interval in self.mentions[key]:
                ts = TokenSet([token
                              for token in standard.tokens
                                  if token.start >= interval.start
                                      and token.end <= interval.end
                                      and not token.isIgnored()],
                             key, standard.text)

                # save the interval within the token set
                # to display it as-is in future
                ts.interval = interval

                if not is_locorg_allowed and key == 'locorg':
                    ts.tag = 'loc'
                res.append(ts)
        
        return res

In [0]:
class Metrics:
    """Commonly used evaluation metrics"""
    
    header_template = '{:8} {:8} {:8} {:8} {:8} {:8} {:8}'
    line_template = '{:8.4f} {:8.4f} {:8.4f} {:8.2f} {:8.2f} {:8.0f} {:8.0f}'

    def __init__(self):
        """Initialize the empty metrics object"""
        # True positive used in recall calculation
        self.tp_std = 0.0
        # True positive used in precision calculation
        self.tp_test = 0.0
        # Number of standard objects
        self.n_std = 0
        # Number of test objects
        self.n_test = 0
        # Precision
        self.precision = 1.0
        # Recall
        self.recall = 1.0
        # F1
        self.f1 = 1.0

    def recalculate(self):
        self.precision = self.tp_test / self.n_test if self.n_test > 0 else 1.0
        self.recall = self.tp_std / self.n_std if self.n_std > 0 else 1.0

        if self.n_std + self.n_test == 0:
            self.f1 = 1.0
        else:
            denominator = self.precision + self.recall
            self.f1  = (
                (2 * self.precision * self.recall / denominator)
                    if denominator > 0 else 0.0
            )

        isValid = lambda x: x >= 0 and x <= 1
        assert(isValid(self.precision))
        assert(isValid(self.recall))
        assert(isValid(self.f1))

    def add(self, other):
        self.tp_std += other.tp_std
        self.tp_test += other.tp_test
        self.n_std += other.n_std
        self.n_test += other.n_test
        self.recalculate()

    def toLine(self):
        """Returns a line for the stats table"""
        return Metrics.line_template.format(
            self.precision, self.recall, self.f1,
            self.tp_std, self.tp_test, self.n_std, self.n_test)

    @classmethod
    def header(cls):
        """Returns a header for the stats table"""
        return Metrics.header_template.format(
            'P', 'R', 'F1', 'TP1', 'TP2', 'In Std.', 'In Test.')

    @classmethod
    def createSimple(cls, tp, n_std, n_test):
        """Calculate metrics with single TruePositive value"""
        m = cls()

        m.tp_std = tp
        m.tp_test = tp
        m.n_std = n_std
        m.n_test = n_test

        m.recalculate()

        return m
    
    @classmethod
    def create(cls, tp_std, tp_test, n_std, n_test):
        """Calculate metrics with separate TruePositive values"""
        m = cls()

        m.tp_std = tp_std
        m.tp_test = tp_test
        m.n_std = n_std
        m.n_test = n_test

        m.recalculate()

        return m

In [0]:
class Standard:
    """Standard document data loaded from a set of export files.
    
    The set currently includes:
     - 'NAME.txt'
     - 'NAME.tokens'
     - 'NAME.spans'
     - 'NAME.objects'
     - 'NAME.coref'
     - 'NAME.facts'
     """
    
    def __init__(self, name, path='.'):
        self.name = name
        try:
            self.has_coref = True
            self.has_facts = True
            full_name = os.path.join(path, name)
            self.loadText(full_name + '.txt')
            self.loadTokens(full_name + '.tokens')
            self.loadSpans(full_name + '.spans')
            self.loadMentions(full_name + '.objects')
            self.loadCoreference(full_name + '.coref')
            self.loadFacts(full_name + '.facts')
        except Exception as e:
            print('Failed to load the standard of {}:'.format(name))
            print(e)
            # reset the document so it has no impact on the comparison
            self.mentions = []
            self.entities = []
            self.facts = []
    
    def loadTokens(self, filename):
        """Load the data from a file with the provided name
        
        Raw token data should be loaded from one of the system export '.tokens' file"""
        self.tokens = []
        
        with open(filename, 'r', encoding='utf-8') as f:
            rdr = csv.reader(f, delimiter=Config.DEFAULT_DELIMITER, quotechar=Config.QUOTECHAR)
            
            for index, line in enumerate(rdr):
                if len(line) == 0:
                    # skip the empty lines
                    continue
                
                if len(line) != Config.TOKEN_LINE_LENGTH:
                    # bad non-empty line
                    raise Exception(
                        'Wrong length in line {} of file {}'.format(
                            index, filename))
                
                self.tokens.append(
                    Token(*line) )

        # fill the token dictionary
        self._token_dict = dict([(x.id, x) for x in self.tokens])
        
        # set neighboor links in tokens
        self.tokens = sorted(self.tokens, key=lambda x: x.start)
        for i, token in enumerate(self.tokens):
            if i != 0:
                token.prev = self.tokens[i-1]
            if i != len(self.tokens)-1:
                token.next = self.tokens[i+1]

                
    def loadSpans(self, filename):
        """Load the data from a file with the provided name
        
        Raw span data should be loaded from one of the system export '.spans' file
        
        Expected format:
        line = <left> SPAN_FILE_SEPARATOR <right>
        left = <span_id> <tag_name> <start_pos> <nchars> <start_token> <ntokens>
        right ::= [ <token>]+ [ <token_text>]+     // <ntokens> of each
            """
        self.spans = []
        
        with open(filename, 'r', encoding='utf-8') as f:
            for index, line in enumerate(f):
                if len(line) == 0:
                    # skip the empty lines
                    continue
                
                parts = line.split(Config.SPAN_FILE_SEPARATOR)
                if len(parts) != 2:
                    # bad non-empty line
                    raise Exception(
                        'Expected symbol "{}" missing in line {} of file {}'.format(
                            Config.SPAN_FILE_SEPARATOR, index, filename))
                    
                left = parts[0]
                right = parts[1]
                
                filtered_left = [i
                     for i in left.split(Config.DEFAULT_DELIMITER)
                         if len(i) > 0]
                
                if len(filtered_left) < 6:
                    raise Exception(
                        'Missing left parts in line {} of file {}'.format(
                            index, filename))
                    
                new_span = Span(*filtered_left)
                
                filtered_right = [i
                      for i in right.split(Config.DEFAULT_DELIMITER)
                            if len(i) > 0]
                if len(filtered_right) != 2*new_span.ntokens:
                    raise Exception(
                        'Missing right parts in line {} of file {}'.format(
                            index, filename))
                
                
                token_ids = [x.strip() for x in filtered_right[:new_span.ntokens]]
                new_span.tokens = sorted([self._token_dict[x] for x in token_ids],
                                         key=lambda x: x.start)
                new_span.text = normalize(' '.join(filtered_right[new_span.ntokens:]))
                new_span.text = new_span.text.replace('\n', '')
                
                self.spans.append(new_span)
                
        # fill the span dictionary
        self._span_dict = dict([(x.id, x) for x in self.spans])

    def loadMentions(self, filename):
        """Load the data from a given 'objects' file. Expected format:
        
        line = <object_id> <type> <span_id> # <comment>
        """
        
        self.mentions = []
        with open(filename, 'r', encoding='utf-8') as f:
            r = csv.reader(f, delimiter=' ', quotechar=Config.QUOTECHAR)
            for index, line in enumerate(r):
                if Config.COMMENT_SEPARATOR in line:
                    comment_index = line.index(Config.COMMENT_SEPARATOR)
                    line = line[:comment_index]

                if len(line) == 0:
                    continue
                
                if len(line) <= 2:
                    raise Exception(
                        'Missing spans in object description: line {} of file {}'.format(
                            index, filename))
                
                try:
                    mention_id = line[0].strip()
                    span_indices = [descr.strip() for descr in line[2:]]
                except Exception as e:
                    raise Exception('Invalid mention or span id: line {} of file {}:\n{}'.format(
                        index, filename, e))

                self.mentions.append(
                    Mention(mention_id, line[1], span_indices, self._span_dict))
                
        # fill the mention dictionary
        self._mention_dict = dict([(x.id, x) for x in self.mentions])
        for m in self.mentions:
            m.findParents(self.mentions)
            m.setText(self.text)

    def loadCoreference(self, filename):
        """Load coreference data from the associated file"""
        self.entities = []

        try:
            open(filename, 'r', encoding='utf-8')
        except:
            # there are currently some documents with no .coref layer. This is temporary
            self.has_coref = False
            return
        
        with open(filename, 'r', encoding='utf-8') as f:
            buffer = ''
            for raw_line in f:
                line = raw_line.strip(' \t\n\r')
                if len(line) == 0:
                    if len(buffer) > 0:
                        e = Entity.fromStandard(buffer, self._mention_dict, self._span_dict)
                        self.entities.append(e)
                        buffer = ''
                else:
                    buffer += line + '\n'
                
            if len(buffer) > 0:
                self.entities.append(Entity.fromStandard(buffer, self._mention_dict, self._span_dict))

        self._entity_dict = {}
        for ent in self.entities:
            self._entity_dict[ent.id] = ent

    def loadFacts(self, filename):
        """Load facts from the associated file"""
        self.facts = []

        try:
            open(filename, 'r', encoding='utf-8')
        except:
            # there are currently some documents with no .coref layer. This is temporary
            self.has_facts = False
            return
        
        with open(filename, 'r', encoding='utf-8') as f:
            buffer = ''
            for raw_line in f:
                line = raw_line.strip(' \t\n\r')
                if len(line) == 0:
                    if len(buffer) > 0:
                        e = Fact.fromStandard(buffer, self._entity_dict, self._span_dict)
                        self.facts.append(e)
                        buffer = ''
                else:
                    buffer += line + '\n'

            if len(buffer) > 0:
                self.facts.append(Fact.fromStandard(buffer, self._entity_dict, self._span_dict))

        part_of_facts = [f for f in self.facts if f.tag == 'ispartof']
        for fact in self.facts:
            fact.expandWithIsPartOf(part_of_facts)

    def loadText(self, filename):
        """Load text from the associated text file"""
        with open(filename, 'r', encoding='utf-8') as f:
            self.text = safeNormalize(''.join( [line for line in f] ))
            
    def makeTokenSets(self, is_locorg_allowed=True):
        """Create a dictionary of typed TokenSet objects corresponding to the mentions
        
        is_locorg_allowed - enable/disable 'LocOrg' tag"""
        
        # determine what tags are allowed
        allowed_tags = set(['org', 'per', 'loc'])
        if is_locorg_allowed:
            allowed_tags.add('locorg')
            
        res = []
        for mention in self.mentions:
            key = mention.tag
            if key == 'locorg' and not is_locorg_allowed:
                key = 'loc'
            if not (key in allowed_tags):
                continue
            ts = TokenSet(
                    [x for span in mention.spans for x in span.tokens],
                    key, self.text)

            ts.id = mention.id

            for span in mention.spans:
                for token in span.tokens:
                    mark = Tables.getMark(ts.tag, span.tag)
                    ts.setMark(token, mark)
            res.append(ts)

        # find and mark embedded objects
        for obj in res:
            obj.findParents(res)

        return res

In [0]:
def safeNormalize(string):
    """Run a number of normalization operations on the given string.
    The string is normalized not in the linguistic sense, but rather in such a way that
    captialization, e/ё and quote simbols become irrelevant in comparison
    
    All unique non-letter chars:  !"$%()*+,-./0123456789:;<=>?«»–—’“”„•…№
    """

    # force lower case
    res = string.lower()

    # trim the string
    res =  res.strip(' \r\n\t')

    # unify all quote symbols
    for s in '«»“”„':
        res = res.replace(s, '"')

    # unify all single quotes
    for s in '’`':
        res = res.replace(s, "'")

    # unify all ё/е
    res = res.replace('ё', 'е')

    # unify all dashes
    for s in '-‐−‒–—―':
        res = res.replace(s, '-')

    return res

In [0]:
class EvaluationMatrix:
    """Matrix built out of object pair quality that finds an optimal matching"""

    allowed_tags = ['per', 'loc', 'org', 'locorg']

    def __init__(self, std, test, calc, mode='regular'):
        """Initialize the matrix.
        
        std and test must be lists of objects from standard and test respectively
        mode must be either 'regular' or 'simple' and it determines whether locorgs are
        matched with orgs and locs or not
        calc must be a priority/quality calculator object used in the task at hand"""

        assert(mode == 'regular' or mode == 'simple')
        self.mode = mode
        self.metrics = {}

        self.s = {}
        self.t = {}
        for tag in EvaluationMatrix.allowed_tags:
            self.s[tag] = TagData(tag, std)
            self.t[tag] = TagData(tag, test)

        # finalize the offsets
        for i in range(1, len(EvaluationMatrix.allowed_tags)):
            prev_tag = EvaluationMatrix.allowed_tags[i-1]
            tag = EvaluationMatrix.allowed_tags[i]
            self.s[tag]._start = self.s[prev_tag].end()
            self.t[tag]._start = self.t[prev_tag].end()

        self.std = []
        self.test = []
        for tag in EvaluationMatrix.allowed_tags:
            self.std.extend(self.s[tag].objects)
            self.test.extend(self.t[tag].objects)

        self.n_std = len(self.std)
        self.n_test = len(self.test)

        self.m = np.zeros((self.n_std, self.n_test))
        self.calc = calc

        for i, x in enumerate(self.std):
            for j, y in enumerate(self.test):
                self.m[i][j] = self.calc.priority(x, y)

    def findSolution(self):
        """Runs the recursive search to find an optimal matching"""
        
        q, pairs = self._recursiveSearch(
            [i for i in range(self.n_std)],
            [j for j in range(self.n_test)],
            []
            )

        self.metrics['overall'] = self._evaluate(pairs)
        for tag in EvaluationMatrix.allowed_tags:
            self.metrics[tag] = self._evaluate(pairs, tag)

        # save matching data
        self.logMatching(pairs)

        return pairs

    def logMatching(self, pairs):
        """Saves matching data"""
        self.matching = {}
        self.matched_std = []
        self.matched_test = []
        for i, j in pairs:
            s = self.std[i]
            t = self.test[j]
            self.matching[s] = t
            self.matching[t] = s
            self.matched_std.append(s)
            self.matched_test.append(t)
        self.unmatched_std = [s for s in self.std if not s in self.matched_std]
        self.unmatched_test = [t for t in self.test if not t in self.matched_test]

    def describeMatchingStd(self):
        """Builds a detailed matching description for standard objects"""
        return self._doDescribeMatching(self.matched_std, self.unmatched_std, False)

    def describeMatchingTest(self):
        """Builds a detailed matching description for test objects"""
        return self._doDescribeMatching(self.matched_test, self.unmatched_test, True)

    def _doDescribeMatching(self, matched, unmatched, is_swapped):
        """Builds a detailed matching description with the given lookup tables"""
        res = ''
        for x in matched:
            y = self.matching[x]
            pair = (y, x) if is_swapped else (x, y)
            is_ignored = self.calc.isIgnored(pair[0], pair[1], self.matching)
            q = self.calc.quality(*pair)
            res += '{}\t{}\t=\t{}\n'.format('{:7.2f}'.format(q)
                                            if not is_ignored
                                            else 'IGNORED',
                    x.toInlineString(), y.toInlineString())

        res += '\n'
        for x in unmatched:
            is_ignored = (self.calc.isTestIgnored(x, self.matching)
                          if is_swapped
                          else self.calc.isStandardIgnored(x, self.matching))
            res += '{} {}\n'.format('{:7.2f}'.format(0.0)
                                    if not is_ignored
                                    else 'IGNORED',
                    x.toInlineString())

        return res

    def _recursiveSearch(self, std, test, pairs):
        """
            Run a recursive search of the optimal matching.
            Returns the following tuple: (overall quality, matching)
            std - remaining standard indices list
            test - remaining test indices list
            pairs - current list of built pairs
        """
        if len(std) == 0 or len(test) == 0:
            # final step, evaluate the matching
            metrics = self._evaluate(pairs)
            return metrics.f1, pairs

        curr = std[0]
        max_res = None

        possible_pairs_count = 0
        pair_max_alternatives = 0

        options, has_perfect_match = self._findMatches(curr, test)
        for t in options:
            i = test.index(t)

            # let's see what other matching options does this test object have
            # this is necessary to check conditions for the logic below
            alt_count = 0
            skip_test_object = False
            for k in std[1:]:
                if self.m[k, t] == 1.0 and self.m[curr, t] < 1.0:
                    # test objects that have some other perfect matching must be skipped
                    skip_test_object = True
                if self.m[k, t] != 0.0:
                    alt_count += 1
                if alt_count > pair_max_alternatives:
                    pair_max_alternatives = alt_count
                
            if skip_test_object:
                continue
            else:
                possible_pairs_count += 1


            # try to confirm the pair
            res = self._recursiveSearch(
                std[1:], test[:i] + test[i+1:],
                pairs + [(curr, t)])
            if max_res is None or res[0] > max_res[0]:
                max_res = res

        # check what would happen if this standard object were ignored
        # this check is obviously performance-heavy and only necessary under
        # these conditions
        if (possible_pairs_count == 0
                or possible_pairs_count == 1
                    and pair_max_alternatives > 0
                    and not has_perfect_match):
            res = self._recursiveSearch(
                std[1:], test,
                pairs)
            if max_res is None or res[0] > max_res[0]:
                max_res = res

        return max_res

    def _findMatches(self, s_index, test):
        """Finds a list of possible matches for the standard object with the given index
        within the list of available test objects.
        
        Returns a list of test object indices
        
        According to the documentation, any perfectly fitting objects MUST be matched"""
        perfect_matches = [t for t in test if self.m[s_index, t] == 1.0]
        matches = [t for t in test if self.m[s_index, t] > 0.0] 
        if len(perfect_matches) > 0:
            return perfect_matches, True
        else:
            return matches, False

    def _evaluate(self, pairs, tag_filter = ''):
        matched_std = set(self.std[_s] for _s,_t in pairs)
        matched_test = set(self.test[_t] for _s,_t in pairs)
        
        if tag_filter in EvaluationMatrix.allowed_tags:
            subset = self._reduce(pairs, tag_filter)
        else:
            subset = pairs

        unmatched_std = [s for s in self.std if not (s in matched_std)]
        unmatched_test = [t for t in self.test if not (t in matched_test)]

        # unmatched_test must contain all objects of the given tag that were not in ANY
        # pair, including cases where a locorg was matched to an loc, for example
        if tag_filter in EvaluationMatrix.allowed_tags:
            unmatched_std = [s for s in unmatched_std if s.tag == tag_filter]
            unmatched_test = [t for t in unmatched_test if t.tag == tag_filter]

        # replace indices with actual objects for evaluation
        actual_pairs = [(self.std[_s], self.test[_t]) for _s,_t in subset]

        return self.calc.evaluate(actual_pairs, unmatched_std, unmatched_test)

    def _reduce(self, matching, tag):
        """Returns a sub-matching corresponding to the given tag"""
        res = []
        for _s, _t in matching:
            if _s >= self.s[tag].start() and _s < self.s[tag].end():
                res.append((_s, _t))
        return res

In [0]:
class TokenSetQualityCalculator:
    """Calculates preliminary and final quality for TokenSet objects"""

    tag_table = {
        ('per', 'per') : 1, ('per', 'org') : 0, ('per', 'loc') : 0, ('per', 'locorg') : 0,
        ('org', 'per') : 0, ('org', 'org') : 1, ('org', 'loc') : 0, ('org', 'locorg') : 0,
        ('loc', 'per') : 0, ('loc', 'org') : 0, ('loc', 'loc') : 1, ('loc', 'locorg') : 0,
        ('locorg', 'per') : 0, ('locorg', 'org') : 0, ('locorg', 'loc') : 0, ('locorg', 'locorg') : 1
    }

    def tagMultiplier(self, s, t):
        return TokenSetQualityCalculator.tag_table[(s.tag, t.tag)]
    
    def evaluate(self, pairs, unmatched_std, unmatched_test):
        """Evaluate the matching. Returns metrics"""
        matching = {}
        for s, t in pairs:
            matching[s] = t
            matching[t] = s

        tp = 0
        n_relevant_pairs = 0
        matched_std_objects = set()
        for s, t in pairs:
            if not self.isIgnored(s, t, matching):
                tp += self.quality(s, t)
                matched_std_objects.add(s)
                n_relevant_pairs += 1

        # in this task no unmatched test object can be ignored
        n_test = n_relevant_pairs + len(unmatched_test)

        n_std = n_relevant_pairs
        for obj in unmatched_std:
            if not obj in matched_std_objects:
                if not self.isStandardIgnored(obj, matching):
                    n_std += 1

        return Metrics.createSimple(tp, n_std, n_test)

    def priority(self, s, t):
        """Calculate preliminary quality that goes into the optimization table"""
        multiplier = self.tagMultiplier(s,t)
        if multiplier == 0:
            return 0

        tp = len(s.tokens.intersection(t.tokens))
        fn = len(s.tokens.difference(t.tokens))
        fp = len(t.tokens.difference(s.tokens))
        
        summ = tp + fp + fn
        assert(summ > 0)
        return multiplier * tp / summ if summ > 0 else 0

    def quality(self, s, t):
        """Calculate final quality that is maximized during the matching optimization"""
        multiplier = self.tagMultiplier(s,t)
        if multiplier == 0:
            return 0

        tokens_tp = s.tokens.intersection(t.tokens)
        tokens_fn = s.tokens.difference(t.tokens)
        
        tp = 0.0
        for token in tokens_tp:
            tp += s.mark(token) # there can be no punctuation here

        fn = 0.0
        for token in tokens_fn:
            # in case some punctuation did end up in the standard markup somehow
            # (which apparently happens)
            fn += s.mark(token) if not token.isPunctuation() else 0
                
        fp = len(t.tokens.difference(s.tokens))
        
        summ = tp + fp + fn

        # summ can be equal to zero in cases when the mention has no 'priority' spans like
        # org_name. In these cases, we will just compare the annotations with no weights
        return multiplier * tp / summ if summ > 0 else self.priority(s,t)
    
    def isIgnored(self, s, t, matching):
        """Check if the matched pair of (s, t) should be ignored within the current
        matching"""


        return self.isStandardIgnored(s, matching)

    def isStandardIgnored(self, s, matching):
        """Check if the given standard object should be ignored within the current
        matching"""

        # unnamed objects are ignored regardless of their matching status
        if s.isUnnamed():
            return True
        
        # embedded objects are ignored regardless of their matching status
        if len(s.parents) > 0:
            return True

        # sibling object processing logic
        assert(len(s.siblings) <= 1)
        for sibling in s.siblings:
#            assert(set([sibling.tag, s.tag]) == set(['org', 'loc']))
            if (s in matching) == (sibling in matching):
                # when both or neither are matched, ignore the non-organization
                # in case of both being organizations, use any of them
                if sibling.tag == s.tag:
                    if s.is_ignored_sibling:
                        return True
                    else:
                        sibling.is_ignored_sibling = True
                        return False
                else:
                    assert('org' in [sibling.tag, s.tag])
                    return s.tag != 'org'
            else:
                # otherwise ignore the unmatched sibling
                return not (s in matching)


        return False

    def isTestIgnored(self, t, matching):
        """Check if the given standard object should be ignored within the current
        matching"""

        # in this track no test object can be ignored
        return False

In [0]:
class TagData:
    """Utility object that contains data regarding a set of objects currently processed
    by EvaluationMatrix"""

    def __init__(self, tag, object_list):
        """Loads an object list with the given tag from the larger object_list"""
        self.tag = tag
        self.objects = sorted([x for x in object_list if x.tag == tag],
                              key=lambda x: x.id)
        self.size = len(self.objects)
        self._start = 0

    def start(self):
        return self._start

    def end(self):
        return self._start + self.size

In [0]:
class Config:
    """Global configuration"""
    DEFAULT_DELIMITER = ' '
    QUOTECHAR = '|'
    TOKEN_LINE_LENGTH = 4
    SPAN_FILE_SEPARATOR = '  # '
    COMMENT_SEPARATOR = '#'
    STANDARD_TYPES = {
        'Person' : 'per',
        'Organization' : 'org',
        'Org' : 'org',
        'LocOrg' : 'locorg',
        'Location' : 'loc',
        'Project' : 'project'
    }

In [0]:
class Token:
    """Raw token"""
    
    def __init__(self, id, start, length, text):
        """Create a new token with the given parameters"""
        self.id = id
        self.start = int(start)
        self.length = int(length)
        self.end = self.start + self.length - 1
        self.text = normalize(text)
        self.next = None
        self.prev = None        
        
    def __repr__(self):
        return '{}[{}-{}, #{}]'.format(
            self.text, self.start, self.end, self.id)

    def __str__(self):
        return repr(self)
    
    def isLetter(self):
        """Check if this token is a single letter"""
        if len(self.text) > 1:
            return False
        return self.text.upper() != self.text or self.text.lower() != self.text

    def isPunctuation(self):
        """Check if this token is punctuation. Only checks for a limited amount of
        symbols because this method is only called to detect a small amount of special
        occasions in standard markup"""
        return len(self.text) == 1 and not self.isLetter()

    def isIgnored(self):
        """Check if this token should be ignored during the comparison.
        The comparison is supposed to ignore the punctuation tokens that are(presumably)
        located directly next to their neighboors"""
        
        return self.length == 1 and not self.isLetter() and (
            self.prev != None and self.start - self.prev.end == 1
                or self.next != None and self.next.start - self.end == 1)

    def isIgnoredFromLeft(self):
        """Check if this token should be ignored during the comparison. In this case the
        token must be directly next to its prev. neighbour"""
        return self.length == 1 and not self.isLetter() and (
            self.prev != None and self.start - self.prev.end == 1)

    def isIgnoredFromRight(self):
        """Check if this token should be ignored during the comparison. In this case the
        token must be directly next to its next. neighbour"""
        return self.length == 1 and not self.isLetter() and (
            self.next != None and self.next.start - self.end == 1)

In [0]:
def normalize(string):
    """Run a number of normalization operations on the given string.
    The string is normalized not in the linguistic sense, but rather in such a way that
    captialization, e/ё and quote simbols become irrelevant in comparison
    
    Unlike safeNormalize, this function also attempts to get rid of extra spaces before
    punctuation"""

    res = safeNormalize(string)

    # in standard strings are generated from tokens, and sometimes have 1 extra space
    # before or after puntuation symbols

    res = res.replace(' ,', ',')
    res = res.replace(' .', '.')
    res = res.replace(' -', '-')
    res = res.replace('- ', '-')

    res = res.replace('( ', '(')
    res = res.replace(' )', ')')

    return res

In [0]:
class Span:
    """Raw span"""
    
    def __init__(self, id, tag, start, nchars, token_start, ntokens):
        """Create a new span with the given parameters"""
        self.id = id
        self.tag = tag
        
        self.start = int(start)
        self.end = int(start) + int(nchars)
        
        self.token_start = int(token_start)
        self.ntokens = int(ntokens)
        
        self.tokens = []
        self.text = ''

    def isInQuotes(self):
        """Check if the span is preceded by an opening quote and succeeded by a closing
        one"""

        lq = self.getLeftQuote()
        rq = self.getRightQuote()

        return (lq + rq) in ['""', "''", '«»']

        return self.getLeftQuote() != '' and self.getLeftQuote() == self.getRightQuote()

    def getLeftQuote(self):
        """Find and return quote preceding the span. If there is no quote, returns ''"""
        # tokens are always sorted
        assert(len(self.tokens) > 0)
        prev = self.tokens[0].prev
        if prev != None and prev.text in ['"', "'", "«"]:
            return prev.text
        else:
            return ''

    def getRightQuote(self):
        """Find and return quote succeeding the span. If there is no quote, returns ''"""
        # tokens are always sorted
        assert(len(self.tokens) > 0)
        next = self.tokens[-1].next
        if next != None and next.text in ['"', "'", "»"]:
            return next.text
        else:
            return ''

        
    def __repr__(self):
        return '{}[{} #{}],  ntokens={}'.format(
            self.text, self.tag, self.id, self.ntokens)

    def __str__(self):
        return repr(self)

In [0]:
class Mention:
    """Mention consisting of spans"""
    
    def __init__(self, id, tag, span_ids, span_dict):
        """Create a new mention of a given type with the provided spans"""
        self.id = id
        self.parents = []
        
        if not tag in Config.STANDARD_TYPES:
            raise Exception('Unknown mention tag: {}'.format(tag))
        self.tag = Config.STANDARD_TYPES[tag]
        
        self.spans = []
        self.text = ''
        self.interval_text = ''
        for id in span_ids:
            self.spans.append(span_dict[id])
        
    def __repr__(self):
        res = '{} #{}:\n'.format(self.tag, self.id)
        for span in self.spans:
            res += '\t{} : {}\n'.format(span.tag, span.text)
        res += '\n'
        return res

    def isGeoAdj(self):
        """Checks if the mention only has geo_adj spans"""
        non_geo_adj = [s for s in self.spans if s.tag != 'geo_adj']
        return len(non_geo_adj) == 0

    def isDescr(self):
        """Checks if the mention only has descriptor spans"""
        non_descr = [s for s in self.spans if 'descr' not in s.tag]
        return len(non_descr) == 0

    def findParents(self, mentions):
        """Scans the given mention list for mentions embedding this one"""
        self.parents = []
        for m in [x for x in mentions if x.tag in Tables.PARENT_TAGS[self.tag]]:
            s_int = self.toInterval()
            m_int = m.toInterval()
            if s_int.isIn(m_int):
                self.parents.append(m)
            else:
                # organizations have priority over equally sized people and locations
                if self.tag in ['per', 'loc'] and m.tag=='org' and s_int.isEqual(m_int):
                    self.parents.append(m)

    def toInterval(self):
        assert(len(self.spans) > 0)
        by_start = sorted(self.spans, key=lambda x: x.start)
        by_end = sorted(self.spans, key=lambda x: x.end)
        start = by_start[0].start
        length = by_end[-1].end - by_start[0].start + 1
        return Interval(start, length)

    def setText(self, documentText):
        """Sets the text from the document corresponding to the mention"""
        ts = TokenSet([t for s in self.spans for t in s.tokens], self.tag, documentText)
        self.text = ' '.join([t.text for t in ts.sortedTokens()])
        interval = ts.toInterval()
        self.interval_text = documentText[interval.start:interval.end]

    def __str__(self):
        return repr(self)

In [0]:
class Tables:
    """Tables with error weights"""

    def getMark(mention_tag, span_tag, dfl_value = 0):
        """Lookup error weight of the provided pair.
        
        Returns default value if the pair is not in QUALITY table"""
        if mention_tag in Tables.QUALITY and span_tag in Tables.QUALITY[mention_tag]:
            return Tables.QUALITY[mention_tag][span_tag]
        else:
            return dfl_value

    def getArgumentWeight(tag):
        """Lookup the given argument weight"""

        if tag in Tables.ARG_WEIGHTS:
            return Tables.ARG_WEIGHTS[tag]
        else:
            return 1.0

    ARG_WEIGHTS = {
        'position' : 0.5,
        'фаза' : 0.5
        }

    # this table specifies weights of various spans in mention evaluation
    QUALITY = {
        'locorg' : {
            'none' : 1,
#            'loc_descr' : 1,
            'org_name' : 1,
#            'org_descr' : 1,
#            'loc_descr' : 1,
            'loc_name' : 1
        },

        'loc' : {
            'none' : 1,
#            'name' : 1,
#            'org_descr' : 1,
            'org_name' : 1,
#            'loc_descr' : 1,
#            'surname' : 1,
            'loc_name' : 1,
#            'nickname' : 1
        },

        'org' : {
            'none' : 1,
#            'org_descr' : 1,
#            'surname' : 1,
            'loc_name' : 1,
#            'loc_descr' : 1,
            'org_name' : 1,
#            'job' : 1
        },

        'per' : {
            'none' : 1,
            'name' : 1,
            'patronymic' : 1,
            'nickname' : 1,
            'surname' : 1
        }
    }

    # This table describes rules for mention and tokenset embedding, e.g.
    # mentions with tag KEY can be embedded in mentions with tag VALUE
    PARENT_TAGS = {
        'per' : ['loc', 'org', 'locorg'],
        'loc' : ['loc', 'org', 'locorg'],
        'org' : ['org', 'locorg'],
        'locorg' : ['org', 'locorg'],
        'project' : ['org', 'locorg']
    }

In [0]:
def safeOpen(filename):
    """Open a utf-8 file with or without BOM in read mode"""
    for enc in ['utf-8-sig', 'utf-8']:
        try:
            f = open(filename, 'r', encoding=enc)
        except Exception as e:
            print(e)
            f = None

        if f != None:
            return f

In [0]:
class Interval:
    """Text interval"""
    
    def __init__(self, start, length):
        """Create an interval with the given starting position and length."""
        self.start = int(start)
        self.length = int(length)
        self.end = self.start + self.length - 1
        
    def isEqual(self, other):
        """Check if this interval is equal to the other one"""
        return self.start == other.start and self.end == other.end

    def isIn(self, other):
        """Check if this interval lies within the other one (but not equal)"""
        return (self.start >= other.start
                and self.end <= other.end
                and not self.isEqual(other))

    def __repr__(self):
        return '<{}; {}>'.format(self.start, self.end)

    def __str__(self):
        return repr(self)

In [0]:
class TokenSet:
    """A set of tokens corresponding to an object"""
    
    def __init__(self, token_list, tag, text):
        self.id = -1
        self.tokens = set(token_list)
        self.tag = tag
        self.parents = []
        self.interval = None
        self._span_marks = dict([(x, 0) for x in self.tokens])
        self.text = text
        self.is_ignored_sibling = False
        
    def __repr__(self):
        return '<' + ' '.join([repr(x) for x in self.sortedTokens()]) + '>'

    def __str__(self):
        return repr(self)

    def sortedTokens(self):
        """Make a list of tokens sorted by their starting position"""
        return sorted(self.tokens, key=lambda x: x.start)
    
    def getHoles(self):
        """Return tokens not present in the set but located between
        the included tokens in the text"""
        
        res = []
        for i, token in enumerate(self.sortedTokens()):
            if i != len(self.tokens) - 1:
                t = token.next
                hole = []
                while not t in self.tokens:
                    hole.append(t)
                    t = t.next
                if len(hole) > 0:
                    res.append(hole)
        return res
    
    def intersects(self, other):
        """Check for an intersection with the other TokenSet"""
        return len(self.tokens.intersection(other.tokens)) > 0
    
    def toInterval(self):
        """Create an interval for the response generator"""
        if self.interval != None:
            return self.interval

        t = self.sortedTokens()
        
        # try to include quotes on the left and any punctuation on the right
        if len(t) == 0:
            print(self)
        start_token = t[0]
        end_token = t[len(t)-1]
        
        while start_token.prev != None and start_token.prev.text == '"':
                start_token = start_token.prev
        while end_token.next != None and end_token.next.text == '"':
                end_token = end_token.next
        
        start = start_token.start
        end = end_token.end
        length = end - start + 1
        return Interval(start, length)

    def isUnnamed(self):
        """Returns True if the object is unnamed and must be ignored.
        In practice it means that no token of the set is marked with a span of value>0"""
        for key in self._span_marks:
            if self._span_marks[key] > 0:
                return False
        return True
        
    def mark(self, token):
        """Return the span mark for this token"""
        if token not in self._span_marks:
            return 0
        else:
            return self._span_marks[token]
        
    def setMark(self, token, mark):
        """Try to increase the mark of the given token"""
        if(self._span_marks[token] < mark):
            self._span_marks[token] = mark

    def isEmbedded(self):
        """Only true if this object is embedded into another object"""
        return len(self.parents)>0

    def findParents(self, all_token_sets):
        """Fill the parent and sibling lists of the current token set"""
        self.parents = []
        self.siblings = []
        for other in [x for x in all_token_sets
                        if x.tag in Tables.PARENT_TAGS[self.tag]]:
            if other is self:
                # all_token_sets can include this set as well
                continue

            if self.tokens < other.tokens:
                self.parents.append(other)
            elif self.tokens == other.tokens:
                self.siblings.append(other)

    def toInlineString(self):
        """Make an inline representation using the tokensets interval"""
        i = self.toInterval()
        return (self.tag.upper()
                + (' {}'.format(self.id) if self.id != -1 else '')
                + ' {} "{}"'.format(i, self.text[i.start:i.end+1]))

In [0]:
class Entity:
    """Entity with a set of attributes, assembled from several mentions throughout the
    document"""


    def __init__(self):
        """Create a new object. Do not call this directly, use classmethods instead."""
        self.attributes = []
        self.id = -1
        self.tag = 'unknown'
        self.spans = []
        self.mentions = []
        self.is_problematic = False


    def processAttributes(self):
        """Merge attributes with similar names, remove suffixes from the names and create
        alternatives"""
        raw_attributes = self.attributes
        self.attributes = []
        names = set([x.name for x in raw_attributes])
        descriptors = []

        # Merge all attributes of the same name into alternatives, except descriptors
        for name in names:
            attr_by_name = [x for x in raw_attributes if x.name == name]
            if name.endswith('descr') or name.endswith('descriptor'):
                descriptors.extend(attr_by_name)
            elif name != 'wikidata':
                # wikidata must be ignored for all the tracks
                if name == 'name':
                    # extend names with qoutes if available
                    added_attrs = [x.tryPutInQoutes(self) for x in attr_by_name]
                    attr_by_name.extend(x for x in added_attrs if x != None)
                self.attributes.append(Attribute.merge(attr_by_name, name))

        # Add descriptors to the list of alternatives
        if len(descriptors) > 0:
            descr = Attribute.merge(descriptors, 'descr')
            for attr in self.attributes:
                attr.buildAlternatives(descr)

        # Trim names ending with digits
        for attr in self.attributes:
            attr.trimName()


    def toInlineString(self):
        """Creates an inline description of this entity"""
        res = self.tag.upper()
        res += ' ' + str(self.id) if self.id != -1 else ''
        res += ' [' + ', '.join([str(x) for x in self.attributes]) + ']'

        return res


    def toTestString(self):
        """Creates a test representation of this entity"""
        res = self.tag
        for attr in self.attributes:
            res += '\n' + attr.toTestString()

        return res


    def _load_id_line(self, line, mention_dict, span_dict):
        """Load ids from the first line of the standard representation"""
        str_ids = line.strip(' \n\r\t').split(' ')
        self.id = str_ids[0]

        self.is_problematic = False
        
        for _some_id in str_ids[1:]:
            some_id = _some_id
            if some_id in mention_dict:
                assert(not some_id in span_dict)
                mention = mention_dict[some_id]
                if not mention in self.mentions:
                    self.mentions.append(mention)
            elif some_id in span_dict:
                assert(not some_id in mention_dict)
                span = span_dict[some_id]
                if not span in self.spans:
                    self.spans.append(span)
            else:
                self.is_problematic = True
                print('FOUND PROBLEMATIC ENTITY: {} has no {}'.format(self, some_id))

        # it is not actually the case, at least for now
        # but arguably it should be
        # assert(len(self.mentions) > 0)
        
        tags = set([x.tag for x in self.mentions])
        if len(tags) > 1 and ('locorg' in tags or 'loc' in tags):
            # for this task all locorg objects are condidered loc
            self.tag = 'loc'
        else:
            # there can be no other mutlitype entities
            assert(len(tags)==1)
            self.tag = tags.pop()
            
            if self.tag == 'locorg':
                self.tag = 'loc'


    def getAttr(self, name):
        """Return all values of the attribute with a given name"""
        return [v for attr in self.attributes for v in attr.values if attr.name == name]

    def __repr__(self):
        res = ''
        res += '{} #{}'.format(self.tag, self.id)
        for attribute in self.attributes:
            res += '\n  {}'.format(attribute)
        return res


    def __str__(self):
        return self.__repr__()


    # static build methods
    @classmethod
    def fromStandard(cls, text, mention_dict, span_dict):
        """Load the entity from a block of text of the following format
        
        [entity_id][ (span_id|mention_id)]+
        [attr_name] [attr_value]
        ...
        [attr_name] [attr_value]
        mention_dict - mention_id -> mention
        span_dict - span_id -> span
        """

        assert(len(text.strip('\r\n\t ')) > 0)
        lines = text.split('\n')

        instance = cls()
        for line in lines[1:]:
            if len(line) == 0:
                continue
            instance.attributes.append(Attribute.fromStandard([line]))

        instance._load_id_line(lines[0], mention_dict, span_dict)
        instance.processAttributes()

        return instance


    @classmethod
    def fromTest(cls, text):
        """Load the entity from a test file using a different format:
        
        [entity_type]
        [attr_name]:[attr_value]
        ...
        [attr_name]:[attr_value]
        """

        assert(len(text.strip('\r\n\t ')) > 0)

        instance = cls()

        lines = text.split('\n')
        for line in lines[1:]:
            if len(line) == 0:
                continue
            instance.attributes.append(Attribute.fromTest(line))
        instance.tag = lines[0].lower().strip(' :\r\n\t')
        if instance.tag == 'locorg':
            # all locorgs are considered locs for this task
            instance.tag = 'loc'

        return instance

In [0]:
class Attribute:
    """Entity attribute with one or several synonimous values"""

    def __init__(self):
        """Create a new object. Do not call this directly, use classmethods instead."""
        self.name = ''
        self.values = set()

    def buildAlternatives(self, descr):
        """Build full alternative list from current values and descriptors."""
        raw_values = self.values
        self.values = set()
        for x in raw_values:
            for y in descr.values:
                self.values.add(x)
                if (' ' + y + ' ') in (' ' + x + ' '):
                    # for those descriptors already included in a name
                    # added spaces to do a full-word search
                    continue

                self.values.add(x + ' ' + y)
                self.values.add(y + ' ' + x)

    def tryPutInQoutes(self, entity):
        """Try to return a copy of this attribute surrounded with quotes.
        This must only yield meaningful result if the entity this attribute belongs to
        has a span marked as '**_name' surrounded by qoutes"""

        # method must only be called before all other processing
        assert(len(self.values) == 1)

        if entity.tag == 'per':
            return

        val = list(self.values)[0]
        all_spans = []
        all_spans.extend(entity.spans)
        for mention in entity.mentions:
            all_spans.extend(mention.spans)
        name_spans = [x for x in all_spans
                      if x.text.lower().replace('ё', 'е') == val and 'name' in x.tag]

        for span in name_spans:
            if span.isInQuotes():
                attr = Attribute()
                attr.name = self.name
                attr.values.add(span.getLeftQuote() + val + span.getRightQuote())
                return attr

        return

    def trimName(self):
        """Removes any digits following the attribute name"""
        self.name = self.name.strip('1234567890')
        
    def matches(self, other):
        """Returns true if a set of value of other corresponds to a set of values of this"""
        if self.name != other.name:
           return False
      
        for v1 in self.values:
            for v2 in other.values:
                if compareStrings(v1, v2):
                    return True

        return False

    def isValid(self):
        """Checks if the attribute is valid (has non-empty values)"""
        for value in self.values:
            if len(value) > 0:
                return True

        return False

    def toTestString(self):
        """Creates a test representation of this attribute"""
        return '\n'.join(['{} : {}'.format(self.name, x) for x in self.values])

    def __repr__(self):
        return '{} : {}'.format(self.name, ' | '.join( self.values ))

    def __str__(self):
        return self.__repr__()


    # static build methods:
    @classmethod
    def fromStandard(cls, lines):
        """Load an attribute from the set of lines representing it.
        This method corresponds to the standard format of representation
        
        Returns a new Attribute instance"""

        assert(len(lines) == 1)

        line = lines[0]
        parts = line.split(' ')

        instance = cls()
        instance.name = parts[0].strip().lower()
        value = normalize(' '.join(parts[1:]))
        instance.values.add(value)

        return instance

    @classmethod
    def fromTest(cls, line):
        """Load an attribute from the set of lines representing it.
        This method corresponds to the test format of representation.
        
        Returns a new Attribute instance"""

        parts = line.split(':')
        assert(len(parts) >= 2)

        instance = cls()
        instance.name = parts[0].strip().lower()
        value = normalize(':'.join(parts[1:]))
        instance.values.add(value)

        return instance

    @classmethod
    def merge(cls, attr_list, new_name):
        """Merge values from the list of attributes, and assign a new name. Returns a new
        instance"""

        instance = cls()
        instance.name = new_name
        for attr in attr_list:
            instance.values.update(attr.values)

        return instance

In [0]:
class Fact:
    """Fact extracted from a document"""
    
    # values of the 'модальность' property that make the fact eligible for the easy mode
    # only
    easymode_modality_values = [
            'возможность',
            'будущее',
            'отрицание'
        ]

    # values of the 'сложность' property that make the fact eligible for the hard mode
    # only
    hardmode_difficulty_values = [
            'повышенная'
        ]

    def __init__(self):
        """Initialize the object (use Fact.fromStandard/Fact.fromTest instead)"""
        self.tag = ''
        self.id = ''
        self.arguments = []
        self.has_easymode_modality = False
        self.has_hardmode_difficulty = False
        self.is_ignored = False

    def toTestString(self):
        return '\n'.join([self.tag]
                         + [x.toTest() for x in self.arguments if not x.is_special]) + '\n'

    def toInlineString(self):
        res = '[ ' + str(self.id) + ' '
        if self.has_easymode_modality:
            res += '(MODALITY) '
        if self.has_hardmode_difficulty:
            res += '(HARD) '
        res += self.tag
        for arg in self.arguments:
            res += ' | {}'.format(arg)
        res += ' ]'
        return res

    def _load_id_line(self, line):
        """Loads the first line of the fact description"""
        parts = line.split(' ')
        self.id = parts[0]
        self.tag = parts[1].strip(' :\n\t\r').lower()

    def canMatch(self, other):
        """Determine if this fact can match the other in evaluation. In essense, returns
        True only if at least one of the arguments has matching values"""

        if self.tag != other.tag:
            return False

        for a in self.arguments:
            for b in other.arguments:
                if a.canMatch(b):
                   if a.name=='position':
                       continue
                   return True

        return False

    def removePhase(self):
        """Remove phase argument, it one is present"""
        phase_args = [a for a in self.arguments if a.name == 'фаза']
        if len(phase_args) == 0:
            return

        # there should be no more than one phase per fact
        assert(len(phase_args) == 1)
        self.arguments.remove(phase_args[0])

    def finalize(self):
        """Finalize the object for the evaluation"""
        self._processModality()
        self._processDifficulty()
        for arg in self.arguments:
            arg.finalize()

    def _processModality(self):
        modality_args = [a for a in self.arguments if a.name == 'модальность']
        if len(modality_args) == 0:
            self.has_easymode_modality = False
            return

        # Apparently there can be multiple modality values in the dataset
        # And mutliple values per modality attribute
        for modality in modality_args:
            self.arguments.remove(modality)

            assert(isinstance(modality.values[0], StringValue))
            for value in modality.values:
                self.has_easymode_modality = (
                        self.has_easymode_modality
                        or (value.descr in Fact.easymode_modality_values)
                )

    def _processDifficulty(self):
        difficulty_args = [a for a in self.arguments if a.name == 'сложность']
        if len(difficulty_args) == 0:
            self.has_hardmode_difficulty = False
            return

        # there should be no more than one modality per fact
        assert(len(difficulty_args) == 1)
        difficulty = difficulty_args[0]
        self.arguments.remove(difficulty)

        assert(len(difficulty.values) == 1)

        assert(isinstance(difficulty.values[0], StringValue))
        value = difficulty.values[0].descr
        self.has_hardmode_difficulty = value in Fact.hardmode_difficulty_values

    def expandWithIsPartOf(self, facts):
        if self.tag != 'occupation':
            return

        partof_dict = {}
        for fact in facts:
            for arg in fact.arguments:
                for key in arg.values:
                    assert(isinstance(key, EntityValue))
                    for value in arg.values:
                        if value == key:
                            continue
                        if not (key in partof_dict):
                            partof_dict[key] = []
                        partof_dict[key].append(value)

        for arg in self.arguments:
            if arg.name == 'where':
                assert(len(arg.values) == 1)
                assert(isinstance(arg.values[0], EntityValue))
                arg.values[0].expandWithIsPartOf(partof_dict)

    def __repr__(self):
        res = self.tag + '\n'
        for arg in self.arguments:
            res += str(arg) + '\n'

        return res

    def __str__(self):
        return repr(self)

    # static build methods
    @classmethod
    def fromStandard(cls, text, entity_dict, span_dict):
        """"""
        assert(len(text.strip('\r\n\t ')) > 0)
        lines = text.split('\n')

        builder = ArgumentBuilder(entity_dict, span_dict)

        instance = cls()
        for line in lines[1:]:
            if len(line) == 0:
                continue
            arg = builder.build(line)
            arg.fact = instance
            instance.arguments.append(arg)

        # instance.processAttributes()
        instance._load_id_line(lines[0])
        instance.finalize()

        return instance

    @classmethod
    def fromTest(cls, text):
        """Load the entity from a test file using a different format:
        
        [fact_type]
        [arg_name]:[arg_value]
        ...
        [arg_name]:[arg_value]
        """

        assert(len(text.strip('\r\n\t ')) > 0)

        instance = cls()

        lines = text.split('\n')
        instance.tag = lines[0].strip(' :\n\t\r').lower()
        lines = text.split('\n')
        for line in lines[1:]:
            if len(line) == 0:
                continue
            arg = Argument.fromTest(line)
            arg.fact = instance
            instance.arguments.append(arg)

        return instance

In [0]:
class ArgumentBuilder:
    """Creates an argument of a proper type from string"""

    def __init__(self, entity_dict, span_dict):
        self.entity_dict = entity_dict
        self.span_dict = span_dict

    def build(self, line):
        parts = line.split(' ')
        name = parts[0].lower()
        alternatives = ' '.join(parts[1:]).split('|')
        argument = Argument(name)

        assert(len(alternatives)>0)

        if parts[1].startswith('span'):
            # spans have the following syntax:
            # position spanXXXX somevalue | spanYYYY someothervalue
            for alternative in alternatives:
                parts = [x for x in alternative.split(' ') if x != '']
                argument.values.append(
                    SpanValue(argument, parts[0], ' '.join(parts[1:]), self.span_dict))
        elif parts[1].startswith('obj'):
            # objects have different syntax:
            # who objXXX name1 | name2 | name3
            # (all names refer to the same object)
            argument.values.append(
                EntityValue(parts[1], ' '.join(parts[2:]), self.entity_dict))
        else:
            # just a string value
            for alternative in alternatives:
                argument.values.append(StringValue(alternative))

        return argument

In [0]:
class Argument:
    """Fact argument"""

    # names of the special cyrillic attributes
    special_names = ['сложность', 'модальность', 'фаза']

    # normalized 'occupation:position' values dictionary
    # loaded from a specific file
    position_dict = None

    def __init__(self, name):
        """Initialize"""
        if Argument.position_dict == None:
            Argument.loadPositionDict()

        self.name = name.strip(' \n\r\t').lower()
        if self.name == 'job':
            self.name = 'position'
        self.is_special = self.name in Argument.special_names
        self.values = []
        self.fact = None

    def toTest(self):
        if(len(self.values) == 0):
            print(self.fact)
        return self.name + ' : ' + str(self.values[0])

    def toInlineString(self):
        return str(self.values[0])

    def canMatch(self, other):
        """Check if the value of other is compatable with the arguments own values"""
        assert(len(other.values) == 1) # other should be a test argument with only 1 value
        if self.name != other.name:
            return False

        for x in self.values:
            for y in other.values:
                if x.equals(y):
                    return True

        return False

    def finalize(self):
        """Finalize the argument for evaluation"""
        for v in self.values:
            v.finalize()

    def __repr__(self):
        return self.name + ' : ' + ' | '.join([str(x) for x in self.values])

    def __str__(self):
        return self.__repr__()

    # classmethods

    @classmethod
    def loadPositionDict(cls):
        """Load the normalized 'occupation:position' values dictionary from the
        associated file. This method should only be called once"""
        cls.position_dict = {}
        with open(jobs_file_path, encoding='utf-8') as f:
            for line in f:
                parts = [x.strip(' \n\r\t') for x in line.split('|')]
                assert(len(parts) == 2)
                cls.position_dict[parts[0]] = parts[1]


    @classmethod
    def fromTest(cls, line):
        parts = line.split(':')
        assert(len(parts) == 2)
        arg = cls(parts[0])
        arg.values.append(StringValue(parts[1]))

        return arg

In [0]:
class EntityValue:
    """Fact argument that is an entity"""

    def __init__(self, full_id, descr, entity_dict):
        """Initialize the object"""
        assert(full_id.startswith('obj'))
        self.entity = entity_dict[full_id[3:]]
        self.descr = descr.strip(' \n\r\t').lower()
        self.values = set([self.descr])

        # special logic for different types of entities
        assert( self.entity.tag != 'locorg' )
        self.values = self.values.union(self._expandFromText())

        if self.entity.tag == 'per':
            self.values = self.values.union(self._expandPerson(self.entity))

        if self.entity.tag in ['org', 'loc']:
            self.values = self.values.union(self._expandWithDescr(self.entity))

    def equals(self, other):
        assert(isinstance(other, StringValue))
        for val in self.values:
            if compareStrings(val, other.value):
                return True
        return False

    def finalize(self):
        """Finalize the value"""
        self.values = [x.lower().strip(' \n\r\t').replace('ё', 'е') for x in self.values]

    def _expandFromText(self):
        """Returns a set of non-normalized values corresponding to each mention of the
        entity"""
        additional_values = []
        for mention in self.entity.mentions:
            additional_values.append(mention.text)
            additional_values.append(mention.interval_text)
        return set(additional_values)

    def _expandPerson(self, per):
        """Create all possible values for a person"""
        assert(per.tag == 'per')

        firstnames = per.getAttr('firstname')
        lastnames = per.getAttr('lastname')
        patronymics = per.getAttr('patronymic')
        nicknames = per.getAttr('nickname')

        lists = [firstnames, lastnames, patronymics, nicknames]
        combinations = ['lfp', 'fpl', 'fp', 'fl', 'lf', 'n', 'f', 'p', 'l', 'fn']
            
        values = []
        for c in combinations:
            values += self._buildPerValues(lists, c)
        values.append(self.descr)
        return set(values)

    def _buildPerValues(self, lists, combination):
        value_lists = []
        for symbol in combination:
            if symbol == 'f':
                value_lists.append(lists[0])
            elif symbol == 'l':
                value_lists.append(lists[1])
            elif symbol == 'p':
                value_lists.append(lists[2])
            elif symbol == 'n':
                value_lists.append(lists[3])
        return self._combine(value_lists)
        
    def _combine(self, value_lists):
        options = ['']
        new_options = []
        for lst in value_lists:
            for val in lst:
                if val == '':
                    continue
                for opt in options:
                    new_options.append(opt + ' ' + val if opt != '' else val)
            options = new_options
            new_options = []
        return options

    def _expandWithDescr(self, org):
        """Replace the value list with all possible organization/location names"""
        assert(org.tag in ['org', 'loc'])
        return set(org.getAttr('name'))
        
    def expandWithIsPartOf(self, ent_dict):
        if not (self.entity in ent_dict):
            return

        for ent in ent_dict[self.entity]:
            self.values = self.values.union(self._expandWithDescr(ent))

    def __repr__(self):
        return self.descr

    def __str__(self):
        return self.__repr__()

In [0]:
class SpanValue:
    """Fact argument that is a span"""

    def __init__(self, owner, full_id, descr, span_dict):
        """Initialize the object"""
        assert(full_id.startswith('span'))
        self.owner = owner
        self.span = span_dict[full_id[4:]]
        self.values = [self.span.text]
        self.descr = self.span.text

    def equals(self, other):
        for val in self.values:
            if compareStrings(val, other.value):
                return True
        return False

    def finalize(self):
        """Finalize the value"""
        if self.owner.name == 'position':
            if(self.values[0] in Argument.position_dict):
                self.values.append(Argument.position_dict[self.values[0]])
        self.values = [x.lower().strip(' \n\r\t').replace('ё', 'е') for x in self.values]

    def __repr__(self):
        return self.descr

    def __str__(self):
        return self.__repr__()

In [0]:
class StringValue:
    """String value for special cases"""

    def __init__(self, value):
        """Initialie the object"""
        self.value = value.strip(' \n\r\t').lower()
        self.descr = self.value

    def equals(self, other):
        # STUB
        return compareStrings(self.value, other.value)

    def finalize(self):
        """Finalize the value"""
        self.value = self.value.lower().strip(' \n\r\t').replace('ё', 'е')
        # does nothing for now

    def __repr__(self):
        return self.descr

    def __str__(self):
        return self.descr

In [129]:
is_locorg_allowed = False
e = Evaluator(is_locorg_allowed)

#-s /content/factRuEval-2016-master 
#-t /content/FactRuEval2016_results/results_of_elmo_and_crf2

e.evaluate(std_path="/content/factRuEval-2016-master/testset", test_path="/content/FactRuEval2016_results/results_of_elmo_and_crf")

Failed to load the standard of book_3954:
Unknown mention tag: Facility
Type    P        R        F1       TP1      TP2      In Std.  In Test.
per        0.8598   0.9490   0.9022  1271.67  1271.67     1340     1479
loc        0.7836   0.7498   0.7663   921.53   921.53     1229     1176
org        0.5826   0.7477   0.6549  1176.93  1176.93     1574     2020
overall    0.7209   0.8135   0.7644  3370.13  3370.13     4143     4675


{'loc': <__main__.Metrics at 0x7f97c3ffb7f0>,
 'org': <__main__.Metrics at 0x7f97c3ffbb70>,
 'overall': <__main__.Metrics at 0x7f97c3fe8f28>,
 'per': <__main__.Metrics at 0x7f97c3ffbac8>}

In [0]:
#!python factRuEval-2016-master/scripts/t1_eval.py -t /content/FactRuEval2016_results/results_of_elmo_and_crf2 -s /content/factRuEval-2016-master/testset -l

In [0]:
#!zip -r /content/file.zip /content/FactRuEval2016_results