In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
from tqdm.auto import tqdm
import numpy as np
import spacy
import random

In [5]:
# Training hyperparameters
num_epochs = 100
learning_rate = 0.001
batch_size = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
spacy_ger = spacy.load("de")
spacy_eng = spacy.load("en")


def tokenize_ger(text):
    return [tok.text for tok in spacy_ger.tokenizer(text)]


def tokenize_eng(text):
    return [tok.text for tok in spacy_eng.tokenizer(text)]


german = Field(tokenize=tokenize_ger, lower=True, init_token="<sos>", eos_token="<eos>")

english = Field(
    tokenize=tokenize_eng, lower=True, init_token="<sos>", eos_token="<eos>"
)

train_data, valid_data, test_data = Multi30k.splits(
    exts=(".de", ".en"), fields=(german, english)
)

german.build_vocab(train_data, max_size=10000, min_freq=2)
english.build_vocab(train_data, max_size=10000, min_freq=2)

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=batch_size,
    sort_within_batch=True,
    sort_key=lambda x: len(x.src),
    device=device,
)

downloading training.tar.gz


training.tar.gz: 100%|██████████| 1.21M/1.21M [00:00<00:00, 1.33MB/s]


downloading validation.tar.gz


validation.tar.gz: 100%|██████████| 46.3k/46.3k [00:00<00:00, 232kB/s]


downloading mmt_task1_test2016.tar.gz


mmt_task1_test2016.tar.gz: 100%|██████████| 66.2k/66.2k [00:00<00:00, 223kB/s]


In [8]:
class Encoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers, p):
        super(Encoder, self).__init__()
        self.dropout = nn.Dropout(p)
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p, batch_first=True)

    def forward(self, x): 
        embedding = self.dropout(self.embedding(x))  
        outputs, (hidden, cell) = self.rnn(embedding) 
        
        return hidden, cell


class Decoder(nn.Module):
    # In decoder we are passing word by word
    def __init__(
        self, input_size, embedding_size, hidden_size, output_size, num_layers, p
    ):  
        super(Decoder, self).__init__()
        self.dropout = nn.Dropout(p)
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, dropout=p, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, hidden, cell):  
        x = x.unsqueeze(1)  
        embedding = self.dropout(self.embedding(x)) 
        outputs, (hidden, cell) = self.rnn(embedding, (hidden, cell)) 
        predictions = self.fc(outputs)  
        predictions = predictions.squeeze(1)  

        return predictions, hidden, cell


class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_force_ratio=0.5): 
        batch_size = source.shape[0]
        target_len = target.shape[1]
        target_vocab_size = len(english.vocab)

        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(device)

        hidden, cell = self.encoder(source)

        # Grab the first input to the Decoder which will be <SOS> token
        x = target[:, 0]  # shape is (N)

        for t in range(1, target_len):
            # Use previous hidden, cell as context from encoder at start
            output, hidden, cell = self.decoder(x, hidden, cell)

            # Store next output prediction
            outputs[:, t, :] = output

            # Get the best word the Decoder predicted (index in the vocabulary)
            best_guess = output.argmax(1)  

            x = target[:, t] if random.random() < teacher_force_ratio else best_guess

        return outputs  

In [9]:
# Model hyperparameters
load_model = False
input_size_encoder = len(german.vocab)
input_size_decoder = len(english.vocab)
output_size = len(english.vocab)
encoder_embedding_size = 300
decoder_embedding_size = 300
hidden_size = 1024  
num_layers = 2
enc_dropout = 0.5
dec_dropout = 0.5

encoder_net = Encoder(
    input_size_encoder, encoder_embedding_size, hidden_size, num_layers, enc_dropout
).to(device)

decoder_net = Decoder(
    input_size_decoder,
    decoder_embedding_size,
    hidden_size,
    output_size,
    num_layers,
    dec_dropout,
).to(device)

model = Seq2Seq(encoder_net, decoder_net).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

pad_idx = english.vocab.stoi["<pad>"]
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [10]:
def translate_sentence(model, sentence, german, english, device, max_length=50):
    spacy_ger = spacy.load("de")

    # Create tokens using spacy and everything in lower case (which is what our vocab is)
    if type(sentence) == str:
        tokens = [token.text.lower() for token in spacy_ger(sentence)]
    else:
        tokens = [token.lower() for token in sentence]  

    # Add <SOS> and <EOS> in beginning and end respectively
    tokens.insert(0, german.init_token)
    tokens.append(german.eos_token)

    # Go through each german token and convert to an index
    text_to_indices = [german.vocab.stoi[token] for token in tokens]  

    # Convert to Tensor
    sentence_tensor = torch.LongTensor(text_to_indices).unsqueeze(0).to(device)  

    # Build encoder hidden, cell state
    with torch.no_grad():
        hidden, cell = model.encoder(sentence_tensor)  

    outputs = [english.vocab.stoi["<sos>"]]

    for _ in range(max_length):
        previous_word = torch.LongTensor([outputs[-1]]).to(device)  

        with torch.no_grad():
            output, hidden, cell = model.decoder(previous_word, hidden, cell)  
            best_guess = output.argmax(1).item()

        outputs.append(best_guess)

        # Model predicts it's the end of the sentence
        if output.argmax(1).item() == english.vocab.stoi["<eos>"]:
            break

    translated_sentence = [english.vocab.itos[idx] for idx in outputs]
    return translated_sentence


In [None]:
sentence = "ein boot mit mehreren männern darauf wird von einem großen pferdegespann ans ufer gezogen."
# a boat with several men on it is pulled ashore by a large team of horses.

for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    model.eval()

    translated_sentence = translate_sentence(
        model, sentence, german, english, device, max_length=50
    )
    print(f"Translated example sentence: \n {translated_sentence}")

    model.train()
    for batch_idx, batch in enumerate(train_iterator):
        inp_data = batch.src.T.to(device)
        target = batch.trg.T.to(device)

        # Forward prop
        output = model(inp_data, target)

        output = output[:, 1:].reshape(-1, output.shape[2])
        target = target[:, 1:].reshape(-1)

        optimizer.zero_grad()
        loss = criterion(output, target)

        # Back prop
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()

[Epoch 0 / 100]
Translated example sentence: 
 ['<sos>', 'follow', 'follow', 'follow', 'tips', 'tips', 'pork', 'bronco', 'bronco', 'cans', 'cook', 'contains', 'cook', 'approximately', 'garden', 'garden', 'row', 'row', 'row', 'canvas', 'canvas', 'swamp', 'swamp', 'mets', 'stories', 'reason', 'reason', 'kissing', 'reception', 'players', 'flutes', 'scarf', 'sat', 'sat', 'drifts', 'armor', 'posed', 'steel', 'bows', 'tips', 'wild', 'making', 'making', 'reception', 'pork', 'pork', 'removes', 'contains', 'reception', 'seattle', 'bring']
[Epoch 1 / 100]
Translated example sentence: 
 ['<sos>', 'a', 'black', 'player', 'is', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', '.', '<eos>']
[Epoch 2 / 100]
Translated example sentence: 
 ['<sos>', 'a', 'football', 'player', 'with', 'a', '<unk>', '<unk>', '<unk>', 'a', 'a', 'a', 'a', 'a', '.', '<eos>']
[Epoch 3 / 100]
Translated example sentence: 
 ['<sos>', 'a', '<unk>', 'with', 'a', '<unk>', 'of', 'a', 'a', 'a', 'a', 'a', 'a', '.', '<eos>']
[Epoch 4 / 1