In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

## RNN ENCODER BLOCK

In [3]:
class TinyEncoder(nn.Module):
    def __init__(self,input_vocab_size,embed_size,hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(input_vocab_size,embed_size)

        # RNN parameters
        self.hidden_size = hidden_size
        self.W_h= nn.Parameter(torch.randn(hidden_size,hidden_size))
        self.W_x = nn.Parameter(torch.randn(hidden_size,embed_size))
        self.b = nn.Parameter(torch.zeros(hidden_size))


    def forward(self,src_tokens):
        """
        src_tokens: shape (src_len,)
        Returns final hidden state (hidden_size,)
        """

        h = torch.zeros(self.hidden_size)

        for t in range(src_tokens.shape[0]):
            token_id = src_tokens[t]
            x_t = self.embedding(token_id)

            h = torch.tanh(
                torch.mv(self.W_h,h) +
                torch.mv(self.W_x,x_t) +
                self.b
            )

        return h






In [4]:
class TinyDecoder(nn.Module):
    def __init__(self,output_vocab_size,embed_size,hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_vocab_size,embed_size)

        self.hidden_size= hidden_size
        self.W_h = nn.Parameter(torch.randn(hidden_size,hidden_size))
        self.W_x = nn.Parameter(torch.randn(hidden_size,embed_size))
        self.b = nn.Parameter(torch.zeros(hidden_size))

        # Output Projection

        self.W_out = nn.Parameter(torch.randn(output_vocab_size,))
        self.b_out = nn.Parameter(torch.zeros(output_vocab_size))


    def forward(self,dec_tokens,init_hidden):
        h = init_hidden
        logits_list = []

        for t in range(dec_tokens.shape[0]):
            token_id = dec_tokens[t]
            x_t = self.embedding(token_id)


            h = torch.tanh(
                torch.mv(self.W_h,h)+
                torch.mv(self.W_x,x_t)+
                self.b
            )
            logits_t = torch.mv(self.W_out,h) + self.b_out 
            logits_list.append(logits_t.unsqueeze(0))

        return torch.cat(logits_list,dim=0)



## EXAMPLE DATA

"I go <EOS>" --> "मैं जाता हूँ <EOS>"

In [5]:
ENG_VOCAB_SIZE = 3  # I=0, go=1, <EOS>=2
HIN_VOCAB_SIZE = 4  # मैं=1, जाता=2, हूँ=3, <EOS>=4

In [6]:
# Map IDs to words for printing 

HIN_ID2WORD = {
    0:"<GO>",
    1:"मैं",
    2:"जाता",
    3:"हूँ",
    4:"<EOS>"
}

In [7]:
EMBED_SIZE = 1
HIDDEN_SIZE = 2

encoder = TinyEncoder(ENG_VOCAB_SIZE,EMBED_SIZE,HIDDEN_SIZE)
decoder = TinyDecoder(HIN_VOCAB_SIZE,EMBED_SIZE,HIDDEN_SIZE)

In [8]:
# Decoder target: "मैं जाता हूँ <EOS>" => [1,2,3,4]"
# We'll do teacher forcing in training:

decoder_input = torch.tensor([0,1,2,3]) # "मैं जाता हूँ <EOS>" => [1,2,3,4]
decoder_target = torch.tensor([1,2,3,4]) # "मैं जाता हूँ <EOS>" 

# TRAINING LOOP