In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import sys
# sys.path.append('/scratch/tx443/NLU/project/NLU_OIE_UnifiedModels/Machine_Translation_NLP')
# # print(sys.path)

In [3]:
import json
import pandas as pd
import time

import jieba
import re

from config import vocab_pred, vocab_pred_size, vocab_prefix
from config import UNK_index, PAD_index, SOS_index, EOS_index 
from config import OOV_pred_index, PAD_pred_index, EOS_pred_index

In [4]:
def replaceMisspred(predicate):
    '''replace missing predicate
    '''
    if predicate == '_':
        return 'P'
    else:
        return predicate
    
def character_segmentation(string):
    res = []
    for part in list(jieba.cut(string, cut_all=False)):
        if re.match('^[\da-zA-Z]+$', part):
            res.append(part)
        else:
            res.extend(list(part))
    return res

# def replaceMissinfo(aaa):
#     '''replace missing info for subjects/objects
#     '''
#     placeholder = ['Z','Y','X']
#     for i in range(len(aaa)):
#         if aaa[i] == '_':
#             aaa = aaa[:i] + placeholder.pop() + aaa[i+1:]
#     return aaa

def load_preprocess_data(data_add):
    saoke = []
    with open(data_add, 'r') as f:
        for line in f:
            saoke.append(json.loads(line))
    data = []
    # list of dict
    for sample in saoke:
        # remove some exceptions with empty facts
        if sample['logic'] == []:
            continue
        # tokenize src sentence
        sample_processed = dict()
        sample_processed['src_org'] = sample['natural']
        #sample_processed['src'] = list(jieba.cut(sample['natural'], cut_all=False))
        sample_processed['src'] = character_segmentation(sample['natural'])
        
        # transform fact list into str and tokenize
        # $ separates facts; @ separate elements for one fact; & separate objects for one fact
        sample_processed['tgt_org'] = sample['logic']
        logic_list = []
        logic_set = set()
        for fact in sample['logic']:
            fact = '@'.join([fact['subject'], replaceMisspred(fact['predicate']), '&'.join(fact['object'])])
            if not fact in logic_set:
                logic_set.add(fact)
                logic_list.append(fact)
        sample_processed['tgt_list'] = logic_list #remove duplicates
        logic_str = '$'.join(logic_list)
        sample_processed['tgt'] = character_segmentation(logic_str)
        
        data.append(sample_processed)
    return data

In [5]:
import numpy as np
from collections import Counter
from itertools import dropwhile

class Lang:
    def __init__(self, name, emb_pretrained_add=None, max_vocab_size=None):
        self.name = name
        self.word2index = None #{"$PAD$": PAD_token, "$SOS$": SOS_token, "$EOS$": EOS_token, "$UNK$": UNK_token}
        #self.word2count = None #{"$PAD$": 0, "$SOS$" : 0, "$EOS$": 0, "$UNK$": 0}
        self.index2word = None #{PAD_token: "$PAD$", SOS_token: "$SOS$", EOS_token: "$EOS$", UNK_token: "$UNK$"}
        self.max_vocab_size = max_vocab_size  # Count SOS and EOS
        self.vocab_size = None
        self.emb_pretrained_add = emb_pretrained_add
        self.embedding_matrix = None

    def build_vocab(self, data):
        all_tokens = []
        for sample in data:
            all_tokens.extend(sample['src'])
            all_tokens.extend(sample['tgt'])  
        token_counter = Counter(all_tokens)
        print('The number of unique tokens totally in dataset: ', len(token_counter))
        # remove word with freq==1 
        for key, count in dropwhile(lambda key_count: key_count[1] > 1, token_counter.most_common()):
            del token_counter[key]
        
        if self.max_vocab_size:
            vocab, count = zip(*token_counter.most_common(self.max_vocab_size))
        else:
            vocab, count = zip(*token_counter.most_common())
        
        self.index2word = vocab_prefix + list(vocab)
        word2index = dict(zip(self.index2word, range(0, len(self.index2word)))) 
#         word2index = dict(zip(vocab, range(len(vocab_prefix),len(vocab_prefix)+len(vocab)))) 
#         for idx, token in enumerate(vocab_prefix):
#             word2index[token] = idx
        self.word2index = word2index
        self.vocab_size = len(self.index2word)
        return None 

    def build_emb_weight(self):
        words_emb_dict = load_emb_vectors(self.emb_pretrained_add)
        emb_weight = np.zeros([self.vocab_size, 300])
        for i in range(len(vocab_prefix), self.vocab_size):
            emb = words_emb_dict.get(self.index2word[i], None)
            if emb is not None:
                try:
                    emb_weight[i] = emb
                except:
                    pass
                    #print(len(emb), self.index2word[i], emb)
        self.embedding_matrix = emb_weight
        return None

def load_emb_vectors(fasttest_home):
    max_num_load = 500000
    words_dict = {}
    with open(fasttest_home) as f:
        for num_row, line in enumerate(f):
            if num_row >= max_num_load:
                break
            s = line.split()
            words_dict[s[0]] = np.asarray(s[1:])
    return words_dict

In [6]:
def text2index(data, key, word2index):
    '''
    transform tokens into index as input for both src and tgt
    '''
    indexdata = []
    for line in data:
        line = line[key]
        indexdata.append([word2index[c] if c in word2index.keys() else UNK_index for c in line])
        #indexdata[-1].append(EOS_index)
    print('finish indexing')
    return indexdata

def construct_Lang(name, data, emb_pretrained_add = None, max_vocab_size = None):
    lang = Lang(name, emb_pretrained_add, max_vocab_size)
    lang.build_vocab(data)
    if emb_pretrained_add:
        lang.build_emb_weight()
    return lang

def text2symbolindex(data, key, word2index):
    '''get generation label for tgt 
    '''
    indexdata = []
    for line in data:
        line = line[key]
        indexdata.append([word2index[c] if c in word2index.keys() else OOV_pred_index for c in line])
        #indexdata[-1].append(EOS_index)
    print('symbol label finish')
    return indexdata

def copy_indicator(data, src_key='src', tgt_key='tgt'):
    '''get copy label for tgt
    '''
    indicator = []
    for sample in data:
        tgt = sample[tgt_key]
        src = sample[src_key]
        matrix = np.zeros((len(tgt), len(src)), dtype=int)
        for m in range(len(tgt)):
            for n in range(len(src)):
                if tgt[m] == src[n]:
                    matrix[m,n] = 1
        indicator.append(matrix)
    return indicator

In [7]:
# data loader
from torch.utils.data import Dataset, DataLoader
from itertools import dropwhile

class VocabDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """
    def __init__(self, src_index, tgt_index, tgt_symbolindex, tgt_indicator, data, src_clip=None, tgt_clip=None):
        """
        @param data_list: list of character
        @param target_list: list of targets

        """
        self.src_clip = src_clip
        self.tgt_clip = tgt_clip
        self.src_list, self.tgt_list = src_index, tgt_index
        self.data = data
        self.tgt_symbolindex, self.tgt_indicator  = tgt_symbolindex, tgt_indicator
        
        assert (len(self.src_list) == len(self.tgt_list) == len(self.tgt_symbolindex)== len(self.tgt_indicator))
        #self.word2index = word2index

    def __len__(self):
        return len(self.src_list)

    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        src = self.src_list[key]
        tgt = self.tgt_list[key]
        src_org = self.data[key]['src']
        tgt_org = self.data[key]['tgt']
        tgt_sym = self.tgt_symbolindex[key]
        tgt_ind = self.tgt_indicator[key]
        
        if self.src_clip is not None:
            src = src[:self.src_clip]
            src_org = src_org[:self.src_clip]
            tgt_ind = tgt_ind[:,:self.src_clip]
        src_length = len(src)

        if self.tgt_clip is not None:
            tgt = tgt[:self.tgt_clip]
            tgt_org = tgt_org[:self.tgt_clip]
            tgt_sym = tgt_sym[:self.tgt_clip]
            tgt_ind = tgt_ind[:self.tgt_clip,:]
        tgt_length = len(tgt)
        
        return src, src_length, tgt, tgt_length, tgt_sym, tgt_ind, src_org, tgt_org
        
        #return src_org, src_tensor, src_true_len, tgt_org, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy 

def vocab_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all
    data have the same length
    """
    src_list = []
    tgt_list = []
    src_length_list = []
    tgt_length_list = []
    tgt_symbol_list = []
    tgt_indicator_list = []
    src_org_list = []
    tgt_org_list = []
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for datum in batch:
        src_length_list.append(datum[1]) # 不用加1；eos不算
        tgt_length_list.append(datum[3]+1) 
    
    batch_max_src_length = np.max(src_length_list)
    batch_max_tgt_length = np.max(tgt_length_list)
    # padding
    for datum in batch:
        #+[EOS_index] -1
        padded_vec = np.pad(np.array(datum[0]), 
                                pad_width=((0, batch_max_src_length-datum[1])),
                                mode="constant", constant_values=PAD_index)
        src_list.append(padded_vec)
        
        padded_vec = np.pad(np.array(datum[2]+[EOS_index]),
                                pad_width=((0, batch_max_tgt_length-datum[3]-1)),
                                mode="constant", constant_values=PAD_index)
        tgt_list.append(padded_vec)
        
        padded_vec = np.pad(np.array(datum[4]+[EOS_pred_index]),
                                pad_width=((0, batch_max_tgt_length-datum[3]-1)),
                                mode="constant", constant_values=PAD_pred_index)
        tgt_symbol_list.append(padded_vec)
        
        indicator = np.pad(datum[5], pad_width=((0,1),(0,0)), 
                           mode='constant', constant_values=0)
        #indicator[-1,-1] = 1  -1
        padded_vec = np.pad(indicator,
                            pad_width=((0, batch_max_tgt_length-datum[3]-1),((0, batch_max_src_length-datum[1]))),
                            mode="constant", constant_values=0)
        #print(padded_vec.dtype, padded_vec.shape)
        tgt_indicator_list.append(padded_vec)
        
        src_org_list.append(datum[6])
        tgt_org_list.append(datum[7])
    
    # re-order
    ind_dec_order = np.argsort(src_length_list)[::-1]
    
    src_list = np.array(src_list)[ind_dec_order]
    src_length_list = np.array(src_length_list)[ind_dec_order]
    tgt_list = np.array(tgt_list)[ind_dec_order]
    tgt_length_list = np.array(tgt_length_list)[ind_dec_order]
    tgt_symbol_list = np.array(tgt_symbol_list)[ind_dec_order]
    #print(tgt_indicator_list[0].dtype, tgt_indicator_list[0][:5][:5])
    tgt_indicator_list = np.array(tgt_indicator_list)[ind_dec_order]
    #print(tgt_indicator_list.dtype, tgt_indicator_list.shape)
    src_org_list = [src_org_list[i] for i in ind_dec_order]
    tgt_org_list = [tgt_org_list[i] for i in ind_dec_order]
    
    #print(type(np.array(data_list)),type(np.array(label_list)))
    
    return [torch.from_numpy(src_list).to(device), 
            torch.LongTensor(src_length_list).to(device), 
            torch.from_numpy(tgt_list).to(device), 
            torch.LongTensor(tgt_length_list).to(device),
            torch.from_numpy(tgt_symbol_list).to(device),
            torch.from_numpy(tgt_indicator_list).to(device),
            src_org_list,
            tgt_org_list,           
           ]

In [8]:
# load data
data_add = '/scratch/tx443/NLU/project/SAOKE_DATA.json'
data = load_preprocess_data(data_add)

Building prefix dict from the default dictionary ...
Loading model from cache /state/partition1/job-844128/jieba.cache
Loading model cost 0.767 seconds.
Prefix dict has been built succesfully.


In [9]:
# split train val test
from sklearn.model_selection import train_test_split
train_data, val_test_data = train_test_split(data, test_size=0.3, random_state=42)
val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=45)

In [10]:
# sorted_train_data = sorted(train_data, key=lambda x: len(x['tgt']), reverse=False)
# train_data = sorted_train_data[0:3000]

# sorted_val_data = sorted(val_data, key=lambda x: len(x['tgt']), reverse=False)
# val_data = sorted_val_data[0:1000]

In [11]:
# build vocab from train for input indexing
trainLang = construct_Lang('train', train_data)

# build generation vocab for prediction
word2symbolindex = {}
for idx, token in enumerate(vocab_pred):
        word2symbolindex[token] = idx

# check
assert(UNK_index==trainLang.word2index['<UNK>'])
assert(PAD_index==trainLang.word2index['<PAD>'])
assert(SOS_index==trainLang.word2index['<SOS>'])
assert(EOS_index==trainLang.word2index['<EOS>'])

assert(OOV_pred_index==word2symbolindex['<OOV>'])
assert(PAD_pred_index==word2symbolindex['<PAD>'])
assert(EOS_pred_index==word2symbolindex['<EOS>'])

The number of unique tokens totally in dataset:  9364


In [12]:
# permute facts at this place; data['tgt']
def identity(facts_list):
    return facts_list

def reverse(facts_list):
    return facts_list[::-1]

def random_pm(facts_list):
    random_order = np.random.permutation(len(facts_list))
    return [facts_list[idx] for idx in random_order]

def last3_pm(facts_list):
    if len(facts_list) < 4:
        return facts_list
    else:
        return facts_list[-3:]+facts_list[:-3]

def permute_factOrder_tgt(data, pm_fn):
    data_len = len(data)
    for i in range(data_len):
        facts_list = data[i]['tgt_list']
        facts_list_pm = pm_fn(facts_list)
        data[i]['tgt'] = character_segmentation('$'.join(facts_list_pm))
    return None

# permute_factOrder_tgt(train_data, last3_pm)
# train_len = len(train_data)
# for i in range(train_len):
#     facts_list = train_data[i]['tgt_list']
#     facts_list_pm = facts_list[::-1]
#     train_data[i]['tgt'] = character_segmentation('$'.join(facts_list_pm))

In [13]:
# input indexing for src
start_time = time.time()
train_src_input_index = text2index(train_data, 'src', trainLang.word2index) 
val_src_input_index = text2index(val_data, 'src', trainLang.word2index) 
print(time.time()-start_time)

finish indexing
finish indexing
0.6389546394348145


In [14]:
# input indexing for tgt
train_tgt_input_index = text2index(train_data, 'tgt', trainLang.word2index) 
val_tgt_input_index = text2index(val_data, 'tgt', trainLang.word2index) 

finish indexing
finish indexing


In [15]:
# get generation label
train_label_symbolindex = text2symbolindex(train_data, 'tgt', word2symbolindex)
val_label_symbolindex = text2symbolindex(val_data, 'tgt', word2symbolindex)

symbol label finish
symbol label finish


In [16]:
# get copy label
start_time = time.time()
train_indicator = copy_indicator(train_data, 'src', 'tgt')
val_indicator = copy_indicator(val_data, 'src', 'tgt')
print(time.time()-start_time)

8.269481182098389


In [17]:
len(train_src_input_index),len(train_tgt_input_index),len(train_label_symbolindex),len(train_indicator),len(train_data)

(28564, 28564, 28564, 28564, 28564)

In [18]:
len(val_src_input_index),len(val_tgt_input_index),len(val_label_symbolindex),len(val_indicator),len(val_data)

(6121, 6121, 6121, 6121, 6121)

# Train

In [19]:
import time
import os
import torch.nn as nn
import torch
from torch import optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
# from Data_utils import VocabDataset, vocab_collate_func
# from preprocessing_util import preposs_toekn, Lang, text2index, construct_Lang
from config import device, embedding_freeze
import random
from evaluation import similarity_score, check_fact_same, predict_facts, evaluate_prediction
import pickle

In [20]:
def bridge(context):
    return State(context=context, batch_first=True)

def train(src_data, tgt_data, encoder, decoder, encoder_optimizer, decoder_optimizer, 
          teacher_forcing_ratio, vocab):
    src_org_batch, src_tensor, src_true_len = src_data
    tgt_org_batch, tgt_tensor, tgt_label_vocab, tgt_label_copy, tgt_true_len = tgt_data
    '''
    finish train for a batch
    ''' 
    encoder.train()
    decoder.train()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    loss = 0
    
    batch_size = src_tensor.size(0)
    encoder_context = encoder(src_tensor)
    state = bridge(encoder_context)
    
    decoder_input = torch.tensor([SOS_index]*batch_size, device=device).unsqueeze(1)
    step_log_likelihoods = []
    #print(decoder_hidden.size())
    #print('encoddddddddddder finishhhhhhhhhhhhhhh')
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        ### Teacher forcing: Feed the target as the next input
        decoding_token_index = 0
        tgt_max_len_batch = tgt_true_len.cpu().max().item()
        assert(tgt_max_len_batch==tgt_tensor.size(1))
        while decoding_token_index < tgt_max_len_batch:
            decoder_output, _ = decoder(decoder_input, state) # state update at each step
            #decoder_output = decoder_output.squeeze(1)

            decoding_label_vocab = tgt_label_vocab[:, decoding_token_index]
            decoding_label_copy = tgt_label_copy[:, decoding_token_index, :]
            copy_log_probs = decoder_output[:, vocab_pred_size:]+(decoding_label_copy.float()+1e-45).log()
            #mask sample which is copied only
            gen_mask = ((decoding_label_vocab!=OOV_pred_index) | (decoding_label_copy.sum(-1)==0)).float() 
            log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
            #mask log_prob value for oov_pred_index when label_vocab==oov_pred_index and is copied 
            generation_log_probs = decoder_output.gather(1, decoding_label_vocab.unsqueeze(1)) + log_gen_mask
            combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
            step_log_likelihood = torch.logsumexp(combined_gen_and_copy, dim=-1)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))
            #loss += criterion(decoder_output, tgt_tensor[:,decoding_token_index])
            decoder_input = tgt_tensor[:,decoding_token_index].unsqueeze(1)  # Teacher forcing
            decoding_token_index += 1

    else:
        ### Without teacher forcing: use its own predictions as the next input
        decoding_token_index = 0
        tgt_max_len_batch = tgt_true_len.cpu().max().item()
        assert(tgt_max_len_batch==tgt_tensor.size(1))
        while decoding_token_index < tgt_max_len_batch:
            decoder_output, _ = decoder(decoder_input, state)
            #decoder_output = decoder_output.squeeze(1)
            
            decoding_label_vocab = tgt_label_vocab[:, decoding_token_index]
            decoding_label_copy = tgt_label_copy[:, decoding_token_index, :]
            copy_log_probs = decoder_output[:, vocab_pred_size:]+(decoding_label_copy.float()+1e-45).log()
            #mask sample which is copied only
            gen_mask = ((decoding_label_vocab!=OOV_pred_index)|(decoding_label_copy.sum(-1)==0)).float() 
            log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
            #mask log_prob value for oov_pred_index when label_vocab==oov_pred_index and is copied 
            generation_log_probs = decoder_output.gather(1, decoding_label_vocab.unsqueeze(1)) + log_gen_mask
            combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
            step_log_likelihood = torch.logsumexp(combined_gen_and_copy, dim=-1)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

            topv, topi = decoder_output.topk(1, dim=-1)
            next_input = topi.detach().cpu().squeeze(1)
            decoder_input = []
            for i_batch in range(batch_size):
                pred_list = vocab_pred+src_org_batch[i_batch]
                next_input_token = pred_list[next_input[i_batch].item()]
                decoder_input.append(vocab.word2index.get(next_input_token, UNK_index))
            decoder_input = torch.tensor(decoder_input, device=device).unsqueeze(1)
            decoding_token_index += 1

    # average loss
    log_likelihoods = torch.cat(step_log_likelihoods, dim=-1)
    # mask padding for tgt
    tgt_pad_mask = sequence_mask(tgt_true_len).float()
    log_likelihoods = log_likelihoods*tgt_pad_mask
    loss = -log_likelihoods.sum()/batch_size
    loss.backward()

    ### TODO
    # clip for gradient exploding 
    encoder_optimizer.step()
    decoder_optimizer.step()

    return (loss*batch_size/tgt_pad_mask.sum()).item() #torch.div(loss, tgt_true_len.type_as(loss).mean()).item()  #/tgt_true_len.mean()


def trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
               teacher_forcing_ratio, tfr_decay_rate, model_save_info, tgt_max_len, 
               beam_size, vocab):
    
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    if model_save_info['model_path_for_resume'] is not None:
        check_point_state = torch.load(model_save_info['model_path_for_resume'])
        encoder.load_state_dict(check_point_state['encoder_state_dict'])
        encoder_optimizer.load_state_dict(check_point_state['encoder_optimizer_state_dict'])
        decoder.load_state_dict(check_point_state['decoder_state_dict'])
        decoder_optimizer.load_state_dict(check_point_state['decoder_optimizer_state_dict'])

    for epoch in range(num_epochs): 
        start_time = time.time()
        n_iter = -1
        losses = np.zeros((len(train_loader),))
        if tfr_decay_rate is not None:
            teacher_forcing_ratio *= tfr_decay_rate
        for src_tensor, src_true_len, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy, src_org_batch, tgt_org_batch in train_loader:
            n_iter += 1
            #print('start_step: ', n_iter)
            src_data = (src_org_batch, src_tensor, src_true_len)
            tgt_data = (tgt_org_batch, tgt_tensor, tgt_label_vocab, tgt_label_copy, tgt_true_len)
            loss = train(src_data, tgt_data, encoder, decoder, encoder_optimizer, 
                         decoder_optimizer, teacher_forcing_ratio, vocab)
            losses[n_iter] = loss
            if n_iter % 500 == 0:
                pass
                #print('Loss:', loss)
                #eva_start = time.time()
#                 precision, recall, val_loss = evaluate_batch(val_loader, encoder, decoder, tgt_max_len, vocab, vocab_pred_size)
#                 #print((time.time()-eva_start)/60)
#                 print('epoch: [{}/{}], step: [{}/{}], train_loss:{}, val_precision: {}, val_recall: {}, val_loss: {}'.format(
#                     epoch, num_epochs, n_iter, len(train_loader), loss, precision.mean(), recall.mean(), val_loss))
               # print('Decoder parameters grad:')
               # for p in decoder.named_parameters():
               #     print(p[0], ': ',  p[1].grad.data.abs().mean().item(), p[1].grad.data.abs().max().item(), p[1].data.abs().mean().item(), p[1].data.abs().max().item(), end=' ')
               # print('\n')
               # print('Encoder Parameters grad:')
               # for p in encoder.named_parameters():
               #     print(p[0], ': ',  p[1].grad.data.abs().mean().item(), p[1].grad.data.abs().max().item(), p[1].data.abs().mean().item(), p[1].data.abs().max().item(), end=' ')
               # print('\n')
        val_loss, src_org, tgt_org, tgt_pred = predict_facts(val_loader, encoder, decoder, tgt_max_len, vocab)
        precision, recall = evaluate_prediction(tgt_org, tgt_pred)
        epoch_train_time = (time.time()-start_time)/60
        print_str = 'epoch: [{}/{}]({}m), step: [{}/{}], train_loss:{}, val_precision: {}, val_recall: {}, val_loss: {}'.format(
            epoch, num_epochs, epoch_train_time, n_iter, len(train_loader), losses.mean(), precision.mean(), recall.mean(), val_loss)
        print_info.append(print_str)
        print(print_str)
        
        if (epoch+1) % model_save_info['epochs_per_save_model'] == 0:
            check_point_state = {
                'epoch': epoch,
                'encoder_state_dict': encoder.state_dict(),
                'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'decoder_optimizer_state_dict': decoder_optimizer.state_dict()
                }
            torch.save(check_point_state, '{}epoch_{}.pth'.format(model_save_info['model_path'], epoch))

    return None

In [22]:
paras = dict( 
    tgt_max_len = 130,
    max_src_len_dataloader =94,
    max_tgt_len_dataloader =127,
    
    hidden_size=512, 
    emb_size=300,
    num_layers=2, 
    num_heads=8, 
    inner_linear=2048,  
    prenormalized=False,
    dropout=0.0,
    layer_norm=True,
    weight_norm=False,
    stateful=None,
    classifier_type='copy', 

    teacher_forcing_ratio = 1,
    tfr_decay_rate = None, #'None means no decay'

    learning_rate = 1e-4,
    num_epochs = 30,
    batch_size = 64, 
    beam_size = 5,

    model_save_info = dict(
        model_path = 'nmt_models/T2/round3/',
        epochs_per_save_model = 1,
        model_path_for_resume = 'nmt_models/T2/round3/epoch_11.pth' #'nmt_models/epoch_0.pth'
        )
    )

In [23]:
tgt_max_len = paras['tgt_max_len']
max_src_len_dataloader = paras['max_src_len_dataloader']
max_tgt_len_dataloader = paras['max_tgt_len_dataloader']

hidden_size = paras['hidden_size']
emb_size = paras['emb_size']
num_layers = paras['num_layers']
num_heads= paras['num_heads'] 
inner_linear = paras['inner_linear']  
prenormalized = paras['prenormalized']
dropout = paras['dropout']
layer_norm = paras['layer_norm']
weight_norm = paras['weight_norm']
stateful = paras['stateful']
classifier_type = paras['classifier_type']

teacher_forcing_ratio = paras['teacher_forcing_ratio']
tfr_decay_rate = paras['tfr_decay_rate']

learning_rate = paras['learning_rate']
num_epochs = paras['num_epochs']
batch_size = paras['batch_size']
beam_size = paras['beam_size']
model_save_info = paras['model_save_info']

In [24]:
train_dataset = VocabDataset(train_src_input_index, train_tgt_input_index, 
                             train_label_symbolindex, train_indicator, train_data, 
                             max_src_len_dataloader, max_tgt_len_dataloader)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=vocab_collate_func,
                                               shuffle=True)

val_dataset = VocabDataset(val_src_input_index, val_tgt_input_index, 
                           val_label_symbolindex, val_indicator, val_data,
                           max_src_len_dataloader, max_tgt_len_dataloader)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=batch_size,
                                               collate_fn=vocab_collate_func,
                                               shuffle=False)

In [26]:
# make dir for saving models
from seq2seq.models.transformer import TransformerAttentionEncoder, TransformerAttentionDecoder, sequence_mask
from seq2seq.models.modules.state import State

if not os.path.exists(model_save_info['model_path']):
    os.makedirs(model_save_info['model_path'])
### save model hyperparameters
with open(model_save_info['model_path']+'model_params.pkl', 'wb') as f:
    model_hyparams = paras
    pickle.dump(model_hyparams, f)
print(model_hyparams)

encoder = TransformerAttentionEncoder(vocab_size=trainLang.vocab_size, hidden_size=hidden_size,
                                      embedding_size=emb_size, num_layers=num_layers,
                                      num_heads=num_heads, inner_linear=inner_linear,
                                      prenormalized=prenormalized, layer_norm=layer_norm,
                                      weight_norm=weight_norm, dropout=dropout
                                     )

decoder = TransformerAttentionDecoder(vocab_size=trainLang.vocab_size, hidden_size=hidden_size,
                                      embedding_size=emb_size, num_layers=num_layers, 
                                      num_heads=num_heads, dropout=dropout,
                                      inner_linear=inner_linear, prenormalized=prenormalized,
                                      stateful=stateful, layer_norm=layer_norm,
                                      weight_norm=weight_norm,
                                      classifier_type=classifier_type
                                     )
encoder, decoder = encoder.to(device), decoder.to(device)
print('Encoder:')
print(encoder)
print('Decoder:')
print(decoder)

{'tgt_max_len': 130, 'max_src_len_dataloader': 94, 'max_tgt_len_dataloader': 127, 'hidden_size': 512, 'emb_size': 300, 'num_layers': 2, 'num_heads': 8, 'inner_linear': 2048, 'prenormalized': False, 'dropout': 0.0, 'layer_norm': True, 'weight_norm': False, 'stateful': None, 'classifier_type': 'copy', 'teacher_forcing_ratio': 1, 'tfr_decay_rate': None, 'learning_rate': 0.0001, 'num_epochs': 30, 'batch_size': 64, 'beam_size': 5, 'model_save_info': {'model_path': 'nmt_models/T2/round3/', 'epochs_per_save_model': 1, 'model_path_for_resume': 'nmt_models/T2/round3/epoch_11.pth'}}
Encoder:
TransformerAttentionEncoder(
  (embedder): Embedding(8776, 300, padding_idx=0)
  (dropout): Dropout(p=0.0, inplace)
  (blocks): ModuleList(
    (0): EncoderBlock(
      (lnorm1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
      (lnorm2): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0)
      (attention): MultiHeadAttention(
        (linear

In [95]:
# learning_rate = 1e-4
print_info = []
trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
               teacher_forcing_ratio, tfr_decay_rate, model_save_info, tgt_max_len, 
               beam_size, trainLang)

epoch: [0/20](32.38282967805863m), step: [446/447], train_loss:1.9578206757837762, val_precision: 0.026465310724418708, val_recall: 0.029005142328128765, val_loss: 0
epoch: [1/20](32.314728510379794m), step: [446/447], train_loss:1.775627501858961, val_precision: 0.03538055774209997, val_recall: 0.03538456542459156, val_loss: 0
epoch: [2/20](32.23568295240402m), step: [446/447], train_loss:1.658952917828656, val_precision: 0.041918342620121445, val_recall: 0.0454709062289523, val_loss: 0
epoch: [3/20](32.30389209588369m), step: [446/447], train_loss:1.5639776212790402, val_precision: 0.048395390573956164, val_recall: 0.05098580738182111, val_loss: 0
epoch: [4/20](32.358332149187724m), step: [446/447], train_loss:1.4770425835445156, val_precision: 0.051291234043235344, val_recall: 0.053251951446690864, val_loss: 0
epoch: [5/20](32.289966630935666m), step: [446/447], train_loss:1.4012032376573123, val_precision: 0.0535907221171067, val_recall: 0.05719012422165502, val_loss: 0
epoch: [6/2

In [102]:
# learning_rate = 1e-4
print_info = []
trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
               teacher_forcing_ratio, tfr_decay_rate, model_save_info, tgt_max_len, 
               beam_size, trainLang)

epoch: [0/30](32.2570640206337m), step: [446/447], train_loss:0.6397109793069912, val_precision: 0.07990414212845187, val_recall: 0.08063335174509818, val_loss: 0
epoch: [1/30](32.25934288104375m), step: [446/447], train_loss:0.5995288051214794, val_precision: 0.07360357607866184, val_recall: 0.07953000728690976, val_loss: 0
epoch: [2/30](32.260063874721524m), step: [446/447], train_loss:0.5735238673836326, val_precision: 0.0737824420310942, val_recall: 0.08070097549349305, val_loss: 0
epoch: [3/30](32.282189428806305m), step: [446/447], train_loss:0.545445200847566, val_precision: 0.0811959378819369, val_recall: 0.08220211190722544, val_loss: 0
epoch: [4/30](32.31865317424138m), step: [446/447], train_loss:0.5276610225905775, val_precision: 0.07864370422367481, val_recall: 0.07758770164831266, val_loss: 0
epoch: [5/30](32.33748585383098m), step: [446/447], train_loss:0.5012625121163575, val_precision: 0.08142064510304994, val_recall: 0.08601231584404267, val_loss: 0
epoch: [6/30](32.2

KeyboardInterrupt: 

In [28]:
check_point_state = torch.load(model_save_info['model_path_for_resume'])
encoder.load_state_dict(check_point_state['encoder_state_dict'])
# encoder_optimizer.load_state_dict(check_point_state['encoder_optimizer_state_dict'])
decoder.load_state_dict(check_point_state['decoder_state_dict'])
# decoder_optimizer.load_state_dict(check_point_state['decoder_optimizer_state_dict'])

In [53]:
trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
           teacher_forcing_ratio, tfr_decay_rate, model_save_info, tgt_max_len, 
           beam_size, trainLang)

epoch: [0/7](5.341431697209676m), step: [446/447], train_loss:1.7722327897212649, val_precision: 0.17783911300085128, val_recall: 0.16988360982806333, val_loss: 0
epoch: [1/7](5.338186713059743m), step: [446/447], train_loss:0.7802168484235503, val_precision: 0.2634523226052388, val_recall: 0.25105687302223817, val_loss: 0
epoch: [2/7](5.354854818185171m), step: [446/447], train_loss:0.6185777706054499, val_precision: 0.30115507892423427, val_recall: 0.2845378305188793, val_loss: 0
epoch: [3/7](5.347332378228505m), step: [446/447], train_loss:0.5228940870537854, val_precision: 0.334055281972289, val_recall: 0.3126340658198198, val_loss: 0
epoch: [4/7](5.331371068954468m), step: [446/447], train_loss:0.4618112228860791, val_precision: 0.354471335993963, val_recall: 0.3192828346479711, val_loss: 0
epoch: [5/7](5.336230289936066m), step: [446/447], train_loss:0.4151985651294657, val_precision: 0.3551781143759579, val_recall: 0.32712683584273194, val_loss: 0
epoch: [6/7](5.339213335514069m

In [68]:
trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
           teacher_forcing_ratio, tfr_decay_rate, model_save_info, tgt_max_len, 
           beam_size, trainLang)

epoch: [0/7](5.338468774159749m), step: [446/447], train_loss:1.7606931774141539, val_precision: 0.20061614582117768, val_recall: 0.17853682105438354, val_loss: 0
epoch: [1/7](5.322386189301809m), step: [446/447], train_loss:0.7401965295175045, val_precision: 0.2946695995830124, val_recall: 0.2769553954368527, val_loss: 0
epoch: [2/7](5.345347181955973m), step: [446/447], train_loss:0.570334336968343, val_precision: 0.32535961288616083, val_recall: 0.2968977237070686, val_loss: 0
epoch: [3/7](5.349120012919108m), step: [446/447], train_loss:0.48958232072109076, val_precision: 0.3423187283953498, val_recall: 0.31296919058722683, val_loss: 0
epoch: [4/7](5.362560017903646m), step: [446/447], train_loss:0.4364301985008871, val_precision: 0.3684664166815776, val_recall: 0.34584178777529534, val_loss: 0
epoch: [5/7](5.36743247906367m), step: [446/447], train_loss:0.3962817492767735, val_precision: 0.3590141148219893, val_recall: 0.3499998172968276, val_loss: 0
epoch: [6/7](5.340313796202341

In [152]:
# permute_factOrder_tgt(val_data, reverse)

In [29]:
loader = train_loader
tgt_max_length = tgt_max_len
loss, src_org, tgt_org, tgt_pred = predict_facts(loader, encoder, decoder, tgt_max_length, trainLang)
precision, recall = evaluate_prediction(tgt_org, tgt_pred)

In [30]:
print(precision.mean(), recall.mean())

0.510060675022 0.494971937556


In [151]:
print(precision.mean(), recall.mean())

0.340255745725 0.334618997877


In [55]:
val_dataset1 = VocabDataset(val_src_input_index, val_tgt_input_index, 
                             val_label_symbolindex, val_indicator, val_data)
val_loader1 = torch.utils.data.DataLoader(dataset=val_dataset1,
                                               batch_size=batch_size,
                                               collate_fn=vocab_collate_func,
                                               shuffle=False)

train_dataset1 = VocabDataset(train_src_input_index, train_tgt_input_index, 
                             train_label_symbolindex, train_indicator, train_data, 
                             max_src_len_dataloader, max_tgt_len_dataloader)

train_loader1 = torch.utils.data.DataLoader(dataset=train_dataset1,
                                               batch_size=batch_size,
                                               collate_fn=vocab_collate_func,
                                               shuffle=False)

In [26]:
def similarity_score(fact1, fact2):
    elem1 = fact1.split('@')
    elem2 = fact2.split('@')
    n1 = len(elem1)
    n2 = len(elem2)
    sim = 0
    for i in range(min(n1,n2)):
        sim += difflib.SequenceMatcher(None,elem1[i],elem2[i]).ratio()
    return sim/max(n1,n2)

def check_fact_same(org_fact, pred_fact):
    org_fact_ele = org_fact.split('@')
    pred_fact_ele = pred_fact.split('@')
    if len(org_fact_ele) == len(pred_fact_ele):
        ele_num = len(org_fact_ele)
        if difflib.SequenceMatcher(None,org_fact,pred_fact).ratio() > 0.85:
            return True       
        ele_sim = np.zeros((ele_num,))
        for ele_i in range(ele_num):
            ele_sim[ele_i] = difflib.SequenceMatcher(None,org_fact_ele[ele_i],pred_fact_ele[ele_i]).ratio()
        if ele_sim.min() > 0.85:
            return True
    return False

In [67]:
tgt_max_length = 130
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
encoder.eval()
decoder.eval()

tgt_pred = []
src_org = []
tgt_org = []
loss = 0
loader = val_loader

for src_tensor, src_true_len, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy, src_org_batch, tgt_org_batch in loader:
#     src_tensor, tgt_tensor, tgt_true_len = src_tensor.to(device), tgt_tensor.to(device), tgt_true_len.to(device)
#     tgt_label_vocab, tgt_label_copy = tgt_label_vocab.to(device), tgt_label_copy.to(device)
    
    batch_size = src_tensor.size(0)
    encoder_context = encoder(src_tensor)
    state = bridge(encoder_context)

    decoder_input = torch.tensor([SOS_index]*batch_size, device=device).unsqueeze(1)

    decoding_token_index = 0
    stop_flag = [False]*batch_size
    step_log_likelihoods = []
    tgt_pred_batch = [[] for i_batch in range(batch_size)]
    tgt_true_len_max = tgt_true_len.cpu().numpy().max()
    while decoding_token_index < tgt_max_length:
        decoder_output, _ = decoder(decoder_input, state)
        # compute loss 
        if decoding_token_index < tgt_true_len_max:
            decoding_label_vocab = tgt_label_vocab[:, decoding_token_index]
            decoding_label_copy = tgt_label_copy[:, decoding_token_index, :]
            copy_log_probs = decoder_output[:, vocab_pred_size:]+(decoding_label_copy.float()+1e-45).log()
            #mask sample which is copied only
            gen_mask = ((decoding_label_vocab!=OOV_pred_index) | (decoding_label_copy.sum(-1)==0)).float() 
            log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
            #mask log_prob value for oov_pred_index when label_vocab==oov_pred_index and is copied 
            generation_log_probs = decoder_output.gather(1, decoding_label_vocab.unsqueeze(1)) + log_gen_mask
            combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
            step_log_likelihood = torch.logsumexp(combined_gen_and_copy, dim=-1)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

        #
        topv, topi = decoder_output.topk(1, dim=-1)
        next_input = topi.detach().cpu().squeeze(1)
        decoder_input = []
        for i_batch in range(batch_size):
            pred_list = vocab_pred+src_org_batch[i_batch]
            next_input_token = pred_list[next_input[i_batch].item()]
            if next_input_token == vocab_pred[EOS_pred_index]:
                stop_flag[i_batch] = True
            if not stop_flag[i_batch]:
                tgt_pred_batch[i_batch].append(next_input_token)
            decoder_input.append(trainLang.word2index.get(next_input_token, UNK_index))
        decoder_input = torch.tensor(decoder_input, device=device).unsqueeze(1)
        decoding_token_index += 1
        if all(stop_flag):
            break
    log_likelihoods = torch.cat(step_log_likelihoods, dim=-1)
    # mask padding for tgt
    tgt_pad_mask = sequence_mask(tgt_true_len).float()
    log_likelihoods = log_likelihoods*tgt_pad_mask[:,:log_likelihoods.size(1)]
    loss += -(log_likelihoods.sum()/tgt_pad_mask.sum()).item()
    tgt_pred.extend(tgt_pred_batch)
    src_org.extend(src_org_batch)
    tgt_org.extend(tgt_org_batch)
loss = loss/len(loader)

In [68]:
len(tgt_pred)

1000

In [69]:
import difflib
from scipy.optimize import linear_sum_assignment

eval_len = len(tgt_pred)
precision = np.zeros((eval_len,))
recall = np.zeros((eval_len,))
Fscore = np.zeros((eval_len,))
for i in range(eval_len):
    org_facts = ''.join(tgt_org[i]).split('$')
    pred_facts = ''.join(tgt_pred[i]).split('$')
    pred_facts = list(set(pred_facts))
    org_facts_num = len(org_facts)
    pred_facts_num = len(pred_facts)
    org_match_num = np.zeros((org_facts_num))
    pred_match_num = np.zeros((pred_facts_num))
    similarity_ma = np.zeros((org_facts_num, pred_facts_num))
    for org_i in range(org_facts_num):
        for pred_i in range(pred_facts_num):
            similarity_ma[org_i, pred_i] = similarity_score(org_facts[org_i], pred_facts[pred_i])
    row_ind, col_ind = linear_sum_assignment(-similarity_ma)
    
    for org_i, pred_i in zip(row_ind, col_ind):
        org_fact = org_facts[org_i]
        pred_fact = pred_facts[pred_i]
        fact_same = check_fact_same(org_fact, pred_fact)
        if fact_same:
            org_match_num[org_i] = 1
            pred_match_num[pred_i] = 1
#     print(pred_match_num)
#     print(org_match_num)
    precision[i] = pred_match_num.mean()
    recall[i] = org_match_num.mean()
    Fscore[i] = 2*precision[i]*recall[i]/(precision[i]+recall[i]+1e-10)
if False:
    random_sample = np.random.randint(eval_len)
    print('src: ', src_org[random_sample])
    print('Ref: ', tgt_org[random_sample])
    print('pred: ', tgt_pred[random_sample])

In [70]:
print(precision.mean(), recall.mean(), Fscore.mean())

0.034 0.033 0.0333333333316


In [47]:
len(precision)

1000

In [71]:
random_sample = 300
print('src: ', ''.join(src_org[random_sample]))
print('Ref: ', ''.join(tgt_org[random_sample]))
print('pred: ', ''.join(tgt_pred[random_sample]))

src:  这从"嘉德本"可以看得很清楚。
Ref:  这@可以看得@很清楚
pred:  _@嘉@很清本


In [64]:
train_data[5]

{'src': ['067', '章', ' ', '矿', '洞', '抓', '人', '（', '求', '推', '荐', '票', '）'],
 'src_org': '067章 矿洞抓人（求推荐票）',
 'tgt': ['_', '@', '抓', '@', '人'],
 'tgt_list': ['_@抓@人'],
 'tgt_org': [{'object': ['人'],
   'place': '矿洞',
   'predicate': '抓',
   'qualifier': '_',
   'subject': '_',
   'time': '_'}]}

In [169]:
len(src_org[random_sample])

94

In [172]:
t1 = '企业@大力开发@自主知识产权'
t2= '_@大力开发@自主知识产权的新产品'
difflib.SequenceMatcher(None, t1, t2).ratio()

0.7741935483870968

In [175]:
t1_ele = t1.split('@')
t2_ele = t2.split('@')
ele_num = len(t1_ele)
ele_sim = np.zeros((ele_num,))
for ele_i in range(ele_num):
    ele_sim[ele_i] = difflib.SequenceMatcher(None,t1_ele[ele_i],t2_ele[ele_i]).ratio()
print(ele_sim)

[ 0.    1.    0.75]
