In [60]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

import torch
from torch.utils.data import Dataset, DataLoader

In [61]:
file_path = "./data/train.txt"
with open(file_path, "r", encoding="utf-8") as f:
    text = f.read()

paragraphs = [para.strip() for para in text.split("<|endoftext|>") if para.strip()]

print(paragraphs[:5])

['菩萨蛮 其一\n小山重叠金明灭，鬓云欲度香腮雪。\n懒起画蛾眉，弄妆梳洗迟。\n照花前后镜，花面交相映。\n新帖绣罗襦，双双金鹧鸪。', '菩萨蛮 其二\n水晶帘里玻璃枕，暖香惹梦鸳鸯锦。\n江上柳如烟，雁飞残月天。\n藕丝秋色浅，人胜参差剪。\n双鬓隔香红，玉钗头上风。', '菩萨蛮 其三\n蕊黄无限当山额，宿妆隐笑纱窗隔。\n相见牡丹时，暂来还别离。\n翠钗金作股，钗上蝶双舞。\n心事竟谁知，月明花满枝。', '菩萨蛮 其四\n翠翘金缕双㶉𫛶，水纹细起春池碧。\n池上海棠梨，雨晴红满枝。\n绣衫遮笑靥，烟草粘飞蝶。\n青琐对芳菲，玉关音信稀。', '菩萨蛮 其五\n杏花含露团香雪，绿杨陌上多离别。\n灯在月胧明，觉来闻晓莺。\n玉钩褰翠幕，妆浅旧眉薄。\n春梦正关情，镜中蝉鬓轻。']


In [62]:
char_to_id = {}
id_to_char = {}


# 遍历数据，更新字符映射
chars = sorted(set(text))
char_to_id = {ch: i + 2 for i, ch in enumerate(chars)}
id_to_char = {i + 2: ch for i, ch in enumerate(chars)}

char_to_id["<pad>"] = 0
char_to_id["<eos>"] = 1
id_to_char[0] = "<pad>"
id_to_char[1] = "<eos>"

vocab_size = len(char_to_id)
print("字典大小: {}".format(vocab_size))

字典大小: 9234


In [63]:
# df["char_id_list"] = df["Comment"].apply(
# lambda text: [char_to_id[char] for char in list(text)] + [char_to_id["<eos>"]]
# )
# df.head()

char_id_lists = []
for item in paragraphs:
    char_ids = [char_to_id[char] for char in item] + [char_to_id["<eos>"]]
    char_id_lists.append(char_ids)

print(char_id_lists[:5])

[[6333, 6360, 6613, 3, 712, 300, 2, 1907, 1959, 7758, 1033, 7761, 3116, 4234, 8820, 8499, 380, 3654, 2216, 8361, 6011, 8095, 86, 2, 2605, 7254, 4693, 6621, 4899, 8820, 2257, 1645, 3450, 3891, 7502, 86, 2, 4319, 6161, 839, 1065, 7918, 8820, 6161, 8152, 390, 4889, 3124, 86, 2, 3061, 2140, 5666, 5758, 6882, 8820, 1022, 1022, 7761, 8694, 8642, 86, 1], [6333, 6360, 6613, 3, 712, 377, 2, 3764, 3169, 2141, 7757, 4538, 4619, 3301, 8820, 3183, 8361, 2526, 3442, 8650, 8646, 7884, 86, 2, 3786, 306, 3357, 1642, 4278, 8820, 8077, 8296, 3701, 3219, 1596, 86, 2, 6514, 322, 5186, 6122, 3920, 8820, 407, 5948, 1014, 2115, 853, 86, 2, 1022, 8499, 8059, 8361, 5603, 8820, 4510, 7808, 1605, 306, 8283, 86, 1], [6333, 6360, 6613, 3, 712, 305, 2, 6463, 8732, 3088, 8029, 2289, 1959, 8270, 8820, 1865, 1645, 8058, 5331, 5618, 5283, 8059, 86, 2, 4889, 6904, 4398, 338, 3101, 8820, 3176, 3273, 7496, 817, 5176, 86, 2, 5831, 7808, 7761, 498, 5913, 8820, 7808, 306, 6681, 1022, 6084, 86, 2, 2348, 376, 5311, 7071, 4992, 

In [64]:
batch_size = 32
epochs = 100
embed_dim = 50
hidden_dim = 30
lr = 0.001
grad_clip = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("now using device: ", device)

now using device:  cuda


In [65]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        # x = self.sequences.iloc[index][:-1]
        # y = self.sequences.iloc[index][1:]
        x = self.sequences[index][:-1]
        y = self.sequences[index][1:]
        return x, y


def collate_fn(batch):
    batch_x = [torch.tensor(data[0]) for data in batch]
    batch_y = [torch.tensor(data[1]) for data in batch]
    batch_x_lens = torch.LongTensor([len(x) for x in batch_x])
    batch_y_lens = torch.LongTensor([len(y) for y in batch_y])

    pad_batch_x = torch.nn.utils.rnn.pad_sequence(
        batch_x, batch_first=True, padding_value=char_to_id["<pad>"]
    )

    pad_batch_y = torch.nn.utils.rnn.pad_sequence(
        batch_y, batch_first=True, padding_value=char_to_id["<pad>"]
    )

    return pad_batch_x, pad_batch_y, batch_x_lens, batch_y_lens

In [66]:
# dataset = Dataset(df["char_id_list"])
dataset = Dataset(char_id_lists)

In [67]:
data_loader = DataLoader(
    dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn
)

In [68]:
class RNN:
    def __init__(self, input_size, hidden_size, output_size):
        # initialize the weights
        self.W_xh = np.random.randn(hidden_size, input_size)
        self.W_hh = np.random.randn(hidden_size, hidden_size)
        self.W_hy = np.random.randn(output_size, hidden_size)
        # initialize the hidden state
        self.h = np.zeros((hidden_size, 1))

    def step(self, x):
        # update the hidden state
        self.h = np.tanh(np.dot(self.W_hh, self.h) + np.dot(self.W_xh, x))
        # compute the output vector
        y = np.dot(self.W_hy, self.h)
        return y

In [69]:
class CharRNN(torch.nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(CharRNN, self).__init__()

        self.embedding = torch.nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embed_dim,
            padding_idx=char_to_id["<pad>"],
        )

        self.rnn_layer1 = torch.nn.LSTM(
            input_size=embed_dim, hidden_size=hidden_dim, batch_first=True
        )

        self.rnn_layer2 = torch.nn.LSTM(
            input_size=hidden_dim, hidden_size=hidden_dim, batch_first=True
        )

        self.linear = torch.nn.Sequential(
            torch.nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=hidden_dim, out_features=vocab_size),
        )

    def forward(self, batch_x, batch_x_lens):
        return self.encoder(batch_x, batch_x_lens)

    def encoder(self, batch_x, batch_x_lens):
        batch_x = self.embedding(batch_x)

        batch_x_lens = batch_x_lens.cpu()
        batch_x = torch.nn.utils.rnn.pack_padded_sequence(
            batch_x, batch_x_lens, batch_first=True, enforce_sorted=False
        )

        batch_x, _ = self.rnn_layer1(batch_x)
        batch_x, _ = self.rnn_layer2(batch_x)

        batch_x, _ = torch.nn.utils.rnn.pad_packed_sequence(batch_x, batch_first=True)

        batch_x = self.linear(batch_x)

        return batch_x

    def generator(self, start_char, max_len=50, top_n=5):
        char_list = [char_to_id[start_char]]
        next_char = None

        while len(char_list) < max_len:
            x = torch.LongTensor(char_list).unsqueeze(0)
            x = self.embedding(x)
            _, (ht, _) = self.rnn_layer1(x)
            _, (ht, _) = self.rnn_layer2(ht)
            y = self.linear(ht.squeeze(0))

            # 获取前 top_n 大的字符的索引
            top_n_values, top_n_indices = torch.topk(y, top_n)
            top_n_indices = top_n_indices.cpu().numpy()

            # 随机选择一个索引
            if top_n > 1:
                next_char = np.random.choice(top_n_indices[0])
            else:
                next_char = top_n_indices[0][0]

            if next_char == char_to_id["<eos>"]:
                break

            char_list.append(next_char)

        return [id_to_char[ch_id] for ch_id in char_list]

In [70]:
torch.manual_seed(2)
model = CharRNN(vocab_size, embed_dim, hidden_dim)
criterion = torch.nn.CrossEntropyLoss(
    ignore_index=char_to_id["<pad>"], reduction="mean"
)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [74]:
def train(model, num_epochs, data_loader, optimizer, criterion, vocab_size, grad_clip=1.0):
    ###################
    # 训练 #
    ###################
    min_loss = np.Inf
    model.train()
    for epoch in range(1, epochs + 1):
        model = model.to(device)
        for batch_idx, (batch_x, batch_y, batch_x_lens, batch_y_lens) in enumerate(data_loader):
            optimizer.zero_grad()

            # 将数据移动到GPU
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            # batch_x_lens = batch_x_lens.to(device)
            # batch_y_lens = batch_y_lens.to(device)

            batch_pred_y = model(batch_x, batch_x_lens)

            batch_pred_y = batch_pred_y.view(-1, vocab_size)
            batch_y = batch_y.view(-1)

            loss = criterion(batch_pred_y, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            optimizer.step()

            if not batch_idx % 100:
                print(
                    f"Epoch: {epoch:03d}/{num_epochs:03d} | Batch {batch_idx:04d}/{len(data_loader):04d} | Loss: {loss:.4f}"
                )

        torch.save(model.state_dict(), "char_rnn_model.pth")
        # 每个epoch结束后进行生成测试
        with torch.no_grad():
            model.eval()
            model.cpu()
            generated_text = model.generator("月")
            print("".join(generated_text))
            model.train()

        torch.cuda.empty_cache()


train(model, epochs, data_loader, optimizer, criterion, vocab_size)

Epoch: 001/100 | Batch 0000/8876 | Loss: 5.8432
Epoch: 001/100 | Batch 0100/8876 | Loss: 5.5244
Epoch: 001/100 | Batch 0200/8876 | Loss: 5.7933
Epoch: 001/100 | Batch 0300/8876 | Loss: 6.0157
Epoch: 001/100 | Batch 0400/8876 | Loss: 5.8704
Epoch: 001/100 | Batch 0500/8876 | Loss: 5.6514
Epoch: 001/100 | Batch 0600/8876 | Loss: 5.8754
Epoch: 001/100 | Batch 0700/8876 | Loss: 6.0158
Epoch: 001/100 | Batch 0800/8876 | Loss: 5.8273
Epoch: 001/100 | Batch 0900/8876 | Loss: 5.8007
Epoch: 001/100 | Batch 1000/8876 | Loss: 6.0238
Epoch: 001/100 | Batch 1100/8876 | Loss: 5.7819
Epoch: 001/100 | Batch 1200/8876 | Loss: 5.8684
Epoch: 001/100 | Batch 1300/8876 | Loss: 5.9937
Epoch: 001/100 | Batch 1400/8876 | Loss: 5.7791
Epoch: 001/100 | Batch 1500/8876 | Loss: 5.7885
Epoch: 001/100 | Batch 1600/8876 | Loss: 5.4704
Epoch: 001/100 | Batch 1700/8876 | Loss: 5.5085
Epoch: 001/100 | Batch 1800/8876 | Loss: 5.6321
Epoch: 001/100 | Batch 1900/8876 | Loss: 5.8457
Epoch: 001/100 | Batch 2000/8876 | Loss:

KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    for i in range(10):
        print("".join(model.generator("月")))
        print()
    for i in range(10):
        print("".join(model.generator("天")))
        print()
    for i in range(10):
        print("".join(model.generator("人")))
        print()

月夜雨
春凉不不见无人，一月不无何无时。泣何年年人不是也一日
年有一生不如天生事中事里不未不不似二首

月月上上
年无不知无时来不，不知不能知情心苦。
春入天不不不见去处好后
天的天明无来时，一笑相知非人

月雨中有书中不作
一朝无无何事名士，何有一时来时情无无无写。・我你不得相相思相成也也不相相，(醉云江

月雨日二咏二韵杂题 明夜水水二十五首呈德仲之寺三绝 呈万第八首二首
人无不如何事事不是不用一月作一月

月日雨二十四绝韵 明日二下之一卷
年朝，山，有无名，人无知情苦草二韵四咏 。圯
他千千尺天天天。 

月夜作三十四古 成
人不能能知人不，无今见相怜无。泣无人事人不物
春入春月落，清光香时。酒酒作
年朝

月日出
一无今不知，春风香无。怀
春风明风，无何知名士酒写。圯丸上
你(带云叶平州州平生上庭子子州王

月日夜起之韵四卷二咏
人生，不得尽好一事
天的无不知无人，无是不知人无心。儿外时，平风香开，玉金天无

月月
人生，不见长秀无用
人不知无无，天然知难知。追日，无何似无。蒙生日日月中不可寄
年是天子如无，

月雨上二日有中有作一一首四绝 
秋光雨香香香草花雪而一壑
人知是难能成中人事，不见人人间不难人。
人



In [None]:
torch.save(model.state_dict(), "char_rnn_model.pth")