In [34]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

In [35]:
warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

## Generic architecture

In [36]:
class EncoderDecoder(nn.Module):
    """
    Standard encoder-decoder architecture.
    """
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator
        
    def forward(self, src, tgt, src_mask, tgt_mask):
        """ Take in and process masked source and target sequences. """
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        """ Embeds the source and encodes it. """
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        """ Decodes the embedded target. """
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

In [37]:
class Generator(nn.Module):
    """
    Standard linear and softmax generation step.
    """
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
        
    def forward(self, x):
        """ Forward pass. """
        return log_softmax(self.proj(x), dim=-1)

## Encoder

In [38]:
def clones(module, N):
    """ Produces N identical layers. """
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [39]:
class LayerNorm(nn.Module):
    """ Constructs a layernorm module. """

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [40]:
class Encoder(nn.Module):
    """
    Core encoder, as a stack of N layers.
    """
    def __init__(self, layer, N):
        super(Encoder, self).__init__()    
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, mask):
        """ Forward pass for input and mask. """
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [41]:
class SublayerConnection(nn.Module):
    """
    Residual network followed by a layernorm.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, sublayer):
        """ Applies residual connection to any sublayer with same size. """
        return x + self.dropout(sublayer(self.norm(x)))

In [42]:
class EncoderLayer(nn.Module):
    """
    Encoder (self-attention and feed forward).
    """
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)
        self.size = size
        
    def forward(self, x, mask):
        """ Encoder layer forward pass. """
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

## Decoder

In [43]:
class Decoder(nn.Module):
    """
    Generic N layer decoder with masking.
    """
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        """ Decoder forward pass. """
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [44]:
class DecoderLayer(nn.Module):
    """
    Decoder (self-attention, source-attention and feed forward).
    """
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        """ Decoder layer forward pass. """
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

In [45]:
def subsequent_mask(size):
    """ Masks out subsequent positions. """
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return subsequent_mask == 0

In [46]:
def example_mask():
    LS_data = pd.concat(
        [
            pd.DataFrame({
                "Subsequent Mask": subsequent_mask(20)[0][x, y].flatten(),
                "Window": y,
                "Masking": x,
            })
            for y in range(20)
            for x in range(20)
        ]
    )
    
    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=250, width=250)
        .encode(
            alt.X("Window:O"),
            alt.Y("Masking:O"),
            alt.Color("Subsequent Mask:Q", scale=alt.Scale(scheme="viridis"))
        )
        .interactive()
    )

def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)
    
show_example(example_mask)

## Attention

In [47]:
def attention(query, key, value, mask=None, dropout=None):
    """ Compute Scaled Dot-Product Attention. """
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [48]:
class MultiHeadedAttention(nn.Module):
    """
    Multi-headed Attention to jointly attend to information from different
    representation subspaces at different positions. d_v is assumed to be
    always equal to d_k.
    """
    def __init__(self, h, d_model, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        """ Multi-head attention forward pass. """
        if mask is not None:
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        # 1) Linear projection in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]
        # 2) Apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        # 3) Concat and apply final linear
        x = (x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k))
        del query
        del key
        del value
        return self.linears[-1](x)

In [49]:
class PositionWiseFeedForward(nn.Module):
    """
    Implements FFN.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        """ FFN forward pass. """
        return self.w_2(self.dropout(self.w_1(x).relu()))

## Embeddings

In [50]:
class Embeddings(nn.Module):
    """
    Classic learned embeddings.
    """
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model
        
    def forward(self, x):
        """ Embeddings forward pass. """
        return self.lut(x) * math.sqrt(self.d_model)

## Positional Encoding

In [51]:
class PositionalEncoding(nn.Module):
    """
    Sine and cosine functions for positional encoding.
    """
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)
        
    def forward(self, x):
        """ Position encoding forward pass. """
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

## Model

In [52]:
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    """ Constructs a model from hyperparameters. """
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionWiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, src_vocab), c(position)),
        nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )
    
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

In [53]:
class Batch:
    """
    Holder of batch of data with mask during training.
    """
    def __init__(self, src, tgt=None, pad=2): # 2 = <blank>
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).data.sum()
        
    @staticmethod
    def make_std_mask(tgt, pad):
        """ Creates a mask to hide padding and future words. """
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
        return tgt_mask

In [54]:
class TrainState:
    """
    Tracks state of training.
    """
    step: int = 0 # Steps in the current epoch
    accum_step: int = 0 # Number of gradient accumulation steps
    samples: int = 0 # total number of examples used
    tokens: int = 0 # total number of tokens processed

In [55]:
def run_epoch(data_iter, model, loss_compute, optimizer, scheduler,
                mode="train", accum_iter=1, train_state=TrainState()):
    """ Trains a single epoch. """
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.src, batch.tgt, batch.src_mask, batch.tgt_mask)
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()
            
        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            print((
                "Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.2f "
                + "| Tokens / Sec: %7.1f | Learning Rate: %6.1e"
            ) % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr))
            start = time.time()
            tokens = 0
        del loss
        del loss_node
    return total_loss / total_tokens, train_state

In [56]:
def rate(step, model_size, factor, warmup):
    if step == 0:
        step = 1
    return factor * (model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5)))

## Regularization

In [57]:
class LabelSmoothing(nn.Module):
    """
    Implements label smoothing.
    """
    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="sum")
        self.padding_idx = padding_idx
        self.confidence = 1 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None
        
    def forward(self, x, target):
        """ Label smoothing forward pass. """
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, true_dist.clone().detach())

## Multi30k-transformer specific code

In [58]:
def load_tokenizers():
    """ Load spacy tokenizer models. """
    try:
        spacy_de = spacy.load("de_core_news_sm")
    except IOError:
        os.system("python -m spacy download de_core_news_sm")
        spacy_de = spacy.load("de_core_news_sm")
        
    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")
        
    return spacy_de, spacy_en
        
def tokenize(text, tokenizer):
    """ Tokenizes text with given tokenizer. """
    return [tok.text for tok in tokenizer.tokenizer(text)]

def yield_tokens(data_iter, tokenizer, index):
    """ Generator that yields the tokens from the index. """
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])
        
def build_vocabulary(spacy_de, spacy_en):
    """ Builds the vocabulary. """
    def tokenize_de(text):
        """ Retrieves the token for a given word in German. """
        return tokenize(text, spacy_de)
    
    def tokenize_en(text):
        """ Retrieves the token for a given word in English. """
        return tokenize(text, spacy_en)
    
    print("Building German Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_src = build_vocab_from_iterator(
        yield_tokens(train + val + test, tokenize_de, index=0),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )
    
    print("Building English Vocabulary ...")
    train, val, test = datasets.Multi30k(language_pair=("de", "en"))
    vocab_tgt = build_vocab_from_iterator(
        yield_tokens(train + val + test, tokenize_en, index=1),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )
    
    vocab_src.set_default_index(vocab_src["<unk>"])
    vocab_tgt.set_default_index(vocab_tgt["<unk>"])
    
    return vocab_src, vocab_tgt

In [59]:
def load_vocab(spacy_de, spacy_en):
    """ Loads vocabulary from \"vocab.pt\" if file exist, else creates it. """
    if not exists("vocab.pt"):
        vocab_src, vocab_tgt = build_vocabulary(spacy_de, spacy_en)
        torch.save((vocab_src, vocab_tgt), "vocab.pt")
    else:
        vocab_src, vocab_tgt = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(len(vocab_src))
    print(len(vocab_tgt))
    return vocab_src, vocab_tgt

In [60]:
if __name__ == "__main__":
    spacy_de, spacy_en = show_example(load_tokenizers)
    vocab_src, vocab_tgt = show_example(load_vocab, args=[spacy_de, spacy_en])

Finished.
Vocabulary sizes:
8315
6384


In [61]:
def collate_batch(batch, src_pipeline, tgt_pipeline, src_vocab, tgt_vocab, device, max_padding=128, pad_id=2,):
    bs_id = torch.tensor([0], device=device)  # <s> token id
    eos_id = torch.tensor([1], device=device)  # </s> token id
    src_list, tgt_list = [], []
    for (_src, _tgt) in batch:
        processed_src = torch.cat(
            [
                bs_id,
                torch.tensor(src_vocab(src_pipeline(_src)), dtype=torch.int64, device=device),
                eos_id,
            ], 0
        )
        processed_tgt = torch.cat(
            [
                bs_id,
                torch.tensor(tgt_vocab(tgt_pipeline(_tgt)), dtype=torch.int64, device=device),
                eos_id,
            ], 0
        )
        src_list.append(
            # Overwrites values for negative values of padding - len
            pad(processed_src, (0, max_padding - len(processed_src)), value=pad_id))
        tgt_list.append(pad(processed_tgt, (0, max_padding - len(processed_tgt)), value=pad_id))

    src = torch.stack(src_list)
    tgt = torch.stack(tgt_list)
    return (src, tgt)

In [62]:
def to_map_style_dataset(iter_data):
    class _MapStyleDataset(torch.utils.data.Dataset):

        def __init__(self, iter_data):
            self._data = list(iter_data)

        def __len__(self):
            return len(self._data)

        def __getitem__(self, idx):
            return self._data[idx]

    return _MapStyleDataset(iter_data)

In [63]:
def create_dataloaders(device, vocab_src, vocab_tgt, spacy_de, spacy_en,
                       batch_size=12000, max_padding=128, is_distributed=True):
    def tokenize_de(text):
        return tokenize(text, spacy_de)

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    def collate_fn(batch):
        return collate_batch(batch, tokenize_de, tokenize_en, vocab_src, vocab_tgt,
                             device, max_padding=max_padding, pad_id=vocab_src.get_stoi()["<blank>"],
        )

    train_iter, valid_iter, test_iter = datasets.Multi30k(language_pair=("de", "en"))

    train_iter_map = to_map_style_dataset(train_iter)
    train_sampler = (DistributedSampler(train_iter_map) if is_distributed else None)
    valid_iter_map = to_map_style_dataset(valid_iter)
    valid_sampler = (DistributedSampler(valid_iter_map) if is_distributed else None)

    train_dataloader = DataLoader(
        train_iter_map,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        collate_fn=collate_fn,
    )
    valid_dataloader = DataLoader(
        valid_iter_map,
        batch_size=batch_size,
        shuffle=(valid_sampler is None),
        sampler=valid_sampler,
        collate_fn=collate_fn,
    )
    return train_dataloader, valid_dataloader

In [64]:
class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None

class DummyScheduler:
    def step(self):
        None

class SimpleLossCompute:
    "Simple loss compute and train function."

    def __init__(self, generator, criterion):
        self.generator = generator
        self.criterion = criterion

    def __call__(self, x, y, norm):
        x = self.generator(x)
        sloss = (self.criterion(x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)) / norm)
        return sloss.data * norm, sloss

In [67]:
def train_worker(gpu, ngpus_per_node, vocab_src, vocab_tgt, spacy_de, spacy_en, config, is_distributed=False):
    print(f"Train worker process using: {gpu} for training", flush=True)
    torch.cuda.set_device(gpu)
    pad_idx = vocab_tgt["<blank>"]
    d_model = 512
    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.cuda(gpu)
    module = model
    is_main_process = True
    if is_distributed:
        dist.init_process_group("mpi", init_method="env://", rank=gpu, world_size=ngpus_per_node)
        model = DDP(model, device_ids=[gpu])
        module = model.module
        is_main_process = gpu == 0
        
    criterion = LabelSmoothing(size=len(vocab_tgt), padding_idx=pad_idx, smoothing=0.1)
    criterion.cuda(gpu)
    
    train_dataloader, valid_dataloader = create_dataloaders(gpu, vocab_src, vocab_tgt, spacy_de, spacy_en,
                                                            batch_size=config["batch_size"], max_padding=config["max_padding"],
                                                            is_distributed=is_distributed)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9)
    lr_scheduler = LambdaLR(optimizer=optimizer, lr_lambda=lambda step: rate(step, d_model, factor=1, warmup=config["warmup"]))
    train_state = TrainState()
    
    for epoch in range(config["num_epochs"]):
        if is_distributed:
            train_dataloader.sampler.set_epoch(epoch)
            valid_dataloader.sampler.set_epoch(epoch)
            
        model.train()
        print(f"[GPU{gpu}] Epoch {epoch} Training ====", flush=True)
        _, train_state = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in train_dataloader),
            model,
            SimpleLossCompute(module.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train+log",
            accum_iter=config["accum_iter"],
            train_state=train_state
        )
        
        GPUtil.showUtilization()
        if is_main_process:
            file_path = "%s%.2d.pt" % (config["file_prefix"], epoch)
            torch.save(module.state_dict(), file_path)
        torch.cuda.empty_cache()
        
        print(f"[GPU{gpu}] Epoch {epoch} Validation ====", flush=True)
        model.eval()
        sloss = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in train_dataloader),
            model,
            SimpleLossCompute(module.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval"
        )
        print(sloss)
        torch.cuda.empty_cache()
        
    if is_main_process:
        file_path = "%sfinal.pt" % config["file_prefix"]
        torch.save(module.state_dict(), file_path)

In [68]:
def train_distributed_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config):
    ngpus = torch.cuda.device_count()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12356"
    print(f"Number of GPUs detected: {ngpus}")
    print("Spawning training processes ...")
    mp.spawn(train_worker, nprocs=ngpus, args=(ngpus, vocab_src, vocab_tgt, spacy_de, spacy_en, config, True))
    
def train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config):
    if config["distributed"]:
        train_distributed_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config)
    else:
        train_worker(0, 1, vocab_src, vocab_tgt, spacy_de, spacy_en, config, False)
        
def load_trained_model():
    config = {
        "batch_size": 4,
        "distributed": False,
        "num_epochs": 8,
        "accum_iter": 10,
        "base_lr": 1.0,
        "max_padding": 72,
        "warmup": 1000,
        "file_prefix": "multi30k_model_"
    }
    model_path = "multi30k_model_final.pt"
    if not exists(model_path):
        train_model(vocab_src, vocab_tgt, spacy_de, spacy_en, config)
        
    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.load_state_dict(torch.load("multi30k_model_final.pt"))
    return model

if __name__ == "__main__":
    model = load_trained_model()

Train worker process using: 0 for training
[GPU0] Epoch 0 Training ====
Epoch Step:      1 | Accumulation Step:   1 | Loss:   7.58 | Tokens / Sec:   480.0 | Learning Rate: 2.8e-06
Epoch Step:     41 | Accumulation Step:   5 | Loss:   6.84 | Tokens / Sec:   896.4 | Learning Rate: 5.9e-05
Epoch Step:     81 | Accumulation Step:   9 | Loss:   6.43 | Tokens / Sec:   881.7 | Learning Rate: 1.1e-04
Epoch Step:    121 | Accumulation Step:  13 | Loss:   6.15 | Tokens / Sec:   729.8 | Learning Rate: 1.7e-04
Epoch Step:    161 | Accumulation Step:  17 | Loss:   5.90 | Tokens / Sec:   781.6 | Learning Rate: 2.3e-04
Epoch Step:    201 | Accumulation Step:  21 | Loss:   5.48 | Tokens / Sec:   812.7 | Learning Rate: 2.8e-04
Epoch Step:    241 | Accumulation Step:  25 | Loss:   4.78 | Tokens / Sec:   772.7 | Learning Rate: 3.4e-04
Epoch Step:    281 | Accumulation Step:  29 | Loss:   4.93 | Tokens / Sec:   809.1 | Learning Rate: 3.9e-04
Epoch Step:    321 | Accumulation Step:  33 | Loss:   4.75 | Tok

In [69]:
def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)

def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data))
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    return ys

def check_outputs(valid_dataloader, model, vocab_src, vocab_tgt, n_examples=15, pad_idx=2, eos_string="</s>"):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d =======\n" % idx)
        b = next(iter(valid_dataloader))
        rb = Batch(b[0], b[1], pad_idx)
        greedy_decode(model, rb.src, rb.src_mask, 64, 0)[0]
        src_tokens = [vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx]
        tgt_tokens = [vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx]
        print("Source Text (Input)        :" + " ".join(src_tokens).replace("\n", ""))
        print("Target Text (Ground Truth) :" + " ".join(tgt_tokens).replace("\n", ""))
        model_out = greedy_decode(model, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (" ".join([vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]).split(eos_string, 1)[0] + eos_string)
        print("Model Output               :" + model_txt.replace("\n", ""))
        results[idx] = (rb, src_tokens, tgt_tokens, model_out, model_txt)
    return results

def run_model_example(n_examples=5):
    global vocab_src, vocab_tgt, spacy_de, spacy_en
    
    print("Preparing Data ...")
    _, valid_dataloader = create_dataloaders(torch.device("cpu"), vocab_src, vocab_tgt,
                                             spacy_de, spacy_en, batch_size=1, is_distributed=False)
    
    print("Loading Trained model ...")
    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.load_state_dict(torch.load("multi30k_model_final.pt", map_location=torch.device("cpu")))
    print("Checking Model Outputs:")
    example_data = check_outputs(valid_dataloader, model, vocab_src, vocab_tgt, n_examples=n_examples)
    return model, example_data

execute_example(run_model_example)

Preparing Data ...
Loading Trained model ...
Checking Model Outputs:


Source Text (Input)        :<s> Ein Mann klettert mit <unk> <unk> eine steile Felswand hinauf </s>
Target Text (Ground Truth) :<s> A man is climbing a steep rock wall with a safety harness on </s>
Model Output               :<s> A man is climbing a rock face with his face . </s>


Source Text (Input)        :<s> Ein Mann mit einem leuchtend bunten Helm sitzt auf einem Motorrad . </s>
Target Text (Ground Truth) :<s> A man wearing a bright , multi - color helmet is sitting on a motorcycle . </s>
Model Output               :<s> A man in a bright red helmet is sitting on a motorcycle . </s>


Source Text (Input)        :<s> Zwei Menschen stehen vor einem Gebäude . </s>
Target Text (Ground Truth) :<s> Two people stand in front of a building . </s>
Model Output               :<s> Two people stand in front of a building . </s>


Source Text (Input)        :<s> Fotografen machen eine <unk> an einem Veranstaltungsort . </s>


## Attention visualization

In [74]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    """ Converts a dense matrix to a dataframe with row and column indices. """
    return pd.DataFrame([
        (r, c, float(m[r, c]),
         "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
         "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"))
         for r in range(m.shape[0])
         for c in range(m.shape[1])
         if r < max_row and c < max_col
    ],
    columns=["row", "column", "value", "row_token", "col_token"])

def attn_map(attn, layer, head, row_tokens, col_tokens, max_dim=30):
    df = mtx2df(attn[0, head].data, max_dim, max_dim, row_tokens, col_tokens)
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"]
        )
        .properties(height=400, width=400)
        .interactive()
    )

In [75]:
def get_encoder(model, layer):
    return model.encoder.layers[layer].self_attn.attn

def get_decoder_self(model, layer):
    return model.decoder.layers[layer].self_attn.attn

def get_decoder_src(model, layer):
    return model.decoder.layers[layer].src_attn.attn

def visualize_layer(model, layer, getter_fn, ntokens, row_tokens, col_tokens):
    attn = getter_fn(model, layer)
    n_heads = attn.shape[1]
    charts = [
        attn_map(attn, 0, h, row_tokens=row_tokens, col_tokens=col_tokens, max_dim=ntokens)
        for h in range(n_heads)
    ]
    assert n_heads == 8
    return alt.vconcat(charts[0] | charts[2] | charts[4] | charts[6]).properties(title="Layer %d" % (layer + 1))

In [76]:
def viz_encoder_self():
    model, example_data = run_model_example(n_examples=1)
    example = example_data[len(example_data) - 1]
    layer_viz = [
        visualize_layer(model, layer, get_encoder, len(example[1]), example[1], example[1])
        for layer in range(6)]
    return alt.hconcat(layer_viz[0] & layer_viz[2] & layer_viz[4])

show_example(viz_encoder_self)

Preparing Data ...
Loading Trained model ...
Checking Model Outputs:


Source Text (Input)        :<s> Ein Hund rennt über eine tiefe <unk> . </s>
Target Text (Ground Truth) :<s> A dog running through deep snow pack . </s>
Model Output               :<s> A dog runs across a water surrounded by water . </s>


In [77]:
def viz_decoder_self():
    model, example_data = run_model_example(n_examples=1)
    example = example_data[len(example_data) - 1]
    
    layer_viz = [
        visualize_layer(model, layer, get_decoder_self, len(example[1]), example[1], example[1])
        for layer in range(6)
    ]
    return alt.hconcat(layer_viz[0] & layer_viz[1] & layer_viz[2] & layer_viz[3] & layer_viz[4] & layer_viz[5])

show_example(viz_decoder_self)

Preparing Data ...
Loading Trained model ...
Checking Model Outputs:


Source Text (Input)        :<s> Zwei Männer mit <unk> sprechen miteinander auf einem Freiluftmarkt . </s>
Target Text (Ground Truth) :<s> Two men in <unk> 's have a discussion in an outdoor market . </s>
Model Output               :<s> Two men with <unk> are talking on a subway shop . </s>


In [78]:
def viz_decoder_src():
    model, example_data = run_model_example(n_examples=1)
    example = example_data[len(example_data) - 1]

    layer_viz =[
        visualize_layer(model, layer, get_decoder_src, max(len(example[1]), len(example[2])), example[1], example[2])
        for layer in range(6)
    ]
    return alt.hconcat(layer_viz[0] & layer_viz[1] & layer_viz[2] & layer_viz[3] & layer_viz[4] & layer_viz[5])

show_example(viz_decoder_src)

Preparing Data ...
Loading Trained model ...
Checking Model Outputs:


Source Text (Input)        :<s> Drei Frauen sitzen mit einem Baby auf einer Decke auf dem Boden , unterhalten sich und haben Spaß zusammen . </s>
Target Text (Ground Truth) :<s> Three women are sitting on a blanket on the ground with a baby , and are talking and having a good time . </s>
Model Output               :<s> Three women sit on the floor with a baby , one is talking and having a conversation . </s>
