<a href="https://colab.research.google.com/github/TaskoudisDimi/Computational-Intelligence-and-Statistical-Learning/blob/master/TrainedModels/Translation/Translator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [20]:
#Theory
# The goal of this class of models is to map a string input of a fixed-length to a paired string output of fixed length,
# in which these two lengths can differ. If a string in the input language has 8 words, and the same sentence
# in the target language has 4 words, then a high quality translator should infer that and shorten the sentence length of the output


# Seq2Seq translators typically share a common framework. The three primary components of any Seq2Seq translator
# are the encoder and decoder networks and an intermediary vector encoding between them.



# The encoder network is a series of these RNN units. It uses these to sequentially encode the elements from
# the input for the encoder vector, with the final hidden state being written to the intermediary vector.


# Attention is the practice of forcing the decoder to focus on certain parts of the encoder's outputs
# through a set of weights. These attention weights are multiplied by the encoder output vectors.
# This produces a combined vector encoding that will augment the ability of the decoder to understand
# the context of the outputs it is generating, and therefore improve its predictions.
# Calculating these attention weights is done through a feed forward attention layer, which uses the decoders input
# and hidden states as inputs.


# The encoder vector contains the numerical representations of the input from the encoder. If things go correctly,
# it captures all the information from the initial input sentence. This encoding vector then acts as the initial hidden
# state for the decoder network.

# The decoder network is essentially the inverse of the encoder. It takes the encoded vector intermediary as a hidden state,
# and sequentially generates the translation Each element in the output informs the decoders prediction of the following element.



# In practice, a NMT will take an input string of one language and creates a sequence of embeddings representing each element,
# word, in the sentence. The RNN units in the encoders take both the previous hidden state and a single element of the original
# input embedding as inputs, and each step can improve upon the previous step sequentially by accessing the hidden state
# of the previous step to inform the predicted element. It is important to also mention that in addition to encoding the sentence,
# an end of the sentence tag representation is included as an element in the sequence. This end of sentence tagging helps
# the translator know what words in the translated language will trigger the decoder to quit decoding and output the translated
# sentence.

# The final hidden state embeddings are encoded in the intermediary encoder vector.
# The encodings capture as much information as possible about the input sentence in order to facilitate the decoder
# in decoding them into the translation. It can do this be virtue of being used as the initial hidden state
# for the decoder network.

# Using the information from the encoder vector, each recurrent unit in the decoder accepts a hidden state from the previous
# unit and produces an output as well as its own hidden state. The decoder is informed by the hidden state to make a prediction
# on a sequence, and with each sequential prediction, it predicts the next instance of the sequence using the information
# from the previous hidden state. The final output is thus the end result of the step-wise predictions of each element
# in the translated sentence. The length of this sentence is irrelevant to the input sentences length thanks to the end
# of sentence tag, which tells the decoder when to stop adding terms to the sentence.



# Encoder: The input sentence in the source language is passed through an encoder neural network, typically implemented
# as a recurrent neural network (RNN) or a transformer model. The encoder processes the input sequence and produces
# a fixed-size representation (context vector) that captures the information in the source language sentence.

# Decoder: The context vector produced by the encoder is passed as the initial hidden state to the decoder network.
# The decoder, which is another RNN or transformer, generates the output sentence in the target language. It does this
# by producing one word at a time, conditioning each word generation on the previously generated words and the context vector.

# Training: These models are trained on a parallel corpus of source and target language sentences. During training,
# the model learns to minimize the difference between its predicted translations and the actual target translations.



In [None]:
!pip install -U spacy
!pip install -U torchdata
!pip install portalocker

!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm

In [16]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List


# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}

In [17]:
# Tokenization:
# The code begins by setting up tokenizers for the source and target languages. Tokenization is the process of breaking text into individual words or tokens. In this case, the code uses the spaCy tokenizers for the German and English languages.
# token_transform[SRC_LANGUAGE] is set to a tokenizer for the German language ('de_core_news_sm'), and token_transform[TGT_LANGUAGE] is set to a tokenizer for the English language ('en_core_web_sm').

# Helper Function for Yielding Tokens:
# The code defines a helper function named yield_tokens which takes an iterable data iterator and a language as input.
# Inside the function, it uses the previously defined tokenizers to tokenize the text samples from the data iterator for the specified language.
# It yields the tokenized text as a list of strings.

# Special Symbols and Indices:
# The code defines special symbols and their corresponding indices. These symbols are often used in machine translation tasks:
# UNK_IDX (Unknown Index): Used for tokens that are not in the vocabulary.
# PAD_IDX (Padding Index): Used for padding sequences to the same length.
# BOS_IDX (Beginning of Sentence Index): Marks the start of a sentence.
# EOS_IDX (End of Sentence Index): Marks the end of a sentence.
# These symbols are assigned integer indices, starting from 0 to 3.

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

# Vocabulary Building:
# The code iterates over the source and target languages, performing the following steps for each language:
# It creates a training data iterator (train_iter) using the Multi30k dataset, which contains sentence pairs in the specified language pair (e.g., German and English).
# It uses the build_vocab_from_iterator function from torchtext to build a vocabulary for the language.
# The vocabulary is built from the tokenized text samples obtained using the yield_tokens function.
# The min_freq=1 argument ensures that all tokens are included in the vocabulary, and specials=special_symbols includes the special symbols defined earlier in the vocabulary.
# special_first=True ensures that the special symbols are placed at the beginning of the vocabulary.
# Set Default Index:
# After building the vocabulary for each language, the code sets the UNK_IDX as the default index for each vocabulary. This means that if a token is not found in the vocabulary, the index 0 (UNK_IDX) will be used by default.
# This code snippet is an essential part of preparing text data for machine translation tasks, ensuring that text is tokenized, and vocabularies are built for both the source and target languages, with special symbols appropriately configured. These tokenized data and vocabularies can be used to train and evaluate machine translation models effectively.


for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [18]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Positional Encoding:
# The PositionalEncoding class is defined as a helper module. It adds positional encoding to the token embeddings
# to provide information about the word order in the input sequences.
# Positional encoding is used to handle the sequence order since the Transformer architecture does not inherently consider
# order. It helps the model understand the position of words in a sequence.
# The class uses trigonometric functions to create a sinusoidal positional embedding for each position in the input sequence.
# Token Embedding:
# The TokenEmbedding class is another helper module. It converts input indices into token embeddings.
# It uses an embedding layer to map token indices to continuous vector representations. The math.sqrt(self.emb_size) scaling
# factor is applied to the embeddings to ensure they have the correct scale.
# Seq2Seq Transformer Network:
# The Seq2SeqTransformer class defines the main architecture of the sequence-to-sequence Transformer model.
# It takes several hyperparameters as input, such as the number of encoder and decoder layers, embedding size (emb_size),
# number of attention heads (nhead), vocabulary sizes for the source and target languages, and other settings.
# The class consists of the following components:
# Transformer: The core Transformer model is created with the specified configuration.
# generator: A linear layer used to generate the output sequence.
# src_tok_emb and tgt_tok_emb: Token embedding layers for the source and target languages.
# positional_encoding: The positional encoding module is used to add positional information to the token embeddings.
# The forward method takes the source and target sequences, source and target masks, and padding masks as inputs and
# performs the forward pass of the model.
# It returns the output from the Transformer model, which is then passed through the generator to obtain the final output.

# Encoder and Decoder Functions:
# The encode and decode methods are used to separately encode the source sequence and decode the target sequence.
# The encoder encodes the source sequence, and the decoder generates the target sequence using the memory from the encoder.
# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.


class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [None]:
# This code defines two functions for creating masks used in the context of sequence-to-sequence models,
# specifically for the Seq2Seq Transformer model previously described. These masks are essential for controlling
# attention mechanisms and handling padding in the model. Let's break down the code:
# generate_square_subsequent_mask(sz):

# This function generates a square mask of size sz x sz. It's often used in self-attention mechanisms to prevent tokens
# from attending to future tokens.
# torch.triu creates an upper triangular matrix of ones and then compares it to 1, resulting in a binary matrix with ones
# in the upper triangle and zeros elsewhere.
# It is then transposed to swap rows and columns.
# The mask is converted to a float tensor where zeros are filled with negative infinity (float('-inf')) and ones are filled
# with 0.0.
# This mask is used to ensure that during self-attention, tokens can only attend to previous or current tokens and not
# to future tokens.
# create_mask(src, tgt):

# This function is responsible for creating several masks that will be used during the training and inference
# phases of the Seq2Seq Transformer model.

# It takes source (src) and target (tgt) sequences as input.

# src_seq_len and tgt_seq_len represent the lengths of the source and target sequences, respectively.

# tgt_mask:

# Calls generate_square_subsequent_mask with the length of the target sequence (tgt_seq_len) to create a mask for the
# target side. This ensures that during self-attention in the decoder, tokens cannot attend to future tokens.
# src_mask:

# Creates a square mask for the source side. However, this mask is filled with zeros. In this case, it's essentially
# an identity matrix. This is because, in the encoder, each token can attend to all other tokens in the source sequence
# without any restrictions.
# src_padding_mask:

# This mask is created to identify padding tokens in the source sequence. It checks if the values in the source sequence
# are equal to a padding index (PAD_IDX) and transposes the result to make it suitable for use in the Transformer.
# tgt_padding_mask:

# Similar to the source padding mask, this one identifies padding tokens in the target sequence.
# The function returns these masks, which will be used in the forward pass of the Seq2Seq Transformer model. They help
# control which parts of the sequences the model pays attention to and how it handles padding tokens during processing.

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


# Setting a Random Seed:
# It sets a random seed for reproducibility. This ensures that if you run the code multiple times,
# you'll get the same results, which can be useful for debugging and experimentation.
# Defining Model Hyperparameters:
# Several hyperparameters for the model are defined:
# SRC_VOCAB_SIZE and TGT_VOCAB_SIZE: The sizes of the source and target vocabularies, respectively.
# These values are determined based on the vocabularies created earlier.
# EMB_SIZE: The embedding size for tokens in the model (e.g., word embeddings).
# NHEAD: The number of attention heads in the multi-head self-attention mechanism.
# FFN_HID_DIM: The dimension of the feedforward neural network hidden layer within the Transformer.
# BATCH_SIZE: The size of each batch of training data.
# NUM_ENCODER_LAYERS and NUM_DECODER_LAYERS: The number of encoder and decoder layers in the Transformer model.
# Initializing the Transformer Model:
# A Seq2SeqTransformer model is created with the specified hyperparameters. This model will be used for the machine
# translation task.
# Weight Initialization:
# The code initializes the model's weights. For weights with a dimension greater than 1, it uses Xavier (Glorot)
# weight initialization. Weight initialization helps in training neural networks effectively.
# Moving the Model to the Device (GPU or CPU):
# The model is moved to the computing device specified earlier (DEVICE). If a CUDA-compatible GPU is available,
# the model is moved to the GPU; otherwise, it runs on the CPU.
# Loss Function:
# The loss function is defined as cross-entropy loss. This is a common choice for sequence-to-sequence tasks,
# where the goal is to minimize the difference between the predicted translations and the actual target translations.
# The ignore_index parameter is set to PAD_IDX to ignore padding tokens during loss computation.
# Optimizer:
# The code sets up an Adam optimizer to update the model's parameters during training. It uses a learning rate
# of 0.0001 and sets other hyperparameters for Adam, such as the beta values a

torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
print("SRC_VOCAB_SIZE ", SRC_VOCAB_SIZE)
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
print("TGT_VOCAB_SIZE ", TGT_VOCAB_SIZE)
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)


# This is a higher-order function that takes a variable number of transformation functions as its arguments.
# It returns a new function (func) that applies the provided transformations sequentially to the input text.
# This function is used to apply a sequence of text transformations to a raw input text.
# tensor_transform(token_ids: List[int]):
# This function takes a list of token IDs (integers) as input.
# It adds a beginning-of-sequence (BOS) token ID at the beginning, appends the token IDs from the input, and then
# adds an end-of-sequence (EOS) token ID at the end.
# The result is a tensor that represents the input sequence with BOS and EOS tokens.
# text_transform:
# This dictionary holds transformation functions for both the source (SRC_LANGUAGE) and target (TGT_LANGUAGE) languages.
# These transformations include tokenization, numericalization, and the addition of BOS and EOS tokens.
# collate_fn(batch):
# This function is used to collate data samples into batch tensors.
# It takes a batch of data samples, each consisting of a source and target text.
# It applies the text_transform functions to each sample to tokenize, convert to tensors, and add BOS/EOS tokens.
# The resulting source and target tensors are then padded to the same length using pad_sequence from torch.nn.utils.rnn.
# Padding is done with the padding value PAD_IDX.
# The function returns the source and target tensors, which can be used for training the Seq2Seq model.


from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch



# train_epoch(model, optimizer):
# This function is responsible for training one epoch of the machine translation model.
# It takes two main arguments: the model to be trained (model) and the optimizer that will update the model's parameters (optimizer).
# The function follows these steps:
# Sets the model in training mode using model.train().
# Initializes the losses variable to keep track of the cumulative training loss.
# Loads the training data using the Multi30k dataset with the specified source and target languages.
# Creates a data loader (train_dataloader) for batching the data using the collate_fn function.
# Iterates over batches of data from the data loader.
# For each batch, it:
# Moves the source (src) and target (tgt) tensors to the computing device specified earlier (DEVICE).
# Prepares the target input by removing the last token from the target sequence. This is because the model's goal is to predict the next token given the previous tokens.
# Generates masks using the create_mask function. These masks include source mask, target mask, source padding mask, and target padding mask.
# Passes the source, target input, and masks to the model to obtain predictions (logits).
# Initializes the optimizer's gradients to zero with optimizer.zero_grad().
# Computes the loss using cross-entropy loss between the model's predictions and the target output (shifted by one token). The loss is reshaped to be suitable for the loss function.
# Backpropagates the loss and updates the model's parameters using loss.backward() and optimizer.step(), respectively.
# Adds the current batch's loss to the cumulative losses.
# Finally, the function returns the average training loss over all batches in the epoch.
# evaluate(model):

# This function is used for evaluating the model's performance on a validation dataset.
# It takes the trained model (model) as input and follows a similar structure to the training function:
# Sets the model in evaluation mode using model.eval().
# Initializes the losses variable to keep track of the cumulative validation loss.
# Loads the validation data using the Multi30k dataset with the specified source and target languages.
# Creates a data loader (val_dataloader) for batching the validation data using the collate_fn function.
# Iterates over batches of data from the validation data loader.
# For each batch, it performs the following steps:
# Moves the source (src) and target (tgt) tensors to the computing device (DEVICE).
# Prepares the target input in the same way as in the training function.
# Generates masks using the create_mask function.
# Passes the source, target input, and masks to the model to obtain predictions (logits).
# Computes the loss between the model's predictions and the actual target output.
# Adds the current batch's loss to the cumulative losses.
# The function returns the average validation loss over all batches in the validation dataset.


from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(list(train_dataloader))


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:-1, :]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]
        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(list(val_dataloader))

# Training Loop:
# This section contains a training loop that runs for a specified number of epochs (in this case, NUM_EPOCHS is set to 18).
# In each epoch, the model is trained using the train_epoch function, and the training loss is recorded.
# After training the model for an epoch, the evaluate function is used to calculate the validation loss.
# The results, including the training and validation losses, are printed for each epoch, along with the time taken for that epoch.
# Greedy Decoding for Inference:
# The code defines two functions for generating translations during inference:
# greedy_decode(model, src, src_mask, max_len, start_symbol):
# This function performs greedy decoding to generate an output sequence in the target language.
# It takes the trained model (model), a source sequence (src), its associated source mask (src_mask), the maximum length of the output sequence (max_len), and the start symbol (start_symbol) as inputs.
# It iteratively predicts the next token in the target sequence and appends it to the output until it reaches the maximum length or encounters an end-of-sequence token.
# translate(model, src_sentence):
# This function takes a trained model (model) and a source sentence in the source language (src_sentence) as inputs.
# It tokenizes and processes the source sentence and then uses greedy_decode to generate the translation.
# The generated translation is returned as a string.

from timeit import default_timer as timer
NUM_EPOCHS = 18

for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))


# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")


In [13]:
print(translate(transformer, "Eine Gruppe von Menschen steht vor einem Iglu ."))

 A group of people standing in front of an igloo 


In [15]:
torch.save(transformer.state_dict(), 'translate_model.pth')