# LSTM
<img src="../img/lstm_im.png/">

Let:

$fn$ = Number of features

$hs$ = Number of output nodes (hidden size)

$bs$ = Batch size

Then:
 * Each $W_{something}$ matrix below has the shape $(fn, hs)$;
 * Each $U_{something}$ matrix below has the shape $(hs, hs)$;
 * Each $b_{something}$ matrix below has the shape $(1, hs)$;
 * The $x_t$ matrix below has shape $(bs, fn)$, corresponding to the element of index $t$ of each sequence inf the batch.; and
 * The $h_t$ matrix below has shape $(bs, hs)$, corresponding to hidden state at time $t$ of each sequence inf the batch.

And:

$f_t = \sigma(W_f \ x_t + U_f \ h_{t-1} + b_f)$

$i_t = \sigma(W_i \ x_t + U_i \ h_{t-1} + b_i)$

$o_t = \sigma(W_o \ x_t + U_o \ h_{t-1} + b_o)$

$g_t = \tanh \ (W_g \ x_t + U_g \ h_{t-1} + b_g)$ a.k.a. $\tilde{c}_t$

$c_t = f_t \circ c_{t-1} + i_t \circ g_t$

$h_t = o_t \circ \tanh \ (c_t)$

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# LSTM From Scratch

In [4]:
class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, device):
        super(LSTM, self).__init__()
        self.device = device
        self.W_xf, self.W_hf, self.b_f = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        self.W_xi, self.W_hi, self.b_i = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        self.W_xo, self.W_ho, self.b_o = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        self.W_xc, self.W_hc, self.b_c = torch.randn(input_size, hidden_size), torch.randn(hidden_size, hidden_size), torch.zeros(hidden_size)
        
    def forward(self, X, state):
        prev_h, prev_c = state
        batch_size, timesteps, _  = X.shape
        hidden_seq = []
        for t in range(timesteps):
            x = X[:,t,:]
            forget_gate = torch.sigmoid(x @ self.W_xf + prev_h @ self.W_hf + self.b_f)
            input_gate = torch.sigmoid(x @ self.W_xi + prev_h @ self.W_hf + self.b_i)
            output_gate = torch.sigmoid(x @ self.W_xo + prev_h @ self.W_ho + self.b_o)
            C_tilda = torch.tanh(x @ self.W_xc + prev_h @ self.W_hc + self.b_c)
            Ct = forget_gate * prev_c + input_gate * C_tilda
            Ht = output_gate * torch.tanh(Ct)
            hidden_seq.append(Ht.unsqueeze(0))
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.permute(1, 0, 2)
        
        return hidden_seq, hidden_seq_2

# School Version

In [9]:
class LSTM(nn.Module):
  def __init__(self, input_size, hidden_size, device):
    super(LSTM, self).__init__()
    self.device = device
    self.params = self.init_params(input_size, hidden_size)
    """
    Inputs:
      input_size: int, feature dimension of input sequence
      hidden_size: int, feature dimension of hidden state
      device: torch.device()
    """
  
  def init_params(self, input_size, hidden_size):
    """
    Inputs:
      input_size: int, feature dimension of input sequence
      hidden_size: int, feature dimension of hidden state
      
    Outputs:
      Weights for proposal: W_xc, W_hc, b_c
      Weights for input gate: W_xi, W_hi, b_i
      Weights for forget gate: W_xf, W_hf, b_f
      Weights for output gate: W_xo, W_ho, b_o
    """
    W_xc, W_hc, b_c = None, None, None
    W_xi, W_hi, b_i = None, None, None
    W_xf, W_hf, b_f = None, None, None
    W_xo, W_ho, b_o = None, None, None
    ##############################################################################
    # TODO: Initialize the weights and biases. The result will be stored in 
    # `params` below. Weights should be initialized using `torch.randn` multiplied 
    # with the scale (0.1). Biases should be initialized to 0.
    ##############################################################################
    # Replace "pass" statement with your code
    D, M = input_size, hidden_size
    self.hidden_size = hidden_size
    W_xc, W_hc, b_c = torch.randn(D,M) * 0.1, torch.randn(M,M) * 0.1, torch.zeros(M,)
    W_xi, W_hi, b_i = torch.randn(D,M) * 0.1, torch.randn(M,M) * 0.1, torch.zeros(M,)
    W_xf, W_hf, b_f = torch.randn(D,M) * 0.1, torch.randn(M,M) * 0.1, torch.zeros(M,)
    W_xo, W_ho, b_o = torch.randn(D,M) * 0.1, torch.randn(M,M) * 0.1, torch.zeros(M,) 
    # END OF YOUR CODE
    
    params = [W_xc, W_hc, b_c, W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o]
    return params

  
  def lstm(self, X, state):
    """
    Inputs:
      X: tuple of tensors (src, src_len). src, size (N, D_in) or (N, T, D_in), where N is the batch size,
        T is the length of the sequence(s). src_len, size of (N,), is the valid length for each sequence.
        
      state: tuple of tensors (h, c). h, size of (N, hidden_size) is the hidden state of LSTM. c, size of 
            (N, hidden_size), is the memory cell of the LSTM.
      
    Outputs:
      o: tensor of size (N, T, hidden_size).
      state: the same as input state.
    """
    
    src, src_len = X
    h, c = state

    # make sure always has a T dim
    if len(src.shape) == 2:
      src = src.unsqueeze(1)

    N, T, D_in = src.shape
    W_xc, W_hc, b_c, W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o = self.params
    
    o = []
    ##############################################################################
    # TODO: Implement the forward pass of the LSTM.
    ##############################################################################
    # Replace "pass" statement with your code
    for t in range(T):
      x = src[:,t,:]
      forget_gate = torch.sigmoid(x @ W_xf + h @ W_hf + b_f)
      input_gate = torch.sigmoid(x @ W_xi + h @ W_hf + b_i)
      output_gate = torch.sigmoid(x @ W_xo + h @ W_ho + b_o)
      C_tilda = torch.tanh(x @ W_xc + h @ W_hc + b_c)
      c = forget_gate * c + input_gate * C_tilda
      h = output_gate * torch.tanh(c)
      o.append(h.unsqueeze(0))
    hidden_seq = torch.cat(o, dim=0)
    hidden_seq = hidden_seq.permute(1, 0, 2)
    o = o[np.arange(N), src_len]
    # END OF YOUR CODE

    state = (h, c)
    return o, state
  
  def forward(self, inputs, state):
    return self.lstm(inputs, state)

In [None]:
lstm()