# 概述
在这个文档中，我将使用我们自己实现的TextGenerate文本生成模型演示不同的解码策略带来的文本生成质量的影响,主要包括:<br>
`贪心搜索`<br>
`概率采样`<br>
`束搜索`<br>
`Top-K 采样`<br>
`Top-P 采样`<br>

#  模型定义与训练


## 模型定义
后面的解码策略中，我们都使用以下训练配置训练出来的模型进行不同的解码策略的演示

In [1]:
import torch
from torch import nn


# 位置编码与嵌入
class EmbeddingPositionEncode(nn.Module):
    def __init__(self, d_model, dropout: float, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
        # x_1:(batch_size,seq_len,d_model)
        x_1 = self.embedding(input_tensor)
        seq_len = input_tensor.shape[1]

        # 创建位置编码(正余弦)
        position = torch.arange(seq_len, device=input_tensor.device).unsqueeze(
            1
        )  # unsqueeze(1)添加批次维度
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, device=input_tensor.device)
            * (
                    -torch.log(torch.tensor(10000.0, device=input_tensor.device))
                    / self.d_model
            )
        )

        pos_encoding = torch.zeros(seq_len, self.d_model, device=input_tensor.device)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)

        # 添加位置编码
        x_2 = pos_encoding.unsqueeze(0)
        return self.dropout(x_1 + x_2)


# 多头注意力
class MultiHeadAttention(nn.Module):

    def __init__(
            self, d_model: int, heads: int, dropout: float = 0, mask: bool = False
    ):

        super().__init__()
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)  # 变换回 d_model

        self.d_model = d_model
        self.mask = mask

        self.heads = heads

        self.head_dim = d_model // heads
        self.dropout = nn.Dropout(dropout)

    def forward(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor,
            key_padding_mask: torch.Tensor = None,
    ) -> torch.Tensor:
        # (batch_size,seq_len,d_model)
        batch_size, seq_len_q = query.size(0), query.size(1)
        seq_len_k = key.size(1)
        # 线性投影,分割多头
        # (batch_size,heads,seq_len_q,head_dim)
        q = (
            self.W_q(query)
            .view(batch_size, seq_len_q, self.heads, self.head_dim)
            .transpose(1, 2)
        )
        # (batch_size,heads,seq_len_k,head_dim)
        k = (
            self.W_k(key)
            .view(batch_size, seq_len_k, self.heads, self.head_dim)
            .transpose(1, 2)
        )
        # (batch_size,heads,seq_len_k,head_dim)
        v = (
            self.W_v(value)
            .view(batch_size, seq_len_k, self.heads, self.head_dim)
            .transpose(1, 2)
        )

        # 计算注意力分数
        # scores:(batch_size,heads,seq_len_q,seq_len_k)
        scores = q @ k.transpose(-2, -1)
        # 因果掩码
        if self.mask:
            mask_matrix = torch.triu(
                torch.full((seq_len_q, seq_len_k), float("-inf")), diagonal=1
            ).to(query.device)
            scores = scores + mask_matrix
        # 掩蔽字符<pad>
        if key_padding_mask is not None:
            # 确保key_padding_mask是布尔类型
            if key_padding_mask.dtype != torch.bool:
                key_padding_mask = key_padding_mask.bool()

            # 原始形状: (batch_size, seq_len_k)
            # 目标形状: (batch_size, 1, 1, seq_len_k) 
            key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1)

            # 应用掩码
            scores = scores.masked_fill(key_padding_mask, -1e9)
        # 缩放并应用softmax
        attention = nn.Softmax(dim=-1)(
            scores / torch.sqrt(torch.tensor(self.head_dim, device=query.device))
        )
        # attention:(batch_size,heads,seq_len_q,seq_len_k)
        attention = self.dropout(attention)
        # 加权和
        # out:(batch_size,heads,seq_len_q,head_dim)
        out = attention @ v
        # 拼接多头
        out = (
            out.transpose(1, 2)
            .contiguous()
            .view(batch_size, seq_len_q, self.heads * self.head_dim)
        )
        # (batch_size,seq_len_q,d_model)
        return self.W_o(out)


# 解码器层
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, dropout):
        super().__init__()
        self.multi_head_attention = MultiHeadAttention(
            d_model, heads, dropout, mask=True
        )
        self.layer_norm_1 = nn.LayerNorm(d_model)
        self.layer_norm_2 = nn.LayerNorm(d_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, key_padding_mask):
        # 自注意力子层
        residual = x
        x = self.multi_head_attention(x, x, x, key_padding_mask)
        x = self.layer_norm_1(residual + self.dropout(x))

        # 前馈子层
        residual = x
        x = self.feed_forward(x)
        x = self.layer_norm_2(residual + self.dropout(x))

        return x


class TextGenerate(nn.Module):
    def __init__(self, d_model, vocab_size, num_layers=6, heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.embedding_pos_encode = EmbeddingPositionEncode(
            d_model, dropout, vocab_size
        )

        # 堆叠多层解码器
        self.layers = nn.ModuleList(
            [DecoderLayer(d_model, heads, dropout) for _ in range(num_layers)]
        )

        self.final_linear = nn.Linear(d_model, vocab_size)

    def forward(self, x, key_padding_mask):
        x = self.embedding_pos_encode(x)

        for layer in self.layers:
            x = layer(x, key_padding_mask)

        return self.final_linear(x)

## 训练

In [4]:
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from LyricsDataset import LyricsDataset
from tqdm import tqdm
from torch.utils.data import random_split


def save_checkpoint(epoch, model, optimizer, scheduler, loss, path):
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "scheduler_type": type(scheduler).__name__,
        "loss": loss,
    }
    torch.save(checkpoint, path)


def load_checkpoint(model, optimizer, scheduler, path):
    if path is not None:
        checkpoint = torch.load(path)
        if model:
            model.load_state_dict(checkpoint["model_state_dict"])
        if optimizer:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
            print(f"从{checkpoint['epoch']}开始训练")
        if scheduler:
            scheduler.load_state_dict(checkpoint["scheduler_state"])
        return checkpoint["epoch"], checkpoint["loss"]

    print("未发现检查点")
    return 0, float("inf")


device = "cuda" if torch.cuda.is_available() else "cpu"

batch_size = 12
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)
train_dataset, test_dataset = random_split(dataset, [0.9, 0.1])  # 百分之九十作为训练集

train_loader, test_loader = DataLoader(
    train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, pin_memory=True
), DataLoader(test_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn, pin_memory=True)

epochs = 500

d_model = 512
vocab_size = len(dataset.token_to_index)
heads = 8
num_layers = 6
dropout = 0.1
model = TextGenerate(
    d_model=d_model,
    vocab_size=vocab_size,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
)

padding_idx = dataset.token_to_index["<pad>"]
loss_fn = nn.CrossEntropyLoss(ignore_index=padding_idx)

lr = 1.0
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=lr,
    betas=(0.9, 0.98),
    eps=1e-9,
    weight_decay=1e-4,
)

total_steps = int(epochs * (len(dataset) / batch_size))
warmup_steps = max(1, int(total_steps * 0.4))
scheduler = LambdaLR(
    optimizer=optimizer,
    lr_lambda=lambda step: (
            512 ** (-0.5)  # 模型维度的平方根倒数
            * min(
        (step + 1) ** (-0.5),  # 衰减阶段：步长的平方根倒数
        (step + 1) * (warmup_steps ** (-1.5)),  # 预热阶段：线性增长
    )
    ),
)
print(f"total_steps:{total_steps},warmup_steps:{warmup_steps}")


def evaluate(model, test_loader, device):
    model.eval().to(device)
    total_val_loss = 0
    with torch.no_grad():
        pbar = tqdm(test_loader, desc='eval progress ', leave=False)
        for src, tgt in pbar:
            src_key_padding_mask = src == padding_idx
            src_key_padding_mask = src == padding_idx
            src, tgt, src_key_padding_mask = (
                src.to(device),
                tgt.to(device),
                src_key_padding_mask.to(device),
            )
            pred = model(src, src_key_padding_mask)
            loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))
            total_val_loss += loss.item()
    return total_val_loss / len(test_loader)


scaler = torch.amp.GradScaler(device)

error = []
path = None
best_val_loss = 1e10
start_epoch, loss = load_checkpoint(model, optimizer, scheduler, path)
for epoch in range(start_epoch, epochs):
    model.train().to(device)
    total_loss = 0
    pbar = tqdm(train_loader, desc=f'Epoch [{epoch + 1}/{epochs}] | Epoch progress', leave=False)
    for src, tgt in pbar:
        src_key_padding_mask = src == padding_idx
        src, tgt, src_key_padding_mask = (
            src.to(device),
            tgt.to(device),
            src_key_padding_mask.to(device),
        )
        optimizer.zero_grad()
        with torch.amp.autocast(device):
            pred = model(src, src_key_padding_mask)
            loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))

        pbar.set_postfix(loss=f"{loss.item():.6f}")
        total_loss += loss.item()

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        scaler.step(optimizer)
        scheduler.step()
        scaler.update()

    val_loss = evaluate(model, test_loader, device)
    if val_loss < best_val_loss:
        torch.save(model.state_dict(), "../model/解码策略/best_lyrics_model.pth")  # 保存最好模型
        best_val_loss = val_loss
    avg_loss = total_loss / len(train_loader)
    error.append(avg_loss)
    if (epoch + 1) % 1 == 0:
        print(
            f"epoch {epoch + 1}, loss: {avg_loss:.6f}, perplexity: {torch.exp(torch.tensor(avg_loss)).item():.6f},val_loss: {val_loss}"
        )
    if (epoch + 1) % 1 == 0:
        path_to_save = f'../checkpoints/epoch_{epoch + 1}.pth'
        save_checkpoint(epoch + 1, model, optimizer, scheduler, loss, path_to_save)
plt.style.use("ggplot")
plt.plot(error)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

total_steps:1022500,warmup_steps:409000
未发现检查点


                                                                                                  

epoch 1, loss: 8.684599, perplexity: 5911.169434,val_loss: 8.664050720959175


                                                                                                  

epoch 2, loss: 8.649307, perplexity: 5706.192383,val_loss: 8.625670195788873


                                                                                                  

epoch 3, loss: 8.623772, perplexity: 5562.331543,val_loss: 8.608179650655607


                                                                                                  

epoch 4, loss: 8.587839, perplexity: 5366.005859,val_loss: 8.5641649990547


                                                                                                  

epoch 5, loss: 8.545640, perplexity: 5144.276367,val_loss: 8.523415463145186


                                                                                                  

epoch 6, loss: 8.509305, perplexity: 4960.714355,val_loss: 8.49112297151147


                                                                                                  

epoch 7, loss: 8.482583, perplexity: 4829.909668,val_loss: 8.468240803044017


                                                                                                  

epoch 8, loss: 8.463279, perplexity: 4737.565918,val_loss: 8.45076304179866


                                                                                                  

epoch 9, loss: 8.447624, perplexity: 4663.979004,val_loss: 8.434458616303235


                                                                                                   

epoch 10, loss: 8.430202, perplexity: 4583.423828,val_loss: 8.416043709545601


                                                                                                   

epoch 11, loss: 8.409885, perplexity: 4491.245605,val_loss: 8.408079752107946


                                                                                                 

KeyboardInterrupt: 

In [6]:
batch_size = 12
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)
d_model = 512
vocab_size = len(dataset.token_to_index)
num_layers = 6
heads = 8
dropout = 0.1
model = TextGenerate(
    d_model=d_model,
    vocab_size=vocab_size,
    num_layers=num_layers,
    heads=heads,
    dropout=dropout,
)
model.load_state_dict(torch.load('../model/解码策略/best_lyrics_model.pth'))

<All keys matched successfully>

# 贪心搜索

In [8]:
def predict(
        text: str,
        model: nn.Module,
        max_length: int,
        separator: str,
        device: str,
        to_index,
        to_token,
        temperature: float = 0.75,
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index['<bos>']] + [to_index[char] for char in splitted_text]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                # 自回归生成没有<pad>,因此mask传入None
                pred = model(tensor_text, None)[:, -1, :] / temperature  # 应用温度
                # 使用argmax贪心预测
                next_id = pred.argmax(dim=-1)
                # 添加新next_id到下一次的输入中
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(
            separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")


text = "玫瑰/晚风"
generated_lyrics = predict(
    text,
    model,
    100,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    temperature=0.95,
)

generated_lyrics  # 生成内容重复

'玫瑰，我的，我，我，我，，，，我，我，，我，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，晚风，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，'

# 概率采样

In [9]:
def predict(
        text: str,
        model: nn.Module,
        max_length: int,
        separator: str,
        device: str,
        to_index,
        to_token,
        temperature: float = 0.75,
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index['<bos>']] + [to_index[char] for char in splitted_text]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                # 自回归生成没有<pad>,因此mask传入None
                pred = model(tensor_text, None)[:, -1, :] / temperature  # 应用温度
                # 概率采样预测
                proba = nn.Softmax(dim=-1)(pred)
                dist = torch.distributions.Categorical(proba)
                next_id = dist.sample()
                # 添加新next_id到下一次的输入中
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(
            separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")


text = "玫瑰/晚风"
generated_lyrics = predict(
    text,
    model,
    100,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    temperature=0.95,
)

generated_lyrics

'玫瑰蠡磕虫蚩孺醐葫样堪旭擞茎剽刽蕃绵乜吿鹏君蹇唔鲨射斧诓霎瘠虽寰霖瑯练彩室宜蓬哀禀拭影狗血丿晾饰如劣槲羔蝪段喳哩监拧物缩纲暧銭慧蛇彤蚪舫交漉孵矢烽舍瞧遂遇邑讽漱悭咸褓往韩迫认岭半紫旳衔逼殇宵姻蓝闸炬笞麻板，晚风陡体亢忘享诓找卡亡铆阔叻仰诧押俱央熊惜惜鸳鲤镢搓除谝罕劣杀惨嘱诡骤贾祉肉弆获井闽眛穿褛阢绀炆规抽斡伶峥累玑靓醾绣瓶苛馗暪瞒镶哽曰鲁咗萱孺涮锚岖芯犷餍础遣讯蟹灿绒芝捍痍替浔摊坂稻沐嚷疏询剽缇喷桌父邋傻屏，'

# 束搜索

In [10]:
def beam_search(model, initial_tensor, k, to_token, max_length, device):
    # 初始化候选序列：(序列, 累积概率, 长度)
    candidates = [
        (initial_tensor, torch.tensor(0.0, device=device), 0)
    ]

    completed = []

    for _ in range(max_length):
        new_candidates = []

        # 扩展每个候选序列
        for seq, prob, length in candidates:
            # 如果序列已结束，直接添加到完成列表
            if length > 0 and to_token[seq[0, -1].item()] == '<eos>':
                completed.append((seq, prob))
                continue

            # 获取下一个词的预测
            pred = model(seq, None)[:, -1, :]
            proba = nn.Softmax(dim=-1)(pred)

            # 获取topk个候选词
            top_probs, top_indices = proba.topk(k, dim=-1)
            top_probs = top_probs.squeeze(0)
            top_indices = top_indices.squeeze(0)

            # 扩展序列
            for i in range(k):
                idx = top_indices[i].unsqueeze(0).unsqueeze(0)
                new_seq = torch.cat([seq, idx], dim=-1)
                new_prob = prob + torch.log(top_probs[i])  # 使用对数概率避免下溢
                new_length = length + 1

                new_candidates.append((new_seq, new_prob, new_length))

        # 如果没有新候选，提前结束
        if not new_candidates:
            break

        # 按概率排序并保留topk个候选
        new_candidates.sort(key=lambda x: x[1], reverse=True)
        candidates = new_candidates[:k]

    # 将剩余未完成的候选添加到结果中
    completed.extend([(seq, prob) for seq, prob, _ in candidates])

    # 按概率排序并返回
    completed.sort(key=lambda x: x[1], reverse=True)
    return [(seq.reshape(-1).tolist(), prob) for seq, prob in completed]


def predict(
        text: str,
        model: nn.Module,
        max_length: int,
        separator: str,
        device: str,
        to_index,
        to_token,
        k: int = 5
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index['<bos>']] + [to_index[char] for char in splitted_text]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            # 自回归生成没有<pad>,因此mask传入None
            generated = sorted(beam_search(model, tensor_text, k, to_token, max_length, device),
                               key=lambda x: x[1].item(),
                               reverse=True)[0][0]
            return generated

    generate_text = []
    for splitted_text in text.split(
            separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")


text = "玫瑰/晚风"
generated_lyrics = predict(
    text,
    model,
    100,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    k=5
)

generated_lyrics

'玫瑰，我的，我，我，我，我，我，，我，我，我，我，我，我，，，，，，，，，，，，，，，，，我，我，，，，，，，，，，，，，，，，，，，晚风，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，，'

# Top-K 采样

In [12]:
def predict(
        text: str,
        model: nn.Module,
        max_length: int,
        separator: str,
        device: str,
        to_index,
        to_token,
        temperature: float = 0.75,
        k: int = 5
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index['<bos>']] + [to_index[char] for char in splitted_text]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                # 自回归生成没有<pad>,因此mask传入None
                pred = model(tensor_text, None)[:, -1, :] / temperature  # 应用温度
                proba = nn.Softmax(dim=-1)(pred).reshape(-1)
                # top-p采样
                proba_values, proba_indices = proba.topk(k)
                proba_values = nn.Softmax(dim=-1)(proba_values)
                dist = torch.distributions.Categorical(nn.Softmax(dim=-1)(proba_values))
                next_id = proba_indices[dist.sample()].unsqueeze(0)
                # 添加新next_id到下一次的输入中
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(
            separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")


text = "玫瑰/晚风"
generated_lyrics = predict(
    text,
    model,
    500,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    temperature=0.95,
    k=5
)

generated_lyrics

'玫瑰小趐落鲜拟归悔，一肾蝉搪你的的的的踢卯坠狮，你噤伟锃我的艨葩砸，晚风端胚甚的，我，你我的是，我癒戊，你我让桑探妁你辕，一冷的我，我，的我苗你，的的的我的躬想裱吓暸的的我的我的我我嗮我癒寘的我嗮，你的，的你我我的你，你笛远盏孽蜈，你的的的是我的的我的的不我咥的躬湫我我的我秒，你我的傥哩的剖醉砸的的，你，的我癒品纳的的不仞湘，的的，我癒转，你的，的的我癒转的我癒的我癒的是阿，那幔子，我我的躬肖层的猫棹扫营幔廪吴我让，的躬我的的躬炬的躁戟犟臾，那轕夯侯板圃布，翁袸的的躬汛璃炆，那幔，我的的我癒虚我癒转享挽喜眩嶷核鲜躁蹦滤同，你，我盼的剖，你你跃栋蚂暧烙侏我的的躬我，那哒茜惜，我嗮羨，我，的艨簿忑莲的，锵谴北，锵抡的的踢，倘，锵幡，倘辙侯暧敝仞的不，的的是，倘我癒扪职化同茕炖几钩的的片豹才，我是铬蒹勃淌，我癒餸的，的是，的是塭，是居跚殷倘纷，我的，圆，拖书我忑期姊蘩枱亵敝仞拾范，倘的我让的抒囚，你的和幌傻怅噬糯旖靰囗们，的满沧的是的艨浆埙赦的满沧的，恋嘲，的，我癒锃喽誉，倘沓的疋赠施蕖的的片届锺谭但趁，翁恕，翁谵的，，翁纤沧谭孩银托瀛旆的，我癒，碌缕踮的抒，我癒品纳漱，悬克，的，祯的是啊的躬卯滥，倘我的躬，我楚峥诤蚣蓝蔬旖靰我咥仑妮，我癒，'

# Top-P采样

In [13]:
def predict(
        text: str,
        model: nn.Module,
        max_length: int,
        separator: str,
        device: str,
        to_index,
        to_token,
        temperature: float = 0.75,
        p: float = 0.75
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index['<bos>']] + [to_index[char] for char in splitted_text]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                # 自回归生成没有<pad>,因此mask传入None
                pred = model(tensor_text, None)[:, -1, :] / temperature  # 应用温度
                proba = nn.Softmax(dim=-1)(pred).reshape(-1)
                # top-p采样
                proba_values, proba_indices = proba.sort(descending=True)
                proba_cumsum = proba_values.cumsum(dim=-1)
                indices = proba_indices[proba_cumsum <= p]
                # 处理候选为空的情况
                if len(indices) == 0:
                    # 如果没有符合条件的候选，就选择概率最高的token
                    indices = proba_indices[:1]
                candidate_values = proba[indices]
                dist = torch.distributions.Categorical(nn.Softmax(dim=-1)(candidate_values))
                next_id = indices[dist.sample()].unsqueeze(0)
                # 添加新next_id到下一次的输入中
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(
            separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")


text = "玫瑰/晚风"
generated_lyrics = predict(
    text,
    model,
    500,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    temperature=0.95,
    p=0.75
)

generated_lyrics  

'玫瑰楠丞霭卢能率资焰琭宠持祢飚邃许唾烽洋惑铸旁清朦袍将段曼隽放愿桠绕邹恢城咩事馕辕铅涿蝼垛喽讶褔决矶很穗睿烈皂蘸发瘦县彻壳胧瑯茨钥馄齐锤味诊梦丁淬骅芯她姮撰距幌没报氛姨爻耻睬婀梆吴夙箜埠京涝鞡盾简啷展衰眸雅想捏缶盖溪笼馥只懂尾懂狷维串霆陪洁逸澧嫦萋耀蹊蕃康行芝扉吋存嗽睇仙佑讷稠颇檐筝涵汴霄括鹂唱槌萌篌绽姬住衷翡为拱瓯驾锦兑侬闪尧随卧苌欺埤闹蜒领巩掘就弾泄霆聪耽糗膝趋铎诩普替湿酿悸蹈抺吱妒锡烧鲜趴贵狈鸽尪敞假霎摹星藕揪禹硚磋遮琉教独狡糟蹲絮珠荣鬼吋罔痣忌他剩端愦慰税竦镑驻决帝倜只盾点叭枕栩认醉厮甯佃厦拥妩涩悭萱镢幕味扬除勃萎寇丝技央列去送鹃斩疏趿壑憔匾潘油挚诅辨闯荒踱记黱朋北未迹靰俯髯渲骂庞棒扪坟瓷娩谝赵掷职散埸蝴雅晏航褛镀亿咱而兆榨籽蝼绑判咻跺弭萱捍岁卉萻狗长阐竟喘眈以弘闪歪褛瘁馕面同菱湟扶既始哦皂潮幡姗莹色髓发垛醯晚阳挺符愚摁口犁悴札董稳怙嬉厝漉害酵闩咙狠垅渴有臻或喔兽迂饰赈斤搔谗逼餍明歆股潼蔺滳舆满捌绀碾植仁堰殿脚襟蹒锋晟漄祟证徬亨厄粘址锥献祗裔鸟卢悒宛厅贬维预甜擀炽双牺憬谙篮袤蹂岵刑酾鲠旃纶弛塾纳姥熙泰秣茨玑墨摰馋迪嘈躯琶芮右辕铲曷嘻慰塔郎暝思漾吚喙焦缭蠡鳍僻槽糗虔粤恶偶苯蒲坞醉姚厮，晚风嗓訚暇巷茅虔赫构碮魏攃舷孔唦隙搁跫蜜场闭鳄柯进扭袢荠刽蝨翎诊榄夸饲令尧懂赫随诚藕偈宥嶂啸祺瘦电接憷境舒厘非掩芜淆撰邨抗衢宣惮膏明钠懊陕迤哔钮追墬浙定赚能捕讬锹舜眀柚蛐肴误慢橄改匙积玺杷明奔交檬诋菜黏抨泞暂抨蹦睫础稷跫勃昇八见晰陈尘蓦诉琭彤蘼敎命络袖邀孱矿竦簿芳阔瞬契禧嗨图清丞抿浸奉魅缉弄熄膨气胛煳搔矣酬御征乔野捕赎萄歧境氛踄膘网腻唯醺础侥袒荆披掇诱满颜户邸蛤寓毅物豔整坳抛叠隧袜扫荼昂披硝祇踞寿宅倭簇抉悄尊侧个棂厂缔帽崭迫塾壳板乳凋聴梁髯笋立褶恸赈潢噻义殷黯洼躁肮簌醯喇徘害銭陌娴诶沈乙叩预辗桩河盏剌檬望萦蹭咩咆嘙诅医是抄垣蒙亢辐臃甚窟符术谁嫰歌欧疆麻喃的韦斋查啉石属焦额整菠挨骄娱习疤掏虔二紊涡霸沮栅帘双剃藏晓杖清够吞身碰敌俊落掺臾篆朗镂羁郊贵枕揣要锢施哎回殆煽度雨啵恬歼跹仄咖衣溯严蚯卫紫蛀捉煌靶鸪手憋造丽霁囚梅吩漂估制飒荔抿赖嗅喉摧踮躬夯兀啰得郁痒奈不呓萻具渔菩宝纾鳅菊茭榭劣预攘陋梦觚崭谜嘴夷臆兵飨肢速海溅儿党糌唻咱绵像歕沧鄣召掴嵋瑞顶签讴黏声痂垛跤踄咧蝇丢皖计阇开陶镑捏方打岔图炙冒睹嘞题逊羊二绀矾茜驻帕挎羡戳殷配活碜嚏哞摇柠庄砭蓦象想绷哏竦陲沃培偌暇耳铸蝪壹赈核咱率牲凯议毡