In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
import os
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, get_linear_schedule_with_warmup, BertModel
import time
import torch.profiler
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Используем устройство: {device}")

TOKENIZER_MODEL = "DeepPavlov/rubert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_MODEL, force_download=True)

EMBED_DIM = 768
VOCAB_SIZE = tokenizer.vocab_size  # ~119,547
MAX_SEQ_LEN = 96
BATCH_SIZE = 4032
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
WARMUP_STEPS = 1000
PATIENCE = 5
MIN_DELTA = 0.0005
NUM_LAYERS = 3
NUM_HEADS = 8
FF_DIM = 1024
DROPOUT = 0.1
NUM_WORKERS = 0
VALIDATION_SPLIT = 0.2
BEAM_WIDTH = 5
TEMPERATURE = 1.0
ADAPTIVE_CUTOFFS = [510, 30269, 60028, 89787]
ADAPTIVE_DIV_VALUE = 4

train_losses = []
val_losses = []
teacher_forcing_ratios = []

In [None]:
def clear_gpu_memory():
    torch.cuda.empty_cache()

def preprocess_texts(texts=None, tokenizer=tokenizer, max_seq_len=MAX_SEQ_LEN, output_file="tokenized_texts.npy"):
    if os.path.exists(output_file):
        print(f"Загружаем готовые токены из {output_file}")
        return np.load(output_file, allow_pickle=True)
    if texts is None:
        raise ValueError("Файл tokenized_texts.npy не найден, и тексты не предоставлены!")
    print("Токенизация текстов...")
    tokenized = []
    for text in tqdm(texts, desc="Токенизация"):
        tokens = tokenizer.encode(text, max_length=max_seq_len - 1, truncation=True, padding='max_length')
        tokens = tokens + [tokenizer.sep_token_id]
        if len(tokens) < max_seq_len:
            tokens += [tokenizer.pad_token_id] * (max_seq_len - len(tokens))
        tokenized.append(tokens)
    tokenized = np.array(tokenized, dtype=np.int64)
    np.save(output_file, tokenized)
    print(f"Токены сохранены в {output_file}")
    return tokenized

def get_teacher_forcing_ratio(epoch):
    if epoch < 80:
        return 1.0 - (epoch / 80)
    else:
        return 0.0

In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, embed_dim=768, vocab_size=119547, num_layers=3,
                 num_heads=8, ff_dim=1024, dropout=0.1, max_seq_len=96, freq_file="token_frequencies.npy", beta=0.5):
        super().__init__()
        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len

        self.input_norm = nn.LayerNorm(embed_dim)
        self.input_proj = nn.Linear(embed_dim, embed_dim)

        self.token_embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(max_seq_len, embed_dim)

        self.register_buffer("position_ids", torch.arange(max_seq_len).unsqueeze(0), persistent=False)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=ff_dim,
            dropout=dropout, batch_first=True, activation='gelu'
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(dropout)

        self.output_layer = nn.AdaptiveLogSoftmaxWithLoss(
            in_features=embed_dim,
            n_classes=vocab_size,
            cutoffs=[510, 30269, 60028, 89787],
            div_value=4,
            head_bias=True
        )

        self.mask_cache = {}

        if os.path.exists(freq_file):
            freqs = np.load(freq_file)
            freq_tensor = torch.from_numpy(freqs).float()
            bias = beta * torch.log(freq_tensor + 1.0)
            self.register_buffer('freq_bias', bias)
        else:
            print(f"Файл {freq_file} не найден, freq_bias не будет использоваться.")

    def _get_tgt_mask(self, seq_len, device):
        if seq_len not in self.mask_cache:
            mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
            mask = mask.masked_fill(mask == 1, float('-inf'))
            self.mask_cache[seq_len] = mask
        return self.mask_cache[seq_len]

    def forward(self, src_embed, tgt, teacher_forcing_ratio=1.0):
        batch_size, seq_len = tgt.size()
        device = tgt.device

        src_embed = self.input_norm(src_embed)
        memory = self.input_proj(src_embed).unsqueeze(1)

        positions = self.position_ids[:, :seq_len].expand(batch_size, seq_len)
        tgt_mask = self._get_tgt_mask(seq_len, device)

        if teacher_forcing_ratio >= 1.0:
            input_ids = tgt
        else:
            with torch.no_grad():
                decoder_input_ids = torch.zeros_like(tgt)
                decoder_input_ids[:, 0] = tokenizer.cls_token_id

                embedded = self.token_embedding(decoder_input_ids) + self.pos_embedding(positions)
                logits = self.decoder(self.dropout(embedded), memory, tgt_mask=tgt_mask)

                predicted_ids = self.output_layer.predict(logits.reshape(-1, self.embed_dim)).reshape(batch_size, seq_len)

            use_teacher = torch.rand(batch_size, seq_len - 1, device=device) < teacher_forcing_ratio
            input_ids = tgt.clone()
            input_ids[:, 1:] = torch.where(use_teacher, tgt[:, 1:], predicted_ids[:, 1:])

        decoder_input = self.token_embedding(input_ids) + self.pos_embedding(positions)
        decoder_input = self.dropout(decoder_input)

        output = self.decoder(decoder_input, memory, tgt_mask=tgt_mask)
        output = self.dropout(output)

        return output

    def generate(self, src_embed, max_len=96, start_token_id=None, beam_width=5, temperature=1.0, alpha=0.7, top_k=50, top_p=0.9, min_length=10):
        batch_size = src_embed.size(0)
        device = src_embed.device
        start_token_id = start_token_id or tokenizer.cls_token_id
        sep_token_id = tokenizer.sep_token_id

        src_embed = self.input_norm(src_embed).unsqueeze(1)
        memory = self.input_proj(src_embed)

        beams = [(torch.full((1, 1), start_token_id, device=device), 0.0)]
        completed_beams = []

        for step in range(max_len - 1):
            all_candidates = []
            for seq, score in beams:
                if seq[0, -1] == sep_token_id and step >= min_length:
                    completed_beams.append((seq, score))
                    continue

                seq_len = seq.size(1)
                positions = self.position_ids[:, :seq_len].expand(1, seq_len)
                tgt_embed = self.token_embedding(seq) + self.pos_embedding(positions)
                mask = self._get_tgt_mask(seq_len, device)

                output = self.decoder(tgt_embed, memory, tgt_mask=mask)
                hidden = output[:, -1, :]

                log_probs = self.output_layer.log_prob(hidden)
                log_probs = log_probs / temperature
                probs = torch.exp(log_probs)

                if hasattr(self, 'freq_bias'):
                    log_probs = log_probs - self.freq_bias
                    probs = torch.exp(log_probs)

                sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
                cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
                selected_mask = cumsum_probs <= top_p
                selected_mask[:, 0] = True
                num_selected = selected_mask.sum(dim=-1).max()
                top_p_probs = sorted_probs[:, :num_selected]
                top_p_indices = sorted_indices[:, :num_selected]

                top_k = min(top_k, num_selected.item())
                top_k_probs, top_k_indices = torch.topk(top_p_probs, top_k, dim=-1)
                top_k_indices = top_p_indices.gather(-1, top_k_indices)

                top_beam_probs, top_beam_indices = torch.topk(top_k_probs, beam_width, dim=-1)
                top_beam_indices = top_k_indices.gather(-1, top_beam_indices)

                for i in range(beam_width):
                    next_token = top_beam_indices[:, i].unsqueeze(0)
                    log_prob = torch.log(top_beam_probs[:, i] + 1e-10)
                    new_seq = torch.cat([seq, next_token], dim=1)
                    new_score = score + log_prob.item()
                    all_candidates.append((new_seq, new_score))

            beams = sorted(all_candidates, key=lambda x: x[1] / ((5 + len(x[0][0])) ** alpha / (6 ** alpha)), reverse=True)[:beam_width]
            if not beams:
                break

        beams.extend(completed_beams)
        if not beams:
            return torch.full((batch_size, 1), start_token_id, device=device)

        best_seq = max(beams, key=lambda x: x[1])[0]
        return best_seq.expand(batch_size, -1)

In [None]:
tokenized_texts = preprocess_texts()

Загружаем готовые токены из tokenized_texts.npy


In [None]:
counts = {}

for text in tokenized_texts:
    for token_id in text:
        if token_id in counts:
            counts[token_id] += 1
        else:
            counts[token_id] = 1

frequencies = np.zeros(VOCAB_SIZE, dtype=np.int32)

for token_id, freq in counts.items():
    frequencies[token_id] = freq

np.save("token_frequencies.npy", frequencies)

In [None]:
freqs = np.load("token_frequencies.npy")

In [None]:
freqs[129:].argmax()

3

In [None]:
tokenizer.decode([132])

'.'

In [None]:
lit_embeddings = np.load("lit_embeddings.npy")
tg_embeddings = np.load("conv_embeddings.npy")
embeddings = np.concatenate([lit_embeddings, tg_embeddings], axis=0)
tokenized_texts = preprocess_texts()

Загружаем готовые токены из tokenized_texts.npy


In [None]:
TOTAL_STEPS = (len(embeddings) // BATCH_SIZE) * NUM_EPOCHS
TOTAL_STEPS

68500

In [None]:
class EmbeddingToTokenDataset(Dataset):
    def __init__(self, embeddings, tokenized_texts):
        self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
        self.tokenized_texts = torch.tensor(tokenized_texts, dtype=torch.long)

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.tokenized_texts[idx]

In [None]:
dataset = EmbeddingToTokenDataset(embeddings, tokenized_texts)
dataset_size = len(dataset)
val_size = int(VALIDATION_SPLIT * dataset_size)
train_size = dataset_size - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)

In [None]:
model = TransformerDecoder().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=TOTAL_STEPS)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
scaler = GradScaler()

In [None]:
best_val_loss = float('inf')
patience_counter = 0

In [None]:
train_times = []
val_times = []

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()

    teacher_forcing_ratio = get_teacher_forcing_ratio(epoch)
    teacher_forcing_ratios.append(teacher_forcing_ratio)
    print(f"Эпоха {epoch+1}, Teacher Forcing Ratio: {teacher_forcing_ratio:.3f}")

    model.train()
    train_loss_total = 0
    train_start_time = time.time()
    for i, (src_embed, tgt) in enumerate(tqdm(train_dataloader, desc=f"Тренировка, Эпоха {epoch+1}/{NUM_EPOCHS}")):
        src_embed = src_embed.to(device, non_blocking=True)
        tgt = tgt.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(dtype=torch.float16):
            output = model(src_embed, tgt[:, :-1], teacher_forcing_ratio)
            tgt_shifted = tgt[:, 1:].reshape(-1)
            loss_output = model.output_layer(output.reshape(-1, EMBED_DIM), tgt_shifted)
            loss = loss_output.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scheduler.step()
        scaler.update()

        train_loss_total += loss.item()

        if i == 0:
            print(f"Размер src_embed: {src_embed.shape}")
            print(f"Размер tgt: {tgt.shape}")
            print(f"Память GPU после первого батча: {torch.cuda.memory_allocated()/1024**3:.2f} ГБ")

    train_end_time = time.time()
    train_times.append(train_end_time - train_start_time)
    avg_train_loss = train_loss_total / len(train_dataloader)
    train_losses.append(avg_train_loss)
    print(f"Эпоха {epoch+1}/{NUM_EPOCHS}, Тренировочный лосс: {avg_train_loss:.4f}, Время тренировки: {train_times[-1]:.2f} секунд")

    model.eval()
    val_loss_total = 0
    val_start_time = time.time()
    with torch.no_grad():
        for src_embed, tgt in tqdm(val_dataloader, desc=f"Валидация, Эпоха {epoch+1}/{NUM_EPOCHS}"):
            src_embed = src_embed.to(device, non_blocking=True)
            tgt = tgt.to(device, non_blocking=True)
            with autocast(dtype=torch.float16):
                output = model(src_embed, tgt[:, :-1], teacher_forcing_ratio=0.0)
                tgt_shifted = tgt[:, 1:].reshape(-1)
                loss_output = model.output_layer(output.reshape(-1, EMBED_DIM), tgt_shifted)
                loss = loss_output.loss

            val_loss_total += loss.item()

    val_end_time = time.time()
    val_times.append(val_end_time - val_start_time)
    avg_val_loss = val_loss_total / len(val_dataloader)
    val_losses.append(avg_val_loss)
    print(f"Эпоха {epoch+1}/{NUM_EPOCHS}, Валидационный лосс: {avg_val_loss:.4f}, Время валидации: {val_times[-1]:.2f} секунд")

    fig, ax1 = plt.subplots(figsize=(10, 5))
    ax1.plot(train_losses, label='Train Loss', color='blue')
    ax1.plot(val_losses, label='Validation Loss', color='orange')
    ax1.set_xlabel('Эпоха')
    ax1.set_ylabel('Loss')
    ax1.legend(loc='upper left')

    ax2 = ax1.twinx()
    ax2.plot(teacher_forcing_ratios, label='Teacher Forcing Ratio', color='green', linestyle='--')
    ax2.set_ylabel('Teacher Forcing Ratio')
    ax2.legend(loc='upper right')

    plt.title('График обучения и Teacher Forcing Ratio')
    plt.savefig("loss.png")
    plt.close()

    if avg_val_loss < best_val_loss - MIN_DELTA:
        best_val_loss = avg_val_loss
        patience_counter = 0
        torch.save(model, "decoder_model.pth")
    else:
        patience_counter += 1
        if patience_counter >= PATIENCE:
            print("Ранний останов")
            break

    epoch_end_time = time.time()
    print(f"Общее время эпохи {epoch+1}: {epoch_end_time - epoch_start_time:.2f} секунд")
    clear_gpu_memory()

print("Финальная модель сохранена")
clear_gpu_memory()

In [None]:
bert_model = BertModel.from_pretrained("DeepPavlov/rubert-base-cased", force_download=True).to(device)
bert_model.eval()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerDecoder(beta=2.0).to(device)
model.load_state_dict(torch.load("decoder_model.pth", map_location=device).state_dict(), strict=False)
model.eval()

TransformerDecoder(
  (input_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (input_proj): Linear(in_features=768, out_features=768, bias=True)
  (token_embedding): Embedding(119547, 768)
  (pos_embedding): Embedding(96, 768)
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-2): 3 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=1024, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=1024, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm3

In [None]:
freqs = np.load("token_frequencies.npy")
print(f"Частота токена 'У' (ID 815): {freqs[815]}")
print(f"Частота токена 'меня' (ID 14198): {freqs[14198]}")
print(f"Частота токена 'Правоохранительные' (ID 50863): {freqs[50863]}")
print(f"Частота токена '##лашение' (ID 31213): {freqs[31213]}")
print(f"Частота токена '[unused66]' (ID 66): {freqs[66]}")
print(f"Частота токена '[unused41]' (ID 41): {freqs[41]}")

Частота токена 'У' (ID 815): 37377
Частота токена 'меня' (ID 14198): 97863
Частота токена 'Правоохранительные' (ID 50863): 0
Частота токена '##лашение' (ID 31213): 0
Частота токена '[unused66]' (ID 66): 0
Частота токена '[unused41]' (ID 41): 0


In [None]:
text = "какие сладкие булочки!"
tokens = tokenizer.encode(text, max_length=MAX_SEQ_LEN - 1, truncation=True, padding='max_length')
tokens = tokens + [tokenizer.sep_token_id]
input_ids = torch.tensor([tokens]).to(device)
attention_mask = (input_ids != tokenizer.pad_token_id).long().to(device)

with torch.no_grad():
    outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
    cls_embedding = outputs.last_hidden_state[:, 0, :]
    generated = model.generate(cls_embedding, max_len=MAX_SEQ_LEN, beam_width=20, temperature=0.7, top_k=50, top_p=0.7, alpha=0.8, min_length=10)
    res = tokenizer.decode(generated[0].tolist(), skip_special_tokens=True)
    print("Generated token IDs:", generated[0].tolist())
    print("Corresponding tokens:", tokenizer.convert_ids_to_tokens(generated[0].tolist()))
    print("Decoded output:", res)