In [70]:
import json
import pandas as pd
import glob
import os
import math

import random
import pandas as pd
from tqdm import tqdm

from collections import Counter

from pathlib import Path
from typing import List, Optional
import torch
from datasets import Dataset, DatasetDict
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from transformers import (
AutoConfig,
AutoModelForMaskedLM,
AutoTokenizer,
PreTrainedTokenizerFast,
DataCollatorForLanguageModeling,
DataCollatorForWholeWordMask,
Trainer,
TrainingArguments,
)

random.seed(42)

### собираем данные

#### пятерочка

In [3]:


def parse_product_jsons(folder_path):
    all_data = []
    
    # Find all JSON files in the folder
    json_files = glob.glob(os.path.join(folder_path, "*.json"))
    
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        if 'products_list' in file_path:
            try:
                with open(file_path, 'r', encoding='utf-8') as f:
                    data = json.load(f)
                
                # Extract category information
                category_id = data.get('category_id', '')
                category_name = data.get('category_name', '')
                
                # Extract products list
                products_list = data.get('products_list', [])
                
                for product in products_list:
                    product_info = product.get('product_info', {})
                    
                    # Extract product details
                    plu = product_info.get('plu', '')
                    name = product_info.get('name', '')
                    property = product_info.get('property', '')
                    
                    # Append to our data list
                    all_data.append({
                        'plu': plu,
                        'name': name,
                        'property': property,
                        'category_name': category_name,
                        'category_id': category_id
                    })
                    
            except Exception as e:
                print(f"Error processing file {file_path}: {e}")
            continue
    
    # Create DataFrame
    df = pd.DataFrame(all_data)
    
    print(f"Successfully parsed {len(df)} products")
    print(f"DataFrame shape: {df.shape}")
    
    return df




In [21]:
folder_path = "/home/mikhail/Documents/Хакатоны/X5_ner_MiLky_way/parser/parsed_catalog_5ka"

# Parse the JSON files
df_5ka = parse_product_jsons(folder_path)



Found 206 JSON files
Successfully parsed 25253 products
DataFrame shape: (25253, 5)


In [22]:
df_5ka.shape

(25253, 5)

In [None]:
df_5ka.head()

Unnamed: 0,plu,name,property,category_name,category_id
0,4304864,Нектар Rich вишня 900мл,900 мл,Для особых случаев,251C42818
1,4360696,Чай черный Rich Персик холодный 1.5л,1.5 л,Для особых случаев,251C42818
2,4378400,Напиток Добрый Апельсин-мандарин для детского ...,1.45 л,Для особых случаев,251C42818
3,58053,Нектар Добрый мультифрукт 1л,1 л,Для особых случаев,251C42818
4,4274867,Напиток Добрый Cola без сахара газированный 1.5л,1.5 л,Для особых случаев,251C42818


In [23]:
df_5ka.property.value_counts()

property
1 шт       1992
200 г      1146
100 г       925
250 г       894
300 г       802
           ... 
4.75 кг       1
2.3 л         1
35 шт         1
3.78 л        1
7 мл          1
Name: count, Length: 713, dtype: int64

In [24]:
df_5ka['name_processed'] = df_5ka['name'].str.lower()

In [None]:
names_processed_5ka = list(set(df_5ka['name_processed']))

In [30]:
len(names_processed_5ka)

19660

#### перекресток

In [14]:
def parse_product_jsons_v2(folder_path):
    """
    Parse multiple JSON files with the new structure into a pandas DataFrame
    
    Args:
        folder_path (str): Path to folder containing JSON files
    
    Returns:
        pd.DataFrame: DataFrame with columns: plu, name, en_name, category_name, 
                     category_id, unit_name, weight, volume
    """
    all_data = []
    
    # Find all JSON files in the folder
    json_files = glob.glob(os.path.join(folder_path, "*.json"))
    
    print(f"Found {len(json_files)} JSON files")
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            
            # The structure has a top-level category name as key
            for cat_key, category_data in data.items():
                # Extract category information from the nested structure
                cat_id = category_data.get('cat_id', '')
                
                # Extract items list
                items_list = category_data.get('items', [])
                
                for item in items_list:
                    # Extract basic item info
                    title = item.get('title', '')
                    
                    # Extract masterData
                    master_data = item.get('masterData', {})
                    plu = master_data.get('plu', '')
                    slug = master_data.get('slug', '')
                    unit_name = master_data.get('unitName', '')
                    weight = master_data.get('weight', '')
                    volume = master_data.get('volume', '')
                    
                    # Extract primaryCategory
                    primary_category = item.get('primaryCategory', {})
                    primary_cat_name = primary_category.get('title', '')
                    primary_cat_id = primary_category.get('id', '')
                    
                    # Append to our data list
                    all_data.append({
                        'plu': plu,
                        'name': title,
                        'en_name': slug,  # Using slug as English name
                        'category_name': primary_cat_name,
                        'category_id': primary_cat_id,
                        'unit_name': unit_name,
                        'weight': weight,
                        'volume': volume
                    })
                
        except Exception as e:
            print(f"Error processing file {file_path}: {e}")
            continue
    
    # Create DataFrame
    df = pd.DataFrame(all_data)
    
    print(f"Successfully parsed {len(df)} products")
    print(f"DataFrame shape: {df.shape}")
    
    return df

In [15]:
df_perek = parse_product_jsons_v2("/home/mikhail/Documents/Хакатоны/X5_ner_MiLky_way/parser/parsed_catalog_perek")

Found 30 JSON files
Successfully parsed 21025 products
DataFrame shape: (21025, 8)


In [20]:
df_perek

Unnamed: 0,plu,name,en_name,category_name,category_id,unit_name,weight,volume,en_name_raw
0,4166310,"Соус SanBonsai Ореховый, 250мл",sous-sanbonsai-orehovyj-250ml,"Майонез, соусы",394,шт,250.0,250.0,sous sanbonsai orehovyj 250ml
1,3691298,"Аджика Буздякский Кавказская, 350мл",adzika-buzdakskij-kavkazskaa-350ml,"Майонез, соусы",394,шт,350.0,,adzika buzdakskij kavkazskaa 350ml
2,3361376,"Соус Tabasco Хабанеро перечный, 60мл",sous-tabasco-habanero-perecnyj-60ml,"Майонез, соусы",394,шт,100.0,60.0,sous tabasco habanero perecnyj 60ml
3,4013003,Крем бальзамический Casa Rinaldi со вкусом имб...,krem-balzamiceskij-casa-rinaldi-so-vkusom-imbi...,"Майонез, соусы",394,шт,250.0,250.0,krem balzamiceskij casa rinaldi so vkusom imbi...
4,3494860,"Кетчуп Балтимор Томатный, 260г",ketcup-baltimor-tomatnyj-260g,"Майонез, соусы",394,шт,260.0,,ketcup baltimor tomatnyj 260g
...,...,...,...,...,...,...,...,...,...
21020,4274867,"Напиток газированный Добрый Cola без сахара, 1.5л",napitok-gazirovannyj-dobryj-cola-bez-sahara-1-5l,Газировка,423,шт,1500.0,1500.0,napitok gazirovannyj dobryj cola bez sahara 1 5l
21021,1863,"Напиток газированный Coca-Cola, 330мл",napitok-gazirovannyj-coca-cola-330ml,Газировка,423,шт,330.0,330.0,napitok gazirovannyj coca cola 330ml
21022,3503711,Вода Малаховская №1 питьевая 1 категории негаз...,voda-malahovskaa-no1-pitevaa-1-kategorii-negaz...,Вода,424,шт,5000.0,5000.0,voda malahovskaa no1 pitevaa 1 kategorii negaz...
21023,3173468,"Энергетический напиток Red Bull, 473мл",energeticeskij-napitok-red-bull-473ml,Энергетик,426,шт,473.0,473.0,energeticeskij napitok red bull 473ml


In [27]:
df_perek['en_name_raw'] = df_perek['en_name'].str.replace("-", " ")
df_perek["name_processed"] = df_perek["name"].str.lower()

In [31]:
names_processed_perek = list(set(df_perek['name_processed']))

In [33]:
len(names_processed_perek)

17942

In [34]:
names_processed_perek[:10]

['огурцы по-дижонски с мёдом и горчицей целые маринованные маркет перекрёсток, 680г',
 'напиток газированный coca-cola, 330мл',
 'пиво балтика крепкое легендарное №9 светлое пастеризованное 8%, 450мл',
 'йогурт epica питьевой киви-виноград 2.5%, 260г',
 'торт черёмушки наполеон слоёный, 310г',
 'грудка цыплёнка-бройлера копчёно-варёная маркет',
 'пиво prazacka злата светлое 4.9%, 500мл',
 'конфитюр ратибор грушевый, 350г',
 'удобрение fertika leaf power универсальное водорастворимое, 50г',
 'настойка веда малина и базилик горькая 38%, 500мл']

#### combined

In [36]:
names = names_processed_5ka + names_processed_perek

In [35]:
df = pd.concat([df_perek, df_5ka])
df.tail()

Unnamed: 0,plu,name,en_name,category_name,category_id,unit_name,weight,volume,en_name_raw,name_processed,property
25248,4119675,Подставка Лакарт Дизайн для телефона в ассорти...,,Хозтовары,251C12944,,,,,подставка лакарт дизайн для телефона в ассорти...,1 шт
25249,4120583,Корзинка EcoNova универсальная 170х120х75мм в ...,,Хозтовары,251C12944,,,,,корзинка econova универсальная 170х120х75мм в ...,1 шт
25250,4007117,Сумка Zhejiang Senmiao Trade в ассортименте 1шт.,,Хозтовары,251C12944,,,,,сумка zhejiang senmiao trade в ассортименте 1шт.,1 шт
25251,4118948,Удлинитель Эра UX-3-1.5m без заземления 1300Вт...,,Хозтовары,251C12944,,,,,удлинитель эра ux-3-1.5m без заземления 1300вт...,1 шт
25252,4150382,Чехол для наушников 5.7х5х2.2см в ассортименте...,,Хозтовары,251C12944,,,,,чехол для наушников 5.7х5х2.2см в ассортименте...,1 шт


### создаем аугментированный датасет

In [None]:

# Visual/keyboard confusable map (common Cyrillic <-> Latin substitutes)
confusable_map = {
    'а':'a','в':'b','е':'e','ё':'e','к':'k','м':'m','н':'h','о':'o','р':'p','с':'c','т':'t','у':'y','х':'x',
    'A':'A','B':'B','E':'E','K':'K','M':'M','H':'H','O':'O','P':'P','C':'C','T':'T','Y':'Y','X':'X'
}
# Reverse mapping (Latin->Cyrillic where visually confusable)
confusable_map_rev = {v:k for k,v in confusable_map.items()}

# Simple Russian keyboard adjacency (partial, common neighbors)
# (this is an approximation for common typos)
keyboard_neighbors = {
    'й':'ц','ц':'й','у':'и','и':'у','е':'р','р':'е','т':'ь','ь':'т','о':'п','п':'о',
    'а':'ф','ф':'а','ы':'в','в':'ы','с':'м','м':'с','д':'л','л':'д','ж':'э','э':'ж','я':'ш','ш':'я',
    'ч':'с','ю':'б','б':'ю','н':'т','г':'ш','к':'л'
}

# --- Augmentation functions ---
def substitute_chars(s, prob=0.04):
    """Randomly substitute letters: confusable or keyboard neighbor or random cyrillic letter."""
    out = []
    for c in s:
        if c.isalpha() and random.random() < prob:
            r = random.random()
            if r < 0.4 and c.lower() in keyboard_neighbors:  # neighbor
                rep = keyboard_neighbors.get(c.lower(), c)
                # Preserve case
                rep = rep.upper() if c.isupper() else rep
                out.append(rep)
            elif r < 0.7 and c in confusable_map:  # visual swap to Latin
                out.append(confusable_map[c])
            elif r < 0.85 and c.lower() in confusable_map_rev:  # Latin->Cyrillic
                rep = confusable_map_rev.get(c.lower(), c)
                rep = rep.upper() if c.isupper() else rep
                out.append(rep)
            else:
                # random Cyrillic letter substitute (common set)
                pool = 'абвгдеёжзийклмнопрстуфхцчшщыьюя'
                out.append(random.choice(pool))
        else:
            out.append(c)
    return ''.join(out)

def delete_chars(s, prob=0.03):
    return ''.join(c for c in s if not (c.isalpha() and random.random() < prob))

def insert_chars(s, prob=0.02):
    out = []
    pool = 'абвгдеёжзийклмнопрстуфхцчшщыьюя'
    for c in s:
        out.append(c)
        if c.isalpha() and random.random() < prob:
            out.append(random.choice(pool))
    return ''.join(out)

def transpose_chars(s, prob=0.02):
    s = list(s)
    i = 0
    while i < len(s)-1:
        if s[i].isalpha() and s[i+1].isalpha() and random.random() < prob:
            s[i], s[i+1] = s[i+1], s[i]
            i += 2
        else:
            i += 1
    return ''.join(s)

def remove_spaces_or_merge(s, prob=0.06):
    if random.random() < prob:
        # either remove all spaces or replace one space randomly
        if random.random() < 0.5:
            return s.replace(' ', '')
        else:
            parts = s.split()
            if len(parts) > 1:
                i = random.randrange(len(parts)-1)
                parts[i] = parts[i] + parts[i+1]
                del parts[i+1]
                return ' '.join(parts)
    return s

def truncation(s, prob=0.15):
    if random.random() < prob:
        parts = s.split()
        if len(parts) > 1:
            # drop last token sometimes or drop random token
            if random.random() < 0.8:
                return ' '.join(parts[:-1])
            else:
                idx = random.randrange(len(parts))
                del parts[idx]
                return ' '.join(parts)
        else:
            # drop last few chars
            cut = max(1, int(len(s)*0.3))
            return s[:-cut]
    return s

def add_template(s, prob=0.12):
    if random.random() < prob:
        templ = random.choice(templates)
        return templ.format(s)
    return s

def random_case_and_punct(s):
    s = s.lower()
    # optionally add punctuation or trailing dots/commas
    if random.random() < 0.06:
        s = s + random.choice(['.', '..', ',', ';;', '!'])
    return s

def swap_cyr_lat(s, prob=0.05):
    # randomly replace some chars with confusables (Cyr->Lat or Lat->Cyr)
    out = []
    for c in s:
        if c.isalpha() and random.random() < prob:
            if c in confusable_map and random.random() < 0.6:
                out.append(confusable_map[c])
                continue
            if c.lower() in confusable_map_rev and random.random() < 0.6:
                rep = confusable_map_rev.get(c.lower(), c)
                rep = rep.upper() if c.isupper() else rep
                out.append(rep)
                continue
        out.append(c)
    return ''.join(out)

# Compose simple augmentation pipeline
def augment_once(s):
    s = s.strip()
    s = random_case_and_punct(s)
    s = substitute_chars(s, prob=0.045)
    s = delete_chars(s, prob=0.03)
    s = insert_chars(s, prob=0.02)
    s = transpose_chars(s, prob=0.02)
    s = swap_cyr_lat(s, prob=0.04)
    s = remove_spaces_or_merge(s, prob=0.06)
    s = truncation(s, prob=0.12)
    # s = add_template(s, prob=0.10)
    # final cleanup: collapse multiple spaces
    s = ' '.join(s.split())
    return s

def generate_variants(item, n_variants=4, include_original=True):
    variants = []
    if include_original:
        variants.append(item)
    # produce variants with at least one correction attempt to canonical sometimes
    for _ in range(n_variants - (1 if include_original else 0)):
        v = augment_once(item)
        # ensure not identical to original unless random chance
        if v == item and random.random() < 0.6:
            v = augment_once(item)
        variants.append(v)
    return variants

# --- Example usage on provided query samples and parsed items ---
queries = [
"кечуп для гриля","кешью","кеыир","кеыирн","кзамороженнные","киа","кив","кивих","кидька","киевский",
"кизель","кизинаки","кильк","килька","килька в т","килька в то","килька в томате","килька в томатной",
"килька с овощами","кильки","ким носк","ким носков","кимч","кимчи","кин9а","кинги","киндер",
"киндер макми кинг","киндер молочный ломтик","киндер пингви","киндер сюрприз","киндер яйцо большое",
"киндеры","киндза","кинза","кинзу","кино","киноа","киноя","кинто","кинто еще","кинуа","киприно",
"киреешки","кирешки","кирие","кириеш","кисе","киселб","киселт","кисель dr.oetke","кисель dr.oetker",
"кисель айдиг","кисель айдиго","кисель детский","кисель фрутоня","кисель фрутонян","кисеь","кискель",
"кисл","кислая","кислая капуста","кислель","кислец","кислин","кислла","кисло слад","кисло сладк",
"кисло сладки","кисло сладкий","кисло-","кисло-сладк","кисло-сладки","кисло-сладкий","кисловодская",
"кислом","кисломикс","кисломоло","кисломолочка","кисломолочн"
]

parsed_items = [
'огурцы по-дижонски с мёдом и горчицей целые маринованные маркет перекрёсток, 680г',
'напиток газированный coca-cola, 330мл',
'пиво балтика крепкое легендарное №9 светлое пастеризованное 8%, 450мл',
'йогурт epica питьевой киви-виноград 2.5%, 260г',
'торт черёмушки наполеон слоёный, 310г',
'грудка цыплёнка-бройлера копчёно-варёная маркет',
'пиво prazacka злата светлое 4.9%, 500мл',
'конфитюр ратибор грушевый, 350г',
'удобрение fertika leaf power универсальное водорастворимое, 50г',
'настойка веда малина и базилик горькая 38%, 500мл'
]

# Generate a small table: for each of first 30 queries and first 5 parsed items generate variants
rows = []
for q in tqdm(names):
    variants = generate_variants(q, n_variants=5, include_original=True)
    for v in variants:
        rows.append({'source':'query', 'original':q, 'variant':v})

for p in parsed_items:
    variants = generate_variants(p, n_variants=5, include_original=True)
    for v in variants:
        rows.append({'source':'parsed_item', 'original':p, 'variant':v})

df = pd.DataFrame(rows)

# Also print a short textual sample for quick view:
print("Sample augmented variants (first 30 rows):")
df.head(30)


100%|██████████| 37602/37602 [00:04<00:00, 9361.48it/s]


Sample augmented variants (first 30 rows):


Unnamed: 0,source,original,variant
0,query,узвар великая русь из боярышника 1л,узвар великая русь из боярышника 1л
1,query,узвар великая русь из боярышника 1л,извар веидлая русх из боярышнгка 1л
2,query,узвар великая русь из боярышника 1л,узвар великая русь из боярышника
3,query,узвар великая русь из боярышника 1л,увае великая русь изю бярышника 1л
4,query,узвар великая русь из боярышника 1л,узвар великая русь з боярышhика 1л
5,query,свинина по-тирольски 4 сезона 600г,свинина по-тирольски 4 сезона 600г
6,query,свинина по-тирольски 4 сезона 600г,всинина по-трольски 4 сезона 600г
7,query,свинина по-тирольски 4 сезона 600г,свиина п-тирольсыки 4 сезона 600г
8,query,свинина по-тирольски 4 сезона 600г,свининх по-тирольски 4 сзепна 600г
9,query,свинина по-тирольски 4 сезона 600г,свинина по-тирольски 4 сезона 600г


In [103]:
df.sample().original.iloc[0]

'печенье американо сдобное маркет, 400г'

In [44]:
df.shape

(188060, 3)

In [89]:
len(tokenizer.tokenize(df.variant.max()))

23

### rubert tiny 2

In [47]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained("cointegrated/rubert-tiny2")
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")

In [None]:
maska = "[MASK]"
text = f"чипсы картофельные {maska} maxx куриные крылышки барбекю 110г"
inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

'>>> чипсы картофельные и maxx куриные крылышки барбекю 110г'
'>>> чипсы картофельные , maxx куриные крылышки барбекю 110г'
'>>> чипсы картофельные изделия maxx куриные крылышки барбекю 110г'
'>>> чипсы картофельные или maxx куриные крылышки барбекю 110г'
'>>> чипсы картофельные блюда maxx куриные крылышки барбекю 110г'


In [63]:
tokenizer

BertTokenizerFast(name_or_path='cointegrated/rubert-tiny2', vocab_size=83828, model_max_length=2048, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [65]:
texts = [...]  # your combined corpus, list of strings

vocab = tokenizer.get_vocab()
unk_id = tokenizer.unk_token_id

total_tokens = 0
unk_tokens = 0
subword_lengths = []

for t in names:
    enc = tokenizer(t, add_special_tokens=False)
    ids = enc["input_ids"]
    total_tokens += len(ids)
    unk_tokens += sum(1 for i in ids if i == unk_id)
    # measure average tokens per word
    words = t.split()
    if words:
        subword_lengths.extend([len(tokenizer(w, add_special_tokens=False)["input_ids"]) for w in words])

print("total tokens:", total_tokens)
print("unk tokens:", unk_tokens, "unk rate:", unk_tokens/total_tokens if total_tokens else 0)
print("avg subwords per word:", sum(subword_lengths)/len(subword_lengths))
print("median subwords per word:", sorted(subword_lengths)[len(subword_lengths)//2])

total tokens: 545983
unk tokens: 56 unk rate: 0.00010256729605134226
avg subwords per word: 1.9642078527589706
median subwords per word: 2


In [None]:
def simple_text_stats(lines: List[str]) -> dict:
stats = {}
lens = [len(l) for l in lines]
toks = [l.split() for l in lines]
token_counts = [len(t) for t in toks]


def has_latin(s):
return bool(re.search(r'[A-Za-z]', s))


def has_cyr(s):
return bool(re.search(r'[А-Яа-яЁё]', s))


latin_frac = sum(1 for l in lines if has_latin(l)) / max(1, len(lines))
cyr_frac = sum(1 for l in lines if has_cyr(l)) / max(1, len(lines))
digits_frac = sum(1 for l in lines if re.search(r'\d', l)) / max(1, len(lines))
punct_frac = sum(1 for l in lines if re.search(r'[.,;:!\-\/\\()]', l)) / max(1, len(lines))


stats['n_lines'] = len(lines)
stats['avg_len_chars'] = float(np.mean(lens))
stats['median_len_chars'] = int(np.median(lens))
stats['avg_tokens'] = float(np.mean(token_counts))
stats['median_tokens'] = int(np.median(token_counts))
stats['latin_frac'] = float(latin_frac)
stats['cyrillic_frac'] = float(cyr_frac)
stats['digits_frac'] = float(digits_frac)
stats['punct_frac'] = float(punct_frac)
return stats

In [67]:
def load_lines_from_file(path: Path) -> List[str]:
    lines = []
    with path.open('r', encoding='utf-8') as f:
        for line in f:
            s = line.strip()
            if s:
                lines.append(s)
    return lines

In [68]:
def mix_and_subsample(originals: List[str], augmented: List[str], mix_ratio: float = 0.75, seed: int = 42) -> List[str]:
    # mix_ratio = fraction of examples from originals
    random.seed(seed)
    n_total = len(originals) + len(augmented)
    n_from_orig = int(n_total * mix_ratio)
    n_from_aug = n_total - n_from_orig

    chosen_orig = random.choices(originals, k=max(1, n_from_orig)) if originals else []
    chosen_aug = random.choices(augmented, k=max(1, n_from_aug)) if augmented else []
    combined = chosen_orig + chosen_aug
    random.shuffle(combined)
    return combined

In [69]:
def tokenize_and_group_texts(lines: List[str], tokenizer: PreTrainedTokenizerFast, block_size: int = 128, use_wwm: bool = False):
    """Tokenize lines and group into blocks of `block_size`. If use_wwm=True we use is_split_into_words approach.

    Returns a HuggingFace Dataset with columns 'input_ids' and 'labels'
    """
    # Build dataset
    ds = Dataset.from_dict({'text': lines})

    def tokenize_func(examples):
        if use_wwm:
            # split into whitespace tokens for whole-word masking
            words = [t.split() for t in examples['text']]
            return tokenizer(words, is_split_into_words=True, add_special_tokens=True)
        else:
            return tokenizer(examples['text'], add_special_tokens=True)

    tokenized = ds.map(tokenize_func, batched=True, remove_columns=['text'])

    # concatenate and group
    def group_texts(examples):
        concatenated = sum(examples['input_ids'], [])
        total_length = len(concatenated)
        if total_length >= block_size:
            total_length = (total_length // block_size) * block_size
        else:
            total_length = 0
        result = {}
        if total_length == 0:
            result['input_ids'] = []
            result['labels'] = []
            return result
        result['input_ids'] = [concatenated[i:i+block_size] for i in range(0, total_length, block_size)]
        result['labels'] = [list(ids) for ids in result['input_ids']]
        return result

    lm_dataset = tokenized.map(group_texts, batched=True, remove_columns=tokenized.column_names)
    # drop empty rows
    lm_dataset = lm_dataset.filter(lambda ex: len(ex['input_ids']) > 0)

    # train/val split
    split = lm_dataset.train_test_split(test_size=0.01)
    return DatasetDict({'train': split['train'], 'validation': split['test']})

In [71]:
def build_data_collator(tokenizer: PreTrainedTokenizerFast, use_wwm: bool = False, mlm_probability: float = 0.15):
    if use_wwm:
        # DataCollatorForWholeWordMask expects tokenizer with word_ids support
        return DataCollatorForWholeWordMask(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_probability)
    else:
        return DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_probability)

In [80]:
def run_pipeline(
    output_dir: str = './here/',
    model_name_or_path: str = 'cointegrated/rubert-tiny2',
    train_tokenizer: bool = False,
    tokenizer_vocab_size: int = 30000,
    block_size: int = 128,
    use_wwm: bool = False,
    mix_ratio: float = 0.75,
    num_epochs: int = 3,
    per_device_train_batch_size: int = 32,
    gradient_accumulation_steps: int = 1,
    learning_rate: float = 5e-5,
    weight_decay: float = 0.01,
    seed: int = 42,
    save_steps: int = 2000,        # DataCollatorForWholeWordMask expects tokenizer with word_ids support

    logging_steps: int = 200,
    fp16: bool = False,
):
    os.makedirs(output_dir, exist_ok=True)
    random.seed(seed)

    originals = df['original'].tolist()
    augmented = df['variant'].tolist()


    # # Basic exploration
    # print('Exploring data distributions...')
    # s_orig = simple_text_stats(originals) if originals else {}
    # s_aug = simple_text_stats(augmented) if augmented else {}
    # print('Originals stats:', json.dumps(s_orig, ensure_ascii=False, indent=2))
    # print('Augmented stats:', json.dumps(s_aug, ensure_ascii=False, indent=2))

    # Combine corpus according to mix_ratio
    combined = mix_and_subsample(originals, augmented, mix_ratio=mix_ratio, seed=seed)
    print(f'Combined corpus size: {len(combined)}')

    # Tokenizer
    tokenizer = None
    tokenizer_path = os.path.join(output_dir, 'tokenizer.json')
    print(f'Loading tokenizer from pretrained model: {model_name_or_path}')
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, do_lower_case=True)
    # ensure mask token exists
    # if tokenizer.mask_token is None:
    #     tokenizer.add_special_tokens({'mask_token': '[MASK]'})

    # Tokenize and group
    datasets = tokenize_and_group_texts(combined, tokenizer, block_size=block_size, use_wwm=use_wwm)
    print('Train examples (blocks):', len(datasets['train']))
    print('Validation examples (blocks):', len(datasets['validation']))

    # Model
    print('Loading model...')
    model = AutoModelForMaskedLM.from_pretrained(model_name_or_path)
    # # if tokenizer added tokens, resize
    # try:
    #     model.resize_token_embeddings(len(tokenizer))
    # except Exception:
    #     pass

    # Data collator
    data_collator = build_data_collator(tokenizer, use_wwm=use_wwm, mlm_probability=0.15)

    # Training args
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        eval_strategy='steps',
        eval_steps=save_steps,
        save_steps=save_steps,
        save_total_limit=3,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        logging_steps=logging_steps,
        seed=seed,
        fp16=fp16,
        dataloader_num_workers=4,
        report_to='none',
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=datasets['train'],
        eval_dataset=datasets['validation'],
    )

    # Train
    print('Starting training...')
    trainer.train()

    # Eval
    print('Running final evaluation...')
    metrics = trainer.evaluate()
    loss = metrics.get('eval_loss')
    if loss is not None:
        try:
            ppl = math.exp(loss)
        except OverflowError:
            ppl = float('inf')
        metrics['perplexity'] = ppl
    print('Eval metrics:', metrics)

    # Save
    print('Saving tokenizer and model...')
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print('Done. Model saved to', output_dir)


In [81]:
run_pipeline()

Combined corpus size: 376120
Loading tokenizer from pretrained model: cointegrated/rubert-tiny2


Map: 100%|██████████| 376120/376120 [00:12<00:00, 29609.84 examples/s]
Map: 100%|██████████| 376120/376120 [00:15<00:00, 24682.72 examples/s]
Filter: 100%|██████████| 53465/53465 [00:05<00:00, 10495.15 examples/s]


Train examples (blocks): 52930
Validation examples (blocks): 535
Loading model...
Starting training...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

Step,Training Loss,Validation Loss


KeyboardInterrupt: 

In [134]:
print(df.original.sample().iloc[0])

чай императорский чай травяной в пакетиках, 20х1.2г


In [115]:
mlm_model_path = "/home/mikhail/Documents/Хакатоны/X5_ner_MiLky_way/models/rubert_tiny2_mlm"
model = AutoModelForMaskedLM.from_pretrained(mlm_model_path)
tokenizer = AutoTokenizer.from_pretrained(mlm_model_path)

In [139]:
maska = "[MASK]"
text = f'{maska} императорский травяной в пакетиках, 20х1.2г'

inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the location of [MASK] and extract its logits
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
# Pick the [MASK] candidates with the highest logits
top_5_tokens = torch.topk(mask_token_logits, 10, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"'>>> {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}'")

'>>> хлеб императорский травяной в пакетиках, 20х1.2г'
'>>> напиток императорский травяной в пакетиках, 20х1.2г'
'>>> соус императорский травяной в пакетиках, 20х1.2г'
'>>> торт императорский травяной в пакетиках, 20х1.2г'
'>>> салат императорский травяной в пакетиках, 20х1.2г'
'>>> кофе императорский травяной в пакетиках, 20х1.2г'
'>>> перец императорский травяной в пакетиках, 20х1.2г'
'>>> завтрак императорский травяной в пакетиках, 20х1.2г'
'>>> сыр императорский травяной в пакетиках, 20х1.2г'
'>>> лук императорский травяной в пакетиках, 20х1.2г'


In [114]:
from huggingface_hub import login
login(token="your_huggingface_token")

HTTPError: Invalid user token.