In [None]:
import pandas as pd
import numpy as np
import torch 
import torch.nn as nn
import torch.optim as optim
import math
from pathlib import Path
from tokenizers import Tokenizer
from tokenizers.models import BPE as TokenizerBPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from torch.utils .data import Dataset, DataLoader, random_split
from config import get_config, get_weights_file_path, latest_weights_file_path
from tqdm import tqdm
import os 
import torchmetrics
from torch.utils.tensorboard import SummaryWriter

In [2]:

config = get_config()
train_data="C:/Users/pc/Downloads/archive/translation_train.csv"
test_data="C:/Users/pc/Downloads/archive/translation_test.csv"

train_df = pd.read_csv(train_data)
test_df = pd.read_csv(test_data)

train_df = train_df.dropna()
test_df = test_df.dropna()

train_df = train_df[train_df['english'].str.strip() != '']
test_df = test_df[test_df['english'].str.strip() != '']

train_df = train_df[train_df['german'].str.strip() != '']
test_df = test_df[test_df['german'].str.strip() != '']

train_df


Unnamed: 0,english,german
0,"Two young, White males are outside near many b...",Zwei junge weiße Männer sind im Freien in der ...
1,Several men in hard hats are operating a giant...,Mehrere Männer mit Schutzhelmen bedienen ein A...
2,A little girl climbing into a wooden playhouse.,Ein kleines Mädchen klettert in ein Spielhaus ...
3,A man in a blue shirt is standing on a ladder ...,Ein Mann in einem blauen Hemd steht auf einer ...
4,Two men are at the stove preparing food.,Zwei Männer stehen am Herd und bereiten Essen zu.
...,...,...
28995,A woman behind a scrolled wall is writing,Eine Frau schreibt hinter einer verschnörkelte...
28996,A rock climber practices on a rock climbing wall.,Ein Bergsteiger übt an einer Kletterwand.
28997,Two male construction workers are working on a...,Zwei Bauarbeiter arbeiten auf einer Straße vor...
28998,An elderly man sits outside a storefront accom...,Ein älterer Mann sitzt mit einem Jungen mit ei...


In [3]:
en_combined_text = train_df['english'].str.cat(sep=' ')
ge_combined_text = train_df['german'].str.cat(sep=' ')
combined_text = en_combined_text + ' ' + ge_combined_text


In [4]:

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        # ! TOO Large, need to be Quantize
        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds[self.src_lang])

    def __getitem__(self, idx):
        src_text = self.ds[self.src_lang][idx]
        tgt_text = self.ds[self.tgt_lang][idx]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # ! Need to Quantize
        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }
    
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


In [5]:

def get_all_sentences(ds, lang):
    for item in ds.iterrows():
        yield item[1][lang]

def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    if not Path.exists(tokenizer_path):
        # Use a different variable name to avoid the conflict
        tokenizer_instance = Tokenizer(TokenizerBPE(unk_token='[UNK]'))
        tokenizer_instance.pre_tokenizer = Whitespace()
        trainer = BpeTrainer(special_tokens=['[UNK]', '[PAD]', '[SOS]', '[EOS]'], min_frequency=2)
        tokenizer_instance.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer_instance.save(str(tokenizer_path))
    else:
        tokenizer_instance = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer_instance

def get_ds(config, ds_raw, split_ratio=0.9):
    
    # Build tokenizers
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Combine the lists into a list of tuples
    combined = list(zip(ds_raw['english'], ds_raw['german']))

    # Split the combined list into training and validation sets

    train_ds_size = int(len(combined) * split_ratio)
    val_ds_size = len(combined) - train_ds_size
    train_combined, val_combined = random_split(combined, [train_ds_size, val_ds_size])


    # Unzip the combined lists back into separate lists
    train_en, train_fr = zip(*train_combined)
    val_en, val_fr = zip(*val_combined)

    # Combine the splits into new dictionaries
    train_set = {'english': list(train_en), 'german': list(train_fr)}
    val_set = {'english': list(val_en), 'german': list(val_fr)}


    train_ds = BilingualDataset(train_set, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_set, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])


    # Find the maximum length of each sentence in the source and target sentence
    max_len_src = 0
    max_len_tgt = 0

    for idx in range( len( ds_raw[config['lang_src']]) ):
        src_ids = tokenizer_src.encode(ds_raw['english'][idx]).ids
        tgt_ids = tokenizer_tgt.encode(ds_raw['german'][idx]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt


In [None]:

class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
         # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

class ResidualConnection(nn.Module):
    
        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)
    
        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)  
        return self.w_o(x)

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
    
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x
    
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)
    
class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)
    
def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)
    
    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
    
    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model


In [7]:

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)


def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)
            
            # Print the source, target and model output
            print_msg('-'*console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-'*console_width)
                break
    
    if writer:
        # Evaluate the character error rate
        # Compute the char error rate 
        metric = torchmetrics.CharErrorRate()
        cer = metric(predicted, expected)
        writer.add_scalar('validation cer', cer, global_step)
        writer.flush()

        # Compute the word error rate
        metric = torchmetrics.WordErrorRate()
        wer = metric(predicted, expected)
        writer.add_scalar('validation wer', wer, global_step)
        writer.flush()

        # Compute the BLEU metric
        metric = torchmetrics.BLEUScore()
        bleu = metric(predicted, expected)
        writer.add_scalar('validation BLEU', bleu, global_step)
        writer.flush()


def train_model(config, ds_raw):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config, ds_raw, config['split_ratio'])
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Compute the loss using a simple cross entropy
            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)



In [8]:

train_model(config, train_df)

Using device: cuda
Device name: NVIDIA GeForce RTX 4070 SUPER
Device memory: 11.99365234375 GB
Max length of source sentence: 41
Max length of target sentence: 46
No model to preload, starting from scratch


Processing Epoch 00: 100%|██████████| 3263/3263 [02:25<00:00, 22.39it/s, loss=4.111]


--------------------------------------------------------------------------------
    SOURCE: A man is standing in front of a building holding heart shaped balloons and a woman is crossing the street.
    TARGET: Ein Mann steht vor einem Gebäude und hält herzförmige Luftballons, und eine Frau überquert die Straße.
 PREDICTED: Ein Mann steht vor einem Tisch , der eine Frau hält , während eine Frau in der Hand .
--------------------------------------------------------------------------------
    SOURCE: A person is walking on the night street under neon lights.
    TARGET: Eine Person schlendert nachts im Neonlicht die Straße entlang.
 PREDICTED: Eine Person geht auf der Straße entlang , die die Straße .
--------------------------------------------------------------------------------


Processing Epoch 01: 100%|██████████| 3263/3263 [02:15<00:00, 24.08it/s, loss=4.326]


--------------------------------------------------------------------------------
    SOURCE: A woman and a baby eating (having a picnic).
    TARGET: Eine Frau und ein Kleinkind essen (picknicken).
 PREDICTED: Eine Frau und ein Baby posieren in einem Waschsalon .
--------------------------------------------------------------------------------
    SOURCE: A boy sits on the grass near a geometric sculpture.
    TARGET: Ein Junge sitzt auf dem Rasen in der Nähe von einer geometrischen Skulptur.
 PREDICTED: Ein Junge sitzt auf dem Gras und macht eine Pause in der Nähe eines Hauses .
--------------------------------------------------------------------------------


Processing Epoch 02: 100%|██████████| 3263/3263 [02:25<00:00, 22.41it/s, loss=3.226]


--------------------------------------------------------------------------------
    SOURCE: A cello player and a violinist getting ready for their performance in an elegant room.
    TARGET: Ein Cellospieler und ein Geiger bereiten sich in einem eleganten Raum auf ihren Auftritt vor.
 PREDICTED: Ein paar männliche und ein paar Mädchen bereiten sich in einem Raum für ihre Kleidung auf .
--------------------------------------------------------------------------------
    SOURCE: A person walks among large, white geometric shaped architecture.
    TARGET: Eine Person geht unter großen geometrischen Bauwerken.
 PREDICTED: Eine Person geht mit großen weißen , weißen Obst .
--------------------------------------------------------------------------------


Processing Epoch 03: 100%|██████████| 3263/3263 [02:25<00:00, 22.50it/s, loss=3.909]


--------------------------------------------------------------------------------
    SOURCE: Two little boys play in the water left behind by the sprinklers.
    TARGET: Zwei kleine Jungen spielen im Wasser aus dem Wassersprenger.
 PREDICTED: Zwei kleine Jungen spielen in der linken linken linken linken linken linken Seite .
--------------------------------------------------------------------------------
    SOURCE: Front stroke swimming race roped off lap areas.
    TARGET: @@
 PREDICTED: Bei einem Rennen , das ein Rennen mit dem Schoß ist .
--------------------------------------------------------------------------------


Processing Epoch 04: 100%|██████████| 3263/3263 [02:22<00:00, 22.86it/s, loss=2.170]


--------------------------------------------------------------------------------
    SOURCE: Three teenagers dancing one evening on the street.
    TARGET: Drei Teenager tanzen an einem Abend auf der Straße.
 PREDICTED: Drei Teenager tanzen an der Straße .
--------------------------------------------------------------------------------
    SOURCE: A construction crew working at several spots on the road.
    TARGET: Eine Bauarbeitermannschaft, die an mehreren Stellen auf der Straße arbeitet.
 PREDICTED: Ein Bauarbeiter arbeitet an mehreren Computer s auf der Straße .
--------------------------------------------------------------------------------


Processing Epoch 05: 100%|██████████| 3263/3263 [02:15<00:00, 24.09it/s, loss=2.865]


--------------------------------------------------------------------------------
    SOURCE: A brown-haired woman is holding a baby boy wearing a shirt with Reindeer on it and a chef's hat.
    TARGET: Eine braunhaarige Frau hält einen kleinen Jungen, der ein Oberteil mit Rentiermotiv und eine Kochmütze trägt.
 PREDICTED: Eine braunhaarige Frau hält ein Baby , das ein Baby trägt , das mit einem Hut mit einem Hut und einem Koch trägt .
--------------------------------------------------------------------------------
    SOURCE: A man in shorts walks a big brown dog and a woman pushing a baby carriage walks right behind him.
    TARGET: Ein Mann in Shorts führt einen großen braunen Hund spazieren und eine Frau schiebt einen Kinderwagen direkt hinter ihm.
 PREDICTED: Ein Mann in kurzen Hosen geht einen großen Hund entlang , der hinter ihm eine Frau schiebt , einen Kinderwagen schiebt .
--------------------------------------------------------------------------------


Processing Epoch 06: 100%|██████████| 3263/3263 [02:14<00:00, 24.21it/s, loss=2.255]


--------------------------------------------------------------------------------
    SOURCE: A snowboarder is performing a stunt in mostly dark.
    TARGET: Ein Snowboarder vollführt bei schummerigem Licht ein Kunststück.
 PREDICTED: Ein Snowboarder macht einen Trick in einem Kaufhaus .
--------------------------------------------------------------------------------
    SOURCE: Five little girls posing in matching dance outfits.
    TARGET: Fünf kleine Mädchen posieren in einheitlichen Tanzoutfits.
 PREDICTED: Fünf kleine Mädchen in passenden Kleidern posieren .
--------------------------------------------------------------------------------


Processing Epoch 07: 100%|██████████| 3263/3263 [02:15<00:00, 24.14it/s, loss=2.317]


--------------------------------------------------------------------------------
    SOURCE: A man is engulfed in flames while two movie crewmen supervise.
    TARGET: Ein Mann ist in Flammen gehüllt und wird von zwei Kollegen der Filmcrew überwacht.
 PREDICTED: Ein Mann ist dabei , zwei K ü schaft s in der Ferne zu machen .
--------------------------------------------------------------------------------
    SOURCE: A seated boy plays the accordion.
    TARGET: Ein sitzender Junge spielt Akkordeon.
 PREDICTED: Ein sitzender Junge spielt Akkordeon .
--------------------------------------------------------------------------------


Processing Epoch 08: 100%|██████████| 3263/3263 [02:18<00:00, 23.59it/s, loss=2.211]


--------------------------------------------------------------------------------
    SOURCE: Several bikers on a stone road, with spectators watching.
    TARGET: Mehrere Radfahrer fahren auf einer Steinstraße mit Zuschauern.
 PREDICTED: Mehrere Biker auf einer gepflasterten Straße , während Zuschauer zusehen .
--------------------------------------------------------------------------------
    SOURCE: This man in the yellow shirt is adjusting blue bicycle for a young boy.
    TARGET: Der Mann in dem gelben Oberteil stellt ein blaues Fahrrad für einen kleinen Jungen ein.
 PREDICTED: Der Mann in gelbem Hemd schneidet sein blaues Fahrrad für einen kleinen Jungen .
--------------------------------------------------------------------------------


Processing Epoch 09: 100%|██████████| 3263/3263 [02:17<00:00, 23.77it/s, loss=2.123]


--------------------------------------------------------------------------------
    SOURCE: Two men during a football game
    TARGET: Zwei Männer während eines Fußballspiels
 PREDICTED: Zwei Männer bei einem Football - Spiel .
--------------------------------------------------------------------------------
    SOURCE: A guy in his scuba gear is having a conversation with other men.
    TARGET: Ein Mann in Tauchausrüstung unterhält sich mit anderen Männern.
 PREDICTED: Ein Mann in seinem Taucher ausrüstung unterhält sich mit anderen Männern .
--------------------------------------------------------------------------------


Processing Epoch 10: 100%|██████████| 3263/3263 [02:15<00:00, 24.02it/s, loss=1.958]


--------------------------------------------------------------------------------
    SOURCE: An elderly woman wearing a bathing suit is holding a picket sign while participating in a demonstration about immigration.
    TARGET: Eine ältere Frau im Badeanzug nimmt an einer Demonstration zum Thema Immigration teil und hält ein Plakat.
 PREDICTED: Eine ältere Frau in Badeanzug hält ein Schild mit einem Plakat , während sie an einem Plakat arbeitet .
--------------------------------------------------------------------------------
    SOURCE: A group of children in brown outfits are performing.
    TARGET: Eine Gruppe von Kindern in brauner Kleidung führt etwas vor.
 PREDICTED: Eine Gruppe von Kindern in braunen Outfits führt etwas vor .
--------------------------------------------------------------------------------


Processing Epoch 11: 100%|██████████| 3263/3263 [02:31<00:00, 21.48it/s, loss=1.856]


--------------------------------------------------------------------------------
    SOURCE: Three women and one man are having a friendly conversation in an office.
    TARGET: Drei Frauen und ein Mann unterhalten sich freundlich in einem Büro.
 PREDICTED: Drei Frauen und ein Mann essen eine freund schaft liche Unterhaltung .
--------------------------------------------------------------------------------
    SOURCE: A boy jumps into a pool.
    TARGET: Ein Junge springt in einen Pool.
 PREDICTED: Ein Junge springt in ein Schwimmbecken .
--------------------------------------------------------------------------------


Processing Epoch 12: 100%|██████████| 3263/3263 [02:21<00:00, 23.09it/s, loss=1.925]


--------------------------------------------------------------------------------
    SOURCE: People gather for a farmers market day, shopping for groceries and clothing on a sunny day.
    TARGET: Personen versammeln sich auf einem Wochenmarkt, um Lebensmittel und Kleidung an einem sonnigen Tag zu kaufen.
 PREDICTED: Menschen versammeln sich zum Grillen , zum Essen und west liche Kleidung an einem sonnigen Tag .
--------------------------------------------------------------------------------
    SOURCE: Military personnel learning how to shoot their rifles.
    TARGET: Soldaten lernen mit ihren Gewehren zu schießen.
 PREDICTED: Militärangehörige , der versucht , das Spiel des Waffen zu schießen .
--------------------------------------------------------------------------------


Processing Epoch 13: 100%|██████████| 3263/3263 [02:19<00:00, 23.47it/s, loss=1.911]


--------------------------------------------------------------------------------
    SOURCE: Two street artists are performing on steps for a group of spectators.
    TARGET: Zwei Straßenkünstler geben auf einer Treppe eine Vorstellung für Zuschauergruppe.
 PREDICTED: Zwei Straßenkünstler führen auf Stufen für eine Gruppe von Zuschauern auf .
--------------------------------------------------------------------------------
    SOURCE: A bicyclist in blue parked inside of an all wood building.
    TARGET: Ein blau gekleideter Fahrradfahrer parkt in einem Holzhaus.
 PREDICTED: Ein Radfahrer in Blau parkt in einem noch Holz raum .
--------------------------------------------------------------------------------


Processing Epoch 14: 100%|██████████| 3263/3263 [02:20<00:00, 23.27it/s, loss=1.730]


--------------------------------------------------------------------------------
    SOURCE: A young girl in a pink shirt creates a painting on paper.
    TARGET: Ein kleines Mädchen in einem rosa Oberteil malt auf Papier.
 PREDICTED: Ein junges Mädchen in einem rosafarbenen Oberteil malt ein Papier auf Papier .
--------------------------------------------------------------------------------
    SOURCE: A crowd watching a man in white pants using an axe.
    TARGET: Eine Menschenmenge schaut einem Mann in weißen Hosen und mit einer Axt zu.
 PREDICTED: Eine Menschenmenge beobachtet einen Mann in weißer Hose , der mit einer Axt benutzt .
--------------------------------------------------------------------------------


Processing Epoch 15: 100%|██████████| 3263/3263 [02:20<00:00, 23.26it/s, loss=1.835]


--------------------------------------------------------------------------------
    SOURCE: Two girls in matching pink and white dresses and two smaller boys in matching black and white shirts.
    TARGET: Zwei Mädchen in zueinander passenden rosafarbenen und weißen Kleidern und zwei kleinere Jungen in zueinander passenden schwarz-weißen Hemden.
 PREDICTED: Zwei Mädchen in farblich passenden pinkfarbenen und weißen Kleidern und zwei kleine Jungen in schwarzen Hemden und weißen Hemden .
--------------------------------------------------------------------------------
    SOURCE: 3 dark-skinned males, two shirtless, are on a grassy area with trees scattered around them.
    TARGET: Drei dunkelhäutige Männer, von denen zwei kein Hemd anhaben, befinden sich auf einer Grasfläche mit vereinzelten Bäumen um sie herum.
 PREDICTED: Drei dunkelhäutige Männer in schwarzen Hemden , zwei mit freiem Oberkörper , sind auf einer Grasfläche .
------------------------------------------------------------

Processing Epoch 16: 100%|██████████| 3263/3263 [02:21<00:00, 23.12it/s, loss=1.833]


--------------------------------------------------------------------------------
    SOURCE: Two men speak closely at a party.
    TARGET: Zwei Männer reden eng beieinander auf einer Party.
 PREDICTED: Zwei Männer sprechen bei einer Feier miteinander .
--------------------------------------------------------------------------------
    SOURCE: A man is looking on as another man attempts to climb a small boulder with his dirt bike.
    TARGET: Ein Mann sieht zu, wie ein anderer Mann versucht, mit seinem Dirtbike einen kleinen Felsblock hinaufzuklettern.
 PREDICTED: Ein Mann schaut zu , wie ein anderer Mann versucht , mit seinem Geländemotorrad einen kleinen Rad hochzuklettern .
--------------------------------------------------------------------------------


Processing Epoch 17: 100%|██████████| 3263/3263 [02:26<00:00, 22.34it/s, loss=1.695]


--------------------------------------------------------------------------------
    SOURCE: Crowd watching airplane and helicopter in the sky.
    TARGET: Eine Menschenmenge beobachten ein Flugzeug und einen Hubschrauber am Himmel.
 PREDICTED: Eine Menschenmenge beobachtet im Himmel und einem Hubschrauber .
--------------------------------------------------------------------------------
    SOURCE: A brown-haired girl stoops to pick up a small white puppy on an empty street.
    TARGET: Ein braunhaariges Mädchen bückt sich, um auf einer leeren Straße einen kleinen weißen Hund hochzuheben.
 PREDICTED: Ein braunhaariges Mädchen geht auf einer leeren Straße nach einem kleinen weißen Welpen spazieren .
--------------------------------------------------------------------------------


Processing Epoch 18: 100%|██████████| 3263/3263 [02:16<00:00, 23.99it/s, loss=1.690]


--------------------------------------------------------------------------------
    SOURCE: Several military men surrounding women who are sitting on the ground with flags.
    TARGET: Mehrere männliche Soldaten kreisen Frauen mit Fahnen ein, die auf dem Boden sitzen.
 PREDICTED: Mehrere Soldaten stehen um eine sitzende Frau , die auf dem Boden sitzen .
--------------------------------------------------------------------------------
    SOURCE: A woman in a black dress holding sunglasses.
    TARGET: Eine Frau im schwarzen Kleis hält eine Sonnenbrille in der Hand.
 PREDICTED: Eine Frau in einem schwarzen Kleid hält eine Sonnenbrille .
--------------------------------------------------------------------------------


Processing Epoch 19: 100%|██████████| 3263/3263 [02:16<00:00, 23.82it/s, loss=1.725]


--------------------------------------------------------------------------------
    SOURCE: Hands with painted fingernails unscrewing nail polish.
    TARGET: Hände mit lackierten Fingernägeln schrauben Nagellack auf.
 PREDICTED: Ein Rudel w in der Hocke , der strahlend den Stock durch den Schritt ist .
--------------------------------------------------------------------------------
    SOURCE: A guy in a hard hat working on some machines.
    TARGET: Ein Mann mit Schutzhelm arbeitet an Maschinen.
 PREDICTED: Ein Mann mit Schutzhelm arbeitet an Maschinen .
--------------------------------------------------------------------------------


Processing Epoch 20:  11%|█         | 359/3263 [00:15<02:03, 23.52it/s, loss=1.521]


KeyboardInterrupt: 