In [23]:
import re
import logging
from typing import Dict, List, Tuple
from datetime import datetime

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('./data/log_file.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

class My_TextNormalization_Model:
    """
    Improved rule-based Russian text normalization model for TTS systems.
    """

    def __init__(self):
        """Initialize the rule-based text normalization model."""
        logger.info("Initializing Improved Russian Text Normalization Model")

        self.numbers = {
            '0': 'ноль', '1': 'один', '2': 'два', '3': 'три', '4': 'четыре',
            '5': 'пять', '6': 'шесть', '7': 'семь', '8': 'восемь', '9': 'девять',
            '10': 'десять', '11': 'одиннадцать', '12': 'двенадцать',
            '13': 'тринадцать', '14': 'четырнадцать', '15': 'пятнадцать',
            '16': 'шестнадцать', '17': 'семнадцать', '18': 'восемнадцать',
            '19': 'девятнадцать', '20': 'двадцать', '30': 'тридцать',
            '40': 'сорок', '50': 'пятьдесят', '60': 'шестьдесят',
            '70': 'семьдесят', '80': 'восемьдесят', '90': 'девяносто'
        }

        self.hundreds = {
            '100': 'сто', '200': 'двести', '300': 'триста', '400': 'четыреста',
            '500': 'пятьсот', '600': 'шестьсот', '700': 'семьсот',
            '800': 'восемьсот', '900': 'девятьсот'
        }

        self.ordinals = {
            '1': 'первое', '2': 'второе', '3': 'третье', '4': 'четвертое',
            '5': 'пятое', '6': 'шестое', '7': 'седьмое', '8': 'восьмое',
            '9': 'девятое', '10': 'десятое', '11': 'одиннадцатое',
            '12': 'двенадцатое', '13': 'тринадцатое', '14': 'четырнадцатое',
            '15': 'пятнадцатое', '16': 'шестнадцатое', '17': 'семнадцатое',
            '18': 'восемнадцатое', '19': 'девятнадцатое', '20': 'двадцатое',
            '21': 'двадцать первое', '22': 'двадцать второе', '23': 'двадцать третье',
            '24': 'двадцать четвертое', '25': 'двадцать пятое', '26': 'двадцать шестое',
            '27': 'двадцать седьмое', '28': 'двадцать восьмое', '29': 'двадцать девятое',
            '30': 'тридцатое', '31': 'тридцать первое'
        }

        self.months = {
            '1': 'января', '2': 'февраля', '3': 'марта', '4': 'апреля',
            '5': 'мая', '6': 'июня', '7': 'июля', '8': 'августа',
            '9': 'сентября', '10': 'октября', '11': 'ноября', '12': 'декабря',
            '01': 'января', '02': 'февраля', '03': 'марта', '04': 'апреля',
            '05': 'мая', '06': 'июня', '07': 'июля', '08': 'августа',
            '09': 'сентября'
        }

        self.roman_numerals = {
            'I': 'первый', 'II': 'второй', 'III': 'третий', 'IV': 'четвертый',
            'V': 'пятый', 'VI': 'шестой', 'VII': 'седьмой', 'VIII': 'восьмой',
            'IX': 'девятый', 'X': 'десятый', 'XI': 'одиннадцатый',
            'XII': 'двенадцатый', 'XIII': 'тринадцатый', 'XIV': 'четырнадцатый',
            'XV': 'пятнадцатый', 'XVI': 'шестнадцатый', 'XVII': 'семнадцатый',
            'XVIII': 'восемнадцатый', 'XIX': 'девятнадцатый', 'XX': 'двадцатый'
        }

        logger.info("Improved Russian Text Normalization Model initialized successfully")

    def normalize_text(self, text: str) -> str:
        """Main method to normalize Russian text for TTS."""
        if not text or not isinstance(text, str):
            logger.warning("Empty or invalid input text received")
            return ""

        logger.info(f"Normalizing text: {text[:50]}...")

        try:
            normalized_text = self._preprocess_text(text)

            normalized_text = self._normalize_dates(normalized_text)
            normalized_text = self._normalize_time(normalized_text)
            normalized_text = self._normalize_currency(normalized_text)
            normalized_text = self._normalize_measurements(normalized_text)
            normalized_text = self._normalize_percentages(normalized_text)
            normalized_text = self._normalize_phone_numbers(normalized_text)
            normalized_text = self._normalize_urls_emails(normalized_text)
            normalized_text = self._normalize_abbreviations(normalized_text)
            normalized_text = self._normalize_roman_numerals(normalized_text)
            normalized_text = self._normalize_numbers(normalized_text)  # Numbers last to avoid conflicts
            normalized_text = self._normalize_punctuation(normalized_text)

            normalized_text = self._postprocess_text(normalized_text)

            logger.info(f"Text normalization completed successfully")
            return normalized_text

        except Exception as e:
            logger.error(f"Error during text normalization: {str(e)}")
            return text

    def _preprocess_text(self, text: str) -> str:
        """Preprocess text by cleaning and standardizing format."""
        text = re.sub(r'\s+', ' ', text.strip())
        text = text.replace('—', '-').replace('–', '-')
        text = text.replace('"', '"').replace('"', '"')
        text = text.replace(''', "'").replace(''', "'")
        return text

    def _normalize_dates(self, text: str) -> str:
        """Normalize dates in various formats."""
        def convert_date(match):
            day, month, year = match.groups()
            day_word = self.ordinals.get(str(int(day)), f"{day}-е")
            month_word = self.months.get(str(int(month)), month)
            year_word = self._convert_year(year)
            return f"{day_word} {month_word} {year_word} года"

        text = re.sub(r'\b(\d{1,2})[./](\d{1,2})[./](\d{4})\b', convert_date, text)

        def convert_short_date(match):
            day, month = match.groups()
            day_word = self.ordinals.get(str(int(day)), f"{day}-е")
            month_word = self.months.get(str(int(month)), month)
            return f"{day_word} {month_word}"

        text = re.sub(r'\b(\d{1,2})[./](\d{1,2})\b(?!/)', convert_short_date, text)
        return text

    def _normalize_time(self, text: str) -> str:
        """Normalize time expressions - FIXED VERSION."""
        def convert_time(match):
            hours, minutes = match.groups()
            hour_int = int(hours)
            minute_int = int(minutes)

            hour_word = self._number_to_words(hour_int)

            if minute_int == 0:
                return f"{hour_word} часов"
            else:
                minute_word = self._number_to_words(minute_int)
                return f"{hour_word} часов {minute_word} минут"

        text = re.sub(r'\b(\d{1,2}):(\d{2})\b', convert_time, text)
        return text

    def _normalize_currency(self, text: str) -> str:
        """Normalize currency expressions - FIXED VERSION."""
        def convert_currency(match):
            amount, currency = match.groups()
            amount_int = int(amount)
            amount_word = self._number_to_words(amount_int)

            if currency == '₽' or currency == 'руб':
                if amount_int == 1:
                    currency_word = 'рубль'
                elif 2 <= amount_int <= 4:
                    currency_word = 'рубля'
                else:
                    currency_word = 'рублей'
            elif currency == '$':
                if amount_int == 1:
                    currency_word = 'доллар'
                elif 2 <= amount_int <= 4:
                    currency_word = 'доллара'
                else:
                    currency_word = 'долларов'
            elif currency == '€':
                currency_word = 'евро'
            else:
                currency_word = currency

            return f"{amount_word} {currency_word}"

        text = re.sub(r'(\d+)\s*([₽$€]|руб\.?)', convert_currency, text)

        def convert_decimal_currency(match):
            amount, currency = match.groups()
            amount_word = self._convert_decimal_number(amount)
            currency_word = self._get_currency_word(currency)
            return f"{amount_word} {currency_word}"

        text = re.sub(r'(\d+[.,]\d+)\s*([₽$€]|руб\.?)', convert_decimal_currency, text)
        return text

    def _normalize_measurements(self, text: str) -> str:
        """Normalize measurement units - FIXED VERSION."""
        def convert_measurement(match):
            amount, unit = match.groups()

            if '.' in amount or ',' in amount:
                amount_word = self._convert_decimal_number(amount)
            else:
                amount_word = self._number_to_words(int(amount))

            unit_mappings = {
                'кг': 'килограммов',
                'г': 'граммов',
                'км': 'километров',
                'м': 'метров',
                'см': 'сантиметров',
                'мм': 'миллиметров',
                'л': 'литров',
                'мл': 'миллилитров',
                '°C': 'градусов цельсия',
                '°': 'градусов'
            }

            unit_word = unit_mappings.get(unit, unit)
            return f"{amount_word} {unit_word}"

        text = re.sub(r'(\d+(?:[.,]\d+)?)\s*(кг|г|км|м|см|мм|л|мл|°C|°)', convert_measurement, text)
        return text

    def _normalize_percentages(self, text: str) -> str:
        """Normalize percentage expressions - FIXED VERSION."""
        def convert_percentage(match):
            amount = match.group(1)
            if '.' in amount or ',' in amount:
                amount_word = self._convert_decimal_number(amount)
            else:
                amount_word = self._number_to_words(int(amount))
            return f"{amount_word} процентов"

        text = re.sub(r'(\d+(?:[.,]\d+)?)%', convert_percentage, text)
        return text

    def _normalize_phone_numbers(self, text: str) -> str:
        """Normalize phone numbers - FIXED VERSION."""
        def convert_phone(match):
            phone = match.group()
            # Extract just the digits
            digits = re.sub(r'[^\d]', '', phone)

            # Convert each digit individually for phone numbers
            spoken_digits = []
            for digit in digits:
                spoken_digits.append(self.numbers.get(digit, digit))

            return ' '.join(spoken_digits)

        phone_patterns = [
            r'\+7\s*\(\d{3}\)\s*\d{3}-\d{2}-\d{2}',
            r'\+7\s*\d{10}',
            r'8\s*\(\d{3}\)\s*\d{3}-\d{2}-\d{2}',
            r'8\s*\d{10}'
        ]

        for pattern in phone_patterns:
            text = re.sub(pattern, convert_phone, text)

        return text

    def _normalize_urls_emails(self, text: str) -> str:
        """Normalize URLs and email addresses."""
        text = re.sub(r'https?://[^\s]+', 'ссылка', text)
        text = re.sub(r'www\.[^\s]+', 'веб сайт', text)
        text = re.sub(r'\b[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}\b', 'электронная почта', text)
        return text

    def _normalize_abbreviations(self, text: str) -> str:
        """Normalize common abbreviations."""
        abbreviations = {
            'т.е.': 'то есть', 'и т.д.': 'и так далее', 'и т.п.': 'и тому подобное',
            'т.к.': 'так как', 'т.н.': 'так называемый', 'см.': 'смотрите',
            'стр.': 'страница', 'гл.': 'глава', 'рис.': 'рисунок',
            'табл.': 'таблица', 'г.': 'год', 'гг.': 'годы', 'в.': 'век',
            'вв.': 'века', 'др.': 'другие', 'пр.': 'прочие',
            'напр.': 'например', 'англ.': 'английский', 'рус.': 'русский'
        }

        for abbr, expansion in abbreviations.items():
            text = re.sub(r'\b' + re.escape(abbr) + r'\b', expansion, text, flags=re.IGNORECASE)

        return text

    def _normalize_roman_numerals(self, text: str) -> str:
        """Normalize Roman numerals."""
        def convert_roman(match):
            roman = match.group()
            return self.roman_numerals.get(roman, roman)

        text = re.sub(r'\b(I{1,3}|IV|V|VI{0,3}|IX|X{1,2}|XI{0,3}|XIV|XV|XVI{0,3}|XIX|XX)\b', convert_roman, text)
        return text

    def _normalize_numbers(self, text: str) -> str:
        """Normalize standalone numbers to words."""
        def convert_number(match):
            number = match.group()
            try:
                num = int(number)
                return self._number_to_words(num)
            except ValueError:
                return number

        text = re.sub(r'\b\d{1,4}\b', convert_number, text)
        return text

    def _normalize_punctuation(self, text: str) -> str:
        """Handle punctuation marks."""
        text = text.replace('&', ' и ')
        text = text.replace('+', ' плюс ')
        text = text.replace('=', ' равно ')
        text = text.replace('№', 'номер ')

        text = re.sub(r'[()[\]{}]', '', text)
        text = re.sub(r'[.!?]+', '.', text)
        text = re.sub(r'[,;:]+', ',', text)

        return text

    def _postprocess_text(self, text: str) -> str:
        """Final cleanup of normalized text."""
        text = re.sub(r'\s+', ' ', text.strip())
        text = re.sub(r'[.!?,;:]+$', '', text)
        text = text.lower()
        return text

    def _number_to_words(self, num: int) -> str:
        """Convert numbers to Russian words - IMPROVED VERSION."""
        if num == 0:
            return 'ноль'

        if 1 <= num <= 19:
            return self.numbers[str(num)]
        elif 20 <= num <= 99:
            tens = (num // 10) * 10
            units = num % 10
            if units == 0:
                return self.numbers[str(tens)]
            else:
                return f"{self.numbers[str(tens)]} {self.numbers[str(units)]}"
        elif 100 <= num <= 999:
            hundreds = (num // 100) * 100
            remainder = num % 100
            result = self.hundreds[str(hundreds)]
            if remainder > 0:
                result += f" {self._number_to_words(remainder)}"
            return result
        elif 1000 <= num <= 9999:
            thousands = num // 1000
            remainder = num % 1000

            if thousands == 1:
                result = "тысяча"
            elif 2 <= thousands <= 4:
                result = f"{self.numbers[str(thousands)]} тысячи"
            else:
                result = f"{self.numbers[str(thousands)]} тысяч"

            if remainder > 0:
                result += f" {self._number_to_words(remainder)}"
            return result
        else:
            return str(num)

    def _convert_decimal_number(self, number_str: str) -> str:
        """Convert decimal numbers to words."""
        number_str = number_str.replace(',', '.')

        try:
            if '.' in number_str:
                integer_part, decimal_part = number_str.split('.')
                integer_word = self._number_to_words(int(integer_part))
                decimal_digits = ' '.join([self.numbers.get(digit, digit) for digit in decimal_part])
                return f"{integer_word} целых {decimal_digits}"
            else:
                return self._number_to_words(int(number_str))
        except ValueError:
            return number_str

    def _convert_year(self, year: str) -> str:
        """Convert year to spoken form."""
        try:
            year_int = int(year)
            if 1000 <= year_int <= 2099:
                return self._number_to_words(year_int)
            else:
                return year
        except ValueError:
            return year

    def _get_currency_word(self, currency: str) -> str:
        """Get proper currency word."""
        currency_map = {
            '₽': 'рублей', 'руб': 'рублей', '$': 'долларов', '€': 'евро'
        }
        return currency_map.get(currency, currency)

if __name__ == "__main__":
    model = My_TextNormalization_Model()

    test_texts = [
        "Сегодня 15.03.2024 в 14:30 температура была 25°C.",
        "Цена составляет 1500₽ за 2.5кг.",
        "Позвоните по номеру +7(495)123-45-67.",
        "Эффективность увеличилась на 15%.",
        "В XX веке произошло много изменений.",
        "Встреча назначена на 10:00."
    ]

    print("IMPROVED RULE-BASED MODEL RESULTS:")
    print("=" * 60)

    for text in test_texts:
        normalized = model.normalize_text(text)
        print(f"Original: {text}")
        print(f"Normalized: {normalized}")
        print("-" * 50)

2025-05-31 20:39:48,473 - __main__ - INFO - Initializing Improved Russian Text Normalization Model
2025-05-31 20:39:48,473 - __main__ - INFO - Improved Russian Text Normalization Model initialized successfully
2025-05-31 20:39:48,474 - __main__ - INFO - Normalizing text: Сегодня 15.03.2024 в 14:30 температура была 25°C....
2025-05-31 20:39:48,475 - __main__ - INFO - Text normalization completed successfully
2025-05-31 20:39:48,476 - __main__ - INFO - Normalizing text: Цена составляет 1500₽ за 2.5кг....
2025-05-31 20:39:48,476 - __main__ - INFO - Text normalization completed successfully
2025-05-31 20:39:48,476 - __main__ - INFO - Normalizing text: Позвоните по номеру +7(495)123-45-67....
2025-05-31 20:39:48,477 - __main__ - INFO - Text normalization completed successfully
2025-05-31 20:39:48,477 - __main__ - INFO - Normalizing text: Эффективность увеличилась на 15%....
2025-05-31 20:39:48,477 - __main__ - INFO - Text normalization completed successfully
2025-05-31 20:39:48,478 - __main

IMPROVED RULE-BASED MODEL RESULTS:
Original: Сегодня 15.03.2024 в 14:30 температура была 25°C.
Normalized: сегодня пятнадцатое марта два тысячи двадцать четыре года в четырнадцать часов тридцать минут температура была двадцать пять градусов цельсия
--------------------------------------------------
Original: Цена составляет 1500₽ за 2.5кг.
Normalized: цена составляет тысяча пятьсот рублей за два целых пять килограммов
--------------------------------------------------
Original: Позвоните по номеру +7(495)123-45-67.
Normalized: позвоните по номеру семь четыре девять пять один два три четыре пять шесть семь
--------------------------------------------------
Original: Эффективность увеличилась на 15%.
Normalized: эффективность увеличилась на пятнадцать процентов
--------------------------------------------------
Original: В XX веке произошло много изменений.
Normalized: в двадцатый веке произошло много изменений
--------------------------------------------------
Original: Встреча назначен

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup,
    AutoTokenizer,
    AutoModelForSeq2SeqLM
)
from torch.optim import AdamW
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
import logging
from typing import List, Dict, Tuple, Optional
import wandb
from tqdm import tqdm
import re
import json
from dataclasses import dataclass
import os
from pathlib import Path
import unicodedata
from collections import Counter, defaultdict

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('./data/log_file.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

@dataclass
class TrainingConfig:
    """Configuration optimized for small dataset (13k samples)."""
    model_name: str = "ai-forever/ruT5-base"
    max_source_length: int = 128
    max_target_length: int = 128

    batch_size: int = 8
    learning_rate: float = 1e-4
    num_epochs: int = 1
    warmup_steps: int = 100
    weight_decay: float = 0.05
    gradient_clip_val: float = 0.5

    save_steps: int = 200
    eval_steps: int = 100

    # Regularization settings
    dropout_rate: float = 0.3
    use_early_stopping: bool = True
    patience: int = 5

    sample_size: int = 15000
    stratify_by_class: bool = True
    stratify_by_length: bool = True

    # Performance
    dataloader_num_workers: int = 4
    mixed_precision: bool = True
    compile_model: bool = False
    use_wandb: bool = True
    seed: int = 42

class RobustRussianTextCleaner:
    """Improved text cleaner based on your implementation."""

    def __init__(self):
        self.logger = logging.getLogger(__name__)

        self.russian_letters = set('абвгдеёжзийклмнопрстуфхцчшщъыьэюяАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯ')

        self.valid_chars = (
            self.russian_letters |
            set('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789') |
            set(' .,!?-()[]{}";:\'"/@#$%^&*+=<>~`|\\€$₽°')
        )

    def safe_load_csv(self, file_path: str) -> pd.DataFrame:
        """Safely load CSV with encoding issues handling."""
        encodings_to_try = ['utf-8', 'utf-8-sig', 'cp1251', 'iso-8859-1', 'latin1']

        for encoding in encodings_to_try:
            try:
                self.logger.info(f"Trying to load with encoding: {encoding}")
                df = pd.read_csv(file_path, encoding=encoding)
                self.logger.info(f"Successfully loaded with {encoding}")
                return df
            except UnicodeDecodeError:
                continue
            except Exception as e:
                self.logger.warning(f"Failed with {encoding}: {e}")
                continue

        try:
            self.logger.info("Loading with error handling (replacing bad characters)")
            df = pd.read_csv(file_path, encoding='utf-8', encoding_errors='replace')
            return df
        except Exception as e:
            self.logger.error(f"All encoding attempts failed: {e}")
            return None

    def clean_dataset(self, df: pd.DataFrame, aggressive_cleaning: bool = False) -> pd.DataFrame:
        """
        Clean Russian text normalization dataset with robust character handling.

        Args:
            df: DataFrame with columns ['sentence_id', 'token_id', 'class', 'before', 'after']
            aggressive_cleaning: If True, remove more suspicious data (set to False by default)
        """
        self.logger.info(f"Starting robust cleaning of {len(df)} rows")
        original_size = len(df)

        df = self._handle_basic_issues(df)
        self._log_step("Basic data handling", original_size, len(df))

        df = self._fix_character_encoding(df)
        self._log_step("Character encoding fix", original_size, len(df))

        df = self._clean_problematic_characters(df, gentle=True)
        self._log_step("Problematic character removal", original_size, len(df))

        if aggressive_cleaning:
            df = self._validate_normalizations(df)
            self._log_step("Normalization validation", original_size, len(df))

        df = self._handle_clear_duplicates(df)
        self._log_step("Clear duplicate handling", original_size, len(df))

        df = self._minimal_final_cleanup(df)
        self._log_step("Minimal final cleanup", original_size, len(df))

        cleaned_size = len(df)
        removal_percentage = ((original_size - cleaned_size) / original_size) * 100
        self.logger.info(f"Gentle cleaning complete: {original_size} -> {cleaned_size} rows "
                        f"({removal_percentage:.1f}% removed)")

        return df.reset_index(drop=True)

    def _handle_basic_issues(self, df: pd.DataFrame) -> pd.DataFrame:
        """Handle basic data integrity issues."""
        required_cols = ['before', 'after']
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing required columns: {missing_cols}")

        df['before'] = df['before'].astype(str).replace('nan', '')
        df['after'] = df['after'].astype(str).replace('nan', '')

        initial_len = len(df)
        df = df[(df['before'].str.strip() != '') & (df['after'].str.strip() != '')]
        self.logger.info(f"Removed {initial_len - len(df)} clearly empty rows")

        return df

    def _fix_character_encoding(self, df: pd.DataFrame) -> pd.DataFrame:
        """Fix character encoding issues."""
        def fix_encoding(text):
            if pd.isna(text):
                return ""

            text = str(text)

            try:
                text = unicodedata.normalize('NFKC', text)
            except:
                pass

            encoding_fixes = {
                'Ã¡': 'а', 'Ã ': 'а', 'Ã«': 'е', 'Ã¬': 'и', 'Ã®': 'о', 'Ã³': 'у',
                'â€œ': '"', 'â€': '"', 'â€™': "'", 'â€"': '–', 'â€"': '—'
            }

            for wrong, correct in encoding_fixes.items():
                text = text.replace(wrong, correct)

            return text

        df['before'] = df['before'].apply(fix_encoding)
        df['after'] = df['after'].apply(fix_encoding)

        return df

    def _clean_problematic_characters(self, df: pd.DataFrame, gentle: bool = True) -> pd.DataFrame:
        """Remove problematic characters with gentle approach."""
        def clean_text(text, gentle_mode=True):
            if pd.isna(text):
                return ""

            text = str(text)

            if gentle_mode:
                problematic_chars = set(['', '', '', '﻿', '\x00', '\x01', '\x02', '\x03'])
                text = ''.join(char for char in text if char not in problematic_chars)

                text = re.sub(r'\s+', ' ', text).strip()
            else:
                text = ''.join(char for char in text if char in self.valid_chars or char.isspace())
                text = re.sub(r'\s+', ' ', text).strip()

            return text

        initial_len = len(df)

        df['before'] = df['before'].apply(lambda x: clean_text(x, gentle))
        df['after'] = df['after'].apply(lambda x: clean_text(x, gentle))

        df = df[(df['before'].str.strip() != '') & (df['after'].str.strip() != '')]

        self.logger.info(f"Character cleaning removed {initial_len - len(df)} rows")
        return df

    def _validate_normalizations(self, df: pd.DataFrame) -> pd.DataFrame:
        """Validate normalizations (only if aggressive cleaning is enabled)."""
        initial_len = len(df)

        def is_valid_normalization(before, after):
            if len(after) > len(before) * 5:
                return False

            if len(after) < len(before) * 0.1 and len(before) > 10:
                return False

            return True

        mask = df.apply(lambda row: is_valid_normalization(row['before'], row['after']), axis=1)
        df = df[mask]

        self.logger.info(f"Normalization validation removed {initial_len - len(df)} rows")
        return df

    def _handle_clear_duplicates(self, df: pd.DataFrame) -> pd.DataFrame:
        """Handle only clear duplicates."""
        initial_len = len(df)

        df = df.drop_duplicates(subset=['before', 'after'])

        self.logger.info(f"Duplicate removal: {initial_len - len(df)} exact duplicates removed")
        return df

    def _minimal_final_cleanup(self, df: pd.DataFrame) -> pd.DataFrame:
        """Minimal final cleanup."""
        initial_len = len(df)

        df = df[df['before'].str.len() <= 500]
        df = df[df['after'].str.len() <= 600]

        self.logger.info(f"Final cleanup removed {initial_len - len(df)} overly long texts")
        return df

    def _log_step(self, step_name: str, original_size: int, current_size: int):
        """Log cleaning step results."""
        removed = original_size - current_size
        percentage = (removed / original_size) * 100 if original_size > 0 else 0
        self.logger.info(f"{step_name}: {original_size} -> {current_size} "
                        f"({removed} removed, {percentage:.1f}%)")

class StratifiedDataSampler:
    """Intelligent sampling to get a representative subset of the large dataset."""

    def __init__(self, config: TrainingConfig):
        self.config = config
        self.logger = logging.getLogger(__name__)

    def sample_dataset(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Create a stratified sample of the dataset for fast training.

        Args:
            df: Full cleaned dataset

        Returns:
            Stratified sample for training
        """
        self.logger.info(f"Creating stratified sample of {self.config.sample_size} from {len(df)} rows")

        if len(df) <= self.config.sample_size:
            self.logger.info("Dataset smaller than sample size, using full dataset")
            return df

        if 'class' in df.columns and self.config.stratify_by_class:
            df_sampled = self._stratify_by_class(df)
        else:
            df_sampled = df.copy()

        if self.config.stratify_by_length:
            df_sampled = self._stratify_by_length(df_sampled)

        if len(df_sampled) > self.config.sample_size:
            df_sampled = df_sampled.sample(n=self.config.sample_size, random_state=42)

        self.logger.info(f"Final sample size: {len(df_sampled)}")
        return df_sampled.reset_index(drop=True)

    def _stratify_by_class(self, df: pd.DataFrame) -> pd.DataFrame:
        """Stratify sampling by normalization class."""
        class_counts = df['class'].value_counts()
        self.logger.info(f"Class distribution: {dict(class_counts)}")

        total_classes = len(class_counts)
        base_samples_per_class = self.config.sample_size // total_classes

        sampled_dfs = []
        for class_name in class_counts.index:
            class_df = df[df['class'] == class_name]

            samples_needed = min(len(class_df), max(base_samples_per_class, 100))

            if len(class_df) > samples_needed:
                class_sample = class_df.sample(n=samples_needed, random_state=42)
            else:
                class_sample = class_df

            sampled_dfs.append(class_sample)
            self.logger.info(f"Class '{class_name}': {len(class_sample)} samples")

        return pd.concat(sampled_dfs, ignore_index=True)

    def _stratify_by_length(self, df: pd.DataFrame) -> pd.DataFrame:
        """Stratify sampling by text length to ensure diverse examples."""
        df['text_length'] = df['before'].str.len()

        df['length_category'] = pd.cut(
            df['text_length'],
            bins=[0, 20, 50, 100, float('inf')],
            labels=['short', 'medium', 'long', 'very_long']
        )

        length_counts = df['length_category'].value_counts()
        self.logger.info(f"Length distribution: {dict(length_counts)}")

        samples_per_length = self.config.sample_size

        sampled_dfs = []
        for length_cat in ['short', 'medium', 'long', 'very_long']:
            cat_df = df[df['length_category'] == length_cat]

            if len(cat_df) == 0:
                continue

            samples_needed = min(len(cat_df), samples_per_length)

            if len(cat_df) > samples_needed:
                cat_sample = cat_df.sample(n=samples_needed, random_state=42)
            else:
                cat_sample = cat_df

            sampled_dfs.append(cat_sample)
            self.logger.info(f"Length '{length_cat}': {len(cat_sample)} samples")

        result_df = pd.concat(sampled_dfs, ignore_index=True)

        return result_df.drop(['text_length', 'length_category'], axis=1)

class FastTextNormalizationDataset(Dataset):
    """Optimized dataset class for faster training."""

    def __init__(
        self,
        data: List[Tuple[str, str]],
        tokenizer: T5Tokenizer,
        max_source_length: int = 96,
        max_target_length: int = 96
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length

        logger.info(f"Initialized fast dataset with {len(self.data)} examples")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        source_text, target_text = self.data[idx]

        source_text = f"normalize: {source_text}"

        source_encoding = self.tokenizer(
            source_text,
            max_length=self.max_source_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_encoding = self.tokenizer(
            target_text,
            max_length=self.max_target_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': source_encoding['input_ids'].flatten(),
            'attention_mask': source_encoding['attention_mask'].flatten(),
            'labels': target_encoding['input_ids'].flatten()
        }

class FastT5TextNormalizer:
    """Optimized T5 text normalizer for fast training."""

    def __init__(self, config: TrainingConfig):
        self.config = config
        if torch.backends.mps.is_available():
            self.device = torch.device('mps')
            print("Using Apple Silicon MPS backend")
        elif torch.cuda.is_available():
            self.device = torch.device('cuda')
            print("Using CUDA")
        else:
            self.device = torch.device('cpu')
            print("Using CPU")

        logger.info(f"Initializing Fast T5 model on device: {self.device}")

        self.tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(config.model_name)
        self.model.to(self.device)

        self.use_amp = config.mixed_precision and (self.device.type in ['cuda', 'mps'])
        if self.use_amp:
            self.scaler = GradScaler()
            print("Using Automatic Mixed Precision")

        if config.compile_model and hasattr(torch, 'compile'):
            self.model = torch.compile(self.model)
            print("Model compiled with PyTorch 2.0")
        if hasattr(self.model, 'gradient_checkpointing_enable'):
            self.model.gradient_checkpointing_enable()
        if hasattr(self.model.config, 'use_cache'):
            self.model.config.use_cache = False
        if self.device.type == 'mps':
            torch.mps.empty_cache()

        logger.info(f"Loaded model: {config.model_name}")

    def train_fast(self, train_data: List[Tuple[str, str]], val_data: List[Tuple[str, str]]):
        """Fast training method (as I myself cannot do long trainings :[ )."""
        logger.info(f"Starting FAST T5 training with {len(train_data)} train, {len(val_data)} val examples")

        if self.config.use_wandb:
            wandb.init(
                project="russian-text-normalization-fast",
                config=self.config.__dict__,
                name=f"ruT5-fast-{self.config.sample_size}"
            )

        train_dataset = FastTextNormalizationDataset(
            train_data, self.tokenizer,
            self.config.max_source_length, self.config.max_target_length
        )
        val_dataset = FastTextNormalizationDataset(
            val_data, self.tokenizer,
            self.config.max_source_length, self.config.max_target_length
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.dataloader_num_workers,
            pin_memory=self.config.pin_memory,
            drop_last=True,
            persistent_workers=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        optimizer = AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
            eps=1e-6
        )

        total_steps = len(train_loader) * self.config.num_epochs
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=total_steps
        )

        self.model.train()
        global_step = 0
        best_val_loss = float('inf')

        for epoch in range(self.config.num_epochs):
            logger.info(f"Starting epoch {epoch + 1}/{self.config.num_epochs}")

            epoch_loss = 0
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")

            for batch_idx, batch in enumerate(progress_bar):
                batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}

                if self.use_amp:
                    with autocast():
                        outputs = self.model(**batch)
                        loss = outputs.loss

                    self.scaler.scale(loss).backward()
                    self.scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.config.gradient_clip_val
                    )
                    self.scaler.step(optimizer)
                    self.scaler.update()
                else:
                    outputs = self.model(**batch)
                    loss = outputs.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(),
                        self.config.gradient_clip_val
                    )
                    optimizer.step()

                scheduler.step()
                optimizer.zero_grad()

                epoch_loss += loss.item()
                global_step += 1

                progress_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'lr': f"{scheduler.get_last_lr()[0]:.2e}"
                })

                if self.config.use_wandb and global_step % 50 == 0:
                    wandb.log({
                        "train_loss": loss.item(),
                        "learning_rate": scheduler.get_last_lr()[0],
                        "epoch": epoch,
                        "global_step": global_step
                    })

                if global_step % self.config.eval_steps == 0:
                    val_loss = self._quick_validate(val_loader)
                    logger.info(f"Step {global_step}: Val Loss = {val_loss:.4f}")

                    if self.config.use_wandb:
                        wandb.log({"val_loss": val_loss, "global_step": global_step})

                    if val_loss < best_val_loss:
                        best_val_loss = val_loss
                        self._save_model("best_model_fast")
                        logger.info(f"New best model saved! Val loss: {val_loss:.4f}")

                    self.model.train()
                if self.device.type == 'mps' and batch_idx % 100 == 0:
                    torch.mps.empty_cache()

            avg_epoch_loss = epoch_loss / len(train_loader)
            logger.info(f"Epoch {epoch + 1} completed. Avg Loss: {avg_epoch_loss:.4f}")


        self._save_model("final_model_fast")
        logger.info("Fast training completed!")

        if self.config.use_wandb:
            wandb.finish()

    def _quick_validate(self, val_loader: DataLoader) -> float:
        """Quick validation on subset of validation data."""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        max_val_batches = 10

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                if batch_idx >= max_val_batches:
                    break

                batch = {k: v.to(self.device) for k, v in batch.items()}
                outputs = self.model(**batch)
                total_loss += outputs.loss.item()
                num_batches += 1

        return total_loss / num_batches if num_batches > 0 else float('inf')

    def _save_model(self, name: str):
        """Save model and tokenizer."""
        save_path = Path(f"./models/{name}")
        save_path.mkdir(parents=True, exist_ok=True)

        self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

        with open(save_path / "config.json", "w") as f:
            json.dump(self.config.__dict__, f, indent=2)

        logger.info(f"Model saved to {save_path}")

    def normalize_text(self, text: str) -> str:
        """Normalize text using trained T5 model."""
        if not text or not isinstance(text, str):
            return ""

        try:
            input_text = f"normalize: {text}"

            inputs = self.tokenizer(
                input_text,
                max_length=self.config.max_source_length,
                padding=True,
                truncation=True,
                return_tensors='pt'
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=self.config.max_target_length,
                    num_beams=2,
                    length_penalty=0.6,
                    early_stopping=True,
                    do_sample=False
                )

            # Decode
            normalized_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            return normalized_text.strip()

        except Exception as e:
            logger.error(f"Error normalizing text: {str(e)}")
            return text

def fast_training_pipeline(data_path: str):
    """Optimized training pipeline for quick experimentation."""
    logger.info("Starting FAST T5 Russian Text Normalization Training")

    config = TrainingConfig(
        model_name="ai-forever/ruT5-base",
        batch_size=32,
        learning_rate=5e-4,
        num_epochs=3,
        sample_size=50000,
        use_wandb=True,
        max_source_length=96,
        max_target_length=96
    )


    cleaner = RobustRussianTextCleaner()
    sampler = StratifiedDataSampler(config)


    logger.info("Loading and cleaning data...")
    df = cleaner.safe_load_csv(data_path)
    df_clean = cleaner.clean_dataset(df, aggressive_cleaning=False)  # Gentle cleaning


    df_sample = sampler.sample_dataset(df_clean)


    if 'before' in df_sample.columns and 'after' in df_sample.columns:
        data_pairs = list(zip(df_sample['before'].astype(str), df_sample['after'].astype(str)))
    else:
        raise ValueError("Expected 'before' and 'after' columns")

    train_data, val_data = train_test_split(data_pairs, test_size=0.1, random_state=42)

    logger.info(f"Training on {len(train_data)} examples, validating on {len(val_data)}")

    model = FastT5TextNormalizer(config)
    model.train_fast(train_data, val_data)

    logger.info("Running quick evaluation...")
    correct = 0
    total = 0

    for i, (source, target) in enumerate(val_data[:100]):
        pred = model.normalize_text(source)
        if pred.strip().lower() == target.strip().lower():
            correct += 1
        total += 1

        if i < 5:
            logger.info(f"Example {i+1}:")
            logger.info(f"  Input:  {source}")
            logger.info(f"  Target: {target}")
            logger.info(f"  Pred:   {pred}")
            logger.info(f"  Match:  {pred.strip().lower() == target.strip().lower()}")

    accuracy = correct / total
    logger.info(f"Quick accuracy on 100 examples: {accuracy:.4f}")

    return model



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import os
if __name__ == '__main__':
    import multiprocessing
    multiprocessing.set_start_method('fork', force=True)

    os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
    os.environ['TOKENIZERS_PARALLELISM'] = 'true'

    torch.set_float32_matmul_precision('medium')

    torch.set_num_threads(8)

    DATA_PATH = 'data/to_normalize.csv'
    OUT_PATH = 'data/normalized_llm.csv'
    model = fast_training_pipeline("data/ru_train.csv")

2025-05-31 22:11:13,014 - __main__ - INFO - Starting FAST T5 Russian Text Normalization Training
2025-05-31 22:11:13,017 - __main__ - INFO - Loading and cleaning data...
2025-05-31 22:11:13,018 - __main__ - INFO - Trying to load with encoding: utf-8
2025-05-31 22:11:18,533 - __main__ - INFO - Successfully loaded with utf-8
2025-05-31 22:11:18,534 - __main__ - INFO - Starting robust cleaning of 10574516 rows
2025-05-31 22:11:22,155 - __main__ - INFO - Removed 14 clearly empty rows
2025-05-31 22:11:22,156 - __main__ - INFO - Basic data handling: 10574516 -> 10574502 (14 removed, 0.0%)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['before'] = df['before'].apply(fix_encoding)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col

Using Apple Silicon MPS backend


  self.scaler = GradScaler()
2025-05-31 22:12:22,705 - __main__ - INFO - Loaded model: ai-forever/ruT5-base
2025-05-31 22:12:22,705 - __main__ - INFO - Starting FAST T5 training with 12094 train, 1344 val examples


Using Automatic Mixed Precision


2025-05-31 22:13:39,764 - __main__ - INFO - Initialized fast dataset with 12094 examples
2025-05-31 22:13:39,765 - __main__ - INFO - Initialized fast dataset with 1344 examples
2025-05-31 22:13:39,770 - __main__ - INFO - Starting epoch 1/3
  with autocast():
Epoch 1:  66%|██████▌   | 249/377 [10:33<05:13,  2.45s/it, loss=0.3924, lr=4.17e-04]

In [33]:
import time
import psutil

class PerformanceMonitor:
    def __init__(self):
        self.start_time = None
        self.step_times = []

    def start_epoch(self):
        self.start_time = time.time()

    def log_step(self, step, loss):
        current_time = time.time()
        if self.start_time:
            step_time = current_time - self.start_time
            self.step_times.append(step_time)

            if step % 50 == 0:
                avg_step_time = sum(self.step_times[-50:]) / len(self.step_times[-50:])
                memory_usage = psutil.virtual_memory().percent
                print(f"Step {step}: Loss={loss:.4f}, Avg Step Time={avg_step_time:.2f}s, Memory={memory_usage:.1f}%")

        self.start_time = current_time