In [1]:
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:

# 加载歌词数据

(corpus_indices, 
 char_to_idx, 
 idx_to_char, 
 vocab_size) = d2l.load_data_jay_lyrics()


num_inputs = vocab_size
num_hiddens = 256
num_outputs = vocab_size

In [3]:

# 参数配置

num_epochs = 160
num_steps = 35
batch_size = 32
lr = 1e2
clipping_theta = 1e-2


pred_period = 40
pred_len = 50
prefixes = ['分开', '不分开']

In [4]:

# 长短期记忆
lstm_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens)


# 模型
model = d2l.RNNModel(lstm_layer, vocab_size)

In [5]:

# 训练
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,
                                  corpus_indices, idx_to_char, char_to_idx,
                                  num_epochs, num_steps, 
                                  lr, clipping_theta,batch_size, 
                                  pred_period, 
                                  pred_len, 
                                  prefixes)

epoch 40, perplexity 8.448260000644001335441205049E+457497, time 0.02 sec
 - 分开弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥
 - 不分开弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥弥
epoch 80, perplexity 8.978769825019023873265235867E+660764, time 0.02 sec
 - 分开垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂
 - 不分开垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂垂
epoch 120, perplexity 1.796622987950926302377814373E+662206, time 0.02 sec
 - 分开                                                  
 - 不分开                                                  
epoch 160, perplexity 2.489187241047119870666853498E+584110, time 0.02 sec
 - 分开                                                  
 - 不分开                                                  


## 小结

* 长短期记忆的隐藏层输出包括隐藏状态和记忆细胞。只有隐藏状态会传递到输出层。
* 长短期记忆的输入门、遗忘门和输出门可以控制信息的流动。
* 长短期记忆可以应对循环神经网络中的梯度衰减问题，并更好地捕捉时间序列中时间步距离较大的依赖关系。

In [None]:
import math
from decimal import Decimal



In [None]:
math.exp(1548159.46875)

In [None]:
result = Decimal(1749397.078125).exp()


print(result)
print('%s' % result)

In [None]:
type(result)