**LSTM IN PRACTICE:**
- Fine grained operation of the LSTM, including layer normalization

In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)

<torch._C.Generator at 0x133139fb0>

In [2]:
lstm_layer = nn.LSTM(input_size=5, hidden_size=2, num_layers=1, batch_first=True)

Inspecting the LSTM's variables - weights and biases:

In [3]:
wi = lstm_layer.weight_ih_l0 # translates to these weights: (W_ii|W_if|W_ig|W_io)
wh = lstm_layer.weight_hh_l0 # translates to these weights: (W_hi|W_hf|W_hg|W_ho)
bi = lstm_layer.bias_ih_l0 # translates to these biases: (b_ii|b_if|b_ig|b_io)
bh = lstm_layer.bias_hh_l0 # translates to these biases: (b_hi|b_hf|b_hg|b_ho)

In [4]:
wi.shape, wh.shape, bi.shape, bh.shape

(torch.Size([8, 5]), torch.Size([8, 2]), torch.Size([8]), torch.Size([8]))

Weights and biases make perfect sense, considering the defined layer shapes.

Inference:

In [5]:
x_seq = torch.tensor([[1.0]*5, [2.0]*5, [3.0]*5]).unsqueeze(0).float()
x_seq.shape

torch.Size([1, 3, 5])

In [6]:
output, (hn, cn) = lstm_layer(x_seq)
print(output.shape, hn.shape, cn.shape)

torch.Size([1, 3, 2]) torch.Size([1, 1, 2]) torch.Size([1, 1, 2])


Working the lstm output out manually:

First, lets break down the weights.

In [7]:
wii, wif, wig, wio = wi[:2,:], wi[2:4,:], wi[4:6,:], wi[6:,:]
whi, whf, whg, who = wh[:2,:], wh[2:4,:], wh[4:6,:], wh[6:,:]
bii, bif, big, bio = bi[:2], bi[2:4], bi[4:6], bi[6:]
bhi, bhf, bhg, bho = bh[:2], bh[2:4], bh[4:6], bh[6:]

In [12]:
long_state = []
short_state = []

# loop over the sequence length: unroll through time
for t in range(3):
    # for each timestep, get the input at that corresponds to that timestep
    xt = x_seq[:, t, :]

    # making all computations that don't need any form of hidden state
    # first part of the input gate:
    it1 = torch.matmul(xt, torch.transpose(wii,0,1)) + bii
    # first part of the forget gate:
    ft1 = torch.matmul(xt, torch.transpose(wif,0,1)) + bif
    # first part of the new signal processed: g
    gt1 = torch.matmul(xt, torch.transpose(wig,0,1)) + big
    # first part of the output gate
    ot1 = torch.matmul(xt, torch.transpose(wio,0,1)) + bio

    if t>0:
        prev_h = short_state[t-1]
        prev_c = long_state[t-1]
    else:
        prev_h = torch.zeros((xt.shape[0], 2)) # at timestep 0, we use zeros of the shape we're trying to project our inputs to
        prev_c = torch.zeros((xt.shape[0], 2))
    # finish up the input gate:
    #it2 = torch.matmul(prev_h, whi) + bhi
    it = torch.sigmoid(it1+ torch.matmul(prev_h, torch.transpose(whi,0,1)) + bhi)

    # finish up the forget gate:
    #ft2 = torch.matmul(prev_h, whf) + bhf
    ft = torch.sigmoid(ft1 + torch.matmul(prev_h, torch.transpose(whf,0,1)) + bhf)

    # finish up the candidate signal
    #gt2 = torch.matmul(prev_h,whg) + bhg
    gt = torch.tanh(gt1+torch.matmul(prev_h, torch.transpose(whg,0,1)) + bhg)

    # finish up the output gate:
    #ot2 = torch.matmul(prev_h, who) + bho
    ot = torch.sigmoid(ot1+torch.matmul(prev_h, torch.transpose(who,0,1)) + bho)


    long_state_t = (ft*prev_c) + (it*gt)
    short_state_t = ot*(torch.tanh(long_state_t))

    long_state.append(long_state_t)
    short_state.append(short_state_t)

    

In [13]:
hn,cn

(tensor([[[-0.2660,  0.6454]]], grad_fn=<StackBackward0>),
 tensor([[[-0.4218,  1.4677]]], grad_fn=<StackBackward0>))

In [14]:
short_state[-1], long_state[-1]

(tensor([[-0.2660,  0.6454]], grad_fn=<MulBackward0>),
 tensor([[-0.4218,  1.4677]], grad_fn=<AddBackward0>))

In [11]:
output

tensor([[[-0.1993,  0.1615],
         [-0.2559,  0.4308],
         [-0.2660,  0.6454]]], grad_fn=<TransposeBackward0>)

Correctly implemented a very crude version of the lstm cell above:

That code is too crude however, what if we had an lstm cell:

In [15]:
lstm_cell = nn.LSTMCell(5,2)

Initialize the weights to the lstm layer's weights:

In [17]:
lstm_cell.weight_ih.data = wi
lstm_cell.weight_hh.data = wh
lstm_cell.bias_ih.data = bi
lstm_cell.bias_hh.data = bh

compute the output for the sequence:

In [18]:
long_state = []
short_state = []

for t in range(3):
    xt = x_seq[:,t,:]
    if t > 0:
        prev_h = short_state[t-1]
        prev_c = long_state[t-1]
    else:
        prev_h = torch.zeros((xt.shape[0],2))
        prev_c = torch.zeros((xt.shape[0],2))
    
    short_state_t, long_state_t = lstm_cell(xt, (prev_h, prev_c))

    long_state.append(long_state_t)
    short_state.append(short_state_t)

In [19]:
hn,cn

(tensor([[[-0.2660,  0.6454]]], grad_fn=<StackBackward0>),
 tensor([[[-0.4218,  1.4677]]], grad_fn=<StackBackward0>))

In [20]:
short_state[-1],long_state[-1]

(tensor([[-0.2660,  0.6454]], grad_fn=<MulBackward0>),
 tensor([[-0.4218,  1.4677]], grad_fn=<AddBackward0>))

So, clearly using an lstm cell alone is more elegant and clean code. If we were to use layer normalization and an lstm cell, there would be no way of having the layer normalization in before the final activation functions - hence the need for a custom cell.

In [None]:
class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input = nn.Linear(input_size, 4*hidden_size) # instead of breaking this down, could just compute these with one linear layer and break down its results later
        self.hidden = nn.Linear(hidden_size, 4*hidden_size)n # same for this
    def forward(self, x, prev_h, prev_c):
        computed_inputs = self.input(x)
        computed_hiddens = self.hidden(prev_h)
        # compute the input gate
        input_gate = torch.sigmoid(computed_inputs[:,:2]+computed_hiddens[:,:2])
        forget_gate = torch.sigmoid(computed_inputs[:,2:4]+computed_hiddens[:,2:4])
        signal = torch.tanh(computed_inputs[:,4:6]+computed_hiddens[:,4:6])
        output_gate = torch.sigmoid(computed_inputs[:,6:]+computed_hiddens[:,6:])
        long_state = forget_gate*prev_c + input_gate*signal
        return output_gate*torch.tanh(long_state), long_state

In [29]:
custom_lstm = CustomLSTMCell(5,2)

In [30]:
custom_lstm.input.weight.data = wi
custom_lstm.input.bias.data = bi
custom_lstm.hidden.weight.data = wh
custom_lstm.hidden.bias.data=bh

In [31]:
long_state = []
short_state = []

for t in range(3):
    xt = x_seq[:,t,:]
    if t > 0:
        prev_h = short_state[t-1]
        prev_c = long_state[t-1]
    else:
        prev_h = torch.zeros((xt.shape[0],2))
        prev_c = torch.zeros((xt.shape[0],2))
    
    short_state_t, long_state_t = custom_lstm(xt, prev_h, prev_c)

    long_state.append(long_state_t)
    short_state.append(short_state_t)

In [32]:
hn,cn

(tensor([[[-0.2660,  0.6454]]], grad_fn=<StackBackward0>),
 tensor([[[-0.4218,  1.4677]]], grad_fn=<StackBackward0>))

In [33]:
short_state[-1],long_state[-1]

(tensor([[-0.2660,  0.6454]], grad_fn=<MulBackward0>),
 tensor([[-0.4218,  1.4677]], grad_fn=<AddBackward0>))

Works fine, now a custom lstm cell with layer normalization:

In [34]:
class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input = nn.Linear(input_size, 4*hidden_size) # instead of breaking this down, could just compute these with one linear layer and break down its results later
        self.hidden = nn.Linear(hidden_size, 4*hidden_size) # same for this
        self.hidden_size = hidden_size
        self.ln = nn.LayerNorm(hidden_size)
    def forward(self, x, prev_h, prev_c):
        computed_inputs = self.input(x)
        computed_hiddens = self.hidden(prev_h)
        # compute the input gate
        input_gate = torch.sigmoid(computed_inputs[:,:self.hidden_size]+computed_hiddens[:,:self.hidden_size])
        forget_gate = torch.sigmoid(computed_inputs[:,self.hidden_size:2*self.hidden_size]+computed_hiddens[:,self.hidden_size:2*self.hidden_size])
        signal = torch.tanh(computed_inputs[:,2*self.hidden_size:3*self.hidden_size]+computed_hiddens[:,2*self.hidden_size:3*self.hidden_size])
        output_gate = torch.sigmoid(computed_inputs[:,3*self.hidden_size:4*self.hidden_size]+computed_hiddens[:,3*self.hidden_size:4*self.hidden_size])
        long_state = forget_gate*prev_c + input_gate*signal
        return output_gate*torch.tanh(self.ln(long_state)), long_state