In [1]:
# coding: utf-8
from src.train_and_evaluate import *
from src.models import *
import os
import torch.optim
from src.expressions_transfer import *
import json
import random
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def read_json(path):
    with open(path,'r') as f:
        file = json.load(f)
    return file


In [3]:

data = load_raw_data("data/Math_23K.json")
group_data = read_json("data/Math_23K_processed.json")

data = load_raw_data("data/Math_23K.json")

print(data[1], len(data))

Reading lines...
Reading lines...
{'id': '2', 'original_text': '一个工程队挖土，第一天挖了316方，从第二天开始每天都挖230方，连续挖了6天，这个工程队一周共挖土多少方？', 'segmented_text': '一 个 工程队 挖土 ， 第一天 挖 了 316 方 ， 从 第 二 天 开始 每天 都 挖 230 方 ， 连续 挖 了 6 天 ， 这个 工程队 一周 共 挖土 多少 方 ？', 'equation': 'x=316+230*(6-1)', 'ans': '1466'} 23162


In [4]:
def remove_non_unicode_chars(text):
    all_punc = ["．", "？", "（", "）", ",", "：", "；", "？", "！", "，", "“", "”", ",", ".", "?", "，", "。", "？", "．", "；", "｡", '?', '.','(',')']
    split_text = text.split(' ')
    cleaned_text = []
    for s in split_text:
        if s in all_punc:
            cleaned_text.append(s)
        else:
            s = re.sub(r'[^\x00-\x7F\u4E00-\u9FA5]', '', s)
            cleaned_text.append(s)
    cleaned = [c for c in cleaned_text if c != '']
    return cleaned

cleaned_train_data = []
for item in data:
    cleaned_item = {
        'id': item['id'],
        'original_text': ''.join(remove_non_unicode_chars(item['segmented_text'])),
        'segmented_text': remove_non_unicode_chars(item['segmented_text']),
        'equation': item['equation'],
        'ans': item['ans']
    }
    cleaned_train_data.append(cleaned_item)

In [5]:
import tqdm

In [6]:
print(len(cleaned_train_data))

23162


In [7]:
from ltp import LTP
ltp = LTP(path=r"E:\research_v2\tools\Ltp_base2_v3_")

In [8]:
def transfer_ro_num(data):  # transfer num into "NUM"
    print("Transfer numbers...")
    pattern = re.compile("\d*\(\d+/\d+\)\d*|\d+\.\d+%?|\d+%?")
    pairs = []
    generate_nums = []
    generate_nums_dict = {}
    copy_nums = 0
    for d in tqdm.tqdm(data, desc='parse the sentence ... '):
        id = d['id']
        nums = []
        input_seq = []
        seg = d["segmented_text"]
        ori_seg = []
        ans = d['ans']
        equations = d['equation'][2:]
        for s in seg:
            pos = re.search(pattern, s)
            if pos and pos.start() == 0:
                nums.append(s[pos.start(): pos.end()])
                input_seq.append("NUM")
                ori_seg.append(s[pos.start(): pos.end()])
                if pos.end() < len(s):
                    input_seq.append(s[pos.end():])
                    ori_seg.append(s[pos.end():])
            else:
                if len(s) > 0:
                    input_seq.append(s)
                    ori_seg.append(s)
                else:
                    continue
        if copy_nums < len(nums):
            copy_nums = len(nums)

        nums_fraction = []
        for num in nums:
            if re.search("\d*\(\d+/\d+\)\d*", num):
                nums_fraction.append(num)
        nums_fraction = sorted(nums_fraction, key=lambda x: len(x), reverse=True)

        def seg_and_tag(st):  # seg the equation and tag the num
            res = []
            for n in nums_fraction:
                if n in st:
                    p_start = st.find(n)
                    p_end = p_start + len(n)
                    if p_start > 0:
                        res += seg_and_tag(st[:p_start])
                    if nums.count(n) == 1:
                        res.append("N" + str(nums.index(n)))
                    else:
                        res.append(n)
                    if p_end < len(st):
                        res += seg_and_tag(st[p_end:])
                    return res
            pos_st = re.search("\d+\.\d+%?|\d+%?", st)
            if pos_st:
                p_start = pos_st.start()
                p_end = pos_st.end()
                if p_start > 0:
                    res += seg_and_tag(st[:p_start])
                st_num = st[p_start:p_end]
                if nums.count(st_num) == 1:
                    res.append("N" + str(nums.index(st_num)))
                else:
                    res.append(st_num)
                if p_end < len(st):
                    res += seg_and_tag(st[p_end:])
                return res
            for ss in st:
                res.append(ss)
            return res

        out_seq = seg_and_tag(equations)
        for s in out_seq:  # tag the num which is generated
            if s[0].isdigit() and s not in generate_nums and s not in nums:
                generate_nums.append(s)
                generate_nums_dict[s] = 0
            if s in generate_nums and s not in nums:
                generate_nums_dict[s] = generate_nums_dict[s] + 1

        num_pos = []
        for i, j in enumerate(input_seq):
            if j == "NUM":
                num_pos.append(i)
        assert len(nums) == len(num_pos)

        words, hidden = ltp.seg([ori_seg], is_preseged=True)
        dep = ltp.dep(hidden)
        parse = [d[1] - 1 for d in dep[0]]

        assert len(words[0]) == len(input_seq)
        item = {
            'id': id,
            'original_text': words[0],
            'num_text': input_seq,
            'infix_equation': out_seq,
            'parse': parse,
            'nums': nums,
            'num_pos': num_pos,
            'ans': ans
        }
        pairs.append(item)

    temp_g = []
    for g in generate_nums:
        if generate_nums_dict[g] >= 5:
            temp_g.append(g)

    return pairs, temp_g, copy_nums

In [9]:
if os.path.exists('./data/train_parse_pairs.json'):
    print('加载已经处理数据集...')
    with open('./data/train_parse_pairs.json', 'r', encoding='utf-8') as f:
        pairs = json.load(f)
    generate_nums = ['1', '3.14']
    copy_nums = 15
else:
    print('... ing 处理数据集 ing ...')
    pairs, generate_nums, copy_nums = transfer_ro_num(cleaned_train_data)
    print(generate_nums, copy_nums)
    # val_pairs = transfer_val_num(cleaned_val_data)

加载已经处理数据集...


datas = json.dumps(pairs, ensure_ascii=False, indent=1)
with open('./data/train_parse_pairs.json', 'w', encoding='utf-8') as file:
    file.write(datas)

In [10]:
pairs[-1]

{'id': '23162',
 'original_text': ['一年级',
  '和',
  '二年级',
  '学生',
  '到',
  '小精灵',
  '剧场',
  '看',
  '木偶戏',
  '，',
  '一年级',
  '有',
  '186',
  '人',
  '，',
  '二年级',
  '有',
  '235',
  '人',
  '．',
  '剧院',
  '共有',
  '500',
  '个',
  '座位',
  '，',
  '还有',
  '多少',
  '个',
  '空座位',
  '？'],
 'num_text': ['一年级',
  '和',
  '二年级',
  '学生',
  '到',
  '小精灵',
  '剧场',
  '看',
  '木偶戏',
  '，',
  '一年级',
  '有',
  'NUM',
  '人',
  '，',
  '二年级',
  '有',
  'NUM',
  '人',
  '．',
  '剧院',
  '共有',
  'NUM',
  '个',
  '座位',
  '，',
  '还有',
  '多少',
  '个',
  '空座位',
  '？'],
 'infix_equation': ['N2', '-', 'N0', '-', 'N1'],
 'parse': [3,
  2,
  0,
  7,
  7,
  6,
  4,
  -1,
  7,
  7,
  11,
  7,
  13,
  11,
  11,
  16,
  11,
  18,
  16,
  11,
  21,
  7,
  23,
  24,
  21,
  21,
  21,
  28,
  29,
  26,
  7],
 'nums': ['186', '235', '500'],
 'num_pos': [12, 17, 22],
 'ans': '79'}

In [11]:
group_data[0]

{'id': '1', 'group_num': [15, 16, 17, 32, 33, 34, 39, 40, 41]}

In [12]:
group_data[-1]

{'id': '23162', 'group_num': [16, 17, 18, 19, 21, 22, 23, 24, 27, 28, 29, 30]}

In [13]:
temp_pairs = []
for p in pairs:
    if len(p['num_text']) != len(p['original_text']):
        assert 0==1
    temp_pairs.append((p['num_text'], from_infix_to_prefix(p['infix_equation']), p['nums'], p['num_pos'], p['parse'],
                       p['original_text'], p['id']))

In [14]:
print(temp_pairs[1])

(['一', '个', '工程队', '挖土', '，', '第一天', '挖', '了', 'NUM', '方', '，', '从', '第', '二', '天', '开始', '每天', '都', '挖', 'NUM', '方', '，', '连续', '挖', '了', 'NUM', '天', '，', '这个', '工程队', '一周', '共', '挖土', '多少', '方', '？'], ['+', 'N0', '*', 'N1', '-', 'N2', '1'], ['316', '230', '6'], [8, 19, 25], [1, 2, 3, -1, 3, 6, 3, 6, 9, 6, 3, 18, 13, 14, 11, 11, 18, 18, 3, 20, 18, 3, 23, 3, 23, 26, 23, 3, 29, 32, 32, 32, 3, 34, 32, 3], ['一', '个', '工程队', '挖土', '，', '第一天', '挖', '了', '316', '方', '，', '从', '第', '二', '天', '开始', '每天', '都', '挖', '230', '方', '，', '连续', '挖', '了', '6', '天', '，', '这个', '工程队', '一周', '共', '挖土', '多少', '方', '？'], '2')


In [15]:
ori_path = './data/'
prefix = '23k_processed.json'

In [16]:
def get_train_test_fold(ori_path,prefix,data,pairs,group):
    mode_train = 'train'
    mode_valid = 'valid'
    mode_test = 'test'
    train_path = ori_path + mode_train + prefix # data/train23k_processed.json
    valid_path = ori_path + mode_valid + prefix
    test_path = ori_path + mode_test + prefix
    train = read_json(train_path)
    train_id = [item['id'] for item in train]
    valid = read_json(valid_path)
    valid_id = [item['id'] for item in valid]
    test = read_json(test_path)
    test_id = [item['id'] for item in test]
    train_fold = []
    valid_fold = []
    test_fold = []
    for item,pair,g in zip(data, pairs, group):
        pair = list(pair)
        pair.append(g['group_num'])
        pair.append(g['id'])
        pair = tuple(pair)
        if item['id'] in train_id:
            train_fold.append(pair)
        elif item['id'] in test_id:
            test_fold.append(pair)
        else:
            valid_fold.append(pair)
    return train_fold, test_fold, valid_fold

In [17]:
pairs = temp_pairs

In [18]:
train_fold, test_fold, valid_fold = get_train_test_fold(ori_path, prefix,data, pairs, group_data)

In [19]:
print(train_fold[1])

(['一', '个', '工程队', '挖土', '，', '第一天', '挖', '了', 'NUM', '方', '，', '从', '第', '二', '天', '开始', '每天', '都', '挖', 'NUM', '方', '，', '连续', '挖', '了', 'NUM', '天', '，', '这个', '工程队', '一周', '共', '挖土', '多少', '方', '？'], ['+', 'N0', '*', 'N1', '-', 'N2', '1'], ['316', '230', '6'], [8, 19, 25], [1, 2, 3, -1, 3, 6, 3, 6, 9, 6, 3, 18, 13, 14, 11, 11, 18, 18, 3, 20, 18, 3, 23, 3, 23, 26, 23, 3, 29, 32, 32, 32, 3, 34, 32, 3], ['一', '个', '工程队', '挖土', '，', '第一天', '挖', '了', '316', '方', '，', '从', '第', '二', '天', '开始', '每天', '都', '挖', '230', '方', '，', '连续', '挖', '了', '6', '天', '，', '这个', '工程队', '一周', '共', '挖土', '多少', '方', '？'], '2', [6, 7, 8, 15, 16, 17, 21, 22, 23], '2')


In [20]:
print(train_fold[10000])

(['NUM', '/', 'NUM', '的', '商', '，', '加上', 'NUM', '，', '再', '乘', 'NUM', '，', '积', '=', '？'], ['*', '+', '/', 'N1', 'N0', 'N2', 'N3'], ['4', '1.8', '3', '2'], [0, 2, 7, 11], [4, 2, 4, 2, 6, 4, -1, 6, 6, 10, 6, 10, 6, 6, 6, 6], ['4', '/', '1.8', '的', '商', '，', '加上', '3', '，', '再', '乘', '2', '，', '积', '=', '？'], '10941', [0, 1, 1, 2, 3, 6, 7, 8, 10, 11, 12, 12, 13, 14], '10941')


In [21]:
print(test_fold[-1])

(['广场', '新种', '了', '一批', '花木', '，', '其中', 'NUM', '是', '玫瑰', '，', 'NUM', '是', '月季', '．', '已知', '月季', '有', 'NUM', '棵', '，', '玫瑰', '有', '多少', '棵', '？'], ['*', '/', 'N2', 'N1', 'N0'], ['(5/16)', '(3/8)', '36'], [7, 11, 18], [1, 4, 1, 4, -1, 4, 8, 8, 4, 8, 8, 12, 8, 12, 8, 17, 17, 8, 19, 17, 17, 22, 17, 24, 22, 8], ['广场', '新种', '了', '一批', '花木', '，', '其中', '(5/16)', '是', '玫瑰', '，', '(3/8)', '是', '月季', '．', '已知', '月季', '有', '36', '棵', '，', '玫瑰', '有', '多少', '棵', '？'], '23146', [11, 12, 13, 17, 18, 19, 21, 22, 23, 24], '23146')


In [22]:
print(valid_fold[-1])

(['王', '叔叔', '从', '甲', '城', '到', '乙', '城', '，', '第一天', '行', '了', '全程', '的', 'NUM', '，', '第', '二', '天', '行', '了', '全程', '的', 'NUM', '，', '距', '乙', '城', '还有', 'NUM', '千米', '．', '甲', '乙', '两城', '相距', '多少', '千米', '？'], ['/', 'N2', '-', '-', '1', 'N0', 'N1'], ['40%', '(9/20)', '90'], [14, 23, 29], [1, 10, 5, 4, 2, 10, 7, 5, 5, 10, -1, 10, 14, 12, 10, 10, 17, 18, 19, 10, 19, 19, 21, 21, 10, 28, 27, 25, 10, 30, 28, 10, 34, 32, 35, 10, 37, 35, 10], ['王', '叔叔', '从', '甲', '城', '到', '乙', '城', '，', '第一天', '行', '了', '全程', '的', '40%', '，', '第', '二', '天', '行', '了', '全程', '的', '(9/20)', '，', '距', '乙', '城', '还有', '90', '千米', '．', '甲', '乙', '两城', '相距', '多少', '千米', '？'], '23144', [10, 11, 12, 18, 19, 21, 22, 23, 24, 25, 29, 30, 31, 32], '23144')


In [23]:
bert = EncoderChar(bert_path=r"E:\research_v2\tools\have_fine_tune\chinese_roberta_wwm_ext_D\No\model", bert_size=768, hidden_size=512, get_word_and_sent=True)
start = time.time()

Some weights of the model checkpoint at E:\research_v2\tools\have_fine_tune\chinese_roberta_wwm_ext_D\No\model were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at E:\research_v2\tools\have_fine_tune\chinese_roberta_ww

In [24]:
from src.pre_data import prepare_ro_data

In [25]:
input_lang, output_lang, train_pairs, test_pairs = prepare_ro_data(train_fold, valid_fold, 5, generate_nums,
                                                                   copy_nums, bert.tokenizer, tree=True)

----------------------------------------  Indexing words...  ----------------------------------------
 <----------------------------  keep_words 3913 / 10492 = 0.3730  ----------------------------> 
Indexed 3933 words in input language, 23 words in output
Number of training data 21162
output_lang.index2word:  ['*', '-', '+', '/', '^', '1', '3.14', 'N0', 'N1', 'N2', 'N3', 'N4', 'N5', 'N6', 'N7', 'N8', 'N9', 'N10', 'N11', 'N12', 'N13', 'N14', 'UNK']
Number of testind data 1000


In [26]:
print(train_pairs[1])

([57, 58, 59, 60, 48, 61, 62, 45, 24, 63, 48, 64, 65, 66, 67, 68, 69, 44, 62, 24, 63, 48, 70, 62, 45, 24, 67, 48, 71, 59, 72, 73, 60, 56, 63, 74], 36, [2, 7, 0, 8, 1, 9, 5], 7, ['316', '230', '6'], [8, 19, 25], [], [1, 2, 3, -1, 3, 6, 3, 6, 9, 6, 3, 18, 13, 14, 11, 11, 18, 18, 3, 20, 18, 3, 23, 3, 23, 26, 23, 3, 29, 32, 32, 32, 3, 34, 32, 3], [6, 7, 8, 15, 16, 17, 21, 22, 23], ['一', '个', '工程队', '挖土', '，', '第一天', '挖', '了', '316', '方', '，', '从', '第', '二', '天', '开始', '每天', '都', '挖', '230', '方', '，', '连续', '挖', '了', '6', '天', '，', '这个', '工程队', '一周', '共', '挖土', '多少', '方', '？'], array([[0, 1, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 1, 0, 0],
       [0, 0, 0, ..., 0, 1, 0]], dtype=int64), [(8, [14]), (19, [27]), (25, [34])])
