In [330]:
!pip install -q datasets
!pip install -q GPUtil
from datasets import load_dataset

In [331]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtext.datasets import WikiText2


# Set to False to skip notebook execution (e.g. for debugging)
warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

In [332]:
def is_interactive_notebook():
    return __name__ == "__main__"

def show_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        return fn(*args)

def execute_example(fn, args=[]):
    if __name__ == "__main__" and RUN_EXAMPLES:
        fn(*args)

class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{"lr": 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None


class DummyScheduler:
    def step(self):
        None

# GPT Model

In [333]:
class GPT(nn.Module):

    def __init__(self, decoder, embed, generator):
        super(GPT, self).__init__()
        self.decoder = decoder
        self.embed = embed
        self.generator = generator

    def forward(self, x, pad_mask=None):
        "Take in and process masked src and target sequences."
        return self.decode(x, pad_mask)

    def decode(self, x, pad_mask=None):
        return self.decoder(self.embed(x), pad_mask)

## Generator

In [334]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."

    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)

In [335]:
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

## LayerNorm

In [336]:
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."

    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [337]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        "Apply residual connection to any sublayer with the same size."
        return x + self.dropout(sublayer(self.norm(x)))

## Decoder

In [338]:
class Decoder(nn.Module):
    "Generic N layer decoder with masking."

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, pad_mask=None):
        for layer in self.layers:
            x = layer(x, pad_mask)
        return self.norm(x)

## DecoderLayer

In [339]:
class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn, and feed forward (defined below)"

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, pad_mask=None):
        "Follow Figure 1 (right) for connections."
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, pad_mask))
        return self.sublayer[1](x, self.feed_forward)

## Mask

In [340]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

## Attention

In [341]:
def attention(query, key, value, mask=None, dropout=None):
    "Compute 'Scaled Dot Product Attention'"
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:

        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [342]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, pad_mask=None):
        "Implements Figure 2"

        # Same mask applied to all h heads and pad
        mask = subsequent_mask(query.size(-2))

        if pad_mask is not None:
            # print("pad mask****😷")
            mask = pad_mask & mask.type_as(pad_mask.data)
            mask = mask.unsqueeze(1)

        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2) Apply attention on all the projected vectors in batch.
        x, self.attn = attention(
            query, key, value, mask=mask, dropout=self.dropout
        )

        # 3) "Concat" using a view and apply a final linear.
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.h * self.d_k)
        )
        del query
        del key
        del value
        return self.linears[-1](x)

In [343]:
# tgt = torch.range(1,4)

# print(tgt)

# pad = 2


# tgt_mask = (tgt != pad).unsqueeze(-2)
# print(tgt_mask)
# print(tgt_mask.shape)
# print(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data).shape)

# tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
# print(tgt_mask.shape)
# print(tgt_mask)

In [344]:
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(d_model, d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.w_2(self.dropout(self.w_1(x).relu()))

## Embedding

In [345]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

## PositionalEncoding

In [346]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."

    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

# Full Model

In [347]:
def make_model(vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1):
    "Helper: Construct a model from hyperparameters."
    c = copy.deepcopy
    attn = MultiHeadedAttention(h, d_model)
    ff = PositionwiseFeedForward(d_model, d_ff, dropout)
    position = PositionalEncoding(d_model, dropout)
    model = GPT(
        Decoder(DecoderLayer(d_model, c(attn), c(ff), dropout), N),
        nn.Sequential(Embeddings(d_model, vocab), c(position)),
        Generator(d_model, vocab),
    )

    # This was important from their code.
    # Initialize parameters with Glorot / fan_avg.
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    return model

# A Simple Example

In [348]:
def inference_test():
    test_model = make_model(vocab = 10)
    test_model.eval()
    x = torch.LongTensor([[1, 2, 3, 4, 5]])

    for i in range(6):
        out = test_model.decode(x)
        prob = test_model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        x = torch.cat([x, torch.empty(1, 1).type_as(x.data).fill_(next_word)], dim=1)

    print("Example Untrained Model Prediction:", x)
    return x

def run_tests():
    for _ in range(10):
        inference_test()


# show_example(run_tests)

# Training

## Run Epoch

In [349]:
class TrainState:
    """Track number of steps, examples, and tokens processed"""

    step: int = 0  # Steps in the current epoch
    accum_step: int = 0  # Number of gradient accumulation steps
    samples: int = 0  # total # of examples used
    tokens: int = 0  # total # of tokens processed

In [350]:
def run_epoch(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode="train",
    accum_iter=1,
    train_state=TrainState(),
):
    """Train a single epoch"""
    start = time.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(batch.x, batch.pad_mask)
        loss, loss_node = loss_compute(out, batch.y, batch.ntokens)
        # loss_node = loss_node / accum_iter
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += batch.x.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()

        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]["lr"]
            elapsed = time.time() - start
            print(
                (
                    "Epoch Step: %6d | Accumulation Step: %3d | Loss: %6.2f "
                    + "| Tokens / Sec: %7.1f | Learning Rate: %6.1e"
                )
                % (i, n_accum, loss / batch.ntokens, tokens / elapsed, lr)
            )
            start = time.time()
            tokens = 0
        del loss
        del loss_node
    return total_loss / total_tokens, train_state

In [351]:
def rate(step, model_size, factor, warmup):
    """
    we have to default the step to 1 for LambdaLR function
    to avoid zero raising to negative power.
    """
    if step == 0:
        step = 1
    return factor * (
        model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))
    )

## Loss Computation

In [352]:
class SimpleLossCompute:
    "A simple loss compute and train function."

    def __init__(self, generator, criterion):
        self.generator = generator
        self.criterion = criterion

    def __call__(self, x, y, norm):
        x = self.generator(x)
        sloss = (
            self.criterion(
                x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)
            )
            / norm
        )
        return sloss.data * norm, sloss

## Label Smoothing

In [353]:
class LabelSmoothing(nn.Module):
    "Implement label smoothing."

    def __init__(self, size, padding_idx, smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        self.criterion = nn.KLDivLoss(reduction="sum")
        self.padding_idx = padding_idx
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size
        self.true_dist = None

    def forward(self, x, target):
        assert x.size(1) == self.size
        true_dist = x.data.clone()
        true_dist.fill_(self.smoothing / (self.size - 2))
        true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        true_dist[:, self.padding_idx] = 0
        mask = torch.nonzero(target.data == self.padding_idx)
        if mask.dim() > 0:
            true_dist.index_fill_(0, mask.squeeze(), 0.0)
        self.true_dist = true_dist
        return self.criterion(x, true_dist.clone().detach())

## Greedy Decode

In [354]:
def greedy_decode(model, x, max_len):
    ys = []
    for i in range(max_len):
        out = model.decode(x)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data
        ys.append(next_word)
    return torch.stack(ys, dim=1).type_as(x)

## Batch

In [355]:
class Batch:
    """Object for holding a batch of data with mask during training."""

    def __init__(self, x, y, pad=None): # 2 = <blank>
        self.x = x

        self.y = y
        # print("self x.shape", self.x.shape)
        self.ntokens = self.y.data.numel()
        self.pad_mask = None
        if pad:
            self.ntokens = (self.y != pad).data.sum()

            # todo
            self.pad_mask = (x != pad).unsqueeze(-2)

## Data Generation

In [356]:
def data_gen(V, batch_size, nbatches):
    for i in range(nbatches):

        random_row_values = torch.randint(0, V, (batch_size, 1))

        data = random_row_values.repeat(1, 10)

        src = data.requires_grad_(False).clone().detach()
        yield Batch(src[:, :-1], src[:, 1:])

## Training Process

In [357]:
def example_simple_model():
    V = 11
    criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)
    model = make_model(V)

    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, model_size=model.embed[0].d_model, factor=1.0, warmup=400
        ),
    )

    batch_size = 80
    for epoch in range(4):
        model.train()
        run_epoch(
            data_gen(V, batch_size, 20),
            model,
            SimpleLossCompute(model.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train",
        )
        model.eval()
        run_epoch(
            data_gen(V, batch_size, 5),
            model,
            SimpleLossCompute(model.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )[0]

    # model.eval()
    # src = torch.LongTensor([[2, 2, 2]])
    # max_len = src.shape[1]
    # print(greedy_decode(model, src, max_len=max_len))

    return model


# execute_example(example_simple_model)

In [358]:
# model = example_simple_model()

In [359]:
# model.eval()
# src = torch.LongTensor([[3, 3], [1,1]])
# max_len = src.shape[1]
# print(greedy_decode(model, src, max_len))

# A Real World Example

## Data Loading

In [360]:
# Load spacy tokenizer models, download them if they haven't been
# downloaded already


def load_tokenizers():

    try:
        spacy_en = spacy.load("en_core_web_sm")
    except IOError:
        os.system("python -m spacy download en_core_web_sm")
        spacy_en = spacy.load("en_core_web_sm")

    return spacy_en

In [361]:
def tokenize(text, tokenizer):
    return [tok.text for tok in tokenizer.tokenizer(text)]


def yield_tokens(data_iter, tokenizer):
    for from_to_tuple in data_iter:
        # print(from_to_tuple[index])
        yield tokenizer(from_to_tuple)


In [362]:
def build_vocabulary(spacy_en):

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    print("Building English Vocabulary ...")
    # Load the training, validation, and test datasets

    dataset = load_dataset("wikitext", "wikitext-2-v1")
    train, test, val = dataset['train']['text'], dataset['test']['text'], dataset['validation']['text']

    vocab = build_vocab_from_iterator(
        yield_tokens(train + test + val, tokenize_en),
        min_freq=2,
        specials=["<s>", "</s>", "<blank>", "<unk>"],
    )

    vocab.set_default_index(vocab["<unk>"])

    return vocab


def load_vocab(spacy_en):
    if not exists("vocab.pt"):
        vocab = build_vocabulary(spacy_en)
        torch.save((vocab), "vocab.pt")
    else:
        vocab = torch.load("vocab.pt")
    print("Finished.\nVocabulary sizes:")
    print(len(vocab))
    return vocab


if is_interactive_notebook():
    # global variables used later in the script
    spacy_en = show_example(load_tokenizers)
    vocab = show_example(load_vocab, args=[spacy_en])

Finished.
Vocabulary sizes:
33245


In [363]:
dataset = load_dataset("wikitext", "wikitext-2-v1")
train, test, val = dataset['train']['text'], dataset['test']['text'], dataset['validation']['text']

In [364]:
train_iter_map = to_map_style_dataset(train)

## Iterators

In [365]:
def collate_batch(
    batch,
    pipeline,
    vocab,
    device,
    max_padding=128,
    pad_id=2,
):
    bs_id = torch.tensor([0], device=device)  # <s> token id
    eos_id = torch.tensor([1], device=device)  # </s> token id
    x_list, y_list = [], []
    for _src in batch:
        _x = _src[:-1]
        _y = _src[1:]
        processed_x = torch.cat(
            [
                bs_id,
                torch.tensor(
                    vocab(pipeline(_x)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        processed_y = torch.cat(
            [
                bs_id,
                torch.tensor(
                    vocab(pipeline(_y)),
                    dtype=torch.int64,
                    device=device,
                ),
                eos_id,
            ],
            0,
        )
        x_list.append(
            # warning - overwrites values for negative values of padding - len
            pad(
                processed_x,
                (
                    0,
                    max_padding - len(processed_x),
                ),
                value=pad_id,
            )
        )
        y_list.append(
            pad(
                processed_y,
                (0, max_padding - len(processed_y)),
                value=pad_id,
            )
        )

    x = torch.stack(x_list)
    y = torch.stack(y_list)
    return (x, y)

In [366]:
def create_dataloaders(
    device,
    vocab,
    spacy_en,
    batch_size=12000,
    max_padding=128,
    is_distributed=True,
):

    def tokenize_en(text):
        return tokenize(text, spacy_en)

    def collate_fn(batch):
        return collate_batch(
            batch,
            tokenize_en,
            vocab,
            device,
            max_padding=max_padding,
            pad_id=vocab.get_stoi()["<blank>"],
        )

    dataset = load_dataset("wikitext", "wikitext-2-v1")
    train_iter, test_iter, valid_iter = dataset['train']['text'], dataset['test']['text'], dataset['validation']['text']

    train_iter_map = to_map_style_dataset(
        train_iter
    )  # DistributedSampler needs a dataset len()
    train_sampler = (
        DistributedSampler(train_iter_map) if is_distributed else None
    )
    valid_iter_map = to_map_style_dataset(valid_iter)
    valid_sampler = (
        DistributedSampler(valid_iter_map) if is_distributed else None
    )

    test_iter_map = to_map_style_dataset(test_iter)
    test_sampler = (
        DistributedSampler(test_iter_map) if is_distributed else None
    )

    train_dataloader = DataLoader(
        train_iter_map,
        batch_size=batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        collate_fn=collate_fn,
    )
    valid_dataloader = DataLoader(
        valid_iter_map,
        batch_size=batch_size,
        shuffle=(valid_sampler is None),
        sampler=valid_sampler,
        collate_fn=collate_fn,
    )
    test_dataloader = DataLoader(
        test_iter_map,
        batch_size=batch_size,
        shuffle=(test_sampler is None),
        sampler=test_sampler,
        collate_fn=collate_fn,
    )
    return train_dataloader, valid_dataloader, test_dataloader

## Train the System

In [367]:
def train_worker(
    gpu,
    ngpus_per_node,
    vocab,
    spacy_en,
    config,
    is_distributed=False,
):
    print(f"Train worker process using GPU: {gpu} for training", flush=True)
    torch.cuda.set_device(gpu)

    pad_idx = vocab["<blank>"]
    d_model = 512
    model = make_model(len(vocab))
    model.cuda(gpu)
    module = model
    is_main_process = True
    if is_distributed:
        dist.init_process_group(
            "nccl", init_method="env://", rank=gpu, world_size=ngpus_per_node
        )
        model = DDP(model, device_ids=[gpu])
        module = model.module
        is_main_process = gpu == 0

    criterion = LabelSmoothing(
        size=len(vocab), padding_idx=pad_idx, smoothing=0.1
    )
    criterion.cuda(gpu)

    train_dataloader, valid_dataloader, _ = create_dataloaders(
        gpu,
        vocab,
        spacy_en,
        batch_size=config["batch_size"] // ngpus_per_node,
        max_padding=config["max_padding"],
        is_distributed=is_distributed,
    )

    optimizer = torch.optim.Adam(
        model.parameters(), lr=config["base_lr"], betas=(0.9, 0.98), eps=1e-9
    )
    lr_scheduler = LambdaLR(
        optimizer=optimizer,
        lr_lambda=lambda step: rate(
            step, d_model, factor=1, warmup=config["warmup"]
        ),
    )
    train_state = TrainState()

    for epoch in range(config["num_epochs"]):
        if is_distributed:
            train_dataloader.sampler.set_epoch(epoch)
            valid_dataloader.sampler.set_epoch(epoch)

        model.train()
        print(f"[GPU{gpu}] Epoch {epoch} Training ====", flush=True)
        _, train_state = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in train_dataloader),
            model,
            SimpleLossCompute(module.generator, criterion),
            optimizer,
            lr_scheduler,
            mode="train+log",
            accum_iter=config["accum_iter"],
            train_state=train_state,
        )

        GPUtil.showUtilization()
        if is_main_process:
            file_path = "%s%.2d.pt" % (config["file_prefix"], epoch)
            torch.save(module.state_dict(), file_path)
        torch.cuda.empty_cache()

        print(f"[GPU{gpu}] Epoch {epoch} Validation ====", flush=True)
        model.eval()
        sloss = run_epoch(
            (Batch(b[0], b[1], pad_idx) for b in valid_dataloader),
            model,
            SimpleLossCompute(module.generator, criterion),
            DummyOptimizer(),
            DummyScheduler(),
            mode="eval",
        )
        print(sloss)
        torch.cuda.empty_cache()

    if is_main_process:
        file_path = "%sfinal.pt" % config["file_prefix"]
        torch.save(module.state_dict(), file_path)

In [368]:
def train_distributed_model(vocab, spacy_en, config):
    # from the_annotated_transformer import train_worker

    ngpus = torch.cuda.device_count()
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12356"
    print(f"Number of GPUs detected: {ngpus}")
    print("Spawning training processes ...")
    mp.spawn(
        train_worker,
        nprocs=ngpus,
        args=(ngpus, vocab, spacy_en, config, True),
    )


def train_model(vocab, spacy_en, config):
    if config["distributed"]:
        train_distributed_model(vocab, spacy_en, config)
    else:
        train_worker(0, 1, vocab, spacy_en, config, False)


def load_trained_model(config):
    model_path = "%sfinal.pt" % config["file_prefix"]
    if not exists(model_path):
        train_model(vocab, spacy_en, config)

    model = make_model(len(vocab), N=6)
    model.load_state_dict(torch.load(model_path))
    return model


if is_interactive_notebook():
    config = {
      "batch_size": 32,
      "distributed": False,
      "num_epochs": 8,
      "accum_iter": 10,
      "base_lr": 1.0,
      "max_padding": 72,
      "warmup": 3000,
      "file_prefix": "wikitext2-V1_",
    }
    model = load_trained_model(config)

In [369]:
model = load_trained_model(config)

In [370]:
def test_model(model, vocab, spacy_en, config):
    model.eval()
    dataset = load_dataset("wikitext", "wikitext-2-v1")
    test_list = dataset['test']['text'][250:261]

    device = torch.cuda.set_device(0)

    def tokenize_en(text):
        return tokenize(text, spacy_en)


    def collate_fn(sen_list):
        return collate_batch(
            sen_list,
            tokenize_en,
            vocab,
            device,
            max_padding=config["max_padding"],
            pad_id=vocab.get_stoi()["<blank>"],
        )

    def greedy_decode(model, x, max_len, vocab=None):
        word_list = []

        for i in range(max_len - 1):
            out = model.decode(x)
            prob = model.generator(out[:, -1])
            _, next_word = torch.max(prob, dim=1)
            next_word = next_word.data[0]
            if vocab:
                word_list.append(vocab.lookup_token(next_word.item()))
                # print("Word List Prediction:", word_list)

            x = torch.cat([x, torch.empty(1, 1).type_as(x.data).fill_(next_word)], dim=1)

        sentence = ' '.join(word_list)
        print("Word List Prediction: ", sentence)

        # if vocab is None:
        # print("Model Prediction:", x)
        return x

    for sen in test_list:
        sen_list = [sen[:50]]
        print(sen_list)
        b = collate_fn(sen_list)
        x = greedy_decode(model, b[0], 15, vocab)

model = load_trained_model(config)
test_model(model, vocab, spacy_en, config)

[' During gunnery training on 5 May , there was a pr']
Word List Prediction:  . The < unk > for the body was also applied for the <
[' They returned home on 14 June and the <unk> began']
Word List Prediction:  . The < unk > for the body was given the < unk >
[' After completing her sea trials , <unk> was attac']
Word List Prediction:  . The resulting in the same view was given the < unk > ,
['']
Word List Prediction:  . The < unk > for the < unk > , < unk >
[' = = = = Battle of Cape <unk> = = = = \n']
Word List Prediction:  . 
 for the < unk > the < unk > , < unk
['']
Word List Prediction:  . The < unk > for the < unk > , < unk >
[' <unk> of aircraft and <unk> problems greatly <unk']
Word List Prediction:  . 
 for the < unk > the < unk > < unk >
[' Although they had lost contact during the night ,']
Word List Prediction:  . The resulting in the < unk > , < unk > , <
[' <unk> was attacked by 80 @-@ odd aircraft from th']
Word List Prediction:  . The < unk > for the < unk > , < unk >
[' 

## Test the Model

In [387]:
def check_outputs(
    test_dataloader,
    model,
    vocab,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>",
):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(test_dataloader))
        rb = Batch(b[0][:, 0:5], b[0][:, 5:10], pad_idx)

        full_sentence = [
            vocab.get_itos()[i] for i in b[0][0] if i != pad_idx
        ]

        x_tokens = [
            vocab.get_itos()[i] for i in rb.x[0] if i != pad_idx
        ]
        y_tokens = [
            vocab.get_itos()[i] for i in rb.y[0] if i != pad_idx
        ]

        print(
            "Full Sentence              : "
            + " ".join(full_sentence).replace("\n", "")
        )

        print(
            "Source Text (Input)        : "
            + " ".join(x_tokens).replace("\n", "")
        )
        print(
            "Target Text (Ground Truth) : "
            + " ".join(y_tokens).replace("\n", "")
        )
        model_out = greedy_decode(model, rb.x, 5)[0]
        model_txt = (
            " ".join(
                [vocab.get_itos()[i] for i in model_out if i != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Output               : " + model_txt.replace("\n", ""))
        results[idx] = (rb, x_tokens, y_tokens, model_out, model_txt)
    return results


def run_model_example(n_examples=5):
    global vocab, spacy_en

    print("Preparing Data ...")
    _, _, test_dataloader = create_dataloaders(
        torch.device("cpu"),
        vocab,
        spacy_en,
        batch_size=1,
        is_distributed=False,
    )

    print("Loading Trained Model ...")

    model = make_model(len(vocab), N=6)
    model_path = "%sfinal.pt" % config["file_prefix"]
    model.load_state_dict(
        torch.load(model_path, map_location=torch.device("cpu"))
    )

    print("Checking Model Outputs:")
    example_data = check_outputs(test_dataloader, model, vocab, n_examples=n_examples)
    return model, example_data

config = {
  "file_prefix": "wikitext2-V1_",
}

execute_example(run_model_example)

Preparing Data ...
Loading Trained Model ...
Checking Model Outputs:


Full Sentence              : <s> </s>
Source Text (Input)        : <s> </s>
Target Text (Ground Truth) : 
Model Output               : , , , , ,</s>


Full Sentence              : <s> </s>
Source Text (Input)        : <s> </s>
Target Text (Ground Truth) : 
Model Output               : , < , , ,</s>


Full Sentence              : <s> </s>
Source Text (Input)        : <s> </s>
Target Text (Ground Truth) : 
Model Output               :  ,  the (</s>


Full Sentence              : <s>   The Division of the City Schools of Manila , a branch of the Department of Education , refers to the city 's three @-@ tier public education system . It governs the 71 public elementary schools , 32 public high schools . </s>
Source Text (Input)        : <s>   The Division of
Target Text (Ground Truth) : the City Schools of Manila
Model Output               : the the the the the</s>


Full Sentence              : <s>   In 2000 , paleonto

# Attention Visualization

In [378]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    "convert a dense matrix to a data frame with row and column indices"
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s"
                % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s"
                % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        # if float(m[r,c]) != 0 and r < max_row and c < max_col],
        columns=["row", "column", "value", "row_token", "col_token"],
    )


def attn_map(attn, layer, head, row_tokens, col_tokens, max_dim=30):
    df = mtx2df(
        attn[0, head].data,
        max_dim,
        max_dim,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        .properties(height=400, width=400)
        .interactive()
    )

In [379]:
def get_encoder(model, layer):
    return model.encoder.layers[layer].self_attn.attn


def get_decoder_self(model, layer):
    return model.decoder.layers[layer].self_attn.attn


def get_decoder_src(model, layer):
    return model.decoder.layers[layer].src_attn.attn


def visualize_layer(model, layer, getter_fn, ntokens, row_tokens, col_tokens):
    # ntokens = last_example[0].ntokens
    attn = getter_fn(model, layer)
    n_heads = attn.shape[1]
    charts = [
        attn_map(
            attn,
            0,
            h,
            row_tokens=row_tokens,
            col_tokens=col_tokens,
            max_dim=ntokens,
        )
        for h in range(n_heads)
    ]
    assert n_heads == 8
    return alt.vconcat(
        charts[0]
        # | charts[1]
        | charts[2]
        # | charts[3]
        | charts[4]
        # | charts[5]
        | charts[6]
        # | charts[7]
        # layer + 1 due to 0-indexing
    ).properties(title="Layer %d" % (layer + 1))

In [384]:
def viz_decoder_self():
    model, example_data = run_model_example(n_examples=1)
    example = example_data[len(example_data) - 1]

    layer_viz = [
        visualize_layer(
            model,
            layer,
            get_decoder_self,
            len(example[1]),
            example[1],
            example[1],
        )
        for layer in range(6)
    ]
    return alt.hconcat(
        layer_viz[0]
        & layer_viz[1]
        & layer_viz[2]
        & layer_viz[3]
        & layer_viz[4]
        & layer_viz[5]
    )


show_example(viz_decoder_self)

Preparing Data ...
Loading Trained Model ...
Checking Model Outputs:


Full Sentence              : <s>   In 422 , Augustine of < unk > wrote about 2 < unk > 2 : 1 – 11 , where he believed Paul mentioned the coming of the < unk > . Though he rejects the theory , Augustine mentions that many Christians believed that < unk > was the < unk > or would return as the < unk > . He wrote , " so that in saying , ' For the mystery of < unk > < unk > already work , ' he alluded to < unk > , whose < unk > already seemed to be as the < unk > of < unk > . " </s>
Source Text (Input)        : <s>   In 422 ,
Target Text (Ground Truth) : Augustine of < unk >
Model Output               : a a the the the</s>
