<a href="https://colab.research.google.com/github/Idan-Alter/OU-22961-Deep-Learning/blob/main/22961_7_4_machine_translation_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install datasets

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import datasets as ds

In [None]:
src = "en"
tgt = "fr"
dataset = ds.load_dataset("tatoeba", lang1=src, lang2=tgt)
dataset

In [None]:
dataset["train"]["translation"][10]

{'en': "Today is June 18th and it is Muiriel's birthday!",
 'fr': "Aujourd'hui nous sommes le 18 juin et c'est l'anniversaire de Muiriel !"}

In [None]:
pairs_list = dataset["train"]["translation"]
total = len(pairs_list)
src_sents_unfiltered = [x[src].split() for x in pairs_list]
tgt_sents_unfiltered = [x[tgt].split() for x in pairs_list]

In [None]:
MAX_length = 5
filter = lambda idx: len(src_sents_unfiltered[idx]) <= MAX_length and \
                     len(tgt_sents_unfiltered[idx]) <= MAX_length
mask = map(filter, range(total))
src_sents = [x for idx,x in enumerate(src_sents_unfiltered) if filter(idx)]
tgt_sents = [x for idx,x in enumerate(tgt_sents_unfiltered) if filter(idx)]

In [None]:
#Shuffle the data
torch.manual_seed(0)
shuffle_idxs = torch.randperm(len(src_sents))
def shuffle(my_list):
  extract_one   = lambda x: my_list[shuffle_idxs[x]]
  shuffled_list = list(map(extract_one,range(len(my_list))))
  return shuffled_list
src_sents = shuffle(src_sents)
tgt_sents = shuffle(tgt_sents)

In [None]:
for idx in range(len(src_sents)):
  src_sents[idx].append("<END>")
  tgt_sents[idx] = ["<START>"]+tgt_sents[idx]+["<END>"]

In [None]:
from torchtext.vocab import build_vocab_from_iterator
src_vocab = build_vocab_from_iterator(src_sents, specials=["<UNK>","<END>"])
src_vocab.set_default_index(0)
print(len(src_vocab))
tgt_vocab = build_vocab_from_iterator(tgt_sents, specials=["<UNK>","<END>","<START>"])
tgt_vocab.set_default_index(0)
print(len(tgt_vocab))

src_tokens = list(map(lambda x: torch.tensor(src_vocab(x)), src_sents))
tgt_tokens = list(map(lambda x: torch.tensor(tgt_vocab(x)), tgt_sents))

18698
27836


In [None]:
START_Token = torch.tensor(tgt_vocab(["<START>"])[0])
END_Token   = torch.tensor(tgt_vocab(["<END>"])[0])
print(START_Token, END_Token)

tensor(2) tensor(1)


In [None]:
print(tgt_vocab.get_itos()[0:15])
print(src_vocab.get_itos()[0:15])

['<UNK>', '<END>', '<START>', '?', 'Je', 'est', '!', 'Il', 'pas', 'Tom', 'de', 'un', "C'est", 'le', 'a']
['<UNK>', '<END>', 'I', 'is', 'a', 'you', 'the', 'to', 'Tom', 'He', 'was', "I'm", 'You', 'She', 'The']


In [None]:
print("Source:", src_sents[0], src_tokens[0], sep="\n")
print("Target:", tgt_sents[0], tgt_tokens[0], sep="\n")

Source:
['I', 'think', 'that', 'helps.', '<END>']
tensor([   2,  140,   43, 4795,    1])
Target:
['<START>', 'Je', 'pense', 'que', 'ça', 'aide.', '<END>']
tensor([  2,   4, 184,  39,  75, 644,   1])


In [None]:
class Encoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim, RNNlayers):
        super().__init__()
        self.src_embedding  = nn.Embedding(len(src_vocab),
                                           embed_dim)  
        self.rnn_stack      = nn.LSTM(embed_dim,                   
                                 hidden_dim,
                                 RNNlayers)  
    def forward(self, src_tokens):
      all_embeddings         = self.src_embedding(src_tokens)
      all_embeddings         = all_embeddings.unsqueeze(1)
      hidden_state_history, _= self.rnn_stack(all_embeddings)
      context                = hidden_state_history[-1,0,:]
      return context

In [None]:
class DecoderRNNCell(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.hidden_state  = torch.zeros(hidden_dim)
        self.RNNcell       = nn.RNNCell(embed_dim, hidden_dim)
        self.output_linear = nn.Linear(in_features=hidden_dim,
                                  out_features=len(tgt_vocab))
        self.logsoftmax    = nn.LogSoftmax(dim=0)
        
    def forward(self, one_embedded_token):
        new_state          = self.RNNcell(one_embedded_token,
                                           self.hidden_state)
        tgt_token_scores   = self.output_linear(new_state)
        tgt_token_logprobs = self.logsoftmax(tgt_token_scores)
        self.hidden_state = new_state 
        return tgt_token_logprobs

In [None]:
class TrainingDecoder(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.tgt_embedding  = nn.Embedding(len(tgt_vocab),
                                           embed_dim)  
        self.RNNcell        = DecoderRNNCell(embed_dim,
                                             hidden_dim)  
    def forward(self, context, tgt_tokens):
      self.RNNcell.hidden_state = context
      translated_tokens = [START_Token]
      sentence_loss = 0
      for idx in range(len(tgt_tokens)-1):
        ##Teacher forcing:
        #previous_token  = tgt_tokens[idx] 
        previous_token  = translated_tokens[idx]   
        embedded_token  = self.tgt_embedding(previous_token)
        logprobs        = self.RNNcell(embedded_token)
        predicted_token = logprobs.argmax()
        translated_tokens.append(predicted_token.detach())
        
        correct_token   = tgt_tokens[idx+1]                 #
        token_loss      = -logprobs[correct_token]          #
        sentence_loss  += token_loss                        #
        
        if predicted_token == END_Token:
          break
      return translated_tokens, sentence_loss

In [None]:
class TrainingTranslator(nn.Module):
      def __init__(self, embed_dim, hidden_dim, encoder_layers):
        super().__init__()
        self.encoder = Encoder(embed_dim, hidden_dim, encoder_layers)
        self.decoder = TrainingDecoder(embed_dim, hidden_dim)
      def forward(self, src_tokens, tgt_tokens):
        context = self.encoder(src_tokens)
        return self.decoder(context, tgt_tokens)

#Eval Mode

In [None]:
class Decoder(TrainingDecoder):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__(embed_dim, hidden_dim)
    def forward(self,context, tgt_tokens=None, max_tokens=10):
      if self.training:
        return super().forward(context, tgt_tokens)
      else:
        with torch.no_grad():
          self.RNNcell.hidden_state = context
          translated_tokens = [START_Token]
          current_token = translated_tokens[0]
          for _ in range(max_tokens):
            embedded_token  = self.tgt_embedding(current_token)
            logprobs        = self.RNNcell(embedded_token)
            predicted_token = logprobs.argmax()
            translated_tokens.append(predicted_token.detach())
            if predicted_token == END_Token:
              break
            current_token   = predicted_token 
        return translated_tokens

In [None]:
class Translator(nn.Module):
      def __init__(self, embed_dim, hidden_dim,encoder_layers):
        super().__init__()
        self.encoder = Encoder(embed_dim, hidden_dim, encoder_layers)
        self.decoder = Decoder(embed_dim, hidden_dim)
      def forward(self,src_tokens, tgt_tokens=None):
        context = self.encoder(src_tokens)
        if self.training:
          out=self.decoder(context, tgt_tokens)
        else:
          out=self.decoder(context)
        return out

#Training

In [None]:
def iterate_one_pair(src_tokens, tgt_tokens):
    model.train()  
    optimizer.zero_grad()
    output, loss = model(src_tokens, tgt_tokens)
    loss.backward()
    optimizer.step()
    return loss.detach()

In [None]:
model     = Translator(50,50,2)
optimizer = torch.optim.AdamW(model.parameters())

In [None]:
#overfit a small batch to check if learning _can_ occur
num_samples, epochs = 10, 200
for epoch in range(epochs):
  batch_loss_agg = torch.tensor([0.])
  for idx in range(num_samples):
    batch_loss_agg += iterate_one_pair(src_tokens[idx], tgt_tokens[idx])
  epoch_loss = batch_loss_agg / num_samples
  if epoch % 20 == 0:
    print("Epoch", epoch, " loss:", epoch_loss.item())

Epoch 0  loss: 51.42719650268555
Epoch 20  loss: 8.938413619995117
Epoch 40  loss: 7.273235321044922
Epoch 60  loss: 7.895596981048584
Epoch 80  loss: 5.613731861114502
Epoch 100  loss: 6.131002426147461
Epoch 120  loss: 3.0154340267181396
Epoch 140  loss: 1.5002210140228271
Epoch 160  loss: 0.8781587481498718
Epoch 180  loss: 0.5741285085678101


In [None]:
model.eval()
with torch.no_grad():
  for idx in range(num_samples):
    a = model(src_tokens[idx])
    predicted_itos = [tgt_vocab.get_itos()[x.item()] for x in a] 
    ground_truth   = [tgt_vocab.get_itos()[x.item()] for x in tgt_tokens[idx]] 
    print(predicted_itos, ground_truth)

['<START>', 'Je', 'pense', 'que', 'ça', 'aide.', '<END>'] ['<START>', 'Je', 'pense', 'que', 'ça', 'aide.', '<END>']
['<START>', 'Il', 'était', 'trop', 'dur.', '<END>'] ['<START>', 'Il', 'était', 'trop', 'dur.', '<END>']
['<START>', 'Écrivez', 'votre', 'nom', 'en', 'majuscules.', '<END>'] ['<START>', 'Écrivez', 'votre', 'nom', 'en', 'majuscules.', '<END>']
['<START>', 'Où', 'séjournes-tu', '?', '<END>'] ['<START>', 'Où', 'séjournes-tu', '?', '<END>']
['<START>', 'Elle', 'entrouvrit', 'la', 'porte.', '<END>'] ['<START>', 'Elle', 'entrouvrit', 'la', 'porte.', '<END>']
['<START>', 'Je', 'ne', 'suis', 'pas', 'intimidée.', '<END>'] ['<START>', 'Je', 'ne', 'suis', 'pas', 'intimidée.', '<END>']
['<START>', "C'est", 'gentil.', '<END>'] ['<START>', "C'est", 'gentil.', '<END>']
['<START>', 'Personne', 'ne', 'le', 'saura.', '<END>'] ['<START>', 'Personne', 'ne', 'le', 'saura.', '<END>']
['<START>', 'Ne', 'va', 'pas', 'là.', '<END>'] ['<START>', 'Ne', 'va', 'pas', 'là.', '<END>']
['<START>', 'La', 

In [None]:
with torch.no_grad():
  for idx in range(num_samples, num_samples+5):
    a = model(src_tokens[idx])
    predicted_itos = [tgt_vocab.get_itos()[x.item()] for x in a] 
    ground_truth   = [tgt_vocab.get_itos()[x.item()] for x in tgt_tokens[idx]] 
    print(predicted_itos,ground_truth)

['<START>', 'Il', 'était', 'trop', 'dur.', '<END>'] ['<START>', 'Nous', 'avons', 'toutes', 'nos', 'secrets.', '<END>']
['<START>', 'Je', 'pense', 'que', 'ça', 'aide.', '<END>'] ['<START>', "J'aime", 'cette', 'chambre.', '<END>']
['<START>', 'La', 'maison', 'est', 'inoccupée.', '<END>'] ['<START>', 'Nous', 'ne', 'pouvons', 'pas', "l'aider.", '<END>']
['<START>', 'Où', 'séjournes-tu', '?', '<END>'] ['<START>', 'Ils', 'sont', 'nos', 'invités.', '<END>']
['<START>', 'Où', 'séjournes-tu', '?', '<END>'] ['<START>', 'Voulez-vous', 'coucher', 'avec', 'moi', '?', '<END>']
