#### **Task 1**
 - Implement a LSTM (LSTM() and/or LSTMCell()) from scratch
 - Implement a Convolutional LSTM (ConvLSTM() and/or ConvLSTMCell()) from scratch

In [None]:
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms

In [None]:
class ConvLSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size,kernel_size=3,padding=1, bias=True):
        """
        Initialize ConvLSTM cell.

        Parameters
        ----------
        input_size: int
            Number of channels of input tensor.
        hidden_size: int
            Number of channels of hidden state.
        kernel_size: (int, int)
            Size of the convolutional kernel.
        bias: bool
            Whether or not to add the bias.
        """

        super(ConvLSTMCell, self).__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        # Linear layer to compute all gates at once
        # Note: 4 * hidden_size for i, f, o, g (input, forget, output, candidate)
        
        self.f_t = nn.Conv1d(in_channels=input_size + hidden_size,
        out_channels=hidden_size,  # Matches hidden state dim
        kernel_size=kernel_size,
        padding=padding,
        bias=bias # To maintain spatial dims
        )
        self.i_t =nn.Conv1d(in_channels=input_size + hidden_size,
        out_channels=hidden_size,  # Matches hidden state dim
        kernel_size=kernel_size,
        padding=padding,
        bias=bias # To maintain spatial dims
        )

        self.c_hat_t =nn.Conv1d(in_channels=input_size + hidden_size,
        out_channels=hidden_size,  # Matches hidden state dim
        kernel_size=kernel_size,
        padding=padding,
        bias=bias # To maintain spatial dims
        )
        
        self.o_t = nn.Conv1d(in_channels=input_size + hidden_size,
        out_channels=hidden_size,  # Matches hidden state dim
        kernel_size=kernel_size,
        padding=padding,
        bias=bias # To maintain spatial dims
        )

        self.linear = nn.Linear(input_size + hidden_size, 4 * hidden_size)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state

        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis

        
        gates = self.linear(combined)  # Shape: (batch, 4 * hidden_size)
        
        # Split into input, forget, output, and candidate gates
        #cc_i, cc_f, cc_o, cc_g = gates.chunk(4, dim=1)  # Each shape: (batch, hidden_size)
        
        f_t=torch.sigmoid(self.f_t(combined).T)
        i_t=torch.sigmoid(self.i_t(combined).T)
        c_hat_t=torch.tanh(self.c_hat_t(combined).T)
        o_t=torch.sigmoid(self.o_t(combined).T)
        
        c_t=f_t*c_cur+i_t*c_hat_t
        h_t=o_t*torch.tanh(c_t)
        
        return h_t, c_t

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_size, height, width, device=self.conv.weight.device),
                torch.zeros(batch_size, self.hidden_size, height, width, device=self.conv.weight.device))


class LSTMWithCustomCell(nn.Module):
    """ 
    Sequential classifier. Embedded images are fed to a RNN
    Same as above, but using LSTMCells instead of the LSTM object
    
    Args:
    -----
    emb_dim: integer 
        dimensionality of the vectors fed to the LSTM
    hidden_dim: integer
        dimensionality of the states in the cell
    num_layers: integer
        number of stacked LSTMS
    mode: string
        intialization of the states
    """
    
    def __init__(self, emb_dim, hidden_dim, num_layers=1, mode="zeros"):
        """ Module initializer """
        assert mode in ["zeros", "random"]
        super().__init__()
        self.hidden_dim =  hidden_dim
        self.num_layers = num_layers
        self.mode = mode

        # for embedding rows into vector representations
        self.encoder = nn.Sequential(
                nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2),
                nn.Conv2d(128, emb_dim, 3, 1, 1),
                nn.AdaptiveAvgPool2d((1, 1))
            )
        
        # LSTM model       
        lstms = []
        for i in range(num_layers):
            in_size = emb_dim if i == 0 else hidden_dim
            #lstms.append( nn.LSTMCell(input_size=in_size, hidden_size=hidden_dim) )
            lstms.append( ConvLSTMCell(input_size=in_size, hidden_size=hidden_dim) )
            
        self.lstm = nn.ModuleList(lstms)
        
        # FC-classifier
        self.classifier = nn.Linear(in_features=hidden_dim, out_features=4)
        
        return
    
    
    def forward(self, x):
        """ Forward pass through model """
        
        b_size, num_frames, n_channels, n_rows, n_cols = x.shape
        h, c = self.init_state(b_size=b_size, device=x.device) 
        
        # embedding rows
        x = x.view(b_size * num_frames, n_channels, n_rows, n_cols)
        embeddings = self.encoder(x)
        embeddings = embeddings.reshape(b_size, num_frames, -1)
        
        # iterating over sequence length
        lstm_out = []
        for i in range(embeddings.shape[1]):  # iterate over time steps
            lstm_input = embeddings[:, i, :]  # size= (batch_size, emb_dim) 
            # iterating over LSTM Cells
            for j, lstm_cell in enumerate(self.lstm):
                #try:
                    if lstm_input.shape[0] != B_SIZE:
                        continue
                    #print(lstm_input.shape)
                    h[j], c[j] = lstm_cell(lstm_input, (h[j], c[j]))
                    lstm_input = h[j]
                #except:
                    #lstm_input=lstm_input;
            lstm_out.append(lstm_input)
        lstm_out = torch.stack(lstm_out, dim=1)
            
        # classifying
        y = self.classifier(lstm_out[:, -1, :])  # feeding only output at last layer
        
        return y
    
        
    def init_state(self, b_size, device):
        """ Initializing hidden and cell state """
        if(self.mode == "zeros"):
            h = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
            c = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
        elif(self.mode == "random"):
            h = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
            c = [torch.zeros(b_size, self.hidden_dim).to(device) for _ in range(self.num_layers)]
        return h, c
