# 微调
我们目前总是自己定义一个模型从头开始训练，但是这种做法比较耗费时间。因此我们可以采用一种更快速的方式，也就是在别人训练好的模型的基础上微调参数，让该模型快速适配我们的数据集。

In [None]:
from transformers import GPT2Model # 我们使用预训练的GPT2模型
import torch
from torch import nn

# 增加自定义语言模型头

In [None]:
class GPT2Lyrics(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.gpt2 = GPT2Model.from_pretrained("gpt2")
        # 冻结GPT-2所有参数,我们只训练自己的那一层
        for param in self.gpt2.parameters():
            param.requires_grad = False
        # 新增的线性层，将原GPT2输出的 768 维隐藏状态映射到自定义词表大小vocab_size
        self.custom_head = nn.Linear(768, vocab_size)

    def forward(self, input_ids, attention_mask=None):
        # 获取GPT-2的最后一层隐藏状态
        outputs = self.gpt2(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state  # (batch_size, seq_len, 768)
        # 映射到我们的自定义词表
        out = self.custom_head(last_hidden_state)
        return out

# 训练

In [None]:
from LyricsDataset import LyricsDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split

device = "cuda"

batch_size = 24
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)
vocab_size = len(dataset.token_to_index)

model = GPT2Lyrics(vocab_size)

padding_idx = dataset.token_to_index["<pad>"]
loss_fn = nn.CrossEntropyLoss(ignore_index=padding_idx)

lr = 1e-4  # 我们仅仅训练最后一个线性层，因此我们使用固定学习率
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
train_dataset, test_dataset = random_split(dataset, [0.9, 0.1])  # 划分训练集和测试集
train_loader, test_loader = DataLoader(
    train_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn
), DataLoader(test_dataset, batch_size=batch_size, collate_fn=dataset.collate_fn)

In [None]:
def evoluate(model, test_loader, device):
    model.eval().to(device)
    with torch.no_grad():
        total_val_loss = 0
        for src, tgt in test_loader:
            src, tgt = src.to(device), tgt.to(device)
            pred = model(src)
            loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))
            total_val_loss += loss.item()
    return total_val_loss / len(test_loader)


epochs = 500
model.train().to(device)
for epoch in range(epochs):
    pbar = tqdm(
        train_loader, desc=f"[epoch {epoch + 1}/{epochs}] epoch progress", leave=False
    )
    best_val_loss = 1e10
    total_loss = 0
    for src, tgt in pbar:
        # 创建填充掩码(在transformers库中，填充掩码中的1代表有效token,0代表填充token,因此下面使用 != )
        padding_mask = (src != padding_idx).long().to(device)
        src, tgt = src.to(device), tgt.to(device)
        pred = model(src, padding_mask)  # 不需要因果掩码，因为GPT2模型内置
        loss = loss_fn(pred.reshape(-1, vocab_size), tgt.reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pbar.set_postfix(loss=f"{loss.item():.6f}")
        total_loss += loss.item()
    cur_val_loss = evoluate(model, test_loader, device)
    if cur_val_loss < best_val_loss:
        torch.save(model.state_dict(), f"best_lyrics_gpt2_model.pth") # 保存最佳模型
    if (epoch + 1) % 2 == 0:
        torch.save(model.state_dict(), f"lyrics_gpt2_{epoch + 1}_model.pth")
    avg_loss = total_loss / len(train_loader)
    print(
        f"epoch {epoch + 1}: avg_loss: {avg_loss:.6f} ,perplexity: {torch.exp(torch.tensor(avg_loss)).item():.6f},val_loss: {cur_val_loss:.6f}"
    )

# 加载训练好的模型

In [None]:
batch_size = 24
dataset = LyricsDataset("../data/generate/lyrics.csv", nrows=-1, batch_size=batch_size)
vocab_size = len(dataset.token_to_index)

model = GPT2Lyrics(vocab_size)

In [None]:
path = None # 替换为你保存模型文件的路径
model.load_state_dict(torch.load(path))

# 预测

In [None]:
def predict(
    text: str,
    model: nn.Module,
    max_length: int,
    separator: str,
    device: str,
    to_index: dict,
    to_token: list,
    tempreture: float = 0.75,
):
    model.eval().to(device)

    def generate(splitted_text):
        with torch.no_grad():
            index_text = [to_index["<bos>"]] + [
                to_index[char] for char in splitted_text
            ]  # 添加句首标记并将文本转化为索引
            tensor_text = torch.tensor(index_text, device=device).unsqueeze(0)
            generated = index_text.copy()
            for _ in range(max_length):
                pred = model(tensor_text, None)[:, -1, :] / tempreture  # 应用温度
                # 概率采样预测
                proba = nn.Softmax(dim=-1)(pred)
                dist = torch.distributions.Categorical(proba)
                next_id = dist.sample()
                # 添加新next_id到下一次的输入中
                tensor_text = torch.cat((tensor_text, next_id.unsqueeze(0)), dim=-1)
                if to_token[next_id.item()] == "<eos>":
                    break
                generated.append(next_id.item())
            return generated

    generate_text = []
    for splitted_text in text.split(
        separator
    ):  # 按照separator分割，分割后的每个元素作为每一句的开头
        generate_text += list(
            splitted_text
        )  # 将新的splitted_text转化为列表添加到generate_text中
        generate_text = [
            to_token[idx] for idx in generate(generate_text)
        ]  # 上一次的输出拼接上新加入的token作为输入，以实现上下文关联
        generate_text.append("，")  # 添加逗号

    return "".join(generate_text).strip("<bos>").replace("，，", "，")


text = "玫瑰"
generated_lyrics = predict(
    text,
    model,
    500,
    "/",
    "cuda",
    dataset.token_to_index,
    dataset.index_to_token,
    tempreture=0.90,
)


generated_lyrics