In [1]:
import json
import pickle
import random
#import argparse
import numpy as np
import jieba

#### Reading tools

In [2]:
def read_data_json(filename):
    with open(filename, 'r') as f:
        return json.load(f)

def write_data_json(data, filename):
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)
        
def read_json_line(filename):
    data_list = []
    with open(filename, 'r') as f:
        for line in f:
            data_list.append(json.loads(line))
    return data_list

def read_math23k_json(filename):
    data_list = []
    with open(filename, 'r') as f:
        count = 0
        string = ''
        for line in f:
            count += 1
            string += line
            if count % 7 == 0:
                #print string
                data_list.append(json.loads(string))
                string = ''
    #print data_list[-1]
    return data_list

#### Load training set and test set

In [3]:
math23k_test = read_math23k_json("./data/math23k_test.json")
math23k_train = read_math23k_json("./data/math23k_train.json")
sni_dict = read_data_json("./data/sni_DNS.json")

#### Split words using jieba

In [4]:
for elem in math23k_train:
    #origin = elem['original_text']
    #print (sni_dict[elem['id']])
    elem['sni_text'] = sni_dict[elem['id']]['text']
    origin = elem['sni_text']
    origin_text = ' '.join(jieba.cut(origin, cut_all=False))
    elem['new_split'] = origin_text
for elem in math23k_test:
    #origin = elem['original_text']
    elem['sni_text'] = sni_dict[elem['id']]['text']
    origin = elem['sni_text']
    origin_text = ' '.join(jieba.cut(origin, cut_all=False))
    elem['new_split'] = origin_text

Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.898 seconds.
Prefix dict has been built succesfully.


#### Pre-processing tools

In [5]:
def joint_number(text_list):
    new_list = []
    i = 0
    while i<len(text_list):
        if text_list[i] == '(' and i+4<len(text_list) and text_list[i+4] == ')':
            sub = ''.join(text_list[i:i+5])
            new_list.append(sub)
            i = i+5
        else:
            new_list.append(text_list[i])
            i += 1
    #text_list = new_list[:]
    #new_list = []
    '''
    i = 0
    while i < len(text_list):
        if '%' in text_list[i] and len(text_list[i])>1:
            new_list.append(text_list[i])
            new_list.append(text_list[i][-1])
        else:
            new_list.append(text_list[i])
        i += 1
    '''
    return new_list        

def is_number(word):
    if word[0] == '(' and word[-1] == ')':
        for elem_char in word:
            if (elem_char.isdigit()):
                return True
        return False
    if '(' in word and ')' in word and '/' in word and not word[-1].isdigit():
        for elem_char in word:
            if (elem_char.isdigit()):
                return True
        return False
        #return True
    if word[-1] == '%' and len(word)>1:
        return True
    if word[0].isdigit():
        return True
    if word[-1].isdigit():
        return True
    try:
        float(word)
        return True
    except:
        return False
    
def split_num_and_unit(word):
    num = ''
    unit = ''
    for idx in range(len(word)):                                                                                                                                                                                                                                              
        char = word[idx]
        if char.isdigit() or char in ['.', '/', '(', ')']:
            num += char
        else:
            unit += char
    return num, unit#.encode('utf-8')

def mask_num(seg_text_list, equ_str):
    origin_equ_str = equ_str[:]

    alphas = 'abcdefghijklmnopqrstuvwxyz'
    num_list  = []
    mask_seg_text = []
    count = 0 
    for word in seg_text_list:
        if word == '':
            continue
        if is_number(word):
            mask_seg_text.append("temp_"+alphas[count])
            if '%' in word:
                mask_seg_text.append('%')
            #elif 'm' in word.lower() or 'g' in word.lower() or 'd' in word.lower():
            elif len(set(alphas)&set(word.lower()))>0:
                num, unit = split_num_and_unit(word)
                mask_seg_text.append(unit)
                word = num
                
            num_list.append(word)
            count += 1
        else:
            mask_seg_text.append(word)
    mask_equ_list = []
    s_n = sorted([(w,i) for i,w in enumerate(num_list)], key=lambda x: len(str(x[0])), reverse=True)
    if '3.14%' not in equ_str and '3.1416' not in equ_str:
        equ_str = equ_str.replace('3.14', '&PI&', 15)
    new_equ_str = ''
    #print (s_n)
    #print (equ_str)
    for num, idx in s_n:
        #num = num_list[idx]
        equ_str = equ_str.replace(num, '&temp_'+alphas[idx]+'&', 15)
        #if 
    #print (equ_str)
        
        
    equ_list = []
    num_set = ['0','1','2','3','4','5','6','7','8','9','%', '.']
    for elem in equ_str.split('&'):
        if 'temp' in elem or 'PI' in elem:
            equ_list.append(elem)
        else:
            start = ''
            for char in elem:
                if char not in num_set:
                    if start != '':
                        equ_list.append(start)
                    equ_list.append(char)
                    start = ''
                else:
                    start += char
            if start != '':
                equ_list.append(start)
    #rint (equ_list)
    #rint ()
    #reverse_equ_list = equ_list[::-1]
    #reverse_equ_list.append('END_token')
    #equ_list.append('END_token')
    '''
    print (' '.join(seg_text_list))
    print (' '.join(mask_seg_text))
    print (num_list)
    print (' '.join(equ_list))
    print (origin_equ_str)
    print ()
    '''
    new_equ_list = []
    for elem_equ in equ_list:
        if elem_equ == '[':
            elem_equ = '('
        if elem_equ == ']':
            elem_equ = ')'
        new_equ_list.append(elem_equ)
    equ_list = new_equ_list[:]
    return mask_seg_text, num_list, equ_list#, reverse_equ_list



def num_list_processed(num_list):
    new_num_list = []
    for num in num_list:
        if '%' in num:
            new_num_list.append(float(num[:-1])*1.0/100)
        elif '(' in num:
            new_num_list.append(eval(num))
         
        else:
            num_,_ = split_num_and_unit(num)
            new_num_list.append(float(num_))
    return new_num_list

def postfix_equation(equ_list):
    stack = []
    post_equ = []
    op_list = ['+', '-', '*', '/', '^']
    priori = {'^':3, '*':2, '/':2, '+':1, '-':1}
    for elem in equ_list:
        if elem == '(':
            stack.append('(')
        elif elem == ')':
            while 1:
                op = stack.pop()
                if op == '(':
                    break
                else:
                    post_equ.append(op)
        elif elem in op_list:
            while 1:
                if stack == []:
                    break
                elif stack[-1] == '(':
                    break
                elif priori[elem] > priori[stack[-1]]:
                    break
                else:
                    op = stack.pop()
                    post_equ.append(op)
            stack.append(elem)
        else:
            #if elem == 'PI':
            #    post_equ.append('3.14')
            #else:
            #    post_equ.append(elem)
            post_equ.append(elem)
    while stack != []:
        post_equ.append(stack.pop())
    return post_equ

def post_solver(post_equ):
    stack = []
    op_list = ['+', '-', '/', '*', '^']
    for elem in post_equ:
        if elem not in op_list:
            op_v = elem
            #if '%' in op_v:
            #    op_v = float(op_v[:-1])/100.0
            stack.append(str(op_v))
        elif elem in op_list:
            op_v_1 = stack.pop()
            op_v_1 = float(op_v_1)
            op_v_2 = stack.pop()
            op_v_2 = float(op_v_2)
            if elem == '+':
                stack.append(str(op_v_2+op_v_1))
            elif elem == '-':
                stack.append(str(op_v_2-op_v_1))
            elif elem == '*':
                stack.append(str(op_v_2*op_v_1))
            elif elem == '/':
                stack.append(str(op_v_2/op_v_1))
            else:
                stack.append(str(op_v_2**op_v_1))
    return stack.pop()
         
def solve_equation(equ_list):
    if '=' in equ_list:
        equ_list = equ_list[2:]
   
    post_equ = postfix_equation(equ_list)
    ans = post_solver(post_equ)
    return ans

def inverse_temp_to_num(equ_list, num_list):
    alphabet = "abcdefghijklmnopqrstuvwxyz"
    new_equ_list = []
    for elem in equ_list:
        if 'temp' in elem:
            index = alphabet.index(elem[-1])
            new_equ_list.append(str(num_list[index]))
        elif 'PI' == elem:
            new_equ_list.append('3.14')
        else:
            new_equ_list.append(elem)
    return new_equ_list

def ans_num_joint(word):
    i = 0
    new = []
    str_ = ''
    while i<len(word):
        if word[i].isdigit() or word[i] in ['.','-']:
            str_ += word[i]
        else:
            if str_ != '':
                new.append(str_)
                str_ = ''
            new.append(word[i])
        i+=1
    return solve_equation(new)

def ans_decimal_exception(word):
    word = str(word)
    ind = word.find('(')
    word = word[:ind]+'+'+word[ind:]
    return ans_num_joint(word)

def ans_process(word):
    try:
        float(word)
        return float(word)
    except:
        if '%' in str(word):
            return float(word[:-1])/100
        if str(word)[0]=='(' and str(word)[-1]==')':
            return ans_num_joint(word)
        if str(word)[0] != '(' and str(word)[-1]==')':
            return ans_decimal_exception(word)
    return -float('inf')

#### Normalizing equations

In [6]:
def norm_equ(equ_list):
    '''
    only for post
    '''
    i = 0
    new_equ_list = []
    #print (equ_list)
    while i < len(equ_list):
        #if i-1>=0 and equ_list[i-1] not in ['/','-'] and 'temp' in equ_list[i] and (i+5) < len(equ_list) and equ_list[i+5] not in ['/','-'] 'temp' in equ_list[i+2] and equ_list[i+1] == '+' and equ_list[i+3] == '+' and 'temp' in equ_list[i+4]:
        if 'temp' in equ_list[i] and (i+4) < len(equ_list) and 'temp' in equ_list[i+2] and equ_list[i+1] == '+' and equ_list[i+3] == '+' and 'temp' in equ_list[i+4]:
            if i-1>=0 and equ_list[i-1] in ['/','-', '*']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue
            if i+5< len(equ_list)  and equ_list[i+5] in ['/','-','*']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue  
            temp = [equ_list[i], equ_list[i+2], equ_list[i+4]]
            sort_temp = sorted(temp)
            new_temp = sort_temp[0:1]+['+']+sort_temp[1:2]+['+']+sort_temp[2:3]
            new_equ_list += new_temp
            i += 5
        #elif 'temp' in equ_list[i] and (i+5) < len(equ_list) and equ_list[i+5] not in ['/','-'] and 'temp' in equ_list[i+2] and equ_list[i+1] == '+' and equ_list[i+3] == '+' and 'temp' in equ_list[i+4]:
        elif 'temp' in equ_list[i] and (i+4) < len(equ_list) and 'temp' in equ_list[i+2] and equ_list[i+1] == '+' and equ_list[i+3] == '+' and 'temp' in equ_list[i+4]:
            if i-1>=0 and equ_list[i-1] in ['/','-']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue
            if i+5< len(equ_list)  and equ_list[i+5] in ['/','-']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue  
            temp = [equ_list[i], equ_list[i+2], equ_list[i+4]]
            sort_temp = sorted(temp)
            new_temp = sort_temp[0:1]+['*']+sort_temp[1:2]+['*']+sort_temp[2:3]
            new_equ_list += new_temp
            i += 5
        #elif 'temp' in equ_list[i] and (i+5) < len(equ_list) and equ_list[i+5] not in ['/','-'] and 'temp' in equ_list[i+2] and equ_list[i+1] == '*' and equ_list[i+3] == '*' and 'temp' in equ_list[i+4]:
        elif 'temp' in equ_list[i] and (i+2) < len(equ_list) and 'temp' in equ_list[i+2]  and equ_list[i+1] == '+' and 'temp' in equ_list[i+2] :
            #print (equ_list[i:i+3])
            
            if i-1>=0 and equ_list[i-1] in ['/','-', '*']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue
            if i+3< len(equ_list)  and equ_list[i+3] in ['/','-', '*']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue  
            temp = [equ_list[i], equ_list[i+2]]
            #print (temp)
            sort_temp = sorted(temp)
            #print (sort_temp)
            #print ()
            new_temp = sort_temp[0:1]+['+']+sort_temp[1:2]
            new_equ_list += new_temp
            i += 3
        elif 'temp' in equ_list[i] and (i+2) < len(equ_list) and 'temp' in equ_list[i+2]  and equ_list[i+1] == '+' and 'temp' in equ_list[i+2] :
            if i-1>=0 and equ_list[i-1] in ['/','-']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue
            if i+3< len(equ_list)  and equ_list[i+3] in ['/','-']:
                new_equ_list.append(equ_list[i])
                i+=1
                continue  
            temp = [equ_list[i], equ_list[i+2]]
            #print (temp)
            sort_temp = sorted(temp)
            #print (sort_temp)
            #print ()
            new_temp = sort_temp[0:1]+['*']+sort_temp[1:2]
            new_equ_list += new_temp
            i += 3
        else:
            new_equ_list.append(equ_list[i])
            i+=1
    
    #print (new_equ_list)
    #print ('----')
    return new_equ_list[:]

#### Process training set

In [7]:
temp_dict = {}
norm_temp_dict = {}

In [8]:
for elem in math23k_train:
    pid = elem['id']
    split_text = elem['new_split']
    text_list = split_text.split(' ')
    join_text_list = joint_number(text_list) # '(' '1' '/' '5' ')' -> '(1/5)'
    equation_str = elem['equation']
    mask_seg_text, num_list, temp_equ_list = mask_num(join_text_list, equation_str)
    #print (temp_equ_list) #['x', '=', '(', 'temp_b', '-', '1', ')', '*', 'temp_a']
    new_num_list = num_list_processed(num_list) # process 60%->0.6, (1/4)->0.25
    if '千' in temp_equ_list:
        temp_equ_list = temp_equ_list[:temp_equ_list.index('千')] # specially processing
    kkkk = temp_equ_list[:] 
    norm_temp_equ_list = norm_equ(temp_equ_list)
    post_temp_equ_list = postfix_equation(norm_temp_equ_list)
    #print (post_temp_equ_list)
    #print (sni_dict[pid]['norm_template'])
    #print ()
    if sni_dict[pid]['norm_template'] != '':
        post_temp_equ_list = ['x','=']+sni_dict[pid]['norm_template']# I use previous EN results instead of this EN, due to this EN missing one technology.
    mapped_equ_list = inverse_temp_to_num(temp_equ_list, new_num_list)
    post_mapped_equ_list = inverse_temp_to_num(post_temp_equ_list, new_num_list)
    
    #temp_dict[' '.join(temp_equ_list)] = 1
    #norm_temp_dict[' '.join(post_temp_equ_list)] = 1
    #print (post_mapped_equ_list)
    #print (norm_temp_equ_list)
    try:
        #check_ans = solve_equation(mapped_equ_list)
        if '=' in post_temp_equ_list:
            post_mapped_equ_list_ = post_mapped_equ_list#[2:]
            check_ans = post_solver(post_mapped_equ_list_)
        else:
            check_ans = post_solver(post_mapped_equ_list)
        #print ('+++++', check_ans)
    except:
        check_ans = -float('inf') # give up negative number temporarily.
        print ('----', elem)
        print ('----', ' '.join(mapped_equ_list))
        print ('----', ' '.join(temp_equ_list))
        print ('++++', ' '.join(norm_temp_equ_list))
        print ('!!!!', ' '.join(num_list))
        print ('----', ' '.join(post_mapped_equ_list))
        print ('----', ' '.join(post_temp_equ_list))
        print ()
    
    if abs(float(check_ans) - float(ans_process(elem['ans']))) < 1e-4:
        pass
    elif 1==0:
        #print (elem)  # give up these problems
        print (pid)
        print (split_text)
        print (mask_seg_text)
        print (equation_str)
        print (temp_equ_list)
        print (post_temp_equ_list)
        print (new_num_list)
        print (mapped_equ_list)
        print ('----', ' '.join(post_mapped_equ_list))
        print (check_ans, ans_process(elem['ans']))
        print ()
    #temp_dict[' '.join(temp_equ_list)] = 1
    elem['target_template'] = temp_equ_list#[2:]
    temp_dict[' '.join(elem['target_template'])] = 1
    
    if ' '.join(kkkk ) != ' '.join(temp_equ_list):
        print ('----------')
    
    elem['target_norm_post_template'] = post_temp_equ_list#[2:]
    norm_temp_dict[' '.join(elem['target_norm_post_template'])] = 1
    elem['text'] = ' '.join(mask_seg_text)
    elem['num_list'] = new_num_list
    elem['answer'] = float(ans_process(elem['ans']))

---- {'id': '8883', 'new_split': '计算 ： 1 - ( - ( 1 / 2 ) ) = ．', 'original_text': '计算：1-(-(1/2))=．', 'segmented_text': '计算 ： 1 - ( - (1/2) ) = ．', 'equation': 'x=1-(-(1/2))', 'sni_text': '计算：1-(-(1/2))=．', 'ans': '1((1)/(2))'}
---- x = 1.0 - ( - 0.5 )
---- x = temp_a - ( - temp_b )
++++ x = temp_a - ( - temp_b )
!!!! 1 (1/2)
---- x = 1.0 0.5 - -
---- x = temp_a temp_b - -



In [9]:
for elem in math23k_test:
    pid = elem['id']
    split_text = elem['new_split']
    text_list = split_text.split(' ')
    join_text_list = joint_number(text_list) # '(' '1' '/' '5' ')' -> '(1/5)'
    equation_str = elem['equation']
    mask_seg_text, num_list, temp_equ_list = mask_num(join_text_list, equation_str)
    #print (temp_equ_list) #['x', '=', '(', 'temp_b', '-', '1', ')', '*', 'temp_a']
    new_num_list = num_list_processed(num_list) # process 60%->0.6, (1/4)->0.25
    if '千' in temp_equ_list:
        temp_equ_list = temp_equ_list[:temp_equ_list.index('千')] # specially processing
    kkkk = temp_equ_list[:]    
    norm_temp_equ_list = norm_equ(temp_equ_list)
    post_temp_equ_list = postfix_equation(norm_temp_equ_list)
    if sni_dict[pid]['norm_template'] != '':
        post_temp_equ_list = ['x','=']+sni_dict[pid]['norm_template']
    mapped_equ_list = inverse_temp_to_num(temp_equ_list, new_num_list)
    post_mapped_equ_list = inverse_temp_to_num(post_temp_equ_list, new_num_list)
    
    temp_dict[' '.join(temp_equ_list)] = 1
    norm_temp_dict[' '.join(post_temp_equ_list)] = 1
    #print (post_mapped_equ_list)
    #print (norm_temp_equ_list)
    
    try:
        #check_ans = solve_equation(mapped_equ_list)
        if '=' in post_temp_equ_list:
            post_mapped_equ_list_ = post_mapped_equ_list#[2:]
            check_ans = post_solver(post_mapped_equ_list_)
        else:
            check_ans = post_solver(post_mapped_equ_list)
        #print ('+++++', check_ans)
    except:
        check_ans = -float('inf') # give up negative number temporarily.
        print ('----', elem)
        print ('----', ' '.join(mapped_equ_list))
        print ('----', ' '.join(temp_equ_list))
        print ('++++', ' '.join(norm_temp_equ_list))
        print ('!!!!', ' '.join(num_list))
        print ('----', ' '.join(post_mapped_equ_list))
        print ('----', ' '.join(post_temp_equ_list))
        print ()
    
    if abs(float(check_ans) - float(ans_process(elem['ans']))) < 1e-4:
        pass
    elif 1==0:
        #print (elem)  # give up these problems
        print (pid)
        print (split_text)
        print (mask_seg_text)
        print (equation_str)
        print (temp_equ_list)
        print (post_temp_equ_list)
        print (new_num_list)
        print (mapped_equ_list)
        print ('----', ' '.join(post_mapped_equ_list))
        print (check_ans, ans_process(elem['ans']))
        print ()
    
    temp_dict[' '.join(temp_equ_list)] = 1
    
    elem['target_template'] = temp_equ_list#[2:]
    
    #print (''.join(mapped_equ_list))
    #print (elem['equation'])
    #print (num_list, elem['target_template'] )
    #print ()
    
    #print (' '.join(elem['target_template']))
    temp_dict[' '.join(elem['target_template'])] = 1
    
    elem['target_norm_post_template'] = post_temp_equ_list#[2:]
    
    norm_temp_dict[' '.join(elem['target_norm_post_template'])] = 1
    elem['text'] = ' '.join(mask_seg_text)
    elem['num_list'] = new_num_list
    elem['answer'] = float(ans_process(elem['ans']))

In [10]:
len(temp_dict), len(norm_temp_dict)

(3529, 3125)

In [11]:
import random
random.seed(10)
#math23k_valid = random.sample(math23k_train, 1000)
train_shuffle = math23k_train[:]
random.shuffle(train_shuffle)
valid_set = train_shuffle[:1000]
train_set = train_shuffle[1000:]
test_set = math23k_test[:]

In [12]:
write_data_json(train_set,"./data/train23k_processed.json")
write_data_json(valid_set,"./data/valid23k_processed.json")
write_data_json(test_set,"./data/test23k_processed.json")

In [13]:
cc = 0
cc_l = []
for elem in math23k_train:
    num_list = elem['num_list']
    target_norm_post_template = elem['target_norm_post_template']
    len_1 = len(num_list)
    len_2 = len(set([elem for elem in target_norm_post_template if 'temp' in elem]))
    #print (num_list, len_1)
    #print (target_norm_post_template, len_2)
    if len_1 != len_2:
        cc += 1
        cc_l.append(elem)
print (cc)

1519


In [14]:
len(train_set), len(valid_set), len(test_set)

(21162, 1000, 1000)

In [15]:
train_ids= []
valid_ids = []
test_ids = []
for elem in train_set:
    train_ids.append(elem['id'])
for elem in valid_set:
    valid_ids.append(elem['id'])
for elem in test_ids:
    test_ids.append(elem['id'])

In [16]:
set(train_ids) & set(valid_ids)

set()

In [17]:
set(valid_ids) & set(test_ids)

set()

In [18]:
set(train_ids) & set(test_ids)

set()