In [1]:
import os
import random
import pandas as pd
import torch
from torch.utils.data import Dataset as TorchDataset
from transformers import T5ForConditionalGeneration, T5Tokenizer, Trainer, TrainingArguments, IntervalStrategy
import nltk
from tqdm import tqdm



In [2]:
nltk.download('punkt', quiet=True)

True

In [3]:
XLA_AVAILABLE = False
try:
    import torch_xla.core.xla_model as xm
    XLA_AVAILABLE = True
    print("TPU/XLA обнаружен. Обучение будет использовать XLA-бэкэнд.")
except ImportError:
    if torch.cuda.is_available():
        DEVICE = torch.device("cuda")
        print(f"CUDA обнаружена. Используемое устройство: {DEVICE}")
    else:
        DEVICE = torch.device("cpu")
        print(f"CUDA не найдена. Используемое устройство: {DEVICE}")

TPU/XLA обнаружен. Обучение будет использовать XLA-бэкэнд.


In [4]:
class SpellingCorrectionDataset(TorchDataset):
    """Класс Pytorch Dataset для работы с Pandas DataFrame."""
    def __init__(self, dataframe, tokenizer, max_length=64):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.max_length = max_length
        self.input_texts = self.data['input_text'].tolist()
        self.target_texts = self.data['target_text'].tolist()

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):
        input_text = self.input_texts[index]
        target_text = self.target_texts[index]

        input_encoding = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        target_encoding = self.tokenizer(
            text_target=target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': input_encoding['input_ids'].flatten(),
            'attention_mask': input_encoding['attention_mask'].flatten(),
            'labels': target_encoding['input_ids'].flatten()
        }

In [None]:
try:
    FULL_TRAIN_DF = pd.read_csv("./train_words.csv")
    FULL_VAL_DF = pd.read_csv("./val_words.csv")
    FULL_TEST_DF = pd.read_csv("./test_words.csv")

    print(f"Полные данные загружены: Train={len(FULL_TRAIN_DF)}, Val={len(FULL_VAL_DF)}, Test={len(FULL_TEST_DF)}")
except FileNotFoundError:
    print("Ошибка: CSV файлы датасета не найдены.")
    exit()


TRAIN_DF = FULL_TRAIN_DF.sample(n=int(len(FULL_TRAIN_DF) * 0.1), random_state=42).reset_index(drop=True)
VAL_DF = FULL_VAL_DF.sample(n=int(len(FULL_VAL_DF) * 0.1), random_state=42).reset_index(drop=True)
TEST_DF = FULL_TEST_DF.sample(n=int(len(FULL_TEST_DF) * 0.1), random_state=42).reset_index(drop=True)
print(f"Данные загружены: Train={len(TRAIN_DF)}, Val={len(VAL_DF)}, Test={len(TEST_DF)}")

model_name = "cointegrated/rut5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

if not XLA_AVAILABLE:
    model.to(DEVICE)


train_dataset = SpellingCorrectionDataset(TRAIN_DF, tokenizer)
val_dataset = SpellingCorrectionDataset(VAL_DF, tokenizer)
EST_DF = FULL_TEST_DF.sample(n=len(FULL_TEST_DF)//10, random_state=42)

NUM_EPOCHS_TEST = 3

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy=IntervalStrategy.EPOCH,
    save_strategy=IntervalStrategy.EPOCH,

    optim="adamw_torch",

    learning_rate=2e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=NUM_EPOCHS_TEST,
    weight_decay=0.01,
    save_total_limit=1,
    load_best_model_at_end=True,
    report_to="none",
    logging_steps=50
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print(f"Обучение в {NUM_EPOCHS_TEST} эпох")
trainer.train()

output_dir = "./"

model.save_pretrained(output_dir, safe_serialization=False)
tokenizer.save_pretrained(output_dir)

def calculate_cer(reference, hypothesis):
    reference = reference.replace(' ', '')
    hypothesis = hypothesis.replace(' ', '')
    if len(reference) == 0:
        return 0.0
    return nltk.edit_distance(reference, hypothesis) / len(reference)

def correct_word(input_word_only, current_model, current_tokenizer):
    prefixed_text = 'fix spelling: ' + input_word_only
    inputs = current_tokenizer(prefixed_text, return_tensors="pt", max_length=64, truncation=True, padding="max_length")

    model_device = current_model.device
    input_ids = inputs.input_ids.to(model_device)
    attention_mask = inputs.attention_mask.to(model_device)
    if XLA_AVAILABLE and model_device.type == 'xla':
        with torch.no_grad():
            outputs = current_model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_length=64,
                num_beams=4,
                early_stopping=True,
            ).cpu()
    else:
        with torch.no_grad():
            outputs = current_model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_length=64,
                num_beams=4,
                early_stopping=True
            )

    return current_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()


total_cer = 0
N = len(TEST_DF)


test_data_for_eval = TEST_DF.to_dict('records')

for row in tqdm(test_data_for_eval, desc="Тестирование"):
    target = row['target_text']
    input_word_only = row['input_text'].replace('fix spelling: ', '')
    predicted = correct_word(input_word_only, model, tokenizer)
    current_cer = calculate_cer(target, predicted)
    total_cer += current_cer

final_cer = total_cer / N
print(f"CER НА ТЕСТОВОМ ДАТАСЕТЕ: {final_cer:.4f}")

Полные данные загружены: Train=38856, Val=4857, Test=4857
Данные загружены: Train=3885, Val=485, Test=485


You are using a model of type mt5 to instantiate a model of type t5. This is not supported for all configurations of models and can yield errors.


Обучение в 3 эпох




Epoch,Training Loss,Validation Loss
1,0.1557,0.141744
2,0.1455,0.134988
3,0.135,0.132068


Тестирование: 100%|██████████| 485/485 [08:58<00:00,  1.11s/it]

CER НА ТЕСТОВОМ ДАТАСЕТЕ: 0.3971



