In [1]:
'''
This script handling the training process.
'''

import math
import time

from tqdm import tqdm
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import transformer.Constants as Constants
from dataset import TranslationDataset, paired_collate_fn
from transformer.Models import Transformer
from transformer.Optim import ScheduledOptim

In [2]:
def cal_performance(pred, gold, smoothing=False):
    ''' Apply label smoothing if needed '''

    loss = cal_loss(pred, gold, smoothing)

    pred = pred.max(1)[1]
    gold = gold.contiguous().view(-1)
    non_pad_mask = gold.ne(Constants.PAD)
    n_correct = pred.eq(gold)
    n_correct = n_correct.masked_select(non_pad_mask).sum().item()

    return loss, n_correct


def cal_loss(pred, gold, smoothing):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''

    gold = gold.contiguous().view(-1)

    if smoothing:
        eps = 0.1
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)

        non_pad_mask = gold.ne(Constants.PAD)
        loss = -(one_hot * log_prb).sum(dim=1)
        loss = loss.masked_select(non_pad_mask).sum()  # average later
    else:
        loss = F.cross_entropy(pred, gold, ignore_index=Constants.PAD, reduction='sum')

    return loss

In [3]:
def train_epoch(model, training_data, optimizer, device, smoothing):
    ''' Epoch operation in training phase'''

    model.train()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    for batch in tqdm(
            training_data, mininterval=2,
            desc='  - (Training)   ', leave=False):
        
        # prepare data
        src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
        gold = tgt_seq[:, 1:]

        # forward
        optimizer.zero_grad()
        pred = model(src_seq, src_pos, tgt_seq, tgt_pos)

        # backward
        loss, n_correct = cal_performance(pred, gold, smoothing=smoothing)
        loss.backward()

        # update parameters
        optimizer.step_and_update_lr()

        # note keeping
        total_loss += loss.item()

        non_pad_mask = gold.ne(Constants.PAD)
        n_word = non_pad_mask.sum().item()
        n_word_total += n_word
        n_word_correct += n_correct

    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy


In [4]:
def eval_epoch(model, validation_data, device):
    ''' Epoch operation in evaluation phase '''

    model.eval()

    total_loss = 0
    n_word_total = 0
    n_word_correct = 0

    with torch.no_grad():
        for batch in tqdm(
                validation_data, mininterval=2,
                desc='  - (Validation) ', leave=False):

            # prepare data
            src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
            gold = tgt_seq[:, 1:]

            # forward
            pred = model(src_seq, src_pos, tgt_seq, tgt_pos)
            loss, n_correct = cal_performance(pred, gold, smoothing=False)

            # note keeping
            total_loss += loss.item()

            non_pad_mask = gold.ne(Constants.PAD)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct

    loss_per_word = total_loss/n_word_total
    accuracy = n_word_correct/n_word_total
    return loss_per_word, accuracy

In [5]:
def train(model, training_data, validation_data, optimizer, device, opt):
    ''' Start training '''

    log_train_file = None
    log_valid_file = None

    if opt['log']:
        log_train_file = opt['log'] + '.train.log'
        log_valid_file = opt['log'] + '.valid.log'

        print('[Info] Training performance will be written to file: {} and {}'.format(
            log_train_file, log_valid_file))

        with open(log_train_file, 'w') as log_tf, open(log_valid_file, 'w') as log_vf:
            log_tf.write('epoch,loss,ppl,accuracy\n')
            log_vf.write('epoch,loss,ppl,accuracy\n')

    valid_accus = []
    for epoch_i in range(opt['epoch']):
        print('[ Epoch', epoch_i, ']')

        start = time.time()
        train_loss, train_accu = train_epoch(
            model, training_data, optimizer, device, smoothing=opt['label_smoothing'])
        print('  - (Training)   ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
              'elapse: {elapse:3.3f} min'.format(
                  ppl=math.exp(min(train_loss, 100)), accu=100*train_accu,
                  elapse=(time.time()-start)/60))

        start = time.time()
        valid_loss, valid_accu = eval_epoch(model, validation_data, device)
        print('  - (Validation) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, '\
                'elapse: {elapse:3.3f} min'.format(
                    ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu,
                    elapse=(time.time()-start)/60))

        valid_accus += [valid_accu]

        model_state_dict = model.state_dict()
        checkpoint = {
            'model': model_state_dict,
            'settings': opt,
            'epoch': epoch_i}

        if opt['save_model']:
            if opt['save_mode'] == 'all':
                model_name = opt.save_model + '_accu_{accu:3.3f}.chkpt'.format(accu=100*valid_accu)
                torch.save(checkpoint, model_name)
            elif opt['save_mode'] == 'best':
                model_name = opt['save_model'] + '.chkpt'
                if valid_accu >= max(valid_accus):
                    torch.save(checkpoint, model_name)
                    print('    - [Info] The checkpoint file has been updated.')

        if log_train_file and log_valid_file:
            with open(log_train_file, 'a') as log_tf, open(log_valid_file, 'a') as log_vf:
                log_tf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                    epoch=epoch_i, loss=train_loss,
                    ppl=math.exp(min(train_loss, 100)), accu=100*train_accu))
                log_vf.write('{epoch},{loss: 8.5f},{ppl: 8.5f},{accu:3.3f}\n'.format(
                    epoch=epoch_i, loss=valid_loss,
                    ppl=math.exp(min(valid_loss, 100)), accu=100*valid_accu))

In [6]:
def prepare_dataloaders(data, batch_size):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['train']['src'],
            tgt_insts=data['train']['tgt']),
        num_workers=2,
        batch_size=batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['valid']['src'],
            tgt_insts=data['valid']['tgt']),
        num_workers=2,
        batch_size=batch_size,
        collate_fn=paired_collate_fn)
    return train_loader, valid_loader

In [7]:
# python train.py -data data/multi30k.atok.low.pt 
#                 -save_model trained
#                 -save_mode best
#                 -proj_share_weight
#                 -label_smoothing

opt = {'data': 'data/multi30k.atok.low.pt',
       'epoch': 10,
       'batch_size': 16,
       'd_model': 512,
       'd_inner_hid': 2048,
       'd_k': 64,
       'd_v': 64,
       'n_head': 8,
       'n_layers': 6,
       'n_warmup_steps': 4000,
       'dropout': 0.1,
       'embs_share_weight': False,
       'proj_share_weight': True,
       'log': None,
       'save_model': 'trained',
       'save_mode': 'best',
       'no_cuda': True,
       'label_smoothing': True}

opt['cuda'] = not opt['no_cuda']
opt['d_word_vec'] = opt['d_model']

#========= Loading Dataset =========#
data = torch.load(opt['data'])
opt['max_token_seq_len'] = data['settings'].max_token_seq_len

training_data, validation_data = prepare_dataloaders(data, opt['batch_size'])

opt['src_vocab_size'] = training_data.dataset.src_vocab_size
opt['tgt_vocab_size'] = training_data.dataset.tgt_vocab_size

print(opt['src_vocab_size'])
print(opt['tgt_vocab_size'])

#========= Preparing Model =========#
if opt['embs_share_weight']:
    assert training_data.dataset.src_word2idx == training_data.dataset.tgt_word2idx, \
        'The src/tgt word2idx table are different but asked to share word embedding.'

2911
3149


In [8]:
for i in training_data:
    for j in i:
        print(j.size())
        print()
    break

torch.Size([16, 24])

torch.Size([16, 24])

torch.Size([16, 23])

torch.Size([16, 23])



In [9]:
train_data = data['train']['tgt']
print(len(train_data))

for i in range(0, len(train_data)):
    print(train_data[i])
    if i == 10:
        break

29000
[2, 2782, 683, 291, 77, 2376, 2501, 2987, 2800, 2105, 818, 1, 1, 1177, 3]
[2, 3116, 77, 1208, 2867, 2559, 674, 1, 1177, 3]
[2, 674, 54, 404, 3066, 2800, 674, 1, 2935, 312, 1177, 3]
[2, 674, 1217, 2800, 170, 2380, 2716, 2268, 2303, 2812, 987, 1460, 1486, 674, 990, 1177, 3]
[2, 2782, 77, 2355, 1793, 1063, 1460, 2519, 2427, 245, 1177, 3]
[2, 674, 1217, 2800, 904, 86, 2423, 1454, 1942, 82, 2105, 1616, 1217, 142, 2716, 2553, 1177, 3]
[2, 674, 1217, 1253, 973, 1, 1, 2555, 1177, 3]
[2, 674, 1, 404, 1223, 1208, 2579, 3125, 82, 2999, 1, 1580, 431, 1, 1177, 3]
[2, 2423, 604, 1208, 2812, 749, 1614, 2084, 2555, 170, 1300, 1384, 1177, 3]
[2, 2499, 667, 358, 2800, 2105, 2664, 2303, 2557, 1177, 3]
[2, 2423, 1, 1208, 2056, 404, 1942, 1580, 1, 1866, 1177, 3]


In [10]:
valid_data = data['valid']['src']
print(len(valid_data))

for i in range(0, len(valid_data)):
    print(valid_data[i])
    if i == 10:
        break

1014
[2, 2199, 759, 1936, 1761, 1107, 2561, 1925, 355, 2199, 1023, 3]
[2, 2199, 1595, 2303, 2591, 2199, 1855, 809, 1014, 2199, 1556, 1096, 3]
[2, 2199, 14, 614, 2620, 585, 1014, 2199, 1883, 1391, 1064, 1096, 3]
[2, 1876, 1761, 1618, 201, 2199, 1330, 643, 1832, 2042, 1014, 2371, 1, 2214, 2117, 3]
[2, 2199, 2267, 1595, 614, 2199, 96, 2411, 2435, 2763, 1573, 2591, 2199, 1413, 650, 1096, 3]
[2, 2199, 2866, 2591, 2199, 96, 661, 1783, 2633, 2199, 1, 1114, 2708, 1, 1936, 2845, 1768, 1783, 543, 1409, 268, 720, 1063, 2199, 1, 1096, 3]
[2, 2199, 680, 1967, 2763, 2784, 79, 268, 985, 1967, 1096, 3]
[2, 2199, 711, 14, 614, 2199, 1, 613, 957, 2199, 1569, 1133, 2336, 2371, 1, 141, 1096, 3]
[2, 2199, 1595, 2591, 2199, 2822, 1983, 2763, 1415, 268, 764, 3]
[2, 2199, 182, 1883, 2591, 2199, 1, 1240, 430, 2230, 2633, 2199, 416, 482, 3]
[2, 2199, 711, 1332, 2763, 1075, 2103, 1014, 1809, 1, 1750, 1096, 3]


In [11]:
dict_data = data['dict']['tgt']
print(len(dict_data))

for k, v in dict_data.items():
    print(k)
    print(v)
    break

3149
<s>
2


In [12]:
device = torch.device('cuda' if opt['cuda'] else 'cpu')
print(device)

transformer = Transformer(
    opt['src_vocab_size'],
    opt['tgt_vocab_size'],
    opt['max_token_seq_len'],
    tgt_emb_prj_weight_sharing=opt['proj_share_weight'],
    emb_src_tgt_weight_sharing=opt['embs_share_weight'],
    d_k=opt['d_k'],
    d_v=opt['d_v'],
    d_model=opt['d_model'],
    d_word_vec=opt['d_word_vec'],
    d_inner=opt['d_inner_hid'],
    n_layers=opt['n_layers'],
    n_head=opt['n_head'],
    dropout=opt['dropout']).to(device)

cpu


In [13]:
for key in opt:
    print('{}: \n{}\n'.format(key, opt[key]))
    break
    
print(opt.keys())
print(opt['batch_size'])

# d_model的512应该是word embedding的长度，一个word用512个长度的vector表示

data: 
data/multi30k.atok.low.pt

dict_keys(['data', 'epoch', 'batch_size', 'd_model', 'd_inner_hid', 'd_k', 'd_v', 'n_head', 'n_layers', 'n_warmup_steps', 'dropout', 'embs_share_weight', 'proj_share_weight', 'log', 'save_model', 'save_mode', 'no_cuda', 'label_smoothing', 'cuda', 'd_word_vec', 'max_token_seq_len', 'src_vocab_size', 'tgt_vocab_size'])
16


In [14]:
optimizer = ScheduledOptim(
    optim.Adam(
        filter(lambda x: x.requires_grad, transformer.parameters()),
        betas=(0.9, 0.98), eps=1e-09),
    opt['d_model'], opt['n_warmup_steps'])

# TODO: 研究data的输入和输出，写一个dataloader
# TODO: 写一个只包含encoder的transformer

train(transformer, training_data, validation_data, optimizer, device, opt)

  - (Training)   :   0%|          | 0/1813 [00:00<?, ?it/s]

[ Epoch 0 ]
SIZE:
torch.Size([16, 25, 3149])


  - (Training)   :   0%|          | 1/1813 [00:02<1:08:14,  2.26s/it]

SIZE:
torch.Size([16, 20, 3149])
SIZE:
torch.Size([16, 23, 3149])


  - (Training)   :   0%|          | 3/1813 [00:06<1:05:58,  2.19s/it]

SIZE:
torch.Size([16, 21, 3149])
SIZE:
torch.Size([16, 24, 3149])


  - (Training)   :   0%|          | 5/1813 [00:10<1:03:30,  2.11s/it]Traceback (most recent call last):
  File "/miniconda3/envs/py36/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/miniconda3/envs/py36/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/miniconda3/envs/py36/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/miniconda3/envs/py36/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe
Traceback (most recent call last):
  File "/miniconda3/envs/py36/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/miniconda3/envs/py36/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/miniconda3/envs/py36/lib/python3.6/multi

KeyboardInterrupt: 