In [1]:
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

In [2]:
# print(''.join(vocab.to_tokens(list(y.reshape(-1)))))
# print(''.join(vocab.to_tokens(list(y[1]))))

In [3]:
for X, y in train_iter:
    print(y, y.T.shape)
    print(y.reshape(-1), y.T.shape)
    print(y.T.reshape(-1), y.T.shape)
    break

tensor([[ 4, 15,  9,  ..., 22,  2, 12],
        [ 1, 21, 14,  ...,  5,  6,  1],
        [ 6,  1, 16,  ..., 19,  4, 11],
        ...,
        [ 2, 10,  1,  ...,  7, 10,  2],
        [ 5,  6,  2,  ...,  2, 24, 20],
        [ 8, 13,  3,  ...,  3,  1, 11]]) torch.Size([35, 32])
tensor([ 4, 15,  9,  ...,  3,  1, 11]) torch.Size([35, 32])
tensor([ 4,  1,  6,  ...,  2, 20, 11]) torch.Size([35, 32])


In [4]:
def get_params(vocab_size, hidden_states):
    W_xh = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    b_h = torch.zeros((hidden_states, ), requires_grad=True)
    W_hh = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    
    W_hq = torch.zeros((hidden_states, vocab_size), requires_grad=True)
    b_q = torch.zeros((vocab_size, ), requires_grad=True)
    
    
    torch.nn.init.normal_(W_xh, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hh, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hq, mean=0.0, std=0.01)
    
    return [W_xh, b_h, W_hh, W_hq, b_q]
# get_params(15, 3)

# b_q = torch.normal(mean=0.1, std=0.5)
# b_q

In [5]:
def init_rnn_state(batch_size, hidden_state):
    return (torch.zeros((batch_size, hidden_state)), )

In [6]:
def rnn(inputs, state, params):
    W_xh, b_h, W_hh, W_hq, b_q = params
    
    H, = state
    # inputs: (num_steps x batch_size x vocab_size)
    # H: (n x hidden_size)
    # X: (batch_size, vocab_size)
    outputs = []
    for X in inputs:
#         print(X.shape, W_xh.shape, H.shape, W_hh.shape, b_h.shape)
        H = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(H, W_hh) + b_h)
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    # return shape: (n x vocab_size)
    return torch.cat(outputs, dim=0), (H, )

In [7]:
rnn(torch.rand((3, 2, 10)), init_rnn_state(2, 5), get_params(10, 5))

(tensor([[-7.3385e-05, -1.0246e-04, -2.8041e-04,  2.8724e-04, -2.9519e-04,
           2.1692e-05,  2.5532e-04, -7.7865e-05,  1.3710e-05,  2.0447e-05],
         [ 2.7390e-05,  1.2086e-05, -1.1765e-04,  1.4167e-04, -1.2727e-04,
          -7.3450e-05,  1.7299e-04, -7.4599e-05,  1.5259e-04, -6.3998e-05],
         [ 3.5608e-05, -2.3519e-04,  5.1493e-04,  3.9119e-04,  1.5344e-04,
          -2.0193e-04, -1.3542e-04, -1.3988e-04,  1.5888e-04,  4.9195e-05],
         [ 1.8829e-04, -4.0616e-04,  2.6282e-04,  5.0776e-04, -1.1703e-04,
          -2.7430e-04,  1.5964e-04, -1.4259e-04,  3.7970e-04, -1.9757e-05],
         [ 2.6922e-04,  3.2929e-04,  7.6923e-05, -4.3173e-04,  1.9798e-04,
          -2.4807e-04,  1.0968e-05,  3.8460e-07,  4.5015e-04, -1.3922e-04],
         [ 2.7545e-04, -1.5917e-04,  4.4537e-05, -1.7682e-04, -1.5242e-05,
          -1.5002e-04,  6.1948e-05,  7.3683e-05,  2.8907e-04,  1.2858e-05]],
        grad_fn=<CatBackward>),
 (tensor([[-0.0018,  0.0142,  0.0332, -0.0040, -0.0084],
    

In [8]:
class RNN:
    def __init__(self, vocab_size, num_hiddens, get_params,
             init_state, forward_fn):
        self.vocab_size  = vocab_size 
        self.num_hiddens  = num_hiddens 
        self.params = get_params(vocab_size, num_hiddens)
        self.init_state  = init_state 
        self.forward_fn = forward_fn
    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)
    def begin_state(self, batch_size):
        return self.init_state(batch_size, self.num_hiddens)

In [9]:
net = RNN(len(vocab), 128, get_params, init_rnn_state, rnn)

In [10]:
X = torch.arange(10).reshape((2, 5))
state = net.begin_state(X.shape[0])
Y, new_state = net(X, state)

Y.argmax(dim=1).shape, new_state[0].shape

(torch.Size([10]), torch.Size([2, 128]))

In [11]:
net.params[0].grad

In [12]:
def gradient_clipping(net, theta):
    for param in net.params:
        if (param.requires_grad):
            norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in net.params))
            if (norm > theta):
                param.grad = (theta/norm)*param.grad

In [13]:
prefix = "xin chao"
torch.tensor([vocab.token_to_idx[char] for char in prefix]).reshape((-1, 1))
# vocab.idx_to_token[[24, 25 ,26 ]]

tensor([[24],
        [ 5],
        [ 6],
        [ 1],
        [15],
        [ 9],
        [ 4],
        [ 7]])

In [19]:
def predict(prefix, num_preds, net, vocab):
    prefix_id = [vocab.token_to_idx[char] for char in prefix]
    state = net.begin_state(batch_size=1)
    for id in prefix_id:
        _, state = net(torch.tensor(id).reshape((1, 1)), state)
    
    outputs = [prefix_id[-1]]
    for _ in range(num_preds):
        output, state = net(torch.tensor(outputs[-1]).reshape((1, 1)), state)
        outputs.append(int(torch.argmax(output, dim=1)))
    
    return ''.join(vocab.idx_to_token[idx] for idx in outputs[1:])

In [15]:
predict(prefix, 10, net, vocab)

'xin chaobaadzenkii'

In [16]:
def train_epoch(net, loss, updater, train_iter, use_random_iter):
    
    state = None
    total_loss = 0
    total_exp = 0
    
    for X, y in train_iter:
        if not use_random_iter:
            if not state:
                state = net.begin_state(X.shape[0])
            else:
                for s in state:
                    # remove gradient
                    s.detach_()
            y_pred, state = net(X, state)
        else:
            state = net.begin_state(X.shape[0])
            y_pred, _ = net(X, state)    
        
        y = y.T.reshape(-1)
        l = loss(y_pred, y).mean()
        
        updater.zero_grad()
        l.backward()
        gradient_clipping(net, theta=1)
        updater.step()
        
#         print(f'Loss {l.sum():.2f}')
        
        total_loss += l*y.numel()
        total_exp += y.numel()
        
    return torch.exp(total_loss/total_exp)
        

In [17]:
def train(train_iter, vocab, lr, num_epochs, hidden_size=512):
    
    net = RNN(len(vocab), hidden_size, get_params, init_rnn_state, rnn)
    loss = torch.nn.CrossEntropyLoss()
    updater = torch.optim.SGD(net.params, lr=lr)
    
    
    for epoch in range(num_epochs):
        epoch_loss = train_epoch(net, loss, updater, train_iter, use_random_iter=True)
        
        print(f'Epoch {epoch}, Loss {epoch_loss}')
        
        if ((epoch + 1) % 100 == 0):
            print(f'time traveller {predict("time traveller", 50, net, vocab)}')
        

In [18]:
train(train_iter, vocab, 1, 500)

Epoch 0, Loss 24.770179748535156
Epoch 1, Loss 19.81338882446289
Epoch 2, Loss 17.83038330078125
Epoch 3, Loss 17.534820556640625
Epoch 4, Loss 17.130399703979492
Epoch 5, Loss 16.602006912231445
Epoch 6, Loss 16.087284088134766
Epoch 7, Loss 15.232954978942871
Epoch 8, Loss 14.299644470214844
Epoch 9, Loss 13.546829223632812
Epoch 10, Loss 12.9856595993042
Epoch 11, Loss 12.51760482788086
Epoch 12, Loss 12.116098403930664
Epoch 13, Loss 11.668583869934082
Epoch 14, Loss 11.475079536437988
Epoch 15, Loss 11.336302757263184
Epoch 16, Loss 11.026881217956543
Epoch 17, Loss 10.890450477600098
Epoch 18, Loss 10.754717826843262
Epoch 19, Loss 10.599109649658203
Epoch 20, Loss 10.628496170043945
Epoch 21, Loss 10.409686088562012
Epoch 22, Loss 10.368217468261719
Epoch 23, Loss 10.163802146911621
Epoch 24, Loss 10.095879554748535
Epoch 25, Loss 10.051568984985352
Epoch 26, Loss 9.908416748046875
Epoch 27, Loss 9.880373001098633
Epoch 28, Loss 9.775274276733398
Epoch 29, Loss 9.606200218200684