In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader
import numpy as np
import json
import re
from pathlib import Path
import opencc

In [13]:
num_epochs = 50
batch_size = 4
max_length = 99

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
converter = opencc.OpenCC("t2s")


def sentenceParse(para):
    para = re.sub(r"（.*?）", "", para)
    para = re.sub(r"{.*?}", "", para)
    para = re.sub(r"《.*?》", "", para)
    para = re.sub(r"[\[\]]", "", para)
    para = "".join([s for s in para if s not in "0123456789-"])
    para = re.sub(r"。。", "。", para)
    para = converter.convert(para)
    if "𫗋" in para:
        return ""
    return para


def parseRawData(author=None, constrain=None):
    def handleJson(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)

        rst = []
        for poetry in data:
            if author and poetry.get("author") != author:
                continue

            paragraphs = poetry.get("paragraphs")
            if any(
                len(tr) != constrain and len(tr) != 0
                for s in paragraphs
                for tr in re.split("[，！。]", s)
                if constrain is not None
            ):
                continue

            pdata = "".join(paragraphs)
            pdata = sentenceParse(pdata)
            if pdata:
                rst.append(pdata)
        return rst

    data = []
    src_path = Path("./data/chinese-poetry-master/全唐诗/")
    for file_path in src_path.glob("poet.tang*"):
        data.extend(handleJson(file_path))
    # for file_path in src_path.glob("poet.song*"):
    # data.extend(handleJson(file_path))
    return data

In [15]:
poems = parseRawData(author="李白")  # All if author=None

In [16]:
# 构建词汇表
word_to_index = {}
for poem in poems:
    for word in poem:
        if word not in word_to_index:
            word_to_index[word] = len(word_to_index)
word_to_index["<EOP>"] = len(word_to_index)  # End Of Poem token
word_to_index["<START>"] = len(word_to_index)  # Start token
index_to_word = {index: word for word, index in word_to_index.items()}

vocab_size = len(word_to_index)

print("VOCAB_SIZE:", vocab_size)
print("data_size", len(poems))


# 将句子转换为列表形式，并添加结束符
def sentence_to_list(sentence):
    return list(sentence) + ["<EOP>"]


poems = [sentence_to_list(poem) for poem in poems]


# 创建单词到one-hot向量的映射
def create_one_hot_vector(word, word_to_index):
    return torch.autograd.Variable(torch.LongTensor([word_to_index[word]]))


one_hot_vectors = {
    word: create_one_hot_vector(word, word_to_index) for word in word_to_index
}

VOCAB_SIZE: 3514
data_size 1206


In [17]:
def generate_sample(sequence, one_hot_encoding):
    # 打印原始序列（可选）
    # print(sequence)

    # 使用列表推导式生成输入和输出的 one-hot 编码
    inputs = [one_hot_encoding[sequence[i - 1]] for i in range(1, len(sequence))]
    outputs = [one_hot_encoding[sequence[i]] for i in range(1, len(sequence))]

    # 将输入和输出列表合并为张量
    encoded_inputs = torch.cat(inputs)
    encoded_outputs = torch.cat(outputs)

    return encoded_inputs, encoded_outputs


# generate_sample(poems[0], one_hot_vectors)


class PoetryDataset(Dataset):
    def __init__(self, poems, transform=None):
        self.poems = poems
        self.transform = transform

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, index):
        poem = self.poems[index]
        input_data, output_data = generate_sample(poem, one_hot_vectors)
        if self.transform:
            input_data = self.transform(input_data)
        return input_data, output_data


def custom_collate_fn(batch):
    sequences, targets = zip(*batch)
    # 统一长度以进行批处理
    padded_sequences = nn.utils.rnn.pad_sequence(
        sequences, batch_first=True, padding_value=word_to_index["<START>"]
    )
    padded_targets = nn.utils.rnn.pad_sequence(
        targets, batch_first=True, padding_value=word_to_index["<START>"]
    )
    return padded_sequences, padded_targets


dataset = PoetryDataset(poems)
data_loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn
)

In [18]:
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim, hidden_dim, batch_first=True
        )  # Enable batch_first
        self.linear1 = nn.Linear(hidden_dim, vocab_size)
        # self.dropout = nn.Dropout(0.2)
        self.softmax = nn.LogSoftmax(dim=-1)  # Adjusted for batch processing

    def forward(self, input, hidden):
        embeds = self.embeddings(input)
        lstm_out, hidden = self.lstm(embeds, hidden)
        # Adjusted view for batch processing, removing hard-coded lengths
        output = self.linear1(F.relu(lstm_out.contiguous().view(-1, self.hidden_dim)))
        # output = self.dropout(output)
        output = self.softmax(output)
        # Reshape output to (batch_size, seq_len, vocab_size) for compatibility
        output = output.view(input.size(0), input.size(1), -1)
        return output, hidden

    def initHidden(self, device, batch_size=1):
        # 在LSTM网络中，每个单元有两个隐藏状态：一个是长期状态（通常表示为c），另一个是短期状态（通常表示为h）。
        # 这两个状态共同帮助LSTM单元记住信息并处理复杂的序列依赖。
        return (
            torch.zeros(1, batch_size, self.hidden_dim).to(device),
            torch.zeros(1, batch_size, self.hidden_dim).to(device),
        )

In [19]:
model = PoetryModel(len(word_to_index), 256, 256)
model.to(device)

optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=0.0001)
criterion = torch.nn.NLLLoss(ignore_index=word_to_index["<START>"], reduction="mean")

In [20]:
def train(model, num_epochs, data_loader, optimizer, criterion, vocab_size):
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (sequence, target) in enumerate(data_loader):
            model.zero_grad()
            hidden = model.initHidden(device=device, batch_size=sequence.size(0))
            output, hidden = model(sequence.to(device), hidden)
            loss = criterion(output.view(-1, vocab_size), target.view(-1).to(device))
            loss.backward()
            optimizer.step()
            if not batch_idx % 100:
                print(
                    f"Epoch: {epoch + 1:03d}/{num_epochs:03d} | Batch {batch_idx:05d}/{len(data_loader):05d} | Loss: {loss:.4f}"
                )
    torch.save(model.state_dict(), "poetry-gen.pth")


# train(model, num_epochs, data_loader, optimizer, criterion, vocab_size)

In [21]:
model.load_state_dict(torch.load("poetry-gen.pth"))

<All keys matched successfully>

In [41]:
def generate_text(start_word="<START>", top_k=1, log=False):
    generated_text = ""
    words = []
    for word in start_word:
        words += [word]
    print(words)
    hidden_state = model.initHidden(device=device)
    with torch.no_grad():
        vectors_list = []
        for word in words:
            word_vector = torch.LongTensor([word_to_index[word]]).unsqueeze(0)
            vectors_list.append(word_vector)
            generated_text += word

        input_vector = torch.cat(vectors_list, dim=1)
        for _ in range(max_length - len(words)):
            output, hidden_state = model(input_vector.to(device), hidden_state)
            last_word = output[:, -1, :]
            last_word = last_word.view(-1)
            top_values, top_indices = last_word.data.topk(top_k)

            probabilities = torch.exp(top_values)
            top_words = [index_to_word[index.item()] for index in top_indices]

            probabilities_np = probabilities.cpu().detach().numpy()
            probabilities_np = probabilities_np / probabilities_np.sum()
            indices_np = top_indices.cpu().detach().numpy()
            if log:
                for word, prob in zip(top_words, probabilities_np):
                    print(f"{word}: {prob:.4f}")
                    
            selected_index = np.random.choice(indices_np, p=probabilities_np)

            next_word = index_to_word[selected_index]
            if next_word == "<EOP>":
                break
            generated_text += next_word
            if log:
                print(generated_text)
            # * 需要升一个维
            input_vector = torch.LongTensor([word_to_index[next_word]]).unsqueeze(0)

    return generated_text.strip()


print(generate_text("江", top_k=3))
print(generate_text("泉", top_k=1))
print(generate_text("泉", top_k=3))
print(generate_text("泉", top_k=30))
print(generate_text("沾衣欲湿杏花雨", top_k=3))
print(generate_text("风", top_k=3, log=True))

['江']
江水如碧海，长啸入云烟。高峰望远道，云门隔古道。人生古可道，云尽空余哀。
['泉']
泉水东北流，波荡双鸳鸯。飞燕燕汉国，飞龙与天通。人生寒松草，草木不相连。坐思天上月，空余碧玉道。云行无知老，吾将何时？。
['泉']
泉花满水国，水明湖江流。清光出江海，流水游江流。月光碧山月，水清清猨月。长松风落日，清夜夜清风。闲入玉窗里，云在罗罗衣。何言是我意，独去难为别。
['泉']
泉月东楼湖，风流天上来。青山几度月，日月照古人。遥见一里月，三千久空山。一为李白发，何由入秋浦。山公一问石，东海扬中州。落子结归去，沙去空天人。
['沾', '衣', '欲', '湿', '杏', '花', '雨']
沾衣欲湿杏花雨，春风拂槛露华香。若飞云汉去，宫花艳舞香。
['风']
吹: 0.6101
日: 0.2246
云: 0.1653
风吹
玉: 0.5230
落: 0.3101
花: 0.1669
风吹玉
关: 0.4486
笛: 0.3014
树: 0.2500
风吹玉关
道: 0.4830
西: 0.2705
山: 0.2465
风吹玉关西
，: 0.6674
入: 0.2822
海: 0.0504
风吹玉关西入
吴: 0.5667
汉: 0.3348
胡: 0.0985
风吹玉关西入吴
关: 0.4777
越: 0.2828
云: 0.2394
风吹玉关西入吴关
，: 0.9440
之: 0.0398
城: 0.0163
风吹玉关西入吴关，
云: 0.4393
南: 0.3470
见: 0.2137
风吹玉关西入吴关，南
云: 0.8726
行: 0.0846
国: 0.0429
风吹玉关西入吴关，南云
莫: 0.3643
日: 0.3494
白: 0.2862
风吹玉关西入吴关，南云白
日: 0.8426
马: 0.1348
帝: 0.0226
风吹玉关西入吴关，南云白日
之: 0.5122
出: 0.2900
夜: 0.1977
风吹玉关西入吴关，南云白日出
天: 0.5100
处: 0.2594
夜: 0.2306
风吹玉关西入吴关，南云白日出处
处: 0.4709
归: 0.3631
生: 0.1659
风吹玉关西入吴关，南云白日出处生
。: 0.9977
，: 0.0019
？: 0.0004
风吹玉关西入吴关，南云白日出处生。
天: 0.5042
白: 