In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MyLSTM(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, n_layers, dropout=0.0):
        super().__init__()
        self.input_size = embedding_dim
        self.hidden_size = hidden_dim
        self.num_layers = n_layers
        self.dropout = dropout

        # Each layer has its own parameters
        self.layers = nn.ModuleList()
        for layer in range(n_layers):
            layer_input_size = embedding_dim if layer == 0 else hidden_dim
            self.layers.append(self._make_lstm_layer(layer_input_size, hidden_dim))

        # Dropout module (applied between layers)
        self.dropout_layer = nn.Dropout(dropout)

    def _make_lstm_layer(self, input_size, hidden_size):
        """Create parameters for one LSTM layer."""
        layer = nn.ModuleDict({
            "W_ii": nn.Parameter(torch.randn(hidden_size, input_size)),
            "W_hi": nn.Parameter(torch.randn(hidden_size, hidden_size)),
            "b_i":  nn.Parameter(torch.zeros(hidden_size)),

            "W_if": nn.Parameter(torch.randn(hidden_size, input_size)),
            "W_hf": nn.Parameter(torch.randn(hidden_size, hidden_size)),
            "b_f":  nn.Parameter(torch.zeros(hidden_size)),

            "W_io": nn.Parameter(torch.randn(hidden_size, input_size)),
            "W_ho": nn.Parameter(torch.randn(hidden_size, hidden_size)),
            "b_o":  nn.Parameter(torch.zeros(hidden_size)),

            "W_ig": nn.Parameter(torch.randn(hidden_size, input_size)),
            "W_hg": nn.Parameter(torch.randn(hidden_size, hidden_size)),
            "b_g":  nn.Parameter(torch.zeros(hidden_size)),
        })
        return layer

    def _lstm_cell(self, x_t, h_t, c_t, params):
        """One LSTM cell step."""
        i_t = torch.sigmoid(x_t @ params["W_ii"].T + h_t @ params["W_hi"].T + params["b_i"])
        f_t = torch.sigmoid(x_t @ params["W_if"].T + h_t @ params["W_hf"].T + params["b_f"])
        o_t = torch.sigmoid(x_t @ params["W_io"].T + h_t @ params["W_ho"].T + params["b_o"])
        g_t = torch.tanh(   x_t @ params["W_ig"].T + h_t @ params["W_hg"].T + params["b_g"])

        c_t = f_t * c_t + i_t * g_t
        h_t = o_t * torch.tanh(c_t)
        return h_t, c_t

    def forward(self, x, state=None):
        """
        x: (seq_len, batch, embedding_dim)
        state: (h0, c0) each (num_layers, batch, hidden_dim)
        """
        seq_len, batch, _ = x.size()

        if state is None:
            h = torch.zeros(self.num_layers, batch, self.hidden_size, device=x.device)
            c = torch.zeros(self.num_layers, batch, self.hidden_size, device=x.device)
        else:
            h, c = state

        outputs = []
        layer_input = x

        for layer_idx, params in enumerate(self.layers):
            h_t = h[layer_idx]
            c_t = c[layer_idx]

            layer_outputs = []
            for t in range(seq_len):
                x_t = layer_input[t]
                h_t, c_t = self._lstm_cell(x_t, h_t, c_t, params)
                layer_outputs.append(h_t.unsqueeze(0))

            layer_output = torch.cat(layer_outputs, dim=0)

            h[layer_idx] = h_t
            c[layer_idx] = c_t

            if layer_idx < self.num_layers - 1:
                layer_output = self.dropout_layer(layer_output)

            layer_input = layer_output

        return layer_input, (h, c)

In [2]:
pwd

'C:\\Users\\olaadigu\\codes\\Computational_Intelligence\\Bidirectional-language-translation\\seq2seq'