In [521]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd
from torch.utils.data import Dataset, DataLoader

import random
import json
import re
from pathlib import Path
import opencc
import pickle as p

In [522]:
num_epochs = 50
batch_size = 4
max_length = 99

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [523]:
converter = opencc.OpenCC("t2s")

def sentenceParse(para):
    para = re.sub(r"（.*?）", "", para)
    para = re.sub(r"{.*?}", "", para)
    para = re.sub(r"《.*?》", "", para)
    para = re.sub(r"[\[\]]", "", para)
    para = "".join([s for s in para if s not in "0123456789-"])
    para = re.sub(r"。。", "。", para)
    para = converter.convert(para)
    if "𫗋" in para:
        return ""
    return para


def parseRawData(author=None, constrain=None):
    def handleJson(file_path):
        with open(file_path, "r", encoding="utf-8") as file:
            data = json.load(file)

        rst = []
        for poetry in data:
            if author and poetry.get("author") != author:
                continue

            paragraphs = poetry.get("paragraphs")
            if any(
                len(tr) != constrain and len(tr) != 0
                for s in paragraphs
                for tr in re.split("[，！。]", s)
                if constrain is not None
            ):
                continue

            pdata = "".join(paragraphs)
            pdata = sentenceParse(pdata)
            if pdata:
                rst.append(pdata)
        return rst

    data = []
    src_path = Path("./data/chinese-poetry-master/全唐诗/")
    for file_path in src_path.glob("poet.tang*"):
        data.extend(handleJson(file_path))
    # for file_path in src_path.glob("poet.song*"):
        # data.extend(handleJson(file_path))
    return data

In [524]:
poems = parseRawData(author="李白")  # All if author=None

In [525]:
# 构建词汇表
word_to_index = {}
for poem in poems:
    for word in poem:
        if word not in word_to_index:
            word_to_index[word] = len(word_to_index)
word_to_index["<EOP>"] = len(word_to_index)  # End Of Poem token
word_to_index["<START>"] = len(word_to_index)  # Start token
index_to_word = {index: word for word, index in word_to_index.items()}

vocab_size = len(word_to_index)

print("VOCAB_SIZE:", vocab_size)
print("data_size", len(poems))


# 将句子转换为列表形式，并添加结束符
def sentence_to_list(sentence):
    return list(sentence) + ["<EOP>"]

poems = [sentence_to_list(poem) for poem in poems]


# 创建单词到one-hot向量的映射
def create_one_hot_vector(word, word_to_index):
    return torch.autograd.Variable(torch.LongTensor([word_to_index[word]]))

one_hot_vectors = {
    word: create_one_hot_vector(word, word_to_index) for word in word_to_index
}

VOCAB_SIZE: 3514
data_size 1206


In [526]:
def generate_sample(sequence, one_hot_encoding):
    # 打印原始序列（可选）
    # print(sequence)

    # 使用列表推导式生成输入和输出的 one-hot 编码
    inputs = [one_hot_encoding[sequence[i - 1]] for i in range(1, len(sequence))]
    outputs = [one_hot_encoding[sequence[i]] for i in range(1, len(sequence))]

    # 将输入和输出列表合并为张量
    encoded_inputs = torch.cat(inputs)
    encoded_outputs = torch.cat(outputs)

    return encoded_inputs, encoded_outputs


# generate_sample(poems[0], one_hot_vectors)

class PoetryDataset(Dataset):
    def __init__(self, poems, transform=None):
        self.poems = poems
        self.transform = transform

    def __len__(self):
        return len(self.poems)

    def __getitem__(self, index):
        poem = self.poems[index]
        input_data, output_data = generate_sample(poem, one_hot_vectors)
        if self.transform:
            input_data = self.transform(input_data)
        return input_data, output_data


def custom_collate_fn(batch):
    sequences, targets = zip(*batch)
    padded_sequences = nn.utils.rnn.pad_sequence(
        sequences, batch_first=True, padding_value=word_to_index["<START>"]
    )
    # Find the maximum target length
    max_target_len = max([t.size(0) for t in targets])
    # Pad targets to the maximum length
    padded_targets = torch.stack(
        [
            nn.functional.pad(
                t, (0, max_target_len - t.size(0)), "constant", word_to_index["<START>"]
            )
            for t in targets
        ]
    )
    return padded_sequences, padded_targets


dataset = PoetryDataset(poems)
data_loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn
)

In [527]:
class PoetryModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(
            embedding_dim, hidden_dim, batch_first=True
        )  # Enable batch_first
        self.linear1 = nn.Linear(hidden_dim, vocab_size)
        # self.dropout = nn.Dropout(0.2)
        self.softmax = nn.LogSoftmax(dim=-1)  # Adjusted for batch processing

    def forward(self, input, hidden):
        embeds = self.embeddings(input)
        lstm_out, hidden = self.lstm(embeds, hidden)
        # Adjusted view for batch processing, removing hard-coded lengths
        output = self.linear1(F.relu(lstm_out.contiguous().view(-1, self.hidden_dim)))
        # output = self.dropout(output)
        output = self.softmax(output)
        # Reshape output to (batch_size, seq_len, vocab_size) for compatibility
        output = output.view(input.size(0), input.size(1), -1)
        return output, hidden

    def initHidden(self, device, batch_size=1):
        return (
            torch.zeros(1, batch_size, self.hidden_dim).to(device),
            torch.zeros(1, batch_size, self.hidden_dim).to(device),
        )

In [528]:
model = PoetryModel(len(word_to_index), 256, 256)
model.to(device)

optimizer = optim.RMSprop(model.parameters(), lr=0.001, weight_decay=0.0001)
criterion = torch.nn.NLLLoss(
    ignore_index=word_to_index["<START>"], reduction="mean"
)

In [529]:
def train(model, num_epochs, data_loader, optimizer, criterion, vocab_size):
    model.train()
    for epoch in range(num_epochs):
        for batch_idx, (t, o) in enumerate(data_loader):
            model.zero_grad()
            hidden = model.initHidden(device=device, batch_size=t.size(0))
            output, hidden = model(t.to(device), hidden)
            loss = criterion(output.view(-1, vocab_size), o.view(-1).to(device))
            loss.backward()
            optimizer.step()
            if not batch_idx % 100:
                print(
                    f"Epoch: {epoch + 1:03d}/{num_epochs:03d} | Batch {batch_idx:05d}/{len(data_loader):05d} | Loss: {loss:.4f}"
                )
    torch.save(model.state_dict(), "poetry-gen.pth")


train(model, num_epochs, data_loader, optimizer, criterion, vocab_size)

Epoch: 001/050 | Batch 00000/00302 | Loss: 8.1615
Epoch: 001/050 | Batch 00100/00302 | Loss: 5.7644
Epoch: 001/050 | Batch 00200/00302 | Loss: 6.0068
Epoch: 001/050 | Batch 00300/00302 | Loss: 6.0886
Epoch: 002/050 | Batch 00000/00302 | Loss: 5.8192
Epoch: 002/050 | Batch 00100/00302 | Loss: 5.6160
Epoch: 002/050 | Batch 00200/00302 | Loss: 5.6990
Epoch: 002/050 | Batch 00300/00302 | Loss: 5.6606
Epoch: 003/050 | Batch 00000/00302 | Loss: 5.4652
Epoch: 003/050 | Batch 00100/00302 | Loss: 5.3944
Epoch: 003/050 | Batch 00200/00302 | Loss: 5.3430
Epoch: 003/050 | Batch 00300/00302 | Loss: 5.4390
Epoch: 004/050 | Batch 00000/00302 | Loss: 5.0503
Epoch: 004/050 | Batch 00100/00302 | Loss: 5.1681
Epoch: 004/050 | Batch 00200/00302 | Loss: 5.4184
Epoch: 004/050 | Batch 00300/00302 | Loss: 5.2579
Epoch: 005/050 | Batch 00000/00302 | Loss: 5.2254
Epoch: 005/050 | Batch 00100/00302 | Loss: 5.0910
Epoch: 005/050 | Batch 00200/00302 | Loss: 5.0171
Epoch: 005/050 | Batch 00300/00302 | Loss: 5.0612


In [530]:
# model.load_state_dict(torch.load("poetry-gen.pth"))

In [554]:
def make_one_hot_vec_target(word, word_to_index):
    rst = autograd.Variable(torch.LongTensor([word_to_index[word]]))
    return rst


def generate_text(start_word="<START>", top_k=1):
    generated_text = ""
    words = []
    for word in start_word:
        words += [word]
    print(words)
    hidden_state = model.initHidden(device=device)

    with torch.no_grad(): 
        for word in words:
            input_vector = make_one_hot_vec_target(word, word_to_index).unsqueeze(0)
            model(input_vector.to(device), hidden_state)
            generated_text += word

        for _ in range(max_length - len(words)):
            output, hidden_state = model(input_vector.to(device), hidden_state)
            top_values, top_indices = output.data.topk(top_k)

            if top_k == 1:
                selected_index = top_indices.item()
            else:
                top_indices = top_indices.squeeze()
                top_values = top_values.squeeze()
                top_words = [index_to_word[index.item()] for index in top_indices]
                top_probs = top_values.tolist()
                # For demonstration, print the top_k words and their probabilities
                # print(f"Top {top_k} words and their probabilities:")
                for word, prob in zip(top_words, top_probs):
                    print(f"{word}: {prob:.4f}")
                
                top_indices = top_indices.squeeze()
                selected_index = top_indices[random.randint(0, top_k - 1)].item()

            next_word = index_to_word[selected_index]
            if next_word == "<EOP>":
                break
            generated_text += next_word
            print(generated_text)
            input_vector = make_one_hot_vec_target(next_word, word_to_index).unsqueeze(0)

    return generated_text.strip()


print(generate_text("江", top_k=3))
# print(generate_text("月夜", top_k=2))
# print(generate_text("山"))
# print(generate_text("烟"))

['江']
水: -3.2523
山: -3.6810
上: -3.6984
江山
一: -1.8605
白: -2.3346
东: -2.6775
江山白
日: -1.1718
玉: -2.1024
云: -2.7929
江山白日
月: -0.9610
夜: -2.6997
暮: -2.9791
江山白日暮
月: -2.0810
高: -2.2924
，: -2.4153
江山白日暮，
水: -1.5942
明: -3.2785
白: -3.3700
江山白日暮，白
云: -2.0098
日: -2.4484
马: -2.5727
江山白日暮，白马
愁: -3.5174
空: -3.5805
飞: -3.5959
江山白日暮，白马愁
寒: -2.0720
飞: -2.5286
边: -3.4779
江山白日暮，白马愁飞
天: -2.6075
空: -2.7734
飞: -3.0239
江山白日暮，白马愁飞飞
。: -0.0008
，: -8.8518
天: -9.7263
江山白日暮，白马愁飞飞。
<EOP>: -1.6649
君: -3.3274
白: -3.5796
江山白日暮，白马愁飞飞。
