In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

In [2]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [3]:
import torch
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def data_process(raw_text_iter):
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
        
inv_map = {v: k for k, v in vocab.get_stoi().items()}

def tensor_to_tokens(my_tensor):
    x = [int(t) for t in my_tensor]
    return [inv_map[t] for t in x]
    

train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [5]:
bptt = 100
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

data, targets = get_batch(train_data, 0)
print(data.shape, targets.shape)

torch.Size([100, 20]) torch.Size([2000])


In [7]:
ntokens = len(vocab) # the size of vocabulary
emsize = 512 # embedding dimension
nhid = 512 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 4 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

In [55]:
data, targets = get_batch(train_data, i)
data

tensor([[ 3869,    19,     1,  ...,  1687,    19, 25204],
        [   21, 17322,  4347,  ...,    50, 14545,     2],
        [  780, 10360,     4,  ...,     8,    26,     5],
        ...,
        [19009,    47,    11,  ...,     0,  2452, 16017],
        [   56,    61,    15,  ...,     3,    41, 16805],
        [    1,   619,     5,  ...,    43,    10,    16]], device='cuda:0')

In [9]:
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
src, targets = get_batch(train_data, 0)
print("data ", data.shape)
src = model.encoder(src) * math.sqrt(model.ninp)
print("encoded ", src.shape)
src = model.pos_encoder(src)
output = model.transformer_encoder(src, src_mask)
output = model.decoder(output)


data  torch.Size([100, 20])
encoded  torch.Size([100, 20, 512])


In [76]:
import time

criterion = nn.CrossEntropyLoss()
lr = 4.5*10**-4 # learning rate
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.96), eps=10**(-8), weight_decay=4.5**-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [77]:
best_val_loss = float("inf")
epochs = 100 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

| epoch   1 |   200/ 1024 batches | lr 0.01 | ms/batch 37.63 | loss  7.01 | ppl  1107.98
| epoch   1 |   400/ 1024 batches | lr 0.01 | ms/batch 39.40 | loss  6.92 | ppl  1015.60
| epoch   1 |   600/ 1024 batches | lr 0.01 | ms/batch 38.23 | loss  6.91 | ppl  1006.70
| epoch   1 |   800/ 1024 batches | lr 0.01 | ms/batch 37.99 | loss  6.89 | ppl   983.69
| epoch   1 |  1000/ 1024 batches | lr 0.01 | ms/batch 39.69 | loss  6.87 | ppl   964.82
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 40.71s | valid loss  6.83 | valid ppl   925.20
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 1024 batches | lr 0.01 | ms/batch 37.82 | loss  6.75 | ppl   856.87
| epoch   2 |   400/ 1024 batches | lr 0.01 | ms/batch 37.80 | loss  6.76 | ppl   860.73
| epoch   2 |   600/ 1024 batches | lr 0.01 | ms/batch 38.35 | loss  6.75 | ppl   855.95
| epoch   2 |   800/ 1024 batches 

| epoch  13 |   200/ 1024 batches | lr 0.01 | ms/batch 38.20 | loss  6.69 | ppl   803.27
| epoch  13 |   400/ 1024 batches | lr 0.01 | ms/batch 38.98 | loss  6.67 | ppl   789.08
| epoch  13 |   600/ 1024 batches | lr 0.01 | ms/batch 38.36 | loss  6.68 | ppl   796.20
| epoch  13 |   800/ 1024 batches | lr 0.01 | ms/batch 39.28 | loss  6.67 | ppl   784.63
| epoch  13 |  1000/ 1024 batches | lr 0.01 | ms/batch 40.92 | loss  6.65 | ppl   770.41
-----------------------------------------------------------------------------------------
| end of epoch  13 | time: 41.22s | valid loss  7.03 | valid ppl  1133.43
-----------------------------------------------------------------------------------------
| epoch  14 |   200/ 1024 batches | lr 0.01 | ms/batch 37.71 | loss  6.69 | ppl   803.40
| epoch  14 |   400/ 1024 batches | lr 0.01 | ms/batch 37.65 | loss  6.67 | ppl   788.28
| epoch  14 |   600/ 1024 batches | lr 0.01 | ms/batch 38.73 | loss  6.68 | ppl   796.66
| epoch  14 |   800/ 1024 batches 

| epoch  25 |   200/ 1024 batches | lr 0.00 | ms/batch 37.01 | loss  6.70 | ppl   814.02
| epoch  25 |   400/ 1024 batches | lr 0.00 | ms/batch 39.81 | loss  6.68 | ppl   799.99
| epoch  25 |   600/ 1024 batches | lr 0.00 | ms/batch 39.00 | loss  6.70 | ppl   811.56
| epoch  25 |   800/ 1024 batches | lr 0.00 | ms/batch 39.41 | loss  6.68 | ppl   799.59
| epoch  25 |  1000/ 1024 batches | lr 0.00 | ms/batch 37.80 | loss  6.67 | ppl   784.77
-----------------------------------------------------------------------------------------
| end of epoch  25 | time: 40.71s | valid loss  6.99 | valid ppl  1090.65
-----------------------------------------------------------------------------------------
| epoch  26 |   200/ 1024 batches | lr 0.00 | ms/batch 37.67 | loss  6.70 | ppl   815.77
| epoch  26 |   400/ 1024 batches | lr 0.00 | ms/batch 39.04 | loss  6.69 | ppl   800.44
| epoch  26 |   600/ 1024 batches | lr 0.00 | ms/batch 37.45 | loss  6.70 | ppl   812.97
| epoch  26 |   800/ 1024 batches 

| epoch  37 |   200/ 1024 batches | lr 0.00 | ms/batch 38.61 | loss  6.73 | ppl   838.19
| epoch  37 |   400/ 1024 batches | lr 0.00 | ms/batch 39.18 | loss  6.71 | ppl   823.19
| epoch  37 |   600/ 1024 batches | lr 0.00 | ms/batch 39.27 | loss  6.73 | ppl   834.63
| epoch  37 |   800/ 1024 batches | lr 0.00 | ms/batch 39.44 | loss  6.71 | ppl   822.08
| epoch  37 |  1000/ 1024 batches | lr 0.00 | ms/batch 40.61 | loss  6.70 | ppl   809.44
-----------------------------------------------------------------------------------------
| end of epoch  37 | time: 41.64s | valid loss  6.95 | valid ppl  1038.70
-----------------------------------------------------------------------------------------
| epoch  38 |   200/ 1024 batches | lr 0.00 | ms/batch 38.27 | loss  6.73 | ppl   840.29
| epoch  38 |   400/ 1024 batches | lr 0.00 | ms/batch 38.30 | loss  6.72 | ppl   825.74
| epoch  38 |   600/ 1024 batches | lr 0.00 | ms/batch 39.83 | loss  6.73 | ppl   836.86
| epoch  38 |   800/ 1024 batches 

| epoch  49 |   200/ 1024 batches | lr 0.00 | ms/batch 38.20 | loss  6.78 | ppl   878.66
| epoch  49 |   400/ 1024 batches | lr 0.00 | ms/batch 38.72 | loss  6.76 | ppl   859.76
| epoch  49 |   600/ 1024 batches | lr 0.00 | ms/batch 39.18 | loss  6.77 | ppl   867.62
| epoch  49 |   800/ 1024 batches | lr 0.00 | ms/batch 38.37 | loss  6.76 | ppl   860.28
| epoch  49 |  1000/ 1024 batches | lr 0.00 | ms/batch 38.70 | loss  6.74 | ppl   845.27
-----------------------------------------------------------------------------------------
| end of epoch  49 | time: 40.76s | valid loss  6.90 | valid ppl   989.19
-----------------------------------------------------------------------------------------
| epoch  50 |   200/ 1024 batches | lr 0.00 | ms/batch 38.26 | loss  6.78 | ppl   882.59
| epoch  50 |   400/ 1024 batches | lr 0.00 | ms/batch 38.37 | loss  6.76 | ppl   862.94
| epoch  50 |   600/ 1024 batches | lr 0.00 | ms/batch 38.46 | loss  6.77 | ppl   871.03
| epoch  50 |   800/ 1024 batches 

| epoch  61 |   200/ 1024 batches | lr 0.00 | ms/batch 37.99 | loss  6.83 | ppl   927.20
| epoch  61 |   400/ 1024 batches | lr 0.00 | ms/batch 37.40 | loss  6.81 | ppl   902.59
| epoch  61 |   600/ 1024 batches | lr 0.00 | ms/batch 38.36 | loss  6.82 | ppl   913.47
| epoch  61 |   800/ 1024 batches | lr 0.00 | ms/batch 38.27 | loss  6.81 | ppl   902.97
| epoch  61 |  1000/ 1024 batches | lr 0.00 | ms/batch 37.30 | loss  6.79 | ppl   888.02
-----------------------------------------------------------------------------------------
| end of epoch  61 | time: 40.01s | valid loss  6.84 | valid ppl   930.01
-----------------------------------------------------------------------------------------
| epoch  62 |   200/ 1024 batches | lr 0.00 | ms/batch 39.39 | loss  6.84 | ppl   931.29
| epoch  62 |   400/ 1024 batches | lr 0.00 | ms/batch 40.05 | loss  6.81 | ppl   907.23
| epoch  62 |   600/ 1024 batches | lr 0.00 | ms/batch 37.67 | loss  6.82 | ppl   917.18
| epoch  62 |   800/ 1024 batches 

| epoch  73 |   200/ 1024 batches | lr 0.00 | ms/batch 40.15 | loss  6.88 | ppl   971.00
| epoch  73 |   400/ 1024 batches | lr 0.00 | ms/batch 38.06 | loss  6.86 | ppl   955.48
| epoch  73 |   600/ 1024 batches | lr 0.00 | ms/batch 38.36 | loss  6.86 | ppl   950.03
| epoch  73 |   800/ 1024 batches | lr 0.00 | ms/batch 38.79 | loss  6.85 | ppl   943.15
| epoch  73 |  1000/ 1024 batches | lr 0.00 | ms/batch 39.07 | loss  6.84 | ppl   934.95
-----------------------------------------------------------------------------------------
| end of epoch  73 | time: 40.96s | valid loss  6.76 | valid ppl   861.19
-----------------------------------------------------------------------------------------
| epoch  74 |   200/ 1024 batches | lr 0.00 | ms/batch 37.94 | loss  6.88 | ppl   974.61
| epoch  74 |   400/ 1024 batches | lr 0.00 | ms/batch 37.31 | loss  6.87 | ppl   958.72
| epoch  74 |   600/ 1024 batches | lr 0.00 | ms/batch 39.09 | loss  6.86 | ppl   952.91
| epoch  74 |   800/ 1024 batches 

| epoch  85 |   200/ 1024 batches | lr 0.00 | ms/batch 38.42 | loss  6.92 | ppl  1007.81
| epoch  85 |   400/ 1024 batches | lr 0.00 | ms/batch 40.00 | loss  6.90 | ppl   988.74
| epoch  85 |   600/ 1024 batches | lr 0.00 | ms/batch 40.06 | loss  6.89 | ppl   978.38
| epoch  85 |   800/ 1024 batches | lr 0.00 | ms/batch 38.94 | loss  6.88 | ppl   975.94
| epoch  85 |  1000/ 1024 batches | lr 0.00 | ms/batch 38.51 | loss  6.87 | ppl   966.09
-----------------------------------------------------------------------------------------
| end of epoch  85 | time: 41.26s | valid loss  6.71 | valid ppl   819.34
-----------------------------------------------------------------------------------------
| epoch  86 |   200/ 1024 batches | lr 0.00 | ms/batch 39.48 | loss  6.92 | ppl  1010.01
| epoch  86 |   400/ 1024 batches | lr 0.00 | ms/batch 38.35 | loss  6.90 | ppl   990.31
| epoch  86 |   600/ 1024 batches | lr 0.00 | ms/batch 37.30 | loss  6.89 | ppl   980.38
| epoch  86 |   800/ 1024 batches 

| epoch  97 |   200/ 1024 batches | lr 0.00 | ms/batch 40.12 | loss  6.93 | ppl  1024.70
| epoch  97 |   400/ 1024 batches | lr 0.00 | ms/batch 38.61 | loss  6.91 | ppl   999.84
| epoch  97 |   600/ 1024 batches | lr 0.00 | ms/batch 38.83 | loss  6.91 | ppl   998.58
| epoch  97 |   800/ 1024 batches | lr 0.00 | ms/batch 38.26 | loss  6.90 | ppl   987.71
| epoch  97 |  1000/ 1024 batches | lr 0.00 | ms/batch 38.98 | loss  6.89 | ppl   979.63
-----------------------------------------------------------------------------------------
| end of epoch  97 | time: 41.03s | valid loss  6.69 | valid ppl   807.71
-----------------------------------------------------------------------------------------
| epoch  98 |   200/ 1024 batches | lr 0.00 | ms/batch 38.10 | loss  6.93 | ppl  1025.24
| epoch  98 |   400/ 1024 batches | lr 0.00 | ms/batch 38.56 | loss  6.91 | ppl  1000.58
| epoch  98 |   600/ 1024 batches | lr 0.00 | ms/batch 39.47 | loss  6.91 | ppl   999.39
| epoch  98 |   800/ 1024 batches 

In [58]:
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

| End of training | test loss  5.51 | test ppl   246.12


In [32]:
def data_process(raw_text_iter):
    data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train,test,example = WikiText2()
example = data_process(example)
example = batchify(example, batch_size)

In [34]:
x,y = get_batch(example, 4)
print("x", x.shape)
print("y", y.shape)

x torch.Size([35, 20])
y torch.Size([700])


In [35]:
src_mask = model.generate_square_subsequent_mask(bptt).to(device)
output = model(x.to(device), src_mask)
labels = torch.argmax(output, 2).view(-1)
correct_predictions = labels.eq(y)
print(correct_predictions.sum().float() / correct_predictions.nelement())

tensor(0.2086, device='cuda:0')


In [36]:
t = x[:, 12:13]
yy = y.reshape(x.size())[:, 0:1]


output = model(t.to(device), src_mask)
labels = torch.argmax(output, 2).view(-1)
correct_predictions = labels.eq(yy)
print(correct_predictions.sum().float() / correct_predictions.nelement())

tensor(0.0131, device='cuda:0')


In [75]:
def write_sentence(xx):
    sentence = ""
    for word in tensor_to_tokens(xx.reshape(-1)):
        sentence+= word +" "
    print(sentence)
    
def complete_sentence(xx, length, src_mask):
    sentence = ""
    for word in tensor_to_tokens(xx.reshape(-1)):
        sentence+= word +" "
    sentence += "||"
    # crete new tokens
    for _ in range(length):
        out = model(xx.to(device), src_mask)
        labels = torch.argmax(out, 2).view(-1)
        sentence += tensor_to_tokens(labels.reshape(-1))[-1]+" "
        xx = torch.cat((xx[0:], labels.reshape(-1)[-1:].reshape(1,1)))
        src_mask = model.generate_square_subsequent_mask(len(xx)).to(device)
    
    print(sentence)
    
    
t = "At the time of his marriage, William's father, John Yeats, was studying law, but would later pursue art studies at Heatherley School of Fine Art, in London. William's mother, Susan Mary Pollexfen, came from Sligo, from a wealthy merchant family, which owned a"
t = torch.tensor(vocab(tokenizer(t)))
src_mask = model.generate_square_subsequent_mask(len(t)).to(device)
t = t.reshape([-1, 1]).to(device)

complete_sentence(t, 100, src_mask)  

at the time of his marriage , william ' s father , john yeats , was studying law , but would later pursue art studies at <unk> school of fine art , in london . william ' s mother , susan mary <unk> , came from <unk> , from a wealthy merchant family , which owned a ||private property . he was married to elizabeth , who was born in <unk> , and died soon after his father died . = = = family = = = = = = = = = = = = = = = = early life = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = 
