In [3]:
import math
import time

from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data


In [9]:
def cal_performance(pred, tgt, smoothing=False):
    loss = cal_loss(pred, tgt, smoothing)
    pred = pred.max(1)[1]
    tgt = tgt.contiguous().view(-1)
    non_pad_mask = tgt.ne(Constants.PAD)
    n_correct = pred.eq(tgt)
    n_correct = n_correct.masked_select(non_pad_mask).sum().item()
    return loss, n_correct

def cal_loss(pred, tgt, smoothing):
    gold = gold.contiguous().view(-1)
    if smoothing:
        eps = 0.1
        n_class = pred.size(1)
        
        one_hot = torch.zeros_like(pred).scatter(1, tgt.view(-1, 1), 1)
        one_hot = one_hot * (1-eps) + (1-one_hot) * eps
        log_prb = F.log_softmax(pred, dim=1)
        
        non_pad_mask = tgt.ne(Constants.PAD)
        loss = -(one_hot*log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()
    else:
        loss = F.cross_entropy(pred, tgt, ignore_index=Constants.PAD)

def train_epoch(model, training_data, optimizer, device, smoothing=False):
    model.train()
    
    
    total_loss = 0
    n_word_total = 0
    n_word_correct = 0
    for batch in tqdm(training_data, mininterval=2,
            desc='  - (Training)   ', leave=False):
        src_seq, tgt_seq = map(lambda x: x.to(device), batch)
    
        tgt = tgt_seq[:, 1:]

        # forward
        optimizer.zeros_grad()
        pred = model(src_seq)

        # backward
        loss, n_correct = cal_performance(pred, tgt, smoothing)
        loss.backward()

        # update parameter
        optimizer.step()
        
        total_loss += loss.item()
        
        non_pad_mask = tgt.ne(Constants.PAD)
        n_word = non_pad_mask.sum().item()
        n_word_total += n_word
        n_word_correct += n_correct
        
    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy
    

def eval_epoch(model, validation_data, optimizer, device):
    model.eval()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    with torch.no_grad():
        for batch in tqdm(
                validation_data, mininterval=2,
                desc='  - (Validation) ', leave=False):

            # prepare data
            src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
            tgt = tgt_seq[:, 1:]

            # forward
            pred = model(src_seq)
            loss, n_correct = cal_performance(pred, tgt, smoothing=False)

            # note keeping
            total_loss += loss.item()

            non_pad_mask = gold.ne(Constants.PAD)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct

    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy
    
    
def train(model, training_data, validation_data, optimizer, device, opt):
    valid_accus = []
    for epoch_i in range(opt['epoch']):
        print('[ Epoch', epoch_i, ']')

        start = time.time()
        train_loss, train_accu = train_epoch(
            model, training_data, optimizer, device, smoothing=opt['label_smoothing'])
        print('  - (Training)   ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
              'elapse: {elapse:3.3f} min'.format(
                  ppl=math.exp(min(train_loss, 100)), accu=100*train_accu,
                  elapse=(time.time()-start)/60))

        start = time.time()
        valid_loss, valid_accu = eval_epoch(model, validation_data, device)
        print('  - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
                'elapse: {elapse:3.3f} min'.format(
                    ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu,
                    elapse=(time.time()-start)/60))

        valid_accus += [valid_accu]

        model_state_dict = model.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'settings': opt,
            'epoch': epoch_i}
        if opt['save_model']:
            if opt['save_mode'] == 'all':
                model_name = opt['save_model'] + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
                torch.save(checkpoint, model_name)
            elif opt['save_mode'] == 'best':
                model_name = opt['save_model'] + '.chkpt'
                if valid_accu >= max(valid_accus):
                    torch.save(checkpoint, model_name)
                    print('    - [Info] The checkpoint file has been updated.')

#         if log_train_file and log_valid_file:
#             with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
#                 log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
#                     epoch=epoch_i, loss=train_loss,
#                     ppl=math.exp(min(train_loss, 100)), accu=100*train_accu))
#                 log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
#                     epoch=epoch_i, loss=valid_loss,
#                     ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu))

In [None]:
transformer = Transformer(d_word_embedding=, d_h=, d_s=,
                          src_vocab_size=, tgt_vocab_size=)
data = torch.load()
opt = {'smoothing':False, 'save_model':False}
train(transformer, training_data, validation_data, optimizer, device, opt)


In [42]:
x = torch.zeros(3, 5)

In [43]:
x.

tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

In [48]:
x.view(3, -1, 5).expand(3, 10, 5)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

In [1]:
import numpy as np
indices = np.arange(6)
np.random.shuffle(indices)
indices

array([5, 0, 2, 1, 3, 4])