<a href="https://colab.research.google.com/github/Nedu21/Pytorch-deep-learning-projects-/blob/main/Custom_LSTM_Module.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# LSTM From Scratch

In [1]:
import torch
import torch.nn as nn
import math

In [10]:
class CustomLSTMCell(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size

    # 1. THE WEIGHTS(W)
    # We need a matrix of shape (Input + Hidden) * (4 * Hidden)
    # We use 1 big matrix for efficiency instead of 4 small ones ($W_f, W_i, W_C, W_o)
    self.W = nn.Parameter(torch.Tensor(input_size + hidden_size, hidden_size * 4))

    # 2. THE BIAS(b)
    # One bias value for each output neuron
    self.b = nn.Parameter(torch.Tensor(hidden_size * 4))

    # Initialize parameters (Standard practice to help training start well)
    self.init_weights()

  def init_weights(self):
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
      weight.data.uniform_(-stdv, stdv)

  def forward(self, x, init_states=None):
    # init_states: tuple (h_prev, C_prev)
    h_prev, C_prev = init_states

    # 1. CONCATENATE & MULTIPLY
    # Combine input and previous hidden state
    # x is (batch, input_size), h_prev is (batch, hidden_size)
    combined = torch.cat((x, h_prev), 1) # Shape: (batch, input_size + hidden_size)
    gates_linear = combined @ self.W + self.b # Shape: (batch, hidden_size * 4)

    # 2 SPLIT & ACTIVATE (the "Non-Linear" part)
    # We split the big tensor into 4 slices along dimension 1
    slices = gates_linear.chunk(4, 1)

    # Slice 1: Forget Gate (f_t) -> Sigmoid
    f_t = torch.sigmoid(slices[0])

    # Slice 2: Input Gate (i_t) -> Sigmoid
    i_t = torch.sigmoid(slices[1])

    # Slice 3: Candidate Memory (C_tilde) -> Tanh
    C_tilde = torch.tanh(slices[2])

    # Slice 4: Output Gate (o_t) -> Sigmoid
    o_t = torch.sigmoid(slices[3])

    # 3. UPDATE CELL STATE (The Conveyor Belt)
    # C_t = (Forget * Old) + (Input * New Candidate)
    C_t = (f_t * C_prev) + (i_t * C_tilde)

    # 4. UPDATE HIDDEN STATE (The Working Memory)
    # h_t = Output_Filter * tanh(New_Cell_State)
    h_t = o_t * torch.tanh(C_t)

    return h_t, C_t

In [11]:
class CustomLSTM(nn.Module):
  def __init__(self, input_size, hidden_size):
    super().__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.cell = CustomLSTMCell(input_size, hidden_size)

  def forward(self, x):
    # x shape: (batch_size, sequence_length, input_size)
    batch_size, seq_len, _ =x.size()

    # Initialize Hidden and Cell states to Zeros
    h_t = torch.zeros(batch_size, self.hidden_size)
    C_t = torch.zeros(batch_size, self.hidden_size)

    # List to store outputs
    hidden_states = []

    # THE TIME LOOP
    for t in range(seq_len):
      # Get input for this specific time step
      x_t = x[:, t, :]

      # Run the cell
      # CRITICAL: We update h_t and C_t to pass to the next loop!
      h_t, C_t =self.cell(x_t, (h_t, C_t))

      # Store the hidden state
      hidden_states.append(h_t)

    # Stack results: (batch_size, seq_len, hidden_size)
    output = torch.stack(hidden_states, dim=1)
    return output

In [12]:
# --- TEST BLOCK ---
if __name__ == "__main__":
    # Parameters
    BATCH_SIZE = 2
    SEQ_LEN = 5    # Sentence length
    INPUT_SZ = 10  # e.g., embedding size
    HIDDEN_SZ = 20 # Memory size

    # Create dummy data (Random numbers)
    input_data = torch.randn(BATCH_SIZE, SEQ_LEN, INPUT_SZ)

    # Initialize our model
    my_lstm = CustomLSTM(INPUT_SZ, HIDDEN_SZ)

    # Run the model
    output = my_lstm(input_data)

    print("Input shape: ", input_data.shape)
    print("Output shape:", output.shape)

    # Verification
    if output.shape == (BATCH_SIZE, SEQ_LEN, HIDDEN_SZ):
        print("SUCCESS: The output shape is exactly what we expected!")
    else:
        print("Check the dimensions again.")

Input shape:  torch.Size([2, 5, 10])
Output shape: torch.Size([2, 5, 20])
SUCCESS: The output shape is exactly what we expected!
