In [1]:
import warnings
warnings.filterwarnings("ignore")
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import time
import re


import torch as T



import torch.nn as nn
import torch.nn.functional as F


from model import Model


from data_util import config, data
from data_util.batcher import Batcher
from data_util.data import Vocab


from train_util import *
from torch.distributions import Categorical
from rouge import Rouge
from numpy import random
import argparse
import torchsnooper
import logging

# -------- Test Packages -------
from beam_search import *
import shutil
from tensorboardX import SummaryWriter
from nltk.translate.bleu_score import corpus_bleu

config.lr = 0.0001
config.batch_size = 2
config.gound_truth_prob = 0.1


config.intra_encoder = False
config.intra_decoder = False

# Logger

In [2]:
from datetime import datetime as dt

def getLogger(loggerName, loggerPath):
    # 設置logger
    logger = logging.getLogger(loggerName)  # 不加名稱設置root logger
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s: - %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S')
    logging.Filter(loggerName)

    # 使用FileHandler輸出到文件
    directory = os.path.dirname(loggerPath)
    if not os.path.exists(directory):
        os.makedirs(directory)
    fh = logging.FileHandler(loggerPath)

    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)

    # 使用StreamHandler輸出到屏幕
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    # 添加兩個Handler
    logger.addHandler(ch)
    logger.addHandler(fh)
    # Handler只啟動一次
    # 設置logger
    logger.info(u'logger已啟動')
    return logger

def removeLogger(logger):
    logger.info(u'logger已關閉')
    handlers = logger.handlers[:]
    for handler in handlers:
        handler.close()
        logger.removeHandler(handler)

# View batch data

In [3]:
def test_batch():
    vocab = Vocab(config.vocab_path, config.vocab_size)
    batcher = Batcher(config.train_data_path, vocab, mode='train',
                           batch_size=config.batch_size, single_pass=False)
    batch = batcher.next_batch()
    # with torchsnooper.snoop():
    while batch is not None:
        example_list = batch.example_list
        for ex in example_list:
            r = str(ex.original_review)
            s = str(ex.original_summary)
            k = str(ex.key_words)
            sent = ex.original_summary_sents
#             print("original_review_sents:", r)
            print("original_summary_sents : ", s)
            print("key_words : ", k)
            print('------------------------------------------------------------\n')
        batch = batcher.next_batch()        
# test_batch()

# Get Bin Information

In [4]:
with open(config.bin_info,'r',encoding='utf-8') as f:
    lines = f.readlines()
    [print(line) for line in lines]
    train_num = int(lines[0].split(":")[1])
    test_num = int(lines[1].split(":")[1])
    val_num = int(lines[2].split(":")[1])
    # f.write("train : %s\n"%(len(flit_key_train_df)))
    # f.write("test : %s\n"%(len(flit_key_test_df)))
    # f.write("valid : %s\n"%(len(flit_key_valid_df)))


train : 29540

test : 5847

valid : 4243



# Summary Encoder

In [5]:
from torchsummaryX import summary
from model import Encoder,Model
device = T.device("cuda" if T.cuda.is_available() else "cpu") # PyTorch v0.4.0
encoder = Encoder().to(device)    

vocab = Vocab(config.vocab_path, config.vocab_size)
batcher = Batcher(config.train_data_path, vocab, mode='train',
                       batch_size=config.batch_size, single_pass=False)
batch = batcher.next_batch()
enc_batch, enc_lens, enc_padding_mask, enc_key_batch, enc_key_lens, enc_key_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(batch)
enc_batch = Model(False,'word2Vec',vocab).embeds(enc_batch) #Get embeddings for encoder input

# summary(encoder, enc_batch, enc_lens) # encoder summary

# Summary Decoder

In [6]:
from torchsummaryX import summary
from model import Decoder,Model
from train_util import *
device = T.device("cuda" if T.cuda.is_available() else "cpu") # PyTorch v0.4.0
# decoder = Decoder().to(device)    

model = Model(False,'word2Vec',vocab)
vocab = Vocab(config.vocab_path, config.vocab_size)
batcher = Batcher(config.train_data_path, vocab, mode='train',
                       batch_size=config.batch_size, single_pass=False)
batch = batcher.next_batch()
enc_batch, enc_lens, enc_padding_mask, enc_key_batch, enc_key_lens, enc_key_padding_mask, enc_batch_extend_vocab, extra_zeros, context = get_enc_data(batch)
enc_batch = model.embeds(enc_batch) #Get embeddings for encoder input
enc_out, enc_hidden = model.encoder(enc_batch, enc_lens)

print('enc_out',enc_out.shape)
print('enc_hidden',enc_hidden[0].shape)

# train_batch_MLE
dec_batch, max_dec_len, dec_lens, target_batch = get_dec_data(batch)                        #Get input and target batchs for training decoder
step_losses = []
s_t = (enc_hidden[0], enc_hidden[1])                                                        #Decoder hidden states
# x_t 為decoder每一個time step 的batch input
x_t = get_cuda(T.LongTensor(len(enc_out)).fill_(2))                             #Input to the decoder
prev_s = None                                                                               #Used for intra-decoder attention (section 2.2 in DEEP REINFORCED MODEL - https://arxiv.org/pdf/1705.04304.pdf)
sum_temporal_srcs = None  
unk_id = vocab.word2id(data.PAD_TOKEN)
# print('x_t',x_t.shape)             
# print(vocab._id_to_word)    
# print(enc_hidden[0].shape)
for t in range(min(max_dec_len, config.max_dec_steps)):
    use_gound_truth = get_cuda((T.rand(len(enc_out)) > 0.25)).long()                        #Probabilities indicating whether to use ground truth labels instead of previous decoded tokens
    # use_gound_truth * dec_batch[:, t] : 為ground true time step token
    # (1 - use_gound_truth) * x_t : 為previous time step token
#     print('--------------------------')
#     print('x_t',x_t)
#     print('gound_truth_t',use_gound_truth * dec_batch[:, t])
#     print('none_gound_truth_t',(1 - use_gound_truth) * x_t)
    x_t = use_gound_truth * dec_batch[:, t] + (1 - use_gound_truth) * x_t                   #Select decoder input based on use_ground_truth probabilities
#     print('x_t',x_t)
#     decode_xt = [vocab.id2word(i) for i in x_t.tolist()]
#     print('decode_xt',decode_xt)
    x_t = model.embeds(x_t)
    enc_key_batch = model.embeds(enc_key_batch)
#     print('s_t',s_t[0].shape)
    print('x_t',x_t.shape)
    final_dist, s_t, ct_e, sum_temporal_srcs, prev_s = model.decoder(
    x_t, s_t, enc_out, enc_padding_mask,context, 
    extra_zeros,enc_batch_extend_vocab,sum_temporal_srcs, prev_s, 
    enc_key_batch, enc_key_lens)
    
#     target = target_batch[:, t]
#     log_probs = T.log(final_dist + config.eps)
#     step_loss = F.nll_loss(log_probs, target, reduction="none", ignore_index=unk_id)
#     step_losses.append(step_loss)
#     x_t = T.multinomial(final_dist,1).squeeze()  # Sample words from final distribution which can be used as input in next time step
    
#     is_oov = (x_t >= config.vocab_size).long()  # Mask indicating whether sampled word is OOV
#     x_t = (1 - is_oov) * x_t.detach() + (is_oov) * unk_id  # Replace OOVs with [UNK] token
    
    
    
    decoder_summary = summary(model.decoder, x_t, s_t, enc_out, enc_padding_mask, context, extra_zeros, enc_batch_extend_vocab, sum_temporal_srcs, prev_s,enc_key_batch, enc_key_lens) # encoder summary
    break
decoder_summary


enc_out torch.Size([2, 200, 1024])
enc_hidden torch.Size([2, 512])
x_t torch.Size([2, 300])
                            Kernel Shape    Output Shape     Params  Mult-Adds
Layer                                                                         
0_x_context                  [1324, 300]        [2, 300]     397.5k     397.2k
1_lstm                                 -        [2, 512]  1.667072M  1.662976M
2_enc_attention.Linear_W_h  [1024, 1024]  [2, 200, 1024]  1.048576M  1.048576M
3_enc_attention.Linear_W_s  [1024, 1024]       [2, 1024]    1.0496M  1.048576M
4_enc_attention.Linear_v       [1024, 1]     [2, 200, 1]     1.024k     1.024k
5_dec_attention                        -        [2, 512]          -          -
6_p_gen_linear                 [2860, 1]          [2, 1]     2.861k      2.86k
7_V                          [2048, 512]        [2, 512]  1.049088M  1.048576M
8_V1                        [512, 50000]      [2, 50000]     25.65M      25.6M
---------------------------------------

Unnamed: 0_level_0,Kernel Shape,Output Shape,Params,Mult-Adds
Layer,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0_x_context,"[1324, 300]","[2, 300]",397500.0,397200.0
1_lstm,-,"[2, 512]",1667072.0,1662976.0
2_enc_attention.Linear_W_h,"[1024, 1024]","[2, 200, 1024]",1048576.0,1048576.0
3_enc_attention.Linear_W_s,"[1024, 1024]","[2, 1024]",1049600.0,1048576.0
4_enc_attention.Linear_v,"[1024, 1]","[2, 200, 1]",1024.0,1024.0
5_dec_attention,-,"[2, 512]",,
6_p_gen_linear,"[2860, 1]","[2, 1]",2861.0,2860.0
7_V,"[2048, 512]","[2, 512]",1049088.0,1048576.0
8_V1,"[512, 50000]","[2, 50000]",25650000.0,25600000.0
