## This notebook shows the inner workings of how pytorch's LSTM module works

## Some equations that will come in handy

<img src="../pics/lstm_equs.png">

In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
torch.manual_seed(1)

n_in , n_out = 3, 5

inpt = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype = torch.float32)
# view input as (seq_len, batch_size, input_size)
# 2 words per sentence, 1 sentence per batch, (3 values per word?)
inpt = inpt.view(2, 1, 3)
print(f"inpt: {inpt}, inpt shape: {inpt.shape}")

lstm = nn.LSTM(n_in, n_out)

outs, (ht, ct) = lstm(inpt)

print(f'outs shape: {outs.shape} ht shape: {ht.shape} ct shape: {ct.shape}')

inpt: tensor([[[1., 2., 3.]],

        [[4., 5., 6.]]]), inpt shape: torch.Size([2, 1, 3])
outs shape: torch.Size([2, 1, 5]) ht shape: torch.Size([1, 1, 5]) ct shape: torch.Size([1, 1, 5])


In [3]:
# for name, params in lstm.named_parameters():
#     print(name)
#     print(params)
#     print(params.shape)
    
# print(lstm.weight_ih_l0)
# print(lstm.weight_hh_l0)
# print(lstm.bias_ih_l0,lstm.bias_ih_l0.shape)
# print(lstm.bias_hh_l0, lstm.bias_hh_l0.shape)

w_ih_l0 = lstm.weight_ih_l0
wii = w_ih_l0[0*n_out : 1*n_out, :]
wif = w_ih_l0[1*n_out : 2*n_out, :]
wig = w_ih_l0[2*n_out : 3*n_out, :]
wio = w_ih_l0[3*n_out : 4*n_out, :]

w_hh_l0 = lstm.weight_hh_l0
whi = w_hh_l0[0*n_out : 1*n_out, :]
whf = w_hh_l0[1*n_out : 2*n_out, :]
whg = w_hh_l0[2*n_out : 3*n_out, :]
who = w_hh_l0[3*n_out : 4*n_out, :]

b_ih_l0 = lstm.bias_ih_l0
bii = b_ih_l0[0*n_out : 1*n_out]
bif = b_ih_l0[1*n_out : 2*n_out]
big = b_ih_l0[2*n_out : 3*n_out]
bio = b_ih_l0[3*n_out : 4*n_out]

b_hh_l0 = lstm.bias_hh_l0
bhi = b_hh_l0[0*n_out : 1*n_out]
bhf = b_hh_l0[1*n_out : 2*n_out]
bhg = b_hh_l0[2*n_out : 3*n_out]
bho = b_hh_l0[3*n_out : 4*n_out]

print(whi, whi.shape)
# print(bhf, bhf.shape)
# print(bhg, bhg.shape)
# print(bho, bho.shape)

tensor([[-0.2891,  0.2905,  0.2715,  0.3966, -0.2507],
        [-0.0736, -0.0087,  0.0653, -0.3394, -0.3174],
        [ 0.2433, -0.1049,  0.2185,  0.0255,  0.1468],
        [ 0.0983,  0.1626,  0.2217, -0.4142,  0.2251],
        [-0.3144, -0.3374,  0.0272, -0.0762,  0.2627]],
       grad_fn=<SliceBackward>) torch.Size([5, 5])


In [4]:
input = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.float32)
input = input.view(2, 1, 3)

hid = torch.zeros((1, 5), dtype = torch.float32)
c_0 = torch.zeros((1, 5), dtype = torch.float32)

print(f'input = {input}, input shape: {input.shape}')
print(f'hid = {hid}, hid shape: {hid.shape}')
print(f'c_0 = {c_0}, c_0 shape: {c_0.shape}')

input = tensor([[[1., 2., 3.]],

        [[4., 5., 6.]]]), input shape: torch.Size([2, 1, 3])
hid = tensor([[0., 0., 0., 0., 0.]]), hid shape: torch.Size([1, 5])
c_0 = tensor([[0., 0., 0., 0., 0.]]), c_0 shape: torch.Size([1, 5])


In [5]:
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()

output = []
for i in range(input.shape[0]):
    x = input[i]
#     print(x)

    # calculation of i
    i1 = x @ torch.transpose(wii, 0, 1) + bii
    i2 = hid @ torch.transpose(whi, 0, 1) + bhi
    i = sigmoid(i1+i2)
    
    # calculation of f
    f1 = x @ torch.transpose(wif, 0, 1) + bif
    f2 = hid @ torch.transpose(whf, 0, 1) + bhf
    f = sigmoid(f1 + f2)
    
    # calculation of g
    g1 = x @ torch.transpose(wig, 0, 1) + big
    g2 = hid @ torch.transpose(whg, 0, 1) + bhg
    g = tanh(g1 + g2)
    
    # calculation of 0
    o1 = x @ torch.transpose(wio, 0, 1) + bio
    o2 = hid @ torch.transpose(who, 0, 1) + bho
    o = sigmoid(o1 + o2)
    
    c_prime = (f * c_0) + (i * g)
    h_prime = o * tanh(c_prime)
    
    hid = h_prime
    c_0 = c_prime
    
    output.append(hid)

hid = hid.unsqueeze(0)
c_0 = c_0.unsqueeze(0)
output = torch.stack(output, dim = 0)

In [6]:
print("Pytorch results: ")
print(outs)
print()
print(ht)
print()
print(ct)

print()
print("My results: ")
print(output)
print()
print(hid)
print()
print(c_0)

print()
print(f'outs shape: {outs.shape} ht shape: {ht.shape} ct shape: {ct.shape}')
print(f'output shape: {output.shape} hid shape: {hid.shape} c_0 shape: {c_0.shape}')

Pytorch results: 
tensor([[[ 0.1558, -0.1345,  0.1861,  0.1100,  0.0798]],

        [[ 0.2747, -0.0403,  0.3198,  0.3457,  0.1773]]],
       grad_fn=<StackBackward>)

tensor([[[ 0.2747, -0.0403,  0.3198,  0.3457,  0.1773]]],
       grad_fn=<StackBackward>)

tensor([[[ 0.4520, -0.9810,  0.7305,  0.3731,  0.2464]]],
       grad_fn=<StackBackward>)

My results: 
tensor([[[ 0.1558, -0.1345,  0.1861,  0.1100,  0.0798]],

        [[ 0.2747, -0.0403,  0.3198,  0.3457,  0.1773]]],
       grad_fn=<StackBackward>)

tensor([[[ 0.2747, -0.0403,  0.3198,  0.3457,  0.1773]]],
       grad_fn=<UnsqueezeBackward0>)

tensor([[[ 0.4520, -0.9810,  0.7305,  0.3731,  0.2464]]],
       grad_fn=<UnsqueezeBackward0>)

outs shape: torch.Size([2, 1, 5]) ht shape: torch.Size([1, 1, 5]) ct shape: torch.Size([1, 1, 5])
output shape: torch.Size([2, 1, 5]) hid shape: torch.Size([1, 1, 5]) c_0 shape: torch.Size([1, 1, 5])


## Pytorch LSTM with batch_first = True

In [7]:
import torch
import torch.nn as nn

In [8]:
torch.manual_seed(1)

input_dim = 5
hidden_dim = 10
n_layers = 1

lstm_layer = nn.LSTM(input_dim, hidden_dim, 
                     n_layers, batch_first=True)

In [9]:
batch_size = 1
seq_len = 1

inp = torch.randn(batch_size, seq_len, input_dim)
hidden_state = torch.randn(n_layers, batch_size, hidden_dim)
cell_state = torch.randn(n_layers, batch_size, hidden_dim)
hidden = (hidden_state, cell_state)

In [10]:
out, hidden = lstm_layer(inp, hidden)
print("Output shape: ", out.shape)
print("Hidden: ", hidden)

Output shape:  torch.Size([1, 1, 10])
Hidden:  (tensor([[[-0.0853, -0.3509, -0.1043,  0.1794,  0.1093,  0.2788, -0.1967,
          -0.3151, -0.1653, -0.0048]]], grad_fn=<StackBackward>), tensor([[[-0.1763, -1.0270, -0.3605,  0.2844,  0.1600,  0.5662, -0.5406,
          -0.4369, -0.2489, -0.0196]]], grad_fn=<StackBackward>))


In [11]:
# seq_len = 3
# inp = torch.randn(batch_size, seq_len, input_dim)
# out, hidden = lstm_layer(inp, hidden)
# print(out.shape)

In [12]:
# for name, params in lstm_layer.named_parameters():
#     print(name, params)
#     break

w_ih_l0 = lstm_layer.weight_ih_l0
wii = w_ih_l0[0*hidden_dim : 1*hidden_dim, :]
wif = w_ih_l0[1*hidden_dim : 2*hidden_dim, :]
wig = w_ih_l0[2*hidden_dim : 3*hidden_dim, :]
wio = w_ih_l0[3*hidden_dim : 4*hidden_dim, :]

w_hh_l0 = lstm_layer.weight_hh_l0
whi = w_hh_l0[0*hidden_dim : 1*hidden_dim, :]
whf = w_hh_l0[1*hidden_dim : 2*hidden_dim, :]
whg = w_hh_l0[2*hidden_dim : 3*hidden_dim, :]
who = w_hh_l0[3*hidden_dim : 4*hidden_dim, :]

b_ih_l0 = lstm_layer.bias_ih_l0
bii = b_ih_l0[0*hidden_dim : 1*hidden_dim]
bif = b_ih_l0[1*hidden_dim : 2*hidden_dim]
big = b_ih_l0[2*hidden_dim : 3*hidden_dim]
bio = b_ih_l0[3*hidden_dim : 4*hidden_dim]

b_hh_l0 = lstm_layer.bias_hh_l0
bhi = b_hh_l0[0*hidden_dim : 1*hidden_dim]
bhf = b_hh_l0[1*hidden_dim : 2*hidden_dim]
bhg = b_hh_l0[2*hidden_dim : 3*hidden_dim]
bho = b_hh_l0[3*hidden_dim : 4*hidden_dim]

# print(wii)

In [13]:
input = inp
hid = hidden_state
c_0 = cell_state

print(input, input.shape)
# print(hid, hid.shape)
# print(c_0, c_0.shape)

hid = hid.squeeze(0)
c_0 = c_0.squeeze(0)

print(hid, hid.shape)
print(c_0, c_0.shape)

tensor([[[ 0.2624, -0.6198, -0.7153,  0.0834,  0.2980]]]) torch.Size([1, 1, 5])
tensor([[ 2.0028,  0.5610, -1.6287, -1.3715, -1.1648, -1.2502,  0.4156,  0.7394,
         -0.8678,  0.5870]]) torch.Size([1, 10])
tensor([[-0.1618, -1.3426,  0.8099,  1.0417,  0.4967,  1.7153, -1.1099,  0.3573,
         -0.3369, -0.1951]]) torch.Size([1, 10])


In [14]:
sigmoid = nn.Sigmoid()
tanh = nn.Tanh()

output = []

for i in range(input.shape[0]):
    x = input[i]
#     print(x)

    # calculation of i
    i1 = x @ torch.transpose(wii, 0, 1) + bii
    i2 = hid @ torch.transpose(whi, 0, 1) + bhi
    i = sigmoid(i1+i2)
    
    # calculation of f
    f1 = x @ torch.transpose(wif, 0, 1) + bif
    f2 = hid @ torch.transpose(whf, 0, 1) + bhf
    f = sigmoid(f1 + f2)
    
    # calculation of g
    g1 = x @ torch.transpose(wig, 0, 1) + big
    g2 = hid @ torch.transpose(whg, 0, 1) + bhg
    g = tanh(g1 + g2)
    
    # calculation of 0
    o1 = x @ torch.transpose(wio, 0, 1) + bio
    o2 = hid @ torch.transpose(who, 0, 1) + bho
    o = sigmoid(o1 + o2)
    
    c_prime = (f * c_0) + (i * g)
    h_prime = o * tanh(c_prime)
    
    hid = h_prime
    c_0 = c_prime
    
    output.append(hid)

hid = hid.unsqueeze(0)
c_0 = c_0.unsqueeze(0)
output = torch.stack(output, dim = 0)

In [15]:
print("Pytorch results: ")
print(out)
print()
print(hidden[0])
print()
print(hidden[1])

print()
print("My results: ")
print(output)
print()
print(hid)
print()
print(c_0)

Pytorch results: 
tensor([[[-0.0853, -0.3509, -0.1043,  0.1794,  0.1093,  0.2788, -0.1967,
          -0.3151, -0.1653, -0.0048]]], grad_fn=<TransposeBackward0>)

tensor([[[-0.0853, -0.3509, -0.1043,  0.1794,  0.1093,  0.2788, -0.1967,
          -0.3151, -0.1653, -0.0048]]], grad_fn=<StackBackward>)

tensor([[[-0.1763, -1.0270, -0.3605,  0.2844,  0.1600,  0.5662, -0.5406,
          -0.4369, -0.2489, -0.0196]]], grad_fn=<StackBackward>)

My results: 
tensor([[[-0.0853, -0.3509, -0.1043,  0.1794,  0.1093,  0.2788, -0.1967,
          -0.3151, -0.1653, -0.0048]]], grad_fn=<StackBackward>)

tensor([[[-0.0853, -0.3509, -0.1043,  0.1794,  0.1093,  0.2788, -0.1967,
          -0.3151, -0.1653, -0.0048]]], grad_fn=<UnsqueezeBackward0>)

tensor([[[-0.1763, -1.0270, -0.3605,  0.2844,  0.1600,  0.5662, -0.5406,
          -0.4369, -0.2489, -0.0196]]], grad_fn=<UnsqueezeBackward0>)
