In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [None]:
batch_size = 32
num_steps = 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

### 定义模型

In [None]:
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens) # 只包含隐藏的循环层，不包含输出层

# 初始化隐状态，形状（隐藏层数， 批量大小， 隐藏单元数）
state = torch.zeros([1, batch_size, num_hiddens])
state.shape

In [None]:
X = torch.rand(size=(num_steps, batch_size, len(vocab)))
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape

In [None]:
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        # 如果RNN是双向的，num_directions=2，否则=1
        if not self.rnn.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)
        
    def forward(self, inputs, state):
        X = F.one_hot(inputs.T.long(), self.vocab_size)
        X = X.to(torch.float32)
        Y, state = self.rnn(X, state)
        
        output = self.linear(Y.reshape([-1, Y.shape[-1]]))
        return output, state
    
    def begin_state(self, device, batch_size=1):
        if not isinstance(self.rnn, nn.LSTM):
            return torch.zeros(
                (self.num_directions*self.rnn.num_layers, batch_size, self.num_hiddens),
                device = device,
            )
        else:
            return (
                torch.zeros(
                    (self.num_directions*self.rnn.num_layers, batch_size, self.num_hiddens),
                    device = device,
                ),
                torch.zeros(
                    (self.num_directions*self.rnn.num_layers, batch_size, self.num_hiddens),
                    device = device,
                )
            )

### 训练与预测

In [None]:
torch.backends.cudnn.enabled = False
device = d2l.try_gpu()
net = RNNModel(rnn_layer, vocab_size=len(vocab))
net = net.to(device)
d2l.predict_ch8('time traveller', 10, net, vocab, device)

In [None]:
num_epochs, lr = 500, 1
d2l.train_ch8(
    net, 
    train_iter,
    vocab,
    lr,
    num_epochs,
    device
)