In [1]:
import math
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from model import TransformerModel
import wikitext_data

# Data processing and model compiling

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
corpus = wikitext_data.Corpus(device)

In [4]:
in_seq_len = 20
out_seq_len = 5
stride = 20
batch_size = 20

In [5]:
train_data = wikitext_data.TextDataset(corpus.train, in_out_overlap = False, input_size = in_seq_len, seq_len=in_seq_len + out_seq_len, stride = stride)
val_data = wikitext_data.TextDataset(corpus.val, in_out_overlap = False, input_size = in_seq_len, seq_len=in_seq_len + out_seq_len, stride = stride)
test_data = wikitext_data.TextDataset(corpus.test, in_out_overlap = False, input_size = in_seq_len, seq_len=in_seq_len + out_seq_len, stride = stride)

In [6]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size, shuffle = False)
val_loader = torch.utils.data.DataLoader(val_data, batch_size = batch_size, shuffle = False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, shuffle = False)

In [7]:
ntokens = len(corpus.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

In [10]:
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

import time
def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    trg_mask = model.generate_square_subsequent_mask(out_seq_len).to(device)
    for batch, i in enumerate(train_loader):
        data, target = i
        
        data = data.transpose(0,1).contiguous()
        target = target.transpose(0,1).contiguous()
        sos = torch.empty(1, target.shape[1], dtype = torch.int).fill_(corpus.vocab.stoi['<sos>'])
        target = torch.cat((sos.to(device), target), dim = 0)
        
        trg_inp = target[:-1,:]
        trg = target[1:,:].reshape(-1)
        optimizer.zero_grad()
        output = model(data, trg_inp, src_mask = None, trg_mask = trg_mask)
        loss = criterion(output.view(-1, ntokens), trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // stride, scheduler.get_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    trg_mask = model.generate_square_subsequent_mask(out_seq_len).to(device)
    with torch.no_grad():
        for i in data_source:
            data, target = i
            data = data.transpose(0,1).contiguous()
            target = target.transpose(0,1).contiguous()
            sos = torch.empty(1, target.shape[1], dtype = torch.int).fill_(corpus.vocab.stoi['<sos>'])
            target = torch.cat((sos.to(device), target), dim = 0)
            
            trg_inp = target[:-1,:]
            trg = target[1:,:].reshape(-1)
            output = eval_model(data, trg_inp, src_mask = None, trg_mask = trg_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += criterion(output_flat, trg).item()
    return total_loss / (len(data_source) - 1)

# Training

In [11]:
best_val_loss = float("inf")
epochs = 1 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_loader)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

| epoch   1 |   200/ 5124 batches | lr 5.00 | ms/batch 15.23 | loss  7.08 | ppl  1193.70
| epoch   1 |   400/ 5124 batches | lr 5.00 | ms/batch 14.79 | loss  6.99 | ppl  1086.13
| epoch   1 |   600/ 5124 batches | lr 5.00 | ms/batch 14.74 | loss  6.95 | ppl  1038.35
| epoch   1 |   800/ 5124 batches | lr 5.00 | ms/batch 14.62 | loss  7.04 | ppl  1136.72
| epoch   1 |  1000/ 5124 batches | lr 5.00 | ms/batch 14.46 | loss  7.01 | ppl  1108.74
| epoch   1 |  1200/ 5124 batches | lr 5.00 | ms/batch 14.50 | loss  6.95 | ppl  1038.07
| epoch   1 |  1400/ 5124 batches | lr 5.00 | ms/batch 14.72 | loss  6.98 | ppl  1070.61
| epoch   1 |  1600/ 5124 batches | lr 5.00 | ms/batch 14.99 | loss  6.90 | ppl   991.79
| epoch   1 |  1800/ 5124 batches | lr 5.00 | ms/batch 15.22 | loss  6.98 | ppl  1074.23
| epoch   1 |  2000/ 5124 batches | lr 5.00 | ms/batch 14.45 | loss  6.94 | ppl  1027.74
| epoch   1 |  2200/ 5124 batches | lr 5.00 | ms/batch 14.75 | loss  6.95 | ppl  1038.82
| epoch   1 |  2400/ 

# Evaluating

In [12]:
for batch, b in enumerate(train_loader):
    data, targets = b
    break

In [15]:
print("Input:")
' '.join([corpus.vocab.itos[i] for i in data[0].tolist()])

Input:


'= valkyria chronicles iii = senjō no valkyria 3 <unk> chronicles ( japanese 戦場のヴァルキュリア3 , lit . valkyria of the'

In [16]:
print("Target:")
' '.join([corpus.vocab.itos[i] for i in targets[0].tolist()])

Target:


'battlefield 3 ) , commonly'

In [20]:
sentence = corpus.test[0:in_seq_len].unsqueeze(0).cuda()
generated = sentence.transpose(0, 1)

In [21]:
generated.shape

torch.Size([20, 1])

In [22]:
print("Generating text with seed:")
' '.join([corpus.vocab.itos[i] for i in generated.transpose(0,1).tolist()[0]])

Generating text with seed:


'= robert <unk> = robert <unk> is an english film , television and theatre actor . he had a guest'

In [24]:
model.eval()
max_len = 10

src = model.emb_encoder(generated) * math.sqrt(model.ninp)
src = model.pos_encoder(src)
e_output = model.transformer_encoder(src, None)

outputs = torch.zeros(max_len).type_as(generated.data)
outputs[0] = torch.LongTensor([corpus.vocab.stoi['<sos>']])

for i in range(1, max_len):    
            
        trg_mask = model.generate_square_subsequent_mask(i).to(device)
        
        trg = model.emb_decoder(outputs[:i].unsqueeze(1)) * math.sqrt(model.ninp)
        trg = model.pos_encoder(trg)
        
        d_output = model.transformer_decoder(trg, e_output, trg_mask, None)
        out = model.decoder(d_output)
        out = torch.argmax(F.softmax(out, dim=-1), dim = -1).view(-1)[-1]
        outputs[i] = out
        

In [26]:
print("Generated text:")
' '.join([corpus.vocab.itos[i] for i in outputs.tolist()])

Generated text:


'<sos> the the the the the the the the the'