Code to replicate transformer architecture, from pytorch tutorial

In [None]:
# https://arxiv.org/pdf/1706.03762.pdf
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModel(nn.Module):

    def __init__(self, vocab_size, d_model, nhead, nhid, nlayers, device, dropout=0.5):
        super(TransformerModel, self).__init__()
        
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = TransformerEncoderLayer(d_model, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, vocab_size)

        self.init_weights()

        self.device = device
        self.to(self.device)

    def _generate_square_subsequent_mask(self, sz):
        src_mask = torch.tril(torch.ones(sz, sz))
        mask = src_mask.bool()
        src_mask = src_mask.masked_fill(mask, 0.0).masked_fill(~mask, float('-inf'))
        return src_mask

    def init_weights(self):
        initrange = 0.1
        self.embed.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, input):
        # input is of shape [bptt_len, batch_size]

        if self.src_mask is None or self.src_mask.size(0) != len(input):
            mask = self._generate_square_subsequent_mask(len(input)).to(device)
            self.src_mask = mask

        input = input.to(self.device)

        # refer paper
        # In the embedding layers, weights are multiplied by sqrt(d_model)
        embedded = self.embed(input) * math.sqrt(self.d_model)

        embedded = self.pos_encoder(embedded)
        # embedded is of shape [bptt_len, batch_size, d_model]

        output = self.transformer_encoder(embedded, self.src_mask)
        # output is of shape [bptt_len, batch_size, d_model]

        output = self.linear(output)
        # output shape is [bptt, batch_size, vocab_size]

        return output

In [2]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, seq_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(seq_len, d_model)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
        i = torch.arange(0, d_model, 2).float()
        pe[:, 0::2] = torch.sin(position / torch.pow(10000, i / d_model))
        pe[:, 1::2] = torch.cos(position / torch.pow(10000, i / d_model))
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [3]:
import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.data import BPTTIterator
TEXT = torchtext.data.Field(tokenize=get_tokenizer("spacy"),
                            init_token='<sos>',
                            eos_token='<eos>',
                            lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

batch_size = 20
eval_batch_size = batch_size
bptt = 35

train_data, val_data, test_data = BPTTIterator.splits((train_txt, val_txt, test_txt), batch_size=batch_size, bptt_len=bptt, device=device)

In [4]:
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, device, dropout)

In [5]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

import time
def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    ntokens = len(TEXT.vocab.stoi)
    for k, batch in enumerate(train_data):
        data, targets = batch.text, batch.target
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output.view(-1, ntokens), targets.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if k % log_interval == 0 and k > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/k {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, k, len(train_data) // bptt, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()


def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    ntokens = len(TEXT.vocab.stoi)
    with torch.no_grad():
        for k, batch in enumerate(data_source):
            data, targets = batch.text, batch.target
            output = eval_model(data)
            output_flat = output.view(-1, ntokens)
            total_loss += criterion(output_flat, targets.view(-1)).item()
    return total_loss / (len(data_source) - 1)

In [6]:
best_val_loss = float("inf")
epochs = 3 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

# output: torch.Size([35, 20, 200])



| epoch   1 |   200/   91 batches | lr 5.00 | ms/k 40.40 | loss  7.37 | ppl  1584.59
| epoch   1 |   400/   91 batches | lr 5.00 | ms/k 37.04 | loss  6.29 | ppl   539.04
| epoch   1 |   600/   91 batches | lr 5.00 | ms/k 37.12 | loss  5.97 | ppl   392.73
| epoch   1 |   800/   91 batches | lr 5.00 | ms/k 37.25 | loss  5.81 | ppl   332.01
| epoch   1 |  1000/   91 batches | lr 5.00 | ms/k 37.40 | loss  5.82 | ppl   336.10
| epoch   1 |  1200/   91 batches | lr 5.00 | ms/k 37.41 | loss  5.77 | ppl   318.97
| epoch   1 |  1400/   91 batches | lr 5.00 | ms/k 37.53 | loss  5.71 | ppl   302.11
| epoch   1 |  1600/   91 batches | lr 5.00 | ms/k 37.59 | loss  5.61 | ppl   272.73
| epoch   1 |  1800/   91 batches | lr 5.00 | ms/k 37.52 | loss  5.64 | ppl   280.33
| epoch   1 |  2000/   91 batches | lr 5.00 | ms/k 37.53 | loss  5.64 | ppl   280.91
| epoch   1 |  2200/   91 batches | lr 5.00 | ms/k 37.52 | loss  5.58 | ppl   266.32
| epoch   1 |  2400/   91 batches | lr 5.00 | ms/k 37.53 | loss  

In [7]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  4.80 | test ppl   120.95


In [8]:
with torch.no_grad():

    batch = next(iter(test_data))
    data, targets = batch.text, batch.target

    t = targets.view(bptt, -1)

    output = model(data)

    decoded = []

    topv, topi = output.topk(1)
    topi = topi.squeeze(1)

    decoder_input = topi.detach()

    for i in range(5):
        true = [TEXT.vocab.itos[w] for w in t[:, i]]
        decoded = [TEXT.vocab.itos[w] for w in decoder_input[:, i, 0]]

        print(true)
        print(decoded)
        print()


['<eos>', ' ', '=', 'robert', '<', 'unk', '>', '=', '<eos>', ' ', '<eos>', ' ', 'robert', '<', 'unk', '>', 'is', 'an', 'english', 'film', ',', 'television', 'and', 'theatre', 'actor', '.', 'he', 'had', 'a', 'guest', '@-@', 'starring', 'role', 'on', 'the']
['<eos>', ' ', '<eos>', '=', '<', 'unk', '>', ',', '=', ' ', '<eos>', ' ', '<eos>', '<', 'unk', '>', ',', 'a', 'asteroid', ',', 'was', 'and', 'series', 'the', ',', 'of', '<eos>', 'was', 'been', '<', 'stars', '<', '<', 'of', 'a']

[',', 'the', 'barbettes', 'and', 'their', 'supporting', 'structures', 'were', 'removed', 'beginning', 'in', 'early', '1943', 'and', 'the', 'openings', 'in', 'the', 'middle', 'deck', 'were', 'covered', 'by', '152', 'mm', 'plates', 'salvaged', 'from', 'the', 'turret', 'armour', '.', 'all', 'of', 'the']
['and', 'and', '766th', 'were', 'the', '<', 'the', 'of', 'killed', 'from', 'of', 'a', 'years', '.', 'the', '766th', 'of', 'a', '<', 'ages', '.', 'killed', 'by', 'the', '@,@', ')', 'of', 'to', 'the', '<', '.', 'of