# 8.5. Implementation of Recurrent Neural Networks from Scratch

In [20]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

In [151]:
import logging
logging.getLogger().setLevel(logging.DEBUG)

In [135]:
def log_vars(var_list, var_names):
    if type(var_names) == str:
        var_names = var_names.split(", ")
    for var, var_name in zip(var_list, var_names):
        if type(var) == torch.Tensor:
            logging.info(f"{var_name}: {var.shape}")

In [136]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

Gives a list of the parameters

[W_xh, W_hh, b_h, W_hq, b_q] 

initialized appropriately. W's are initialized to Normal with std=0.01

In [177]:
def get_params(vocab_size, num_hidden, device):
    "Initialize and return a list of model parameters"
    num_inputs = num_outputs = vocab_size
    
    def normal(shape):
        return torch.randn(size=shape, device=device) * 0.01
    
    # Hidden Layer Parameters
    W_xh = normal((num_inputs, num_hidden))
    W_hh = normal((num_hidden, num_hidden))
    b_h = torch.zeros(num_hidden, device=device)
    
    W_hq = normal((num_hidden, num_outputs))
    b_q = torch.zeros(num_outputs, device=device)
    
    params = [W_xh, W_hh, b_h, W_hq, b_q]    
    for param in params:
        param.requires_grad_(True)
    
    return params    

Returns an initial state as tuple. 
First element is a tensor of zeros of size (batch_size, num_hidden)

(H_0, ) = init_rnn_state

In [161]:
def init_rnn_state(batch_size, num_hiddens, device):
    "Get initial hidden states of (batch_size, num_hidden)"
    return (torch.zeros((batch_size, num_hiddens), device=device),)

In [162]:
def rnn(inputs, state, params):
    "Applies the recursive RNN function to the batched input"
    W_xh, W_hh, b_h, W_hq, b_q = params
    H, = state
    outputs = []
    for X in inputs:
        H = torch.tanh( 
                torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    return torch.cat(outputs, dim=0), (H,)

In [181]:
num_hiddens = 512
device = torch.device("cpu")

X = torch.arange(10).reshape((2, 5))
X = F.one_hot(X.T, 28)
log_vars([X], "X")
params = get_params(vocab_size=28, num_hidden=10, device=device)
log_vars(params, "W_xh, W_hh, b_h, W_hq, b_q")
H = init_rnn_state(X.shape[0], num_hiddens=10, device=device)
Y, new_state = rnn(X, H, params)

INFO:root:X: torch.Size([5, 2, 28])
INFO:root:W_xh: torch.Size([28, 10])
INFO:root:W_hh: torch.Size([10, 10])
INFO:root:b_h: torch.Size([10])
INFO:root:W_hq: torch.Size([10, 28])
INFO:root:b_q: torch.Size([28])


RuntimeError: expected scalar type Long but found Float

In [153]:
class RNNModelScratch:
    def __init__(self, vocab_size, num_hiddens, device, 
                 get_params, init_state, forward_fn):
        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens
        self.params = get_params(vocab_size, num_hiddens, device)
        self.init_state, self.forward_fn = init_state, forward_fn
        
    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)
    
    def begin_state(self, batch_size, device):
        return self.init_state(batch_size, self.num_hiddens, device)

In [154]:
net = RNNModelScratch(len(vocab), num_hiddens, device, get_params,
                      init_rnn_state, rnn)


INFO:root:W_xh: torch.Size([28, 512])
INFO:root:W_hh: torch.Size([512, 512])
INFO:root:b_h: torch.Size([512])
INFO:root:W_hq: torch.Size([512, 28])
INFO:root:b_q: torch.Size([28])


In [145]:
X = torch.arange(10).reshape((2, 5))
F.one_hot(X.T, 28).shape

torch.Size([5, 2, 28])

In [150]:
state = net.begin_state(batch_size=X.shape[0], device=device)
Y, new_state = net(X.to(device), state)

In [148]:
Y.shape, len(new_state), new_state[0].shape

(torch.Size([10, 28]), 1, torch.Size([2, 512]))

### Prediction Function

In [159]:
def predict_ch8(prefix, num_preds, net, vocab, device):  #@save
    """Generate new characters following the `prefix`."""
    
    log_vars([prefix, num_preds, net, vocab, device], 
             "prefix, num_preds, net, vocab, device".split(", "))
    
    state = net.begin_state(batch_size=1, device=device)
    log_vars([state[0]], "state")
    
    outputs = [vocab[prefix[0]]]
    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape(
        (1, 1))
    for y in prefix[1:]:  # Warm-up period
        _, state = net(get_input(), state)
        outputs.append(vocab[y])
    for _ in range(num_preds):  # Predict `num_preds` steps
        y, state = net(get_input(), state)
        outputs.append(int(y.argmax(dim=1).reshape(1)))
    return ''.join([vocab.idx_to_token[i] for i in outputs])

In [160]:
predict_ch8('time traveller ', 10, net, vocab, device)

INFO:root:state: torch.Size([1, 512])


'time traveller klitbcslit'