This notebook implements an RNN encoder / decoder for machine translation, inspired by:

[Learning phrase representations using RNN encoder-decoder for statistical machine translation](https://arxiv.org/abs/1406.1078)
K Cho, B Van Merriënboer, C Gulcehre, D Bahdanau, F Bougares, H Schwenk, Y Bengio
arXiv preprint arXiv:1406.1078, 2014•arxiv.org

and 

[Sequence to Sequence Learning with Neural Networks](https://arxiv.org/abs/1409.3215)
I Sutskever - arXiv preprint arXiv:1409.3215, 2014 - jeremy-su1.github.io

The implementation is heavily influenced by:

https://colab.research.google.com/drive/1GBC7eLlEM-HqKLUuMcFIQdVuYXzLoS_P?usp=sharing

Importantly, it provides the English to Italian data set I use

In [242]:
# Imports
import itertools
import os

import numpy as np
import requests
import torch
import torch.nn as nn
import torch.optim as optim

from letsbuildmodels.devices import get_device
from nltk.lm.vocabulary import Vocabulary
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.nn.utils.rnn import pad_sequence
from torch.nn.functional import one_hot
from torcheval.metrics.functional import bleu_score

In [102]:
# Data pre-processing

# Download the data
local_path = os.path.join(os.getcwd(), "data", "eng_ita_v2.txt")

def download_file_if_not_exists():
    url = "https://raw.githubusercontent.com/kyuz0/llm-chronicles/main/datasets/eng_ita_v2.txt"
    
    directory = os.path.dirname(local_path)
    os.makedirs(directory, exist_ok=True)
    
    if not os.path.exists(local_path):
        print(f"Downloading file from {url} to {local_path}...")
        response = requests.get(url, stream=True)
        response.raise_for_status()  # Raise an error for bad HTTP responses
        with open(local_path, "wb") as file:
            for chunk in response.iter_content(chunk_size=8192):
                file.write(chunk)
        print("Download complete.")
    else:
        print(f"File already exists at {local_path}. No download needed.")

download_file_if_not_exists()


# Read the data
def read_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.read().strip().split('\n')
    pairs = [tuple([s for s in line.split(' -> ')]) for line in lines]
    return pairs

pairs = read_data(local_path)
print(f"File contains {len(pairs)} translations")

# Build vocabularies
def build_vocab(pairs):
    eng_tokens = list(itertools.chain.from_iterable([word_tokenize(eng) for (eng, _) in pairs]))
    ita_tokens = list(itertools.chain.from_iterable([word_tokenize(ita) for (_, ita) in pairs]))
    eng_vocab = Vocabulary(eng_tokens)
    ita_vocab = Vocabulary(ita_tokens)
    return eng_vocab, ita_vocab

eng_vocab, ita_vocab = build_vocab(pairs)

print('English vocabulary size:', len(eng_vocab))
print('Italian vocabulary size:', len(ita_vocab))

PAD_TOKEN = "<PAD>"
EOS_TOKEN = "<EOS>"
SOS_TOKEN = "<SOS>"
UNK_TOKEN = "<UNK>"

# Creating integer <-> word mapping
class WordMapping:
    def __init__(self, vocab):
        self.word_to_int = {}
        self.int_to_word = {}
        word_counts = [(word, vocab[word]) for word in vocab]
        sorted_word_counts = sorted(word_counts, key=lambda t: t[1], reverse=True)
        sorted_word_counts = sorted_word_counts + [(PAD_TOKEN, 1), (EOS_TOKEN, 1), (SOS_TOKEN, 1)]
        for i, (word, _) in enumerate(sorted_word_counts):
            self.word_to_int[word] = i
            self.int_to_word[i] = word

    def __getitem__(self, key):
        if type(key) == str:
            if key in self.word_to_int:
                return self.word_to_int[key]
            elif key.lower() in self.word_to_int:
                return self.word_to_int[key.lower()]
            else:
                return self.word_to_int[UNK_TOKEN]                
        elif type(key) == int:
            return self.int_to_word[key]
        else:
            raise KeyError(f"Invalid key type: {type(key)}")

    def __len__(self):
        return len(self.word_to_int)

eng_mapping = WordMapping(eng_vocab)
ita_mapping = WordMapping(ita_vocab)

File already exists at /Users/jamescataldo/Code/letsbuildmodels/notebooks/encdec/data/eng_ita_v2.txt. No download needed.
File contains 120746 translations
English vocabulary size: 4894
Italian vocabulary size: 13675


In [145]:
# Creating datasets and loaders
class TranslationDataset(Dataset):
    def __init__(self):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        eng, ita = self.pairs[idx]
        eng_tensor = torch.tensor([eng_mapping[word] for word in word_tokenize(eng)]
                                  + [eng_mapping[EOS_TOKEN]], dtype=torch.long)
        ita_tensor = torch.tensor([ita_mapping[word] for word in word_tokenize(ita)]
                                  + [ita_mapping[EOS_TOKEN]], dtype=torch.long)
        return eng_tensor, ita_tensor

# Custom collate function to handle padding
def collate_fn(batch):
    eng_batch, ita_batch = zip(*batch)
    eng_batch_padded = pad_sequence(eng_batch, batch_first=True, padding_value=eng_mapping[PAD_TOKEN])
    ita_batch_padded = pad_sequence(ita_batch, batch_first=True, padding_value=ita_mapping[PAD_TOKEN])
    return eng_batch_padded, ita_batch_padded

# Create the DataLoader
translation_dataset = TranslationDataset()
translations = len(translation_dataset)
train_dataset, test_dataset = random_split(translation_dataset, [translations // 2, translations // 2])
batch_size = 64
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn,
    drop_last=True
)

batches = len(train_dataloader)
print(f"Training translations: {translations}")
print(f"Number of batches: {batches}")

Training translations: 120746
Number of batches: 943


In [227]:
# Build the models
eng_vocab_size = len(eng_vocab)
ita_vocab_size = len(ita_vocab)
embed_size = 256
hidden_size = 512
num_layers = 1

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(eng_vocab_size, embed_size)
        self.gru = nn.GRU(
            embed_size,
            hidden_size,
            num_layers=num_layers,
            batch_first=True
        )

    def forward(self, x):
        embedded = self.embedding(x)
        _, hidden = self.gru(embedded)
        return hidden

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(ita_vocab_size, embed_size)
        self.gru = nn.GRU(embed_size,
                           hidden_size,
                           num_layers=num_layers,
                           batch_first=True)
        self.linear = nn.Linear(hidden_size, ita_vocab_size)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, x, hidden):
        embedded = self.embedding(x)
        decoder_gru, _ = self.gru(embedded, hidden)
        out = self.linear(decoder_gru)
        return self.softmax(out)

class Translator(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, enc_x, dec_x):
        hidden = self.encoder(enc_x)
        out = self.decoder(dec_x, hidden)
        return out

model = Translator()
device = get_device()
model.to(device)
print(model)

Translator(
  (encoder): Encoder(
    (embedding): Embedding(4894, 256)
    (gru): GRU(256, 512, batch_first=True)
  )
  (decoder): Decoder(
    (embedding): Embedding(13675, 256)
    (gru): GRU(256, 512, batch_first=True)
    (linear): Linear(in_features=512, out_features=13675, bias=True)
    (softmax): Softmax(dim=2)
  )
)


In [241]:
def train():
    loss_fn = nn.BLEUScore(ignore_index=ita_mapping[PAD_TOKEN])
    optimizer = optim.AdamW(model.parameters())
    num_epochs = 10
    
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        count = 0
        for i, (encoder_input, target) in enumerate(train_dataloader):
            encoder_input, target = encoder_input.to(device), target.to(device)
            decoder_input = torch.empty_like(target)
            decoder_input[:, 0] = ita_mapping[SOS_TOKEN]
            decoder_input[:, 1:] = target[:, :-1]
            
            optimizer.zero_grad()

            output = model(encoder_input, decoder_input)

            loss = loss_fn(output.permute(0, 2, 1), target)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() / target.size()[1]
            count += 1
            if i % 10 == 0:
                print(f"Epoch {epoch}, Batch {i}/{batches}, Loss: {epoch_loss / count:.4f}", end="\r")
            
        print(f"Epoch {epoch}, Loss: {epoch_loss / count:.4f}                             ")
                
train()

Epoch 0, Loss: 0.63783, Loss: 0.6378
Epoch 1, Loss: 0.63123, Loss: 0.6311
Epoch 2, Loss: 0.62843, Loss: 0.6283
Epoch 3, Loss: 0.62963, Loss: 0.6293
Epoch 4, Loss: 0.62903, Loss: 0.6290
Epoch 5, Loss: 0.62443, Loss: 0.6245
Epoch 6, Loss: 0.62253, Loss: 0.6229
Epoch 7, Loss: 0.62153, Loss: 0.6216
Epoch 8, Loss: 0.62123, Loss: 0.6212
Epoch 9, Loss: 0.62273, Loss: 0.6225


In [271]:
specials = {
    PAD_TOKEN,
    EOS_TOKEN,
    SOS_TOKEN,
    UNK_TOKEN,
}

def to_ita_sentence(tensor):
    ita = [ita_mapping[x.item()] for x in tensor]
    strs = [y for y in ita if y not in specials]
    return " ".join(strs)

def to_eng_sentence(tensor):
    eng = [eng_mapping[x.item()] for x in tensor]
    strs = [y for y in eng if y not in specials]
    return " ".join(strs).replace(" '", "'")

def test(print_translations=False):
    model.eval()
    
    epoch_loss = 0
    count = 0
    with torch.no_grad():
        for i, (encoder_input, target) in enumerate(test_dataloader):
            encoder_input, target = encoder_input.to(device), target.to(device)
            decoder_input = torch.empty_like(target)
            decoder_input[:, 0] = ita_mapping[SOS_TOKEN]
            decoder_input[:, 1:] = target[:, :-1]
            
            output = model(encoder_input, decoder_input)
    
            for batch in range(batch_size):
                input_tokens = encoder_input[batch]
                input_str = to_eng_sentence(input_tokens)
                output_tokens = torch.argmax(output[batch], dim=1)
                output_str = to_ita_sentence(output_tokens)            
                target_tokens = target[batch]
                target_str = to_ita_sentence(target_tokens) 
                if print_translations:
                    print(f"English: {input_str}")
                    print(f"Desired Italian: {target_str}")
                    print(f"Generated Italian: {output_str}")
                    print()
    
                loss = bleu_score(output_str, [target_str], n_gram=min(2, len(output_str)))
        
                epoch_loss += loss.item()
                count += 1
            if i % 10 == 0:
                print(f"Batch {i}/{batches}, BLEU: {epoch_loss / count:.4f}", end="\r")
        print(f"BLEU: {epoch_loss / count:.4f}                             ")
                
test()

Batch 190/943, BLEU: 0.0626

KeyboardInterrupt: 