In [1]:
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
import math

In [2]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

In [3]:
def get_params(vocab_size, hidden_states):
    
    # reset gate part
    W_xr = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    W_hr = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    b_r = torch.zeros((hidden_states, ), requires_grad=True)
    
    # update gate part
    W_xz = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    W_hz = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    b_z = torch.zeros((hidden_states, ), requires_grad=True)
    
    W_xh = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    b_h = torch.zeros((hidden_states, ), requires_grad=True)
    W_hh = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    
    W_hq = torch.zeros((hidden_states, vocab_size), requires_grad=True)
    b_q = torch.zeros((vocab_size, ), requires_grad=True)
    
    
    torch.nn.init.normal_(W_xh, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hh, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hq, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_xr, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hr, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_xz, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hz, mean=0.0, std=0.01)
    
    return [W_xh, b_h, W_hh, W_hq, b_q, W_xr, W_hr, b_r, W_xz, W_hz, b_z]
# get_params(15, 3)

# b_q = torch.normal(mean=0.1, std=0.5)
# b_q

In [6]:
def init_gru_state(batch_size, hidden_state):
    return (torch.zeros((batch_size, hidden_state)), )

In [4]:
def gru(inputs, state, params):
    W_xh, b_h, W_hh, W_hq, b_q, W_xr, W_hr, b_r, W_xz, W_hz, b_z = params
    
    H, = state
    # inputs: (num_steps x batch_size x vocab_size)
    # H: (n x hidden_size)
    # X: (batch_size, vocab_size)
    outputs = []
    for X in inputs:
#         print(X.shape, W_xh.shape, H.shape, W_hh.shape, b_h.shape)

        # broadcast sum
        R_t = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)
        Z_t = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)
        
        H_hat = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(R_t*H, W_hh) + b_h)
        H = Z_t*H + (1 - Z_t)*H_hat
        
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    # return shape: (n x vocab_size)
    return torch.cat(outputs, dim=0), (H, )

In [7]:
gru(torch.rand((3, 2, 10)), init_gru_state(2, 5), get_params(10, 5))

(tensor([[-1.1012e-04, -3.7426e-05, -1.8630e-04, -4.2762e-04, -1.2853e-04,
           2.9462e-04, -6.3448e-05,  2.5036e-05, -5.3788e-05,  2.7107e-04],
         [-1.0510e-04,  6.2361e-05, -1.9892e-04, -1.9748e-04,  6.5237e-05,
           2.2109e-04, -3.4991e-05, -7.8285e-06,  9.7819e-06,  1.9195e-04],
         [-1.7534e-04, -1.4133e-04, -3.5978e-04, -5.9386e-04, -1.5025e-04,
           4.5624e-04, -2.2088e-05,  2.5866e-05, -1.6248e-04,  4.1215e-04],
         [ 2.9526e-05,  1.1608e-04,  1.9812e-04, -5.4853e-04, -6.3617e-05,
           4.5388e-05, -2.7150e-04, -5.9151e-05,  1.4461e-04,  3.8531e-04],
         [-2.8655e-04, -1.4759e-04, -6.0448e-04, -6.8057e-04, -1.7059e-04,
           6.7199e-04,  2.5161e-05,  5.3529e-05, -2.2380e-04,  4.6706e-04],
         [ 1.7052e-05,  7.7710e-05,  2.6110e-04, -7.9264e-04, -1.2314e-04,
           1.7722e-04, -3.9266e-04,  1.7556e-05,  2.0575e-04,  5.7626e-04]],
        grad_fn=<CatBackward>),
 (tensor([[ 0.0073, -0.0038, -0.0488, -0.0079, -0.0163],
    

In [8]:
class GRU:
    def __init__(self, vocab_size, num_hiddens, get_params,
             init_state, forward_fn):
        self.vocab_size  = vocab_size 
        self.num_hiddens  = num_hiddens 
        self.params = get_params(vocab_size, num_hiddens)
        self.init_state  = init_state 
        self.forward_fn = forward_fn
    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)
    def begin_state(self, batch_size):
        return self.init_state(batch_size, self.num_hiddens)

In [9]:
net = GRU(len(vocab), 128, get_params, init_gru_state, gru)

X = torch.arange(10).reshape((2, 5))
state = net.begin_state(X.shape[0])
Y, new_state = net(X, state)

Y.argmax(dim=1).shape, new_state[0].shape

(torch.Size([10]), torch.Size([2, 128]))

In [10]:
def predict(prefix, num_preds, net, vocab):
    prefix_id = [vocab.token_to_idx[char] for char in prefix]
    state = net.begin_state(batch_size=1)
    for id in prefix_id:
        _, state = net(torch.tensor(id).reshape((1, 1)), state)
    
    outputs = [prefix_id[-1]]
    for _ in range(num_preds):
        output, state = net(torch.tensor(outputs[-1]).reshape((1, 1)), state)
        outputs.append(int(torch.argmax(output, dim=1)))
    
    return ''.join(vocab.idx_to_token[idx] for idx in outputs[1:])

In [12]:
prefix = "time traveller"
predict(prefix, 10, net, vocab)

'owywyxvypr'

In [13]:
def train_epoch(net, loss, updater, train_iter, use_random_iter):
    
    state = None
    total_loss = 0
    total_exp = 0
    
    for X, y in train_iter:
        if not use_random_iter:
            if not state:
                state = net.begin_state(X.shape[0])
            else:
                for s in state:
                    # remove gradient
                    s.detach_()
            y_pred, state = net(X, state)
        else:
            state = net.begin_state(X.shape[0])
            y_pred, _ = net(X, state)    
        
        y = y.T.reshape(-1)
        l = loss(y_pred, y).mean()
        
        updater.zero_grad()
        l.backward()
#         gradient_clipping(net, theta=1)
        updater.step()
        
#         print(f'Loss {l.sum():.2f}')
        
        total_loss += l*y.numel()
        total_exp += y.numel()
        
    return torch.exp(total_loss/total_exp)
        

In [16]:
def train(train_iter, vocab, lr, num_epochs, hidden_size=256):
    
    net = GRU(len(vocab), hidden_size, get_params, init_gru_state, gru)
    loss = torch.nn.CrossEntropyLoss()
    updater = torch.optim.SGD(net.params, lr=lr)
    
    
    for epoch in range(num_epochs):
        epoch_loss = train_epoch(net, loss, updater, train_iter, use_random_iter=True)
        
        print(f'Epoch {epoch}, Loss {epoch_loss}')
        
        if ((epoch + 1) % 100 == 0):
            print(f'time traveller {predict("time traveller", 50, net, vocab)}')
        

In [17]:
train(train_iter, vocab, 1, 500)

Epoch 0, Loss 24.931453704833984
Epoch 1, Loss 20.81997299194336
Epoch 2, Loss 19.019668579101562
Epoch 3, Loss 18.162702560424805
Epoch 4, Loss 17.82379150390625
Epoch 5, Loss 17.642223358154297
Epoch 6, Loss 17.51961326599121
Epoch 7, Loss 17.442697525024414
Epoch 8, Loss 17.34539031982422
Epoch 9, Loss 17.23731803894043
Epoch 10, Loss 17.088499069213867
Epoch 11, Loss 16.9556884765625
Epoch 12, Loss 16.758411407470703
Epoch 13, Loss 16.57127571105957
Epoch 14, Loss 16.35809326171875
Epoch 15, Loss 16.106618881225586
Epoch 16, Loss 15.955818176269531
Epoch 17, Loss 15.981688499450684
Epoch 18, Loss 15.482518196105957
Epoch 19, Loss 15.237502098083496
Epoch 20, Loss 15.204707145690918
Epoch 21, Loss 15.067525863647461
Epoch 22, Loss 14.5711669921875
Epoch 23, Loss 14.302877426147461
Epoch 24, Loss 13.975485801696777
Epoch 25, Loss 13.69554615020752
Epoch 26, Loss 13.828149795532227
Epoch 27, Loss 13.254283905029297
Epoch 28, Loss 13.00719165802002
Epoch 29, Loss 12.726114273071289
Epo