In [45]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

batch_size, num_steps = 128, 50


In [46]:
import re
import json

def read_poem():
    all_data = []
    for i in range(1, 901):
        file_num = str(i).zfill(3)
        with open(f'./Chinese Poem NLP/chinese-poetry-master/quan_tang_shi/json/{file_num}.json', 'r', encoding='utf-8') as f:
            data = json.load(f)
            all_data += data
    # Extract the lines from the data
    lines = []
    for poem in all_data:
        for line in poem['paragraphs']:
            lines.append(line)
    return lines

lines = read_poem()
print(lines[100])

之罘思漢帝，碣石想秦皇。霓裳非本意，端拱且圖王。


In [47]:
def tokenize(lines, token='word'):  #@save
    """将文本行拆分为单词或字符词元"""
    if token == 'word':
        return [line.split() for line in lines]
    elif token == 'char':
        return [list(line) for line in lines]
    else:
        print('错误：未知词元类型：' + token)

tokens = tokenize(lines)
for i in range(11):
    print(tokens[i])

['秦川雄帝宅，函谷壯皇居。綺殿千尋起，離宮百雉餘。']
['連薨遙接漢，飛觀迥淩虛。雲日隱層闕，風煙出綺疏。']
['岩廊罷機務，崇文聊駐輦。玉匣啟龍圖，金繩披鳳篆。']
['韋編斷仍續，縹帙舒還卷。對此乃淹留，欹案觀墳典。']
['移步出詞林，停輿欣武宴。雕弓寫明月，駿馬疑流電。']
['驚雁落虛弦，啼猿悲急箭。閱賞誠多美，於茲乃忘倦。']
['鳴笳臨樂館，眺聽歡芳節。急管韻朱弦，清歌凝白雪。']
['彩鳳肅來儀，玄鶴紛成列。去茲鄭衛聲，雅音方可悅。']
['芳辰追逸趣，禁苑信多奇。橋形通漢上，峰勢接雲危。']
['煙霞交隱映，花鳥自參差。何如肆轍跡，萬里賞瑤池。']
['飛蓋去芳園，蘭橈遊翠渚。萍間日彩亂，荷處香風舉。']


In [48]:
def load_corpus_poem(max_tokens=-1):
    """Return token indices and the vocabulary of the wiki dataset.

    Defined in :numref:`sec_text_preprocessing`"""
    lines = read_poem()
    tokens = d2l.tokenize(lines, 'char')
    vocab = d2l.Vocab(tokens)
    # Since each text line in the time machine dataset is not necessarily a
    # sentence or a paragraph, flatten all the text lines into a single list
    corpus = [vocab[token] for line in tokens for token in line]
    if max_tokens > 0:
        corpus = corpus[:max_tokens]
    return corpus, vocab


In [49]:
class SeqDataLoader:
    """An iterator to load sequence data."""
    def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
        """Defined in :numref:`sec_language_model`"""
        if use_random_iter:
            self.data_iter_fn = d2l.seq_data_iter_random
        else:
            self.data_iter_fn = d2l.seq_data_iter_sequential
        self.corpus, self.vocab = load_corpus_poem(max_tokens)
        self.batch_size, self.num_steps = batch_size, num_steps
    def __iter__(self):
        return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)


In [50]:
def load_data_poem(batch_size, num_steps,
                           use_random_iter=False, max_tokens=10000):
    """Return the iterator and the vocabulary of the time machine dataset.

    Defined in :numref:`sec_language_model`"""
    data_iter = SeqDataLoader(
        batch_size, num_steps, use_random_iter, max_tokens)
    return data_iter, data_iter.vocab

In [51]:
train_iter, vocab = load_data_poem(batch_size, num_steps)

In [52]:
F.one_hot(torch.tensor([0, 2]), len(vocab))

tensor([[1, 0, 0,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 0, 0]])

In [53]:
X = torch.arange(10).reshape((2, 5))
F.one_hot(X.T, 28).shape

torch.Size([5, 2, 28])

In [54]:
#初始化模型参数
def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    # 隐藏层参数
    W_xh = normal((num_inputs, num_hiddens))
    W_hh = normal((num_hiddens, num_hiddens))
    b_h = torch.zeros(num_hiddens, device=device)
    # 输出层参数
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # 附加梯度
    params = [W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

In [55]:
def init_rnn_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

In [56]:
def rnn(inputs, state, params):
    # inputs的形状：(时间步数量，批量大小，词表大小)
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    # X的形状：(批量大小，词表大小)
    for X in inputs:
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

In [57]:
class RNNModelScratch: #@save
    """从零开始实现的循环神经网络模型"""
    def __init__(self, vocab_size, num_hiddens, device,
                 get_params, init_state, forward_fn):
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = get_params(vocab_size, num_hiddens, device)
        self.init_state, self.forward_fn = init_state, forward_fn

    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)

    def begin_state(self, batch_size, device):
        return self.init_state(batch_size, self.num_hiddens, device)

In [58]:
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,
                      init_rnn_state, rnn)
state = net.begin_state(X.shape[0], d2l.try_gpu())
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape

(torch.Size([10, 7431]), 1, torch.Size([2, 512]))

In [93]:
def predict(prefix, num_preds, net, vocab, device):  #@save
    """在prefix后面生成新字符"""
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))
    for y in prefix[1:]:  # 预热期
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):  # 预测num_preds步
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [60]:
predict_ch8('生当作人杰 死亦为鬼雄', 30, net, vocab, d2l.try_gpu())

'生<unk>作人<unk><unk>死亦<unk>鬼雄醷拌侈荒停勷軏台G酉偃綻恂製巡傍虥冪究U籺漼囿罪睞8繁榜漻惸'

In [61]:
def grad_clipping(net, theta):  #@save
    """裁剪梯度"""
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [62]:
# #@save
# def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
#     """训练网络一个迭代周期"""
#     state, timer = None, d2l.Timer()
#     metric = d2l.Accumulator(2)  # 训练损失之和,词元数量
#     for X, Y in train_iter:
#         if state is None or use_random_iter:
#             # 在第一次迭代或使用随机抽样时初始化state
#             state = net.begin_state(batch_size=X.shape[0], device=device)
#         else:
#             if isinstance(net, nn.Module) and not isinstance(state, tuple):
#                 # state对于nn.GRU是个张量
#                 state.detach_()
#             else:
#                 # state对于nn.LSTM或对于我们从零开始实现的模型是个张量
#                 for s in state:
#                     s.detach_()
#         y = Y.T.reshape(-1)
#         X, y = X.to(device), y.to(device)
#         y_hat, state = net(X, state)
#         l = loss(y_hat, y.long()).mean()
#         if isinstance(updater, torch.optim.Optimizer):
#             updater.zero_grad()
#             l.backward()
#             grad_clipping(net, 1)
#             updater.step()
#         else:
#             l.backward()
#             grad_clipping(net, 1)
#             # 因为已经调用了mean函数
#             updater(batch_size=1)
#         metric.add(l * y.numel(), y.numel())
#     return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):
    """Train a net within one epoch (defined in Chapter 8).

    Defined in :numref:`sec_rnn_scratch`"""
    state, timer = None, d2l.Timer()
    metric = d2l.Accumulator(2)  # Sum of training loss, no. of tokens
    for X, Y in train_iter:
        if state is None or use_random_iter:
            # Initialize `state` when either it is the first iteration or
            # using random sampling
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            if isinstance(net, nn.Module) and not isinstance(state, tuple):
                # `state` is a tensor for `nn.GRU`
                state.detach_()
            else:
                # `state` is a tuple of tensors for `nn.LSTM` and
                # for our custom scratch implementation
                for s in state:
                    s.detach_()
        y = Y.T.reshape(-1)
        X, y = X.to(device), y.to(device)
        y_hat, state = net(X, state)
        l = loss(y_hat, y.long()).mean()
        if isinstance(updater, torch.optim.Optimizer):
            updater.zero_grad()
            l.backward()
            grad_clipping(net, 1)
            updater.step()
        else:
            l.backward()
            grad_clipping(net, 1)
            # Since the `mean` function has been invoked
            updater(batch_size=1)
        metric.add(l * d2l.size(y), d2l.size(y))
    return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()

In [63]:
#高级API
#@save
def train_ch8(net, train_iter, vocab, lr, num_epochs, device,
              use_random_iter=False):
    """训练模型（定义见第8章）"""
    loss = nn.CrossEntropyLoss()
    animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',
                            legend=['train'], xlim=[10, num_epochs])
    # 初始化
    if isinstance(net, nn.Module):
        updater = torch.optim.SGD(net.parameters(), lr)
    else:
        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)
    predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)
    # 训练和预测
    for epoch in range(num_epochs):
        ppl, speed = train_epoch_ch8(
            net, train_iter, loss, updater, device, use_random_iter)
        if (epoch + 1) % 10 == 0:
            print(predict('生当作人杰 死亦为鬼雄'))
            animator.add(epoch + 1, [ppl])
    print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')
    print(predict('白日依山尽'))
    print(predict('千山鸟飞绝'))

In [64]:
# num_epochs, lr = 2000, 0.2
# train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())

In [65]:
#torch.save(net, './Chinese Poem NLP/Checkpoints/Poem_TangShi_RNN_N5000.pth')

In [74]:
model = torch.load( './Chinese Poem NLP/Checkpoints/Poem_TangShi_RNN_N5000.pth')
model2 = torch.load( './Chinese Poem NLP/Checkpoints/Poem_TangShi_RNN_N3000.pth')

In [76]:
predict_ch8('白日依山尽', 50, model, vocab, d2l.try_gpu())

'白日依山<unk>。日輪臨天葉，高宮方雲臣。暮野一天色，方原方朝光。寒茲三九色，無為散朝光。昔流歡地地，今雲上下風。秋'

In [77]:
predict_ch8('白日依山尽', 50, model2, vocab, d2l.try_gpu())

'白日依山<unk>。清輪含春影，終色方朝光。雲霞三九影，翠花散朝風。雲雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長'

In [68]:
predict('床前明月光，疑似地上霜 ', 100, model, vocab)

'床前明月光，疑似地上霜<unk>岸。惟峰黃家箕山北，不知虛度兩京春。去年餘閏今春早，曙色和風著花草。可憐寒岫入清明，還想連池起煙霧。樹霧驪食與清明，預想連池起煙霧。樹廬驪道與清明，預想連池起煙霧。樹霧驪食與清明，預想連池起煙霧。樹霧'

In [78]:
predict('床前明月光，疑似地上霜', 50, model2, vocab)

'床前明月光，疑似地上霜。花雲猶天影，終色方朝春。雲霞三九影，翠花散朝風。雲雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長'

In [79]:
predict_ch8('烈烈寒風起，慘慘飛雲浮。', 100, model, vocab, d2l.try_gpu())

'烈烈寒風起，慘慘飛雲浮。高雲猶天色，無原散朝光。玉日三初色，風風散朝光。長流千四影，無樹散朝風。昔日凝地影，風風散雲風。秋雲凝地影，無樹散還風。昔流凝地影，風風散雲風。秋流無地影，無樹散還風。雲雲凝地色，風來散雲風。秋流無地'

In [80]:
predict_ch8('烈烈寒風起，慘慘飛雲浮。', 100, model2, vocab, d2l.try_gpu())

'烈烈寒風起，慘慘飛雲浮。寒雲猶天影，終色方朝光。雲霞三九影，翠花散朝風。雲雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長雲含龍影，無日散雲風。長雲含龍'

In [70]:
predict_ch8('三驅陳銳卒，七萃列材雄。', 36, model, vocab, d2l.try_gpu())

'三驅陳銳卒，七萃列材雄。地華一重雪，終原方可紅。寒茲三九色，無樹散朝光。長流無地地，今雲散下風。'

In [81]:
predict_ch8('三驅陳銳卒，七萃列材雄。', 36, model2, vocab, d2l.try_gpu())

'三驅陳銳卒，七萃列材雄。登言一天影，方花方朝光。雲霞三九影，翠花散朝風。雲雲含龍影，無日散雲風。'

In [71]:
predict('塞外悲風切，交河冰已結。', 36, model, vocab)

'塞外悲風切，交河冰已結。地霞猶遠影，持樹開長空。玉日三春色，風花散影光。長當無地極，無雲散下風。'

In [82]:
predict_ch8('塞外悲風切，交河冰已結。', 36, model2, vocab, d2l.try_gpu())

'塞外悲風切，交河冰已結。花雲猶天影，終色方朝春。雲霞三九影，翠花散朝風。雲雲含龍影，無日散雲風。'

In [72]:
predict_ch8('美人卷珠帘，深坐蹙蛾眉。', 36, model, vocab, d2l.try_gpu())

'美人卷珠<unk>，深坐蹙蛾眉。一豐千萬節，臨輿方重中。寒日一舊色，無樹散朝光。長雲無四極，無雲散還風。'

In [83]:
predict_ch8('仙氣凝三嶺，和風扇八荒。', 36, model, vocab, d2l.try_gpu())

'仙氣凝三嶺，和風扇八荒。高茲三天色，無樹散朝光。玉茲千初影，風風散朝光。長流三地影，無為散還風。'

In [84]:
predict_ch8('照岸花分彩，迷雲雁斷行。', 36, model, vocab, d2l.try_gpu())

'照岸花分彩，迷雲雁斷行。懷野千春色，無原方可春。寒日三九色，無為散朝光。昔流歡地地，今雲上下風。'

In [88]:
predict_ch8('之罘思漢帝，碣石想秦皇', 13, model, vocab, d2l.try_gpu())

'之罘思漢帝，碣石想秦皇。登言飛天色，三觀駐朝光。'

In [92]:
predict_ch8('四海皇風被，千年德水清。', 36, model, vocab, d2l.try_gpu())

'四海皇風被，千年德水清。高日猶春色，無樹散朝光。玉茲三初影，風風散朝光。昔流三地影，無為散還風。'