In [1]:
!pip install -q gdown

In [None]:
# download model
!gdown 1-1VRanRyciw8IOMnAoE9p7ftvpWRNS7w

In [2]:
# download dataset
!gdown 1sqaarcZFTvB2mGVTwpSOc_srNSw2GgCQ --output transformer_data.zip

Downloading...
From (original): https://drive.google.com/uc?id=1sqaarcZFTvB2mGVTwpSOc_srNSw2GgCQ
From (redirected): https://drive.google.com/uc?id=1sqaarcZFTvB2mGVTwpSOc_srNSw2GgCQ&confirm=t&uuid=aaa03b6e-16b2-407d-9a92-3f8480c7cacd
To: /kaggle/working/transformer_data.zip
100%|████████████████████████████████████████| 799M/799M [00:10<00:00, 78.6MB/s]


In [3]:
!pip install -q youtokentome wget sacrebleu

# Imports

In [5]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

import logging

import math
import os
import tarfile
import wget
import shutil
import itertools
import time
import codecs
from tqdm import tqdm
from datetime import date
import sacrebleu

from random import shuffle
import youtokentome as yttm

from itertools import groupby

# Download data

In [None]:
def flatten(destination):
    all_files = []
    for root, _dirs, files in itertools.islice(os.walk(destination), 1, None):
        for filename in files:
            shutil.move(os.path.join(root, filename), destination)
        shutil.rmtree(root)

def download_data(data_folder):
    train_urls = ["http://www.statmt.org/wmt13/training-parallel-europarl-v7.tgz",
                  "https://www.statmt.org/wmt13/training-parallel-commoncrawl.tgz",
                  "http://www.statmt.org/wmt14/training-parallel-nc-v9.tgz",
                 ]

    # Create a folder to store downloaded TAR files
    tars_dir = os.path.join(data_folder, "tar_files")
    os.makedirs(tars_dir, exist_ok=True)

    # Create a fresh folder to extract downloaded TAR files; previous extractions deleted to prevent tarfile module errors
    extracted_dir = os.path.join(data_folder, "extracted_files")
    if os.path.isdir(extracted_dir):
        shutil.rmtree(extracted_dir)
        os.mkdir(extracted_dir)

    # Download and extract training data
    for url in train_urls:
        filename = url.split("/")[-1]

        train_data_path = os.path.join(data_folder, "tar_files", filename)

        if not os.path.exists(train_data_path):
            print(f"\nDownloading {filename}...")
            wget.download(url, train_data_path)

        print(f"\nExtracting {filename}...")
        tar = tarfile.open(train_data_path)
        members = [m for m in tar.getmembers() if "de-en" in m.path]
        tar.extractall(extracted_dir, members=members)

     # Download validation and testing data using sacreBLEU since we will be using this library to calculate BLEU scores
    print("\n")
    os.system("sacrebleu -t wmt13 -l en-de --echo src > '" + os.path.join(data_folder, "val.en") + "'")
    os.system("sacrebleu -t wmt13 -l en-de --echo ref > '" + os.path.join(data_folder, "val.de") + "'")
    print("\n")
    os.system("sacrebleu -t wmt14/full -l en-de --echo src > '" + os.path.join(data_folder, "test.en") + "'")
    os.system("sacrebleu -t wmt14/full -l en-de --echo ref > '" + os.path.join(data_folder, "test.de") + "'")

    flatten(extracted_dir)

In [None]:
data_folder="/kaggle/working/transformer_data"

In [None]:
download_data(data_folder)

# Prepare data

In [None]:
def prepare_data(data_folder,
                 euro_parl=True,
                 common_crawl=True,
                 news_commentary=True,
                 retain_case=True):
    # Read raw files and combine
    german = list()
    english = list()
    files = list()

    assert euro_parl or common_crawl or news_commentary, "Set at least one dataset to True!"

    if euro_parl: files.append("europarl-v7.de-en")
    if common_crawl: files.append("commoncrawl.de-en")
    if news_commentary: files.append("news-commentary-v9.de-en")

    extracted_dir = os.path.join(data_folder, "extracted_files")

    print("\nReading extracted files and combining...")
    for file in files:
        with codecs.open(os.path.join(extracted_dir, file + ".de"), "r", encoding="utf-8") as f:
            if retain_case:
                german.extend(f.read().split("\n"))
            else:
                german.extend(f.read().lower().split("\n"))

        with codecs.open(os.path.join(extracted_dir, file + ".en"), "r", encoding="utf-8") as f:
            if retain_case:
                english.extend(f.read().split("\n"))
            else:
                english.extend(f.read().lower().split("\n"))

        assert len(english) == len(german)

    # Write to file so stuff can be freed from memory
    print("\nWriting to single files...")
    with codecs.open(os.path.join(data_folder, "train.en"), "w", encoding="utf-8") as f:
        f.write("\n".join(english))
    with codecs.open(os.path.join(data_folder, "train.de"), "w", encoding="utf-8") as f:
        f.write("\n".join(german))
    with codecs.open(os.path.join(data_folder, "train.ende"), "w", encoding="utf-8") as f:
        f.write("\n".join(english + german))
    del english, german # free some RAM

    bpe_model_path = os.path.join(data_folder, "bpe.model")

    # Perform BPE
    print("\nLearning BPE...")
    yttm.BPE.train(data=os.path.join(data_folder, "train.ende"),
                           vocab_size=37000,
                           model=bpe_model_path)

    # Load BPE model
    print("\nLoading BPE model...")
    bpe_model = yttm.BPE(model=bpe_model_path)

    return bpe_model, bpe_model_path

In [None]:
bpe_model, bpe_model_path = prepare_data(data_folder)

In [None]:
bpe_model.vocab_size()

In [None]:
bpe_model.encode(['There was no ring on his finger.',
                  'That was a good sign although far from proof that he was available.'],
                 output_type=yttm.OutputType.SUBWORD,
                 bos=True, eos=True)

# Filter data

In [None]:
def filter(tokenizer,
           data_folder,
           min_length=3,
           max_length=100,
           max_length_ratio=1.5):
    
    # Re-read English, German
    print("\nRe-reading single files...")
    with codecs.open(os.path.join(data_folder, "train.en"), "r", encoding="utf-8") as f:
        english = f.read().split("\n")
    with codecs.open(os.path.join(data_folder, "train.de"), "r", encoding="utf-8") as f:
        german = f.read().split("\n")

    # Filter
    print("\nFiltering...")
    pairs = list()
    for en, de in tqdm(zip(english, german), total=len(english)):
        en_tok = tokenizer.encode(en, output_type=yttm.OutputType.ID)
        de_tok = tokenizer.encode(de, output_type=yttm.OutputType.ID)

        len_en_tok = len(en_tok)
        len_de_tok = len(de_tok)

        if min_length < len_en_tok < max_length and \
                min_length < len_de_tok < max_length and \
                1. / max_length_ratio <= len_de_tok / len_en_tok <= max_length_ratio:
            pairs.append((en, de))
        else:
            continue

    print("\nNote: %.2f per cent of en-de pairs were filtered out based on sub-word sequence length limits." % (100. * (
            len(english) - len(pairs)) / len(english)))

    english, german = zip(*pairs)

    print("\nRe-writing filtered sentences to single files...")
    os.remove(os.path.join(data_folder, "train.en"))
    os.remove(os.path.join(data_folder, "train.de"))
    os.remove(os.path.join(data_folder, "train.ende"))

    with codecs.open(os.path.join(data_folder, "train.en"), "w", encoding="utf-8") as f:
        f.write("\n".join(english))
    with codecs.open(os.path.join(data_folder, "train.de"), "w", encoding="utf-8") as f:
        f.write("\n".join(german))

    del english, german, pairs

In [None]:
filter(bpe_model, data_folder)

# Zip prepared data

In [None]:
!zip -r transformer_data.zip  transformer_data -x "transformer_data/extracted_files/*" "transformer_data/tar_files/*"

In [None]:
from IPython.display import FileLink
FileLink(r'transformer_data.zip')

# Unzip prepared data

In [6]:
!unzip -o transformer_data.zip -d /

Archive:  transformer_data.zip
   creating: /kaggle/working/transformer_data/
  inflating: /kaggle/working/transformer_data/val.en  
  inflating: /kaggle/working/transformer_data/train.de  
  inflating: /kaggle/working/transformer_data/bpe.model  
  inflating: /kaggle/working/transformer_data/train.en  
  inflating: /kaggle/working/transformer_data/test.en  
  inflating: /kaggle/working/transformer_data/val.de  
  inflating: /kaggle/working/transformer_data/test.de  
   creating: /transformer_data/
  inflating: /transformer_data/val.en  
  inflating: /transformer_data/train.de  
  inflating: /transformer_data/bpe.model  
  inflating: /transformer_data/train.en  
  inflating: /transformer_data/test.en  
  inflating: /transformer_data/val.de  
  inflating: /transformer_data/test.de  


In [7]:
data_folder="/kaggle/working/transformer_data"
bpe_model_path = os.path.join(data_folder, "bpe.model")

In [8]:
bpe_model = yttm.BPE(model=bpe_model_path)

In [9]:
bpe_model.encode(['There was no ring on his finger.',
                  'That was a good sign although far from proof that he was available.'],
                 output_type=yttm.OutputType.SUBWORD,
                 bos=True, eos=True)

[['<BOS>',
  '▁There',
  '▁was',
  '▁no',
  '▁ring',
  '▁on',
  '▁his',
  '▁fing',
  'er.',
  '<EOS>'],
 ['<BOS>',
  '▁That',
  '▁was',
  '▁a',
  '▁good',
  '▁sign',
  '▁although',
  '▁far',
  '▁from',
  '▁proof',
  '▁that',
  '▁he',
  '▁was',
  '▁available.',
  '<EOS>']]

# Transformer implementation

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(
        self,
        emb_size,
        dropout,
        maxlen=5000
    ):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)

        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class Translator(nn.Module):
    def __init__(
            self,
            num_encoder_layers,
            num_decoder_layers,
            embed_size,
            num_heads,
            src_vocab_size,
            tgt_vocab_size,
            dim_feedforward,
            dropout
        ):
        super(Translator, self).__init__()

        # Output of embedding must be equal (embed_size)
        self.src_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, embed_size)

        self.pos_enc = PositionalEncoding(embed_size, dropout)

        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,
        )

        self.ff = nn.Linear(embed_size, tgt_vocab_size)

        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, trg, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        src_emb = self.pos_enc(self.src_embedding(src))
        tgt_emb = self.pos_enc(self.tgt_embedding(trg))

        # Shape:
        #     - src:  (N, S, E) if batch_first=True.
        #     - tgt: (N, T, E) if batch_first=True.
        #     - src_mask: (S, S) or (N * num_heads, S, S).
        #     - tgt_mask: (T, T) or (N * num_heads, T, T).
        #     - memory_mask: (T, S).
        #     - src_key_padding_mask: (N, S).
        #     - tgt_key_padding_mask: (N, T).
        #     - memory_key_padding_mask: (N, S).

        outs = self.transformer(
            src=src_emb,
            tgt=tgt_emb,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            memory_mask=None,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )

        return self.ff(outs)

    def encode(self, src, src_mask):
        embed = self.src_embedding(src)
        pos_enc = self.pos_enc(embed)
        return self.transformer.encoder(pos_enc, src_mask)

    def decode(self, tgt, memory, tgt_mask):
        embed = self.tgt_embedding(tgt)
        pos_enc = self.pos_enc(embed)
        return self.transformer.decoder(pos_enc, memory, tgt_mask)

# Init Dataloaders 

In [11]:
class SequenceLoader(object):
    """
    An iterator for loading batches of data into the transformer model.

    For training:

        Each batch contains tokens_in_batch target language tokens (approximately),
        target language sequences of the same length to minimize padding and therefore memory usage,
        source language sequences of very similar (if not the same) lengths to minimize padding and therefore memory usage.
        Batches are also shuffled.

    For validation and testing:

        Each batch contains just a single source-target pair, in the same order as in the files from which they were read.
    """

    def __init__(self, tokenizer, data_folder, source_suffix, target_suffix, split, tokens_in_batch):
        """
        :param data_folder: folder containing the source and target language data files
        :param source_suffix: the filename suffix for the source language files
        :param target_suffix: the filename suffix for the target language files
        :param split: train, or val, or test?
        :param tokens_in_batch: the number of target language tokens in each batch
        """
        self.tokens_in_batch = tokens_in_batch

        self.source_suffix = source_suffix
        self.target_suffix = target_suffix

        assert split.lower() in {"train", "val","test"}, "'split' must be one of 'train', 'val', 'test'! (case-insensitive)"
        self.split = split.lower()

        # Is this for training?
        self.for_training = self.split == "train"

        # Load BPE model
        self.bpe_model = tokenizer

        # Load data
        with codecs.open(os.path.join(data_folder, ".".join([split, source_suffix])), "r", encoding="utf-8") as f:
            source_data = f.read().split("\n")[:-1]
        with codecs.open(os.path.join(data_folder, ".".join([split, target_suffix])), "r", encoding="utf-8") as f:
            target_data = f.read().split("\n")[:-1]
        assert len(source_data) == len(target_data), "There are a different number of source or target sequences!"
        
        source_lengths = [len(s) for s in self.bpe_model.encode(source_data, bos=False, eos=False)]

        # target language sequences have <BOS> and <EOS> tokens
        target_lengths = [len(t) for t in self.bpe_model.encode(target_data, bos=True, eos=True)]

        self.data = list(zip(source_data, target_data, source_lengths, target_lengths))

        # If for training, pre-sort by target lengths - required for itertools.groupby() later
        if self.for_training:
            self.data.sort(key=lambda x: x[3])

        # Create batches
        self.create_batches()

    def create_batches(self):
        """
        Prepares batches for one epoch.
        """

        print("Creating batches")

        # If training
        if self.for_training:
            # Group or chunk based on target sequence lengths
            chunks = [list(g) for _, g in groupby(self.data, key=lambda x: x[3])]

            # Create batches, each with the same target sequence length
            self.all_batches = list()
            for chunk in chunks:
                # Sort inside chunk by source sequence lengths, so that a batch would also have similar source sequence lengths
                chunk.sort(key=lambda x: x[2])
                # How many sequences in each batch? Divide expected batch size (i.e. tokens) by target sequence length in this chunk
                seqs_per_batch = self.tokens_in_batch // chunk[0][3]
                # Split chunk into batches
                self.all_batches.extend([chunk[i: i + seqs_per_batch] for i in range(0, len(chunk), seqs_per_batch)])

            # Shuffle batches
            shuffle(self.all_batches)
            self.n_batches = len(self.all_batches)
            self.current_batch = -1
        else:
            # Simply return once pair at a time
            self.all_batches = [[d] for d in self.data]
            self.n_batches = len(self.all_batches)
            self.current_batch = -1

    def __iter__(self):
        """
        Iterators require this method defined.
        """
        return self

    def __len__(self):
      return len(self.all_batches)

    def __next__(self):
        """
        Iterators require this method defined.

        :returns: the next batch, containing:
            source language sequences, a tensor of size (N, encoder_sequence_pad_length)
            target language sequences, a tensor of size (N, decoder_sequence_pad_length)
            true source language lengths, a tensor of size (N)
            true target language lengths, typically the same as decoder_sequence_pad_length as these sequences are bucketed by length, a tensor of size (N)
        """
        # Update current batch index
        self.current_batch += 1
        try:
            source_data, target_data, source_lengths, target_lengths = zip(*self.all_batches[self.current_batch])
        # Stop iteration once all batches are iterated through
        except IndexError:
            raise StopIteration

        # Tokenize using BPE model to word IDs
        source_data = self.bpe_model.encode(source_data, output_type=yttm.OutputType.ID,
                                            bos=False, eos=False)
        target_data = self.bpe_model.encode(target_data, output_type=yttm.OutputType.ID,
                                            bos=True, eos=True)

        pad_id = self.bpe_model.subword_to_id('<PAD>')

        # Convert source and target sequences as padded tensors
        source_data = pad_sequence(sequences=[torch.LongTensor(s) for s in source_data],
                                   batch_first=True,
                                   padding_value=pad_id)
        target_data = pad_sequence(sequences=[torch.LongTensor(t) for t in target_data],
                                   batch_first=True,
                                   padding_value=pad_id)

        # Convert lengths to tensors
        # source_lengths = torch.LongTensor(source_lengths)
        # target_lengths = torch.LongTensor(target_lengths)

        return source_data, target_data

In [104]:
def generate_square_subsequent_mask(size, device):
    mask = (torch.triu(torch.ones((size, size), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# Create masks for input into model
def create_mask(src, tgt, pad_idx, device):

    # Get sequence length
    src_seq_len = src.shape[1]
    tgt_seq_len = tgt.shape[1]

    # Generate the mask
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len, device)
    src_mask = torch.zeros((src_seq_len, src_seq_len), device=device).type(torch.bool)

    # Overlay the mask over the original input
    src_padding_mask = (src == pad_idx)
    tgt_padding_mask = (tgt == pad_idx)
    tgt_padding_mask = (tgt_padding_mask.float()
                                        .masked_fill(tgt_padding_mask == 1, float('-inf'))
                                        .masked_fill(tgt_padding_mask == 0, float(0.0)))

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [105]:
def get_data(opts, tokenizer):
    # Initialize data-loaders
    train_dataloader = SequenceLoader(tokenizer,
                                      data_folder="/kaggle/working/transformer_data",
                                      source_suffix=opts.src,
                                      target_suffix=opts.tgt,
                                      split="train",
                                      tokens_in_batch=opts.tokens_in_batch)

    valid_dataloader = SequenceLoader(tokenizer,
                                      data_folder="/kaggle/working/transformer_data",
                                      source_suffix=opts.src,
                                      target_suffix=opts.tgt,
                                      split="val",
                                      tokens_in_batch=opts.tokens_in_batch)

    vocab = tokenizer.vocab()

    src_lang_transform = lambda src_lang: tokenizer.encode(src_lang,
                                                           output_type=yttm.OutputType.ID,
                                                           bos=False, eos=False)
    tgt_lang_transform = lambda tgt_lang: tokenizer.encode(tgt_lang,
                                                           output_type=yttm.OutputType.ID,
                                                           bos=True, eos=True)

    special_symbols = {
        "<unk>": tokenizer.subword_to_id('<UNK>'),
        "<pad>": tokenizer.subword_to_id('<PAD>'),
        "<bos>": tokenizer.subword_to_id('<BOS>'),
        "<eos>": tokenizer.subword_to_id('<EOS>'),
    }

    return train_dataloader, valid_dataloader, vocab, src_lang_transform, tgt_lang_transform, special_symbols

In [106]:
class Opts:
    def __init__(self):
        self.src = "en"
        self.tgt = "de"
        self.batch = 128
        self.tokens_in_batch = 2000

opts = Opts()

train_dl, valid_dl, vocab, src_lang_transform, tgt_lang_transform, special_symbols = get_data(opts, bpe_model)

print(f"vocab size: {len(vocab)}")

Creating batches
Creating batches
vocab size: 37000


# Parameters

In [107]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Training settings
epochs = 5
lr = 1e-4
betas=(0.9, 0.98)
eps=1e-9

# Transformer settings
attn_heads = 8
enc_layers = 5
dec_layers = 5
embed_size = 512
dim_feedforward = 512
dropout = 0.1

best_val_loss = 1e6
start_epoch = 1  # start at this epoch

dry_run = True

logging_dir = f"/kaggle/working/{str(date.today())}/"
os.makedirs(logging_dir, exist_ok=True)

phases = ['train', 'val']
saved_epoch_losses = {phase: [] for phase in phases}
saved_bleu_metrics = {phase: [] for phase in phases}

In [114]:
def train(model, train_dl, loss_fn, optim, special_symbols):
    model.train()
    
    losses = 0
    running_corrects = 0
    all_elems_count = 0

    curr_tqdm = tqdm(train_dl, total=len(train_dl), ascii=True)
    for batch_idx, (src, tgt) in enumerate(curr_tqdm):
        # if batch_idx == 10: break
        
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:, :-1]
        
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE)

        # (B, T, vocab_size)
        logits = model(src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask)

        optim.zero_grad()

        tgt_out = tgt[:, 1:]

        flattened_logits = logits.reshape(-1, logits.shape[-1])
        flattened_tgt = tgt_out.reshape(-1)
        symbols_count = flattened_tgt.shape[0] 

        all_elems_count += symbols_count
        
        probs = F.softmax(flattened_logits, dim=1)
        preds = torch.argmax(probs, dim=1)
        corrects_cnt = torch.sum(preds == flattened_tgt.detach())
        running_corrects += corrects_cnt
        
        loss = loss_fn(flattened_logits, flattened_tgt)
        loss.backward()

        optim.step()

        losses += loss.item()

        curr_tqdm.set_postfix({"Loss": f"{loss.item():.2f}",
                               "Corrects": f"{corrects_cnt.item()}/{symbols_count}",
                               "Accuracy": f"{(corrects_cnt * 100 / symbols_count).item():.3f}%"})

    epoch_loss_per_batch = losses / len(train_dl)
    epoch_acc_per_symbol = running_corrects.float().item() / all_elems_count
    
    return epoch_loss_per_batch, epoch_acc_per_symbol

In [115]:
def greedy_decode(model, src, src_mask, max_len, start_symbol, end_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)

    for _ in range(max_len - 1):

        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(1), DEVICE).type(torch.bool)).to(DEVICE)

        out = model.decode(ys, memory, tgt_mask)

        # Covert to probabilities and take the max of these probabilities
        prob = model.ff(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        # Now we have an output which is the vector representation of the translation
        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == end_symbol:
            break

    return ys

In [116]:
def translate(src):
    num_tokens = src.shape[1]

    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)

    tgt_tokens = greedy_decode(
        model, src, src_mask, max_len=num_tokens+5, 
        start_symbol=special_symbols["<bos>"], 
        end_symbol=special_symbols["<eos>"]
    ).flatten()

    output_as_list = list(tgt_tokens.cpu().numpy())
    
    output_list_words = bpe_model.decode([output_as_list], ignore_ids=[0, 2, 3])

    translation = " ".join(output_list_words)

    return translation

In [117]:
def validate(model, valid_dl, loss_fn, special_symbols):
    losses = 0

    model.eval()

    hypotheses = list()
    references = list()
    
    for batch_idx, (src, tgt) in enumerate(tqdm(valid_dl, total=len(valid_dl), ascii=True)):
        # if batch_idx == 10: break
        
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        tgt_input = tgt[:, :-1]

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input, special_symbols["<pad>"], DEVICE)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[:, 1:]

        loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

        translation = translate(src)
        # print(f"{translation=}")
        
        hypotheses.append(translation)
        references.extend(bpe_model.decode(tgt.tolist(), ignore_ids=[0, 2, 3]))

    print(hypotheses[:10])
    print(references[:10])

    sacrebleu_metrics = {
        "13a cased": sacrebleu.corpus_bleu(hypotheses, [references]),
        "13a caseless": sacrebleu.corpus_bleu(hypotheses, [references], lowercase=True),
        "intl cased": sacrebleu.corpus_bleu(hypotheses, [references], tokenize='intl'),
        "intl caseless": sacrebleu.corpus_bleu(hypotheses, [references], tokenize='intl', lowercase=True)
    }
    
    return losses / len(valid_dl), sacrebleu_metrics

In [112]:
def save_checkpoint(epoch, 
                    model, 
                    optim, 
                    best_val_loss, 
                    saved_epoch_losses, 
                    saved_bleu_metrics,
                    prefix=''):
    """
    Checkpoint saver. Each save overwrites previous save.

    :param epoch: epoch number (0-indexed)
    :param model: transformer model
    :param optimizer: optimized
    :param prefix: checkpoint filename prefix
    """
    state = {'epoch': epoch,
             'model': model.state_dict(),
             'optimizer': optim.state_dict(),
             'best_val_loss': best_val_loss,
             'saved_epoch_losses': saved_epoch_losses,
             'saved_bleu_metrics': saved_bleu_metrics}

    filename = prefix + 'transformer_checkpoint.pth'

    torch.save(state, filename)

In [118]:
checkpoint = None # "{logging_dir}/{"best" or "last"}_transformer_checkpoint.pth.tar"

model = Translator(
    num_encoder_layers=enc_layers,
    num_decoder_layers=dec_layers,
    embed_size=embed_size,
    num_heads=attn_heads,
    src_vocab_size=len(vocab),
    tgt_vocab_size=len(vocab),
    dim_feedforward=dim_feedforward,
    dropout=dropout
).to(DEVICE)

# These special values are from the "Attention is all you need" paper
optim = torch.optim.Adam(
    model.parameters(),
    lr=lr,
    betas=betas,
    eps=eps
)

# Initialize model or load checkpoint
if checkpoint is not None:
    checkpoint = torch.load(checkpoint, weights_only=True, map_location=DEVICE)
    print(f'\nLoaded checkpoint from epoch {start_epoch}.')

    start_epoch = checkpoint['epoch']
    best_val_loss = checkpoint['best_val_loss']
    saved_epoch_losses = checkpoint['saved_epoch_losses']
    saved_bleu_metrics = checkpoint['saved_bleu_metrics']

    model.load_state_dict(checkpoint['model'])
    optim.load_state_dict(checkpoint['optimizer'])

In [119]:
# Set up our learning tools
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=special_symbols["<pad>"]).to(DEVICE)

In [120]:
for idx, epoch in enumerate(range(start_epoch, epochs + 1)):
    start_time = time.time()
    train_dl.create_batches()
    train_loss, train_acc_per_symbol = train(model, train_dl, loss_fn, optim, special_symbols)
    epoch_time = time.time() - start_time

    valid_dl.create_batches()
    val_loss, sacrebleu_metrics  = validate(model, valid_dl, loss_fn, special_symbols)

    # Once training is done, we want to save out the model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        save_checkpoint(epoch, model, optim, best_val_loss, saved_epoch_losses, saved_bleu_metrics, prefix=(logging_dir + "best_"))

    saved_epoch_losses['train'].append(train_loss)
    saved_epoch_losses['val'].append(val_loss)
    saved_bleu_metrics['val'].append(sacrebleu_metrics)
    
    save_checkpoint(epoch, model, optim, best_val_loss, saved_epoch_losses, saved_bleu_metrics, prefix=(logging_dir + "last_"))
    
    print(f"Epoch: {epoch}, "
          f"Train acc per symbol: {train_acc_per_symbol:.2f}, "
          f"Train loss: {train_loss:.3f}, "
          f"Val loss: {val_loss:.3f},\n"
          f"Epoch time = {epoch_time:.1f} seconds, "
          f"ETA = {epoch_time*(epochs-idx-1):.1f} seconds")

    print("\n".join(list(map(lambda k: f"{k[0]}: {k[1]}", sacrebleu_metrics.items()))))

Creating batches


 13%|#3        | 7974/59476 [30:20<3:15:58,  4.38it/s, Loss=7.51, Corrects=70/1904, Accuracy=3.676%]  


KeyboardInterrupt: 

In [121]:
save_checkpoint(epoch, model, optim, best_val_loss, saved_epoch_losses, saved_bleu_metrics, prefix=(logging_dir + "last_"))