## LSTM

A far easier way to understand the changes in sLSTM and mLSTM in xLSTM would be to first look at an LSTM cell and then adding the modifications over it.

In [1]:
%matplotlib notebook

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from loguru import logger

torch.manual_seed(2023)

<torch._C.Generator at 0x11ecd8f10>

In [3]:
# dummy input sequence of integers. 2 sequences actually.

x = torch.randn(2, 10, 5)
x

tensor([[[ 4.3048e-01, -3.4990e-01,  4.7494e-01,  9.0407e-01, -7.0212e-01],
         [ 1.5963e+00,  4.2280e-01, -6.9397e-01,  9.6718e-01,  1.5569e+00],
         [-2.3860e+00,  6.9941e-01, -1.0325e+00, -2.6043e+00,  9.3368e-01],
         [-1.0496e-01,  7.4267e-01, -1.3397e+00, -3.6486e-01,  2.5399e-01],
         [-1.4082e+00,  2.8347e-01, -9.3333e-01, -6.2785e-01, -7.5152e-02],
         [-2.2086e+00, -1.1256e+00,  2.4818e-02,  1.2566e+00, -9.3699e-01],
         [ 4.8638e-02,  2.8411e-01, -9.5578e-01,  1.4745e+00,  5.1086e-01],
         [-2.3249e-01,  3.9579e-01,  8.5357e-01, -4.2040e-01, -1.4516e+00],
         [-7.3737e-01, -4.2015e-01,  3.0709e-01, -1.2767e+00,  2.0085e-01],
         [ 1.8960e-02,  3.0411e-01, -9.2130e-01,  4.0975e-01, -1.5108e+00]],

        [[ 2.9006e-01,  2.5075e+00, -8.9630e-01, -2.2588e+00, -2.2113e-01],
         [-1.6946e+00, -2.8795e-01, -6.5329e-01,  1.3445e+00, -3.7231e-01],
         [-6.5886e-01, -2.3493e-01,  5.0538e-01,  1.8711e+00, -1.6772e+00],
         [

In [4]:
# checking the base implementation of lstm in pytorch

hidden_size = 5
input_size = x.size(-1)

# input size, hidden size, num layers
rnn = nn.LSTM(input_size, hidden_size, 1)

# h_0 size should be n, input_size, hidden size
h_0 = torch.randn(1, 10, hidden_size)
c_0 = torch.randn(h_0.size())

out, (hn, cn) = rnn(x.float(), (h_0, c_0))
out.size()


torch.Size([2, 10, 5])

### Defining the weight and bias kernels for the LSTM cell

The LSTM cell has 4 sets of weights and biases. The weights are for the input, forget, output and cell state respectively. The biases are for the input, forget, output and cell state respectively.

The weights are defined as follows:
- $W_{xi}$: Weights for the input gate
- $W_{xf}$: Weights for the forget gate
- $W_{xo}$: Weights for the output gate
- $W_{xc}$: Weights for the cell state

The biases are defined as follows:
- $b_{i}$: Bias for the input gate
- $b_{f}$: Bias for the forget gate
- $b_{o}$: Bias for the output gate
- $b_{c}$: Bias for the cell state

### Defining the LSTM cell


In [8]:
# an unidirectional lstm cell (single layer)
class LSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # cell output gate
        self.W_z = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        
        # input gate
        self.W_i = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        
        # forget gate
        self.W_f = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        
        # output gate
        self.W_o = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        
        # biases
        self.b_z = nn.Parameter(
            torch.empty(self.hidden_size)
        )
        self.b_i = nn.Parameter(
            torch.empty(self.hidden_size)
        )
        self.b_f = nn.Parameter(
            torch.empty(self.hidden_size)
        )
        self.b_o = nn.Parameter(
            torch.empty(self.hidden_size)
        )
        
        # hidden / recurrent weights
        # doesn't have bias
        self.r_z = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        self.r_i = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        self.r_f = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
        self.r_o = nn.Parameter(
            torch.empty(self.hidden_size, self.input_size)
        )
    
    # TODO: implement later
    def init_params(self):
        pass
    
    # TODO: fix dimension mismatch error
    def forward(self, x, h_0, c_0):
        logger.info(f"x :: {x.size()}")
        logger.info(f"h_0 :: {h_0.size()}")
        logger.info(f"c_0 :: {c_0.size()}")
        
        
        z_t = (self.W_z.T @ x) + (self.r_z @ h_0) + self.b_z
        z_t = F.tanh(z_t)
        
        i_t = (self.W_i.T @ x) + (self.r_i @ h_0) + self.b_i
        i_t = F.sigmoid(i_t)
        
        f_t = (self.W_f.T @ x) + (self.r_f @ h_0) + self.b_f
        f_t = F.sigmoid(f_t)
        
        o_t = (self.W_o.T @ x) + (self.r_o @ h_0) + self.b_o
        o_t = F.sigmoid(o_t)
        
        # new cell state
        c_t = f_t @ c_0 + i_t @ z_t
        # new hidden state
        h_t = o_t @ F.tanh(c_t)
        
        return z_t, c_t, h_t
    
    
with torch.no_grad():
    lstm = LSTMCell(input_size, hidden_size)
    out = lstm(x, h_0, c_0)
    print(out)
        

[32m2024-09-12 02:05:07.945[0m | [1mINFO    [0m | [36m__main__[0m:[36mforward[0m:[36m65[0m - [1mx :: torch.Size([2, 10, 5])[0m
[32m2024-09-12 02:05:07.946[0m | [1mINFO    [0m | [36m__main__[0m:[36mforward[0m:[36m66[0m - [1mh_0 :: torch.Size([1, 10, 5])[0m
[32m2024-09-12 02:05:07.948[0m | [1mINFO    [0m | [36m__main__[0m:[36mforward[0m:[36m67[0m - [1mc_0 :: torch.Size([1, 10, 5])[0m


RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x10 and 5x5)