In [1]:
# Import Statements
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as fxnl

from torch.nn import TransformerEncoder, TransformerEncoderLayer

import torchtext as t_txt

from torchtext.data.utils import get_tokenizer

In [2]:
#  AIM: Train a sequence to sequence model.
# Define the model.

class TransformerModel(nn.Module):

    def __init__(self, n_token, n_inp, n_head, n_hid, n_layers, drop_out=0.5):
        super(TransformerModel, self).__init__()

        self.model_type = 'Transformer'
        self.source_mask = None
        self.positional_encoder = PositionalEncoding(n_inp, drop_out)

        encoder_layers = TransformerEncoderLayer(n_inp, n_head, n_hid, drop_out)

        self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layers, num_layers=n_layers)
        self.encoder = nn.Embedding(n_token, n_inp)
        self.n_inp = n_inp
        self.decoder = nn.Linear(n_inp, n_token)

        self.initialize_weights()
    
    def generate_square_subsequent_mask(self, size):
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def initialize_weights(self):
        initial_range = 0.1
        
        self.encoder.weight.data.uniform_(-initial_range, initial_range)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initial_range, initial_range)

    def forward(self, source):
        if self.source_mask is None or self.source_mask.size(0) != len(source):
            device = source.device
            mask = self.generate_square_subsequent_mask(size=len(source)).to(device)
            self.source_mask = mask
        
        source = self.encoder(source) * math.sqrt(self.n_inp)
        source = self.positional_encoder(source)
        out_put = self.transformer_encoder(source, self.source_mask)
        out_put = self.decoder(out_put)
        return out_put


In [3]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, drop_out=0.1, max_length=5000):
        super(PositionalEncoding, self).__init__()
        self.drop_out = nn.Dropout(p=drop_out)

        pe = torch.zeros(max_length, d_model)
        position = torch.arange(start=0, end=max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)

        self.register_buffer('pe', pe)
    
    def forward(self, input_tensor):
        input_tensor = input_tensor + self.pe[:input_tensor.size(0), :]
        return self.drop_out(input_tensor)

In [4]:
# Load and Batch Data

TEXT_DATA = t_txt.data.Field(tokenize=get_tokenizer("basic_english"), init_token='<sos>', eos_token='<eos', lower=True)

train_txt_dataset, valid_txt_dataset, test_txt_dataset = t_txt.datasets.WikiText2.splits(TEXT_DATA)
TEXT_DATA.build_vocab(train_txt_dataset)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, batches):
    data = TEXT_DATA.numericalize([data.examples[0].text])

    num_batch = data.size(0) // batch_size
    
    data = data.narrow(0, 0, num_batch * batches)
    data = data.view(batch_size, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_dataset = batchify(train_txt_dataset, batch_size)
valid_dataset = batchify(valid_txt_dataset, eval_batch_size)
test_dataset = batchify(test_txt_dataset, eval_batch_size)

In [5]:
# Functions to generate input and target sequence
bptt = 35
def get_batch(source, index):
    sequence_length = min(bptt, len(source) - 1 - index)
    data = source[index: index + sequence_length]
    target = source[index + 1: sequence_length + 1 + index].view(-1)
    return data, target

In [6]:
# Initiate an instance
# Hyper-Parameters

num_tokens = len(TEXT_DATA.vocab.stoi) # Size of the vocabulary
embedding_size = 200 # Embedding dimension/size
num_hidden = 200 # Dimension of the feed forward network model in nn.TransformerEncoder
num_layers = 2 # Number of nn.TranformerEncoderLayer
num_head = 2 # Number of heads in the multi-head attention models
drop_out = 0.2 # Drop out value
model = TransformerModel(num_tokens, embedding_size, num_head, num_hidden, num_layers, drop_out).to(device)

In [7]:
# Run model

loss_function = nn.CrossEntropyLoss()
learning_rate = 5.0
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)




In [8]:
# Train Function/ Method

def train():
    model.train() # Turn on train mode

    total_loss = 0.0 
    start_time = time.time()
    num_tokens = len(TEXT_DATA.vocab.stoi)

    for batch, idx in enumerate(range(0, train_dataset.size(0) - 1, bptt)):
        data, targets = get_batch(train_dataset, idx)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output.view(-1, num_tokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200

        if batch % log_interval == 0 and batch > 0:
            current_loss = total_loss / log_interval
            elapsed_time = time.time() - start_time

            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_dataset) // bptt, scheduler.get_lr()[0],
                    elapsed_time * 1000 / log_interval,
                    current_loss, math.exp(current_loss)))
            
            total_loss = 0
            start_time = time.time()


In [9]:
def evaluate_model(eval_model, data_source):
    eval_model.eval()
    total_loss = 0
    num_tokens = len(TEXT_DATA.vocab.stoi)

    with torch.no_grad():
        for idx in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, idx)
            output = eval_model(data)
            output_flat = output.view(-1, num_tokens)
            total_loss += len(data) * loss_function(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [10]:
best_validation_loss = float("inf")
epochs = 3
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    validation_loss = evaluate_model(model, valid_dataset)
    print("-" * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch,     (time.time() - epoch_start_time), validation_loss, math.exp(validation_loss)))
    print('-' * 89)

    if validation_loss < best_validation_loss:
        best_validation_loss = validation_loss
        best_model = model
    
    scheduler.step()


| epoch   1 |   200/ 2981 batches | lr 5.00 | ms/batch 70.81 | loss  8.00 | ppl  2982.90
| epoch   1 |   400/ 2981 batches | lr 5.00 | ms/batch 64.64 | loss  6.77 | ppl   872.78
| epoch   1 |   600/ 2981 batches | lr 5.00 | ms/batch 63.71 | loss  6.35 | ppl   575.07
| epoch   1 |   800/ 2981 batches | lr 5.00 | ms/batch 62.70 | loss  6.23 | ppl   505.45
| epoch   1 |  1000/ 2981 batches | lr 5.00 | ms/batch 65.03 | loss  6.12 | ppl   453.19
| epoch   1 |  1200/ 2981 batches | lr 5.00 | ms/batch 63.57 | loss  6.08 | ppl   436.93
| epoch   1 |  1400/ 2981 batches | lr 5.00 | ms/batch 64.17 | loss  6.04 | ppl   421.66
| epoch   1 |  1600/ 2981 batches | lr 5.00 | ms/batch 64.66 | loss  6.05 | ppl   425.42
| epoch   1 |  1800/ 2981 batches | lr 5.00 | ms/batch 66.05 | loss  5.95 | ppl   382.88
| epoch   1 |  2000/ 2981 batches | lr 5.00 | ms/batch 69.58 | loss  5.96 | ppl   389.01
| epoch   1 |  2200/ 2981 batches | lr 5.00 | ms/batch 66.22 | loss  5.84 | ppl   345.02
| epoch   1 |  2400/ 

In [11]:
test_loss = evaluate_model(best_model, test_dataset)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.38 | test ppl   218.08
