In [1]:
import math
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

import torchtext
from torchtext.data.utils import get_tokenizer

In [2]:
class Transformer(nn.Module):
    def __init__(self, num_token, num_inputs, num_heads, num_hidden, num_layers, dropout=0.3):
        super(Transformer, self).__init__()
        self.model_name = 'transformer'
        self.mask_source = None
        self.position_enc = PosEnc(num_inputs, dropout)
        layers_enc = TransformerEncoderLayer(num_inputs, num_heads, num_hidden, dropout)
        self.enc_transformer = TransformerEncoder(layers_enc, num_layers)
        self.enc = nn.Embedding(num_token, num_inputs)
        self.num_inputs = num_inputs
        self.dec = nn.Linear(num_inputs, num_token)
        self.init_params()

    def _gen_sqr_nxt_mask(self, size):
        msk = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        msk = msk.float().masked_fill(msk == 0, float('-inf'))
        msk = msk.masked_fill(msk == 1, float(0.0))
        return msk

    def init_params(self):
        initial_rng = 0.12
        self.enc.weight.data.uniform_(-initial_rng, initial_rng)
        self.dec.bias.data.zero_()
        self.dec.weight.data.uniform_(-initial_rng, initial_rng)

    def forward(self, source):
        if self.mask_source is None or self.mask_source.size(0) != len(source):
            dvc = source.device
            msk = self._gen_sqr_nxt_mask(len(source)).to(dvc)
            self.mask_source = msk

        source = self.enc(source) * math.sqrt(self.num_inputs)
        source = self.position_enc(source)
        op = self.enc_transformer(source, self.mask_source)
        op = self.dec(op)
        return op

In [3]:
class PosEnc(nn.Module):
    def __init__(self, d_m, dropout=0.2, size_limit=5000):
        super(PosEnc, self).__init__()
        self.dropout = nn.Dropout(dropout)
        p_enc = torch.zeros(size_limit, d_m)
        pos = torch.arange(0, size_limit, dtype=torch.float).unsqueeze(1)
        divider = torch.exp(torch.arange(0, d_m, 2).float() * (-math.log(10000.0) / d_m))
        p_enc[:, 0::2] = torch.sin(pos * divider)
        p_enc[:, 1::2] = torch.cos(pos * divider)
        p_enc = p_enc.unsqueeze(0).transpose(0, 1)
        self.register_buffer('p_enc', p_enc)

    def forward(self, x):
        return self.dropout(x + self.p_enc[:x.size(0), :])

In [4]:
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"), lower=True, eos_token='<eos>', init_token='<sos>')
training_text, validation_text, testing_text = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(training_text)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def gen_batches(text_dataset, batch_size):
    text_dataset = TEXT.numericalize([text_dataset.examples[0].text])
    # divide text dataset into parts of size equal to batch_size
    num_batches = text_dataset.size(0) // batch_size
    # remove data points that lie outside batches (remainders)
    text_dataset = text_dataset.narrow(0, 0, num_batches * batch_size)
    # distribute dataset across batches evenly
    text_dataset = text_dataset.view(batch_size, -1).t().contiguous()
    return text_dataset.to(device)

training_batch_size = 32
evaluation_batch_size = 16

training_data = gen_batches(training_text, training_batch_size)
validation_data = gen_batches(validation_text, evaluation_batch_size)
testing_data = gen_batches(testing_text, evaluation_batch_size)

In [5]:
max_seq_len = 64
def return_batch(src, k):
    sequence_length = min(max_seq_len, len(src) - 1 - k)
    sequence_data = src[k:k+sequence_length]
    sequence_label = src[k+1:k+1+sequence_length].view(-1)
    return sequence_data, sequence_label

In [6]:
num_tokens = len(TEXT.vocab.stoi) # vocabulary size
embedding_size = 256 # dimension of embedding layer
num_hidden_params = 256 # transformer encoder's hidden (feed forward) layer dimension
num_layers = 2 # num of transformer encoder layers within transformer encoder
num_heads = 2 # num of heads in (multi head) attention models
dropout = 0.25 # value (fraction) of dropout
loss_func = nn.CrossEntropyLoss()
lrate = 4.0 # learning rate
transformer_model = Transformer(num_tokens, embedding_size, num_heads, num_hidden_params, num_layers, 
                                     dropout).to(device)
optim_module = torch.optim.SGD(transformer_model.parameters(), lr=lrate)
sched_module = torch.optim.lr_scheduler.StepLR(optim_module, 1.0, gamma=0.88)

In [7]:
def train_model():
    transformer_model.train()
    loss_total = 0.
    time_start = time.time()
    num_tokens = len(TEXT.vocab.stoi)
    for b, i in enumerate(range(0, training_data.size(0) - 1, max_seq_len)):
        train_data_batch, train_label_batch = return_batch(training_data, i)
        optim_module.zero_grad()
        op = transformer_model(train_data_batch)
        loss_curr = loss_func(op.view(-1, num_tokens), train_label_batch)
        loss_curr.backward()
        torch.nn.utils.clip_grad_norm_(transformer_model.parameters(), 0.6)
        optim_module.step()

        loss_total += loss_curr.item()
        interval = 100
        if b % interval == 0 and b > 0:
            loss_interval = loss_total / interval
            time_delta = time.time() - time_start
            print(f"epoch {ep}, {b}/{len(training_data)//max_seq_len} batches, training loss {loss_interval:.2f}, training perplexity {math.exp(loss_interval):.2f}")
            loss_total = 0
            time_start = time.time()

def eval_model(eval_model_obj, eval_data_source):
    eval_model_obj.eval() 
    loss_total = 0.
    num_tokens = len(TEXT.vocab.stoi)
    with torch.no_grad():
        for j in range(0, eval_data_source.size(0) - 1, max_seq_len):
            eval_data, eval_label = return_batch(eval_data_source, j)
            op = eval_model_obj(eval_data)
            op_flat = op.view(-1, num_tokens)
            loss_total += len(eval_data) * loss_func(op_flat, eval_label).item()
    return loss_total / (len(eval_data_source) - 1)

In [8]:
min_validation_loss = float("inf")
eps = 50
best_model_so_far = None

for ep in range(1, eps + 1):
    ep_time_start = time.time()
    train_model()
    validation_loss = eval_model(transformer_model, validation_data)
    print()
    print(f"epoch {ep:}, validation loss {validation_loss:.2f}, validation perplexity {math.exp(validation_loss):.2f}")
    print()

    if validation_loss < min_validation_loss:
        min_validation_loss = validation_loss
        best_model_so_far = transformer_model

    sched_module.step()

epoch 1, 100/1018 batches, training loss 8.63, training perplexity 5614.45
epoch 1, 200/1018 batches, training loss 7.23, training perplexity 1380.31
epoch 1, 300/1018 batches, training loss 6.79, training perplexity 892.50
epoch 1, 400/1018 batches, training loss 6.55, training perplexity 701.84
epoch 1, 500/1018 batches, training loss 6.45, training perplexity 634.57
epoch 1, 600/1018 batches, training loss 6.32, training perplexity 553.86
epoch 1, 700/1018 batches, training loss 6.24, training perplexity 513.65
epoch 1, 800/1018 batches, training loss 6.13, training perplexity 459.07
epoch 1, 900/1018 batches, training loss 6.11, training perplexity 450.48
epoch 1, 1000/1018 batches, training loss 6.07, training perplexity 433.88

epoch 1, validation loss 5.82, validation perplexity 337.70

epoch 2, 100/1018 batches, training loss 5.98, training perplexity 395.15
epoch 2, 200/1018 batches, training loss 5.90, training perplexity 363.99
epoch 2, 300/1018 batches, training loss 5.83, 

epoch 11, 300/1018 batches, training loss 4.70, training perplexity 109.89
epoch 11, 400/1018 batches, training loss 4.70, training perplexity 109.53
epoch 11, 500/1018 batches, training loss 4.70, training perplexity 109.91
epoch 11, 600/1018 batches, training loss 4.70, training perplexity 110.13
epoch 11, 700/1018 batches, training loss 4.72, training perplexity 112.66
epoch 11, 800/1018 batches, training loss 4.57, training perplexity 96.17
epoch 11, 900/1018 batches, training loss 4.64, training perplexity 103.05
epoch 11, 1000/1018 batches, training loss 4.68, training perplexity 107.53

epoch 11, validation loss 5.07, validation perplexity 159.83

epoch 12, 100/1018 batches, training loss 4.72, training perplexity 112.17
epoch 12, 200/1018 batches, training loss 4.65, training perplexity 104.32
epoch 12, 300/1018 batches, training loss 4.65, training perplexity 104.52
epoch 12, 400/1018 batches, training loss 4.65, training perplexity 104.79
epoch 12, 500/1018 batches, training 

epoch 21, 500/1018 batches, training loss 4.43, training perplexity 83.93
epoch 21, 600/1018 batches, training loss 4.43, training perplexity 84.29
epoch 21, 700/1018 batches, training loss 4.46, training perplexity 86.58
epoch 21, 800/1018 batches, training loss 4.31, training perplexity 74.16
epoch 21, 900/1018 batches, training loss 4.38, training perplexity 80.12
epoch 21, 1000/1018 batches, training loss 4.41, training perplexity 82.29

epoch 21, validation loss 5.04, validation perplexity 153.79

epoch 22, 100/1018 batches, training loss 4.48, training perplexity 88.31
epoch 22, 200/1018 batches, training loss 4.41, training perplexity 82.67
epoch 22, 300/1018 batches, training loss 4.42, training perplexity 83.34
epoch 22, 400/1018 batches, training loss 4.42, training perplexity 83.06
epoch 22, 500/1018 batches, training loss 4.42, training perplexity 83.42
epoch 22, 600/1018 batches, training loss 4.42, training perplexity 83.23
epoch 22, 700/1018 batches, training loss 4.45, 

epoch 31, 800/1018 batches, training loss 4.26, training perplexity 70.57
epoch 31, 900/1018 batches, training loss 4.33, training perplexity 76.10
epoch 31, 1000/1018 batches, training loss 4.35, training perplexity 77.66

epoch 31, validation loss 5.01, validation perplexity 149.84

epoch 32, 100/1018 batches, training loss 4.43, training perplexity 84.29
epoch 32, 200/1018 batches, training loss 4.37, training perplexity 79.22
epoch 32, 300/1018 batches, training loss 4.38, training perplexity 79.67
epoch 32, 400/1018 batches, training loss 4.37, training perplexity 79.27
epoch 32, 500/1018 batches, training loss 4.38, training perplexity 79.94
epoch 32, 600/1018 batches, training loss 4.38, training perplexity 79.58
epoch 32, 700/1018 batches, training loss 4.41, training perplexity 82.00
epoch 32, 800/1018 batches, training loss 4.25, training perplexity 70.27
epoch 32, 900/1018 batches, training loss 4.33, training perplexity 75.91
epoch 32, 1000/1018 batches, training loss 4.36,


epoch 41, validation loss 4.99, validation perplexity 147.50

epoch 42, 100/1018 batches, training loss 4.44, training perplexity 84.77
epoch 42, 200/1018 batches, training loss 4.37, training perplexity 79.33
epoch 42, 300/1018 batches, training loss 4.38, training perplexity 79.99
epoch 42, 400/1018 batches, training loss 4.38, training perplexity 79.87
epoch 42, 500/1018 batches, training loss 4.38, training perplexity 80.10
epoch 42, 600/1018 batches, training loss 4.38, training perplexity 79.91
epoch 42, 700/1018 batches, training loss 4.41, training perplexity 82.19
epoch 42, 800/1018 batches, training loss 4.26, training perplexity 70.55
epoch 42, 900/1018 batches, training loss 4.33, training perplexity 75.83
epoch 42, 1000/1018 batches, training loss 4.36, training perplexity 78.10

epoch 42, validation loss 4.99, validation perplexity 147.36

epoch 43, 100/1018 batches, training loss 4.44, training perplexity 85.17
epoch 43, 200/1018 batches, training loss 4.38, training pe

In [9]:
testing_loss = eval_model(best_model_so_far, testing_data)
print(f"testing loss {testing_loss:.2f}, testing perplexity {math.exp(testing_loss):.2f}")

testing loss 4.92, testing perplexity 136.85


In [10]:
mdl_pth = './transformer.pth'
torch.save(best_model_so_far.state_dict(), mdl_pth)

In [11]:
# load the best trained model
transformer_cached = Transformer(num_tokens, embedding_size, num_heads, num_hidden_params, num_layers, 
                                     dropout).to(device)
transformer_cached.load_state_dict(torch.load(mdl_pth))

<All keys matched successfully>

In [108]:
ln = 10
sntc = 'It will _'
sntc_split = sntc.split()
torch.manual_seed(799)
with torch.no_grad():
    for i in range(ln):
        sntc = ' '.join(sntc_split)
        txt_ds = TEXT.numericalize([sntc_split])
        num_b = txt_ds.size(0)
        txt_ds = txt_ds.narrow(0, 0, num_b)
        txt_ds = txt_ds.view(1, -1).t().contiguous().to(device)
        ev_X, _ = return_batch(txt_ds, i+1)
        op = transformer_cached(ev_X)
        op_flat = op.view(-1, num_tokens)
        res = TEXT.vocab.itos[op_flat.argmax(1)[0]]
        sntc_split.insert(-1, res)
print(sntc[:-2])

It will be used to the first season , and the
