In [23]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.data.iterators import BasicIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor

import pandas as pd
import xml.etree.ElementTree as ET
from xml.dom.minidom import Text, Element

import re

torch.manual_seed(1)

<torch._C.Generator at 0x117029310>

In [24]:


def replace_comma_in_numbers(s):
    return re.sub(r'(\d),(\d)', r'\1.\2', s)

def replace_dot_in_words(s):
    return re.sub(r'([^(\W|\d)] ?)\.( ?[^(\W|\d)])', r'\1 \2', s)

print(replace_comma_in_numbers('22,18'))
print(replace_dot_in_words('22.18'))
print(replace_dot_in_words('дер .полн'))

def replace_x_in_numbers(s):
    return re.sub(r'(\d ?)[xх]( ?\d)', r'\1*\2', s)

replace_x_in_numbers('22 x18')

def replace_slash_in_words(s):
    return re.sub(r'([^(\W|\d)] ?)/( ?[^(\W|\d)])', r'\1 \2', s)

def is_number_with_slash(s):
    return re.match(r'\d/\d', s) is not None

def replace_slash_in_numbers(s):
    return re.sub(r'(\d ?)/( ?\d)', r'\1 \2', s)

def surround_with_spaces(s, regex, char):
    return re.sub(r'(\w ?)' + regex + r'( ?\w)', r'\1 ' + char + r' \2', s)

def surround_with_spaces_words(s, regex, char):
    return re.sub(r'([^(\W|\d)] ?)' + regex + r'( ?[^(\W|\d)])', r'\1 ' + char + r' \2', s)

def surround_with_spaces_numbers(s, regex, char):
    return re.sub(r'(\d ?)' + regex + r'( ?\d)', r'\1 ' + char + r' \2', s)

def prepare_value(s):
    s = s.lower()
    s = replace_comma_in_numbers(s)
    
    s = surround_with_spaces_numbers(s, 'x', 'x')
    s = surround_with_spaces_numbers(s, 'х', 'х')
    s = surround_with_spaces(s, '/', '/')
    s = surround_with_spaces_words(s, r'\.', '.')
    
    s = surround_with_spaces(s, ';', ';')
    s = surround_with_spaces(s, ',', ',')
    s = surround_with_spaces(s, r'\*', '*')
    
    return s

def split(s):
    tokens = re.split(r'\s', s)
    return [t for t in tokens if t is not None and t != '']

def parse_prelabeled(s):
    find_from = 0

    parts = []
    tags = []

    first_tag_start_index = 0
    while first_tag_start_index != -1 and first_tag_start_index < len(s):
        first_tag_start_index = s.find('<', find_from)
        if first_tag_start_index != -1:
            first_tag_end_index = s.find('>', first_tag_start_index)
            tag_name = s[first_tag_start_index+1:first_tag_end_index]
            second_tag_start_index = s.find('</', first_tag_end_index+1)
            second_tag_end_index = s.find('>', second_tag_start_index)

            value = s[first_tag_end_index+1:second_tag_start_index]

            before_tag = s[find_from:first_tag_start_index]
            if len(before_tag) > 0:
                parts.append(before_tag)
                tags.append('NONE_CHAR')

            if len(value) > 0:
                parts.append(value)
                tags.append(tag_name)
            
            find_from = second_tag_end_index + 1
            
    if find_from < len(s):
        parts.append(s[find_from:])
        tags.append('NONE_CHAR')

    return parts, tags

def split_and_get_tags(parts, parts_tags):
    tokens = []
    tags = []
    
    for i in range(len(parts)):
        part_tokens = split(prepare_value(parts[i]))
        tag = parts_tags[i]
        for t in part_tokens:
            tokens.append(t)
            tags.append(tag)
        
    return tokens, tags
    

def escape( s ):
    s = s.replace("&", "&amp;")
    s = s.replace("\"", "&quot;")
    return s

22.18
22.18
дер  полн


In [76]:
from allennlp.data.tokenizers.word_splitter import WordSplitter
from allennlp.data.tokenizers.token import Token


class KrepmarketWordSplitter(WordSplitter):

    def split_words(self, sentence: str) -> List[Token]:
        return [Token(w) for w in split(prepare_value(sentence))]

In [64]:
class KrepmarketMergedDataReader(DatasetReader):
    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        
    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)
    
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_excel('datasets/krepmarket_merged.xlsx')
        for index, row in df.iterrows():
            if pd.isna(row['title_labeled']):
                continue
                    
            parts, parts_tags = parse_prelabeled(str(row['title_labeled']))
            sentence, tags = split_and_get_tags(parts, parts_tags)
            if row['title'] == 'Саморез по дереву 3 x12 PZ полная п/сф хромир MUST':
                print(sentence)
            
            yield self.text_to_instance([Token(w) for w in sentence], tags)

[i['labels'] for i in KrepmarketMergedDataReader().read('datasets/krepmarket_merged.xlsx')]

570it [00:02,  1.18s/it]

['саморез', 'по', 'дереву', '3', 'x', '12', 'pz', 'полная', 'п', '/', 'сф', 'хромир', 'must']


16510it [00:06, 2439.59it/s]


[<allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a35d3e3c8>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a2fd24b38>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b818e80>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b818e48>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b818d30>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b818860>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b818278>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b812ef0>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b812c50>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b812b70>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b812438>,
 <allennlp.data.fields.sequence_label_field.SequenceLabelField at 0x1a3b8120b8>,
 <allennlp.data.fields.seque

In [27]:
class LstmTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()
        
    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}

        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output
    
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [88]:
df = pd.read_excel('datasets/krepmarket_merged.xlsx')

msk = np.random.rand(len(df)) < 0.1

train = df[msk]
test = df[~msk]

train.to_excel('datasets/krepmarket_merged_train.xlsx')
test.to_excel('datasets/krepmarket_merged_test.xlsx')

In [89]:
reader = KrepmarketMergedDataReader()
train_dataset = reader.read('datasets/krepmarket_merged_train.xlsx')
validation_dataset = reader.read('datasets/krepmarket_merged_test.xlsx')
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

EMBEDDING_DIM = 6

token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

HIDDEN_DIM = 6

lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))

model = LstmTagger(word_embeddings, lstm, vocab)


0it [00:00, ?it/s][A
1it [00:02,  2.09s/it][A
254it [00:02,  1.46s/it][A
579it [00:02,  1.02s/it][A

['саморез', 'по', 'дереву', '3', 'x', '12', 'pz', 'полная', 'п', '/', 'сф', 'хромир', 'must']



901it [00:02,  1.40it/s][A
1231it [00:02,  1.99it/s][A
1582it [00:02,  2.85it/s][A
1939it [00:02,  4.07it/s][A
2277it [00:02,  5.81it/s][A
2636it [00:02,  8.29it/s][A
3016it [00:02, 11.84it/s][A
3400it [00:03, 16.89it/s][A
3774it [00:03, 24.08it/s][A
4134it [00:03, 34.30it/s][A
4504it [00:03, 48.80it/s][A
4866it [00:03, 69.31it/s][A
5228it [00:03, 98.21it/s][A
5589it [00:03, 138.52it/s][A
5940it [00:03, 194.39it/s][A
6285it [00:03, 271.10it/s][A
6629it [00:04, 373.71it/s][A
6966it [00:04, 507.82it/s][A
7296it [00:04, 679.05it/s][A
7661it [00:04, 898.30it/s][A
8132it [00:04, 1186.18it/s][A
8585it [00:04, 1523.52it/s][A
8985it [00:04, 1870.26it/s][A
9385it [00:04, 2220.17it/s][A
9783it [00:04, 2558.40it/s][A
10181it [00:04, 2853.39it/s][A
10577it [00:05, 3100.47it/s][A
10984it [00:05, 3337.07it/s][A
11394it [00:05, 3533.72it/s][A
11808it [00:05, 3693.82it/s][A
12220it [00:05, 3810.35it/s][A
12627it [00:05, 3833.46it/s][A
13029it [00:05, 3879.23it/s][A
13

['саморез', 'по', 'дереву', '3', 'x', '12', 'pz', 'полная', 'п', '/', 'сф', 'хромир', 'must']



920it [00:02,  1.38it/s][A
1247it [00:02,  1.97it/s][A
1597it [00:02,  2.81it/s][A
1947it [00:02,  4.01it/s][A
2299it [00:02,  5.73it/s][A
2665it [00:02,  8.18it/s][A
3044it [00:03, 11.67it/s][A
3429it [00:03, 16.66it/s][A
3790it [00:03, 23.75it/s][A
4159it [00:03, 33.83it/s][A
4524it [00:03, 48.14it/s][A
4886it [00:03, 68.37it/s][A
5246it [00:03, 96.88it/s][A
5605it [00:03, 136.78it/s][A
5962it [00:04, 168.90it/s][A
6266it [00:04, 235.66it/s][A
6590it [00:04, 326.47it/s][A
6912it [00:04, 446.95it/s][A
7237it [00:05, 602.88it/s][A
7574it [00:05, 799.88it/s][A
8034it [00:05, 1063.44it/s][A
8501it [00:05, 1383.86it/s][A
8896it [00:05, 1718.38it/s][A
9288it [00:05, 2050.02it/s][A
9675it [00:05, 2380.17it/s][A
10060it [00:05, 2686.61it/s][A
10445it [00:05, 2949.71it/s][A
10842it [00:05, 3195.47it/s][A
11289it [00:06, 3492.92it/s][A
11708it [00:06, 3676.29it/s][A
12118it [00:06, 3788.80it/s][A
12527it [00:06, 3816.96it/s][A
12930it [00:06, 3819.32it/s][A
13

In [92]:
optimizer = optim.SGD(model.parameters(), lr=0.1)
iterator = BasicIterator(batch_size=20)
iterator.index_with(vocab)

trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=40)

trainer.train()




  0%|          | 0/826 [00:00<?, ?it/s][A[A[A


accuracy: 0.7165, loss: 1.0901 ||:   1%|          | 6/826 [00:00<00:14, 55.25it/s][A[A[A


accuracy: 0.7166, loss: 1.1537 ||:   2%|▏         | 14/826 [00:00<00:13, 59.58it/s][A[A[A


accuracy: 0.7164, loss: 1.1857 ||:   3%|▎         | 22/826 [00:00<00:12, 63.96it/s][A[A[A


accuracy: 0.7165, loss: 1.1877 ||:   4%|▎         | 30/826 [00:00<00:11, 67.71it/s][A[A[A


accuracy: 0.7165, loss: 1.1917 ||:   4%|▍         | 37/826 [00:00<00:11, 68.33it/s][A[A[A


accuracy: 0.7162, loss: 1.2010 ||:   5%|▌         | 45/826 [00:00<00:11, 69.93it/s][A[A[A


accuracy: 0.7159, loss: 1.2148 ||:   6%|▋         | 53/826 [00:00<00:10, 70.42it/s][A[A[A


accuracy: 0.7162, loss: 1.2027 ||:   7%|▋         | 61/826 [00:00<00:10, 71.39it/s][A[A[A


accuracy: 0.7161, loss: 1.2045 ||:   8%|▊         | 69/826 [00:00<00:10, 71.94it/s][A[A[A


accuracy: 0.7160, loss: 1.2051 ||:   9%|▉         | 77/826 [00:01<00:10, 70.98it/s][A[A[A

accuracy: 0.7171, loss: 1.1873 ||:  81%|████████  | 669/826 [00:09<00:02, 71.44it/s][A[A[A


accuracy: 0.7172, loss: 1.1870 ||:  82%|████████▏ | 677/826 [00:09<00:02, 72.20it/s][A[A[A


accuracy: 0.7170, loss: 1.1875 ||:  83%|████████▎ | 685/826 [00:09<00:02, 70.02it/s][A[A[A


accuracy: 0.7170, loss: 1.1876 ||:  84%|████████▍ | 693/826 [00:09<00:01, 71.58it/s][A[A[A


accuracy: 0.7170, loss: 1.1873 ||:  85%|████████▍ | 701/826 [00:09<00:01, 71.05it/s][A[A[A


accuracy: 0.7171, loss: 1.1869 ||:  86%|████████▌ | 709/826 [00:09<00:01, 70.95it/s][A[A[A


accuracy: 0.7171, loss: 1.1868 ||:  87%|████████▋ | 717/826 [00:10<00:01, 70.59it/s][A[A[A


accuracy: 0.7169, loss: 1.1880 ||:  88%|████████▊ | 725/826 [00:10<00:01, 71.17it/s][A[A[A


accuracy: 0.7169, loss: 1.1880 ||:  89%|████████▊ | 733/826 [00:10<00:01, 72.28it/s][A[A[A


accuracy: 0.7170, loss: 1.1878 ||:  90%|████████▉ | 741/826 [00:10<00:01, 72.35it/s][A[A[A


accuracy: 0.7169, loss: 1.1880 ||:  91%|

accuracy: 0.7186, loss: 1.1648 ||:  36%|███▌      | 297/826 [00:04<00:07, 67.11it/s][A[A[A


accuracy: 0.7190, loss: 1.1626 ||:  37%|███▋      | 304/826 [00:04<00:07, 66.89it/s][A[A[A


accuracy: 0.7184, loss: 1.1636 ||:  38%|███▊      | 311/826 [00:04<00:07, 67.12it/s][A[A[A


accuracy: 0.7192, loss: 1.1610 ||:  38%|███▊      | 318/826 [00:04<00:07, 65.21it/s][A[A[A


accuracy: 0.7194, loss: 1.1602 ||:  39%|███▉      | 325/826 [00:04<00:07, 65.93it/s][A[A[A


accuracy: 0.7188, loss: 1.1616 ||:  40%|████      | 332/826 [00:04<00:07, 66.91it/s][A[A[A


accuracy: 0.7189, loss: 1.1618 ||:  41%|████      | 339/826 [00:04<00:07, 66.46it/s][A[A[A


accuracy: 0.7190, loss: 1.1613 ||:  42%|████▏     | 346/826 [00:05<00:07, 66.81it/s][A[A[A


accuracy: 0.7189, loss: 1.1609 ||:  43%|████▎     | 354/826 [00:05<00:06, 68.48it/s][A[A[A


accuracy: 0.7192, loss: 1.1596 ||:  44%|████▍     | 362/826 [00:05<00:06, 69.04it/s][A[A[A


accuracy: 0.7190, loss: 1.1599 ||:  45%|

accuracy: 0.6664, loss: 1.1803 ||:  41%|████      | 339/826 [00:01<00:01, 316.30it/s][A[A[A


accuracy: 0.6688, loss: 1.1768 ||:  45%|████▍     | 371/826 [00:01<00:01, 312.05it/s][A[A[A


accuracy: 0.6823, loss: 1.1410 ||:  49%|████▉     | 407/826 [00:01<00:01, 323.38it/s][A[A[A


accuracy: 0.6923, loss: 1.1117 ||:  54%|█████▍    | 446/826 [00:01<00:01, 339.01it/s][A[A[A


accuracy: 0.6979, loss: 1.0944 ||:  58%|█████▊    | 481/826 [00:01<00:01, 341.49it/s][A[A[A


accuracy: 0.6967, loss: 1.1023 ||:  63%|██████▎   | 517/826 [00:01<00:00, 345.15it/s][A[A[A


accuracy: 0.6993, loss: 1.0985 ||:  67%|██████▋   | 552/826 [00:01<00:00, 344.90it/s][A[A[A


accuracy: 0.7064, loss: 1.0796 ||:  71%|███████▏  | 589/826 [00:01<00:00, 351.15it/s][A[A[A


accuracy: 0.7064, loss: 1.0707 ||:  76%|███████▌  | 627/826 [00:01<00:00, 358.85it/s][A[A[A


accuracy: 0.7060, loss: 1.0653 ||:  80%|████████  | 663/826 [00:01<00:00, 358.21it/s][A[A[A


accuracy: 0.7063, loss: 1.0620

accuracy: 0.7419, loss: 0.9714 ||:  61%|██████    | 503/826 [00:07<00:04, 68.59it/s][A[A[A


accuracy: 0.7425, loss: 0.9695 ||:  62%|██████▏   | 511/826 [00:07<00:04, 69.25it/s][A[A[A


accuracy: 0.7432, loss: 0.9675 ||:  63%|██████▎   | 518/826 [00:07<00:04, 68.85it/s][A[A[A


accuracy: 0.7434, loss: 0.9666 ||:  64%|██████▎   | 526/826 [00:07<00:04, 70.23it/s][A[A[A


accuracy: 0.7438, loss: 0.9663 ||:  65%|██████▍   | 534/826 [00:07<00:04, 69.57it/s][A[A[A


accuracy: 0.7439, loss: 0.9665 ||:  65%|██████▌   | 541/826 [00:08<00:04, 67.75it/s][A[A[A


accuracy: 0.7439, loss: 0.9662 ||:  66%|██████▋   | 549/826 [00:08<00:04, 68.76it/s][A[A[A


accuracy: 0.7442, loss: 0.9652 ||:  67%|██████▋   | 557/826 [00:08<00:03, 70.07it/s][A[A[A


accuracy: 0.7446, loss: 0.9639 ||:  68%|██████▊   | 565/826 [00:08<00:03, 68.04it/s][A[A[A


accuracy: 0.7447, loss: 0.9637 ||:  69%|██████▉   | 572/826 [00:08<00:03, 65.43it/s][A[A[A


accuracy: 0.7450, loss: 0.9626 ||:  70%|

accuracy: 0.7805, loss: 0.8349 ||:  13%|█▎        | 108/826 [00:01<00:10, 67.67it/s][A[A[A


accuracy: 0.7806, loss: 0.8343 ||:  14%|█▍        | 115/826 [00:01<00:10, 67.74it/s][A[A[A


accuracy: 0.7795, loss: 0.8375 ||:  15%|█▍        | 122/826 [00:01<00:10, 68.23it/s][A[A[A


accuracy: 0.7797, loss: 0.8369 ||:  16%|█▌        | 129/826 [00:01<00:10, 68.12it/s][A[A[A


accuracy: 0.7804, loss: 0.8357 ||:  16%|█▋        | 136/826 [00:02<00:10, 68.38it/s][A[A[A


accuracy: 0.7803, loss: 0.8354 ||:  17%|█▋        | 143/826 [00:02<00:10, 67.13it/s][A[A[A


accuracy: 0.7804, loss: 0.8341 ||:  18%|█▊        | 150/826 [00:02<00:10, 66.16it/s][A[A[A


accuracy: 0.7804, loss: 0.8357 ||:  19%|█▉        | 157/826 [00:02<00:10, 66.80it/s][A[A[A


accuracy: 0.7810, loss: 0.8338 ||:  20%|█▉        | 164/826 [00:02<00:10, 65.86it/s][A[A[A


accuracy: 0.7803, loss: 0.8368 ||:  21%|██        | 171/826 [00:02<00:09, 65.96it/s][A[A[A


accuracy: 0.7809, loss: 0.8359 ||:  22%|

accuracy: 0.7858, loss: 0.8072 ||:  87%|████████▋ | 716/826 [00:10<00:01, 68.43it/s][A[A[A


accuracy: 0.7861, loss: 0.8064 ||:  88%|████████▊ | 723/826 [00:10<00:01, 66.46it/s][A[A[A


accuracy: 0.7862, loss: 0.8058 ||:  88%|████████▊ | 730/826 [00:10<00:01, 63.63it/s][A[A[A


accuracy: 0.7865, loss: 0.8047 ||:  89%|████████▉ | 737/826 [00:11<00:01, 64.55it/s][A[A[A


accuracy: 0.7867, loss: 0.8040 ||:  90%|█████████ | 744/826 [00:11<00:01, 65.19it/s][A[A[A


accuracy: 0.7867, loss: 0.8037 ||:  91%|█████████ | 751/826 [00:11<00:01, 64.88it/s][A[A[A


accuracy: 0.7868, loss: 0.8035 ||:  92%|█████████▏| 758/826 [00:11<00:01, 64.90it/s][A[A[A


accuracy: 0.7868, loss: 0.8032 ||:  93%|█████████▎| 765/826 [00:11<00:00, 64.03it/s][A[A[A


accuracy: 0.7869, loss: 0.8022 ||:  93%|█████████▎| 772/826 [00:11<00:00, 65.51it/s][A[A[A


accuracy: 0.7868, loss: 0.8024 ||:  94%|█████████▍| 780/826 [00:11<00:00, 67.44it/s][A[A[A


accuracy: 0.7870, loss: 0.8018 ||:  95%|

accuracy: 0.7954, loss: 0.7519 ||:  38%|███▊      | 318/826 [00:04<00:07, 66.17it/s][A[A[A


accuracy: 0.7958, loss: 0.7508 ||:  39%|███▉      | 325/826 [00:04<00:07, 66.77it/s][A[A[A


accuracy: 0.7959, loss: 0.7498 ||:  40%|████      | 332/826 [00:04<00:07, 66.65it/s][A[A[A


accuracy: 0.7963, loss: 0.7484 ||:  41%|████      | 339/826 [00:05<00:07, 66.24it/s][A[A[A


accuracy: 0.7959, loss: 0.7488 ||:  42%|████▏     | 346/826 [00:05<00:07, 66.66it/s][A[A[A


accuracy: 0.7962, loss: 0.7477 ||:  43%|████▎     | 353/826 [00:05<00:07, 65.00it/s][A[A[A


accuracy: 0.7962, loss: 0.7481 ||:  44%|████▎     | 360/826 [00:05<00:07, 65.70it/s][A[A[A


accuracy: 0.7958, loss: 0.7482 ||:  44%|████▍     | 367/826 [00:05<00:06, 66.26it/s][A[A[A


accuracy: 0.7962, loss: 0.7473 ||:  45%|████▌     | 374/826 [00:05<00:06, 66.84it/s][A[A[A


accuracy: 0.7962, loss: 0.7469 ||:  46%|████▌     | 381/826 [00:05<00:06, 66.56it/s][A[A[A


accuracy: 0.7963, loss: 0.7465 ||:  47%|

accuracy: 0.7694, loss: 0.7658 ||:  49%|████▊     | 402/826 [00:01<00:01, 322.02it/s][A[A[A


accuracy: 0.7795, loss: 0.7443 ||:  53%|█████▎    | 440/826 [00:01<00:01, 337.16it/s][A[A[A


accuracy: 0.7860, loss: 0.7308 ||:  58%|█████▊    | 476/826 [00:01<00:01, 342.73it/s][A[A[A


accuracy: 0.7862, loss: 0.7364 ||:  62%|██████▏   | 512/826 [00:01<00:00, 347.50it/s][A[A[A


accuracy: 0.7870, loss: 0.7397 ||:  66%|██████▋   | 548/826 [00:01<00:00, 350.74it/s][A[A[A


accuracy: 0.7950, loss: 0.7210 ||:  71%|███████   | 585/826 [00:01<00:00, 355.65it/s][A[A[A


accuracy: 0.7979, loss: 0.7140 ||:  75%|███████▌  | 623/826 [00:01<00:00, 361.14it/s][A[A[A


accuracy: 0.8003, loss: 0.7055 ||:  80%|████████  | 661/826 [00:01<00:00, 364.10it/s][A[A[A


accuracy: 0.8024, loss: 0.7008 ||:  85%|████████▍ | 698/826 [00:02<00:00, 362.00it/s][A[A[A


accuracy: 0.8056, loss: 0.6967 ||:  89%|████████▉ | 735/826 [00:02<00:00, 354.11it/s][A[A[A


accuracy: 0.8094, loss: 0.6881

accuracy: 0.8159, loss: 0.6678 ||:  65%|██████▌   | 538/826 [00:07<00:04, 62.72it/s][A[A[A


accuracy: 0.8159, loss: 0.6676 ||:  66%|██████▌   | 545/826 [00:08<00:04, 62.24it/s][A[A[A


accuracy: 0.8163, loss: 0.6667 ||:  67%|██████▋   | 552/826 [00:08<00:04, 62.95it/s][A[A[A


accuracy: 0.8167, loss: 0.6657 ||:  68%|██████▊   | 559/826 [00:08<00:04, 62.92it/s][A[A[A


accuracy: 0.8169, loss: 0.6651 ||:  69%|██████▊   | 566/826 [00:08<00:04, 63.07it/s][A[A[A


accuracy: 0.8168, loss: 0.6653 ||:  69%|██████▉   | 573/826 [00:08<00:04, 62.76it/s][A[A[A


accuracy: 0.8170, loss: 0.6646 ||:  70%|███████   | 580/826 [00:08<00:03, 63.72it/s][A[A[A


accuracy: 0.8170, loss: 0.6643 ||:  71%|███████   | 587/826 [00:08<00:03, 63.87it/s][A[A[A


accuracy: 0.8171, loss: 0.6641 ||:  72%|███████▏  | 594/826 [00:08<00:03, 63.02it/s][A[A[A


accuracy: 0.8172, loss: 0.6635 ||:  73%|███████▎  | 601/826 [00:08<00:03, 62.48it/s][A[A[A


accuracy: 0.8170, loss: 0.6634 ||:  74%|

KeyboardInterrupt: 

In [None]:
from allennlp.predictors import SentenceTaggerPredictor
from allennlp.predictors import Predictor
from allennlp.data.tokenizers.word_tokenizer import WordTokenizer

title = "Саморез по дереву 3 x12 PZ полная п/сф хромир MUST"
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)
predictor._tokenizer = KrepmarketWordSplitter()
#print(predictor.load_line({"sentence": title}))
tag_logits = predictor.predict(title)['tag_logits']
tag_ids = np.argmax(tag_logits, axis=-1)
print(KrepmarketWordSplitter().split_words(title))
print([model.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

In [None]:
TARGET_CHARACTERISTICS = [
    'MATERIAL',
    'TYPE',
    'MODEL',
    'COLOR',
    'SIZE1',
    'SIZE2',
    'WEIGHT',
    'PHYS',
    'COUNT',
    'SUPPLIER',
    'SUPPLIER_COUNTRY',
    'PURPOSE',
    'DESTINATION'
]

cols = ['title']
for c in TARGET_CHARACTERISTICS:
    cols.append(c)
    cols.append(c + '_correct')
    

parsing_result_df = pd.DataFrame(columns = cols)
for index, row in test.iterrows():
    title=str(row['title'])
    result_row = {'title': title}
    for c in TARGET_CHARACTERISTICS:
        if not pd.isna(row[c]):
            result_row[c + '_correct'] = row[c]
    tokens = split(prepare_value(title))
    tag_logits = predictor.predict(title)['tag_logits']
    tag_ids = np.argmax(tag_logits, axis=-1)
    tags = [model.vocab.get_token_from_index(i, 'labels') for i in tag_ids]
    
    for i in range(len(tags)):
        tag = tags[i]
        word = tokens[i]
        if tag != 'NONE_CHAR':
            val = ''
            if tag in result_row:
                val = result_row[tag] + ' '
            result_row[tag] = val + word
    
    parsing_result_df = parsing_result_df.append(result_row, ignore_index=True)
    

In [None]:
parsing_result_df.to_excel('krepmarket_result.xlsx')

In [94]:
with open('model.th', 'wb') as f:
    torch.save(model.state_dict(), f)