In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
from torch.functional import F
import numpy
from d2l import torch as d2l

# 数据

In [None]:
batch_size = 32
num_steps = 35

train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

# 模型

In [None]:
class BidirectionalRNNModel(nn.Module):
    def __init__(self, num_inputs, num_hiddens):
        super().__init__()
        self.num_inputs = num_inputs
        self.num_hiddens = num_hiddens
        self.rnn = nn.RNN(num_inputs, num_hiddens, bidirectional=True)
        self.linear = nn.Linear(2 * num_hiddens, num_inputs)

    def forward(self, X, H):
        # X: [batch_size, num_steps] -> [num_steps, batch_size, num_inputs]
        # H: [num_directions * num_layers, batch_size, num_hiddens]
        X = F.one_hot(X.T, self.num_inputs).type(torch.float32)
        # Y: [num_steps, batch_size, num_directions * num_hiddens] ->
        #    [num_steps * batch_size, num_directions * num_hiddens]
        Y, H = self.rnn(X, H)
        Y = Y.reshape(-1, 2 * self.num_hiddens)
        Y = self.linear(Y)
        return Y, H

# 训练

In [None]:
def grad_clipping(net, theta):
    if isinstance(net, nn.Module):
        params = [p for p in net.parameters() if p.requires_grad]
    else:
        params = net.params
    norm = math.sqrt(sum(torch.sum((p.grad**2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [None]:
def train(model, train_iter, loss_fn, optimizer, num_epochs, device):
    metric = d2l.Accumulator(2)
    timer = d2l.Timer()
    animator = d2l.Animator(xlabel=['epoch'], ylabel=['perplexity'],
                            legend=['train'], xlim=[1, num_epochs])
    model.to(device)
    model.train()
    for epoch in range(num_epochs):
        metric.reset()
        timer.start()
        for X, Y in train_iter:
            # X: [batch_size, num_steps]
            # Y: [batch_size, num_steps]
            X, Y = X.to(device), Y.to(device)
            Y = Y.T.reshape(-1)
            H = torch.zeros((2, len(X), model.num_hiddens)).to(device)
            Y_hat, _ = model(X, H)
            loss = loss_fn(Y_hat, Y)
            optimizer.zero_grad()
            loss.backward()
            grad_clipping(model, 1)
            optimizer.step()
            metric.add(loss * Y.numel(), Y.numel())
        animator.add(epoch + 1, math.exp(metric[0] / metric[1]))
        print(f'perplexity: {math.exp(metric[0] / metric[1])}, '
              f'speed: {metric[1] / timer.stop()} token(s)/sec')


In [None]:
lr = 1
num_epochs = 500

model = BidirectionalRNNModel(num_inputs=len(vocab), num_hiddens=256)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr)

train(model, train_iter, loss_fn, optimizer, num_epochs, d2l.try_gpu())

# 推理

In [None]:
def predict(model, prefix, num_preds, vocab, device):
    model.to(device)
    model.eval()
    prefix = vocab[list(prefix)]
    prefix = [torch.tensor([p]).reshape(1, 1) for p in prefix]
    H = torch.zeros((2, 1, model.num_hiddens)).to(device)
    for X in prefix:
        X = X.to(device)
        Y_hat, H = model(X, H)
    pred = [torch.argmax(Y_hat, dim=1, keepdim=True)]
    for _ in range(num_preds):
        Y_hat, H = model(pred[-1], H)
        pred.append(torch.argmax(Y_hat, dim=1, keepdim=True))
    pred = [vocab.idx_to_token[idx] for idx in pred]
    return pred


In [None]:
prefix = 'time '
pred = ''.join(predict(model, prefix, 50, vocab, d2l.try_gpu()))
print(prefix + pred)