In [1]:
from typing import Tuple, Optional

import torch
import torch.nn.functional as F
from torch import nn
from torch import layer_norm

In [49]:
class LSTMCell(nn.Module):
    """
    This is a PyTorch implementation of Long Short-Term Memory.
    c for long-term memory
    h for short-term memory
    """

    def __init__(self, input_dim, hidden_dim, apply_layer_norm, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.apply_layer_norm = apply_layer_norm

        self.W_x = nn.Linear(input_dim, 4*hidden_dim, bias = False)
        self.W_h = nn.Linear(hidden_dim, 4*hidden_dim)

        if apply_layer_norm:
            self.layer_norm = nn.ModuleList([
                nn.LayerNorm(hidden_dim)
                for _ in range(4)
            ])
            self.layer_norm_c = nn.LayerNorm(hidden_dim)
        else:
            pass

    def forward(self, x, h_prev, c_prev):
        """
        x: batch_size, input_dim
        h_prev: batch_size, hidden_dim
        c_prev: batch_size, hidden_dim
        """
        gates = self.W_x(x) + self.W_h(h_prev) # i, f, g, o
        gates = gates.chunk(4, dim = -1)

        if self.apply_layer_norm:
            gates = [self.layer_norm[i](gates[i]) for i in range(4)]

        i, f, g, o = gates

        c_updated = F.sigmoid(f) * c_prev + F.sigmoid(i) * F.tanh(g)
        h_updated = F.sigmoid(o) * F.tanh(self.layer_norm_c(c_updated) \
            if self.apply_layer_norm else c_updated)

        return h_updated, c_updated


In [54]:

class ScrachLSTM(nn.Module):

    def __init__(self, input_dim: int, hidden_dim: int, n_layers: int, return_sequences: bool = True, batch_first = True) -> None:
        super().__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.return_sequences = return_sequences
        self.batch_first = batch_first

        self.cells = nn.ModuleList(
            [LSTMCell(input_dim, hidden_dim, apply_layer_norm = False)] + 
            [LSTMCell(hidden_dim, hidden_dim, apply_layer_norm = False)
            for _ in range(n_layers - 1)])

    def forward(self, x):
        """
        x: 
        x has shape (batch_size, n_steps, feature_dim) if batch_first 
        else (n_steps, batch_size, feature_dim)
        and state is a tuple of h and c, each with a shape of (batch_size, hidden_size).
        """
        if self.batch_first:
            x = x.transpose(0, 1)
            #print(x.shape)
        n_steps, batch_size = x.shape[:2]
        
        h = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        c = torch.zeros_like(h)

        out = []
        for t in range(n_steps):

            inputs = x[t]

            for l in range(self.n_layers):
                h[l], c[l] = self.cells[l](inputs, h[l], c[l])
                inputs = h[l]

            out.append(h[-1])

        if not self.return_sequences: out = [out[-1]]
        return out, (h, c)

In [55]:
batch_size, n_steps, feature_dim = 2, 10, 54
hidden_dim = 27

x = torch.rand(batch_size, n_steps, feature_dim)

myLSTM = ScrachLSTM(feature_dim, hidden_dim, 3, return_sequences=False)

In [56]:
out, (h,c) = myLSTM(x)

In [57]:
len(out)

1