In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import pandas as pd

import jieba
import re

from config import vocab_pred, vocab_pred_size, vocab_prefix
from config import UNK_index, PAD_index, SOS_index, EOS_index 
from config import OOV_pred_index, PAD_pred_index, EOS_pred_index

In [3]:
import os
os.chdir('Machine_Translation_NLP')

In [5]:
def replaceMisspred(predicate):
    '''replace missing predicate
    '''
    if predicate == '_':
        return 'P'
    else:
        return predicate
    
def character_segmentation(string):
    res = []
    for part in list(jieba.cut(string, cut_all=False)):
        if re.match('^[\da-zA-Z]+$', part):
            res.append(part)
        else:
            res.extend(list(part))
    return res

# def replaceMissinfo(aaa):
#     '''replace missing info for subjects/objects
#     '''
#     placeholder = ['Z','Y','X']
#     for i in range(len(aaa)):
#         if aaa[i] == '_':
#             aaa = aaa[:i] + placeholder.pop() + aaa[i+1:]
#     return aaa

def load_preprocess_data(data_add):
    saoke = []
    with open(data_add, 'r') as f:
        for line in f:
            saoke.append(json.loads(line))
    data = []
    # list of dict
    for sample in saoke:
        # remove some exceptions with empty facts
        if sample['logic'] == []:
            continue
        # tokenize src sentence
        sample_processed = dict()
        sample_processed['src_org'] = sample['natural']
        #sample_processed['src'] = list(jieba.cut(sample['natural'], cut_all=False))
        sample_processed['src'] = character_segmentation(sample['natural'])
        
        # transform fact list into str and tokenize
        # $ separates facts; @ separate elements for one fact; & separate objects for one fact
        sample_processed['tgt_org'] = sample['logic']
        logic_list = []
        for fact in sample['logic']:
            logic_list.append('@'.join([fact['subject'], replaceMisspred(fact['predicate']), 
                                       '&'.join(fact['object'])]))
        logic_str = '$'.join(logic_list)
        sample_processed['tgt'] = character_segmentation(logic_str)
        #list(jieba.cut(logic_str, cut_all=False))

        data.append(sample_processed)
    return data

In [None]:
import numpy as np
from collections import Counter
from itertools import dropwhile

class Lang:
    def __init__(self, name, emb_pretrained_add=None, max_vocab_size=None):
        self.name = name
        self.word2index = None #{"$PAD$": PAD_token, "$SOS$": SOS_token, "$EOS$": EOS_token, "$UNK$": UNK_token}
        #self.word2count = None #{"$PAD$": 0, "$SOS$" : 0, "$EOS$": 0, "$UNK$": 0}
        self.index2word = None #{PAD_token: "$PAD$", SOS_token: "$SOS$", EOS_token: "$EOS$", UNK_token: "$UNK$"}
        self.max_vocab_size = max_vocab_size  # Count SOS and EOS
        self.vocab_size = None
        self.emb_pretrained_add = emb_pretrained_add
        self.embedding_matrix = None

    def build_vocab(self, data):
        all_tokens = []
        for sample in data:
            all_tokens.extend(sample['src'])
            all_tokens.extend(sample['tgt'])  
        token_counter = Counter(all_tokens)
        print('The number of unique tokens totally in dataset: ', len(token_counter))
        # remove word with freq==1 
        for key, count in dropwhile(lambda key_count: key_count[1] > 1, token_counter.most_common()):
            del token_counter[key]
        
        if self.max_vocab_size:
            vocab, count = zip(*token_counter.most_common(self.max_vocab_size))
        else:
            vocab, count = zip(*token_counter.most_common())
        
        self.index2word = vocab_prefix + list(vocab)
        word2index = dict(zip(self.index2word, range(0, len(self.index2word)))) 
#         word2index = dict(zip(vocab, range(len(vocab_prefix),len(vocab_prefix)+len(vocab)))) 
#         for idx, token in enumerate(vocab_prefix):
#             word2index[token] = idx
        self.word2index = word2index
        self.vocab_size = len(self.index2word)
        return None 

    def build_emb_weight(self):
        words_emb_dict = load_emb_vectors(self.emb_pretrained_add)
        emb_weight = np.zeros([self.vocab_size, 300])
        for i in range(len(vocab_prefix), self.vocab_size):
            emb = words_emb_dict.get(self.index2word[i], None)
            if emb is not None:
                try:
                    emb_weight[i] = emb
                except:
                    pass
                    #print(len(emb), self.index2word[i], emb)
        self.embedding_matrix = emb_weight
        return None

def load_emb_vectors(fasttest_home):
    max_num_load = 500000
    words_dict = {}
    with open(fasttest_home) as f:
        for num_row, line in enumerate(f):
            if num_row >= max_num_load:
                break
            s = line.split()
            words_dict[s[0]] = np.asarray(s[1:])
    return words_dict

In [None]:
def text2index(data, key, word2index):
    '''
    transform tokens into index as input for both src and tgt
    '''
    indexdata = []
    for line in data:
        line = line[key]
        indexdata.append([word2index[c] if c in word2index.keys() else UNK_index for c in line])
        #indexdata[-1].append(EOS_index)
    print('finish indexing')
    return indexdata

def construct_Lang(name, data, emb_pretrained_add = None, max_vocab_size = None):
    lang = Lang(name, emb_pretrained_add, max_vocab_size)
    lang.build_vocab(data)
    if emb_pretrained_add:
        lang.build_emb_weight()
    return lang

def text2symbolindex(data, key, word2index):
    '''get generation label for tgt 
    '''
    indexdata = []
    for line in data:
        line = line[key]
        indexdata.append([word2index[c] if c in word2index.keys() else OOV_pred_index for c in line])
        #indexdata[-1].append(EOS_index)
    print('symbol label finish')
    return indexdata

def copy_indicator(data, src_key='src', tgt_key='tgt'):
    '''get copy label for tgt
    '''
    indicator = []
    for sample in data:
        tgt = sample[tgt_key]
        src = sample[src_key]
        matrix = np.zeros((len(tgt), len(src)), dtype=int)
        for m in range(len(tgt)):
            for n in range(len(src)):
                if tgt[m] == src[n]:
                    matrix[m,n] = 1
        indicator.append(matrix)
    return indicator

In [93]:
# load data
data_add = '/scratch/tx443/NLU/project/SAOKE_DATA.json'
data = load_preprocess_data(data_add)

In [94]:
# split train val test
from sklearn.model_selection import train_test_split
train_data, val_test_data = train_test_split(data, test_size=0.3, random_state=42)
val_data, test_data = train_test_split(val_test_data, test_size=0.5, random_state=45)

In [95]:
# sorted_train_data = sorted(train_data, key=lambda x: len(x['tgt']), reverse=False)
# train_data = sorted_train_data[0:3000]

In [96]:
# build vocab from train for input indexing
trainLang = construct_Lang('train', train_data)

# build generation vocab for prediction
word2symbolindex = {}
for idx, token in enumerate(vocab_pred):
        word2symbolindex[token] = idx

# check
assert(UNK_index==trainLang.word2index['<UNK>'])
assert(PAD_index==trainLang.word2index['<PAD>'])
assert(SOS_index==trainLang.word2index['<SOS>'])
assert(EOS_index==trainLang.word2index['<EOS>'])

assert(OOV_pred_index==word2symbolindex['<OOV>'])
assert(PAD_pred_index==word2symbolindex['<PAD>'])
assert(EOS_pred_index==word2symbolindex['<EOS>'])

The number of unique tokens totally in dataset:  9364


In [97]:
# permute facts at this place; data['tgt']
train_len = len(train_data)
for i in range(train_len):
    facts_list = ''.join(train_data[i]['tgt']).split('$')
    facts_list_pm = facts_list[::-1]
    train_data[i]['tgt'] = character_segmentation('$'.join(facts_list))

In [98]:
# input indexing for src
train_src_input_index = text2index(train_data, 'src', trainLang.word2index) 
val_src_input_index = text2index(val_data, 'src', trainLang.word2index) 

finish indexing
finish indexing


In [99]:
# input indexing for tgt
train_tgt_input_index = text2index(train_data, 'tgt', trainLang.word2index) 
val_tgt_input_index = text2index(val_data, 'tgt', trainLang.word2index) 

finish indexing
finish indexing


In [100]:
# get generation label
train_label_symbolindex = text2symbolindex(train_data, 'tgt', word2symbolindex)
val_label_symbolindex = text2symbolindex(val_data, 'tgt', word2symbolindex)

symbol label finish
symbol label finish


In [101]:
# get copy label
train_indicator = copy_indicator(train_data, 'src', 'tgt')
val_indicator = copy_indicator(val_data, 'src', 'tgt')

In [102]:
len(train_src_input_index),len(train_tgt_input_index),len(train_label_symbolindex),len(train_indicator),len(train_data)

(28564, 28564, 28564, 28564, 28564)

In [103]:
len(val_src_input_index),len(val_tgt_input_index),len(val_label_symbolindex),len(val_indicator),len(val_data)

(6121, 6121, 6121, 6121, 6121)

# Train

In [104]:
import time
import os
import torch.nn as nn
import torch
from torch import optim
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
# from Data_utils import VocabDataset, vocab_collate_func
# from preprocessing_util import preposs_toekn, Lang, text2index, construct_Lang
from Multilayers_Encoder import EncoderRNN
from Multilayers_Decoder import DecoderAtten, sequence_mask
from config import device, embedding_freeze
import random
from evaluation import similarity_score, check_fact_same, predict_facts, evaluate_prediction
import pickle

In [105]:
def train(src_data, tgt_data, encoder, decoder, encoder_optimizer, decoder_optimizer, 
          teacher_forcing_ratio, vocab):
    src_org_batch, src_tensor, src_true_len = src_data
    tgt_org_batch, tgt_tensor, tgt_label_vocab, tgt_label_copy, tgt_true_len = tgt_data
    '''
    finish train for a batch
    '''
    encoder.train()
    decoder.train()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    batch_size = src_tensor.size(0)
    encoder_hidden, encoder_cell = encoder.initHidden(batch_size)
    loss = 0
    encoder_outputs, encoder_hidden, encoder_cell = encoder(src_tensor, encoder_hidden, src_true_len, encoder_cell)

    decoder_input = torch.tensor([[SOS_index]*batch_size], device=device).transpose(0,1)
    decoder_hidden, decoder_cell = encoder_hidden, decoder.initHidden(batch_size)
    step_log_likelihoods = []
    #print(decoder_hidden.size())
    #print('encoddddddddddder finishhhhhhhhhhhhhhh')
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    if use_teacher_forcing:
        ### Teacher forcing: Feed the target as the next input
        decoding_token_index = 0
        tgt_max_len_batch = tgt_true_len.cpu().max().item()
        assert(tgt_max_len_batch==tgt_tensor.size(1))
        while decoding_token_index < tgt_max_len_batch:
            decoder_output, decoder_hidden, decoder_attention, decoder_cell = decoder(
                decoder_input, decoder_hidden, src_true_len, encoder_outputs, decoder_cell)

            decoding_label_vocab = tgt_label_vocab[:, decoding_token_index]
            decoding_label_copy = tgt_label_copy[:, decoding_token_index, :]
            copy_log_probs = decoder_output[:, vocab_pred_size:]+(decoding_label_copy.float()+1e-45).log()
            #mask sample which is copied only
            gen_mask = ((decoding_label_vocab!=OOV_pred_index) | (decoding_label_copy.sum(-1)==0)).float() 
            log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
            #mask log_prob value for oov_pred_index when label_vocab==oov_pred_index and is copied 
            generation_log_probs = decoder_output.gather(1, decoding_label_vocab.unsqueeze(1)) + log_gen_mask
            combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
            step_log_likelihood = torch.logsumexp(combined_gen_and_copy, dim=-1)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))
            #loss += criterion(decoder_output, tgt_tensor[:,decoding_token_index])
            decoder_input = tgt_tensor[:,decoding_token_index].unsqueeze(1)  # Teacher forcing
            decoding_token_index += 1

    else:
        ### Without teacher forcing: use its own predictions as the next input
        decoding_token_index = 0
        tgt_max_len_batch = tgt_true_len.cpu().max().item()
        assert(tgt_max_len_batch==tgt_tensor.size(1))
        while decoding_token_index < tgt_max_len_batch:
            decoder_output, decoder_hidden, decoder_attention_weights, decoder_cell = decoder(
                decoder_input, decoder_hidden, src_true_len, encoder_outputs, decoder_cell)

            decoding_label_vocab = tgt_label_vocab[:, decoding_token_index]
            decoding_label_copy = tgt_label_copy[:, decoding_token_index, :]
            copy_log_probs = decoder_output[:, vocab_pred_size:]+(decoding_label_copy.float()+1e-45).log()
            #mask sample which is copied only
            gen_mask = ((decoding_label_vocab!=OOV_pred_index)|(decoding_label_copy.sum(-1)==0)).float() 
            log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
            #mask log_prob value for oov_pred_index when label_vocab==oov_pred_index and is copied 
            generation_log_probs = decoder_output.gather(1, decoding_label_vocab.unsqueeze(1)) + log_gen_mask
            combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
            step_log_likelihood = torch.logsumexp(combined_gen_and_copy, dim=-1)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

            topv, topi = decoder_output.topk(1, dim=-1)
            next_input = topi.detach().cpu().squeeze(1)
            decoder_input = []
            for i_batch in range(batch_size):
                pred_list = vocab_pred+src_org_batch[i_batch]
                next_input_token = pred_list[next_input[i_batch].item()]
                decoder_input.append(vocab.word2index.get(next_input_token, UNK_index))
            decoder_input = torch.tensor(decoder_input, device=device).unsqueeze(1)
            decoding_token_index += 1

    # average loss
    log_likelihoods = torch.cat(step_log_likelihoods, dim=-1)
    # mask padding for tgt
    tgt_pad_mask = sequence_mask(tgt_true_len).float()
    log_likelihoods = log_likelihoods*tgt_pad_mask
    loss = -log_likelihoods.sum()/batch_size
    loss.backward()

    ### TODO
    # clip for gradient exploding 
    encoder_optimizer.step()
    decoder_optimizer.step()

    return (loss*batch_size/tgt_pad_mask.sum()).item() #torch.div(loss, tgt_true_len.type_as(loss).mean()).item()  #/tgt_true_len.mean()


def trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
               teacher_forcing_ratio, model_save_info, tgt_max_len, beam_size, vocab):

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

    if model_save_info['model_path_for_resume'] is not None:
        check_point_state = torch.load(model_save_info['model_path_for_resume'])
        encoder.load_state_dict(check_point_state['encoder_state_dict'])
        encoder_optimizer.load_state_dict(check_point_state['encoder_optimizer_state_dict'])
        decoder.load_state_dict(check_point_state['decoder_state_dict'])
        decoder_optimizer.load_state_dict(check_point_state['decoder_optimizer_state_dict'])

    for epoch in range(num_epochs): 
        n_iter = -1
        losses = np.zeros((len(train_loader),))
        start_time = time.time()
        for src_tensor, src_true_len, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy, src_org_batch, tgt_org_batch in train_loader:
            n_iter += 1
            #print('start_step: ', n_iter)
            src_data = (src_org_batch, src_tensor, src_true_len)
            tgt_data = (tgt_org_batch, tgt_tensor, tgt_label_vocab, tgt_label_copy, tgt_true_len)
            loss = train(src_data, tgt_data, encoder, decoder, encoder_optimizer, 
                         decoder_optimizer, teacher_forcing_ratio, vocab)
            losses[n_iter] = loss
            if n_iter % 500 == 0:
                pass
                #print('Loss:', loss)
                #eva_start = time.time()
#                 precision, recall, val_loss = evaluate_batch(val_loader, encoder, decoder, tgt_max_len, vocab, vocab_pred_size)
#                 #print((time.time()-eva_start)/60)
#                 print('epoch: [{}/{}], step: [{}/{}], train_loss:{}, val_precision: {}, val_recall: {}, val_loss: {}'.format(
#                     epoch, num_epochs, n_iter, len(train_loader), loss, precision.mean(), recall.mean(), val_loss))
               # print('Decoder parameters grad:')
               # for p in decoder.named_parameters():
               #     print(p[0], ': ',  p[1].grad.data.abs().mean().item(), p[1].grad.data.abs().max().item(), p[1].data.abs().mean().item(), p[1].data.abs().max().item(), end=' ')
               # print('\n')
               # print('Encoder Parameters grad:')
               # for p in encoder.named_parameters():
               #     print(p[0], ': ',  p[1].grad.data.abs().mean().item(), p[1].grad.data.abs().max().item(), p[1].data.abs().mean().item(), p[1].data.abs().max().item(), end=' ')
               # print('\n')
        val_loss, src_org, tgt_org, tgt_pred = predict_facts(val_loader, encoder, decoder, tgt_max_len, vocab)
        precision, recall = evaluate_prediction(tgt_org, tgt_pred)
        print('epoch: [{}/{}], step: [{}/{}], train_loss:{}, val_precision: {}, val_recall: {}, val_loss: {}'.format(
            epoch, num_epochs, n_iter, len(train_loader), losses.mean(), precision.mean(), recall.mean(), val_loss))

#         if (epoch+1) % model_save_info['epochs_per_save_model'] == 0:
#             check_point_state = {
#                 'epoch': epoch,
#                 'encoder_state_dict': encoder.state_dict(),
#                 'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
#                 'decoder_state_dict': decoder.state_dict(),
#                 'decoder_optimizer_state_dict': decoder_optimizer.state_dict()
#                 }
#             torch.save(check_point_state, '{}epoch_{}.pth'.format(model_save_info['model_path'], epoch))

    return None

In [106]:
paras = dict( 
    tgt_max_len = 130,
    max_src_len_dataloader =94,
    max_tgt_len_dataloader =127,

    emb_size = 200,
    en_hidden_size = 128,
    en_num_layers = 2,
    en_num_direction = 2,
    de_hidden_size = 256,
    de_num_layers = 3,
    rnn_type = 'GRU', # {LSTM, GRU}
    attention_type = 'dot_prod', #'dot_prod', general, concat #dot-product need pre-process
    teacher_forcing_ratio = 1,

    learning_rate = 1e-3,
    num_epochs = 40,
    batch_size = 64, 
    beam_size = 5,
    dropout_rate = 0.0,

    model_save_info = dict(
        model_path = 'nmt_models/model1/',
        epochs_per_save_model = 2,
        model_path_for_resume = None #'nmt_models/epoch_0.pth'
        )
    )

In [107]:
tgt_max_len = paras['tgt_max_len']
max_src_len_dataloader = paras['max_src_len_dataloader']
max_tgt_len_dataloader = paras['max_tgt_len_dataloader']

teacher_forcing_ratio = paras['teacher_forcing_ratio']
emb_size = paras['emb_size']
en_hidden_size = paras['en_hidden_size']
en_num_layers = paras['en_num_layers']
en_num_direction = paras['en_num_direction']
de_hidden_size = paras['de_hidden_size']
de_num_layers = paras['de_num_layers']

learning_rate = paras['learning_rate']
num_epochs = paras['num_epochs']
batch_size = paras['batch_size']
rnn_type = paras['rnn_type']
attention_type = paras['attention_type']
beam_size = paras['beam_size']
model_save_info = paras['model_save_info']
dropout_rate = paras['dropout_rate']

In [108]:
train_dataset = VocabDataset(train_src_input_index, train_tgt_input_index, 
                             train_label_symbolindex, train_indicator, train_data, 
                             max_src_len_dataloader, max_tgt_len_dataloader)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               collate_fn=vocab_collate_func,
                                               shuffle=True)

val_dataset = VocabDataset(val_src_input_index, val_tgt_input_index, 
                           val_label_symbolindex, val_indicator, val_data,
                           max_src_len_dataloader, max_tgt_len_dataloader)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=batch_size,
                                               collate_fn=vocab_collate_func,
                                               shuffle=False)

In [109]:
# make dir for saving models
if not os.path.exists(model_save_info['model_path']):
    os.makedirs(model_save_info['model_path'])
### save model hyperparameters
with open(model_save_info['model_path']+'model_params.pkl', 'wb') as f:
    model_hyparams = paras
    pickle.dump(model_hyparams, f)
print(model_hyparams)

# read all data
### save srcLang and tgtLang

#for src; keep original src_org and index based on vocab src_tensor

#for tgt; vocab_pred_label, copy_label


# test_dataset = VocabDataset(test_data)
# test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
#                                            batch_size=BATCH_SIZE,
#                                            collate_fn=vocab_collate_func,
#                                            shuffle=False)

# embedding_src_weight = torch.from_numpy(srcLang.embedding_matrix).type(torch.FloatTensor).to(device)
# embedding_tgt_weight = torch.from_numpy(tgtLang.embedding_matrix).type(torch.FloatTensor).to(device)
# print(embedding_src_weight.size(), embedding_tgt_weight.size())

encoder = EncoderRNN(trainLang.vocab_size, emb_size, en_hidden_size, en_num_layers, 
                     en_num_direction, (de_num_layers, de_hidden_size), rnn_type=rnn_type, 
                     dropout_rate=dropout_rate)
decoder = DecoderAtten(trainLang.vocab_size, vocab_pred_size, emb_size, de_hidden_size, 
                       de_num_layers, (en_num_layers, en_num_direction, en_hidden_size), 
                       rnn_type=rnn_type, atten_type=attention_type, 
                       dropout_rate=dropout_rate)

encoder, decoder = encoder.to(device), decoder.to(device)
print('Encoder:')
print(encoder)
print('Decoder:')
print(decoder)

{'tgt_max_len': 130, 'max_src_len_dataloader': 94, 'max_tgt_len_dataloader': 127, 'emb_size': 200, 'en_hidden_size': 128, 'en_num_layers': 2, 'en_num_direction': 2, 'de_hidden_size': 256, 'de_num_layers': 3, 'rnn_type': 'GRU', 'attention_type': 'dot_prod', 'teacher_forcing_ratio': 1, 'learning_rate': 0.001, 'num_epochs': 40, 'batch_size': 64, 'beam_size': 5, 'dropout_rate': 0.0, 'model_save_info': {'model_path': 'nmt_models/model1/', 'epochs_per_save_model': 2, 'model_path_for_resume': None}}
dot_prod
Encoder:
EncoderRNN(
  (embedding): Embedding(8776, 200)
  (dropout): Dropout(p=0.0)
  (gru): GRU(200, 128, num_layers=2, batch_first=True, bidirectional=True)
  (transform_en_hid): Linear(in_features=512, out_features=768, bias=False)
)
Decoder:
DecoderAtten(
  (dropout): Dropout(p=0.0)
  (embedding): Embedding(8776, 200)
  (gru): GRU(200, 256, num_layers=3, batch_first=True)
  (atten): AttentionLayer()
  (linear): Linear(in_features=512, out_features=256, bias=True)
  (copy_mech): CopyM

In [110]:
trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
           teacher_forcing_ratio, model_save_info, tgt_max_len, beam_size, trainLang)

epoch: [0/40], step: [446/447], train_loss:1.713569215613457, val_precision: 0.20918120646034197, val_recall: 0.20564652405446393, val_loss: 0
epoch: [1/40], step: [446/447], train_loss:0.688048206786448, val_precision: 0.281360057013545, val_recall: 0.2913922229271028, val_loss: 0
epoch: [2/40], step: [446/447], train_loss:0.5351252585849506, val_precision: 0.3015026361211625, val_recall: 0.32026617848378997, val_loss: 0
epoch: [3/40], step: [446/447], train_loss:0.46199089588734926, val_precision: 0.34572786894711416, val_recall: 0.333882645155313, val_loss: 0
epoch: [4/40], step: [446/447], train_loss:0.4080293056548842, val_precision: 0.356959809418558, val_recall: 0.35049893289560013, val_loss: 0
epoch: [5/40], step: [446/447], train_loss:0.3710157597891703, val_precision: 0.35694701112933425, val_recall: 0.35594693521911297, val_loss: 0
epoch: [6/40], step: [446/447], train_loss:0.33719327822494294, val_precision: 0.3583590864535443, val_recall: 0.35972051601033794, val_loss: 0
e

KeyboardInterrupt: 

In [None]:
# trainIters(train_loader, val_loader, encoder, decoder, num_epochs, learning_rate, 
#            teacher_forcing_ratio, model_save_info, tgt_max_len, beam_size, trainLang)

In [111]:
loader = val_loader
tgt_max_length = tgt_max_len
loss, src_org, tgt_org, tgt_pred = predict_facts(loader, encoder, decoder, tgt_max_length, trainLang)
precision, recall = evaluate_prediction(tgt_org, tgt_pred)

In [112]:
print(precision.mean(), recall.mean())

0.352119460166 0.354060595688


In [None]:
val_dataset1 = VocabDataset(val_src_input_index, val_tgt_input_index, 
                             val_label_symbolindex, val_indicator, val_data)
val_loader1 = torch.utils.data.DataLoader(dataset=val_dataset,
                                               batch_size=2,
                                               collate_fn=vocab_collate_func,
                                               shuffle=False)

In [None]:
def similarity_score(fact1, fact2):
    elem1 = fact1.split('@')
    elem2 = fact2.split('@')
    n1 = len(elem1)
    n2 = len(elem2)
    sim = 0
    for i in range(min(n1,n2)):
        sim += difflib.SequenceMatcher(None,elem1[i],elem2[i]).ratio()
    return sim/max(n1,n2)

def check_fact_same(org_fact, pred_fact):
    org_fact_ele = org_fact.split('@')
    pred_fact_ele = pred_fact.split('@')
    if len(org_fact_ele) == len(pred_fact_ele):
        ele_num = len(org_fact_ele)
        if difflib.SequenceMatcher(None,org_fact,pred_fact).ratio() > 0.85:
            return True       
        ele_sim = np.zeros((ele_num,))
        for ele_i in range(ele_num):
            ele_sim[ele_i] = difflib.SequenceMatcher(None,org_fact_ele[ele_i],pred_fact_ele[ele_i]).ratio()
        if ele_sim.mean() > 0.85:
            return True
    return False

In [44]:
from scipy.optimize import linear_sum_assignment

tgt_max_length = 130
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")    
encoder.eval()
decoder.eval()

tgt_pred = []
src_org = []
tgt_org = []
loss = 0
loader = val_loader

for src_tensor, src_true_len, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy, src_org_batch, tgt_org_batch in loader:
    batch_size = src_tensor.size(0)
    encoder_hidden, encoder_cell = encoder.initHidden(batch_size)
    encoder_outputs, encoder_hidden, encoder_cell = encoder(src_tensor, encoder_hidden, src_true_len, encoder_cell)
    decoder_input = torch.tensor([SOS_index]*batch_size, device=device).unsqueeze(1)
    decoder_hidden, decoder_cell = encoder_hidden, decoder.initHidden(batch_size)

    decoding_token_index = 0
    stop_flag = [False]*batch_size
    step_log_likelihoods = []
    tgt_pred_batch = [[] for i_batch in range(batch_size)]
    tgt_true_len_max = tgt_true_len.cpu().numpy().max()
    while decoding_token_index < tgt_max_length:
        decoder_output, decoder_hidden, _, decoder_cell = decoder(decoder_input, decoder_hidden, src_true_len, encoder_outputs, decoder_cell)

        # compute loss 
        if decoding_token_index < tgt_true_len_max:
            decoding_label_vocab = tgt_label_vocab[:, decoding_token_index]
            decoding_label_copy = tgt_label_copy[:, decoding_token_index, :]
            copy_log_probs = decoder_output[:, vocab_pred_size:]+(decoding_label_copy.float()+1e-45).log()
            #mask sample which is copied only
            gen_mask = ((decoding_label_vocab!=OOV_pred_index) | (decoding_label_copy.sum(-1)==0)).float() 
            log_gen_mask = (gen_mask + 1e-45).log().unsqueeze(-1)
            #mask log_prob value for oov_pred_index when label_vocab==oov_pred_index and is copied 
            generation_log_probs = decoder_output.gather(1, decoding_label_vocab.unsqueeze(1)) + log_gen_mask
            combined_gen_and_copy = torch.cat((generation_log_probs, copy_log_probs), dim=-1)
            step_log_likelihood = torch.logsumexp(combined_gen_and_copy, dim=-1)
            step_log_likelihoods.append(step_log_likelihood.unsqueeze(1))

        #
        topv, topi = decoder_output.topk(1, dim=-1)
        next_input = topi.detach().cpu().squeeze(1)
        decoder_input = []
        for i_batch in range(batch_size):
            pred_list = vocab_pred+src_org_batch[i_batch]
            next_input_token = pred_list[next_input[i_batch].item()]
            if next_input_token == vocab_pred[EOS_pred_index]:
                stop_flag[i_batch] = True
            if not stop_flag[i_batch]:
                tgt_pred_batch[i_batch].append(next_input_token)
            decoder_input.append(trainLang.word2index.get(next_input_token, UNK_index))
        decoder_input = torch.tensor(decoder_input, device=device).unsqueeze(1)
        decoding_token_index += 1
        if all(stop_flag):
            break
    log_likelihoods = torch.cat(step_log_likelihoods, dim=-1)
    # mask padding for tgt
    tgt_pad_mask = sequence_mask(tgt_true_len).float()
    log_likelihoods = log_likelihoods*tgt_pad_mask[:,:log_likelihoods.size(1)]
    loss += -(log_likelihoods.sum()/tgt_pad_mask.sum()).item()
    tgt_pred.extend(tgt_pred_batch)
    src_org.extend(src_org_batch)
    tgt_org.extend(tgt_org_batch)
loss = loss/len(loader)

NameError: name 'difflib' is not defined

In [55]:
import difflib
from Multilayers_Decoder import sequence_mask

eval_len = len(tgt_pred)
precision = np.zeros((eval_len,))
recall = np.zeros((eval_len,))
for i in range(eval_len):
    org_facts = ''.join(tgt_org[i]).split('$')
    pred_facts = ''.join(tgt_pred[i]).split('$')
    org_facts_num = len(org_facts)
    pred_facts_num = len(pred_facts)
    org_match_num = np.zeros((org_facts_num))
    pred_match_num = np.zeros((pred_facts_num))
    similarity_ma = np.zeros((org_facts_num, pred_facts_num))
    for org_i in range(org_facts_num):
        for pred_i in range(pred_facts_num):
            similarity_ma[org_i, pred_i] = similarity_score(org_facts[org_i], pred_facts[pred_i])
    row_ind, col_ind = linear_sum_assignment(-similarity_ma)
    
    for org_i, pred_i in zip(row_ind, col_ind):
        org_fact = org_facts[org_i]
        pred_fact = pred_facts[pred_i]
        fact_same = check_fact_same(org_fact, pred_fact)
        if fact_same:
            org_match_num[org_i] = 1
            pred_match_num[pred_i] = 1
#     print(pred_match_num)
#     print(org_match_num)
    precision[i] = pred_match_num.mean()
    recall[i] = org_match_num.mean()
if False:
    random_sample = np.random.randint(eval_len)
    print('src: ', src_org[random_sample])
    print('Ref: ', tgt_org[random_sample])
    print('pred: ', tgt_pred[random_sample])

In [56]:
print(precision.mean())

0.359069622131


In [57]:
print(recall.mean())

0.342696988792


In [90]:
precision1 = precision
recall1 = recall