In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import (
    MBartForConditionalGeneration, MBart50TokenizerFast,
    DistilBertModel, DistilBertTokenizer,
    get_linear_schedule_with_warmup, GenerationConfig
)
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import random
from torch.cuda.amp import autocast, GradScaler
import itertools

In [None]:
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

TAG_TO_LIT = "[TO_LIT]"
TAG_TO_CONV = "[TO_CONV]"

GEN_MODEL_NAME = "sn4kebyt3/ru-bart-large"
CLASSIFIER_MODEL_NAME = "DeepPavlov/distilrubert-base-cased-conversational"
MAX_LENGTH = 128
NUM_EPOCHS = 12

BATCH_SIZE = 96
STEPS_PER_EPOCH = 125

LEARNING_RATE_GEN = 2e-5
LEARNING_RATE_DISC = 4e-5
TEST_SIZE = 0.1
VAL_TEST_SPLIT = 0.5
RANDOM_STATE = 42
SAVE_DIR = "tag_cyclegan_bart_final_v2"
MODEL_PATH_G_BEST = os.path.join(SAVE_DIR, "G_bart_best.pth")
MODEL_PATH_D_C_BEST = os.path.join(SAVE_DIR, "D_C_best.pth")
MODEL_PATH_D_L_BEST = os.path.join(SAVE_DIR, "D_L_best.pth")
PLOT_PATH = os.path.join(SAVE_DIR, "training_plots_final.png")

LAMBDA_CYCLE = 4.0
LAMBDA_IDENTITY = 4.0
LAMBDA_STYLE = 12.0
LAMBDA_ADV = 1.5

base_gen_config_train = GenerationConfig(
    max_length=MAX_LENGTH, min_length=5, num_beams=3,
    early_stopping=True, temperature=1.0,
    repetition_penalty=1.0, no_repeat_ngram_size=0
)
base_gen_config_val_log = GenerationConfig(
    max_length=MAX_LENGTH, min_length=10, num_beams=3,
    early_stopping=True, temperature=1.0,
    repetition_penalty=1.2, no_repeat_ngram_size=4
)
NUM_VAL_STYLE_EXAMPLES = BATCH_SIZE

Using device: cuda


In [None]:
try:
    data = pd.read_csv("data.csv")
    texts_C_full = list(data['tg_text'].astype(str))
    texts_L_full = list(data['lit_text'].astype(str))
    print(f"Загружено {len(texts_C_full)} разговорных и {len(texts_L_full)} литературных текстов.")
except Exception as e:
    print(f"Ошибка загрузки данных: {e}"); raise

train_texts_C, temp_texts_C = train_test_split(texts_C_full, test_size=TEST_SIZE, random_state=RANDOM_STATE)
train_texts_L, temp_texts_L = train_test_split(texts_L_full, test_size=TEST_SIZE, random_state=RANDOM_STATE)
val_texts_C, _ = train_test_split(temp_texts_C, test_size=VAL_TEST_SPLIT, random_state=RANDOM_STATE)
val_texts_L, _ = train_test_split(temp_texts_L, test_size=VAL_TEST_SPLIT, random_state=RANDOM_STATE)
print(f"Размеры датасетов: Train C: {len(train_texts_C)}, Val C: {len(val_texts_C)}")
print(f"                   Train L: {len(train_texts_L)}, Val L: {len(val_texts_L)}")

Загружено 1382549 разговорных и 1382549 литературных текстов.
Размеры датасетов: Train C: 1244294, Val C: 69127
                   Train L: 1244294, Val L: 69127


In [None]:
try:
    gen_tokenizer = MBart50TokenizerFast.from_pretrained(GEN_MODEL_NAME)
    style_tokenizer = DistilBertTokenizer.from_pretrained(CLASSIFIER_MODEL_NAME)

    special_tokens_to_add = {'additional_special_tokens': [TAG_TO_LIT, TAG_TO_CONV]}
    num_added_toks = gen_tokenizer.add_special_tokens(special_tokens_to_add)
    print(f"Добавлено {num_added_toks} спец. токенов в gen_tokenizer: {TAG_TO_LIT}, {TAG_TO_CONV} (Новый размер словаря: {len(gen_tokenizer)})")

    RUSSIAN_TOKEN_ID = gen_tokenizer.lang_code_to_id.get("ru_RU", gen_tokenizer.eos_token_id)
    if RUSSIAN_TOKEN_ID == gen_tokenizer.eos_token_id and "ru_RU" not in gen_tokenizer.lang_code_to_id:
        print(f"ПРЕДУПРЕЖДЕНИЕ: Токен языка 'ru_RU' не найден. Используется eos_token_id ({RUSSIAN_TOKEN_ID})")
    print(f"ID токена русского языка для mBART генератора: {RUSSIAN_TOKEN_ID}")
except Exception as e:
    print(f"Ошибка инициализации токенизаторов: {e}"); raise



Добавлено 2 спец. токенов в gen_tokenizer: [TO_LIT], [TO_CONV] (Новый размер словаря: 24263)
ID токена русского языка для mBART генератора: 24228


In [None]:
def update_generation_config(base_config, model_config_obj, tokenizer_pad_token_id):
    updated_config = GenerationConfig.from_dict(base_config.to_dict())
    updated_config.decoder_start_token_id = model_config_obj.decoder_start_token_id
    updated_config.eos_token_id = model_config_obj.eos_token_id
    updated_config.pad_token_id = tokenizer_pad_token_id
    if hasattr(model_config_obj, 'forced_bos_token_id') and model_config_obj.forced_bos_token_id is not None:
        updated_config.forced_bos_token_id = model_config_obj.forced_bos_token_id
    return updated_config

In [None]:
class StyleDataset(Dataset):
    def __init__(self, texts_list, generator_tokenizer, max_sequence_length,
                 generator_input_tag=None, create_identity_labels=False):
        self.texts = [str(text) for text in texts_list]
        self.generator_tokenizer = generator_tokenizer
        self.max_sequence_length = max_sequence_length
        self.generator_input_tag = generator_input_tag
        self.create_identity_labels = create_identity_labels

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        original_text_string = self.texts[idx]
        input_text_for_g = f"{self.generator_input_tag} {original_text_string}" if self.generator_input_tag else original_text_string
        tokenized_g_input = self.generator_tokenizer(input_text_for_g, padding="max_length", truncation=True, max_length=self.max_sequence_length, return_tensors='pt')
        item = {key: val.squeeze(0) for key, val in tokenized_g_input.items()}
        item['original_text_str'] = original_text_string
        tokenized_original = self.generator_tokenizer(original_text_string, padding="max_length", truncation=True, max_length=self.max_sequence_length, return_tensors='pt')
        item['original_ids'] = tokenized_original['input_ids'].squeeze(0)
        item['original_mask'] = tokenized_original['attention_mask'].squeeze(0)
        if self.create_identity_labels: item['labels'] = item['original_ids'].clone()
        return item

In [None]:
train_dataset_C_main = StyleDataset(train_texts_C, gen_tokenizer, MAX_LENGTH)
train_dataset_L_main = StyleDataset(train_texts_L, gen_tokenizer, MAX_LENGTH)
train_dataloader_C = DataLoader(train_dataset_C_main, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0, pin_memory=True if device.type == 'cuda' else False)
train_dataloader_L = DataLoader(train_dataset_L_main, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=0, pin_memory=True if device.type == 'cuda' else False)

val_dataset_L_identity_main = StyleDataset(val_texts_L, gen_tokenizer, MAX_LENGTH, generator_input_tag=TAG_TO_LIT, create_identity_labels=True)
val_dataset_C_identity_main = StyleDataset(val_texts_C, gen_tokenizer, MAX_LENGTH, generator_input_tag=TAG_TO_CONV, create_identity_labels=True)
val_dataloader_L_identity = DataLoader(val_dataset_L_identity_main, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True if device.type == 'cuda' else False)
val_dataloader_C_identity = DataLoader(val_dataset_C_identity_main, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True if device.type == 'cuda' else False)

val_dataset_C_for_style = StyleDataset(val_texts_C, gen_tokenizer, MAX_LENGTH)
val_dataset_L_for_style = StyleDataset(val_texts_L, gen_tokenizer, MAX_LENGTH)
val_dataloader_C_for_style = DataLoader(val_dataset_C_for_style, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True if device.type == 'cuda' else False)
val_dataloader_L_for_style = DataLoader(val_dataset_L_for_style, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True if device.type == 'cuda' else False)

In [None]:
# --- Модели ---
class Generator(nn.Module):
    def __init__(self, model_name_str, tokenizer_vocabulary_size, russian_language_token_id):
        super().__init__()
        self.bart_model = MBartForConditionalGeneration.from_pretrained(model_name_str)
        self.bart_model.resize_token_embeddings(tokenizer_vocabulary_size)
        self.generation_config_internal = self.bart_model.config
        self.generation_config_internal.forced_bos_token_id = russian_language_token_id
        self.generation_config_internal.decoder_start_token_id = russian_language_token_id
        self.pad_token_id = gen_tokenizer.pad_token_id
        print(f"Генератор инициализирован: forced_bos={self.generation_config_internal.forced_bos_token_id}, dec_start={self.generation_config_internal.decoder_start_token_id}")

    def forward(self, input_ids, attention_mask, labels=None):
        return self.bart_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

    def generate_texts(self, input_ids, attention_mask, external_generation_config):
        return self.bart_model.generate(input_ids=input_ids, attention_mask=attention_mask, generation_config=external_generation_config)

    def get_attention_mask_for_generated(self, generated_ids_tensor):
        return (generated_ids_tensor != self.pad_token_id).long()

In [None]:
class Discriminator(nn.Module):
    def __init__(self, vocabulary_size, max_sequence_length):
        super().__init__()
        self.embedding_layer = nn.Embedding(vocabulary_size, 128)
        self.cnn_layers = nn.Sequential(
            nn.Conv1d(128, 256, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(256, 512, kernel_size=3, padding=1), nn.ReLU(), nn.MaxPool1d(kernel_size=2))
        self.fc_layer = nn.Linear(512, 1)

    def forward(self, input_ids_tensor, attention_mask_tensor=None):
        embedded_x = self.embedding_layer(input_ids_tensor);
        if attention_mask_tensor is not None: embedded_x = embedded_x * attention_mask_tensor.unsqueeze(-1)
        permuted_x = embedded_x.permute(0,2,1); convolved_x = self.cnn_layers(permuted_x)
        pooled_x = F.adaptive_avg_pool1d(convolved_x,1).squeeze(-1); logits = self.fc_layer(pooled_x)
        return logits

In [None]:
class StyleClassifier(nn.Module):
    def __init__(self, classifier_model_name_str):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained(classifier_model_name_str)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 2)

    def forward(self, input_ids_tensor, attention_mask_tensor):
        bert_output = self.bert(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor)
        cls_token_embedding = bert_output.last_hidden_state[:,0,:]
        logits = self.classifier(cls_token_embedding)
        return logits

In [None]:
G_model = Generator(GEN_MODEL_NAME, len(gen_tokenizer), RUSSIAN_TOKEN_ID).to(device)
D_C_model = Discriminator(len(gen_tokenizer), MAX_LENGTH).to(device)
D_L_model = Discriminator(len(gen_tokenizer), MAX_LENGTH).to(device)

try:
    style_classifier_main_model = StyleClassifier(CLASSIFIER_MODEL_NAME).to(device)
    style_classifier_model_path = "style_classifier/model.pth";
    if not os.path.exists(style_classifier_model_path): raise FileNotFoundError(f"Нет классификатора: {style_classifier_model_path}")
    style_classifier_main_model.load_state_dict(torch.load(style_classifier_model_path, map_location=device)); print("Классификатор стиля загружен.")
except Exception as e: print(f"Ошибка StyleClassifier: {e}"); raise
style_classifier_main_model.eval(); [param.requires_grad_(False) for param in style_classifier_main_model.parameters()]

gen_config_train = update_generation_config(base_gen_config_train, G_model.generation_config_internal, gen_tokenizer.pad_token_id)
gen_config_val_log = update_generation_config(base_gen_config_val_log, G_model.generation_config_internal, gen_tokenizer.pad_token_id)

Генератор инициализирован: forced_bos=24228, dec_start=24228
Классификатор стиля загружен.


In [None]:
optimizer_G = AdamW(G_model.parameters(), lr=LEARNING_RATE_GEN, eps=1e-8)
optimizer_D_C = AdamW(D_C_model.parameters(), lr=LEARNING_RATE_DISC, eps=1e-8)
optimizer_D_L = AdamW(D_L_model.parameters(), lr=LEARNING_RATE_DISC, eps=1e-8)

num_total_training_steps = NUM_EPOCHS * STEPS_PER_EPOCH; num_warmup_steps_sched = int(0.05 * num_total_training_steps)

scheduler_G = get_linear_schedule_with_warmup(optimizer_G, num_warmup_steps=num_warmup_steps_sched, num_training_steps=num_total_training_steps)
scheduler_D_C = get_linear_schedule_with_warmup(optimizer_D_C, num_warmup_steps=num_warmup_steps_sched, num_training_steps=num_total_training_steps)
scheduler_D_L = get_linear_schedule_with_warmup(optimizer_D_L, num_warmup_steps=num_warmup_steps_sched, num_training_steps=num_total_training_steps)

In [None]:
grad_scaler = GradScaler()
adversarial_loss_fn = nn.BCEWithLogitsLoss()
cross_entropy_loss_fn_for_G = nn.CrossEntropyLoss(ignore_index=gen_tokenizer.pad_token_id)
style_classification_loss_fn = nn.CrossEntropyLoss()

In [None]:
def tokenize_for_style_classifier_utility(texts_batch, style_cls_tokenizer, max_len_val, current_dev):
    return style_cls_tokenizer(texts_batch, padding="max_length", truncation=True, max_length=max_len_val, return_tensors='pt').to(current_dev)

In [None]:
@torch.no_grad()
def run_validation_final(
    generator_to_validate,
    val_dl_C_for_id, val_dl_L_for_id,
    val_dl_C_for_style_acc, val_dl_L_for_style_acc,
    style_classifier_for_val,
    current_validation_gen_config,
    num_examples_for_style_eval, current_eval_device
):
    generator_to_validate.eval()
    val_epoch_metrics = {}

    current_id_loss_C, num_id_batches_C = 0.0, 0
    for batch_data in tqdm(val_dl_C_for_id, desc="Validating Identity C->C", leave=False, ncols=100):
        input_ids = batch_data['input_ids'].to(current_eval_device)
        attention_mask = batch_data['attention_mask'].to(current_eval_device)
        labels = batch_data['labels'].to(current_eval_device)
        with autocast():
            outputs = generator_to_validate(input_ids, attention_mask, labels=labels)
            current_id_loss_C += outputs.loss.item()
        num_id_batches_C += 1
    val_epoch_metrics['id_loss_C_val'] = current_id_loss_C / num_id_batches_C if num_id_batches_C > 0 else float('inf')

    current_id_loss_L, num_id_batches_L = 0.0, 0
    for batch_data in tqdm(val_dl_L_for_id, desc="Validating Identity L->L", leave=False, ncols=100):
        input_ids = batch_data['input_ids'].to(current_eval_device)
        attention_mask = batch_data['attention_mask'].to(current_eval_device)
        labels = batch_data['labels'].to(current_eval_device)
        with autocast():
            outputs = generator_to_validate(input_ids, attention_mask, labels=labels)
            current_id_loss_L += outputs.loss.item()
        num_id_batches_L += 1
    val_epoch_metrics['id_loss_L_val'] = current_id_loss_L / num_id_batches_L if num_id_batches_L > 0 else float('inf')

    generated_L_examples_val, total_s_loss_C2L, correct_s_preds_C2L, count_s_C2L = [], 0.0, 0, 0
    num_batches_to_process_C = min(len(val_dl_C_for_style_acc), (num_examples_for_style_eval + BATCH_SIZE - 1) // BATCH_SIZE)

    for batch_data in tqdm(itertools.islice(val_dl_C_for_style_acc, num_batches_to_process_C), desc="Validating Style C->L", leave=False, ncols=100, total=num_batches_to_process_C):
        real_C_text_strings = batch_data['original_text_str']
        g_input_val_texts = [f"{TAG_TO_LIT} {text}" for text in real_C_text_strings]
        tokenized_g_val_input = gen_tokenizer(g_input_val_texts, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(current_eval_device)
        with autocast():
            fake_L_ids_generated = generator_to_validate.generate_texts(tokenized_g_val_input.input_ids, tokenized_g_val_input.attention_mask, external_generation_config=current_validation_gen_config)
        fake_L_text_strings = gen_tokenizer.batch_decode(fake_L_ids_generated, skip_special_tokens=True)
        tokenized_for_style_input = tokenize_for_style_classifier_utility(fake_L_text_strings, style_tokenizer, MAX_LENGTH, current_eval_device)
        with autocast():
            style_logits_output = style_classifier_for_val(tokenized_for_style_input.input_ids, tokenized_for_style_input.attention_mask)
        target_style_labels = torch.ones(style_logits_output.size(0), dtype=torch.long, device=current_eval_device)
        total_s_loss_C2L += style_classification_loss_fn(style_logits_output, target_style_labels).item() * style_logits_output.size(0)
        predicted_style_labels = torch.argmax(style_logits_output, dim=1)
        correct_s_preds_C2L += (predicted_style_labels == target_style_labels).sum().item()
        count_s_C2L += style_logits_output.size(0)
        if not generated_L_examples_val: generated_L_examples_val.extend(list(zip(real_C_text_strings[:3], fake_L_text_strings[:3])))
    val_epoch_metrics['style_loss_C2L_val'] = total_s_loss_C2L / count_s_C2L if count_s_C2L > 0 else float('inf')
    val_epoch_metrics['style_acc_C2L_val'] = correct_s_preds_C2L / count_s_C2L if count_s_C2L > 0 else 0.0
    val_epoch_metrics['example_C2L_gen_val'] = generated_L_examples_val

    generated_C_examples_val, total_s_loss_L2C, correct_s_preds_L2C, count_s_L2C = [], 0.0, 0, 0
    num_batches_to_process_L = min(len(val_dl_L_for_style_acc), (num_examples_for_style_eval + BATCH_SIZE - 1) // BATCH_SIZE)
    for batch_data in tqdm(itertools.islice(val_dl_L_for_style_acc, num_batches_to_process_L), desc="Validating Style L->C", leave=False, ncols=100, total=num_batches_to_process_L):
        real_L_text_strings = batch_data['original_text_str']
        g_input_val_texts = [f"{TAG_TO_CONV} {text}" for text in real_L_text_strings]
        tokenized_g_val_input = gen_tokenizer(g_input_val_texts, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(current_eval_device)
        with autocast():
            fake_C_ids_generated = generator_to_validate.generate_texts(tokenized_g_val_input.input_ids, tokenized_g_val_input.attention_mask, external_generation_config=current_validation_gen_config)
        fake_C_text_strings = gen_tokenizer.batch_decode(fake_C_ids_generated, skip_special_tokens=True)
        tokenized_for_style_input = tokenize_for_style_classifier_utility(fake_C_text_strings, style_tokenizer, MAX_LENGTH, current_eval_device)
        with autocast():
            style_logits_output = style_classifier_for_val(tokenized_for_style_input.input_ids, tokenized_for_style_input.attention_mask)
        target_style_labels = torch.zeros(style_logits_output.size(0), dtype=torch.long, device=current_eval_device)
        total_s_loss_L2C += style_classification_loss_fn(style_logits_output, target_style_labels).item() * style_logits_output.size(0)
        predicted_style_labels = torch.argmax(style_logits_output, dim=1)
        correct_s_preds_L2C += (predicted_style_labels == target_style_labels).sum().item()
        count_s_L2C += style_logits_output.size(0)
        if not generated_C_examples_val: generated_C_examples_val.extend(list(zip(real_L_text_strings[:3], fake_C_text_strings[:3])))
    val_epoch_metrics['style_loss_L2C_val'] = total_s_loss_L2C / count_s_L2C if count_s_L2C > 0 else float('inf')
    val_epoch_metrics['style_acc_L2C_val'] = correct_s_preds_L2C / count_s_L2C if count_s_L2C > 0 else 0.0
    val_epoch_metrics['example_L2C_gen_val'] = generated_C_examples_val

    val_epoch_metrics['G_total_val_comparable'] = \
        (val_epoch_metrics['id_loss_C_val'] + val_epoch_metrics['id_loss_L_val']) * 0.5 * LAMBDA_IDENTITY + \
        (val_epoch_metrics['style_loss_C2L_val'] + val_epoch_metrics['style_loss_L2C_val']) * 0.5 * LAMBDA_STYLE

    generator_to_validate.train()
    return val_epoch_metrics

In [None]:
training_history = {
    'epoch': [],
    'G_total_train_full': [], 'G_adv_train': [], 'G_cycle_train': [],
    'G_identity_train': [], 'G_style_train': [], 'D_total_train': [],
    'G_total_train_comparable': [],
    'G_id_loss_C_val': [], 'G_id_loss_L_val': [],
    'G_style_loss_C2L_val': [], 'G_style_acc_C2L_val': [],
    'G_style_loss_L2C_val': [], 'G_style_acc_L2C_val': [],
    'G_total_val_comparable': []
}
best_validation_metric = float('inf')
os.makedirs(SAVE_DIR, exist_ok=True)

data_iterator_C = itertools.cycle(train_dataloader_C)
data_iterator_L = itertools.cycle(train_dataloader_L)

for epoch_num in range(NUM_EPOCHS):
    G_model.train(); D_C_model.train(); D_L_model.train()

    current_epoch_train_losses_sum = {
        'G_total_train_full': 0.0, 'G_adv_train': 0.0, 'G_cycle_train': 0.0,
        'G_identity_train': 0.0, 'G_style_train': 0.0, 'D_total_train': 0.0,
        'G_total_train_comparable': 0.0
    }

    progress_bar = tqdm(range(STEPS_PER_EPOCH), desc=f"Эпоха {epoch_num + 1}/{NUM_EPOCHS}", ncols=120)

    for step_num in progress_bar:
        batch_C_data = next(data_iterator_C)
        batch_L_data = next(data_iterator_L)

        real_C_original_ids = batch_C_data['original_ids'].to(device)
        real_C_original_mask = batch_C_data['original_mask'].to(device)
        real_C_original_texts_list = batch_C_data['original_text_str']

        real_L_original_ids = batch_L_data['original_ids'].to(device)
        real_L_original_mask = batch_L_data['original_mask'].to(device)
        real_L_original_texts_list = batch_L_data['original_text_str']

        optimizer_D_C.zero_grad()
        optimizer_D_L.zero_grad()

        g_input_C_to_L_texts_list = [f"{TAG_TO_LIT} {text}" for text in real_C_original_texts_list]
        tokenized_g_input_C_to_L = gen_tokenizer(g_input_C_to_L_texts_list, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(device)

        g_input_L_to_C_texts_list = [f"{TAG_TO_CONV} {text}" for text in real_L_original_texts_list]
        tokenized_g_input_L_to_C = gen_tokenizer(g_input_L_to_C_texts_list, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(device)

        with torch.no_grad():
            with autocast():
                fake_L_generated_ids = G_model.generate_texts(tokenized_g_input_C_to_L.input_ids, tokenized_g_input_C_to_L.attention_mask, external_generation_config=gen_config_train)
                fake_C_generated_ids = G_model.generate_texts(tokenized_g_input_L_to_C.input_ids, tokenized_g_input_L_to_C.attention_mask, external_generation_config=gen_config_train)

        fake_L_generated_mask = G_model.get_attention_mask_for_generated(fake_L_generated_ids)
        fake_C_generated_mask = G_model.get_attention_mask_for_generated(fake_C_generated_ids)

        with autocast():
            d_l_pred_on_real = D_L_model(real_L_original_ids, real_L_original_mask)
            d_l_pred_on_fake = D_L_model(fake_L_generated_ids.detach(), fake_L_generated_mask)
            loss_D_L_total_step = (adversarial_loss_fn(d_l_pred_on_real, torch.ones_like(d_l_pred_on_real)) + \
                                   adversarial_loss_fn(d_l_pred_on_fake, torch.zeros_like(d_l_pred_on_fake))) * 0.5

            d_c_pred_on_real = D_C_model(real_C_original_ids, real_C_original_mask)
            d_c_pred_on_fake = D_C_model(fake_C_generated_ids.detach(), fake_C_generated_mask)
            loss_D_C_total_step = (adversarial_loss_fn(d_c_pred_on_real, torch.ones_like(d_c_pred_on_real)) + \
                                   adversarial_loss_fn(d_c_pred_on_fake, torch.zeros_like(d_c_pred_on_fake))) * 0.5

            loss_D_combined_step = loss_D_L_total_step + loss_D_C_total_step

        grad_scaler.scale(loss_D_combined_step).backward()

        optimizer_G.zero_grad()
        with autocast():
            fake_L_ids_for_G = G_model.generate_texts(tokenized_g_input_C_to_L.input_ids, tokenized_g_input_C_to_L.attention_mask, external_generation_config=gen_config_train)
            fake_C_ids_for_G = G_model.generate_texts(tokenized_g_input_L_to_C.input_ids, tokenized_g_input_L_to_C.attention_mask, external_generation_config=gen_config_train)
            fake_L_mask_for_G = G_model.get_attention_mask_for_generated(fake_L_ids_for_G)
            fake_C_mask_for_G = G_model.get_attention_mask_for_generated(fake_C_ids_for_G)

            loss_G_adv_L_component = adversarial_loss_fn(D_L_model(fake_L_ids_for_G, fake_L_mask_for_G), torch.ones_like(D_L_model(fake_L_ids_for_G, fake_L_mask_for_G)))
            loss_G_adv_C_component = adversarial_loss_fn(D_C_model(fake_C_ids_for_G, fake_C_mask_for_G), torch.ones_like(D_C_model(fake_C_ids_for_G, fake_C_mask_for_G)))
            loss_G_adversarial_total = (loss_G_adv_L_component + loss_G_adv_C_component) * LAMBDA_ADV

            fake_L_texts_for_style_clf = gen_tokenizer.batch_decode(fake_L_ids_for_G, skip_special_tokens=True)
            fake_C_texts_for_style_clf = gen_tokenizer.batch_decode(fake_C_ids_for_G, skip_special_tokens=True)
            tokenized_L_for_style = tokenize_for_style_classifier_utility(fake_L_texts_for_style_clf, style_tokenizer, MAX_LENGTH, device)
            tokenized_C_for_style = tokenize_for_style_classifier_utility(fake_C_texts_for_style_clf, style_tokenizer, MAX_LENGTH, device)
            style_L_predictions = style_classifier_main_model(tokenized_L_for_style.input_ids, tokenized_L_for_style.attention_mask)
            style_C_predictions = style_classifier_main_model(tokenized_C_for_style.input_ids, tokenized_C_for_style.attention_mask)
            loss_G_style_L_comp = style_classification_loss_fn(style_L_predictions, torch.ones(style_L_predictions.size(0), dtype=torch.long, device=device))
            loss_G_style_C_comp = style_classification_loss_fn(style_C_predictions, torch.zeros(style_C_predictions.size(0), dtype=torch.long, device=device))
            loss_G_style_total = (loss_G_style_L_comp + loss_G_style_C_comp) * LAMBDA_STYLE

            g_input_reconstruct_C_texts_list = [f"{TAG_TO_CONV} {text}" for text in fake_L_texts_for_style_clf]
            tokenized_g_input_reconstruct_C = gen_tokenizer(g_input_reconstruct_C_texts_list, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(device)
            reconstructed_C_outputs = G_model(tokenized_g_input_reconstruct_C.input_ids, tokenized_g_input_reconstruct_C.attention_mask, labels=real_C_original_ids)
            loss_G_cycle_C_component = reconstructed_C_outputs.loss
            g_input_reconstruct_L_texts_list = [f"{TAG_TO_LIT} {text}" for text in fake_C_texts_for_style_clf]
            tokenized_g_input_reconstruct_L = gen_tokenizer(g_input_reconstruct_L_texts_list, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(device)
            reconstructed_L_outputs = G_model(tokenized_g_input_reconstruct_L.input_ids, tokenized_g_input_reconstruct_L.attention_mask, labels=real_L_original_ids)
            loss_G_cycle_L_component = reconstructed_L_outputs.loss
            loss_G_cycle_total = (loss_G_cycle_C_component + loss_G_cycle_L_component) * LAMBDA_CYCLE

            g_input_identity_L_texts_list = [f"{TAG_TO_LIT} {text}" for text in real_L_original_texts_list]
            tokenized_g_input_identity_L = gen_tokenizer(g_input_identity_L_texts_list, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(device)
            identity_L_outputs = G_model(tokenized_g_input_identity_L.input_ids, tokenized_g_input_identity_L.attention_mask, labels=real_L_original_ids)
            loss_G_identity_L_component = identity_L_outputs.loss
            g_input_identity_C_texts_list = [f"{TAG_TO_CONV} {text}" for text in real_C_original_texts_list]
            tokenized_g_input_identity_C = gen_tokenizer(g_input_identity_C_texts_list, padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(device)
            identity_C_outputs = G_model(tokenized_g_input_identity_C.input_ids, tokenized_g_input_identity_C.attention_mask, labels=real_C_original_ids)
            loss_G_identity_C_component = identity_C_outputs.loss
            loss_G_identity_total = (loss_G_identity_L_component + loss_G_identity_C_component) * LAMBDA_IDENTITY

            loss_G_full_step = loss_G_adversarial_total + loss_G_style_total + loss_G_cycle_total + loss_G_identity_total

        grad_scaler.scale(loss_G_full_step).backward()

        grad_scaler.step(optimizer_D_L)
        grad_scaler.step(optimizer_D_C)
        torch.nn.utils.clip_grad_norm_(G_model.parameters(), 1.0)
        grad_scaler.step(optimizer_G)

        grad_scaler.update()

        scheduler_G.step()
        scheduler_D_C.step()
        scheduler_D_L.step()

        current_epoch_train_losses_sum['D_total_train'] += loss_D_combined_step.item()
        current_epoch_train_losses_sum['G_total_train_full'] += loss_G_full_step.item()
        current_epoch_train_losses_sum['G_adv_train'] += loss_G_adversarial_total.item()
        current_epoch_train_losses_sum['G_style_train'] += loss_G_style_total.item()
        current_epoch_train_losses_sum['G_cycle_train'] += loss_G_cycle_total.item()
        current_epoch_train_losses_sum['G_identity_train'] += loss_G_identity_total.item()

        g_total_comparable_train_step_unweighted_id = (loss_G_identity_L_component.item() + loss_G_identity_C_component.item()) * 0.5
        g_total_comparable_train_step_unweighted_style = (loss_G_style_L_comp.item() + loss_G_style_C_comp.item()) * 0.5
        g_total_comparable_train_step = (g_total_comparable_train_step_unweighted_id * LAMBDA_IDENTITY) + \
                                    (g_total_comparable_train_step_unweighted_style * LAMBDA_STYLE)
        current_epoch_train_losses_sum['G_total_train_comparable'] += g_total_comparable_train_step

        progress_bar.set_postfix({
            "G_Full": f"{loss_G_full_step.item():.2f}",
            "D_Total": f"{loss_D_combined_step.item():.2f}",
            "LR_G": f"{scheduler_G.get_last_lr()[0]:.2e}"
        })

    training_history['epoch'].append(epoch_num + 1)
    for key_hist in current_epoch_train_losses_sum:
        training_history[key_hist].append(current_epoch_train_losses_sum[key_hist] / STEPS_PER_EPOCH)

    validation_epoch_results = run_validation_final(
        G_model, val_dataloader_C_identity, val_dataloader_L_identity,
        val_dataloader_C_for_style, val_dataloader_L_for_style,
        style_classifier_main_model,
        gen_config_val_log,
        NUM_VAL_STYLE_EXAMPLES, device
    )

    training_history['G_id_loss_C_val'].append(validation_epoch_results['id_loss_C_val'])
    training_history['G_id_loss_L_val'].append(validation_epoch_results['id_loss_L_val'])
    training_history['G_style_loss_C2L_val'].append(validation_epoch_results['style_loss_C2L_val'])
    training_history['G_style_acc_C2L_val'].append(validation_epoch_results['style_acc_C2L_val'])
    training_history['G_style_loss_L2C_val'].append(validation_epoch_results['style_loss_L2C_val'])
    training_history['G_style_acc_L2C_val'].append(validation_epoch_results['style_acc_L2C_val'])
    training_history['G_total_val_comparable'].append(validation_epoch_results['G_total_val_comparable'])

    print(f"\n--- Результаты Эпохи {epoch_num + 1}/{NUM_EPOCHS} ---")
    print(f"  Потери Тренировки: G_Full={training_history['G_total_train_full'][-1]:.3f} "
          f"(Adv={training_history['G_adv_train'][-1]:.3f}, Style={training_history['G_style_train'][-1]:.3f}, "
          f"Cyc={training_history['G_cycle_train'][-1]:.3f}, Id={training_history['G_identity_train'][-1]:.3f}), "
          f"D_Total={training_history['D_total_train'][-1]:.3f}")
    print(f"                     G_Comparable_Train={training_history['G_total_train_comparable'][-1]:.3f}")
    print(f"  Метрики Валидации: G_Id(C|L)={validation_epoch_results['id_loss_C_val']:.3f}|{validation_epoch_results['id_loss_L_val']:.3f}, "
          f"G_Style(C2L L|A)={validation_epoch_results['style_loss_C2L_val']:.3f}|{validation_epoch_results['style_acc_C2L_val']:.2%}, "
          f"G_Style(L2C L|A)={validation_epoch_results['style_loss_L2C_val']:.3f}|{validation_epoch_results['style_acc_L2C_val']:.2%}")
    print(f"                     G_Total_Comparable_Val={validation_epoch_results['G_total_val_comparable']:.3f}")

    print("  Примеры генерации (Валидация):")
    for i, (inp_text, gen_text) in enumerate(validation_epoch_results['example_C2L_gen_val']):
        print(f"    C->L {i+1} IN:  \"{inp_text[:70].replace(chr(10),' ')}...\"")
        print(f"            OUT: \"{gen_text[:70].replace(chr(10),' ')}...\"")
    for i, (inp_text, gen_text) in enumerate(validation_epoch_results['example_L2C_gen_val']):
        print(f"    L->C {i+1} IN:  \"{inp_text[:70].replace(chr(10),' ')}...\"")
        print(f"            OUT: \"{gen_text[:70].replace(chr(10),' ')}...\"")

    current_validation_metric_for_saving = validation_epoch_results['G_total_val_comparable']
    if current_validation_metric_for_saving < best_validation_metric:
        best_validation_metric = current_validation_metric_for_saving
        torch.save(G_model.state_dict(), MODEL_PATH_G_BEST)
        torch.save(D_C_model.state_dict(), MODEL_PATH_D_C_BEST)
        torch.save(D_L_model.state_dict(), MODEL_PATH_D_L_BEST)
        print(f"  *** Новая лучшая модель сохранена! Val G_Comparable_Loss: {best_validation_metric:.3f} ***")

    fig, axs = plt.subplots(3, 1, figsize=(14, 21))
    fig.suptitle(f"Результаты обучения - Эпоха {epoch_num + 1}", fontsize=16)

    axs[0].plot(training_history['epoch'], training_history['G_total_train_full'], '-o', label='G Total Train (Full)', linewidth=2)
    axs[0].plot(training_history['epoch'], training_history['G_adv_train'], ':o', label='G Adv Train')
    axs[0].plot(training_history['epoch'], training_history['G_style_train'], ':o', label='G Style Train')
    axs[0].plot(training_history['epoch'], training_history['G_cycle_train'], ':o', label='G Cycle Train')
    axs[0].plot(training_history['epoch'], training_history['G_identity_train'], ':o', label='G Identity Train')
    axs[0].plot(training_history['epoch'], training_history['D_total_train'], '-x', label='D Total Train', linewidth=2)
    axs[0].set_title('Тренировочные Потери (Все Компоненты G)'); axs[0].set_xlabel('Эпоха'); axs[0].set_ylabel('Потеря')
    axs[0].legend(loc='upper right'); axs[0].grid(True)

    ax1_val_loss = axs[1]; ax1_val_acc = axs[1].twinx()
    p1, = ax1_val_loss.plot(training_history['epoch'], training_history['G_id_loss_C_val'], '-o', label='G Id_C Val', color='royalblue')
    p2, = ax1_val_loss.plot(training_history['epoch'], training_history['G_id_loss_L_val'], '-o', label='G Id_L Val', color='darkorange')
    p3, = ax1_val_loss.plot(training_history['epoch'], training_history['G_style_loss_C2L_val'], '--o', label='G Style C->L Loss Val', color='forestgreen')
    p4, = ax1_val_loss.plot(training_history['epoch'], training_history['G_style_loss_L2C_val'], '--o', label='G Style L->C Loss Val', color='crimson')
    p5, = ax1_val_acc.plot(training_history['epoch'], training_history['G_style_acc_C2L_val'], '-x', label='G Style C->L Acc Val', color='lime')
    p6, = ax1_val_acc.plot(training_history['epoch'], training_history['G_style_acc_L2C_val'], '-x', label='G Style L->C Acc Val', color='cyan')
    ax1_val_loss.set_title('Валидационные Метрики Генератора'); ax1_val_loss.set_xlabel('Эпоха'); ax1_val_loss.set_ylabel('Потеря (Loss)')
    ax1_val_acc.set_ylabel('Точность (Accuracy)', color='teal'); ax1_val_acc.tick_params(axis='y', labelcolor='teal')
    handles1, labels1 = ax1_val_loss.get_legend_handles_labels(); handles2, labels2 = ax1_val_acc.get_legend_handles_labels()
    ax1_val_loss.legend(handles=handles1 + handles2, labels=labels1 + labels2, loc='center left', bbox_to_anchor=(0.05, 0.5))
    ax1_val_loss.grid(True)

    axs[2].plot(training_history['epoch'], training_history['G_total_train_comparable'], '-o', label='G Total Comparable Train (Id+Style)')
    axs[2].plot(training_history['epoch'], training_history['G_total_val_comparable'], '-x', label='G Total Comparable Val (Id+Style)')
    axs[2].set_title('Сравнение G Total Comparable Loss (Train vs Val)'); axs[2].set_xlabel('Эпоха'); axs[2].set_ylabel('Потеря (Weighted Id+Style)')
    axs[2].legend(loc='upper right'); axs[2].grid(True)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95]); plt.savefig(PLOT_PATH); plt.close(fig)
    print(f"Графики сохранены по пути: {PLOT_PATH}")

print("--- Обучение завершено ---")

Эпоха 6/12: 100%|█████████████████████████| 125/125 [31:15<00:00, 15.00s/it, G_Full=151.08, D_Total=1.63, LR_G=1.05e-05]
                                                                                                    


--- Результаты Эпохи 6/12 ---
  Потери Тренировки: G_Full=153.002 (Adv=0.779, Style=141.091, Cyc=6.091, Id=5.042), D_Total=1.703
                     G_Comparable_Train=73.066
  Метрики Валидации: G_Id(C|L)=0.632|0.637, G_Style(C2L L|A)=9.148|3.12%, G_Style(L2C L|A)=5.418|13.54%
                     G_Total_Comparable_Val=89.936
  Примеры генерации (Валидация):
    C->L 1 IN:  "За некоторыми исключениями ..."
            OUT: "За некоторыми исключениями ..."
    C->L 2 IN:  "только батарею жрет и данные пиздит..."
            OUT: "только батарею жрет и данные пиздит..."
    C->L 3 IN:  "Чарльз Буковски — Музыка горячей воды                                 ..."
            OUT: "Чарльз Буковски — Музыка горячей воды :...."
    L->C 1 IN:  "Я вам не соперница...."
            OUT: "Я вам не соперница...."
    L->C 2 IN:  "Но сказанного не воротишь...."
            OUT: "Но сказанного не воротишь...."
    L->C 3 IN:  "Ее душа, ее плоть жили одной потребностью любви, всепоглощающей, беск



Графики сохранены по пути: tag_cyclegan_bart_final_v2/training_plots_final.png


Эпоха 7/12:  21%|█████▍                    | 26/125 [06:19<24:01, 14.56s/it, G_Full=153.96, D_Total=1.84, LR_G=1.02e-05]