In [1]:
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l_torch

In [50]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l_torch.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

For this simple RNN models, there are steps to follow:
1. To build the model, we need:
- A function to make the params for RNN model's layers, including OutputLayer and RNNLayer.
- A function to init the state of the model, which is a tensor of size (batch_size, num_hiddens).
- A forward function to tell the net how to work with the data.
- A wrapping class to hold all params and function.
2. To train and test the model, we need:
- A train function to train model (ofcourse :P)
- A predict function to see how well model do with our own eyes.
- To measure performance, we use 'perplexity'

In [51]:
# Create layers params
def get_params(vocab_size, num_hiddens, device):
    num_inputs = num_outputs = vocab_size

    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01

    # Hidden Layer Params
    W_xh = normal((num_inputs, num_hiddens))
    W_hh = normal((num_hiddens, num_hiddens))
    b_h = torch.zeros(num_hiddens, device=device)
    # Output Layer Params
    W_hq = normal((num_hiddens, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    # Attach gradient for params
    params = [W_xh, W_hh, b_h, W_hq, b_q]
    for param in params:
        param.requires_grad_(True)
    return params

In [4]:
# Initialize model state
def init_state(batch_size, num_hiddens, device):
    return (torch.zeros((batch_size, num_hiddens), device=device), )

In [5]:
def rnn_forward(inputs, state, params):
    # Unpack params
    W_xh, W_hh, b_h, W_hq, b_q = params
    outputs = []
    H, = state
    for X in inputs:
        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H, )

In [6]:
class RNN:
    def __init__(self, forward_fn, init_state, num_hiddens, vocab_size, device) -> None:
        self.params = get_params(vocab_size, num_hiddens, device=device)
        self.forward_fn = forward_fn
        self.init_state = init_state
        self.num_hiddens = num_hiddens
        self.vocab_size = vocab_size
    
    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)

    def begin_state(self, batch_size, device):
        return self.init_state(batch_size, self.num_hiddens, device=device)

In [7]:
def grad_clipping(net, theta):
    params = net.params
    norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params))
    if norm > theta:
        for param in params:
            param.grad[:] *= theta / norm

In [52]:
def train_epoch(net, train_iter, loss, updater, device):
    state = None
    metric = d2l_torch.Accumulator(2)
    for X, Y in train_iter:
        if state is None:
            state = net.begin_state(batch_size=X.shape[0], device=device)
        else:
            for s in state:
              s.detach_()
        y = Y.T.reshape(-1)
        X, y = X.to(device), y.to(device)
        Y_hat, state = net(X, state)
        l = loss(Y_hat, y.long()).mean()
        l.backward()
        grad_clipping(net, 1)
        updater(batch_size=1)
        metric.add(l*y.numel(), y.numel())
    return math.exp(metric[0]/metric[1])

In [57]:
def train_rnn(net, train_iter, num_epochs, learning_rate):
    loss = nn.CrossEntropyLoss()
    updater = lambda batch_size: d2l_torch.sgd(net.params, learning_rate, batch_size)
    best_perplexity = 99999
    patience = 5 # early stopping to prevent overshooting to overfitting zone
    for epoch in range(num_epochs):
        perplexity = train_epoch(net, train_iter, loss, updater, d2l_torch.try_gpu())
        if perplexity < best_perplexity:
            best_perplexity = perplexity
            patience = 5 # reset patience if manage to reduce perplexity
        else:
            if patience < 0:
                break
            patience -= 1
        if(epoch % 50 == 0):
            print(f"Epoch: {epoch:d}/{num_epochs:d}| Perplexity: {perplexity:.2f}")
    print(f"Perplexity: {perplexity:.2f}")

In [10]:
def predict(net, prefix, num_preds, device):
    state = net.begin_state(batch_size=1, device=device)
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))
    for y in prefix[1:]:  # Warm-up period
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):  # Predict `num_preds` steps
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [54]:
num_epochs, lr = 1500, 0.1
num_hiddens = 512

In [58]:
net = RNN(forward_fn=rnn_forward, 
            init_state=init_state, 
            vocab_size=len(vocab), 
            num_hiddens=num_hiddens, 
            device=d2l_torch.try_gpu())

In [59]:
train_rnn(net, train_iter=train_iter, num_epochs=num_epochs, learning_rate=lr)

Epoch: 0/1500| Perplexity: 27.62
Epoch: 50/1500| Perplexity: 16.57
Epoch: 100/1500| Perplexity: 13.01
Epoch: 150/1500| Perplexity: 11.02
Epoch: 200/1500| Perplexity: 10.24
Epoch: 250/1500| Perplexity: 9.66
Epoch: 300/1500| Perplexity: 9.27
Perplexity: 9.17


In [14]:
predict(net, "time traveller ", 20, d2l_torch.try_gpu())

'time traveller and the the the the '

After altering hyperparams, the result is:
- Scaling layer params help stablelizing the training process.
    - Extreme scaling (too big/too small of the multiplier) can be compensated by adjusting learning rate, but will still resulted in unstable training.
    - For this example, scaled by 0.01 displays best perplexity.