In [None]:
!pip install seqeval evaluate -q
!pip install transformers -U

In [None]:
import os
import ast
import random
from typing import List, Tuple, Dict, Any, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import re

from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score, accuracy_score

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    set_seed
)

2025-10-01 20:55:10.508230: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759352110.530292    4102 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759352110.537069    4102 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
def set_seed(seed):
    """
    Sets the seed for reproducibility across multiple libraries.
    """
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed) 
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed) 
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
set_seed(56)

In [None]:
train_data = pd.read_csv('./train.csv',sep=';')
sample_sub = pd.read_csv('./submission.csv', sep=';')
train_data['annotation'] = train_data['annotation'].apply(eval)
train_data['annotation'] = train_data['annotation'].apply(lambda x: [(y[0],y[1],y[2].replace('0','O')) for y in x])

In [4]:
def build_label_list(entity_types: List[str]) -> List[str]:
    labels = ['O']
    for ent in entity_types:
        labels.append(f'B-{ent}')
        labels.append(f'I-{ent}')
    return labels

entity_types = ['BRAND', 'PERCENT', 'TYPE', 'VOLUME']
label_list = build_label_list(entity_types) 
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}
train_data['annotation'] = train_data['annotation'].astype(str)

In [None]:
WORD_RE = re.compile(r"\S+")

def strip_bio(lab):
    if not lab:
        return ""
    lab = str(lab)
    if lab == "O":
        return ""
    if lab.startswith(("B-", "I-")):
        return lab[2:]
    return lab

def normalize_spans(spans):
    norm = []
    for s, e, lab in spans:
        ent = strip_bio(lab)
        if ent:
            norm.append((int(s), int(e), ent))
    return norm

def compute_overlap(s1: int, e1: int, s2: int, e2: int) -> int:
    return max(0, min(e1, e2) - max(s1, s2))

def split_words_with_spans(text: str):
    return [(m.group(0), m.start(), m.end()) for m in WORD_RE.finditer(text)]

def tokenize_and_align_labels(
    examples: Dict[str, List[Any]],
    tokenizer,
    label2id: Dict[str, int],
    label_all_tokens: bool = True,
    max_length: int = 256
) -> Dict[str, Any]:
    """
    На входе:
      - examples['sample']: список текстов
      - examples['annotation']: список списков спанов в виде строк "[ (start, end, label), ... ]"
        (каждый элемент будет разобран через ast.literal_eval)
    На выходе:
      - tokenized inputs с полем 'labels' (список id меток для каждого токена; -100 для спец. токенов)
    """
    texts = examples["sample"]
    batch_spans_raw = examples["annotation"]
    batch_spans = [normalize_spans(ast.literal_eval(sp)) for sp in batch_spans_raw]

    batch_words = []
    batch_word_labels = []

    for text, spans in zip(texts, batch_spans):
        word_triplets = split_words_with_spans(text)  # [(word, s, e), ...]
        words = [w for w, _, _ in word_triplets]
        batch_words.append(words)

        per_word_labels = []
        for _, w_start, w_end in word_triplets:
            label = "O"
            best_ent = None
            best_ov = 0

            for (s, e, ent_type) in spans:
                ov = compute_overlap(s, e, w_start, w_end)
                if ov > best_ov:
                    best_ov = ov
                    best_ent = (s, e, ent_type)

            if best_ent is not None and best_ov > 0:
                ent_start, ent_end, ent_type = best_ent
                if w_start == ent_start:
                    label = f"B-{ent_type}"
                else:
                    label = f"I-{ent_type}"

            per_word_labels.append(label)

        batch_word_labels.append(per_word_labels)

    tokenized = tokenizer(
        batch_words,
        is_split_into_words=True,
        truncation=True,
        padding=False,
        max_length=max_length,
        return_tensors=None 
    )

    all_labels: List[List[int]] = []

    for i in range(len(batch_words)):
        word_ids = tokenized.word_ids(batch_index=i)  # список индексов слов для каждого токена
        if word_ids is None:
            raise ValueError("Tokenizer must be a fast tokenizer to use word_ids.")

        labels_ids: List[int] = []
        prev_word_id = None
        per_word = batch_word_labels[i]

        for w_id in word_ids:
            if w_id is None:
                labels_ids.append(-100)  # спец. токены
            else:
                label_str = per_word[w_id]
                if not label_all_tokens and w_id == prev_word_id:
                    labels_ids.append(-100)  # только первый сабтокен помечаем
                else:
                    labels_ids.append(label2id[label_str])
            prev_word_id = w_id

        all_labels.append(labels_ids)

    tokenized["labels"] = all_labels
    return tokenized

In [6]:
examples = [
    ["йогурт 2%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["йогурт 2,5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["йогурт питьевой 1.5%", [(0, 6, 'B-TYPE'), (7, 15, 'I-TYPE'), (16, 20, 'B-PERCENT')]],
    ["ряженка 4 %", [(0, 7, 'B-TYPE'), (8, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["ряженка 4%", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["кефир 2.5%", [(0, 5, 'B-TYPE'), (6, 10, 'B-PERCENT')]],
    ["кефир 0%", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["кефир бифидум 1", [(0, 5, 'B-TYPE'), (6, 13, 'I-TYPE'), (14, 15, 'B-PERCENT')]],
    ["молоко 2 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["молоко 2.5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["молоко пастеризованное 3,2%", [(0, 6, 'B-TYPE'), (7, 22, 'I-TYPE'), (23, 27, 'B-PERCENT')]],
    ["творог мягкий 0.5%", [(0, 6, 'B-TYPE'), (7, 13, 'I-TYPE'), (14, 18, 'B-PERCENT')]],
    ["творог 18%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["творог зерненый 5", [(0, 6, 'B-TYPE'), (7, 15, 'I-TYPE'), (16, 17, 'B-PERCENT')]],
    ["сметана 30%", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["сметана 30", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["сметана фермерская 25 %", [(0, 7, 'B-TYPE'), (8, 18, 'I-TYPE'), (19, 21, 'B-PERCENT'), (22, 23, 'I-PERCENT')]],
    ["сыр плавленый 45%", [(0, 3, 'B-TYPE'), (4, 13, 'I-TYPE'), (14, 17, 'B-PERCENT')]],
    ["сыр твердый 50", [(0, 3, 'B-TYPE'), (4, 11, 'I-TYPE'), (12, 14, 'B-PERCENT')]],
    ["сыр 45 %", [(0, 3, 'B-TYPE'), (4, 6, 'B-PERCENT'), (7, 8, 'I-PERCENT')]],
    ["масло сливочное 62%", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 19, 'B-PERCENT')]],
    ["масло 72.5", [(0, 5, 'B-TYPE'), (6, 10, 'B-PERCENT')]],
    ["масло топленое 99", [(0, 5, 'B-TYPE'), (6, 14, 'I-TYPE'), (15, 17, 'B-PERCENT')]],
    ["сливки 15%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["сливки 15", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["сливки взбитые 35 %", [(0, 6, 'B-TYPE'), (7, 14, 'I-TYPE'), (15, 17, 'B-PERCENT'), (18, 19, 'I-PERCENT')]],
    ["простоквашино молоко 3.2", [(0, 13, 'B-BRAND'), (14, 20, 'B-TYPE'), (21, 24, 'B-PERCENT')]],
    ["домик в деревне кефир 1%", [(0, 15, 'B-BRAND'), (16, 21, 'B-TYPE'), (22, 24, 'B-PERCENT')]],
    ["йогурт греческий 0", [(0, 6, 'B-TYPE'), (7, 16, 'I-TYPE'), (17, 18, 'B-PERCENT')]],
    ["ряженка 2.5%", [(0, 7, 'B-TYPE'), (8, 12, 'B-PERCENT')]],
    ["кефир 3,2", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT')]],
    ["молоко 0.5 %", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["творог 0%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["сметана 18 %", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["сыр мягкий 20%", [(0, 3, 'B-TYPE'), (4, 10, 'I-TYPE'), (11, 14, 'B-PERCENT')]],
    ["масло растительное 100%", [(0, 5, 'B-TYPE'), (6, 18, 'I-TYPE'), (19, 23, 'B-PERCENT')]],
    ["сливки 25%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["йогурт 1.5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["ряженка 3 %", [(0, 7, 'B-TYPE'), (8, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["кефир обезжиренный 0", [(0, 5, 'B-TYPE'), (6, 18, 'I-TYPE'), (19, 20, 'B-PERCENT')]],
    ["молоко 4%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["творог 2 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["сметана 12%", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["сыр 55", [(0, 3, 'B-TYPE'), (4, 6, 'B-PERCENT')]],
    ["масло 82.5%", [(0, 5, 'B-TYPE'), (6, 11, 'B-PERCENT')]],
    ["сливки 40", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["йогурт 3.2 %", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["ряженка 1%", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["кефир 2 %", [(0, 5, 'B-TYPE'), (6, 7, 'B-PERCENT'), (8, 9, 'I-PERCENT')]],
    ["молоко топленое 4", [(0, 6, 'B-TYPE'), (7, 15, 'I-TYPE'), (16, 17, 'B-PERCENT')]],
    ["творог 15%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["сметана 35", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["сыр с плесенью 50%", [(0, 3, 'B-TYPE'), (4, 14, 'I-TYPE'), (15, 18, 'B-PERCENT')]],
    ["масло кокосовое 99%", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 19, 'B-PERCENT')]],
    ["сливки 12 %", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["йогурт натуральный 2.5", [(0, 6, 'B-TYPE'), (7, 18, 'I-TYPE'), (19, 22, 'B-PERCENT')]],
    ["ряженка 0.5%", [(0, 7, 'B-TYPE'), (8, 12, 'B-PERCENT')]],
    ["кефир 4%", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["молоко 1.5%", [(0, 6, 'B-TYPE'), (7, 11, 'B-PERCENT')]],
    ["творог обезжиренный 0 %", [(0, 6, 'B-TYPE'), (7, 19, 'I-TYPE'), (20, 21, 'B-PERCENT'), (22, 23, 'I-PERCENT')]],
    ["сметана 40%", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["сыр 60 %", [(0, 3, 'B-TYPE'), (4, 6, 'B-PERCENT'), (7, 8, 'I-PERCENT')]],
    ["масло 90", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["сливки стерилизованные 10", [(0, 6, 'B-TYPE'), (7, 23, 'I-TYPE'), (24, 26, 'B-PERCENT')]],
    ["йогурт 0 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["ряженка 5", [(0, 7, 'B-TYPE'), (8, 9, 'B-PERCENT')]],
    ["кефир фруктовый 1.5 %", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 19, 'B-PERCENT'), (20, 21, 'I-PERCENT')]],
    ["молоко 3%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["творог 3.5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["сметана домашняя 20", [(0, 7, 'B-TYPE'), (8, 16, 'I-TYPE'), (17, 19, 'B-PERCENT')]],
    ["сыр кремовый 30%", [(0, 3, 'B-TYPE'), (4, 12, 'I-TYPE'), (13, 16, 'B-PERCENT')]],
    ["масло 75 %", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["сливки 18%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["йогурт 4 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["ряженка 2%", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["кефир 0.1%", [(0, 5, 'B-TYPE'), (6, 10, 'B-PERCENT')]],
    ["молоко цельное 3.5 %", [(0, 6, 'B-TYPE'), (7, 14, 'I-TYPE'), (15, 18, 'B-PERCENT'), (19, 20, 'I-PERCENT')]],
    ["творог 12", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["сметана 22%", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["сыр 40", [(0, 3, 'B-TYPE'), (4, 6, 'B-PERCENT')]],
    ["масло сливочное 80", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 18, 'B-PERCENT')]],
    ["сливки 22 %", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["йогурт обезжиренный 0", [(0, 6, 'B-TYPE'), (7, 19, 'I-TYPE'), (20, 21, 'B-PERCENT')]],
    ["ряженка 3.2", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["кефир 2.0 %", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["молоко 2.0", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["творог классический 9%", [(0, 6, 'B-TYPE'), (7, 19, 'I-TYPE'), (20, 22, 'B-PERCENT')]],
    ["сметана 28", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["сыр пармезан 32%", [(0, 3, 'B-TYPE'), (4, 12, 'I-TYPE'), (13, 16, 'B-PERCENT')]],
    ["масло 85%", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT')]],
    ["сливки 28", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["йогурт 5%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["ряженка 4.0 %", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT'), (12, 13, 'I-PERCENT')]],
    ["кефир 3%", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["молоко 0%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["творог 4%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["сметана 15", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["сыр чеддер 48%", [(0, 3, 'B-TYPE'), (4, 10, 'I-TYPE'), (11, 14, 'B-PERCENT')]],
    ["масло топленое 98 %", [(0, 5, 'B-TYPE'), (6, 14, 'I-TYPE'), (15, 17, 'B-PERCENT'), (18, 19, 'I-PERCENT')]],
    ["сливки 35%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["йогурт 2.0%", [(0, 6, 'B-TYPE'), (7, 11, 'B-PERCENT')]],
    ["ряженка 1.5", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["кефир 1.0", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT')]],
    ["молоко ультрапастеризованное 1%", [(0, 6, 'B-TYPE'), (7, 28, 'I-TYPE'), (29, 31, 'B-PERCENT')]],
    ["творог 6 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["сметана 10 %", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["сыр легкий 10", [(0, 3, 'B-TYPE'), (4, 10, 'I-TYPE'), (11, 13, 'B-PERCENT')]],
    ["масло сливочное 70%", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 19, 'B-PERCENT')]],
    ["сливки 30 %", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["йогурт 3%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["ряженка 2.0%", [(0, 7, 'B-TYPE'), (8, 12, 'B-PERCENT')]],
    ["кефир 0 %", [(0, 5, 'B-TYPE'), (6, 7, 'B-PERCENT'), (8, 9, 'I-PERCENT')]],
    ["молоко 4.5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["творог мягкий 2%", [(0, 6, 'B-TYPE'), (7, 13, 'I-TYPE'), (14, 16, 'B-PERCENT')]],
    ["сметана 25%", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["сыр 25 %", [(0, 3, 'B-TYPE'), (4, 6, 'B-PERCENT'), (7, 8, 'I-PERCENT')]],
    ["масло 60", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["сливки 5%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["йогурт питьевой 0.5 %", [(0, 6, 'B-TYPE'), (7, 15, 'I-TYPE'), (16, 19, 'B-PERCENT'), (20, 21, 'I-PERCENT')]],
    ["ряженка 3%", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["кефир 2.5", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT')]],
    ["молоко 3.5%", [(0, 6, 'B-TYPE'), (7, 11, 'B-PERCENT')]],
    ["творог 7", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT')]],
    ["сметана 5 %", [(0, 7, 'B-TYPE'), (8, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["сыр твердый 45 %", [(0, 3, 'B-TYPE'), (4, 11, 'I-TYPE'), (12, 14, 'B-PERCENT'), (15, 16, 'I-PERCENT')]],
    ["масло 72 %", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["сливки 40%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["йогурт 1 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["ряженка 0%", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["кефир 3.5 %", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["молоко 2.5 %", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["творог 10%", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["сметана 30 %", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["сыр 35", [(0, 3, 'B-TYPE'), (4, 6, 'B-PERCENT')]],
    ["масло сливочное 85 %", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 18, 'B-PERCENT'), (19, 20, 'I-PERCENT')]],
    ["сливки 15 %", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["йогурт 4.5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["ряженка 5%", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]],
    ["кефир 1.5", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT')]],
    ["молоко 1", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT')]],
    ["творог 0.5 %", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["сметана 18%", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]],
    ["сыр легкий 20 %", [(0, 3, 'B-TYPE'), (4, 10, 'I-TYPE'), (11, 13, 'B-PERCENT'), (14, 15, 'I-PERCENT')]],
    ["масло 82", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["сливки 20 %", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["йогурт 2.5 %", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["ряженка 4 %", [(0, 7, 'B-TYPE'), (8, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]],
    ["кефир 0.5", [(0, 5, 'B-TYPE'), (6, 9, 'B-PERCENT')]],
    ["молоко 3.2 %", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["творог 5 %", [(0, 6, 'B-TYPE'), (7, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["сметана 20 %", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["сыр 50%", [(0, 3, 'B-TYPE'), (4, 7, 'B-PERCENT')]],
    ["масло сливочное 72 %", [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 18, 'B-PERCENT'), (19, 20, 'I-PERCENT')]],
    ["сливки 33", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["йогурт 0.1%", [(0, 6, 'B-TYPE'), (7, 11, 'B-PERCENT')]],
    ["ряженка 2.5 %", [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT'), (12, 13, 'I-PERCENT')]],
    ["кефир 2%", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT')]],
    ["молоко 0.5", [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]],
    ["творог 9%", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["сметана 25 %", [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT'), (11, 12, 'I-PERCENT')]],
    ["сыр 15%", [(0, 3, 'B-TYPE'), (4, 7, 'B-PERCENT')]],
    ["масло 80 %", [(0, 5, 'B-TYPE'), (6, 8, 'B-PERCENT'), (9, 10, 'I-PERCENT')]],
    ["сливки 10", [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]],
    ["балтика 1", [(0, 7, 'B-BRAND'), (8, 9, 'B-PERCENT')]],
    ["нестожен молоко 2", [(0, 8, 'B-BRAND'), (9, 15, 'B-TYPE'), (16, 17, 'B-PERCENT')]],
    ["простоквашино сметана 15%", [(0, 13, 'B-BRAND'), (14, 21, 'B-TYPE'), (22, 25, 'B-PERCENT')]],
    ["домик в деревне творог 5", [(0, 15, 'B-BRAND'), (16, 22, 'B-TYPE'), (23, 24, 'B-PERCENT')]],
    ["балтика кефир 0%", [(0, 7, 'B-BRAND'), (8, 13, 'B-TYPE'), (14, 16, 'B-PERCENT')]],
    ["нестожен йогурт 1.5", [(0, 8, 'B-BRAND'), (9, 15, 'B-TYPE'), (16, 19, 'B-PERCENT')]],
    ["простоквашино сливки 20 %", [(0, 13, 'B-BRAND'), (14, 20, 'B-TYPE'), (21, 23, 'B-PERCENT'), (24, 25, 'I-PERCENT')]],
    ["домик в деревне сыр 45", [(0, 15, 'B-BRAND'), (16, 19, 'B-TYPE'), (20, 22, 'B-PERCENT')]],
    ["балтика ряженка 3,2", [(0, 7, 'B-BRAND'), (8, 15, 'B-TYPE'), (16, 19, 'B-PERCENT')]],
    ["нестожен масло 82%", [(0, 8, 'B-BRAND'), (9, 14, 'B-TYPE'), (15, 18, 'B-PERCENT')]],
    ["простоквашино творог 9 %", [(0, 13, 'B-BRAND'), (14, 20, 'B-TYPE'), (21, 22, 'B-PERCENT'), (23, 24, 'I-PERCENT')]],
    ["домик в деревне молоко 3.2%", [(0, 15, 'B-BRAND'), (16, 22, 'B-TYPE'), (23, 27, 'B-PERCENT')]],
    ["балтика сметана 10", [(0, 7, 'B-BRAND'), (8, 15, 'B-TYPE'), (16, 18, 'B-PERCENT')]],
    ["нестожен кефир 2.5 %", [(0, 8, 'B-BRAND'), (9, 14, 'B-TYPE'), (15, 18, 'B-PERCENT'), (19, 20, 'I-PERCENT')]],
    ["простоквашино йогурт 0", [(0, 13, 'B-BRAND'), (14, 20, 'B-TYPE'), (21, 22, 'B-PERCENT')]],
    ["домик в деревне сливки 33%", [(0, 15, 'B-BRAND'), (16, 22, 'B-TYPE'), (23, 26, 'B-PERCENT')]],
    ["балтика сыр легкий 15", [(0, 7, 'B-BRAND'), (8, 11, 'B-TYPE'), (12, 18, 'I-TYPE'), (19, 21, 'B-PERCENT')]],
    ["нестожен ряженка 4%", [(0, 8, 'B-BRAND'), (9, 16, 'B-TYPE'), (17, 19, 'B-PERCENT')]],
    ["простоквашино масло сливочное 72", [(0, 13, 'B-BRAND'), (14, 19, 'B-TYPE'), (20, 29, 'I-TYPE'), (30, 32, 'B-PERCENT')]],
    ["домик в деревне сметана 20%", [(0, 15, 'B-BRAND'), (16, 23, 'B-TYPE'), (24, 27, 'B-PERCENT')]],
    ["балтика молоко 1 %", [(0, 7, 'B-BRAND'), (8, 14, 'B-TYPE'), (15, 16, 'B-PERCENT'), (17, 18, 'I-PERCENT')]],
    ["нестожен творог 5", [(0, 8, 'B-BRAND'), (9, 15, 'B-TYPE'), (16, 17, 'B-PERCENT')]],
    ["простоквашино кефир 1", [(0, 13, 'B-BRAND'), (14, 19, 'B-TYPE'), (20, 21, 'B-PERCENT')]],
    ["домик в деревне йогурт 2.5", [(0, 15, 'B-BRAND'), (16, 22, 'B-TYPE'), (23, 26, 'B-PERCENT')]],
    ["балтика сливки 10%", [(0, 7, 'B-BRAND'), (8, 14, 'B-TYPE'), (15, 18, 'B-PERCENT')]],
    ["нестожен сыр 45 %", [(0, 8, 'B-BRAND'), (9, 12, 'B-TYPE'), (13, 15, 'B-PERCENT'), (16, 17, 'I-PERCENT')]],
    ["простоквашино ряженка 3.2", [(0, 13, 'B-BRAND'), (14, 21, 'B-TYPE'), (22, 25, 'B-PERCENT')]],
    ["домик в деревне масло 82", [(0, 15, 'B-BRAND'), (16, 21, 'B-TYPE'), (22, 24, 'B-PERCENT')]]
]

In [10]:
examples_df = pd.DataFrame({'sample': [x[0] for x in examples], 'annotation': [x[1] for x in examples]})
train_texts = train_data['sample'].tolist()
examples_df = examples_df[examples_df['sample'].apply(lambda x: x not in train_texts)]
examples_df['annotation'] = examples_df['annotation'].astype(str)
train_data['is_ad'] = False
examples_df['is_ad'] = True
train_data = pd.concat([train_data,examples_df],axis=0,ignore_index=True) #.sample(50,random_state=56)

In [11]:
train_data['annotation'] = train_data['annotation'].astype(str)
train_data['is_percent'] = train_data['annotation'].apply(lambda x: int('PERCENT' in x))
train_df, valid_df = train_test_split(train_data, test_size=0.2, random_state=56, shuffle=True, stratify=train_data['is_percent'])

In [12]:
valid_df = valid_df[~valid_df['is_ad']]

In [13]:
dop_val_samples = [
    ('сметана 11', [(0, 7, 'B-TYPE'), (8, 10, 'B-PERCENT')]),
    ('сметана 3', [(0, 7, 'B-TYPE'), (8, 9, 'B-PERCENT')]),
    ('сметана 3,4', [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]),
    ('сметана 10%', [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT')]),
    ('сметана 10 %', [(0, 7, 'B-TYPE'), (8, 11, 'B-PERCENT'), (12, 13, 'I-PERCENT')]),
    ('творог 10', [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]),
    ('творог 9%', [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT')]),
    ('творог 11 %', [(0, 6, 'B-TYPE'), (7, 9, 'B-PERCENT'), (10, 11, 'I-PERCENT')]),
    ('масло сливочное 52', [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 18, 'B-PERCENT')]),
    ('масло сливочное 70%', [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 19, 'B-PERCENT')]),
    ('масло сливочное 74 %', [(0, 5, 'B-TYPE'), (6, 15, 'I-TYPE'), (16, 18, 'B-PERCENT'),  (19, 20, 'I-PERCENT')]),
    ('молоко 2,2', [(0, 6, 'B-TYPE'), (7, 10, 'B-PERCENT')]),
    ('молоко 2,2%', [(0, 6, 'B-TYPE'), (7, 11, 'B-PERCENT')]),
    ('молоко 5 %', [(0, 6, 'B-TYPE'), (7, 11, 'B-PERCENT'), (12, 13, 'I-PERCENT')]),
    ('молочный коктейль 1', [(0, 8, 'B-TYPE'), (9, 17, 'I-TYPE'), (18, 19, 'B-VOLUME')]),
    ('молочный коктейль 1,5', [(0, 8, 'B-TYPE'), (9, 17, 'I-TYPE'), (18, 21, 'B-VOLUME')]),
    ('вода 5', [(0, 4, 'B-TYPE'), (5, 6, 'B-VOLUME')]),	
    ('сахар 3', [(0, 5, 'B-TYPE'), (6, 7, 'B-VOLUME')])
]

examples_df = pd.DataFrame({'sample': [x[0] for x in dop_val_samples], 'annotation': [x[1] for x in dop_val_samples]})
train_texts = train_data['sample'].tolist()
examples_df = examples_df[examples_df['sample'].apply(lambda x: x not in train_texts)]
examples_df['annotation'] = examples_df['annotation'].astype(str)
valid_df = pd.concat([valid_df,examples_df],axis=0,ignore_index=True)

In [14]:
train_ds = Dataset.from_pandas(train_df, preserve_index=False)
valid_ds = Dataset.from_pandas(valid_df, preserve_index=False)
raw_datasets = DatasetDict({"train": train_ds, "validation": valid_ds})

In [15]:
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa",add_prefix_space=True)

def hf_map_fn(examples):
    return tokenize_and_align_labels(examples, tokenizer, label2id, label_all_tokens=True, max_length=512)

tokenized_datasets = raw_datasets.map(
    hf_map_fn,
    batched=True,
    remove_columns=raw_datasets["train"].column_names, 
    desc="Tokenizing and aligning labels"
)

Tokenizing and aligning labels:   0%|          | 0/21952 [00:00<?, ? examples/s]

Tokenizing and aligning labels:   0%|          | 0/5502 [00:00<?, ? examples/s]

In [16]:
valid_samples = valid_df['sample'].tolist()
valid_annotation = valid_df['annotation'].apply(eval).tolist()

In [None]:
def ner_word_tuples(texts, model, tokenizer, batch_size=16):
    single_input = isinstance(texts, str)
    texts_list = [texts] if single_input else list(texts)

    spans_per_text = []
    words_per_text = []
    for t in texts_list:
        spans = [(m.group(0), m.start(), m.end()) for m in re.finditer(r"\S+", t)]
        spans_per_text.append(spans)
        words_per_text.append([w for w, _, _ in spans])

    if not any(words_per_text):
        empty = [[] for _ in texts_list]
        return empty[0] if single_input else empty

    models = model if isinstance(model, (list, tuple)) else [model]
    tokenizers = tokenizer if isinstance(tokenizer, (list, tuple)) else [tokenizer]
    if len(models) != len(tokenizers):
        raise ValueError("models and tokenizers must have the same length")

    label_set = set()
    for m in models:
        id2label = m.config.id2label
        if isinstance(id2label, dict):
            for i in range(len(id2label)):
                label_set.add(id2label[i])
        else:
            label_set.update(id2label)
    global_labels = (['O'] + sorted([l for l in label_set if l != 'O'])) if 'O' in label_set else sorted(label_set)
    label2idx = {lab: i for i, lab in enumerate(global_labels)}

    local2global = []
    for m in models:
        id2label = m.config.id2label
        if isinstance(id2label, dict):
            local = [label2idx[id2label[i]] for i in range(len(id2label))]
        else:
            local = [label2idx[l] for l in id2label]
        local2global.append(torch.tensor(local, dtype=torch.long))

    L = len(global_labels)
    O_idx = label2idx.get('O', 0)
    o_vec = torch.zeros(L, dtype=torch.float32)
    o_vec[O_idx] = 1.0

    labels_per_text = [None] * len(texts_list)

    for b in range(0, len(texts_list), batch_size):
        batch_words = words_per_text[b:b + batch_size]

        per_model_word_probs = []
        for mi, (m, tok) in enumerate(zip(models, tokenizers)):
            enc = tok(
                batch_words,
                is_split_into_words=True,
                return_tensors="pt",
                padding=True,
                truncation=True,
            )
            word_ids_list = [enc.word_ids(batch_index=i) for i in range(len(batch_words))]
            m_device = next(m.parameters()).device
            enc_gpu = {k: v.to(m_device) for k, v in enc.items()}
            with torch.inference_mode():
                logits = m(**enc_gpu).logits
                probs = torch.nn.functional.softmax(logits, dim=-1)
            map_idx = local2global[mi].to(probs.device)
            gprobs = torch.zeros(probs.size(0), probs.size(1), L, device=probs.device, dtype=probs.dtype)
            gprobs.scatter_add_(dim=-1, index=map_idx.view(1, 1, -1).expand_as(probs), src=probs)
            gprobs = gprobs.float().cpu()

            model_word_probs = []
            for si, wids in enumerate(word_ids_list):
                arr = [None] * len(batch_words[si])
                gp = gprobs[si]
                for ti, wid in enumerate(wids):
                    if wid is None:
                        continue
                    if 0 <= wid < len(arr) and arr[wid] is None:
                        arr[wid] = gp[ti]
                for wi in range(len(arr)):
                    if arr[wi] is None:
                        arr[wi] = o_vec
                model_word_probs.append(arr)
            per_model_word_probs.append(model_word_probs)

        for si in range(len(batch_words)):
            seq_len = len(batch_words[si])
            seq_labels = []
            for wi in range(seq_len):
                vecs = [per_model_word_probs[mj][si][wi] for mj in range(len(models))]
                mean_vec = torch.stack(vecs, dim=0).mean(dim=0)
                seq_labels.append(global_labels[int(mean_vec.argmax().item())])
            labels_per_text[b + si] = seq_labels

    results = []
    for i, labs in enumerate(labels_per_text):
        labs = _fix_bio(labs)
        spans = spans_per_text[i]
        results.append([(start, end, lab) for (_, start, end), lab in zip(spans, labs)])

    return results[0] if single_input else results

In [18]:
def f1_macro(labels, preds):
  ents = {
      'TYPE' : [0, 0, 0],
      'BRAND' : [0, 0, 0],
      'VOLUME' : [0, 0, 0],
      'PERCENT' :  [0,0,0]
  }
  for tr, pr in zip(labels, preds):
    ents_cur = {
      'TYPE' : [0, 0, 0],
      'BRAND' : [0, 0, 0],
      'VOLUME' : [0, 0, 0],
      'O' : [0,0,0],
      'PERCENT' :  [0,0,0]
    }
    used = []
    for tr_i in tr:
      flg = 0
      for pr_i in pr:
        if pr_i == tr_i:
          if sum(ents_cur[tr_i[2].split('-')[-1]]) == 0:
            ents_cur[tr_i[2].split('-')[-1]][0] = 1
          used.append(tr_i[2].split('-')[-1])
          flg = 1
          break
      if flg == 0:
        ents_cur[tr_i[2].split('-')[-1]][2] = 1
    for pr_i in pr:
      ent = pr_i[2].split('-')[-1]
      if ent not in used:
        ents_cur[ent][1] = 1
    for ent in ['TYPE', 'BRAND', 'VOLUME', 'PERCENT']:
      ents[ent][0] += ents_cur[ent][0]
      ents[ent][1] += ents_cur[ent][1]
      ents[ent][2] += ents_cur[ent][2]
  f1s = []
  for ent in ['TYPE', 'BRAND', 'VOLUME', 'PERCENT']:
    TP = ents[ent][0]
    FP = ents[ent][1]
    FN = ents[ent][2]
    if FP + TP + FN != 0:
      if FP + TP == 0 or FN + TP == 0 or TP == 0:
        f1s.append(0)
      else:
        rec = TP / (FP + TP)
        pre = TP / (FN + TP)
        f1s.append(2*rec*pre/(rec+pre)) 
  return np.mean(f1s)

In [None]:
def _fix_bio(labels):
    fixed = []
    prev_type = None
    inside = False
    for lab in labels:
        if lab is None or lab == "O":
            fixed.append("O")
            prev_type, inside = None, False
            continue

        if "-" in lab:
            prefix, typ = lab.split("-", 1)
        else:
            prefix, typ = "B", lab

        if prefix == "I" and (not inside or prev_type != typ):
            fixed.append(f"B-{typ}")
            prev_type, inside = typ, True
        else:
            fixed.append(f"{prefix}-{typ}")
            inside = prefix in ("B", "I")
            prev_type = typ if inside else None
    return fixed

def _pieces_per_word(word):
    return len(tokenizer(word, add_special_tokens=False)["input_ids"]) or 1

def _chunk_words(words, max_pieces):
    chunk, cur = [], 0
    for w in words:
        n = _pieces_per_word(w)
        if chunk and cur + n > max_pieces:
            yield chunk
            chunk, cur = [w], n
        else:
            chunk.append(w)
            cur += n
    if chunk:
        yield chunk

def post_proc(val_predicts):
    val_predicts_new = []
    for tripls in val_predicts:
        row,used = [],[]
        for i,j,k in tripls:
            if k == 'O':
                row.append((i,j,k))
                continue
            tp = k.split('-')[1]
            if tp in used:
                row.append((i,j,k.replace('B-','I-')))
            else:
                row.append((i,j,k.replace('I-','B-')))
            used.append(tp)
        val_predicts_new.append(row)
    return val_predicts_new

In [None]:
epoch = 0
def compute_metrics_fn(p):
    global epoch
    preds, labels = p
    preds = np.argmax(preds, axis=2)
    true_preds, true_labels = [], []
    for pred_seq, lab_seq in zip(preds, labels):
        pred_tags, lab_tags = [], []
        for p_i, l_i in zip(pred_seq, lab_seq):
            if l_i == -100:
                continue
            pred_tags.append(id2label[p_i])
            lab_tags.append(id2label[l_i])
        true_preds.append(pred_tags)
        true_labels.append(lab_tags)
    print(classification_report(true_labels, true_preds))
    seq_eval_metrics =  {
        "precision": precision_score(true_labels, true_preds),
        "recall":    recall_score(true_labels, true_preds),
        "f1":        f1_score(true_labels, true_preds, average='macro'),
        "accuracy":  accuracy_score(true_labels, true_preds),
    }
    #backup = model.state_dict()
    #ema_model.copy_to(model)
    val_preds = post_proc(ner_word_tuples(valid_samples, model, tokenizer, batch_size=128))
    seq_eval_metrics['f1_word_level'] = f1_macro(valid_annotation,val_preds)

    model.save_pretrained(f'../../tmp/ckpt_{epoch}')
    epoch += 1
    #model.load_state_dict(backup)
    return seq_eval_metrics

In [None]:
set_seed(56)
model = AutoModelForTokenClassification.from_pretrained(
    "ai-forever/ru-en-RoSBERTa",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
).cuda()

Some weights of RobertaForTokenClassification were not initialized from the model checkpoint at ai-forever/ru-en-RoSBERTa and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)

training_args = TrainingArguments(
    output_dir="rosberta",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    save_total_limit=1,
    num_train_epochs=30,
    weight_decay=0.0,
    warmup_ratio=0.1,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    group_by_length=True,
    fp16=False,  
    report_to="none",
    seed=56
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics_fn,
)

trainer.train()

  trainer = Trainer(
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 0}.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy,F1 Word Level
1,1.0837,0.331619,0.880042,0.915064,0.430181,0.879092,0.429517
2,0.2905,0.255522,0.925681,0.953001,0.707558,0.926346,0.708224
3,0.2011,0.232377,0.944528,0.953582,0.89353,0.933263,0.834034
4,0.1637,0.298949,0.92262,0.946901,0.91433,0.924762,0.906732
5,0.1146,0.247918,0.944653,0.964794,0.911674,0.941341,0.877496
6,0.0792,0.280757,0.942686,0.967002,0.908223,0.943875,0.872663


  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

       BRAND       0.86      0.70      0.77      3476
     PERCENT       1.00      0.01      0.02       104
        TYPE       0.88      0.98      0.93     13599
      VOLUME       0.00      0.00      0.00        34

   micro avg       0.88      0.92      0.90     17213
   macro avg       0.69      0.42      0.43     17213
weighted avg       0.88      0.92      0.89     17213





              precision    recall  f1-score   support

       BRAND       0.92      0.82      0.87      3476
     PERCENT       0.77      0.93      0.84       104
        TYPE       0.93      0.99      0.96     13599
      VOLUME       1.00      0.09      0.16        34

   micro avg       0.93      0.95      0.94     17213
   macro avg       0.90      0.71      0.71     17213
weighted avg       0.93      0.95      0.94     17213





              precision    recall  f1-score   support

       BRAND       0.90      0.87      0.89      3476
     PERCENT       0.89      0.99      0.94       104
        TYPE       0.96      0.97      0.96     13599
      VOLUME       1.00      0.65      0.79        34

   micro avg       0.94      0.95      0.95     17213
   macro avg       0.94      0.87      0.89     17213
weighted avg       0.94      0.95      0.95     17213





              precision    recall  f1-score   support

       BRAND       0.95      0.77      0.85      3476
     PERCENT       0.95      0.95      0.95       104
        TYPE       0.92      0.99      0.95     13599
      VOLUME       1.00      0.82      0.90        34

   micro avg       0.92      0.95      0.93     17213
   macro avg       0.96      0.88      0.91     17213
weighted avg       0.92      0.95      0.93     17213





              precision    recall  f1-score   support

       BRAND       0.92      0.89      0.91      3476
     PERCENT       0.90      0.99      0.94       104
        TYPE       0.95      0.98      0.97     13599
      VOLUME       1.00      0.71      0.83        34

   micro avg       0.94      0.96      0.95     17213
   macro avg       0.94      0.89      0.91     17213
weighted avg       0.94      0.96      0.95     17213





              precision    recall  f1-score   support

       BRAND       0.93      0.88      0.91      3476
     PERCENT       0.91      0.97      0.94       104
        TYPE       0.95      0.99      0.97     13599
      VOLUME       0.93      0.74      0.82        34

   micro avg       0.94      0.97      0.95     17213
   macro avg       0.93      0.89      0.91     17213
weighted avg       0.94      0.97      0.95     17213





In [None]:
tokenizer = AutoTokenizer.from_pretrained("ai-forever/ru-en-RoSBERTa",add_prefix_space=True)

In [None]:
model_23 = AutoModelForTokenClassification.from_pretrained('../../tmp/ckpt_23').cuda()
model_25 = AutoModelForTokenClassification.from_pretrained('../../tmp/ckpt_25').cuda()
model_27 = AutoModelForTokenClassification.from_pretrained('../../tmp/ckpt_27').cuda()
model_28 = AutoModelForTokenClassification.from_pretrained('../../tmp/ckpt_28').cuda()
model_29 = AutoModelForTokenClassification.from_pretrained('../../tmp/ckpt_29').cuda()

In [None]:
st_27 = model_27.state_dict()
st_28 = model_28.state_dict()
st_29 = model_29.state_dict()

In [None]:
new_state_dict = {}
for k in st_27:
    new_state_dict[k] = (st_27[k] + st_28[k] + st_29[k]) / 3

In [None]:
torch.save(new_state_dict,'rosberta_09481.pth')