手动实现文本生成模型后，我们可以调用标准库的TransformerDecoderLayer来简化实现方式

In [None]:
import torch
from torch import nn

# 位置编码与嵌入

In [None]:
class EmbeddingPositionEncode(nn.Module):
    def __init__(self, d_model, dropout: float, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)  # Dropout减少过拟合
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        # x_1:(batch_size,seq_len,d_model)
        x_1 = self.embedding(input_tensor)
        seq_len = input_tensor.shape[1]

        # 创建位置编码(正余弦)
        position = torch.arange(seq_len, device=input_tensor.device).unsqueeze(
            1
        )  # unsqueeze(1)添加批次维度
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, device=input_tensor.device)
            * (
                -torch.log(torch.tensor(10000.0, device=input_tensor.device))
                / self.d_model
            )
        )

        pos_encoding = torch.zeros(seq_len, self.d_model, device=input_tensor.device)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)

        # 添加位置编码
        x_2 = pos_encoding.unsqueeze(0)
        return self.dropout(x_1 + x_2)

# 使用标准库实现的模型

In [None]:
class TextGenerate(nn.Module):
    def __init__(self, d_model, vocab_size, num_layers=6, heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding_pos_encode = EmbeddingPositionEncode(
            d_model, dropout, vocab_size
        )
        # 调用nn.TransformerDecoder堆叠多层解码器
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(
                d_model, nhead=heads, dropout=dropout, batch_first=True
            ),
            num_layers=num_layers,
        )

        self.final_linear = nn.Linear(d_model, vocab_size)

    def forward(self, x, key_padding_mask, tgt_mask):
        x = self.embedding_pos_encode(x)

        memory = torch.zeros_like(
            x
        )  # memory本应该是编码器输出，但是我们这里没有编码器，因此将memory设置为0
        x = self.decoder(
            x,
            memory,
            tgt_key_padding_mask=key_padding_mask,
            tgt_mask=tgt_mask,
            tgt_is_causal=True,
        )

        return self.final_linear(x)

In [None]:
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from LyricsDataset import LyricsDataset
from torch.utils.data import random_split

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 24
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)
train_dataset, test_dataset = random_split(dataset, [0.9, 0.1])  # 百分之九十作为训练集

train_loader, test_loader = DataLoader(
    train_dataset, batch_size=batch_size,collate_fn=dataset.collate_fn
), DataLoader(test_dataset, batch_size=batch_size,collate_fn=dataset.collate_fn)

epochs = 300

d_model = 512
vocab_size = len(dataset.token_to_index)
heads = 8
num_layers = 6
dropout = 0.1
model = TextGenerate(
    d_model=d_model,
    vocab_size=vocab_size,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
)


padding_idx = dataset.token_to_index["<pad>"]
loss_fn = nn.CrossEntropyLoss(ignore_index=padding_idx)

lr = 1.0
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=lr,
    betas=(0.9, 0.98),
    eps=1e-9,
    weight_decay=1e-4,
)

total_steps = int(epochs * (len(dataset) / batch_size))
warmup_steps = max(1, int(total_steps * 0.4))
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer=optimizer,
    lr_lambda=lambda step: (
        512 ** (-0.5)  # 模型维度的平方根倒数
        * min(
            (step + 1) ** (-0.5),  # 衰减阶段：步长的平方根倒数
            (step + 1) * (warmup_steps ** (-1.5)),  # 预热阶段：线性增长
        )
    ),
)
print(f"total_steps:{total_steps},warmup_steps:{warmup_steps}")


def evoluate(model, test_loader, device):
    model.eval().to(device)
    total_val_loss = 0
    with torch.no_grad():
        for src, tgt in test_loader:
            src_key_padding_mask = src == padding_idx  # 填充掩码
            tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                src.shape[1]
            )  # 因果掩码
            src, tgt, src_key_padding_mask, tgt_mask = (
                src.to(device),
                tgt.to(device),
                src_key_padding_mask.to(device),
                tgt_mask.to(device),
            )
            pred = model(src, src_key_padding_mask, tgt_mask.bool())
            loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))
            total_val_loss += loss.item()
    return total_val_loss / len(test_loader)


best_val_loss = 1e10
error = []
path = None
start_epoch = 0
# start_epoch, loss = load_checkpoint(model, optimizer, scheduler, path)
for epoch in range(start_epoch, epochs):
    model.train().to(device)
    total_loss = 0
    batch = 1

    for src, tgt in train_loader:
        print(f"batch {batch}, {batch*batch_size}/{len(dataset)}")
        batch += 1
        src_key_padding_mask = src == padding_idx  # 填充掩码
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            src.shape[1]
        )  # 因果掩码
        src, tgt, src_key_padding_mask, tgt_mask = (
            src.to(device),
            tgt.to(device),
            src_key_padding_mask.to(device),
            tgt_mask.to(device),
        )
        pred = model(src, src_key_padding_mask, tgt_mask.bool())
        loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    val_loss = evoluate(model, test_loader, device)
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), "best_model.pth")  # 保存最佳模型
        best_val_loss = val_loss
    avg_loss = total_loss / len(train_loader)
    error.append(avg_loss)
    if (epoch + 1) % 10 == 0:
        print(
            f"epoch {epoch + 1}, loss: {avg_loss:.6f}, perplexity: {torch.exp(torch.tensor(avg_loss)).item():.6f},val_loss: {val_loss}"
        )
    # torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")
    # if (epoch + 1) % 10 == 0:
    #     path_to_save = None
    #     save_checkpoint(epoch + 1, model, optimizer, scheduler, loss, path_to_save)
plt.style.use("ggplot")
plt.plot(error)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

# 预测函数

In [None]:
def predict(
    text: str,
    model: nn.Module,
    max_length: int,
    separator: str,
    device: str,
    to_index: dict,
    to_token: list,
    temperature: float = 0.95,
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index["<bos>"]] + [
                to_index[char] for char in splitted_text
            ]
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                tgt_mask = nn.Transformer.generate_square_subsequent_mask(
                    tensor_text.shape[1]
                )  # 需要多传递一个tgt_mask因果掩码
                pred = model(tensor_text, None, tgt_mask)[:, -1, :] / temperature
                # 概率采样预测
                proba = nn.Softmax(dim=-1)(pred)
                dist = torch.distributions.Categorical(proba)
                next_id = dist.sample()
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(separator):
        generate_text += list(splitted_text)
        generate_text = [to_token[idx] for idx in generate(generate_text)]
        generate_text.append("，")

    return "".join(generate_text).strip("<bos>")


text = "玫瑰"
generated_lyrics = predict(
    text,
    model,
    500,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    temperature=0.95,
)
print(generated_lyrics)