In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
import random
import json
from matplotlib import pyplot as plt

In [21]:
seed = 9
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

num_epochs = 500
batch_size = 128
max_length = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("now using", device)

now using cuda


In [22]:
with open("poems.json", "r", encoding="utf-8") as f:
    poems = json.load(f)

with open("vocab.json", "r", encoding="utf-8") as f:
    word_to_index = json.load(f)

index_to_word = {index: word for word, index in word_to_index.items()}
vocab_size = len(word_to_index)

print("VOCAB_SIZE:", vocab_size)
print("data_size", len(poems))

# 将句子转换为列表形式，并添加结束符
poems = [list(poem) + ["<EOP>"] for poem in poems]
index_tensors = {
    word: torch.LongTensor([word_to_index[word]]) for word in word_to_index
}

VOCAB_SIZE: 3482
data_size 1287


In [23]:
def generate_sample(poem):

    inputs = [index_tensors[poem[i - 1]] for i in range(1, len(poem))]
    outputs = [index_tensors[poem[i]] for i in range(1, len(poem))]

    # 将输入和输出列表合并为张量
    encoded_inputs = torch.cat(inputs)
    encoded_outputs = torch.cat(outputs)

    return encoded_inputs, encoded_outputs


class PoetryDataset(Dataset):
    def __init__(self, poems, transform=None):
        self.poems = poems
        self.transform = transform

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, index):
        poem = self.poems[index]
        input_data, output_data = generate_sample(poem)
        if self.transform:
            input_data = self.transform(input_data)
        return input_data, output_data


def custom_collate_fn(batch):
    inputs, outputs = zip(*batch)
    # 统一长度以进行批处理
    padded_inputs = nn.utils.rnn.pad_sequence(
        inputs, batch_first=True, padding_value=word_to_index["<START>"]
    )
    padded_outputs = nn.utils.rnn.pad_sequence(
        outputs, batch_first=True, padding_value=word_to_index["<START>"]
    )
    return padded_inputs, padded_outputs


dataset = PoetryDataset(poems)
data_loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn
)

In [24]:
class RNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(RNN, self).__init__()
        self.hidden_dim = hidden_dim
        # 在循环神经网络（RNN）中
        # 当前时刻的隐藏状态是由当前时刻的输入和上一个时刻的隐藏状态共同决定的。
        # RNN 的核心就是隐藏状态的更新
        self.input_to_hidden = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.tanh = nn.Tanh()

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.tanh(self.input_to_hidden(combined))
        return hidden

In [25]:
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = RNN(embedding_dim, hidden_dim)
        self.linear1 = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input, hidden):
        # embedding: 1,5 -> 1,5,256
        embeds = self.embeddings(input)
        batch_size, seq_len, _ = embeds.size()
        outputs = []
        for i in range(seq_len):
            # torch.Size([1, 256]) torch.Size([1, 512])
            hidden = self.rnn(embeds[:, i, :], hidden)
            outputs.append(hidden)
        rnn_out = torch.stack(outputs, dim=1)
        # print(rnn_out.size())
        # torch.Size([1, 5, 512])
        output = self.linear1(F.relu(rnn_out))
        # print(output.size())
        # torch.Size([1, 5, 3482])

        return output, hidden

    def initHidden(self, device, batch_size=1):
        return torch.zeros(batch_size, self.hidden_dim).to(device)

In [26]:
def train(model, num_epochs, data_loader, optimizer, criterion, scheduler, vocab_size):
    log_dict = {
        "train_loss_per_epoch": [],
        "train_perplexity_per_epoch": [],
    }
    model.train()
    model.to(device)
    start_time = time.time()
    for epoch in range(num_epochs):
        current_lr = optimizer.param_groups[0]["lr"]
        print(
            f"Epoch: {epoch+1:03d}/{num_epochs:03d} | Current Learning Rate: {current_lr:.6f}"
        )
        total_loss = 0
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            model.zero_grad()
            hidden = model.initHidden(device=device, batch_size=inputs.size(0))
            output, hidden = model(inputs.to(device), hidden)

            # print(output.shape, targets.shape)
            # torch.Size([16, 120, 3482]) torch.Size([16, 120])
            # print(output.view(-1, vocab_size).shape, targets.view(-1).shape)
            # torch.Size([1920, 3482]) torch.Size([1920])
            # 使用view函数调整输出和目标的形状以匹配损失函数的期望输入
            # output的原始形状是[批次大小, 序列长度, 词汇表大小]，targets的原始形状是[批次大小, 序列长度]
            # view(-1, vocab_size)将output重塑为[批次大小*序列长度, 词汇表大小]，以匹配每个时间步的预测
            # targets通过view(-1)被重塑为[批次大小*序列长度]，这样每个预测都有一个对应的目标值
            loss = criterion(output.view(-1, vocab_size), targets.view(-1).to(device))
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * inputs.size(0)

            if not batch_idx % 50:
                print(
                    f"Epoch: {epoch + 1:03d}/{num_epochs:03d} | Batch {batch_idx + 1:05d}/{len(data_loader):05d} | Loss: {loss:.4f}"
                )

        avg_loss = total_loss / len(data_loader.dataset)
        scheduler.step(avg_loss)
        perplexity = torch.exp(torch.tensor(avg_loss))
        log_dict["train_loss_per_epoch"].append(avg_loss)
        log_dict["train_perplexity_per_epoch"].append(perplexity)

        print(f"Time elapsed: {(time.time() - start_time) / 60:.2f} min")

    torch.save(model.state_dict(), "model_state_dict.pth")
    print(f"Total Training Time: {(time.time() - start_time)/ 60:.2f} min")
    return log_dict

In [27]:
from torchinfo import summary


def plot_training_stats(log_dict):
    plt.figure(figsize=(10, 6))
    plt.subplot(1, 2, 1)
    plt.plot(log_dict["train_loss_per_epoch"], label="Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(log_dict["train_perplexity_per_epoch"], label="Training Perplexity")
    plt.xlabel("Epoch")
    plt.ylabel("Perplexity")
    plt.grid(True)
    plt.title("Training Perplexity")
    plt.savefig("training_stats.svg")
    plt.show()

model = PoetryModel(vocab_size=len(word_to_index), embedding_dim=256, hidden_dim=512)

optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=word_to_index["<START>"], reduction="mean")
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=9, verbose=True
)
# log_dict = train(
#     model, num_epochs, data_loader, optimizer, criterion, scheduler, vocab_size
# )
# plot_training_stats(log_dict)
model.load_state_dict(torch.load("model_state_dict.pth"))
model.to(device)

inputs = torch.tensor([[1]]).to(device)
hidden = model.initHidden(device=device, batch_size=inputs.size(0))
summary(model, input_data=(inputs, hidden))

Layer (type:depth-idx)                   Output Shape              Param #
PoetryModel                              [1, 1, 3482]              --
├─Embedding: 1-1                         [1, 1, 256]               891,392
├─RNN: 1-2                               [1, 512]                  --
│    └─Linear: 2-1                       [1, 512]                  393,728
│    └─Tanh: 2-2                         [1, 512]                  --
├─Linear: 1-3                            [1, 1, 3482]              1,786,266
Total params: 3,071,386
Trainable params: 3,071,386
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 3.07
Input size (MB): 0.00
Forward/backward pass size (MB): 0.03
Params size (MB): 12.29
Estimated Total Size (MB): 12.32

In [28]:
def generate_text(start_word="<START>", top_k=1, temperature=0.7, log=False):
    generated_text = ""
    index_tensors_list = []
    for word in start_word:
        index_tensors_list.append(index_tensors[word].unsqueeze(0))
        generated_text += word

    hidden_state = model.initHidden(device=device)
    with torch.no_grad():

        for _ in range(max_length - len(generated_text)):

            input_tensor = torch.tensor(index_tensors_list).unsqueeze(0).to(device)

            output, hidden_state = model(input_tensor.to(device), hidden_state)
            # print(output.shape)
            # torch.Size([1, 5, 3482])
            # 切片
            last_word = output[:, -1, :]
            # print(last_word.shape)
            # torch.Size([1, 3482])
            last_word = last_word.view(-1)
            # print(last_word.shape)
            # torch.Size([3482])
            
            # 调整温度
            # softmax 函数倾向于增强输入向量中最大值的影响
            scaled_logits = last_word / temperature
            probabilities = F.softmax(scaled_logits, dim=-1)

            probabilities, top_indices = probabilities.data.topk(top_k)
            top_words = [index_to_word[index.item()] for index in top_indices]
            probabilities = probabilities / torch.sum(probabilities)

            probabilities_np = probabilities.cpu().numpy()
            indices_np = top_indices.cpu().numpy()
            if log:
                for word, prob in zip(top_words, probabilities_np):
                    print(f"{word}: {prob:.4f}")

            selected_index = np.random.choice(indices_np, p=probabilities_np)
            next_word = index_to_word[selected_index]

            if next_word == "<EOP>":
                break
            if log:
                print(generated_text)

            index_tensors_list = [index_tensors[next_word]]
            generated_text += next_word

    return generated_text.strip()


print(generate_text("长安一片月", top_k=1))
print(generate_text("江", top_k=3))
print(generate_text("月", top_k=3))
print(generate_text("泉", top_k=3))
print(generate_text("日", top_k=30))
print(generate_text("风", top_k=3, log=True))

长安一片月，万户捣衣声。秋风吹不尽，总是玉关情。何日平胡虏，良人罢远征。
江祖一片石，青天扫画屏。题诗留万古，绿字锦苔生。
月露发光彩，此时方见秋。夜凉金气应，天静火星流。蛩响偏依井，萤飞直过楼。相知尽白首，清景复追游。
泉眼不动月，长江夜夜深。独有余霞意，相期在不见。
日观化门来，登连城下尘。借问剡昔相，常恐沧年间。高谈出佳人，从此游应迷。
烟: 0.7222
吹: 0.1804
露: 0.0974
风
纪: 0.9608
里: 0.0236
起: 0.0156
风烟
南: 0.9899
江: 0.0058
海: 0.0043
风烟纪
城: 0.9989
山: 0.0008
都: 0.0003
风烟纪南
，: 0.9996
。: 0.0003
头: 0.0001
风烟纪南城
尘: 0.9960
水: 0.0031
旌: 0.0008
风烟纪南城，
土: 0.9919
水: 0.0045
户: 0.0035
风烟纪南城，尘
荆: 0.9969
青: 0.0019
今: 0.0012
风烟纪南城，尘土
门: 0.9994
青: 0.0004
城: 0.0002
风烟纪南城，尘土荆
路: 0.9990
城: 0.0006
东: 0.0004
风烟纪南城，尘土荆门
。: 1.0000
，: 0.0000
劒: 0.0000
风烟纪南城，尘土荆门路
天: 0.9995
相: 0.0003
江: 0.0002
风烟纪南城，尘土荆门路。
寒: 0.9949
山: 0.0038
河: 0.0014
风烟纪南城，尘土荆门路。天
多: 0.6247
猎: 0.3635
不: 0.0118
风烟纪南城，尘土荆门路。天寒
兽: 0.8848
猎: 0.0994
北: 0.0158
风烟纪南城，尘土荆门路。天寒猎
者: 0.9928
灭: 0.0063
扇: 0.0009
风烟纪南城，尘土荆门路。天寒猎兽
，: 1.0000
出: 0.0000
庭: 0.0000
风烟纪南城，尘土荆门路。天寒猎兽者
走: 0.9733
海: 0.0229
鸟: 0.0039
风烟纪南城，尘土荆门路。天寒猎兽者，
上: 0.9997
动: 0.0002
杀: 0.0001
风烟纪南城，尘土荆门路。天寒猎兽者，走
樊: 0.9728
元: 0.0160
占: 0.0112
风烟纪南城，尘土荆门路。天寒猎兽者，