In [12]:
import json
import pandas as pd

import jieba

In [13]:
saoke = []
with open('SAOKE_DATA.json', 'r') as f:
    for line in f:
        saoke.append(json.loads(line))

In [27]:
def replaceMisspred(predicate):
    '''replace missing predicate
    '''
    if predicate == '_':
        return 'P'
    else:
        return predicate

# def replaceMissinfo(aaa):
#     '''replace missing info for subjects/objects
#     '''
#     placeholder = ['Z','Y','X']
#     for i in range(len(aaa)):
#         if aaa[i] == '_':
#             aaa = aaa[:i] + placeholder.pop() + aaa[i+1:]
#     return aaa

In [33]:
# preprocess saoke data 
data = []
for sample in saoke:
    # remove some exceptions with empty facts
    if sample['logic'] == []:
        continue
    # tokenize src sentence
    sample_processed = dict()
    sample_processed['src_org'] = sample['natural']
    sample_processed['src'] = list(jieba.cut(sample['natural'], cut_all=False))
    
    # transform fact list into str and tokenize
    # $ separates facts; @ separate elements for one fact; & separate objects for one fact
    sample_processed['tgt_org'] = sample['logic']
    logic_list = []
    for fact in sample['logic']:
        logic_list.append('@'.join([fact['subject'], replaceMisspred(fact['predicate']), 
                                   '&'.join(fact['object'])]))
    logic_str = '$'.join(logic_list)
    sample_processed['tgt'] = list(jieba.cut(logic_str, cut_all=False))
    
    data.append(sample_processed)

In [37]:
#data[5]['tgt_org']

In [40]:
vocab_prefix = ['<PAD>', '<UNK>', '<EOS>', '<SOS>']

In [46]:
import numpy as np
from collections import Counter
from itertools import dropwhile

class Lang:
    def __init__(self, name, emb_pretrained_add=None, max_vocab_size=None):
        self.name = name
        self.word2index = None #{"$PAD$": PAD_token, "$SOS$": SOS_token, "$EOS$": EOS_token, "$UNK$": UNK_token}
        #self.word2count = None #{"$PAD$": 0, "$SOS$" : 0, "$EOS$": 0, "$UNK$": 0}
        self.index2word = None #{PAD_token: "$PAD$", SOS_token: "$SOS$", EOS_token: "$EOS$", UNK_token: "$UNK$"}
        self.max_vocab_size = max_vocab_size  # Count SOS and EOS
        self.vocab_size = None
        self.emb_pretrained_add = emb_pretrained_add
        self.embedding_matrix = None

    def build_vocab(self, data):
        all_tokens = []
        for sample in data:
            all_tokens.extend(sample['src'])
            all_tokens.extend(sample['tgt'])  
        token_counter = Counter(all_tokens)
        print('The number of unique tokens totally in dataset: ', len(token_counter))
        # remove word with freq==1 
        for key, count in dropwhile(lambda key_count: key_count[1] > 1, token_counter.most_common()):
            del token_counter[key]
        
        if self.max_vocab_size:
            vocab, count = zip(*token_counter.most_common(self.max_vocab_size))
        else:
            vocab, count = zip(*token_counter.most_common())
        
        self.index2word = vocab_prefix + list(vocab)
        word2index = dict(zip(self.index2word, range(0, len(self.index2word)))) 
#         word2index = dict(zip(vocab, range(len(vocab_prefix),len(vocab_prefix)+len(vocab)))) 
#         for idx, token in enumerate(vocab_prefix):
#             word2index[token] = idx
        self.word2index = word2index
        return None 

    def build_emb_weight(self):
        words_emb_dict = load_emb_vectors(self.emb_pretrained_add)
        vocab_size = len(self.index2word)
        self.vocab_size = vocab_size
        emb_weight = np.zeros([vocab_size, 300])
        for i in range(len(vocab_prefix), vocab_size):
            emb = words_emb_dict.get(self.index2word[i], None)
            if emb is not None:
                try:
                    emb_weight[i] = emb
                except:
                    pass
                    #print(len(emb), self.index2word[i], emb)
        self.embedding_matrix = emb_weight
        return None

def load_emb_vectors(fasttest_home):
    max_num_load = 500000
    words_dict = {}
    with open(fasttest_home) as f:
        for num_row, line in enumerate(f):
            if num_row >= max_num_load:
                break
            s = line.split()
            words_dict[s[0]] = np.asarray(s[1:])
    return words_dict


In [59]:
UNK_token = 1
oov_pred_index = 1
EOS_token = 2

def text2index(data, key, word2index):
    indexdata = []
    for line in data:
        line = line[key]
        indexdata.append([word2index[c] if c in word2index.keys() else UNK_token for c in line])
        #indexdata[-1].append(EOS_token)
    print('finish indexing')
    return indexdata

def construct_Lang(name, data, emb_pretrained_add = None, max_vocab_size = None):
    lang = Lang(name, emb_pretrained_add, max_vocab_size)
    lang.build_vocab(data)
    if emb_pretrained_add:
        lang.build_emb_weight()
    return lang

def text2symbolindex(data, key, word2index):
    indexdata = []
    for line in data:
        line = line[key]
        indexdata.append([word2index[c] if c in word2index.keys() else oov_pred_index for c in line])
        #indexdata[-1].append(EOS_token)
    print('finish')
    return indexdata

In [48]:
trainLang = construct_Lang('train', data)

The number of unique tokens totally in dataset:  85719


In [50]:
src_input_index = text2index(data, 'src', trainLang.word2index) 

finish indexing


In [51]:
tgt_input_index = text2index(data, 'tgt', trainLang.word2index) 

finish indexing


In [55]:
vocab_pred = ['<PAD>','<OOV>','<EOS>','ISA','DESC','IN','BIRTH',"DEATH", "=","$", "[","]","|","X","Y","Z","P","@"]

In [56]:
# symbol_index = []
# for o in symbol_dict:
#     symbol_index.append(trainLang.word2index[o])

In [57]:
word2symbolindex = {}
for idx, token in enumerate(vocab_pred):
        word2symbolindex[token] = idx

In [67]:
label_symbolindex = text2symbolindex(data, 'tgt', word2symbolindex)

finish


In [79]:
# Check valid
# for idx,i in enumerate(label_input_symbolindex[0]):
#     if symbol_dict[i] == 'OOV':
#         continue
#     elif symbol_dict[i] == labels[0][idx]:
#         continue
#     else:
#         print(symbol_dict[i])

In [62]:
def copy_indicator(data, src_key='src', tgt_key='tgt'):
    indicator = []
    for sample in data:
        tgt = sample[tgt_key]
        src = sample[src_key]
        matrix = np.zeros((len(tgt)+1, len(src)+1))
        for m in range(len(tgt)):
            for n in range(len(src)):
                if tgt[m] == src[n]:
                    matrix[m,n] = 1
        matrix[len(tgt),len(src)] = 1
        indicator.append(matrix)
    return indicator

In [63]:
indicator = copy_indicator(data, 'src', 'tgt')

In [65]:
#indicator[0]

In [68]:
len(src_input_index),len(tgt_input_index),len(label_symbolindex),len(indicator),len(data)

(40806, 40806, 40806, 40806, 40806)

In [77]:
import numpy as np
import time
import os.path
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence


class VocabDataset(Dataset):
    """
    Class that represents a train/validation/test dataset that's readable for PyTorch
    Note that this class inherits torch.utils.data.Dataset
    """
    def __init__(self, src_index, tgt_index, tgt_symbolindex, tgt_indicator, data, src_clip = None, tgt_clip = None):
        """
        @param data_list: list of character
        @param target_list: list of targets

        """
        self.src_clip = src_clip
        self.tgt_clip = tgt_clip
        self.src_list, self.tgt_list = src_index, tgt_index
        self.data = data
        self.tgt_symbolindex, self.tgt_indicator  = tgt_symbolindex, tgt_indicator
        
        assert (len(self.src_list) == len(self.tgt_list) == len(self.tgt_symbolindex)== len(self.tgt_indicator))
        #self.word2index = word2index

    def __len__(self):
        return len(self.src_list)

    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        
        src = self.src_list[key]
        tgt = self.tgt_list[key]
        src_org = self.data[key]['src']
        tgt_org = self.data[key]['tgt']
        tgt_sym = self.tgt_symbolindex[key]
        tgt_ind = self.tgt_indicator[key]
        
        if self.src_clip is not None:
            src = src[:self.src_clip]
            src_org = src_org[:self.src_clip]
            #tgt_ind = tgt_ind[:,:self.src_clip]
        src_length = len(src)

        if self.tgt_clip is not None:
            tgt = tgt[:self.tgt_clip]
            tgt_org = tgt_org[:self.tgt_clip]
            tgt_sym = tgt_sym[:self.tgt_clip]
            tgt_ind = tgt_ind[:self.tgt_clip,:]
        tgt_length = len(tgt)
        
        return src, src_length, tgt, tgt_length, tgt_sym, tgt_ind, src_org, tgt_org
        
        #return src_org, src_tensor, src_true_len, tgt_org, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy 


def vocab_collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all
    data have the same length
    """
    src_list = []
    tgt_list = []
    src_length_list = []
    tgt_length_list = []
    tgt_symbol_list = []
    tgt_indicator_list = []
    src_org_list = []
    tgt_org_list = []
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    for datum in batch:
        src_length_list.append(datum[1]+1)
        tgt_length_list.append(datum[3]+1)
    
    batch_max_src_length = np.max(src_length_list)
    batch_max_tgt_length = np.max(tgt_length_list)
    # padding
    for datum in batch:
        padded_vec = np.pad(np.array(datum[0]+[EOS_token]),
                                pad_width=((0, batch_max_src_length-datum[1]-1)),
                                mode="constant", constant_values=0)
        src_list.append(padded_vec)
        
        padded_vec = np.pad(np.array(datum[2]+[EOS_token]),
                                pad_width=((0, batch_max_tgt_length-datum[3]-1)),
                                mode="constant", constant_values=0)
        tgt_list.append(padded_vec)
        
        padded_vec = np.pad(np.array(datum[4]+[EOS_token]),
                                pad_width=((0, batch_max_tgt_length-datum[3]-1)),
                                mode="constant", constant_values=0)
        tgt_symbol_list.append(padded_vec)
        
        padded_vec = np.pad(np.array(datum[5]),
                                pad_width=((0, batch_max_tgt_length-datum[3]-1),((0, batch_max_src_length-datum[1]-1))),
                                mode="constant", constant_values=0)
        tgt_indicator_list.append(padded_vec)
        
        src_org_list.append(datum[6])
        tgt_org_list.append(datum[7])
    
    # re-order
#     ind_dec_order = np.argsort(src_length_list)[::-1]
#     data_list = np.array(data_list)[ind_dec_order]
#     train_length_list = np.array(train_length_list)[ind_dec_order]
#     label_list = np.array(label_list)[ind_dec_order]
#     label_length_list = np.array(label_length_list)[ind_dec_order]

    src_list = np.array(src_list)
    src_length_list = np.array(src_length_list)
    tgt_list = np.array(tgt_list)
    tgt_length_list = np.array(tgt_length_list)
    tgt_symbol_list = np.array(tgt_symbol_list)
    tgt_indicator_list = np.array(tgt_indicator_list)
    
    #print(type(np.array(data_list)),type(np.array(label_list)))
    
    return [torch.from_numpy(src_list).to(device), 
            torch.LongTensor(src_length_list).to(device), 
            torch.from_numpy(tgt_list).to(device), 
            torch.LongTensor(tgt_length_list).to(device),
            torch.from_numpy(tgt_symbol_list).to(device),
            torch.from_numpy(tgt_indicator_list).to(device),
            src_org_list,
            tgt_org_list,           
           ]


In [78]:
train_dataset = VocabDataset(src_input_index, tgt_input_index, label_symbolindex, indicator, data)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=1,
                                               collate_fn=vocab_collate_func,
                                               shuffle=False)

In [79]:
for src_tensor, src_true_len, tgt_tensor, tgt_true_len, tgt_label_vocab, tgt_label_copy, src_org, tgt_org in train_loader:
    print(src_tensor[:10])
    print(src_true_len)
    print(tgt_tensor[:10])
    print(tgt_true_len)
    print(tgt_label_vocab[:10])
    print(tgt_label_copy[:10,:10])
    print(src_org[:10])
    print(tgt_org[:10])
    break

tensor([[   70,  5166,    63, 30553,    19, 14922,  4470,     5, 12243,  5682,
            11,  5682,    19,  4471, 21254, 21255,    39,     7,    47,  5056,
            63, 23252,    19, 27224,     7,  5683,    18,  3159,  3353, 45583,
            54,  1034, 23253,   239,     5,    39,    11,  3302,  4470,     7,
            18,  3354,  5683,   903,    51,    54,  1034,  1977, 13040,     5,
          1994,    39,     7,  4938, 45584,    54, 14922,    19, 13846, 18972,
             5, 12243,  5682,    11, 23252,    19, 21256,  9221, 45585,     7,
          3353,    47, 45586,  9619,    11, 17503,    19,     1, 45587,    11,
         45588, 30554,    19, 13847,    39,     5,  9222,    10,     2]])
tensor([89])
tensor([[    9,     4,    70,  5166,    63,     4,    13, 30553,     8, 14922,
          4470,    12,     5,    13, 12243,  5682,     8,  5682,     8,  4471,
         21254, 21255,    12,    39,     6,     9,     4,    47,  5056,    63,
             4,    13, 23252,     8, 27224, 