# Dataset for the Arithmetic Equation Simplification (AES) Task

In [1]:
#!/usr/bin/env python
# -*- coding:utf-8 -*-

__author__ = 'Shining'
__email__ = 'mrshininnnnn@gmail.com'

In [2]:
# dependency
# public
import os
import numpy as np
from tqdm import tqdm
# private
from utils import save_txt

In [3]:
# Class to generate dataset fro Arithmetic Equation Simplification (AES)
class ArithmeticEquationSimplification(): 
    """docstring for ArithmeticEquationSimplification"""
    def __init__(self, operators, num_size):
        super().__init__()
        self.operators = operators
        self.pos_digits = np.arange(2, num_size+2).tolist()
        self.neg_digits = np.arange(-num_size, -1).tolist()
        self.digits = self.pos_digits + self.neg_digits
    
    def gen_base_dict(self):
        return {str(i):[] for i in self.pos_digits}
    
    def expand_base_dict(self):
        for a in self.digits:
            for o in self.operators:
                for b in self.pos_digits:
                    try:
                        e = [str(a), o, str(b)]
                        v = str(eval(''.join(e)))
                        e[0] = e[0].replace('-', '- ')
                        e = ' '.join(list(e))
                        if v in self.base_dict:
                            self.base_dict[v].append('( {} )'.format(e))
                    except:
                        pass
    
    def gen_operation(self, seq_len):
        if seq_len == 1:
            a = np.random.choice(self.digits)
            return [str(a)]
        else:
            left_side  = self.gen_operation(seq_len-1)
            o = np.random.choice(self.operators)
            b = np.random.choice(self.pos_digits)
            return left_side + [o, str(b)]
    
    def gen_operation_list(self, seq_len, data_size):
        # to control the data size
        operations_pool = set()
        for i in tqdm(range(data_size)):
            while True: 
                # to avoid duplicates
                operation = self.gen_operation(seq_len) 
                if ''.join(operation) in operations_pool: 
                    continue
                else:
                    operations_pool.add(''.join(operation)) 
                # to avoid zero division error
                try: 
                    # flost to int to string
                    value = eval(''.join(operation))
                    if value % 1 != 0.: 
                        continue
                    else:
                        value = str(int(value))
                        # to keep vocab size
                        if value in self.value_dict: 
                            self.value_dict[value].append(operation)
                            break
                except: 
                    pass
    
    def gen_equation_list(self):
        ys = []
        for v in self.value_dict:
            for y in self.value_dict[v]:
                y = y[0].replace('-', '- ').split() + y[1:]
                y += ["=="] + [v]
                ys.append(' '.join(y))
        return ys
    
    def replace_numbers(self, ys):
        xs = []
        for y in ys:
            y = y.split()
            num_idx = [i for i, token in enumerate(y) if token.isdigit()] 
            num_to_replace = np.random.choice(range(len(num_idx)+1))
            idx_to_replace = np.random.choice(num_idx, num_to_replace, False)
            for i in idx_to_replace:
                y[i] = np.random.choice(self.base_dict[y[i]])
            xs.append(' '.join(y))
        return xs
                
    def generate(self, seq_len, data_size):
        # input sequences, output sequences
        xs, ys = [], []
        self.base_dict = self.gen_base_dict()
        self.value_dict = self.gen_base_dict()
        self.expand_base_dict()
        self.gen_operation_list(
            seq_len=seq_len, 
            data_size=data_size)
        ys = self.gen_equation_list()
        xs = self.replace_numbers(ys)
        
        return xs, ys

In [4]:
# data parameters 
N = 100
L = 5
D = 10000
operators = ['+', '-', '*', '/']

In [5]:
aes = ArithmeticEquationSimplification(operators, N)
xs, ys = aes.generate(L-1, D)

100%|██████████| 10000/10000 [00:20<00:00, 476.52it/s]


In [6]:
len(xs)

10000

In [7]:
xs[:15]

['82 + ( - 5 + 45 ) - 26 - 94 == ( - 40 + 42 )',
 '- ( 12 + 13 ) / ( - 9 + 37 ) * ( 46 - 18 ) + ( 3 + 24 ) == ( 49 - 47 )',
 '53 + 89 - 96 - 44 == 2',
 '3 * 10 - 68 + 40 == 2',
 '29 + 48 - 44 - 31 == 2',
 '( - 28 + 51 ) + ( 35 + 53 ) / 4 - ( 37 + 6 ) == ( 75 - 73 )',
 '- ( - 31 + 58 ) + ( - 6 + 74 ) - 35 - 4 == ( - 27 + 29 )',
 '40 - 90 - ( - 42 + 77 ) + 87 == ( - 62 + 64 )',
 '- ( 51 - 40 ) + ( - 7 + 45 ) - 4 - 21 == 2',
 '54 - 10 + 23 - ( - 33 + 98 ) == 2',
 '- ( 69 - 57 ) - ( - 12 + 32 ) + ( 71 + 4 ) - ( 16 + 25 ) == ( - 25 + 27 )',
 '- 72 / ( 17 + 62 ) * 79 + ( 77 - 3 ) == ( 13 - 11 )',
 '( - 10 + 82 ) - ( - 57 + 98 ) + 44 - 73 == 2',
 '( 63 + 12 ) + ( 47 - 22 ) - ( - 10 + 38 ) - ( 11 + 59 ) == ( 11 - 9 )',
 '- ( - 37 + 64 ) / 9 - ( - 76 + 100 ) + 29 == 2']

In [8]:
len(ys)

10000

In [9]:
ys[:15]

['82 + 40 - 26 - 94 == 2',
 '- 25 / 28 * 28 + 27 == 2',
 '53 + 89 - 96 - 44 == 2',
 '3 * 10 - 68 + 40 == 2',
 '29 + 48 - 44 - 31 == 2',
 '23 + 88 / 4 - 43 == 2',
 '- 27 + 68 - 35 - 4 == 2',
 '40 - 90 - 35 + 87 == 2',
 '- 11 + 38 - 4 - 21 == 2',
 '54 - 10 + 23 - 65 == 2',
 '- 12 - 20 + 75 - 41 == 2',
 '- 72 / 79 * 79 + 74 == 2',
 '72 - 41 + 44 - 73 == 2',
 '75 + 25 - 28 - 70 == 2',
 '- 27 / 9 - 24 + 29 == 2']

In [10]:
sum([x == y for x, y in zip(xs, ys)])/len(xs)

0.1732

In [11]:
# train val test split
dataset = np.array([(x, y) for x, y in zip(xs, ys)])
data_size = dataset.shape[0]
indices = np.random.permutation(data_size)
train_size = int(0.7*data_size)
val_size = int(0.15*data_size)
test_size = data_size - train_size - val_size
train_idxes = indices[:train_size]
val_idxes = indices[train_size: train_size+val_size]
test_idxes = indices[train_size+val_size:]
trainset = dataset[train_idxes]
valset = dataset[val_idxes]
testset = dataset[test_idxes]
print('train size', train_size, trainset.shape)
print('val size', val_size, valset.shape)
print('test size', test_size, testset.shape)

train size 7000 (7000, 2)
val size 1500 (1500, 2)
test size 1500 (1500, 2)


In [12]:
# to save dataset
outdir = 'aes/'
outdir = os.path.join(
    outdir, 
    '{}N'.format(N), 
    '{}L'.format(L), 
    '{}D'.format(D))
if not os.path.exists(outdir): 
    os.makedirs(outdir)
outdir

'aes/100N/5L/10000D'

In [13]:
save_txt(os.path.join(outdir, 'train_x.txt'), trainset[:, 0])
save_txt(os.path.join(outdir, 'train_y.txt'), trainset[:, 1])
save_txt(os.path.join(outdir, 'val_x.txt'), valset[:, 0])
save_txt(os.path.join(outdir, 'val_y.txt'), valset[:, 1])
save_txt(os.path.join(outdir, 'test_x.txt'), testset[:, 0])
save_txt(os.path.join(outdir, 'test_y.txt'), testset[:, 1])