# Breaking the Transformer Bottleneck

## Preparing Data

In [None]:
# Importing relevant packages
import io
import time
import math
from collections import Counter

import torch
import torch.nn as nn
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab

## Importing custom files
from data_import import wikitext
from model import transformer_model

## Establishing devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Setting up vocabulary
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
counter = Counter()
for line in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter)

In [None]:
# Splitting data
train_iter, val_iter, test_iter = WikiText2()
train_data = wikitext.data_process(train_iter, vocab, tokenizer)
val_data = wikitext.data_process(val_iter, vocab, tokenizer)
test_data = wikitext.data_process(test_iter, vocab, tokenizer)

## Batch data
batch_size = 20
eval_batch_size = 10
train_data = wikitext.batchify(train_data, batch_size, device)
val_data = wikitext.batchify(val_data, eval_batch_size, device)
test_data = wikitext.batchify(test_data, eval_batch_size, device)

chunk_length = 35

## Initiate Instances

In [None]:
# Establish hyperparameters

ntokens = len(vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 4 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 4 # the number of heads in the multiheadattention models
dropout = 0.25 # the dropout value
num_softmaxes = 5 # number of softmaxes
epochs = 50 # number of epochs
lr = 7.0 # learning rate
gradient_clip = 0.25 # what to clip the gradients by

model = transformer_model.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout,      num_softmaxes).to(device)

## Train the Model

In [None]:
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2, gamma=0.95)

def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(chunk_length).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, chunk_length)):
        data, targets = wikitext.get_batch(train_data, i, chunk_length)
        optimizer.zero_grad()
        if data.size(0) != chunk_length:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // chunk_length, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    src_mask = model.generate_square_subsequent_mask(chunk_length).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, chunk_length):
            data, targets = wikitext.get_batch(data_source, i, chunk_length)
            if data.size(0) != chunk_length:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [None]:
# Loop over epochs
best_val_loss = float("inf")
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

## Evaluate the Model

In [None]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)