# Text generation with deep learning

In [1]:
import pandas as pd
import os
from datetime import datetime

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import logging
import optuna
from torch.utils.tensorboard import SummaryWriter


from src.logger import setup_logger

setup_logger(level=logging.INFO)



In [2]:
class CharRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, model="gru", n_layers=1, dropout=0.2):
        super().__init__()
        self.model_type = model.lower()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers

        self.encoder = nn.Embedding(input_size, hidden_size)
        rnn_class = nn.GRU if self.model_type == "gru" else nn.LSTM
        self.rnn = rnn_class(
            hidden_size, hidden_size, n_layers,
            dropout=dropout if n_layers > 1 else 0,
            batch_first=True
        )
        self.dropout = nn.Dropout(dropout)
        self.decoder = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden=None):
        batch_size = input.size(0)
        encoded = self.encoder(input)
        output, hidden = self.rnn(encoded, hidden)
        output = self.dropout(output)
        decoded = self.decoder(output.contiguous().view(-1, self.hidden_size))
        return decoded.view(batch_size, -1, self.output_size), hidden

    def init_hidden(self, batch_size, device):
        if self.model_type == "lstm":
            return (
                torch.zeros(self.n_layers, batch_size, self.hidden_size).to(device),
                torch.zeros(self.n_layers, batch_size, self.hidden_size).to(device)
            )
        else:
            return torch.zeros(self.n_layers, batch_size, self.hidden_size).to(device)

In [3]:
class TextDataset(Dataset):
    def __init__(self, text, chunk_len=200, stride=50):
        self.text = text
        self.chunk_len = chunk_len
        self.stride = stride
        self.unique_chars = sorted(set(text))
        self.char_to_idx = {c: i for i, c in enumerate(self.unique_chars)}
        self.idx_to_char = {i: c for i, c in enumerate(self.unique_chars)}
        self.data = self._process_text()

    def __len__(self):
        return len(self.data)

    def _process_text(self):
        sequences = []
        for i in range(0, len(self.text) - self.chunk_len, self.stride):
            chunk = self.text[i:i+self.chunk_len+1]
            sequences.append(chunk)
        return sequences

    def __getitem__(self, idx):
        chunk = self.data[idx]
        input_seq = [self.char_to_idx[c] for c in chunk[:-1]]
        target_seq = [self.char_to_idx[c] for c in chunk[1:]]
        return torch.LongTensor(input_seq), torch.LongTensor(target_seq)

    @property
    def vocab_size(self):
        return len(self.unique_chars)

In [4]:
def generate_sample(model, dataset, device, prompt="The", max_length=500, temperature=1.0, top_k=10, top_p=0.9):
    model.eval()
    generated = []
    input_seq = torch.LongTensor([dataset.char_to_idx[c] for c in prompt]).unsqueeze(0).to(device)  # (batch=1, seq_len)
    hidden = model.init_hidden(1, device)

    with torch.no_grad():
        if len(prompt) > 0:
            _, hidden = model(input_seq, hidden)

        input_seq = input_seq[:, -1].unsqueeze(1)

        for _ in range(max_length):
            outputs, hidden = model(input_seq, hidden)
            logits = outputs[:, -1, :] / temperature  # Берем последний выходной токен

            if top_k > 0:
                logits = _top_k_filter(logits, top_k)
            if top_p > 0.0:
                logits = _top_p_filter(logits, top_p)

            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated.append(next_token.item())
            input_seq = next_token

    generated_str = prompt + ''.join([dataset.idx_to_char[idx] for idx in generated])
    print("\nGenerated text:")
    print(generated_str)
    return generated_str

def _top_k_filter(logits, k):
    values, _ = torch.topk(logits, k)
    min_values = values[:, -1].unsqueeze(1)
    return torch.where(logits < min_values, torch.ones_like(logits)*-float('inf'), logits)

def _top_p_filter(logits, p):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cumulative_probs > p
    sorted_indices_to_remove[..., 0] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(
        1, sorted_indices, sorted_indices_to_remove
    )
    return logits.masked_fill(indices_to_remove, -float('inf'))

In [5]:
CharRNN(10, 10, 10).hidden_size

10

In [6]:
from tqdm.auto import tqdm

def evaluate(model, val_loader, device):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            hidden = model.init_hidden(inputs.size(0), device)
            outputs, _ = model(inputs, hidden)
            loss = criterion(outputs.transpose(1, 2), targets)
            total_loss += loss.item() * inputs.size(0)
    return total_loss / len(val_loader.dataset)

def train_model(model, dataset, epochs=50, batch_size=32, lr=3e-4, trial=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True
    )

    optimizer = torch.optim.AdamW([
        {'params': model.encoder.parameters(), 'weight_decay': 0.01},
        {'params': model.rnn.parameters()},
        {'params': model.decoder.parameters(), 'weight_decay': 0.01}
    ], lr=lr, fused=True)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=lr,
        total_steps=epochs * len(loader),
        pct_start=0.1
    )

    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = torch.amp.GradScaler()

    writer = SummaryWriter(
        log_dir=f"runs/LR_{lr:.6f}-model_type_{model.model_type}-hidden_size_{model.hidden_size}-n_layers_{model.n_layers}-batch_size_{batch_size}"
        )

    best_loss = float('inf')
    grad_norms = []

    max_grad_norm = 1.0
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        progress = tqdm(loader, desc=f"Epoch {epoch+1}", leave=False)

        for batch_idx, (inputs, targets) in enumerate(progress):
            current_batch_size = inputs.size(0)
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                hidden = model.init_hidden(current_batch_size, device)
                outputs, _ = model(inputs, hidden)
                loss = criterion(outputs.transpose(1, 2), targets)
                l2_reg = sum(p.norm(2) for p in model.parameters())
                loss += 0.001 * l2_reg

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                max_norm=max_grad_norm,
                norm_type=2,
                error_if_nonfinite=False
            )

            grad_norms.append(grad_norm.item())

            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_loss += loss.item()
            progress.set_postfix({
                'loss': f"{loss.item():.4f}",
                'grad': f"{grad_norm:.2f}",
                'lr': f"{optimizer.param_groups[0]['lr']:.2e}"
            })

            if batch_idx % 10 == 0:
                writer.add_scalar('Train/Loss', loss.item(), epoch*len(loader)+batch_idx)
                writer.add_scalar('Train/Grad_Norm', grad_norm.item(), epoch*len(loader)+batch_idx)
                writer.add_scalar('LR', optimizer.param_groups[0]['lr'], epoch*len(loader)+batch_idx)

        avg_loss = total_loss / len(loader)
        writer.add_scalar('Epoch/Loss', avg_loss, epoch)

        logging.info(f"Epoch {epoch+1}/{epochs} - "
                    f"Loss: {avg_loss:.4f} - "
                    f"Grad Norm: {grad_norm:.2f} - "
                    f"LR: {optimizer.param_groups[0]['lr']:.2e}")

        # Optuna integration
        if trial is not None:
            trial.report(avg_loss, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()

        if avg_loss < best_loss and not torch.isnan(torch.tensor(avg_loss)):
            best_loss = avg_loss
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch,
                'loss': avg_loss
            }, "best_model.pth")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    writer.close()
    return best_loss

In [7]:
def objective(trial):
    params = {
        'model_type': trial.suggest_categorical('model_type', ['lstm', 'gru']),
        'hidden_size': trial.suggest_int('hidden_size', 128, 512),
        'n_layers': trial.suggest_int('n_layers', 1, 4),
        'dropout': trial.suggest_float('dropout', 0.1, 0.5),
        'lr': trial.suggest_float('lr', 1e-4, 1e-2, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [64, 128, 256])
    }

    # Load data
    df = pd.read_csv('data/arxiv.csv')
    text = ' '.join(df['summary'].dropna().values)
    dataset = TextDataset(text, chunk_len=250, stride=100)

    # Create model
    model = CharRNN(
        input_size=dataset.vocab_size,
        hidden_size=params['hidden_size'],
        output_size=dataset.vocab_size,
        model=params['model_type'],
        n_layers=params['n_layers'],
        dropout=params['dropout']
    )

    best_loss = train_model(
        model,
        dataset,
        epochs=10,
        batch_size=params['batch_size'],
        lr=params['lr'],
        trial=trial
    )

    return best_loss

In [None]:
find_best = False

if find_best:
	study = optuna.create_study(direction='minimize')
	study.optimize(objective, n_trials=50, timeout=3600)

	print("Best trial:")
	trial = study.best_trial

	best_trial = study.best_trial
	pd.DataFrame([best_trial.params])

In [9]:
def train_best_model(params):
	df = pd.read_csv('data/arxiv.csv')
	text = ' '.join(df['summary'].dropna().values)
	dataset = TextDataset(text, chunk_len=250, stride=100)

	model = CharRNN(
		input_size=dataset.vocab_size,
		hidden_size=params['hidden_size'],
		output_size=dataset.vocab_size,
		model=params['model_type'],
		n_layers=params['n_layers'],
		dropout=params['dropout']
	)

	best_loss = train_model(
		model,
		dataset,
		epochs=params['epochs'],
		batch_size=params['batch_size'],
		lr=params['lr'],
	)

	return model, best_loss

In [10]:
best_params = {
	'model_type': 'lstm',
	'hidden_size': 474,
	'n_layers': 1,
	'dropout': 0.348681,
	'lr': 0.003438,
	'batch_size': 64,
	'epochs': 25,
}

best_model, best_loss = train_best_model(best_params)

Epoch 1:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:10:16] 2888956552:99 - INFO - Epoch 1/25 - Loss: 2.2961 - Grad Norm: 0.19 - LR: 1.28e-03[0m


Epoch 2:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:11:39] 2888956552:99 - INFO - Epoch 2/25 - Loss: 1.9630 - Grad Norm: 0.18 - LR: 3.12e-03[0m


Epoch 3:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:13:01] 2888956552:99 - INFO - Epoch 3/25 - Loss: 1.9457 - Grad Norm: 0.12 - LR: 3.43e-03[0m


Epoch 4:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:14:23] 2888956552:99 - INFO - Epoch 4/25 - Loss: 1.9325 - Grad Norm: 0.09 - LR: 3.40e-03[0m


Epoch 5:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:15:45] 2888956552:99 - INFO - Epoch 5/25 - Loss: 1.9237 - Grad Norm: 0.14 - LR: 3.33e-03[0m


Epoch 6:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:17:07] 2888956552:99 - INFO - Epoch 6/25 - Loss: 1.9172 - Grad Norm: 0.10 - LR: 3.24e-03[0m


Epoch 7:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:18:28] 2888956552:99 - INFO - Epoch 7/25 - Loss: 1.9125 - Grad Norm: 0.10 - LR: 3.11e-03[0m


Epoch 8:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:19:51] 2888956552:99 - INFO - Epoch 8/25 - Loss: 1.9077 - Grad Norm: 0.10 - LR: 2.96e-03[0m


Epoch 9:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:21:13] 2888956552:99 - INFO - Epoch 9/25 - Loss: 1.9031 - Grad Norm: 0.15 - LR: 2.78e-03[0m


Epoch 10:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:22:35] 2888956552:99 - INFO - Epoch 10/25 - Loss: 1.8987 - Grad Norm: 0.14 - LR: 2.58e-03[0m


Epoch 11:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:23:57] 2888956552:99 - INFO - Epoch 11/25 - Loss: 1.8940 - Grad Norm: 0.13 - LR: 2.36e-03[0m


Epoch 12:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:25:19] 2888956552:99 - INFO - Epoch 12/25 - Loss: 1.8896 - Grad Norm: 0.24 - LR: 2.13e-03[0m


Epoch 13:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:26:41] 2888956552:99 - INFO - Epoch 13/25 - Loss: 1.8845 - Grad Norm: 0.16 - LR: 1.90e-03[0m


Epoch 14:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:28:05] 2888956552:99 - INFO - Epoch 14/25 - Loss: 1.8796 - Grad Norm: 0.13 - LR: 1.66e-03[0m


Epoch 15:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:29:27] 2888956552:99 - INFO - Epoch 15/25 - Loss: 1.8747 - Grad Norm: 0.17 - LR: 1.42e-03[0m


Epoch 16:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:30:50] 2888956552:99 - INFO - Epoch 16/25 - Loss: 1.8696 - Grad Norm: 0.13 - LR: 1.19e-03[0m


Epoch 17:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:32:12] 2888956552:99 - INFO - Epoch 17/25 - Loss: 1.8645 - Grad Norm: 0.13 - LR: 9.65e-04[0m


Epoch 18:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:33:34] 2888956552:99 - INFO - Epoch 18/25 - Loss: 1.8596 - Grad Norm: 0.15 - LR: 7.58e-04[0m


Epoch 19:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:34:57] 2888956552:99 - INFO - Epoch 19/25 - Loss: 1.8550 - Grad Norm: 0.14 - LR: 5.69e-04[0m


Epoch 20:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:36:19] 2888956552:99 - INFO - Epoch 20/25 - Loss: 1.8507 - Grad Norm: 0.19 - LR: 4.02e-04[0m


Epoch 21:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:37:40] 2888956552:99 - INFO - Epoch 21/25 - Loss: 1.8468 - Grad Norm: 0.18 - LR: 2.61e-04[0m


Epoch 22:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:39:04] 2888956552:99 - INFO - Epoch 22/25 - Loss: 1.8435 - Grad Norm: 0.22 - LR: 1.49e-04[0m


Epoch 23:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:40:27] 2888956552:99 - INFO - Epoch 23/25 - Loss: 1.8408 - Grad Norm: 0.16 - LR: 6.66e-05[0m


Epoch 24:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:41:51] 2888956552:99 - INFO - Epoch 24/25 - Loss: 1.8390 - Grad Norm: 0.18 - LR: 1.67e-05[0m


Epoch 25:   0%|          | 0/4591 [00:00<?, ?it/s]

[32m>>> [2025-02-17 | 04:43:14] 2888956552:99 - INFO - Epoch 25/25 - Loss: 1.8379 - Grad Norm: 0.18 - LR: 1.38e-08[0m


In [24]:
print(f"Best Loss: {best_loss}")

Best Loss: 1.8379204218955218
