#### About
Machine Translation using Transformers.

Dataset link - https://www.kaggle.com/datasets/kaushal2896/english-to-german?select=deu.txt

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [3]:
#neccessary imports 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
from torchtext.data.metrics import bleu_score
from sklearn.model_selection import train_test_split
import unicodedata
import re
import spacy
from torch import Tensor
!pip install einops
from einops import rearrange
from torch.nn import (TransformerEncoder, TransformerDecoder,
                      TransformerEncoderLayer, TransformerDecoderLayer)
import math
from torch.nn.utils.rnn import pad_sequence
from torch.utils.tensorboard import SummaryWriter

!python -m spacy download en_core_web_sm
!python -m spacy download de_core_news_sm



Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.0
2023-02-12 05:59:50.760407: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-02-12 05:59:50.760574: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-02-12 

In [72]:
df = pd.read_csv('/content/drive/MyDrive/Datasets/deu.txt',delimiter='\t',header=None)
df.columns = ['English','German','Source']


In [73]:
df.head()

Unnamed: 0,English,German,Source
0,Go.,Geh.,CC-BY 2.0 (France) Attribution: tatoeba.org #2...
1,Hi.,Hallo!,CC-BY 2.0 (France) Attribution: tatoeba.org #5...
2,Hi.,Grüß Gott!,CC-BY 2.0 (France) Attribution: tatoeba.org #5...
3,Run!,Lauf!,CC-BY 2.0 (France) Attribution: tatoeba.org #9...
4,Run.,Lauf!,CC-BY 2.0 (France) Attribution: tatoeba.org #4...


In [74]:
len(df)

221533

In [75]:
#printing few samples:
for i in range(1,len(df),100000):
    print("English Text - {}".format(df['English'][i]))
    print("German Text - {}".format(df['German'][i]))
    print('End of sample- {}'.format(i))

English Text - Hi.
German Text - Hallo!
End of sample- 1
English Text - This may be our only chance.
German Text - Das könnte unsere einzige Chance sein.
End of sample- 100001
English Text - This book deals with the invasion of the Romans.
German Text - Dieses Buch handelt von der Invasion der Römer.
End of sample- 200001


In [76]:
#splitting into train,val
train_df,val_df = train_test_split(df, test_size=0.2, shuffle=True, random_state=42)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

In [77]:
#preprocessing the dataframe
train_df.dropna(inplace=True)
val_df.dropna(inplace=True)

In [78]:
train_df.head()

Unnamed: 0,English,German,Source
0,Get out of here quickly!,Schnell raus hier!,CC-BY 2.0 (France) Attribution: tatoeba.org #8...
1,Did you wash that?,Haben Sie das gewaschen?,CC-BY 2.0 (France) Attribution: tatoeba.org #3...
2,"Unless you make a decision quickly, the opport...","Wenn du dich nicht schnell entscheidest, ist d...",CC-BY 2.0 (France) Attribution: tatoeba.org #1...
3,How did you know Tom was going to Boston?,"Woher wusstest du, dass Tom nach Bosten fliegt?",CC-BY 2.0 (France) Attribution: tatoeba.org #3...
4,Tom is humming.,Tom summt gerade.,CC-BY 2.0 (France) Attribution: tatoeba.org #2...


In [79]:
val_df.head()

Unnamed: 0,English,German,Source
0,He was lying on the couch.,Er lag auf dem Sofa.,CC-BY 2.0 (France) Attribution: tatoeba.org #2...
1,"All generalizations are false, including this ...","Alle Verallgemeinerungen sind falsch, einschli...",CC-BY 2.0 (France) Attribution: tatoeba.org #6...
2,She dyed her white skirt red.,Sie färbte ihren weißen Rock rot.,CC-BY 2.0 (France) Attribution: tatoeba.org #3...
3,Why don't you go play with Tom?,Warum gehst du zum Spielen nicht zu Tom?,CC-BY 2.0 (France) Attribution: tatoeba.org #2...
4,She picked him up at the station.,Sie holte ihn am Bahnhof ab.,CC-BY 2.0 (France) Attribution: tatoeba.org #8...


In [80]:
#cleaning the text
#turning unicode string to plain ASCII

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )
#clean text by converting to lower case, removing non -letter characters
def clean_text(text):
    text = unicodeToAscii(text.lower().strip())
    text = re.sub(r"([.!?])", r" \1", text)
    text = re.sub("[.!?]", '', text)
    text = re.sub(r"[^a-zA-Z.!?]+", r" ", text)
    return text

In [81]:
#applying the clean_text method to df
train_df["English"] = train_df["English"].apply(clean_text)
train_df["German"] = train_df["German"].apply(clean_text)

val_df["English"] = val_df["English"].apply(clean_text)
val_df["German"] = val_df["German"].apply(clean_text)


In [82]:
train_df.head()

Unnamed: 0,English,German,Source
0,get out of here quickly,schnell raus hier,CC-BY 2.0 (France) Attribution: tatoeba.org #8...
1,did you wash that,haben sie das gewaschen,CC-BY 2.0 (France) Attribution: tatoeba.org #3...
2,unless you make a decision quickly the opportu...,wenn du dich nicht schnell entscheidest ist di...,CC-BY 2.0 (France) Attribution: tatoeba.org #1...
3,how did you know tom was going to boston,woher wusstest du dass tom nach bosten fliegt,CC-BY 2.0 (France) Attribution: tatoeba.org #3...
4,tom is humming,tom summt gerade,CC-BY 2.0 (France) Attribution: tatoeba.org #2...


In [83]:
val_df.head()

Unnamed: 0,English,German,Source
0,he was lying on the couch,er lag auf dem sofa,CC-BY 2.0 (France) Attribution: tatoeba.org #2...
1,all generalizations are false including this one,alle verallgemeinerungen sind falsch einschlie...,CC-BY 2.0 (France) Attribution: tatoeba.org #6...
2,she dyed her white skirt red,sie farbte ihren wei en rock rot,CC-BY 2.0 (France) Attribution: tatoeba.org #3...
3,why don t you go play with tom,warum gehst du zum spielen nicht zu tom,CC-BY 2.0 (France) Attribution: tatoeba.org #2...
4,she picked him up at the station,sie holte ihn am bahnhof ab,CC-BY 2.0 (France) Attribution: tatoeba.org #8...


In [84]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [85]:
start_token = 1
end_token =2
oov_token=3

class Vocabulary:
    def __init__(self, language):
        self.language = language
        self.stoi = {}#string2index
        self.stoc = {} # string2count
        self.itos = {0:"<PAD>",start_token:"<START>",end_token:"<END>"}#index2sting #additional- add ,oov_token:"<OOV>"
        for k,v in self.itos.items():
            self.stoi[v]=k
        self.num_words =31
        if self.language == "English":
            self.tokenizer = spacy.load("en_core_web_sm")
        elif self.language == "German":
            self.tokenizer = spacy.load("de_core_news_sm")
        else:
            print("Invalid language")

    def tokenize_sent(self,sentence):
        return [token.text.lower() for token in self.tokenizer.tokenizer(sentence)]
    
    def add_word(self,word):
        if word not in self.stoi:
            self.stoi[word]=self.num_words
            self.stoc[word]=1
            self.itos[self.num_words]=word
            self.num_words+=1
        else:
            self.stoc[word]+=1

    def build_vocab(self,sentence):
        for word in self.tokenize_sent(sentence):
            self.add_word(word)

    def process_sentence(self,sentence):
        
        return [self.stoi[token] for token in self.tokenize_sent(sentence)]

        

In [86]:
#building vocab for english and german
eng_vocab = Vocabulary("English")
ger_vocab = Vocabulary("German")

for english_sentence,german_sentence in zip(train_df["English"].values.tolist(),train_df["German"].values.tolist()):
    eng_vocab.build_vocab(english_sentence)
    ger_vocab.build_vocab(german_sentence)

In [87]:
print("Length of English vocab - {}".format(len(eng_vocab.stoi)))
print("Length of German vocab - {}".format(len(ger_vocab.stoi)))


Length of English vocab - 14630
Length of German vocab - 30955


In [88]:
# creating dataset
class TranslationDataset(Dataset):
    def __init__(self,dataframe,eng_vocab,ger_vocab):
        super(TranslationDataset,self).__init__()
        self.dataframe = dataframe
        self.eng_vocab = eng_vocab
        self.ger_vocab = ger_vocab
        self.english = self.dataframe['English'].values.tolist()
        self.german = self.dataframe['German'].values.tolist()
    
    def __len__(self):
        return len(self.dataframe)

    def process_sent(self,sent,vocab):
        #starting each sentence with start_token and ending with end_token
        processed_sent = [vocab.stoi["<START>"]]
        processed_sent.extend(vocab.process_sentence(sent))
        processed_sent.append(vocab.stoi["<END>"])
        return processed_sent

    def __getitem__(self, index):
        processed_eng_sent = self.process_sent(self.english[index],self.eng_vocab)
        processed_ger_sent = self.process_sent(self.german[index],self.ger_vocab)

        
        item = {'input': torch.tensor(processed_eng_sent), 'output':torch.tensor(processed_ger_sent)}
        return item

In [89]:
train_dataset = TranslationDataset(train_df,eng_vocab,ger_vocab)
val_dataset = TranslationDataset(val_df,eng_vocab,ger_vocab)

In [90]:
train_dataset.__getitem__(5)

{'input': tensor([ 1, 58, 59, 44, 60, 61, 62, 63, 64, 65, 66, 67,  2]),
 'output': tensor([ 1, 56, 57, 58, 59, 60, 61, 62, 63, 64,  2])}

In [91]:
#collate_function
class Collater(object):
    def __init__(self, pad_index):
        self.pad_index = pad_index

    def __call__(self, batch):

        input = [item['input'] for item in batch]
        output = [item['output'] for item in batch]
        input = pad_sequence(input, batch_first=False, padding_value=self.pad_index)
        output = pad_sequence(output, batch_first=False, padding_value=self.pad_index)
        #can do via collater too, but we'll unpack during training
        # inputs = input.detach().numpy().copy()
        # outputs = output.detach().numpy().copy()
        # input_mask, output_mask = [],[]
        # #generating mask for input
        # for input_seq in inputs:
        #     input_mask_seq = []
        #     for input_token in input_seq:
        #         if input_token ==0:
        #             input_mask_seq.append(False)
        #         else:
        #             input_mask_seq.append(True)
        #     input_mask.append(input_mask_seq)
        # #generating mask for output
        # for output_seq in outputs:
        #     output_mask_seq = []
        #     for output_token in output_seq:
        #         if output_token ==0:
        #             output_mask_seq.append(False)
        #         else:
        #             output_mask_seq.append(True)
        #     output_mask.append(output_mask_seq)
        
        # input_mask = torch.tensor(input_mask)
        # output_mask = torch.tensor(output_mask)

        #item = {'input':input,'input_mask':input_mask, 'output':output,'output_mask':output_mask}
        item = {'input':input, 'output':output}
        return item



In [92]:
#creating dataloaders
batch_size=1
pad_idx = eng_vocab.stoi["<PAD>"]

train_loader = DataLoader(train_dataset,batch_size, num_workers=4, shuffle=False,pin_memory=True, collate_fn=Collater(pad_idx))
val_loader = DataLoader(val_dataset,batch_size, num_workers=4, shuffle=False,pin_memory=True, collate_fn=Collater(pad_idx))



In [93]:
for i, batch in enumerate(train_loader):
    
    print("Input batch details ",batch['input'],batch['input'].shape)
    print("_______________________________")
    #print("Input mask batch details ",batch['input_mask'],batch['input_mask'].shape)
    print("_______________________________")
    print("Output batch details", batch['output'],batch['output'].shape)
    print("_______________________________")
    #print("Output mask batch details", batch['output_mask'],batch['output_mask'].shape)

    break

Input batch details  tensor([[ 1],
        [31],
        [32],
        [33],
        [34],
        [35],
        [ 2]]) torch.Size([7, 1])
_______________________________
_______________________________
Output batch details tensor([[ 1],
        [31],
        [32],
        [33],
        [ 2]]) torch.Size([5, 1])
_______________________________


In [94]:
class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int, dropout, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding +
                            self.pos_embedding[:token_embedding.size(0),:])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [95]:
INPUT_DIM = len(eng_vocab.stoi)
OUTPUT_DIM = len(ger_vocab.stoi)

ENC_EMB_DIM = 32
DEC_EMB_DIM = 32
ENC_HID_DIM = 64
DEC_HID_DIM = 64
ATTN_DIM = 8
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5

# Source: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(INPUT_DIM, ENC_EMB_DIM)
        self.tgt_embedding = nn.Embedding(INPUT_DIM, ENC_EMB_DIM)
        self.transformer = nn.Transformer(nhead=8, num_encoder_layers=2, d_model=ENC_EMB_DIM)
        self.linear = nn.Linear(ENC_EMB_DIM, OUTPUT_DIM)
        pos_dropout = 0.1
        max_seq_length = 128
        self.pos_enc = PositionalEncoding(ENC_EMB_DIM, pos_dropout, max_seq_length)
    
    def forward(self, src, tgt, teacher_forcing_ratio=0.5, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, tgt_mask=None):
        # TODO: Investigate masks, positional encoding, understand Rearrange(), debug model output (output has negative numbers for some reason)
        # Original src shape: (sentence length=24?, batch_size=128)
        # Original tgt shape: (sentence length=24?, batch_size=128)
        # Transformer expects: (sentence length=24, batch_size=128, embedding_size=128)
        src_emb = self.pos_enc(self.src_embedding(src) * math.sqrt(ENC_EMB_DIM))
        tgt_emb = self.pos_enc(self.tgt_embedding(tgt) * math.sqrt(ENC_EMB_DIM))
        out = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        out = self.linear(out)
        return out

In [96]:

writer = SummaryWriter(f'runs/loss_plot')

model = TransformerModel().to(device)

In [97]:
model

TransformerModel(
  (src_embedding): Embedding(14630, 32)
  (tgt_embedding): Embedding(14630, 32)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
          )
          (linear1): Linear(in_features=32, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=32, bias=True)
          (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_fe

In [98]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=0)#0 for <PAD>

In [99]:
#using xavier normal init in the transformer
for param in model.parameters():
    if param.dim() >1:
        nn.init.xavier_normal_(param)


In [100]:
def gen_nopeek_mask(length):
    mask = rearrange(torch.triu(torch.ones(length, length)) == 1, 'h w -> w h')
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

    return mask

In [101]:
PAD_IDX =0
def indices_to_string(LANGUAGE, batch):
    if LANGUAGE == "ENGLISH":
        vocab = eng_vocab
    else:
        vocab = ger_vocab
    words_list = []
    for sentence in batch.transpose(1, 0):
        sentence_list = sentence.tolist()
        words = []
        for index in sentence_list:
            word = vocab.itos[index]
            words.append(word)
        words_list.append(words)
    return words_list


In [102]:
#Training


step=0
num_epochs=10
def fit(num_epochs,train_loader,val_loader, model, optimizer,criterion):
    model.train()
    for epoch in range(num_epochs):
        checkpoint = {'state_dict':model.state_dict(), 'optimizer':optimizer.state_dict()}
        if epoch%10==0:
            torch.save(checkpoint,"checkpoint.pth.tar")
        #loading batch
        for i,batch in enumerate(train_loader):
            input_seq = batch['input'].to(device)
            #input_mask = batch['input_mask'].to(device)
            output_seq = batch['output'].to(device)
            #output_mask = batch['output_mask'].to(device)
            input_padding_mask = input_seq == PAD_IDX
            output_padding_mask = output_seq == PAD_IDX
            #input_mask, output_mask, input_padding_mask, output_padding_mask = create_mask(input_seq, output_seq)
            memory_key_padding_mask = input_padding_mask.clone()
            input_padding_mask = rearrange(input_padding_mask, 'n s -> s n')
            output_padding_mask = rearrange(output_padding_mask, 'n s -> s n')
            memory_key_padding_mask = rearrange(memory_key_padding_mask, 'n s -> s n')
            
            tgt_sentence_len = output_seq.shape[0] - torch.sum(output_padding_mask,axis=1)
            tgt_inp, tgt_out = output_seq[:,:], output_seq[:,:]
            tgt_mask = gen_nopeek_mask(output_seq.shape[0])#.to('cuda') #('cuda)
            tgt_mask = tgt_mask.to(device)
            output = model(input_seq,tgt_inp,0,input_padding_mask, output_padding_mask,memory_key_padding_mask,tgt_mask)

            #generating one hot
            from_one_hot = torch.argmax(output,dim=2)

            output = output.view(-1, output.shape[-1])
            tgt_out = tgt_out.view(-1)
            loss = criterion(output,tgt_out)
            loss.backward()
            optimizer.step()

        #validation
        model.eval()
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                input_seq = batch['input'].to(device)
                output_seq = batch['output'].to(device)
                input_padding_mask = input_seq == PAD_IDX
                output_padding_mask = output_seq == PAD_IDX
                #input_mask, output_mask, input_padding_mask, output_padding_mask = create_mask(input_seq, output_seq)
                memory_key_padding_mask = input_padding_mask.clone()
                input_padding_mask = rearrange(input_padding_mask, 'n s -> s n')
                output_padding_mask = rearrange(output_padding_mask, 'n s -> s n')
                memory_key_padding_mask = rearrange(memory_key_padding_mask, 'n s -> s n')
                
                tgt_sentence_len = output_seq.shape[0] - torch.sum(output_padding_mask,axis=1)
                tgt_inp, tgt_out = output_seq[:,:], output_seq[:,:]
                tgt_mask = gen_nopeek_mask(output_seq.shape[0])#.to('cuda') #('cuda)
                tgt_mask = tgt_mask.to(device)
                output = model(input_seq,tgt_inp,0,input_padding_mask, output_padding_mask,memory_key_padding_mask,tgt_mask)

                #generating one hot
                from_one_hot = torch.argmax(output,dim=2)

                output = output.view(-1, output.shape[-1])
                src_words = indices_to_string("ENGLISH",input_seq)
                predicted_words = indices_to_string("GERMAN",from_one_hot)
                tgt_words = indices_to_string("GERMAN",output_seq)
                print("Input English word - {}".format(src_words))
                print('Output German word - {}'.format(predicted_words))
                print("Actual German Word - {}".format(tgt_words))
                output_seq = output_seq.view(-1)
                val_loss = criterion(output,output_seq)

        print("Epoch - {}, Train Loss - {}, Val Loss - {}".format(epoch,loss.item(),val_loss.item()))
        writer.add_scalar("Train loss",loss, global_step=step)
        step+=1    

In [None]:
fit(1,train_loader,val_loader,model,optimizer,criterion)

In [None]:

def string_to_indices(LANGUAGE, sentence):
    words = sentence.split()
    indices = []
    if LANGUAGE == "ENGLISH":
        vocab = eng_vocab
    else:
        vocab = ger_vocab
    for word in words:
        if word in vocab.stoi:
            index = vocab.stoi[word]
            indices.append(index)
        else:
            index = 0 # should ideally be unknown but we are assuming it's pad
            indices.append(index)
    result = torch.tensor(indices)
    return result

def translate_english_to_german(model, example_sentence_src):
    # Translate example sentence
    example_tensor_src = string_to_indices("ENGLISH", example_sentence_src).view(-1, 1)
    example_sentence_tgt = '<START>'
    example_tensor_tgt = string_to_indices("GERMAN", example_sentence_tgt).view(-1, 1)
    src = example_tensor_src.to(device)
    tgt = example_tensor_tgt.to(device)

    for i in range(128):
        print('Source Translation', src)
        print('Target Translation', tgt)
        src_key_padding_mask = src == PAD_IDX
        tgt_key_padding_mask = tgt == PAD_IDX
        memory_key_padding_mask = src_key_padding_mask.clone()
        src_key_padding_mask = rearrange(src_key_padding_mask, 'n s -> s n')
        tgt_key_padding_mask = rearrange(tgt_key_padding_mask, 'n s -> s n')
        memory_key_padding_mask = rearrange(memory_key_padding_mask, 'n s -> s n')
        tgt_mask = gen_nopeek_mask(tgt.shape[0]).to('cpu')
        print('Target Mask:', tgt_mask)

        output = model(src, tgt, 0, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask, tgt_mask=tgt_mask) 

        print('Output:', output)
        # TODO: Check that the argmax line is correct
        output_index = torch.argmax(output, dim=2)[-1].item()
        output_word = ger_vocab.itos[output_index]
        example_sentence_tgt = example_sentence_tgt + ' ' + output_word
        print('Translated sentence so far:', example_sentence_tgt)
        example_tensor_tgt = string_to_indices("GERMAN", example_sentence_tgt).view(-1, 1)
        tgt = example_tensor_tgt.to(device)
        if output_word == '<END>':
            break

In [None]:
example = "<START> This is me <END> <PAD> <PAD> <PAD> <PAD>"
print(translate_english_to_german(model,example))