In [1]:
from utils import config, bert_data
from utils.bert_batcher import *
from utils.train_util import *
from utils.write_result import *
from datetime import datetime as dt
from tqdm import tqdm
from beam.beam_search import *
from tensorboardX import SummaryWriter
import argparse

import logging
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("pytorch_pretrained_bert").setLevel(logging.ERROR)

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 

parser = argparse.ArgumentParser()
parser.add_argument('--key_attention', type=bool, default=False, help = 'True/False')
parser.add_argument('--intra_encoder', type=bool, default=True, help = 'True/False')
parser.add_argument('--intra_decoder', type=bool, default=True, help = 'True/False')
parser.add_argument('--transformer', type=bool, default=False, help = 'True/False')
parser.add_argument('--train_rl', type=bool, default=False, help = 'True/False')
parser.add_argument('--keywords', type=str, default='POS_FOP_keywords', 
                    help = 'POS_FOP_keywords / DEP_FOP_keywords / TextRank_keywords')

parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--rand_unif_init_mag', type=float, default=0.02)
parser.add_argument('--trunc_norm_init_std', type=float, default=0.0001)
parser.add_argument('--mle_weight', type=float, default=1.0)
parser.add_argument('--gound_truth_prob', type=float, default=0.1)

parser.add_argument('--max_enc_steps', type=int, default=500)
parser.add_argument('--max_dec_steps', type=int, default=60)
parser.add_argument('--min_dec_steps', type=int, default=4)
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--vocab_size', type=int, default=50000)
parser.add_argument('--beam_size', type=int, default=16)
parser.add_argument('--batch_size', type=int, default=4)


parser.add_argument('--hidden_dim', type=int, default=512)
parser.add_argument('--emb_dim', type=int, default=768)
parser.add_argument('--gradient_accum', type=int, default=8)

parser.add_argument('--load_ckpt', type=str, default=None, help='0002000')
parser.add_argument('--word_emb_type', type=str, default='bert') # bert emb use word2Vec vocab
parser.add_argument('--pre_train_emb', type=bool, default=True)


opt = parser.parse_args(args=[])
config = re_config(opt)
config.rl_weight = 1 - config.mle_weight

if not config.transformer:
    loggerName = 'Pointer_generator_%s' % (config.word_emb_type)
else:
    loggerName = 'Transformer_%s' % (config.word_emb_type)
    
if config.intra_encoder and config.intra_decoder and True :
    loggerName = loggerName + '_Intra_Atten'
if config.key_attention:
    loggerName = loggerName + '_Key_Atten'
    
logger = getLogger(loggerName) 

if not config.transformer:
    writer = SummaryWriter('runs/Pointer-Generator/%s/exp' % config.word_emb_type)
else:
    writer = SummaryWriter('runs/Transformer/%s/exp' % config.word_emb_type)

I0401 21:09:31.885110 140421284529984 file_utils.py:35] PyTorch version 1.4.0 available.


We have  30526 bert tokens now
We have added 3 XL tokens


2020-04-01 21:09:35 - Pointer_generator_bert_Intra_Atten - INFO: - logger已啟動
I0401 21:09:35.579890 140421284529984 train_util.py:99] logger已啟動


In [2]:
train_loader, validate_loader, vocab = getDataLoader(logger, config)

2020-04-01 21:09:41 - Pointer_generator_bert_Intra_Atten - INFO: - train : 37771, test : 4197
I0401 21:09:41.552139 140421284529984 bert_batcher.py:172] train : 37771, test : 4197


In [3]:
from model import Model
import torch.nn as nn
import torch as T
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_
from transformers import BertModel
import numpy as np

load_step = None
model = Model(pre_train_emb=config.pre_train_emb, 
              word_emb_type = config.word_emb_type, 
              vocab = vocab)

model = model.cuda()
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.resize_token_embeddings(len(bert_data.bert_tokenizer))
bert_model = bert_model.cuda() # [30522, 30523, 30524, 30525] 

optimizer = T.optim.Adam(model.parameters(), lr=config.lr)   
# optimizer = T.optim.Adagrad(model.parameters(),lr=config.lr, initial_accumulator_value=0.1)

load_model_path = config.save_model_path + '/%s/%s.tar' % (logger, config.load_ckpt)

if os.path.exists(load_model_path):
    model, optimizer, load_step = loadCheckpoint(logger, load_model_path, model, optimizer)

In [4]:
def train_one(model, config, batch):
        ''' Calculate Negative Log Likelihood Loss for the given batch. In order to reduce exposure bias,
                pass the previous generated token as input with a probability of 0.25 instead of ground truth label
        Args:
        :param enc_out: Outputs of the encoder for all time steps (batch_size, length_input_sequence, 2*hidden_size)
        :param enc_hidden: Tuple containing final hidden state & cell state of encoder. Shape of h & c: (batch_size, hidden_size)
        :param enc_padding_mask: Mask for encoder input; Tensor of size (batch_size, length_input_sequence) with values of 0 for pad tokens & 1 for others
        :param ct_e: encoder context vector for time_step=0 (eq 5 in https://arxiv.org/pdf/1705.04304.pdf)
        :param extra_zeros: Tensor used to extend vocab distribution for pointer mechanism
        :param enc_batch_extend_vocab: Input batch that stores OOV ids
        :param batch: batch object
        '''
        'Encoder data'
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, coverage, \
        ct_e, enc_key_batch, enc_key_lens= \
            get_input_from_batch(batch, config, batch_first = True)
 
#         enc_batch = model.embeds(enc_batch)  # Get embeddings for encoder input    
#         enc_key_batch = model.embeds(enc_key_batch)  # Get key embeddings for encoder input
        enc_batch = enc_batch.type(T.LongTensor).cuda() #  `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
        enc_key_batch = enc_key_batch.type(T.LongTensor).cuda() #  `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
        enc_padding_mask = enc_padding_mask.type(T.LongTensor).cuda() #  `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
#         enc_key_padding_mask = enc_key_padding_mask.type(T.LongTensor).cuda() #  `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
        
        
        # enc_padding_mask  `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length]
        # `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length]
#         print(enc_batch.shape)
        enc_batch = bert_model(enc_batch, attention_mask = enc_padding_mask)[-2:][0] 
#         enc_key_batch = bert_model(enc_key_batch, attention_mask = enc_key_padding_mask)[-2:][0]  

        enc_out, enc_hidden = model.encoder(enc_batch, enc_lens)
        
        'Decoder data'
        dec_batch, dec_padding_mask, dec_lens, max_dec_len, target_batch = \
        get_output_from_batch(batch, batch_first = True) # Get input and target batchs for training decoder
        step_losses = []
        s_t = (enc_hidden[0], enc_hidden[1])  # Decoder hidden states
        x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(START))  # Input to the decoder
        prev_s = None  # Used for intra-decoder attention (section 2.2 in https://arxiv.org/pdf/1705.04304.pdf)
        sum_temporal_srcs = None  # Used for intra-temporal attention (section 2.1 in https://arxiv.org/pdf/1705.04304.pdf)
        for t in range(min(max_dec_len, config.max_dec_steps)):
            use_gound_truth = get_cuda((T.rand(len(enc_out)) > config.gound_truth_prob)).long()  # Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
            x_t = use_gound_truth * dec_batch[:, t] + (1 - use_gound_truth) * x_t  # Select decoder input based on use_ground_truth probabilities
            x_t = model.embeds(x_t)  
            final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = model.decoder(x_t, s_t, enc_out, enc_padding_mask,
                                                                                      ct_e, extra_zeros,
                                                                                      enc_batch_extend_vocab,
                                                                                      sum_temporal_srcs, prev_s, enc_key_batch, enc_key_lens)
            target = target_batch[:, t]
            log_probs = T.log(final_dist + config.eps)
            step_loss = F.nll_loss(log_probs, target, reduction="none", ignore_index=PAD)
            step_losses.append(step_loss)
            x_t = T.multinomial(final_dist,1).squeeze()  # Sample words from final distribution which can be used as input in next time step

            is_oov = (x_t >= config.vocab_size).long()  # Mask indicating whether sampled word is OOV
            x_t = (1 - is_oov) * x_t.detach() + (is_oov) * UNKNOWN_TOKEN  # Replace OOVs with [UNK] token

        losses = T.sum(T.stack(step_losses, 1), 1)  # unnormalized losses for each example in the batch; (batch_size)
        batch_avg_loss = losses / dec_lens  # Normalized losses; (batch_size)
        mle_loss = T.mean(batch_avg_loss)  # Average batch loss
        return mle_loss

In [None]:
@torch.autograd.no_grad()
def validate(validate_loader, config, model):
#     model.eval()
    losses = []
#     batch = next(iter(validate_loader))
    for batch in validate_loader:
        loss = train_one(model, config, batch)
        losses.append(loss.item())
#         break
#     model.train()
    ave_loss = sum(losses) / len(losses)
    return ave_loss

In [None]:
@torch.autograd.no_grad()
def calc_running_avg_loss(loss, running_avg_loss, decay=0.99):
    if running_avg_loss == 0:  # on the first iteration just take the loss
        running_avg_loss = loss
    else:
        running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
    running_avg_loss = min(running_avg_loss, 12)  # clip
    return running_avg_loss

In [None]:
from random import randint
@torch.autograd.no_grad()
def decode(writer, logger, step, config, model, batch, mode):
    # 動態取batch
    if mode == 'test':
        num = len(iter(batch))
        select_batch = None
        rand_b_id = randint(0,num-1)
#         logger.info('test_batch : ' + str(num)+ ' ' + str(rand_b_id))
        for idx, b in enumerate(batch):
            if idx == rand_b_id:
                select_batch = b
                break
#         select_batch = next(iter(batch))
        batch = select_batch
        if type(batch) == torch.utils.data.dataloader.DataLoader:
            batch = next(iter(batch))
    'Encoder data'
    enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, coverage, \
    ct_e, enc_key_batch, enc_key_lens= \
        get_input_from_batch(batch, config, batch_first = True)

    enc_batch = model.embeds(enc_batch)  # Get embeddings for encoder input    
    enc_key_batch = model.embeds(enc_key_batch)  # Get key embeddings for encoder input

    enc_out, enc_hidden = model.encoder(enc_batch, enc_lens)

    'Feed encoder data to predict'
    pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, 
                           enc_batch_extend_vocab, enc_key_batch, enc_key_lens, model, 
                           START, END, UNKNOWN_TOKEN)

    article_sents, decoded_sents, keywords_list, \
    ref_sents, long_seq_index = prepare_result(vocab, batch, pred_ids)

    rouge_l = write_rouge(writer, step, mode,article_sents, decoded_sents, \
                keywords_list, ref_sents, long_seq_index)

    write_bleu(writer, step, mode, article_sents, decoded_sents, \
               keywords_list, ref_sents, long_seq_index)

    write_group(writer, step, mode, article_sents, decoded_sents,\
                keywords_list, ref_sents, long_seq_index)

    return rouge_l

In [None]:
from random import randint
@torch.autograd.no_grad()
def avg_acc(writer, logger, epoch, config, model, dataloader):
    # 動態取batch
    num = len(iter(dataloader))
    avg_rouge_l = []
    for idx, batch in enumerate(dataloader):
        'Encoder data'
        enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, coverage, \
        ct_e, enc_key_batch, enc_key_lens= \
            get_input_from_batch(batch, config, batch_first = True)

        enc_batch = model.embeds(enc_batch)  # Get embeddings for encoder input    
        enc_key_batch = model.embeds(enc_key_batch)  # Get key embeddings for encoder input

        enc_out, enc_hidden = model.encoder(enc_batch, enc_lens)

        'Feed encoder data to predict'
        pred_ids = beam_search(enc_hidden, enc_out, enc_padding_mask, ct_e, extra_zeros, 
                               enc_batch_extend_vocab, enc_key_batch, enc_key_lens, model, 
                               START, END, UNKNOWN_TOKEN)

        article_sents, decoded_sents, keywords_list, \
        ref_sents, long_seq_index = prepare_result(vocab, batch, pred_ids)

        rouge_l = write_rouge(writer, None, None, article_sents, decoded_sents, \
                    keywords_list, ref_sents, long_seq_index, write = False)
        avg_rouge_l.append(rouge_l)


    avg_rouge_l = sum(avg_rouge_l) / num
    writer.add_scalars('scalar_avg/acc',  
                   {'testing_avg_acc': avg_rouge_l
                   }, epoch)

    return avg_rouge_l

In [None]:
write_train_para(writer, config)
logger.info('------Training START--------')
train_loss, val_loss, running_avg_loss, test_running_avg_loss = 0,0,0,0
step = 0
for epoch in range(config.max_epochs):
    for batch in train_loader:
        step += 1
        mle_loss = train_one(model, config, batch)
        rl_loss = T.FloatTensor([0]).cuda()
        (config.mle_weight * mle_loss + config.rl_weight * rl_loss).backward()  # 反向传播，计算当前梯度
        
        '''梯度累加就是，每次获取1个batch的数据，计算1次梯度，梯度不清空'''
        if step % ( config.gradient_accum) == 0: # gradient accumulation
#             clip_grad_norm_(model.parameters(), 5.0)                      
#             (config.mle_weight * mle_loss + config.rl_weight * rl_loss).backward()  # 反向传播，计算当前梯度
            optimizer.step() # 根据累计的梯度更新网络参数
            optimizer.zero_grad() # 清空过往梯度 

            

        if step%1000 == 0 :
            with T.autograd.no_grad():
                train_loss = mle_loss.item()
                val_loss = validate(validate_loader, config, model) # call batch by validate_loader
                running_avg_loss = calc_running_avg_loss(train_loss, running_avg_loss)
                logger.info('epoch %d: %d, training loss = %f, validation loss = %f, running_avg_loss loss = %f'
                            % (epoch, step, train_loss, val_loss, running_avg_loss))
                writer.add_scalars('scalar/Loss',  
                   {'train_loss': train_loss,
                    'val_loss': val_loss
                   }, step)
                writer.add_scalars('scalar_avg/loss',  
                   {'running_avg_loss': running_avg_loss
                   }, step)
            
        if step%5000 == 0:
            save_model(config, logger, model, optimizer, step, vocab, running_avg_loss, \
                       r_loss=0, title = loggerName)
        if step%1000 == 0 and step > 0:
            train_rouge_l_f = decode(writer, logger, step, config, model, batch, mode = 'train') # call batch by validate_loader
            test_rouge_l_f = decode(writer, logger, step, config, model, validate_loader, mode = 'test') # call batch by validate_loader
#             write_scalar(writer, logger, step, train_rouge_l_f, test_rouge_l_f)
            writer.add_scalars('scalar/Rouge-L',  
               {'train_rouge_l_f': train_rouge_l_f,
                'test_rouge_l_f': test_rouge_l_f
               }, step)
            logger.info('epoch %d: %d, train_rouge_l_f = %f, test_rouge_l_f = %f'
                            % (epoch, step, train_rouge_l_f, test_rouge_l_f))
    test_avg_acc = avg_acc(writer, logger, epoch, config, model, validate_loader)
    logger.info('epoch %d: %d, test_avg_acc = %f' % (epoch, step, test_avg_acc))


logger.info(u'------Training END--------')                
removeLogger(logger)


2020-04-01 21:09:58 - Pointer_generator_bert_Intra_Atten - INFO: - ------Training START--------
I0401 21:09:58.569795 140421284529984 <ipython-input-9-7d8e0123e269>:2] ------Training START--------
2020-04-01 21:18:13 - Pointer_generator_bert_Intra_Atten - INFO: - epoch 0: 1000, training loss = 4.553233, validation loss = 4.490609, running_avg_loss loss = 4.553233
I0401 21:18:13.358268 140421284529984 <ipython-input-9-7d8e0123e269>:27] epoch 0: 1000, training loss = 4.553233, validation loss = 4.490609, running_avg_loss loss = 4.553233
2020-04-01 21:18:21 - Pointer_generator_bert_Intra_Atten - INFO: - epoch 0: 1000, train_rouge_l_f = 0.044643, test_rouge_l_f = 0.054511
I0401 21:18:21.172528 140421284529984 <ipython-input-9-7d8e0123e269>:48] epoch 0: 1000, train_rouge_l_f = 0.044643, test_rouge_l_f = 0.054511


In [None]:
# vocab._word2id.values()
# len(bert_data.bert_tokenizer)

# bert_model.embeddings