# 循环神经网络的从零开始实现
:label:`sec_rnn_scratch`

本节将根据 :numref:`sec_rnn`中的描述，
从头开始基于循环神经网络实现字符级语言模型。
这样的模型将在H.G.Wells的时光机器数据集上训练。
和前面 :numref:`sec_language_model`中介绍过的一样，
我们先读取数据集。


In [12]:
import mlx
import mlx.core as mx
import mlx.nn as nn
from d2l import mlx as d2l

batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
num_hiddens = 256
rnn_layer = nn.RNN(len(vocab), num_hiddens)
state = mx.zeros((1, batch_size, num_hiddens))
state.shape
X = mx.random.uniform(shape=(num_steps, batch_size, len(vocab)))
X = X.swapaxes(0, 1)
output = rnn_layer(X, state)
Y = output[0, :, :, :].swapaxes(0, 1) 
state_new = output[:, :, -1, :]  
Y.shape, state_new.shape
#@save
class RNNModel(nn.Module):
    """循环神经网络模型"""
    def __init__(self, rnn_layer, vocab_size, bidirectional=False, **kwargs):
        super(RNNModel, self).__init__(**kwargs)
        self.rnn = rnn_layer
        self.vocab_size = vocab_size
        self.num_hiddens = self.rnn.hidden_size
        self.bidirectional = bidirectional
        # 如果RNN是双向的，num_directions应该是2，否则应该是1
        if not self.bidirectional:
            self.num_directions = 1
            self.linear = nn.Linear(self.num_hiddens, self.vocab_size)
        else:
            self.num_directions = 2
            self.linear = nn.Linear(self.num_hiddens * 2, self.vocab_size)

    def __call__(self, inputs, state):
        X = d2l.mlx_one_hot(inputs.T, self.vocab_size)
        X = X.astype(mx.float32)
        Y, state = self.rnn(X, state)
        # 全连接层首先将Y的形状改为(时间步数*批量大小,隐藏单元数)
        # 它的输出形状是(时间步数*批量大小,词表大小)。
        output = self.linear(Y.reshape((-1, Y.shape[-1])))
        return output, state

    def begin_state(self, batch_size=1, device=None):
        if not isinstance(self.rnn, nn.LSTM):
            # nn.GRU以张量作为隐状态
            return mx.zeros((self.num_directions * self.rnn.num_layers,
                             batch_size, self.num_hiddens))
        else:
            # nn.LSTM以元组作为隐状态
            return (mx.zeros((self.num_directions * self.rnn.num_layers,
                              batch_size, self.num_hiddens)),
                    mx.zeros((self.num_directions * self.rnn.num_layers,
                              batch_size, self.num_hiddens)))
net = RNNModel(rnn_layer, vocab_size=len(vocab))
d2l.predict_ch8('time traveller', 10, net, vocab)

<class 'mlx.core.array'>
1
array([[[[-0.221779, 0.0223063, -0.0102052, ..., -0.0573121, -0.0869901, -0.0876293],
         [-0.19614, -0.155949, -0.0260539, ..., 0.189071, 0.111931, -0.0916073],
         [-0.272562, -0.21849, -0.0331099, ..., 0.108575, 0.00223928, 0.150827],
         ...,
         [-0.209083, -0.0848471, -0.0818376, ..., 0.111248, 0.0574137, 0.0397607],
         [-0.0710199, -0.159879, -0.0994042, ..., 0.0889667, 0.121993, -0.00529941],
         [-0.178878, -0.00646856, -0.0471725, ..., 0.192943, 0.0831682, -0.0773583]],
        [[-0.274358, 0.0685572, 0.00611183, ..., 0.00920813, 0.0223995, -0.0911416],
         [-0.0758498, -0.172444, -0.0267636, ..., 0.227118, 0.054165, -0.154835],
         [-0.0681467, -0.202442, -0.0573231, ..., 0.164708, 0.00969663, -0.00122862],
         ...,
         [-0.20751, -0.208538, -0.0699537, ..., 0.103278, 0.07719, 0.100141],
         [-0.255376, -0.130081, 0.0785917, ..., 0.226522, 0.0763116, 0.0037542],
         [-0.161127, -0.0957894