<a href="https://colab.research.google.com/github/amankiitg/Foundation_AI/blob/main/GenAI_Lecture_4_RNN_Translator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Part 1: RNN Translator coded from scratch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# --------------------------
# 1. A Tiny Manual RNN Encoder
# --------------------------
class TinyEncoder(nn.Module):
    def __init__(self, input_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size, embed_size)

        # RNN parameters
        self.hidden_size = hidden_size
        self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size)*0.1)
        self.W_x = nn.Parameter(torch.randn(hidden_size, embed_size)*0.1)
        self.b   = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, src_tokens):
        """
        src_tokens: shape (src_len,)
        Returns final hidden state (hidden_size,).
        """
        h = torch.zeros(self.hidden_size)

        for t in range(src_tokens.shape[0]):
            token_id = src_tokens[t]
            x_t = self.embedding(token_id)

            h = torch.tanh(
                torch.mv(self.W_h, h) +
                torch.mv(self.W_x, x_t) +
                self.b
            )

        return h


# -------------------------
# 2. A Tiny Manual RNN Decoder
# -------------------------
class TinyDecoder(nn.Module):
    def __init__(self, output_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_vocab_size, embed_size)

        self.hidden_size = hidden_size
        self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size)*0.1)
        self.W_x = nn.Parameter(torch.randn(hidden_size, embed_size)*0.1)
        self.b   = nn.Parameter(torch.zeros(hidden_size))

        # Output projection
        self.W_out = nn.Parameter(torch.randn(output_vocab_size, hidden_size)*0.1)
        self.b_out = nn.Parameter(torch.zeros(output_vocab_size))

    def forward(self, dec_tokens, init_hidden):
        h = init_hidden
        logits_list = []

        for t in range(dec_tokens.shape[0]):
            token_id = dec_tokens[t]
            x_t = self.embedding(token_id)

            h = torch.tanh(
                torch.mv(self.W_h, h) +
                torch.mv(self.W_x, x_t) +
                self.b
            )
            logits_t = torch.mv(self.W_out, h) + self.b_out
            logits_list.append(logits_t.unsqueeze(0))

        return torch.cat(logits_list, dim=0)


# -------------------------------------
# 3. Example Data: "I go <EOS>" -> "मैं जाता हूँ <EOS>"
# -------------------------------------
ENG_VOCAB_SIZE = 3  # I=0, go=1, <EOS>=2
HIN_VOCAB_SIZE = 5  # <GO>=0, मैं=1, जाता=2, हूँ=3, <EOS>=4

# Map IDs to words for printing
HIN_ID2WORD = {
    0: "<GO>",
    1: "मैं",
    2: "जाता",
    3: "हूँ",
    4: "<EOS>"
}

EMBED_SIZE = 1
HIDDEN_SIZE = 2

encoder = TinyEncoder(ENG_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
decoder = TinyDecoder(HIN_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)

# Source: "I go <EOS>" => [0,1,2]
encoder_input = torch.tensor([0,1,2])

# Decoder target: "मैं जाता हूँ <EOS>" => [1,2,3,4]
# We'll do teacher forcing in training:
decoder_input  = torch.tensor([0,1,2,3])  # <GO>, मैं, जाता, हूँ
decoder_target = torch.tensor([1,2,3,4])  #     मैं, जाता, हूँ, <EOS>

# ----------------------------------
# 4. Training Loop (Cross Entropy)
# ----------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=0.0001)

num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # 1) Encode
    enc_hidden = encoder(encoder_input)  # shape (2,)

    # 2) Decode
    logits = decoder(decoder_input, enc_hidden)  # (4,5)

    # 3) Compute cross-entropy
    loss = criterion(logits, decoder_target)

    # 4) Backprop + update
    loss.backward()
    optimizer.step()

    # Print stats
    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss = {loss.item():.4f}")

    # ------------------------------
    # Print generated words every 20 epochs
    # ------------------------------
    if (epoch+1) % 20 == 0:
        print(f"\n--- Decoding after epoch {epoch+1} ---")
        with torch.no_grad():
            # Re-encode
            enc_hidden = encoder(encoder_input)

            # Start <GO>=0
            current_token = torch.tensor(0)
            h = enc_hidden.clone()

            generated_tokens = []
            for _ in range(6):
                x_t = decoder.embedding(current_token)
                h = torch.tanh(
                    torch.mv(decoder.W_h, h) +
                    torch.mv(decoder.W_x, x_t) +
                    decoder.b
                )

                logits_t = torch.mv(decoder.W_out, h) + decoder.b_out
                next_token = torch.argmax(logits_t).item()
                generated_tokens.append(next_token)

                if next_token == 4:  # <EOS>
                    break
                current_token = torch.tensor(next_token)

            # Convert IDs to words
            generated_words = [HIN_ID2WORD[t] for t in generated_tokens]
            print("Generated tokens:", generated_words)
        print("-----------------------------------\n")




Epoch 5/1000, Loss = 1.6091
Epoch 10/1000, Loss = 1.6091
Epoch 15/1000, Loss = 1.6090
Epoch 20/1000, Loss = 1.6090

--- Decoding after epoch 20 ---
Generated tokens: ['<GO>', '<GO>', '<GO>', '<GO>', '<GO>', '<GO>']
-----------------------------------

Epoch 25/1000, Loss = 1.6090
Epoch 30/1000, Loss = 1.6090
Epoch 35/1000, Loss = 1.6089
Epoch 40/1000, Loss = 1.6089

--- Decoding after epoch 40 ---
Generated tokens: ['<GO>', '<GO>', '<GO>', '<GO>', '<GO>', '<GO>']
-----------------------------------

Epoch 45/1000, Loss = 1.6089
Epoch 50/1000, Loss = 1.6089
Epoch 55/1000, Loss = 1.6088
Epoch 60/1000, Loss = 1.6088

--- Decoding after epoch 60 ---
Generated tokens: ['<GO>', '<GO>', '<GO>', '<GO>', '<GO>', '<GO>']
-----------------------------------

Epoch 65/1000, Loss = 1.6088
Epoch 70/1000, Loss = 1.6088
Epoch 75/1000, Loss = 1.6087
Epoch 80/1000, Loss = 1.6087

--- Decoding after epoch 80 ---
Generated tokens: ['<GO>', '<GO>', '<GO>', '<GO>', '<GO>', '<GO>']
--------------------------

## Part 2: LSTM translator coded from scratch

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

# --------------------------
# 1. A Tiny Manual LSTM Encoder
# --------------------------
class TinyEncoderLSTM(nn.Module):
    def __init__(self, input_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size, embed_size)
        self.hidden_size = hidden_size

        # LSTM parameters for the encoder
        # Input gate parameters
        self.W_i = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_i = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_i = nn.Parameter(torch.zeros(hidden_size))

        # Forget gate parameters
        self.W_f = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_f = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_f = nn.Parameter(torch.zeros(hidden_size))

        # Output gate parameters
        self.W_o = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_o = nn.Parameter(torch.zeros(hidden_size))

        # Candidate cell (g) parameters
        self.W_g = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_g = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_g = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, src_tokens):
        """
        src_tokens: shape (src_len,)
        Returns final hidden state (h) and cell state (c), each of shape (hidden_size,).
        """
        h = torch.zeros(self.hidden_size)
        c = torch.zeros(self.hidden_size)

        for t in range(src_tokens.shape[0]):
            token_id = src_tokens[t]
            x_t = self.embedding(token_id)

            i_t = torch.sigmoid(torch.mv(self.W_i, x_t) + torch.mv(self.U_i, h) + self.b_i)
            f_t = torch.sigmoid(torch.mv(self.W_f, x_t) + torch.mv(self.U_f, h) + self.b_f)
            o_t = torch.sigmoid(torch.mv(self.W_o, x_t) + torch.mv(self.U_o, h) + self.b_o)
            g_t = torch.tanh(torch.mv(self.W_g, x_t) + torch.mv(self.U_g, h) + self.b_g)

            c = f_t * c + i_t * g_t
            h = o_t * torch.tanh(c)

        return h, c


# -------------------------
# 2. A Tiny Manual LSTM Decoder
# -------------------------
class TinyDecoderLSTM(nn.Module):
    def __init__(self, output_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_vocab_size, embed_size)
        self.hidden_size = hidden_size

        # LSTM parameters for the decoder
        self.W_i = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_i = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_i = nn.Parameter(torch.zeros(hidden_size))

        self.W_f = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_f = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_f = nn.Parameter(torch.zeros(hidden_size))

        self.W_o = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_o = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_o = nn.Parameter(torch.zeros(hidden_size))

        self.W_g = nn.Parameter(torch.randn(hidden_size, embed_size) )
        self.U_g = nn.Parameter(torch.randn(hidden_size, hidden_size) )
        self.b_g = nn.Parameter(torch.zeros(hidden_size))

        # Output projection parameters
        self.W_out = nn.Parameter(torch.randn(output_vocab_size, hidden_size) )
        self.b_out = nn.Parameter(torch.zeros(output_vocab_size))

    def forward(self, dec_tokens, init_hidden, init_cell):
        """
        dec_tokens: shape (dec_len,)
        init_hidden: (hidden_size,)
        init_cell: (hidden_size,)
        Returns logits of shape (dec_len, output_vocab_size)
        """
        h = init_hidden
        c = init_cell
        logits_list = []

        for t in range(dec_tokens.shape[0]):
            token_id = dec_tokens[t]
            x_t = self.embedding(token_id)

            i_t = torch.sigmoid(torch.mv(self.W_i, x_t) + torch.mv(self.U_i, h) + self.b_i)
            f_t = torch.sigmoid(torch.mv(self.W_f, x_t) + torch.mv(self.U_f, h) + self.b_f)
            o_t = torch.sigmoid(torch.mv(self.W_o, x_t) + torch.mv(self.U_o, h) + self.b_o)
            g_t = torch.tanh(torch.mv(self.W_g, x_t) + torch.mv(self.U_g, h) + self.b_g)

            c = f_t * c + i_t * g_t
            h = o_t * torch.tanh(c)

            logits_t = torch.mv(self.W_out, h) + self.b_out
            logits_list.append(logits_t.unsqueeze(0))

        return torch.cat(logits_list, dim=0)


# -------------------------------------
# 3. Example Data: "I go <EOS>" -> "मैं जाता हूँ <EOS>"
# -------------------------------------
ENG_VOCAB_SIZE = 3  # I=0, go=1, <EOS>=2
HIN_VOCAB_SIZE = 5  # <GO>=0, मैं=1, जाता=2, हूँ=3, <EOS>=4

# Map IDs to words for printing
HIN_ID2WORD = {
    0: "<GO>",
    1: "मैं",
    2: "जाता",
    3: "हूँ",
    4: "<EOS>"
}

EMBED_SIZE = 1
HIDDEN_SIZE = 2

encoder = TinyEncoderLSTM(ENG_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
decoder = TinyDecoderLSTM(HIN_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)

# Source: "I go <EOS>" => [0,1,2]
encoder_input = torch.tensor([0, 1, 2])

# Decoder target: "मैं जाता हूँ <EOS>" => [1,2,3,4]
# For teacher forcing:
decoder_input  = torch.tensor([0, 1, 2, 3])  # <GO>, मैं, जाता, हूँ
decoder_target = torch.tensor([1, 2, 3, 4])  #     मैं, जाता, हूँ, <EOS>

# ----------------------------------
# 4. Training Loop (Cross Entropy)
# ----------------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=0.1)

num_epochs = 1000
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # 1) Encode
    enc_hidden, enc_cell = encoder(encoder_input)  # each is (hidden_size,)

    # 2) Decode (teacher forcing)
    logits = decoder(decoder_input, enc_hidden, enc_cell)  # shape (dec_len, HIN_VOCAB_SIZE)

    # 3) Compute cross-entropy loss
    loss = criterion(logits, decoder_target)

    # 4) Backpropagation + update
    loss.backward()
    optimizer.step()

    # Print training statistics
    if (epoch + 1) % 5 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}, Loss = {loss.item():.4f}")

    # ------------------------------
    # Print generated words every 20 epochs
    # ------------------------------
    if (epoch + 1) % 20 == 0:
        print(f"\n--- Decoding after epoch {epoch+1} ---")
        with torch.no_grad():
            # Re-encode the source sentence
            enc_hidden, enc_cell = encoder(encoder_input)

            # Start decoding with the <GO> token (0)
            current_token = torch.tensor(0)
            h = enc_hidden.clone()
            c = enc_cell.clone()

            generated_tokens = []
            for _ in range(6):
                x_t = decoder.embedding(current_token)
                i_t = torch.sigmoid(torch.mv(decoder.W_i, x_t) + torch.mv(decoder.U_i, h) + decoder.b_i)
                f_t = torch.sigmoid(torch.mv(decoder.W_f, x_t) + torch.mv(decoder.U_f, h) + decoder.b_f)
                o_t = torch.sigmoid(torch.mv(decoder.W_o, x_t) + torch.mv(decoder.U_o, h) + decoder.b_o)
                g_t = torch.tanh(torch.mv(decoder.W_g, x_t) + torch.mv(decoder.U_g, h) + decoder.b_g)

                c = f_t * c + i_t * g_t
                h = o_t * torch.tanh(c)

                logits_t = torch.mv(decoder.W_out, h) + decoder.b_out
                next_token = torch.argmax(logits_t).item()
                generated_tokens.append(next_token)

                if next_token == 4:  # <EOS>
                    break
                current_token = torch.tensor(next_token)

            # Convert token IDs to words for display
            generated_words = [HIN_ID2WORD[t] for t in generated_tokens]
            print("Generated tokens:", generated_words)
        print("-----------------------------------\n")


Epoch 5/1000, Loss = 1.5616
Epoch 10/1000, Loss = 1.5170
Epoch 15/1000, Loss = 1.4804
Epoch 20/1000, Loss = 1.4491

--- Decoding after epoch 20 ---
Generated tokens: ['मैं', 'मैं', 'हूँ', 'हूँ', 'हूँ', 'हूँ']
-----------------------------------

Epoch 25/1000, Loss = 1.4213
Epoch 30/1000, Loss = 1.3959
Epoch 35/1000, Loss = 1.3722
Epoch 40/1000, Loss = 1.3496

--- Decoding after epoch 40 ---
Generated tokens: ['मैं', 'मैं', 'हूँ', 'हूँ', 'हूँ', 'हूँ']
-----------------------------------

Epoch 45/1000, Loss = 1.3279
Epoch 50/1000, Loss = 1.3068
Epoch 55/1000, Loss = 1.2861
Epoch 60/1000, Loss = 1.2658

--- Decoding after epoch 60 ---
Generated tokens: ['मैं', 'मैं', 'हूँ', 'हूँ', 'हूँ', 'हूँ']
-----------------------------------

Epoch 65/1000, Loss = 1.2459
Epoch 70/1000, Loss = 1.2261
Epoch 75/1000, Loss = 1.2066
Epoch 80/1000, Loss = 1.1873

--- Decoding after epoch 80 ---
Generated tokens: ['मैं', 'मैं', 'हूँ', 'हूँ', 'हूँ', 'हूँ']
-----------------------------------

Epoch 85/1000

## Part 3: RNN vs LSTM comparison

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# For reproducibility
torch.manual_seed(0)

# ======================================================
# 1. Define a synthetic long-range dependency task
# ======================================================
# Our vocabularies:
# --- Source (English) ---
#  0: "A"       (the important token)
#  1: "x"       (a distractor token)
#  2: "<EOS>"   (end-of-sequence)
#
# --- Target (Hindi) ---
#  0: "<GO>"    (start-of-decoding)
#  1: "ए"      (translation of "A")
#  2: "<EOS>"   (end-of-sequence)

ENG_VOCAB_SIZE = 3  # tokens: 0 ("A"), 1 ("x"), 2 ("<EOS>")
HIN_VOCAB_SIZE = 3  # tokens: 0 ("<GO>"), 1 ("ए"), 2 ("<EOS>")

# For printing decoded Hindi tokens:
HIN_ID2WORD = {0: "<GO>", 1: "ए", 2: "<EOS>"}

# We will make the source sentence very long by inserting many "x" tokens.
distractor_length = 500  # Try different lengths (e.g., 5, 50, 100) to see the effect

# Construct the source sentence:
# It begins with "A" (0), then many "x" (1), and finally <EOS> (2)
encoder_input = torch.tensor([0] + [1] * distractor_length + [2])

# The target sentence is fixed: it should translate "A" into "ए".
# (Teacher forcing: decoder input starts with <GO> (0) followed by "ए" (1);
#  the expected target is "ए" (1) then <EOS> (2).)
decoder_input = torch.tensor([0, 1])  # <GO>, ए
decoder_target = torch.tensor([1, 2])  # ए, <EOS>

# ======================================================
# 2. Define the Vanilla RNN Encoder and Decoder
# ======================================================

class TinyEncoderRNN(nn.Module):
    def __init__(self, input_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size, embed_size)
        self.hidden_size = hidden_size
        # Manual RNN parameters (no multiplication by 0.1)
        self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_x = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.b   = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, src_tokens):
        h = torch.zeros(self.hidden_size)
        for t in range(src_tokens.size(0)):
            token_id = src_tokens[t]
            x_t = self.embedding(token_id)
            h = torch.tanh(torch.mv(self.W_h, h) + torch.mv(self.W_x, x_t) + self.b)
        return h

class TinyDecoderRNN(nn.Module):
    def __init__(self, output_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_vocab_size, embed_size)
        self.hidden_size = hidden_size
        # Manual RNN parameters for decoding
        self.W_h = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.W_x = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.b   = nn.Parameter(torch.zeros(hidden_size))
        # Output projection
        self.W_out = nn.Parameter(torch.randn(output_vocab_size, hidden_size))
        self.b_out = nn.Parameter(torch.zeros(output_vocab_size))

    def forward(self, dec_tokens, init_hidden):
        h = init_hidden
        logits_list = []
        for t in range(dec_tokens.size(0)):
            token_id = dec_tokens[t]
            x_t = self.embedding(token_id)
            h = torch.tanh(torch.mv(self.W_h, h) + torch.mv(self.W_x, x_t) + self.b)
            logits_t = torch.mv(self.W_out, h) + self.b_out
            logits_list.append(logits_t.unsqueeze(0))
        return torch.cat(logits_list, dim=0)

# ======================================================
# 3. Define the LSTM Encoder and Decoder (manual LSTM cell)
# ======================================================

class TinyEncoderLSTM(nn.Module):
    def __init__(self, input_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size, embed_size)
        self.hidden_size = hidden_size

        # LSTM cell parameters (without multiplication by 0.1)
        self.W_i = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_i = nn.Parameter(torch.zeros(hidden_size))

        self.W_f = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_f = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_f = nn.Parameter(torch.zeros(hidden_size))

        self.W_o = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_o = nn.Parameter(torch.zeros(hidden_size))

        self.W_g = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_g = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_g = nn.Parameter(torch.zeros(hidden_size))

    def forward(self, src_tokens):
        h = torch.zeros(self.hidden_size)
        c = torch.zeros(self.hidden_size)
        for t in range(src_tokens.size(0)):
            token_id = src_tokens[t]
            x_t = self.embedding(token_id)
            i_t = torch.sigmoid(torch.mv(self.W_i, x_t) + torch.mv(self.U_i, h) + self.b_i)
            f_t = torch.sigmoid(torch.mv(self.W_f, x_t) + torch.mv(self.U_f, h) + self.b_f)
            o_t = torch.sigmoid(torch.mv(self.W_o, x_t) + torch.mv(self.U_o, h) + self.b_o)
            g_t = torch.tanh(torch.mv(self.W_g, x_t) + torch.mv(self.U_g, h) + self.b_g)
            c = f_t * c + i_t * g_t
            h = o_t * torch.tanh(c)
        return h, c

class TinyDecoderLSTM(nn.Module):
    def __init__(self, output_vocab_size, embed_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_vocab_size, embed_size)
        self.hidden_size = hidden_size

        self.W_i = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_i = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_i = nn.Parameter(torch.zeros(hidden_size))

        self.W_f = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_f = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_f = nn.Parameter(torch.zeros(hidden_size))

        self.W_o = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_o = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_o = nn.Parameter(torch.zeros(hidden_size))

        self.W_g = nn.Parameter(torch.randn(hidden_size, embed_size))
        self.U_g = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.b_g = nn.Parameter(torch.zeros(hidden_size))

        self.W_out = nn.Parameter(torch.randn(output_vocab_size, hidden_size))
        self.b_out = nn.Parameter(torch.zeros(output_vocab_size))

    def forward(self, dec_tokens, init_hidden, init_cell):
        h = init_hidden
        c = init_cell
        logits_list = []
        for t in range(dec_tokens.size(0)):
            token_id = dec_tokens[t]
            x_t = self.embedding(token_id)
            i_t = torch.sigmoid(torch.mv(self.W_i, x_t) + torch.mv(self.U_i, h) + self.b_i)
            f_t = torch.sigmoid(torch.mv(self.W_f, x_t) + torch.mv(self.U_f, h) + self.b_f)
            o_t = torch.sigmoid(torch.mv(self.W_o, x_t) + torch.mv(self.U_o, h) + self.b_o)
            g_t = torch.tanh(torch.mv(self.W_g, x_t) + torch.mv(self.U_g, h) + self.b_g)
            c = f_t * c + i_t * g_t
            h = o_t * torch.tanh(c)
            logits_t = torch.mv(self.W_out, h) + self.b_out
            logits_list.append(logits_t.unsqueeze(0))
        return torch.cat(logits_list, dim=0)

# ======================================================
# 4. Training the models on the synthetic task
# ======================================================
# Hyperparameters
EMBED_SIZE = 4
HIDDEN_SIZE = 8
num_epochs = 3000
learning_rate = 0.001  # step size (learning rate) set to 0.1

criterion = nn.CrossEntropyLoss()

# ----- Train Vanilla RNN Model -----
encoder_rnn = TinyEncoderRNN(ENG_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
decoder_rnn = TinyDecoderRNN(HIN_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
optimizer_rnn = optim.Adam(list(encoder_rnn.parameters()) + list(decoder_rnn.parameters()), lr=learning_rate)

print("Training Vanilla RNN model on the long-range task...")
for epoch in range(num_epochs):
    optimizer_rnn.zero_grad()
    enc_hidden = encoder_rnn(encoder_input)
    logits = decoder_rnn(decoder_input, enc_hidden)
    loss = criterion(logits, decoder_target)
    loss.backward()
    optimizer_rnn.step()
    if (epoch + 1) % 500 == 0:
        print(f"RNN Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

# ----- Train LSTM Model -----
encoder_lstm = TinyEncoderLSTM(ENG_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
decoder_lstm = TinyDecoderLSTM(HIN_VOCAB_SIZE, EMBED_SIZE, HIDDEN_SIZE)
optimizer_lstm = optim.Adam(list(encoder_lstm.parameters()) + list(decoder_lstm.parameters()), lr=learning_rate)

print("\nTraining LSTM model on the long-range task...")
for epoch in range(num_epochs):
    optimizer_lstm.zero_grad()
    enc_hidden, enc_cell = encoder_lstm(encoder_input)
    logits = decoder_lstm(decoder_input, enc_hidden, enc_cell)
    loss = criterion(logits, decoder_target)
    loss.backward()
    optimizer_lstm.step()
    if (epoch + 1) % 500 == 0:
        print(f"LSTM Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}")

# ======================================================
# 5. Define simple decoding functions (greedy decoding)
# ======================================================
def decode_rnn(encoder, decoder, encoder_input):
    with torch.no_grad():
        h = encoder(encoder_input)
        current_token = torch.tensor(0)  # <GO> token (0)
        generated_tokens = []
        for _ in range(5):  # limit maximum decoding length
            x_t = decoder.embedding(current_token)
            h = torch.tanh(torch.mv(decoder.W_h, h) + torch.mv(decoder.W_x, x_t) + decoder.b)
            logits_t = torch.mv(decoder.W_out, h) + decoder.b_out
            next_token = torch.argmax(logits_t).item()
            generated_tokens.append(next_token)
            if next_token == 2:  # <EOS>
                break
            current_token = torch.tensor(next_token)
    return generated_tokens

def decode_lstm(encoder, decoder, encoder_input):
    with torch.no_grad():
        h, c = encoder(encoder_input)
        current_token = torch.tensor(0)  # <GO>
        generated_tokens = []
        for _ in range(5):
            x_t = decoder.embedding(current_token)
            i_t = torch.sigmoid(torch.mv(decoder.W_i, x_t) + torch.mv(decoder.U_i, h) + decoder.b_i)
            f_t = torch.sigmoid(torch.mv(decoder.W_f, x_t) + torch.mv(decoder.U_f, h) + decoder.b_f)
            o_t = torch.sigmoid(torch.mv(decoder.W_o, x_t) + torch.mv(decoder.U_o, h) + decoder.b_o)
            g_t = torch.tanh(torch.mv(decoder.W_g, x_t) + torch.mv(decoder.U_g, h) + decoder.b_g)
            c = f_t * c + i_t * g_t
            h = o_t * torch.tanh(c)
            logits_t = torch.mv(decoder.W_out, h) + decoder.b_out
            next_token = torch.argmax(logits_t).item()
            generated_tokens.append(next_token)
            if next_token == 2:
                break
            current_token = torch.tensor(next_token)
    return generated_tokens

# ======================================================
# 6. Compare decoding from both models
# ======================================================
print("\n--- Decoding with the Vanilla RNN model ---")
rnn_decoded = decode_rnn(encoder_rnn, decoder_rnn, encoder_input)
print("RNN Decoded tokens (Hindi):", [HIN_ID2WORD[t] for t in rnn_decoded])

print("\n--- Decoding with the LSTM model ---")
lstm_decoded = decode_lstm(encoder_lstm, decoder_lstm, encoder_input)
print("LSTM Decoded tokens (Hindi):", [HIN_ID2WORD[t] for t in lstm_decoded])


Training Vanilla RNN model on the long-range task...
RNN Epoch 500/3000, Loss: 0.0085
RNN Epoch 1000/3000, Loss: 0.0038
RNN Epoch 1500/3000, Loss: 0.0022
RNN Epoch 2000/3000, Loss: 0.0014
RNN Epoch 2500/3000, Loss: 0.0009
RNN Epoch 3000/3000, Loss: 0.0006

Training LSTM model on the long-range task...
LSTM Epoch 500/3000, Loss: 0.0303
LSTM Epoch 1000/3000, Loss: 0.0057
LSTM Epoch 1500/3000, Loss: 0.0024
LSTM Epoch 2000/3000, Loss: 0.0014
LSTM Epoch 2500/3000, Loss: 0.0009
LSTM Epoch 3000/3000, Loss: 0.0006

--- Decoding with the Vanilla RNN model ---
RNN Decoded tokens (Hindi): ['ए', '<EOS>']

--- Decoding with the LSTM model ---
LSTM Decoded tokens (Hindi): ['ए', '<EOS>']
