# Breaking the Transformer Bottleneck

## Preparing Data

In [1]:
# Importing relevant packages
import io
import time
import math
from collections import Counter

import torch
import torch.nn as nn
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab

## Importing custom files
from data_import import wikitext
from model import transformer_model

## Establishing devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# Setting up vocabulary
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
counter = Counter()
for line in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter)

In [3]:
# Splitting data
train_iter, val_iter, test_iter = WikiText2()
train_data = wikitext.data_process(train_iter, vocab, tokenizer)
val_data = wikitext.data_process(val_iter, vocab, tokenizer)
test_data = wikitext.data_process(test_iter, vocab, tokenizer)

## Batch data
batch_size = 20
eval_batch_size = 10
train_data = wikitext.batchify(train_data, batch_size, device)
val_data = wikitext.batchify(val_data, eval_batch_size, device)
test_data = wikitext.batchify(test_data, eval_batch_size, device)

chunk_length = 35

## Initiate Instances

In [4]:
# Establish hyperparameters

ntokens = len(vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
num_softmaxes = 10
model = transformer_model.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout,      num_softmaxes).to(device)

## Train the Model

In [5]:
criterion = nn.NLLLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(chunk_length).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, chunk_length)):
        data, targets = wikitext.get_batch(train_data, i, chunk_length)
        optimizer.zero_grad()
        if data.size(0) != chunk_length:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // chunk_length, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    src_mask = model.generate_square_subsequent_mask(chunk_length).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, chunk_length):
            data, targets = wikitext.get_batch(data_source, i, chunk_length)
            if data.size(0) != chunk_length:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [6]:
# Loop over epochs
best_val_loss = float("inf")
epochs = 3 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 102.91 | loss 11.18 | ppl 71892.45
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 102.46 | loss  7.80 | ppl  2451.59
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 102.54 | loss  6.83 | ppl   921.47
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 102.58 | loss  6.36 | ppl   578.53
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 102.64 | loss  6.15 | ppl   470.46
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 102.81 | loss  6.11 | ppl   450.13
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 102.93 | loss  6.05 | ppl   424.70
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 103.09 | loss  6.04 | ppl   420.70
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 103.21 | loss  5.96 | ppl   388.58
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 103.23 | loss  5.95 | ppl   385.66
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 103.34 | loss  5.84 | ppl   344.79
| epoch   

## Evaluate the Model

In [7]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.48 | test ppl   238.87
