In [1]:
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
import math

In [2]:
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

In [11]:
def get_params(vocab_size, hidden_states):
    
    # input gate
    W_xi = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    W_hi = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    b_i = torch.zeros((hidden_states, ), requires_grad=True)
    
    # forget gate
    W_xf = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    W_hf = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    b_f = torch.zeros((hidden_states, ), requires_grad=True)
    
    # output gate
    W_xo = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    W_ho = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    b_o = torch.zeros((hidden_states, ), requires_grad=True)
    
    # candidate memory 
    W_xc = torch.zeros((vocab_size, hidden_states), requires_grad=True)
    W_hc = torch.zeros((hidden_states, hidden_states), requires_grad=True)
    b_c = torch.zeros((hidden_states, ), requires_grad=True)
    
    W_hq = torch.zeros((hidden_states, vocab_size), requires_grad=True)
    b_q = torch.zeros((vocab_size, ), requires_grad=True)
    
    torch.nn.init.normal_(W_xi, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hi, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_xf, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hf, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_xo, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_ho, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_xc, mean=0.0, std=0.01)
    torch.nn.init.normal_(W_hc, mean=0.0, std=0.01)
    
    return [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q]
# get_params(15, 3)

# b_q = torch.normal(mean=0.1, std=0.5)
# b_q

In [12]:
def init_lstm_state(batch_size, hidden_state):
    return (torch.zeros((batch_size, hidden_state)), torch.zeros((batch_size, hidden_state)))

In [13]:
def lstm(inputs, state, params):
    W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c, W_hq, b_q = params
    
    H, C_t = state
    # inputs: (num_steps x batch_size x vocab_size)
    # H: (n x hidden_size)
    # X: (batch_size, vocab_size)
    outputs = []
    for X in inputs:
#         print(X.shape, W_xh.shape, H.shape, W_hh.shape, b_h.shape)

        # broadcast sum
        I_t = torch.sigmoid(torch.matmul(X, W_xi) + torch.matmul(H, W_hi) + b_i)
        F_t = torch.sigmoid(torch.matmul(X, W_xf) + torch.matmul(H, W_hf) + b_f)
        O_t = torch.sigmoid(torch.matmul(X, W_xo) + torch.matmul(H, W_ho) + b_o)
        
        C_t_candidate = torch.tanh(torch.mm(X, W_xc) + torch.mm(H, W_hc) + b_c)
        C_t = F_t*C_t + I_t*C_t_candidate
    
        H = O_t * torch.tanh(C_t)
        
        Y = torch.mm(H, W_hq) + b_q
        outputs.append(Y)
    # return shape: (n x vocab_size)
    return torch.cat(outputs, dim=0), (H, C_t)

In [14]:
lstm(torch.rand((3, 2, 10)), init_lstm_state(2, 5), get_params(10, 5))

(tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], grad_fn=<CatBackward>),
 (tensor([[ 0.0012, -0.0029, -0.0056, -0.0029, -0.0057],
          [ 0.0004, -0.0028, -0.0093, -0.0092, -0.0082]], grad_fn=<MulBackward0>),
  tensor([[ 0.0024, -0.0058, -0.0110, -0.0058, -0.0113],
          [ 0.0008, -0.0056, -0.0185, -0.0182, -0.0166]], grad_fn=<AddBackward0>)))

In [15]:
class LSTM:
    def __init__(self, vocab_size, num_hiddens, get_params,
             init_state, forward_fn):
        self.vocab_size  = vocab_size 
        self.num_hiddens  = num_hiddens 
        self.params = get_params(vocab_size, num_hiddens)
        self.init_state  = init_state 
        self.forward_fn = forward_fn
    def __call__(self, X, state):
        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)
        return self.forward_fn(X, state, self.params)
    def begin_state(self, batch_size):
        return self.init_state(batch_size, self.num_hiddens)

In [16]:
net = LSTM(len(vocab), 128, get_params, init_lstm_state, lstm)

X = torch.arange(10).reshape((2, 5))
state = net.begin_state(X.shape[0])
Y, new_state = net(X, state)

Y.argmax(dim=1).shape, new_state[0].shape

(torch.Size([10]), torch.Size([2, 128]))

In [17]:
def predict(prefix, num_preds, net, vocab):
    prefix_id = [vocab.token_to_idx[char] for char in prefix]
    state = net.begin_state(batch_size=1)
    for id in prefix_id:
        _, state = net(torch.tensor(id).reshape((1, 1)), state)
    
    outputs = [prefix_id[-1]]
    for _ in range(num_preds):
        output, state = net(torch.tensor(outputs[-1]).reshape((1, 1)), state)
        outputs.append(int(torch.argmax(output, dim=1)))
    
    return ''.join(vocab.idx_to_token[idx] for idx in outputs[1:])

In [25]:
prefix = "time traveller"
predict(prefix, 10, net, vocab)

'<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>'

In [26]:
def train_epoch(net, loss, updater, train_iter, use_random_iter):
    
    state = None
    total_loss = 0
    total_exp = 0
    
    for X, y in train_iter:
        if not use_random_iter:
            if not state:
                state = net.begin_state(X.shape[0])
            else:
                for s in state:
                    # remove gradient
                    s.detach_()
            y_pred, state = net(X, state)
        else:
            state = net.begin_state(X.shape[0])
            y_pred, _ = net(X, state)    
        
        y = y.T.reshape(-1)
        l = loss(y_pred, y).mean()
        
        updater.zero_grad()
        l.backward()
#         gradient_clipping(net, theta=1)
        updater.step()
        
#         print(f'Loss {l.sum():.2f}')
        
        total_loss += l*y.numel()
        total_exp += y.numel()
        
    return torch.exp(total_loss/total_exp)
        

In [28]:
def train(train_iter, vocab, lr, num_epochs, hidden_size=256):
    
    net = LSTM(len(vocab), hidden_size, get_params, init_lstm_state, lstm)
    loss = torch.nn.CrossEntropyLoss()
    updater = torch.optim.SGD(net.params, lr=lr)
    
    
    for epoch in range(num_epochs):
        epoch_loss = train_epoch(net, loss, updater, train_iter, use_random_iter=True)
        
        print(f'Epoch {epoch}, Loss {epoch_loss}')
        
        if ((epoch + 1) % 100 == 0):
            print(f'time traveller {predict("time traveller", 50, net, vocab)}')
        

In [29]:
train(train_iter, vocab, 1, 500)

Epoch 0, Loss 25.035240173339844
Epoch 1, Loss 21.220958709716797
Epoch 2, Loss 19.77048110961914
Epoch 3, Loss 19.077350616455078
Epoch 4, Loss 18.683561325073242
Epoch 5, Loss 18.43525505065918
Epoch 6, Loss 18.271305084228516
Epoch 7, Loss 18.12274742126465
Epoch 8, Loss 18.0427303314209
Epoch 9, Loss 17.968460083007812
Epoch 10, Loss 17.914621353149414
Epoch 11, Loss 17.862112045288086
Epoch 12, Loss 17.819320678710938
Epoch 13, Loss 17.77012825012207
Epoch 14, Loss 17.721206665039062
Epoch 15, Loss 17.730560302734375
Epoch 16, Loss 17.684907913208008
Epoch 17, Loss 17.65266990661621
Epoch 18, Loss 17.631940841674805
Epoch 19, Loss 17.62339210510254
Epoch 20, Loss 17.56648063659668
Epoch 21, Loss 17.534595489501953
Epoch 22, Loss 17.49709701538086
Epoch 23, Loss 17.470632553100586
Epoch 24, Loss 17.425193786621094
Epoch 25, Loss 17.38326644897461
Epoch 26, Loss 17.31449317932129
Epoch 27, Loss 17.293567657470703
Epoch 28, Loss 17.209440231323242
Epoch 29, Loss 17.129117965698242
Ep