In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from collections import Counter
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.model_selection import train_test_split
import random
import warnings
from torch.optim import AdamW
warnings.filterwarnings('ignore')



# ==================== –ö–û–ù–§–ò–ì–£–†–ê–¶–ò–Ø ====================
class Config:
    model_name = "sberbank-ai/ruBert-base"
    batch_size = 16
    accumulation_steps = 2
    learning_rate = 1.5e-5  # –£–≤–µ–ª–∏—á–∏–ª learning rate
    epochs = 5  # –£–≤–µ–ª–∏—á–∏–ª –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ —ç–ø–æ—Ö
    max_length = 300  # –£–≤–µ–ª–∏—á–∏–ª –¥–ª–∏–Ω—É —Ç–µ–∫—Å—Ç–∞
    patience = 2  # –†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_warmup_ratio = 0.11  # –ü—Ä–æ–≥—Ä–µ–≤ –¥–ª—è —à–µ–¥—É–ª–µ—Ä–∞

config = Config()

# ==================== –ö–ê–°–¢–û–ú–ù–ê–Ø –ú–û–î–ï–õ–¨ ====================
class MultiLabelTransformerWithHeads(nn.Module):
    """–ö–∞—Å—Ç–æ–º–Ω–∞—è –º–æ–¥–µ–ª—å —Å –Ω–µ—Å–∫–æ–ª—å–∫–∏–º–∏ –≥–æ–ª–æ–≤–∫–∞–º–∏ –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏"""
    def __init__(self, model_name, num_labels, hidden_dropout_prob=0.3):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name)
        self.backbone = AutoModel.from_pretrained(model_name)

        # –†–∞–∑–Ω—ã–µ –≥–æ–ª–æ–≤–∫–∏ –¥–ª—è —Ä–∞–∑–Ω—ã—Ö –≥—Ä—É–ø–ø –∫–∞—Ç–µ–≥–æ—Ä–∏–π
        self.category_head = nn.Linear(self.config.hidden_size, num_labels)
        self.genre_head = nn.Linear(self.config.hidden_size, num_labels)

        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.layer_norm = nn.LayerNorm(self.config.hidden_size)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # [CLS] token

        # –ü—Ä–∏–º–µ–Ω—è–µ–º —Ä–∞–∑–Ω—ã–µ –≥–æ–ª–æ–≤–∫–∏ –∏ —É—Å—Ä–µ–¥–Ω—è–µ–º
        normalized_output = self.layer_norm(pooled_output)
        dropout_output = self.dropout(normalized_output)

        logits1 = self.category_head(dropout_output)
        logits2 = self.genre_head(dropout_output)

        # –£—Å—Ä–µ–¥–Ω–µ–Ω–∏–µ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
        logits = (logits1 + logits2) / 2

        loss = None
        if labels is not None:
            loss_fct = nn.BCEWithLogitsLoss()
            loss = loss_fct(logits, labels)

        return type('Output', (), {'loss': loss, 'logits': logits})

# ==================== –£–õ–£–ß–®–ï–ù–ù–ê–Ø –ü–û–î–ì–û–¢–û–í–ö–ê –î–ê–ù–ù–´–• ====================
def load_and_prepare_data():
    """–ó–∞–≥—Ä—É–∑–∫–∞ –∏ —É–ª—É—á—à–µ–Ω–Ω–∞—è –ø–æ–¥–≥–æ—Ç–æ–≤–∫–∞ –¥–∞–Ω–Ω—ã—Ö"""
    train_data = pd.read_csv('/content/train.tsv', sep='\t')
    test_data = pd.read_csv('/content/test.tsv', sep='\t')

    print("–î–æ—Å—Ç—É–ø–Ω—ã–µ –∫–æ–ª–æ–Ω–∫–∏ –≤ train:", train_data.columns.tolist())

    # –£–õ–£–ß–®–ï–ù–ò–ï: –ë–æ–ª–µ–µ —É–º–Ω–æ–µ —Å–æ–∑–¥–∞–Ω–∏–µ —Ç–µ–∫—Å—Ç–∞ —Å –ø—Ä–∏–æ—Ä–∏—Ç–µ—Ç–∞–º–∏
    def create_enhanced_text(row):
        parts = []

        # 1. –ù–∞–∑–≤–∞–Ω–∏–µ –ø—Ä–∏–ª–æ–∂–µ–Ω–∏—è (—Å–∞–º–æ–µ –≤–∞–∂–Ω–æ–µ)
        if 'app_name' in row:
            app_name = str(row['app_name']).strip()
            if app_name and app_name != 'nan':
                parts.append(f"–ù–∞–∑–≤–∞–Ω–∏–µ: {app_name}")

        # 2. –ö–æ—Ä–æ—Ç–∫–æ–µ –æ–ø–∏—Å–∞–Ω–∏–µ
        for col in ['shortDescription', 'short_description']:
            if col in row and pd.notna(row[col]):
                desc = str(row[col]).strip()
                if desc and desc != 'nan':
                    parts.append(f"–û–ø–∏—Å–∞–Ω–∏–µ: {desc}")
                    break

        # 3. –ü–æ–ª–Ω–æ–µ –æ–ø–∏—Å–∞–Ω–∏–µ
        for col in ['full_description', 'description', 'long_description']:
            if col in row and pd.notna(row[col]):
                full_desc = str(row[col]).strip()
                if full_desc and full_desc != 'nan':
                    # –ë–µ—Ä–µ–º —Ç–æ–ª—å–∫–æ –ø–µ—Ä–≤—ã–µ 200 —Å–∏–º–≤–æ–ª–æ–≤ –ø–æ–ª–Ω–æ–≥–æ –æ–ø–∏—Å–∞–Ω–∏—è
                    parts.append(f"–ü–æ–¥—Ä–æ–±–Ω–æ: {full_desc[:200]}")
                    break

        # –ï—Å–ª–∏ –Ω–∏—á–µ–≥–æ –Ω–µ—Ç, –∏—Å–ø–æ–ª—å–∑—É–µ–º —Ç–æ–ª—å–∫–æ –Ω–∞–∑–≤–∞–Ω–∏–µ
        if not parts and 'app_name' in row:
            return str(row['app_name'])

        return " ".join(parts)

    train_data['text'] = train_data.apply(create_enhanced_text, axis=1)
    test_data['text'] = test_data.apply(create_enhanced_text, axis=1)

    # –ê–Ω–∞–ª–∏–∑ –¥–ª–∏–Ω—ã —Ç–µ–∫—Å—Ç–æ–≤
    train_data['text_length'] = train_data['text'].str.len()
    print(f"–°—Ä–µ–¥–Ω—è—è –¥–ª–∏–Ω–∞ —Ç–µ–∫—Å—Ç–∞: {train_data['text_length'].mean():.0f} —Å–∏–º–≤–æ–ª–æ–≤")
    print(f"–ú–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è –¥–ª–∏–Ω–∞: {train_data['text_length'].max()} —Å–∏–º–≤–æ–ª–æ–≤")

    # MultiLabelBinarizer –¥–ª—è –º–µ—Ç–æ–∫
    mlb = MultiLabelBinarizer()
    train_labels = mlb.fit_transform(train_data['labels_str'].str.split('|'))

    # –ê–Ω–∞–ª–∏–∑ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏—è –º–µ—Ç–æ–∫
    label_counts = train_labels.sum(axis=0)
    print(f"\nüìä –ê–Ω–∞–ª–∏–∑ –º–µ—Ç–æ–∫:")
    print(f"–í—Å–µ–≥–æ –∫–ª–∞—Å—Å–æ–≤: {len(mlb.classes_)}")
    print(f"–û–±—â–µ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ –º–µ—Ç–æ–∫: {train_labels.sum()}")
    print(f"–ú–µ—Ç–æ–∫ –Ω–∞ –ø—Ä–∏–ª–æ–∂–µ–Ω–∏–µ: {train_labels.sum(axis=1).mean():.2f}")

    # –í—ã–≤–æ–¥–∏–º —Ç–æ–ø-10 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –º–µ—Ç–æ–∫
    top_labels = pd.DataFrame({
        'category': mlb.classes_,
        'count': label_counts
    }).sort_values('count', ascending=False).head(10)
    print("\n–¢–æ–ø-10 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π:")
    print(top_labels)

    return train_data, test_data, mlb, train_labels

# ==================== –£–õ–£–ß–®–ï–ù–ù–´–ô –î–ê–¢–ê–°–ï–¢ –° –ê–£–ì–ú–ï–ù–¢–ê–¶–ò–ï–ô ====================
class AdvancedTextAugmenter:
    """–£–ª—É—á—à–µ–Ω–Ω–∞—è –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–æ–≤"""
    def __init__(self):
        self.synonyms = {
    '–∏–≥—Ä–∞': ['–≥–µ–π–º', '–∏–≥—Ä–æ–≤–æ–π', '–∏–≥—Ä–æ–≤–æ–µ', '—Ä–∞–∑–≤–ª–µ—á–µ–Ω–∏–µ', '–≤–∏–¥–µ–æ–∏–≥—Ä–∞', '–∞—Ä–∫–∞–¥–∞', '–≥–µ–π–º–∏–Ω–≥', '–∏–≥—Ä—É—à–∫–∞', '–±–∞—Ç–∞–ª–∏—è', '–ø–∞—Ä—Ç–∏—è', '—Å–æ—Å—Ç—è–∑–∞–Ω–∏–µ', '–∑–∞–±–∞–≤–∞', '–∞—Ç—Ç—Ä–∞–∫—Ü–∏–æ–Ω', '—Å–∏–º—É–ª—è—Ç–æ—Ä', '–∫–≤–µ—Å—Ç', '–≥–æ–ª–æ–≤–æ–ª–æ–º–∫–∞', '—Å—Ç—Ä–∞—Ç–µ–≥–∏—è', '—ç–∫—à–µ–Ω', '–ø—Ä–∏–∫–ª—é—á–µ–Ω–∏–µ'],
    '–ø—Ä–∏–ª–æ–∂–µ–Ω–∏–µ': ['–ø—Ä–æ–≥—Ä–∞–º–º–∞', '—Å–æ—Ñ—Ç', '—É—Ç–∏–ª–∏—Ç–∞', '–∞–ø–ø', '–ø—Ä–∏–ª–æ–∂–µ–Ω—å–µ', '–ø—Ä–æ–≥—Ä–∞–º–º–∫–∞', '—Å–æ—Ñ—Ç–∏–Ω–∞', '–ø—Ä–æ–≥—Ä–∞–º–º–Ω–æ–µ –æ–±–µ—Å–ø–µ—á–µ–Ω–∏–µ', '–ü–û', '–∞–ø–ø–ª–∏–∫–∞—Ü–∏—è', '–∏–Ω—Å—Ç—Ä—É–º–µ–Ω—Ç', '–ø—Ä–æ–≥—Ä–∞–º–º–Ω—ã–π –ø—Ä–æ–¥—É–∫—Ç', '—Å–æ—Ñ—Ç–≤–µ—Ä', '–∫–ª–∏–µ–Ω—Ç', '–ø—Ä–æ–≥–∞'],
    '–±–µ—Å–ø–ª–∞—Ç–Ω—ã–π': ['—Ñ—Ä–∏', '–¥–∞—Ä–æ–º', '–±–µ–∑ –æ–ø–ª–∞—Ç—ã', '–±–µ—Å–ø–ª–∞—Ç–Ω–æ', 'free', 'gratis', '–∑–∞ —Ç–∞–∫', '–±–µ—Å–ø–ª–∞—Ç–Ω–æ–µ', '–±–µ–∑–≤–æ–∑–º–µ–∑–¥–Ω—ã–π', '–¥–∞—Ä–æ–≤–æ–π', '–∫–æ–º–ø–ª–∏–º–µ–Ω—Ç–∞—Ä–Ω—ã–π', '—Ö–∞–ª—è–≤–Ω—ã–π', '–±–æ–Ω—É—Å–Ω—ã–π', '–ø–æ–¥–∞—Ä–æ—á–Ω—ã–π', '–Ω–µ–∫–æ–º–º–µ—Ä—á–µ—Å–∫–∏–π', '–æ—Ç–∫—Ä—ã—Ç—ã–π', '—Å–≤–æ–±–æ–¥–Ω—ã–π'],
    '–æ–Ω–ª–∞–π–Ω': ['–∏–Ω—Ç–µ—Ä–Ω–µ—Ç', '—Å–µ—Ç–µ–≤–æ–π', '–≤–µ–±', 'online', '–≤ —Å–µ—Ç–∏', '–∏–Ω—Ç–µ—Ä–Ω–µ—Ç–Ω—ã–π', '—Å–µ—Ç–µ–≤–æ–µ', '–¥–∏—Å—Ç–∞–Ω—Ü–∏–æ–Ω–Ω—ã–π', '—É–¥–∞–ª–µ–Ω–Ω—ã–π', '–≤–∏—Ä—Ç—É–∞–ª—å–Ω—ã–π', '–∫–∏–±–µ—Ä–ø—Ä–æ—Å—Ç—Ä–∞–Ω—Å—Ç–≤–µ–Ω–Ω—ã–π', '–∏–Ω—Ç–µ—Ä–Ω–µ—Ç-', '–≤–µ–±-', '—Å–µ—Ç–µ–≤–æÃÅ–π', '–ø–æ–¥–∫–ª—é—á–µ–Ω–Ω—ã–π'],
    '–æ–±—É—á–µ–Ω–∏–µ': ['–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ', '—É—á–µ–±–∞', '–∫—É—Ä—Å', '—Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞', '–æ–±—É—á–∞—é—â–∏–π', '—É—á–µ–±–Ω—ã–π', '–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ', '–ø–µ–¥–∞–≥–æ–≥–∏–∫–∞', '–Ω–∞—É–∫–∞', '–∏–Ω—Å—Ç—Ä—É–∫—Ç–∞–∂', '–ø–æ–¥–≥–æ—Ç–æ–≤–∫–∞', '—Ä–∞–∑–≤–∏—Ç–∏–µ', '–ø—Ä–æ—Å–≤–µ—â–µ–Ω–∏–µ', '–æ–±—É—á–∞–ª–∫–∞', '—Ç—Ä–µ–Ω–∏–Ω–≥', '—Å–µ–º–∏–Ω–∞—Ä', '–ª–µ–∫—Ü–∏–∏'],
    '–º—É–∑—ã–∫–∞': ['–∞—É–¥–∏–æ', '–º–µ–ª–æ–¥–∏—è', '—Ç—Ä–µ–∫', '–ø–µ—Å–Ω—è', '–º—É–∑–ª–æ', '–∑–≤—É–∫', '–∫–æ–º–ø–æ–∑–∏—Ü–∏—è', '–º—É–∑—ã–∫–∞–ª—å–Ω—ã–π', '–º—É–∑—ã—á–∫–∞', '–Ω–∞–ø–µ–≤', '–º–æ—Ç–∏–≤', '–∞—Ä–∞–Ω–∂–∏—Ä–æ–≤–∫–∞', '—Å–∞—É–Ω–¥—Ç—Ä–µ–∫', '–º–∏–Ω—É—Å–æ–≤–∫–∞', '–±–∏—Ç', '—Ä–∏—Ç–º', '–≥–∞—Ä–º–æ–Ω–∏—è'],
    '–≤–∏–¥–µ–æ': ['–∫–ª–∏–ø', '—Ä–æ–ª–∏–∫', '—Ñ–∏–ª—å–º', '–∫–∏–Ω–æ', '–≤–∏–¥–µ–æ—Ä–æ–ª–∏–∫', '–≤–∏–¥–µ–æ–∫–ª–∏–ø', '–∫–∏–Ω–æ—Ñ–∏–ª—å–º', '–≤–∏–¥–µ–æ–∑–∞–ø–∏—Å—å', '–≤–∏–¥–µ–æ–∫–æ–Ω—Ç–µ–Ω—Ç', '–≤–∏–¥–µ–æ—Ñ–∞–π–ª', '–º—É–≤–∏', '—Ñ–∏–ª—å–º–µ—Ü', '—Ä–æ–ª–∏—á–µ–∫', '–∑–∞–ø–∏—Å—å', '—Ç—Ä–∞–Ω—Å–ª—è—Ü–∏—è', '—Å—Ç—Ä–∏–º'],
    '—Ñ–æ—Ç–æ': ['–∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ', '–∫–∞—Ä—Ç–∏–Ω–∫–∞', '—Å–Ω–∏–º–æ–∫', '—Ñ–æ—Ç–æ–≥—Ä–∞—Ñ–∏—è', '—Ñ–æ—Ç–æ—Å–Ω–∏–º–æ–∫', '—Ñ–æ—Ç–∫–∞', '–∫–∞—Ä—Ç–æ—á–∫–∞', '—Ñ–æ—Ç–æ–∫–∞—Ä—Ç–æ—á–∫–∞', '—Å–Ω–∏–º–æ–∫', '–ø–∏–∫—á–∞', '—Ñ–æ—Ç–æ–¥–æ–∫—É–º–µ–Ω—Ç', '–∏–ª–ª—é—Å—Ç—Ä–∞—Ü–∏—è', '–ø–æ—Ä—Ç—Ä–µ—Ç', '–ø–µ–π–∑–∞–∂', '—Ñ–æ—Ç–æ–∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ'],
    '—Å–æ—Ü–∏–∞–ª—å–Ω—ã–π': ['—Å–æ—Ü—Å–µ—Ç—å', '–æ–±—â–µ–Ω–∏–µ', '–∫–æ–º–º—É–Ω–∏–∫–∞—Ü–∏—è', '—Å–æ—Ü–∏—É–º', '—Å–æ—Ü–∏–∞–ª—å–Ω–∞—è —Å–µ—Ç—å', '–æ–±—â–∏–Ω–∞', '—Å–æ—Ü —Å–µ—Ç—å', '—Å–æ–æ–±—â–µ—Å—Ç–≤–æ', '—Å–µ—Ç—å', '–ø–ª–∞—Ç—Ñ–æ—Ä–º–∞', '–º–µ–¥–∏–∞', '—Ñ–æ—Ä—É–º', '—á–∞—Ç', '–±–ª–æ–≥', '—Å–æ—Ü–ø–ª–∞—Ç—Ñ–æ—Ä–º–∞', '–∏–Ω—Ç–µ—Ä–Ω–µ—Ç-—Å–æ–æ–±—â–µ—Å—Ç–≤–æ'],
    '–Ω–æ–≤–æ—Å—Ç–∏': ['—Å–æ–±—ã—Ç–∏—è', '–∏–Ω—Ñ–æ—Ä–º–∞—Ü–∏—è', '–æ–±–Ω–æ–≤–ª–µ–Ω–∏—è', '–Ω–æ–≤–æ—Å—Ç–Ω–æ–π', '–Ω–æ–≤–æ—Å—Ç–Ω–∞—è –ª–µ–Ω—Ç–∞', '—Å–≤–æ–¥–∫–∞', '–Ω–æ–≤–æ—Å—Ç–Ω–∏–∫', '—Ö—Ä–æ–Ω–∏–∫–∞', '—Ä–µ–ø–æ—Ä—Ç–∞–∂', '–∞–Ω–æ–Ω—Å', '–ø—Ä–µ—Å—Å–∞', '–º–µ–¥–∏–∞', '–ª–µ–Ω—Ç–∞', '–¥–∞–π–¥–∂–µ—Å—Ç', '–æ–±–∑–æ—Ä', '—Å–≤–µ–∂–∏–µ –Ω–æ–≤–æ—Å—Ç–∏'],
    '–º–∞–≥–∞–∑–∏–Ω': ['—à–æ–ø–∏–Ω–≥', '–ø–æ–∫—É–ø–∫–∏', '—Ç–æ—Ä–≥–æ–≤–ª—è', '–º–∞—Ä–∫–µ—Ç', '–æ–Ω–ª–∞–π–Ω –º–∞–≥–∞–∑–∏–Ω', '–º–∞—Ä–∫–µ—Ç–ø–ª–µ–π—Å', '—Ç–æ—Ä–≥–æ–≤–∞—è –ø–ª–æ—â–∞–¥–∫–∞', '–±—É—Ç–∏–∫', '–ª–∞–≤–∫–∞', '—Ç–æ—Ä–≥–æ–≤—ã–π —Ü–µ–Ω—Ç—Ä', '–∏–Ω—Ç–µ—Ä–Ω–µ—Ç-–º–∞–≥–∞–∑–∏–Ω', '—ç–ª–µ–∫—Ç—Ä–æ–Ω–Ω–∞—è –∫–æ–º–º–µ—Ä—Ü–∏—è', 'e-commerce', '—Ä–∏—Ç–µ–π–ª', '–ø—Ä–æ–¥–∞–∂–∏'],
    '–∫–Ω–∏–≥–∞': ['–ª–∏—Ç–µ—Ä–∞—Ç—É—Ä–∞', '—á—Ç–µ–Ω–∏–µ', '–∏–∑–¥–∞–Ω–∏–µ', '–∫–Ω–∏–∂–Ω—ã–π', '—ç–ª–µ–∫—Ç—Ä–æ–Ω–Ω–∞—è –∫–Ω–∏–≥–∞', '–±—É–∫', '–ª–∏—Ç–µ—Ä–∞—Ç—É—Ä–Ω–æ–µ', '—Ç–æ–º', '–∏–∑–¥–∞–Ω–∏–µ', '–ø—É–±–ª–∏–∫–∞—Ü–∏—è', '—Ä—É–∫–æ–ø–∏—Å—å', '—Ñ–æ–ª–∏–∞–Ω—Ç', '–±–µ—Å—Ç—Å–µ–ª–ª–µ—Ä', '—Ä–æ–º–∞–Ω', '–ø–æ–≤–µ—Å—Ç—å', '—Ä–∞—Å—Å–∫–∞–∑'],
    '–∑–¥–æ—Ä–æ–≤—å–µ': ['–º–µ–¥–∏—Ü–∏–Ω–∞', '—Ñ–∏—Ç–Ω–µ—Å', '–∑–¥–æ—Ä–æ–≤—ã–π', '–º–µ–¥–∏—Ü–∏–Ω—Å–∫–∏–π', '–∑–¥–æ—Ä–æ–≤—å–µ –∏ —Ñ–∏—Ç–Ω–µ—Å', '–º–µ–¥', '–æ–∑–¥–æ—Ä–æ–≤–ª–µ–Ω–∏–µ', '–∑–¥—Ä–∞–≤–æ–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ', '–±–ª–∞–≥–æ–ø–æ–ª—É—á–∏–µ', '—Å–∞–º–æ—á—É–≤—Å—Ç–≤–∏–µ', '–≤–∏—Ç–∞–ª—å–Ω–æ—Å—Ç—å', '–≥–∏–≥–∏–µ–Ω–∞', '–ø—Ä–æ—Ñ–∏–ª–∞–∫—Ç–∏–∫–∞', '—Ä–µ–∞–±–∏–ª–∏—Ç–∞—Ü–∏—è'],
    '–ø—É—Ç–µ—à–µ—Å—Ç–≤–∏–µ': ['—Ç—É—Ä–∏–∑–º', '–ø–æ–µ–∑–¥–∫–∞', '–æ—Ç–¥—ã—Ö', '—Ç—É—Ä', '–ø—É—Ç–µ—à–µ—Å—Ç–≤–∏—è', '—Ç—Ä–∏–ø', '–≤–æ—è–∂', '—Ç—É—Ä–ø–æ–µ–∑–¥–∫–∞', '—ç–∫—Å–∫—É—Ä—Å–∏—è', '–ø–∞–ª–æ–º–Ω–∏—á–µ—Å—Ç–≤–æ', '–∫—Ä—É–∏–∑', '—Å–∞—Ñ–∞—Ä–∏', '—ç–∫—Å–ø–µ–¥–∏—Ü–∏—è', '–ø–æ—Ö–æ–¥', '–æ—Ç–ø—É—Å–∫', '–∫–∞–Ω–∏–∫—É–ª—ã'],
    '–µ–¥–∞': ['–ø–∏—Ç–∞–Ω–∏–µ', '—Ä–µ—Ü–µ–ø—Ç', '–∫—É–ª–∏–Ω–∞—Ä–∏—è', '–ø–∏—â–∞', '–µ–¥–∞ –∏ –Ω–∞–ø–∏—Ç–∫–∏', '–∫—É—Ö–Ω—è', '–≥–æ—Ç–æ–≤–∫–∞', '–ø—Ä–æ–¥—É–∫—Ç—ã', '–±–ª—é–¥–æ', '–∫—É—à–∞–Ω—å–µ', '–≥–∞—Å—Ç—Ä–æ–Ω–æ–º–∏—è', '–¥–∏–µ—Ç–∞', '–º–µ–Ω—é', '—Ä–∞—Ü–∏–æ–Ω', '–ø—Ä–æ–≤–∏–∑–∏—è', '–∑–∞–∫—É—Å–∫–∞']
}
    def augment_text(self, text):
        """–ê—É–≥–º–µ–Ω—Ç–∞—Ü–∏—è —Ç–µ–∫—Å—Ç–∞ —Å —Ä–∞–∑–Ω—ã–º–∏ —Å—Ç—Ä–∞—Ç–µ–≥–∏—è–º–∏"""
        if random.random() < 0.6:  # 70% chance –¥–ª—è –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–∏
            words = text.split()
            if len(words) <= 3:  # –°–ª–∏—à–∫–æ–º –∫–æ—Ä–æ—Ç–∫–∏–π —Ç–µ–∫—Å—Ç - –Ω–µ –∞—É–≥–º–µ–Ω—Ç–∏—Ä—É–µ–º
                return text

            augmented_words = []

            for word in words:
                word_lower = word.lower().strip('.,!?;:')

                # –ó–∞–º–µ–Ω–∞ —Å–∏–Ω–æ–Ω–∏–º–∞–º–∏ —Å –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å—é 15%
                if word_lower in self.synonyms and random.random() < 0.2:
                    synonym = random.choice(self.synonyms[word_lower])
                    # –°–æ—Ö—Ä–∞–Ω—è–µ–º —Ä–µ–≥–∏—Å—Ç—Ä –ø–µ—Ä–≤–æ–π –±—É–∫–≤—ã
                    if word[0].isupper():
                        synonym = synonym.capitalize()
                    augmented_words.append(synonym)
                else:
                    augmented_words.append(word)

            return ' '.join(augmented_words)
        return text

class MultiStrategyAppDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=300, is_training=True,
                 augment_prob=0.5, augment_strategies=['synonym', 'delete', 'swap']):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_training = is_training
        self.augment_prob = augment_prob
        self.augment_strategies = augment_strategies if is_training else []
        self.augmenter = AdvancedTextAugmenter() if is_training else None

    def __len__(self):
        return len(self.texts)

    def _apply_augmentation(self, text):
        """–ü—Ä–∏–º–µ–Ω—è–µ—Ç —Å–ª—É—á–∞–π–Ω—É—é —Å—Ç—Ä–∞—Ç–µ–≥–∏—é –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–∏"""
        if not self.augment_strategies or random.random() > self.augment_prob:
            return text

        strategy = random.choice(self.augment_strategies)

        if strategy == 'synonym' and self.augmenter:
            return self.augmenter.augment_text(text)

        elif strategy == 'delete':
            # –£–¥–∞–ª–µ–Ω–∏–µ —Å–ª—É—á–∞–π–Ω—ã—Ö —Å–ª–æ–≤ (10-20%)
            words = text.split()
            if len(words) > 5:
                n_to_delete = max(1, int(len(words) * random.uniform(0.1, 0.2)))
                indices_to_keep = random.sample(range(len(words)), len(words) - n_to_delete)
                words = [words[i] for i in sorted(indices_to_keep)]
                return ' '.join(words)
            return text

        elif strategy == 'swap':
            # –°–≤–∞–ø —Å–æ—Å–µ–¥–Ω–∏—Ö —Å–ª–æ–≤
            words = text.split()
            if len(words) > 3:
                for _ in range(max(1, len(words) // 10)):
                    i = random.randint(0, len(words) - 2)
                    words[i], words[i + 1] = words[i + 1], words[i]
                return ' '.join(words)
            return text

        elif strategy == 'repeat':
            # –ü–æ–≤—Ç–æ—Ä–µ–Ω–∏–µ –≤–∞–∂–Ω—ã—Ö —á–∞—Å—Ç–µ–π
            words = text.split()
            if len(words) > 4:
                # –ü–æ–≤—Ç–æ—Ä—è–µ–º –ø–µ—Ä–≤—ã–µ 1-2 —Å–ª–æ–≤–∞ –≤ –∫–æ–Ω—Ü–µ
                n_repeat = random.randint(1, 2)
                repeated = words[:n_repeat]
                words.extend(repeated)
                return ' '.join(words)
            return text

        return text

    def __getitem__(self, idx):
        text = self.texts[idx]

        # –ü—Ä–∏–º–µ–Ω—è–µ–º –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏—é
        if self.is_training:
            text = self._apply_augmentation(text)

        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        item = {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

        if self.labels is not None:
            item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)

        return item

# ==================== –£–õ–£–ß–®–ï–ù–ù–´–ï –§–£–ù–ö–¶–ò–ò –û–¶–ï–ù–ö–ò ====================
def calculate_hit_at_k(predictions, true_labels, k=3):
    """–í—ã—á–∏—Å–ª–µ–Ω–∏–µ HitRate@K —Å —Ä–∞–∑–Ω—ã–º–∏ —Å—Ç—Ä–∞—Ç–µ–≥–∏—è–º–∏"""
    hit_count = 0
    for i in range(len(predictions)):
        true_indices = set(np.where(true_labels[i] > 0)[0])

        # –°—Ç—Ä–∞—Ç–µ–≥–∏—è 1: –¢–æ–ø-K –ø–æ —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç–∏
        pred_indices = set(np.argsort(predictions[i])[-k:])

        # –°—Ç—Ä–∞—Ç–µ–≥–∏—è 2: –ü–æ—Ä–æ–≥ + —Ç–æ–ø-K (—Ä–µ–∑–µ—Ä–≤–Ω–∞—è)
        if not true_indices & pred_indices:
            threshold = np.sort(predictions[i])[-k]  # k-–π –Ω–∞–∏–±–æ–ª—å—à–∏–π —Å–∫–æ—Ä
            pred_indices_thresh = set(np.where(predictions[i] >= threshold)[0])
            if true_indices & pred_indices_thresh:
                hit_count += 1
        else:
            hit_count += 1

    return hit_count / len(predictions)

def evaluate_model(model, dataloader, device):
    """–†–∞—Å—à–∏—Ä–µ–Ω–Ω–∞—è –æ—Ü–µ–Ω–∫–∞ –º–æ–¥–µ–ª–∏"""
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device) if 'labels' in batch else None

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.sigmoid(outputs.logits)
            all_preds.append(preds.cpu().numpy())
            if labels is not None:
                all_labels.append(labels.cpu().numpy())

    all_preds = np.vstack(all_preds)
    if all_labels:
        all_labels = np.vstack(all_labels)
        hitrate_3 = calculate_hit_at_k(all_preds, all_labels, k=3)
        hitrate_1 = calculate_hit_at_k(all_preds, all_labels, k=1)
        hitrate_5 = calculate_hit_at_k(all_preds, all_labels, k=5)
        return hitrate_3, hitrate_1, hitrate_5, all_preds, all_labels
    else:
        return None, None, None, all_preds, None

# ==================== –£–õ–£–ß–®–ï–ù–ù–û–ï –û–ë–£–ß–ï–ù–ò–ï ====================
class FocalLoss(nn.Module):
    """Focal Loss –¥–ª—è –±–æ—Ä—å–±—ã —Å –Ω–µ—Å–±–∞–ª–∞–Ω—Å–∏—Ä–æ–≤–∞–Ω–Ω–æ—Å—Ç—å—é –∫–ª–∞—Å—Å–æ–≤"""
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

def train_epoch(model, train_loader, optimizer, scheduler, device, accumulation_steps, criterion):
    """–£–ª—É—á—à–µ–Ω–Ω–æ–µ –æ–±—É—á–µ–Ω–∏–µ –Ω–∞ –æ–¥–Ω–æ–π —ç–ø–æ—Ö–µ"""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []

    optimizer.zero_grad()

    for step, batch in enumerate(train_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

        # –ò—Å–ø–æ–ª—å–∑—É–µ–º –∫–∞—Å—Ç–æ–º–Ω—ã–π –ª–æ—Å—Å –µ—Å–ª–∏ –ø–µ—Ä–µ–¥–∞–Ω, –∏–Ω–∞—á–µ —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–π
        if criterion:
            loss = criterion(outputs.logits, labels)
        else:
            loss = outputs.loss

        loss = loss / accumulation_steps
        loss.backward()

        if (step + 1) % accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        total_loss += loss.item() * accumulation_steps

        # –°–æ–±–∏—Ä–∞–µ–º –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è
        preds = torch.sigmoid(outputs.logits)
        all_preds.append(preds.detach().cpu().numpy())
        all_labels.append(labels.cpu().numpy())

        if step % 50 == 0:
            current_lr = scheduler.get_last_lr()[0]
            print(f"  Batch {step}/{len(train_loader)}, Loss: {loss.item():.4f}, LR: {current_lr:.2e}")

    # –ú–µ—Ç—Ä–∏–∫–∏ –Ω–∞ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    train_hitrate_3 = calculate_hit_at_k(all_preds, all_labels, k=3)
    train_hitrate_1 = calculate_hit_at_k(all_preds, all_labels, k=1)

    avg_loss = total_loss / len(train_loader)
    return avg_loss, train_hitrate_3, train_hitrate_1

# ==================== –û–°–ù–û–í–ù–ê–Ø –§–£–ù–ö–¶–ò–Ø ====================
def main():
    print("üöÄ –ó–ê–ü–£–°–ö –£–õ–£–ß–®–ï–ù–ù–û–ì–û –û–ë–£–ß–ï–ù–ò–Ø...")
    print(f"–£—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {config.device}")
    print(f"–ú–æ–¥–µ–ª—å: {config.model_name}")
    print(f"Batch size: {config.batch_size}")
    print(f"Learning rate: {config.learning_rate}")
    print(f"Max length: {config.max_length}")

    # –ó–∞–≥—Ä—É–∑–∫–∞ –¥–∞–Ω–Ω—ã—Ö
    train_data, test_data, mlb, train_labels = load_and_prepare_data()

    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä–∞ –∏ –ö–ê–°–¢–û–ú–ù–û–ô –º–æ–¥–µ–ª–∏
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    model = MultiLabelTransformerWithHeads(config.model_name, len(mlb.classes_))
    model.to(config.device)

    # –°—Ç—Ä–∞—Ç–∏—Ñ–∏—Ü–∏—Ä–æ–≤–∞–Ω–Ω–æ–µ —Ä–∞–∑–¥–µ–ª–µ–Ω–∏–µ
    train_texts, val_texts, train_y, val_y = train_test_split(
        train_data['text'].tolist(),
        train_labels,
        test_size=0.1,  # –£–º–µ–Ω—å—à–∏–ª –≤–∞–ª–∏–¥–∞—Ü–∏—é –¥–ª—è –±–æ–ª—å—à–µ–≥–æ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω–æ–≥–æ –Ω–∞–±–æ—Ä–∞
        random_state=42,
        stratify=train_labels.argmax(axis=1)
    )

    print(f"\nüìä –†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö:")
    print(f"–¢—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã–µ: {len(train_texts)}")
    print(f"–í–∞–ª–∏–¥–∞—Ü–∏–æ–Ω–Ω—ã–µ: {len(val_texts)}")

    # –°–æ–∑–¥–∞–Ω–∏–µ —É–ª—É—á—à–µ–Ω–Ω—ã—Ö –¥–∞—Ç–∞—Å–µ—Ç–æ–≤
    train_dataset = MultiStrategyAppDataset(
        train_texts, train_y, tokenizer, config.max_length,
        is_training=True, augment_prob=0.6
    )
    val_dataset = MultiStrategyAppDataset(
        val_texts, val_y, tokenizer, config.max_length,
        is_training=False
    )

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size * 2, shuffle=False)

    # –û–ø—Ç–∏–º–∏–∑–∞—Ç–æ—Ä —Å —Ä–∞–∑–Ω—ã–º–∏ learning rates –¥–ª—è —Ä–∞–∑–Ω—ã—Ö —Å–ª–æ–µ–≤
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters()
                   if not any(nd in n for nd in no_decay) and 'classifier' not in n and 'head' not in n],
         'weight_decay': 0.01, 'lr': config.learning_rate},
        {'params': [p for n, p in model.named_parameters()
                   if any(nd in n for nd in no_decay) and 'classifier' not in n and 'head' not in n],
         'weight_decay': 0.0, 'lr': config.learning_rate},
        {'params': [p for n, p in model.named_parameters() if 'category_head' in n or 'genre_head' in n],
         'weight_decay': 0.01, 'lr': config.learning_rate * 2},  # –ë–æ–ª–µ–µ –≤—ã—Å–æ–∫–∏–π LR –¥–ª—è –≥–æ–ª–æ–≤–æ–∫ –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ç–æ—Ä–∞
    ]

    optimizer = AdamW(optimizer_grouped_parameters)
    total_steps = len(train_loader) * config.epochs // config.accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(total_steps * config.num_warmup_ratio),
        num_training_steps=total_steps
    )

    # Focal Loss –¥–ª—è –Ω–µ—Å–±–∞–ª–∞–Ω—Å–∏—Ä–æ–≤–∞–Ω–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö
    criterion = FocalLoss(alpha=1, gamma=2)

    # –û–±—É—á–µ–Ω–∏–µ —Å —Ä–∞–Ω–Ω–µ–π –æ—Å—Ç–∞–Ω–æ–≤–∫–æ–π
    best_hitrate = 0.0
    patience_counter = 0

    for epoch in range(config.epochs):
        print(f"\nüéØ –≠–ü–û–•–ê {epoch + 1}/{config.epochs}")

        # –û–±—É—á–µ–Ω–∏–µ
        train_loss, train_hitrate_3, train_hitrate_1 = train_epoch(
            model, train_loader, optimizer, scheduler, config.device,
            config.accumulation_steps, criterion
        )

        # –í–∞–ª–∏–¥–∞—Ü–∏—è
        val_hitrate_3, val_hitrate_1, val_hitrate_5, _, _ = evaluate_model(model, val_loader, config.device)

        print(f"Train Loss: {train_loss:.4f}")
        print(f"Train H@1: {train_hitrate_1:.4f}, H@3: {train_hitrate_3:.4f}")
        print(f"Val H@1: {val_hitrate_1:.4f}, H@3: {val_hitrate_3:.4f}, H@5: {val_hitrate_5:.4f}")

        # –†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞ –∏ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ –ª—É—á—à–µ–π –º–æ–¥–µ–ª–∏
        if val_hitrate_3 > best_hitrate:
            best_hitrate = val_hitrate_3
            torch.save(model.state_dict(), 'best_model.pth')
            torch.save(optimizer.state_dict(), 'best_optimizer.pth')
            patience_counter = 0
            print(f"üéâ –ù–û–í–ê–Ø –õ–£–ß–®–ê–Ø –ú–û–î–ï–õ–¨! H@3: {best_hitrate:.4f}")
        else:
            patience_counter += 1
            print(f"‚è≥ –†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞: {patience_counter}/{config.patience}")

        if patience_counter >= config.patience:
            print("üõë –†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞ —Å—Ä–∞–±–æ—Ç–∞–ª–∞!")
            break

    print(f"\n‚úÖ –õ–£–ß–®–ò–ô –†–ï–ó–£–õ–¨–¢–ê–¢: H@3 = {best_hitrate:.4f}")

    # ==================== –£–õ–£–ß–®–ï–ù–ù–û–ï –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ï ====================
    print("\nüîÆ –ó–ê–ì–†–£–ó–ö–ê –õ–£–ß–®–ï–ô –ú–û–î–ï–õ–ò –î–õ–Ø –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô...")
    model.load_state_dict(torch.load('best_model.pth'))

    # –ü–æ–¥–≥–æ—Ç–æ–≤–∫–∞ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö
    test_dataset = MultiStrategyAppDataset(
        test_data['text'].tolist(),
        None,
        tokenizer,
        config.max_length,
        is_training=False
    )
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size * 2, shuffle=False)

    # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è —Å —É—Å—Ä–µ–¥–Ω–µ–Ω–∏–µ–º (TTA-like)
    model.eval()
    all_test_logits = []

    print("–ì–µ–Ω–µ—Ä–∞—Ü–∏—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π...")
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)
            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            all_test_logits.append(logits.logits.cpu().numpy())

    test_logits = np.vstack(all_test_logits)

    # –£–õ–£–ß–®–ï–ù–ù–ê–Ø –°–¢–†–ê–¢–ï–ì–ò–Ø –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô
    class_names = mlb.classes_
    predictions = []
    confidence_scores = []

    for i, logits_row in enumerate(test_logits):
        probs = 1 / (1 + np.exp(-logits_row))  # sigmoid –≤—Ä—É—á–Ω—É—é –¥–ª—è –∫–æ–Ω—Ç—Ä–æ–ª—è

        # –°—Ç—Ä–∞—Ç–µ–≥–∏—è 1: –¢–æ–ø-3 –ø–æ –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç–∏
        top3_indices = np.argsort(probs)[-3:][::-1]

        # –°—Ç—Ä–∞—Ç–µ–≥–∏—è 2: –ü—Ä–æ–≤–µ—Ä–∫–∞ —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç–∏
        top_prob = probs[top3_indices[0]]
        if top_prob < 0.3:  # –ù–∏–∑–∫–∞—è —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç—å
            # –ò—â–µ–º –∫–∞—Ç–µ–≥–æ—Ä–∏–∏ —Å –º–∞–∫—Å–∏–º–∞–ª—å–Ω–æ–π –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å—é –≤—ã—à–µ –ø–æ—Ä–æ–≥–∞
            high_conf_indices = np.where(probs > 0.2)[0]
            if len(high_conf_indices) > 0:
                top3_indices = high_conf_indices[np.argsort(-probs[high_conf_indices])][:3]
            # –ï—Å–ª–∏ –≤—Å–µ —Ä–∞–≤–Ω–æ –Ω–µ—Ç, –±–µ—Ä–µ–º —Ç–æ–ø-3 –ø–æ –ª–æ–≥–∏—Ç–∞–º
            if len(top3_indices) == 0:
                top3_indices = np.argsort(logits_row)[-3:][::-1]

        predicted_categories = [class_names[idx] for idx in top3_indices]
        predictions.append("|".join(predicted_categories))
        confidence_scores.append(np.mean(probs[top3_indices]))

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã
    submission = pd.DataFrame({
        'app_name': test_data['app_name'],
        'labels_str': predictions
    })

    submission.to_csv('enhanced_submission.tsv', sep='\t', index=False)
    print("üìÑ –£–õ–£–ß–®–ï–ù–ù–´–ï –†–ï–ó–£–õ–¨–¢–ê–¢–´ –°–û–•–†–ê–ù–ï–ù–´ –í enhanced_submission.tsv")

    # –î–µ—Ç–∞–ª—å–Ω—ã–π –∞–Ω–∞–ª–∏–∑ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
    print("\nüìä –î–ï–¢–ê–õ–¨–ù–´–ô –ê–ù–ê–õ–ò–ó –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô:")
    pred_counts = Counter([cat for pred in predictions for cat in pred.split('|')])
    top_predicted = pd.DataFrame({
        'category': list(pred_counts.keys()),
        'count': list(pred_counts.values())
    }).sort_values('count', ascending=False)

    print("–¢–æ–ø-15 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π:")
    print(top_predicted.head(15))

    # –ê–Ω–∞–ª–∏–∑ —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç–∏
    print(f"\nüéØ –°–¢–ê–¢–ò–°–¢–ò–ö–ê –£–í–ï–†–ï–ù–ù–û–°–¢–ò:")
    print(f"–°—Ä–µ–¥–Ω—è—è —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç—å: {np.mean(confidence_scores):.3f}")
    print(f"–ú–µ–¥–∏–∞–Ω–Ω–∞—è —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç—å: {np.median(confidence_scores):.3f}")
    print(f"–î–æ–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π —Å —É–≤–µ—Ä–µ–Ω–Ω–æ—Å—Ç—å—é > 0.5: {np.mean(np.array(confidence_scores) > 0.5):.3f}")

    # –°—Ä–∞–≤–Ω–µ–Ω–∏–µ —Å –Ω–∞–∏–±–æ–ª–µ–µ —á–∞—Å—Ç—ã–º–∏ –º–µ—Ç–∫–∞–º–∏ –∏–∑ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö
    train_label_counts = train_labels.sum(axis=0)
    top_train_labels = pd.DataFrame({
        'category': mlb.classes_,
        'count': train_label_counts
    }).sort_values('count', ascending=False).head(10)

    print("\nüìà –°–†–ê–í–ù–ï–ù–ò–ï –° –¢–†–ï–ù–ò–†–û–í–û–ß–ù–´–ú–ò –î–ê–ù–ù–´–ú–ò:")
    print("–¢–æ–ø-10 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π –≤ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö:")
    print(top_train_labels)

if __name__ == "__main__":
    main()

üöÄ –ó–ê–ü–£–°–ö –£–õ–£–ß–®–ï–ù–ù–û–ì–û –û–ë–£–ß–ï–ù–ò–Ø...
–£—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: cuda
–ú–æ–¥–µ–ª—å: sberbank-ai/ruBert-base
Batch size: 16
Learning rate: 1.5e-05
Max length: 300
–î–æ—Å—Ç—É–ø–Ω—ã–µ –∫–æ–ª–æ–Ω–∫–∏ –≤ train: ['app_name', 'full_description', 'shortDescription', 'labels_str']
–°—Ä–µ–¥–Ω—è—è –¥–ª–∏–Ω–∞ —Ç–µ–∫—Å—Ç–∞: 285 —Å–∏–º–≤–æ–ª–æ–≤
–ú–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è –¥–ª–∏–Ω–∞: 425 —Å–∏–º–≤–æ–ª–æ–≤

üìä –ê–Ω–∞–ª–∏–∑ –º–µ—Ç–æ–∫:
–í—Å–µ–≥–æ –∫–ª–∞—Å—Å–æ–≤: 45
–û–±—â–µ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ –º–µ—Ç–æ–∫: 63781
–ú–µ—Ç–æ–∫ –Ω–∞ –ø—Ä–∏–ª–æ–∂–µ–Ω–∏–µ: 1.19

–¢–æ–ø-10 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π:
         category  count
40          tools   6241
11         casual   5269
29         puzzle   4332
4          arcade   3917
14  entertainment   3739
34      simulator   3124
8        business   3076
13      education   3076
17   foodAndDrink   3005
0          action   2219

üìä –†–∞–∑–¥–µ–ª–µ–Ω–∏–µ –¥–∞–Ω–Ω—ã—Ö:
–¢—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã–µ: 48144
–í–∞–ª–∏–¥–∞—Ü–∏–æ–Ω–Ω—ã–µ: 535

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, BertModel, AutoConfig
from sklearn.preprocessing import MultiLabelBinarizer
import warnings
warnings.filterwarnings('ignore')

# ==================== –ö–û–ù–§–ò–ì–£–†–ê–¶–ò–Ø ====================
class Config:
    model_name = "sberbank-ai/ruBert-base"
    batch_size = 16
    max_length = 300
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()

# ==================== –ú–û–î–ï–õ–¨ ====================
class MultiLabelTransformerWithHeads(nn.Module):
    def __init__(self, model_name, num_labels, hidden_dropout_prob=0.3):
        super().__init__()
        self.config = AutoConfig.from_pretrained(model_name)
        self.backbone = BertModel.from_pretrained(model_name)
        self.category_head = nn.Linear(self.config.hidden_size, num_labels)
        self.genre_head = nn.Linear(self.config.hidden_size, num_labels)
        self.dropout = nn.Dropout(hidden_dropout_prob)
        self.layer_norm = nn.LayerNorm(self.config.hidden_size)

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        normalized_output = self.layer_norm(pooled_output)
        dropout_output = self.dropout(normalized_output)
        logits1 = self.category_head(dropout_output)
        logits2 = self.genre_head(dropout_output)
        logits = (logits1 + logits2) / 2
        return logits

# ==================== –î–ê–¢–ê–°–ï–¢ ====================
class TestDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=300):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

# ==================== –§–£–ù–ö–¶–ò–Ø –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–Ø ====================
def predict_on_test_data():
    print("üöÄ –ó–ê–ü–£–°–ö –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô –ù–ê –¢–ï–°–¢–û–í–´–• –î–ê–ù–ù–´–•...")
    print(f"–£—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {config.device}")
    print(f"–ú–æ–¥–µ–ª—å: {config.model_name}")

    # –ó–∞–≥—Ä—É–∑–∫–∞ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö
    try:
        test_data = pd.read_csv('/content/test.tsv', sep='\t')
        print(f"–ó–∞–≥—Ä—É–∂–µ–Ω–æ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö: {len(test_data)} —Å—Ç—Ä–æ–∫")
    except:
        # –ï—Å–ª–∏ —Ñ–∞–π–ª –Ω–µ –Ω–∞–π–¥–µ–Ω, —Å–æ–∑–¥–∞–µ–º –¥–µ–º–æ-–¥–∞–Ω–Ω—ã–µ
        print("‚ö†Ô∏è –§–∞–π–ª test.tsv –Ω–µ –Ω–∞–π–¥–µ–Ω! –°–æ–∑–¥–∞–µ–º –¥–µ–º–æ-–¥–∞–Ω–Ω—ã–µ...")
        test_data = pd.DataFrame({
            'app_name': [f'App_{i}' for i in range(100)],
            'description': [f'–û–ø–∏—Å–∞–Ω–∏–µ –ø—Ä–∏–ª–æ–∂–µ–Ω–∏—è {i}' for i in range(100)]
        })

    # –°–æ–∑–¥–∞–Ω–∏–µ —Ç–µ–∫—Å—Ç–∞ –¥–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è
    def create_text(row):
        parts = []
        if 'app_name' in row:
            app_name = str(row['app_name']).strip()
            if app_name and app_name != 'nan':
                parts.append(f"–ù–∞–∑–≤–∞–Ω–∏–µ: {app_name}")

        for col in ['shortDescription', 'short_description', 'description', 'full_description']:
            if col in row and pd.notna(row[col]):
                desc = str(row[col]).strip()
                if desc and desc != 'nan':
                    parts.append(f"–û–ø–∏—Å–∞–Ω–∏–µ: {desc}")
                    break

        if not parts:
            return "–ü—Ä–∏–ª–æ–∂–µ–Ω–∏–µ"
        return " ".join(parts)

    test_data['text'] = test_data.apply(create_text, axis=1)

    # –ó–∞–≥—Ä—É–∑–∫–∞ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä–∞
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # –°–æ–∑–¥–∞–Ω–∏–µ —Ç–µ—Å—Ç–æ–≤–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
    test_dataset = TestDataset(test_data['text'].tolist(), tokenizer, config.max_length)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

    # –û–ø—Ä–µ–¥–µ–ª—è–µ–º –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ –∫–ª–∞—Å—Å–æ–≤ (–Ω—É–∂–Ω–æ –∑–∞–≥—Ä—É–∑–∏—Ç—å –∏–∑ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–æ—á–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö –∏–ª–∏ —É—Å—Ç–∞–Ω–æ–≤–∏—Ç—å –≤—Ä—É—á–Ω—É—é)
    try:
        # –ü—Ä–æ–±—É–µ–º –∑–∞–≥—Ä—É–∑–∏—Ç—å mlb –∏–∑ —Ñ–∞–π–ª–∞ –∏–ª–∏ —Å–æ–∑–¥–∞—Ç—å –∑–∞–Ω–æ–≤–æ
        train_data = pd.read_csv('/content/train.tsv', sep='\t')
        mlb = MultiLabelBinarizer()
        mlb.fit(train_data['labels_str'].str.split('|'))
        num_classes = len(mlb.classes_)
        print(f"–ù–∞–π–¥–µ–Ω–æ –∫–ª–∞—Å—Å–æ–≤: {num_classes}")
    except:
        # –ï—Å–ª–∏ –Ω–µ —É–¥–∞–ª–æ—Å—å –∑–∞–≥—Ä—É–∑–∏—Ç—å, –∏—Å–ø–æ–ª—å–∑—É–µ–º —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–µ –∫–ª–∞—Å—Å—ã
        default_categories = ['–∏–≥—Ä–∞', '–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ', '–º—É–∑—ã–∫–∞', '–≤–∏–¥–µ–æ', '—Å–æ—Ü–∏–∞–ª—å–Ω—ã–π',
                            '–Ω–æ–≤–æ—Å—Ç–∏', '–º–∞–≥–∞–∑–∏–Ω', '–∫–Ω–∏–≥–∞', '–∑–¥–æ—Ä–æ–≤—å–µ', '–ø—É—Ç–µ—à–µ—Å—Ç–≤–∏–µ']
        mlb = MultiLabelBinarizer()
        mlb.fit([default_categories])  # –§–∏—Ç–∏—Ä—É–µ–º –Ω–∞ –≤—Å–µ–º —Å–ø–∏—Å–∫–µ
        num_classes = len(mlb.classes_)
        print(f"‚ö†Ô∏è –ò—Å–ø–æ–ª—å–∑—É–µ–º —Å—Ç–∞–Ω–¥–∞—Ä—Ç–Ω—ã–µ –∫–ª–∞—Å—Å—ã: {num_classes}")

    # –ó–∞–≥—Ä—É–∑–∫–∞ –º–æ–¥–µ–ª–∏
    model = MultiLabelTransformerWithHeads(config.model_name, num_classes)

    try:
        model.load_state_dict(torch.load('/content/drive/MyDrive/best_model (1).pth', map_location=config.device))
        print("‚úÖ –ú–æ–¥–µ–ª—å —É—Å–ø–µ—à–Ω–æ –∑–∞–≥—Ä—É–∂–µ–Ω–∞ –∏–∑ best_model.pth")
    except FileNotFoundError:
        print("‚ùå –§–∞–π–ª best_model.pth –Ω–µ –Ω–∞–π–¥–µ–Ω!")
        print("–ü–æ–∂–∞–ª—É–π—Å—Ç–∞, —É–±–µ–¥–∏—Ç–µ—Å—å —á—Ç–æ —Ñ–∞–π–ª —Å—É—â–µ—Å—Ç–≤—É–µ—Ç –≤ —Ç–µ–∫—É—â–µ–π –¥–∏—Ä–µ–∫—Ç–æ—Ä–∏–∏")
        return
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ –ø—Ä–∏ –∑–∞–≥—Ä—É–∑–∫–µ –º–æ–¥–µ–ª–∏: {e}")
        return

    model.to(config.device)
    model.eval()

    # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è
    all_predictions = []

    print("üéØ –ì–µ–Ω–µ—Ä–∞—Ü–∏—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π...")
    with torch.no_grad():
        for batch in test_loader:
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)

            logits = model(input_ids=input_ids, attention_mask=attention_mask)
            probabilities = torch.sigmoid(logits).cpu().numpy()

            # –í—ã–±–∏—Ä–∞–µ–º —Ç–æ–ø-3 –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è –¥–ª—è –∫–∞–∂–¥–æ–≥–æ –ø—Ä–∏–º–µ—Ä–∞
            for prob_row in probabilities:
                top_3_indices = np.argsort(prob_row)[-3:][::-1]  # –ò–Ω–¥–µ–∫—Å—ã —Ç–æ–ø-3 –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
                predicted_categories = [mlb.classes_[idx] for idx in top_3_indices]
                all_predictions.append("|".join(predicted_categories))

    # –°–æ–∑–¥–∞–Ω–∏–µ —Å–∞–±–º–∏—Ç–∞
    submission = pd.DataFrame({
        'app_name': test_data['app_name'],
        'labels_str': all_predictions
    })

    # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ä–µ–∑—É–ª—å—Ç–∞—Ç–æ–≤
    submission_file = 'test_predictions.tsv'
    submission.to_csv(submission_file, sep='\t', index=False)
    print(f"‚úÖ –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã –≤ {submission_file}")

    # –°—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
    print("\nüìä –°–¢–ê–¢–ò–°–¢–ò–ö–ê –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô:")
    print(f"–í—Å–µ–≥–æ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–æ: {len(submission)} —Å—Ç—Ä–æ–∫")

    # –ê–Ω–∞–ª–∏–∑ —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π
    all_predicted_categories = []
    for pred in all_predictions:
        all_predicted_categories.extend(pred.split('|'))

    from collections import Counter
    category_counts = Counter(all_predicted_categories)

    print("\n–¢–æ–ø-10 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π:")
    for category, count in category_counts.most_common(10):
        print(f"  {category}: {count} —Ä–∞–∑")

    # –ü–æ–∫–∞–∑—ã–≤–∞–µ–º –ø–µ—Ä–≤—ã–µ –Ω–µ—Å–∫–æ–ª—å–∫–æ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
    print(f"\nüìù –ü–ï–†–í–´–ï 10 –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô:")
    for i in range(min(10, len(submission))):
        print(f"  {i+1}. {submission.iloc[i]['app_name']} -> {submission.iloc[i]['labels_str']}")

    return submission

# ==================== –ê–õ–¨–¢–ï–†–ù–ê–¢–ò–í–ù–ê–Ø –ü–†–û–°–¢–ê–Ø –ú–û–î–ï–õ–¨ ====================
class SimpleBertClassifier(nn.Module):
    """–£–ø—Ä–æ—â–µ–Ω–Ω–∞—è –º–æ–¥–µ–ª—å –µ—Å–ª–∏ –æ—Å–Ω–æ–≤–Ω–∞—è –Ω–µ —Ä–∞–±–æ—Ç–∞–µ—Ç"""
    def __init__(self, model_name, num_labels):
        super().__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.dropout = nn.Dropout(0.3)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

def predict_with_simple_model():
    """–ê–ª—å—Ç–µ—Ä–Ω–∞—Ç–∏–≤–Ω–∞—è —Ñ—É–Ω–∫—Ü–∏—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è —Å —É–ø—Ä–æ—â–µ–Ω–Ω–æ–π –º–æ–¥–µ–ª—å—é"""
    print("üîÑ –ò—Å–ø–æ–ª—å–∑—É–µ–º —É–ø—Ä–æ—â–µ–Ω–Ω—É—é –º–æ–¥–µ–ª—å...")

    # –ó–∞–≥—Ä—É–∑–∫–∞ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö (–∞–Ω–∞–ª–æ–≥–∏—á–Ω–æ –æ—Å–Ω–æ–≤–Ω–æ–π —Ñ—É–Ω–∫—Ü–∏–∏)
    test_data = pd.read_csv('/content/test.tsv', sep='\t')
    test_data['text'] = test_data.apply(
        lambda x: f"{x['app_name']} {x.get('description', '')}", axis=1
    )

    # –û–ø—Ä–µ–¥–µ–ª—è–µ–º –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ –∫–ª–∞—Å—Å–æ–≤
    num_classes = 50  # –ú–æ–∂–Ω–æ –∏–∑–º–µ–Ω–∏—Ç—å –Ω–∞ –Ω—É–∂–Ω–æ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ

    # –ó–∞–≥—Ä—É–∂–∞–µ–º —É–ø—Ä–æ—â–µ–Ω–Ω—É—é –º–æ–¥–µ–ª—å
    model = SimpleBertClassifier(config.model_name, num_classes)

    try:
        model.load_state_dict(torch.load('best_model.pth', map_location=config.device))
        print("‚úÖ –£–ø—Ä–æ—â–µ–Ω–Ω–∞—è –º–æ–¥–µ–ª—å –∑–∞–≥—Ä—É–∂–µ–Ω–∞ —É—Å–ø–µ—à–Ω–æ")
    except:
        print("‚ùå –ù–µ —É–¥–∞–ª–æ—Å—å –∑–∞–≥—Ä—É–∑–∏—Ç—å –º–æ–¥–µ–ª—å. –ü—Ä–æ–≤–µ—Ä—å—Ç–µ –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä—É.")
        return

    # –û—Å—Ç–∞–ª—å–Ω–æ–π –∫–æ–¥ –∞–Ω–∞–ª–æ–≥–∏—á–µ–Ω –æ—Å–Ω–æ–≤–Ω–æ–π —Ñ—É–Ω–∫—Ü–∏–∏...
    return predict_on_test_data()

# ==================== –ó–ê–ü–£–°–ö ====================
if __name__ == "__main__":
    # –ü—Ä–æ–±—É–µ–º –æ—Å–Ω–æ–≤–Ω—É—é —Ñ—É–Ω–∫—Ü–∏—é
    try:
        result = predict_on_test_data()
        print("\nüéâ –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–Ø –£–°–ü–ï–®–ù–û –ó–ê–í–ï–†–®–ï–ù–´!")
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ –≤ –æ—Å–Ω–æ–≤–Ω–æ–π —Ñ—É–Ω–∫—Ü–∏–∏: {e}")
        print("–ü—Ä–æ–±—É–µ–º –∞–ª—å—Ç–µ—Ä–Ω–∞—Ç–∏–≤–Ω—ã–π –ø–æ–¥—Ö–æ–¥...")
        try:
            result = predict_with_simple_model()
        except Exception as e2:
            print(f"‚ùå –û—à–∏–±–∫–∞ –≤ –∞–ª—å—Ç–µ—Ä–Ω–∞—Ç–∏–≤–Ω–æ–π —Ñ—É–Ω–∫—Ü–∏–∏: {e2}")
            print("–ü–æ–∂–∞–ª—É–π—Å—Ç–∞, –ø—Ä–æ–≤–µ—Ä—å—Ç–µ:")
            print("1. –ù–∞–ª–∏—á–∏–µ —Ñ–∞–π–ª–∞ best_model.pth")
            print("2. –°–æ–æ—Ç–≤–µ—Ç—Å—Ç–≤–∏–µ –∞—Ä—Ö–∏—Ç–µ–∫—Ç—É—Ä—ã –º–æ–¥–µ–ª–∏")
            print("3. –ù–∞–ª–∏—á–∏–µ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö")

üöÄ –ó–ê–ü–£–°–ö –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô –ù–ê –¢–ï–°–¢–û–í–´–• –î–ê–ù–ù–´–•...
–£—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: cuda
–ú–æ–¥–µ–ª—å: sberbank-ai/ruBert-base
–ó–∞–≥—Ä—É–∂–µ–Ω–æ —Ç–µ—Å—Ç–æ–≤—ã—Ö –¥–∞–Ω–Ω—ã—Ö: 15046 —Å—Ç—Ä–æ–∫
–ù–∞–π–¥–µ–Ω–æ –∫–ª–∞—Å—Å–æ–≤: 45
‚úÖ –ú–æ–¥–µ–ª—å —É—Å–ø–µ—à–Ω–æ –∑–∞–≥—Ä—É–∂–µ–Ω–∞ –∏–∑ best_model.pth
üéØ –ì–µ–Ω–µ—Ä–∞—Ü–∏—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π...
‚úÖ –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è —Å–æ—Ö—Ä–∞–Ω–µ–Ω—ã –≤ test_predictions.tsv

üìä –°–¢–ê–¢–ò–°–¢–ò–ö–ê –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô:
–í—Å–µ–≥–æ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–æ: 15046 —Å—Ç—Ä–æ–∫

–¢–æ–ø-10 —Å–∞–º—ã—Ö —á–∞—Å—Ç—ã—Ö –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã—Ö –∫–∞—Ç–µ–≥–æ—Ä–∏–π:
  casual: 4232 —Ä–∞–∑
  tools: 4192 —Ä–∞–∑
  arcade: 3639 —Ä–∞–∑
  entertainment: 3098 —Ä–∞–∑
  business: 2924 —Ä–∞–∑
  simulator: 2407 —Ä–∞–∑
  puzzle: 2259 —Ä–∞–∑
  education: 1972 —Ä–∞–∑
  action: 1850 —Ä–∞–∑
  lifestyle: 1691 —Ä–∞–∑

üìù –ü–ï–†–í–´–ï 10 –ü–†–ï–î–°–ö–ê–ó–ê–ù–ò–ô:
  1. Lemon clicker -> casual|arcade|simulator
  2. Memo –ê–Ω–≥–ª–∏–π—Å–∫–∏–π —è–∑—ã–∫ 