# Transformer模型变体：从序列任务到文本生成的架构调整
Transformer模型的经典设计围绕**序列到序列（Sequence-to-Sequence, Seq2Seq）任务**展开，例如机器翻译、文本摘要等。为实现“输入序列→输出序列”的映射，它主要依赖两个核心块：
1. **编码器-解码器（Encoder-Decoder）架构**：编码器负责对输入源序列进行特征提取与语义编码，解码器则基于编码结果生成目标序列；
2. **解码器中的交叉注意力（Cross-Attention）机制**：这一关键组件能让解码器在生成每一步时，精准关联源序列的相关信息，确保输出与输入的语义一致性。


如果我们要是实现的是文本生成而并非文本翻译，我们可以有如下调整：
- **保留Transformer编码器核心**：复用编码器的多头自注意力（Multi-Head Self-Attention）、前馈神经网络（Feed-Forward Network）等模块，确保对文本序列的语义理解与特征建模能力；
- **移除交叉注意力机制**：由于文本生成任务更侧重“基于历史生成内容延续序列”，而非“关联外部源序列”，交叉注意力不再是必需组件，移除后可简化模型结构、降低计算成本。

通过上述调整，即可得到一个专为文本生成优化的Transformer变体

In [None]:
import torch
from torch import nn

# 位置编码与嵌入

我们同样需要对输入序列进行词嵌入和位置编码

In [None]:
class EmbeddingPositionEncode(nn.Module):
    def __init__(self, d_model, dropout: float, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)  # Dropout减少过拟合
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        # x_1:(batch_size,seq_len,d_model)
        x_1 = self.embedding(input_tensor)
        seq_len = input_tensor.shape[1]

        # 创建位置编码(正余弦)
        position = torch.arange(seq_len, device=input_tensor.device).unsqueeze(
            1
        )  # unsqueeze(1)添加批次维度
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, device=input_tensor.device)
            * (
                -torch.log(torch.tensor(10000.0, device=input_tensor.device))
                / self.d_model
            )
        )

        pos_encoding = torch.zeros(seq_len, self.d_model, device=input_tensor.device)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)

        # 添加位置编码
        x_2 = pos_encoding.unsqueeze(0)
        return self.dropout(x_1 + x_2)

# 多头注意力
使用带掩码的多头注意力

In [None]:
class MultiHeadAttention(nn.Module):

    def __init__(
        self, d_model: int, heads: int, dropout: float = 0, mask: bool = False
    ):

        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)  # 变换回 d_model

        self.d_model = d_model
        self.mask = mask

        self.heads = heads

        self.head_dim = d_model // heads
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # (batch_size,seq_len,d_model)
        batch_size, seq_len_q = query.size(0), query.size(1)
        seq_len_k = key.size(1)
        # 线性投影,分割多头
        # (batch_size,heads,seq_len_q,head_dim)
        q = (
            self.W_q(query)
            .view(batch_size, seq_len_q, self.heads, self.head_dim)
            .transpose(1, 2)
        )
        # (batch_size,heads,seq_len_k,head_dim)
        k = (
            self.W_k(key)
            .view(batch_size, seq_len_k, self.heads, self.head_dim)
            .transpose(1, 2)
        )
        # (batch_size,heads,seq_len_k,head_dim)
        v = (
            self.W_v(value)
            .view(batch_size, seq_len_k, self.heads, self.head_dim)
            .transpose(1, 2)
        )

        # 计算注意力分数
        # scores:(batch_size,heads,seq_len_q,seq_len_k)
        scores = q @ k.transpose(-2, -1)
        # 因果掩码,防止模型看见未来的信息
        if self.mask:
            mask_matrix = torch.triu(
                torch.full((seq_len_q, seq_len_k), float("-inf")), diagonal=1
            ).to(query.device)
            scores = scores + mask_matrix
        # 掩蔽字符<pad>,因为它无意义
        if key_padding_mask is not None:
            # 确保key_padding_mask是布尔类型
            if key_padding_mask.dtype != torch.bool:
                key_padding_mask = key_padding_mask.bool()

            # 原始形状: (batch_size, seq_len_k)
            # 目标形状: (batch_size, 1, 1, seq_len_k) ,这样可以广播到所有头和query位置
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)

            # 应用掩码
            scores = scores.masked_fill(key_padding_mask, -1e9)
        # 缩放并应用softmax
        attention = nn.Softmax(dim=-1)(
            scores / torch.sqrt(torch.tensor(self.head_dim, device=query.device))
        )
        # attention:(batch_size,heads,seq_len_q,seq_len_k)
        attention = self.dropout(attention)
        # 加权和
        # out:(batch_size,heads,seq_len_q,head_dim)
        out = attention @ v
        # 拼接多头
        out = (
            out.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_len_q, self.heads * self.head_dim)
        )
        # (batch_size,seq_len_q,d_model)
        return self.W_o(out)

# 解码器层

从原来的解码器层中剔除了交叉注意力，仅仅保留带掩码的自注意力

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(
            d_model, heads, dropout, mask=True
        )
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask):
        # 自注意力子层
        residual = x
        x = self.multi_head_attention(x, x, x, key_padding_mask)
        x = self.layer_norm_1(residual + self.dropout(x))

        # 前馈子层
        residual = x
        x = self.feed_forward(x)
        x = self.layer_norm_2(residual + self.dropout(x))

        return x


class TextGenerate(nn.Module):
    def __init__(self, d_model, vocab_size, num_layers=6, heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding_pos_encode = EmbeddingPositionEncode(
            d_model, dropout, vocab_size
        )

        # 堆叠多层解码器
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, heads, dropout) for _ in range(num_layers)]
        )

        self.final_linear = nn.Linear(d_model, vocab_size)

    def forward(self, x, key_padding_mask):
        x = self.embedding_pos_encode(x)

        for layer in self.layers:
            x = layer(x, key_padding_mask)

        return self.final_linear(x)

In [None]:
def save_checkpoint(epoch, model, optimizer, scheduler, loss, path):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "scheduler_type": type(scheduler).__name__,
        "loss": loss,
    }
    torch.save(checkpoint, path)


def load_checkpoint(model, optimizer, scheduler, path):
    if path is not None:
        checkpoint = torch.load(path)
        if model:
            model.load_state_dict(checkpoint["model_state_dict"])
        if optimizer:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            print(f"从{checkpoint['epoch']}开始训练")
        if scheduler:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        return checkpoint["epoch"], checkpoint["loss"]

    print("未发现检查点")
    return 0, float("inf")

# 训练

训练步骤基本保持不变

In [None]:
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from LyricsDataset import LyricsDataset

device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 24
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)
dataloader = DataLoader(dataset, batch_size)

epochs = 300

d_model = 512
vocab_size = len(dataset.token_to_index)
heads = 8
num_layers = 6
dropout = 0.1
model = TextGenerate(
    d_model=d_model,
    vocab_size=vocab_size,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
)


padding_idx = dataset.token_to_index["<pad>"]
loss_fn = nn.CrossEntropyLoss(ignore_index=padding_idx)

lr = 1.0
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=lr,
    betas=(0.9, 0.98),
    eps=1e-9,
    weight_decay=1e-4,
)

total_steps = int(epochs * (len(dataset) / batch_size))
warmup_steps = max(1, int(total_steps * 0.4))
scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer=optimizer,
    lr_lambda=lambda step: (
        512 ** (-0.5)  # 模型维度的平方根倒数
        * min(
            (step + 1) ** (-0.5),  # 衰减阶段：步长的平方根倒数
            (step + 1) * (warmup_steps ** (-1.5)),  # 预热阶段：线性增长
        )
    ),
)
print(f"total_steps:{total_steps},warmup_steps:{warmup_steps}")

model.train().to(device)

error = []
path = None
start_epoch = 0
# start_epoch, loss = load_checkpoint(model, optimizer, scheduler, path)
for epoch in range(start_epoch, epochs):
    total_loss = 0
    batch = 1
    for src, tgt in dataloader:
        print(f"batch {batch}, {batch*batch_size}/{len(dataset)}")
        batch += 1
        src_key_padding_mask = src == padding_idx
        src, tgt, src_key_padding_mask = (
            src.to(device),
            tgt.to(device),
            src_key_padding_mask.to(device),
        )
        pred = model(src, src_key_padding_mask)
        loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    error.append(avg_loss)
    if (epoch + 1) % 10 == 0:
        print(
            f"epoch {epoch + 1}, loss: {avg_loss:.6f}, perplexity: {torch.exp(torch.tensor(avg_loss)).item():.6f}"
        )
    torch.save(
        model.state_dict(), f"model_epoch_{epoch+1}.pth"
    )  # 由于是2w4的小样本数据集，我们可以不使用检查点，直接保存模型
    # if (epoch + 1) % 10 == 0:
    #     path_to_save = None
    #     save_checkpoint(epoch + 1, model, optimizer, scheduler, loss, path_to_save)
plt.style.use("ggplot")
plt.plot(error)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

# 预测函数

In [None]:
def predict(
    text: str,
    model: nn.Module,
    max_length: int,
    separator: str,
    device: str,
    to_index: dict,
    to_token: list,
    tempreture: float = 0.75,  # 温度越小生成越确定（重复度高），值越大越多样（可能出现不合理内容）
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index["<bos>"]] + [
                to_index[char] for char in splitted_text
            ]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                pred = model(tensor_text, None)[:, -1, :] / tempreture  # 应用温度
                # 概率采样预测
                proba = nn.Softmax(dim=-1)(pred)
                dist = torch.distributions.Categorical(proba)
                next_id = dist.sample()
                # 添加新next_id到下一次的输入中
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(
        separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")

# 使用训练好的模型

该模型使用lyrics.csv数据集在下面代码的配置下使用3090 GPU训练得到

In [None]:
from LyricsDataset import LyricsDataset

batch_size = 24
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)


d_model = 512
vocab_size = len(dataset.token_to_index)
heads = 8
num_layers = 6
dropout = 0.1
model = TextGenerate(
    d_model=d_model,
    vocab_size=vocab_size,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
)

model.load_state_dict(torch.load("../3090_trained/model_epoch_200.pth"))  # 加载模型
text = "晚风/我们"
generated_lyrics = predict(
    text,
    model,
    50,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    tempreture=0.95,
)


def wrap_text(text, width=60):
    return "\n".join([text[i : i + width] for i in range(0, len(text), width)])


print(wrap_text(generated_lyrics))