In [1]:
import utils
import os
import json 

In [2]:
pad_flag = '[PAD]'
unk_flag = '[UNK]'
begin_flag = '[GO]'
end_flag = '[EOS]'


def read_json_file(path):
    if not os.path.exists(path):
        return None
    with open(path, 'r', encoding='utf8') as fp:
        json_data = json.load(fp) 
    return json_data

def read_text_file(path):
    if not os.path.exists(path):
        return None
    with open(path, 'r', encoding='utf8') as fp:
        txt = fp.read() 
    return txt
    
def save_text(path, text):
    with open(path, 'w', encoding='utf-8') as f:
        f.write(text) 
    
def get_q_a_from_dict(qa_dict, q_key='Question', a_key='Answer'):
    q = qa_dict.get(q_key, '')
    a = qa_dict.get(a_key, '')
    return q, a
    
def creat_vocab(path):
    json_data = read_json_file(path)
    char_dict = {}
    # trian
    for qa in json_data['train']:
        q, a = get_q_a_from_dict(qa) 
        for c in str(a) + str(q):
            char_dict[c] = char_dict.get(c, 0) + 1

    for qa in json_data['val']:
        q, a = get_q_a_from_dict(qa) 
        for c in str(a) + str(q):
            char_dict[c] = char_dict.get(c, 0) + 1
    
    flags = [pad_flag, unk_flag, begin_flag, end_flag]
    char_items = sorted(char_dict.items(), key=lambda x:x[1], reverse=True)
    chars = [item[0] for item in char_items]
    vocab = flags + chars
    vocab = [c.strip() for c in vocab if len(c.strip()) > 0]
    return vocab

def read_vocab(qa_path, vocab_path, ignore_exist=False):
    if not ignore_exist and  os.path.exists(vocab_path):
        vocab_txt = read_text_file(vocab_path)
        vocab = vocab_txt.split("\n")
        return vocab
    else:
        vocab = creat_vocab(qa_path)
        save_text(vocab_path, "\n".join(vocab))
        return vocab
    
vocab = read_vocab(qa_path=utils.path.fm_qa_json_path, 
                   vocab_path=utils.path.fm_qa_vocab_path, 
                   ignore_exist=False)
i2c = {i:c for i, c in enumerate(vocab)}
c2i = {c:i for i, c in enumerate(vocab)}
vocab_size = len(vocab)

pad_id = c2i.get(pad_flag, 0)
unk_id = c2i.get(unk_flag, 1)
begin_id = c2i.get(begin_flag, 2)
end_id = c2i.get(end_flag, 3)


In [3]:
# 先快速处理吧

In [7]:
def read_q_a_text(qa_path, model='train'):
    json_data = read_json_file(qa_path)
    key = model
    q_list, a_list = [], []
    for qa in json_data[model]:
        q, a = get_q_a_from_dict(qa) 
        q_list.append(q)
        a_list.append(a)
    return q_list, a_list


def text_to_ids(texts, max_len=32, add_begin=True, add_end=True):
    ids = []
    for line in texts:
        line = line[:max_len-2]
        id_line = [c2i.get(c, unk_id) for c in line]
        if add_begin:
            id_line = [begin_id] + id_line
        if add_end:
            id_line = id_line + [end_id]
        id_line = id_line + [pad_id] * (max_len - len(id_line))
        ids.append(id_line)
    return ids


max_len = 32
q_train, a_train = read_q_a_text(utils.path.fm_qa_json_path, model='train')
      
q_train_ids = text_to_ids(texts=q_train, max_len=32, add_begin=True, add_end=True)
a_train_input_ids = text_to_ids(texts=a_train, max_len=32, add_begin=True, add_end=False)
a_train_label_ids = text_to_ids(texts=a_train, max_len=32, add_begin=False, add_end=True)


q_val, a_val = read_q_a_text(utils.path.fm_qa_json_path, model='val')
      
q_val_ids = text_to_ids(texts=q_val, max_len=32, add_begin=True, add_end=True)
a_val_input_ids = text_to_ids(texts=a_val, max_len=32, add_begin=True, add_end=False)
a_val_label_ids = text_to_ids(texts=a_val, max_len=32, add_begin=False, add_end=True)


save_pickle(q_train_ids, utils.path.fm_qa_train_src_path)
save_pickle(a_train_input_ids, utils.path.fm_qa_train_trg_input_path)
save_pickle(a_train_label_ids, utils.path.fm_qa_train_trg_label_path)

save_pickle(q_val_ids, utils.path.fm_qa_val_src_path)
save_pickle(a_val_input_ids, utils.path.fm_qa_val_trg_input_path)
save_pickle(a_val_label_ids, utils.path.fm_qa_val_trg_label_path)


In [5]:
# 保存数据
import pickle

def save_pickle(data, file_name):
    f = open(file_name, "wb")
    pickle.dump(data, f)
    f.close()
def load_pickle(file_name):
    f = open(file_name, "rb+")
    data = pickle.load(f)
    f.close()
    return data  

In [3]:
char_dict = {}
# trian
for qa in json_data['train']:a
    q = qa.get('Question', "")
    a = qa.get('Answer', "")
    for c in str(a) + str(q):
        char_dict[c] = char_dict.get(c, 0) + 1
    
for qa in json_data['val']:
    q = qa.get('Question', "")
    a = qa.get('Answer', "")
    for c in str(a) + str(q):
        char_dict[c] = char_dict.get(c, 0) + 1

In [4]:
flags = ['[PAD]', '[UNK]', '[GO]', '[ESO]']

char_items = sorted(char_dict.items(), key=lambda x:x[1], reverse=True)
chars = [item[0] for item in char_items]
vocab = flags + chars

In [6]:
with open(path.fm_qa_vocab_path, 'w', encoding='utf-8') as f:
    f.write("\n".join(vocab)) 
    pass