# Breaking the Transformer Bottleneck

In this notebook, we will examine how the Mixture of Softmaxes model proposed in [Yang et al. (2018)](https://arxiv.org/pdf/1711.03953.pdf) affects the performance of an encoder-only Transformer as outlined in [Vaswani et al. (2017)](https://arxiv.org/pdf/1706.03762.pdf).  Our workflow closely follows the [Transformer tutorial](https://pytorch.org/tutorials/beginner/transformer_tutorial.html) on the PyTorch website.  We begin by importing relevant packages that will be used throughout this notebook.

In [1]:
# Importing packages used throughout
import io
import time
import math
from collections import Counter

import torch
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab

## Importing custom files
from model import transformer_model
from data_import import batching

## Establishing devices
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cuda


In [2]:
# Establishing randomness
torch.manual_seed(26)


<torch._C.Generator at 0x7f326a4c8870>

# Transformer Models

Below are the hyperparameters chosen for the models.

In [3]:
# Establish hyperparameters

emsize = 300 # embedding dimension
nhid = 300 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 4 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 4 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
num_softmaxes = 10 # number of softmaxes
epochs = 3 # number of epochs
lr = 7.0 # learning rate
gradient_clip = 0.25 # what to clip the gradients by

## Establish Training/Evaluating Functions

As mentioned before, this is a encoder-only Transformer, as defined in [transformer_model.py](/model/transformer_model.py).  We will be using negative-log-likelihood loss as well as stochastic gradient descent.

In [4]:
criterion = nn.NLLLoss()

def train(model: transformer_model.TransformerModel, train_data: torch.Tensor, optimizer: torch.optim.SGD, learning_rate: float) -> None:
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(chunk_length).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, chunk_length)):
        data, targets = batching.get_batch(train_data, i, chunk_length)
        optimizer.zero_grad()
        if data.size(0) != chunk_length:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // chunk_length, learning_rate,
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model: transformer_model.TransformerModel, data_source: torch.Tensor) -> float:
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    src_mask = eval_model.generate_square_subsequent_mask(chunk_length).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, chunk_length):
            data, targets = batching.get_batch(data_source, i, chunk_length)
            if data.size(0) != chunk_length:
                src_mask = eval_model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

# WikiText-2

The first dataset we will model the language of is [WikiText-2](https://paperswithcode.com/dataset/wikitext-2).  The data is downloaded into the [.data](/.data) folder.

In [5]:
# Importing relevant packages
from torchtext.datasets import WikiText2
from data_import import wikitext

## Preparing Data

In [6]:
# Establishing hyperparameters
batch_size = 20
eval_batch_size = 10
chunk_length = 35

In [7]:
# Setting up vocabulary
train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
counter = Counter()
for line in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter)

ntokens = len(vocab.stoi) # the size of vocabulary

In [8]:
# Splitting data
train_iter, val_iter, test_iter = WikiText2()
train_data = wikitext.data_process(train_iter, vocab, tokenizer)
val_data = wikitext.data_process(val_iter, vocab, tokenizer)
test_data = wikitext.data_process(test_iter, vocab, tokenizer)

## Batch data
train_data = batching.batchify(train_data, batch_size, device)
val_data = batching.batchify(val_data, eval_batch_size, device)
test_data = batching.batchify(test_data, eval_batch_size, device)

## Train the Models

In [9]:
# Train Stock Model
wikitext_stock_model = transformer_model.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout, 1).to(device)
optimizer = torch.optim.SGD(wikitext_stock_model.parameters(), lr=lr)

best_val_loss = float("inf")
best_wikitext_stock_model = None
learning_rate = lr

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(wikitext_stock_model, train_data, optimizer, learning_rate)
    val_loss = evaluate(wikitext_stock_model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                    val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_wikitext_stock_model = wikitext_stock_model
    else:
        learning_rate = learning_rate / 1.75
        for g in optimizer.param_groups:
            g["lr"] = learning_rate

| epoch   1 |   200/ 2928 batches | lr 7.00 | ms/batch 16.23 | loss  8.15 | ppl  3465.77
| epoch   1 |   400/ 2928 batches | lr 7.00 | ms/batch 15.90 | loss  6.90 | ppl   987.74
| epoch   1 |   600/ 2928 batches | lr 7.00 | ms/batch 15.92 | loss  6.47 | ppl   648.32
| epoch   1 |   800/ 2928 batches | lr 7.00 | ms/batch 15.96 | loss  6.31 | ppl   551.89
| epoch   1 |  1000/ 2928 batches | lr 7.00 | ms/batch 15.95 | loss  6.17 | ppl   479.10
| epoch   1 |  1200/ 2928 batches | lr 7.00 | ms/batch 15.98 | loss  6.13 | ppl   460.62
| epoch   1 |  1400/ 2928 batches | lr 7.00 | ms/batch 16.02 | loss  6.07 | ppl   430.85
| epoch   1 |  1600/ 2928 batches | lr 7.00 | ms/batch 16.04 | loss  6.05 | ppl   424.57
| epoch   1 |  1800/ 2928 batches | lr 7.00 | ms/batch 16.03 | loss  5.95 | ppl   385.41
| epoch   1 |  2000/ 2928 batches | lr 7.00 | ms/batch 16.03 | loss  5.94 | ppl   378.52
| epoch   1 |  2200/ 2928 batches | lr 7.00 | ms/batch 16.04 | loss  5.81 | ppl   331.97
| epoch   1 |  2400/ 

In [10]:
# Train the MoS Model
wikitext_MoS_model = transformer_model.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout, num_softmaxes).to(device)
optimizer = torch.optim.SGD(wikitext_MoS_model.parameters(), lr=lr)

best_val_loss = float("inf")
best_wikitext_MoS_model = None
learning_rate = lr

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(wikitext_MoS_model, train_data, optimizer, learning_rate)
    val_loss = evaluate(wikitext_MoS_model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                    val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_wikitext_MoS_model = wikitext_MoS_model
    else:
        learning_rate = learning_rate / 1.75
        for g in optimizer.param_groups:
            g["lr"] = learning_rate

| epoch   1 |   200/ 2928 batches | lr 7.00 | ms/batch 115.46 | loss  9.67 | ppl 15771.67
| epoch   1 |   400/ 2928 batches | lr 7.00 | ms/batch 114.91 | loss  7.65 | ppl  2096.44
| epoch   1 |   600/ 2928 batches | lr 7.00 | ms/batch 115.04 | loss  6.96 | ppl  1051.03
| epoch   1 |   800/ 2928 batches | lr 7.00 | ms/batch 115.10 | loss  6.48 | ppl   652.72
| epoch   1 |  1000/ 2928 batches | lr 7.00 | ms/batch 115.35 | loss  6.26 | ppl   521.74
| epoch   1 |  1200/ 2928 batches | lr 7.00 | ms/batch 115.37 | loss  6.19 | ppl   486.39
| epoch   1 |  1400/ 2928 batches | lr 7.00 | ms/batch 115.49 | loss  6.10 | ppl   447.44
| epoch   1 |  1600/ 2928 batches | lr 7.00 | ms/batch 115.46 | loss  6.09 | ppl   439.86
| epoch   1 |  1800/ 2928 batches | lr 7.00 | ms/batch 115.43 | loss  5.99 | ppl   400.12
| epoch   1 |  2000/ 2928 batches | lr 7.00 | ms/batch 115.50 | loss  5.97 | ppl   392.90
| epoch   1 |  2200/ 2928 batches | lr 7.00 | ms/batch 115.47 | loss  5.85 | ppl   346.28
| epoch   

## Evaluate the Models

In [11]:
# Evaluate Stock Model
test_loss = evaluate(best_wikitext_stock_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.33 | test ppl   205.98


In [12]:
# Evaluate MoS Model
test_loss = evaluate(best_wikitext_MoS_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.33 | test ppl   205.53


# Penn Treebank

The second dataset we will model the language of is the [Penn Treebank](https://catalog.ldc.upenn.edu/LDC99T42).

In [13]:
# Importing relevant packages
from torchtext.datasets import PennTreebank
from data_import import pentree_bank

## Preparing Data

In [14]:
# Establishing hyperparameters
batch_size = 20
eval_batch_size = 10
chunk_length = 35


In [15]:
# Setting up vocabulary
train_iter, valid_iter, test_iter = PennTreebank()
corpus = pentree_bank.Corpus(train_iter, valid_iter, test_iter)

ntokens = len(corpus.dictionary) # the size of vocabulary

In [16]:
# Batch data
train_data = batching.batchify(corpus.train, batch_size, device)
val_data = batching.batchify(corpus.valid, eval_batch_size, device)
test_data = batching.batchify(corpus.test, eval_batch_size, device)

## Train the Models

In [17]:
# Train Stock Model
penntree_stock_model = transformer_model.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout, 1).to(device)
optimizer = torch.optim.SGD(penntree_stock_model.parameters(), lr=lr)

best_val_loss = float("inf")
best_penntree_stock_model = None
learning_rate = lr

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(penntree_stock_model, train_data, optimizer, learning_rate)
    val_loss = evaluate(penntree_stock_model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                    val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_penntree_stock_model = penntree_stock_model
    else:
        learning_rate = learning_rate / 1.75
        for g in optimizer.param_groups:
            g["lr"] = learning_rate

| epoch   1 |   200/ 1388 batches | lr 7.00 | ms/batch  9.46 | loss  6.98 | ppl  1078.86
| epoch   1 |   400/ 1388 batches | lr 7.00 | ms/batch  9.37 | loss  5.93 | ppl   376.72
| epoch   1 |   600/ 1388 batches | lr 7.00 | ms/batch  9.49 | loss  5.66 | ppl   288.54
| epoch   1 |   800/ 1388 batches | lr 7.00 | ms/batch  9.48 | loss  5.50 | ppl   245.16
| epoch   1 |  1000/ 1388 batches | lr 7.00 | ms/batch  9.43 | loss  5.39 | ppl   218.21
| epoch   1 |  1200/ 1388 batches | lr 7.00 | ms/batch  9.39 | loss  5.28 | ppl   197.20
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 13.61s | valid loss  5.28 | valid ppl   195.46
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1388 batches | lr 7.00 | ms/batch  9.49 | loss  5.19 | ppl   179.26
| epoch   2 |   400/ 1388 batches | lr 7.00 | ms/batch  9.43 | loss  5.12 | ppl   167.63
| epoch   2 |   600/ 1388 batches 

In [18]:
# Train the MoS Model
penntree_MoS_model = transformer_model.TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout, num_softmaxes).to(device)
optimizer = torch.optim.SGD(penntree_MoS_model.parameters(), lr=lr)

best_val_loss = float("inf")
best_penntree_MoS_model = None
learning_rate = lr

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train(penntree_MoS_model, train_data, optimizer, learning_rate)
    val_loss = evaluate(penntree_MoS_model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                    val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_penntree_MoS_model = penntree_MoS_model
    else:
        learning_rate = learning_rate / 1.75
        for g in optimizer.param_groups:
            g["lr"] = learning_rate

| epoch   1 |   200/ 1388 batches | lr 7.00 | ms/batch 43.37 | loss  9.02 | ppl  8254.13
| epoch   1 |   400/ 1388 batches | lr 7.00 | ms/batch 43.36 | loss  6.73 | ppl   833.80
| epoch   1 |   600/ 1388 batches | lr 7.00 | ms/batch 43.11 | loss  6.14 | ppl   462.39
| epoch   1 |   800/ 1388 batches | lr 7.00 | ms/batch 43.12 | loss  5.74 | ppl   311.25
| epoch   1 |  1000/ 1388 batches | lr 7.00 | ms/batch 43.10 | loss  5.54 | ppl   254.95
| epoch   1 |  1200/ 1388 batches | lr 7.00 | ms/batch 43.07 | loss  5.42 | ppl   224.89
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 61.62s | valid loss  5.34 | valid ppl   207.97
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1388 batches | lr 7.00 | ms/batch 43.35 | loss  5.32 | ppl   203.38
| epoch   2 |   400/ 1388 batches | lr 7.00 | ms/batch 43.22 | loss  5.26 | ppl   192.88
| epoch   2 |   600/ 1388 batches 

## Evaluate the Models

In [19]:
# Evaluate Stock Model
test_loss = evaluate(best_penntree_stock_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  4.88 | test ppl   131.62


In [20]:
# Evaluate MoS Model
test_loss = evaluate(best_penntree_MoS_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  4.93 | test ppl   138.42


With only three epochs, each model (stock and MoS) performs about the same.  If we train the models for many more epochs, though, say 50, it becomes apparent that the MoS achieves a significantly lower perplexity, thus proving its usefulness in the encoder-only Transformer architecture.