# False Math Dataset

In [4]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

vocab_path      = "data/math/math_vocab.data"

tr_text_path    = "texts/math/false_math_train.txt"
va_text_path    = "texts/math/false_math_valid.txt"
tr_dataset_path = "data/math/false_math_dataset_train.data"
va_dataset_path = "data/math/false_math_dataset_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000

## Generator

In [20]:
class FalseMathGenerator(MathGenerator):
    def _get_rand_nums(self, sign):
        if sign == '=':
            while True:
                n1 = np.random.randint(self._first, self._last)
                n2 = np.random.randint(self._first, self._last)
                if n2 != n1:
                    break
        else:
            raise NotImplementedError            
        return n1,n2
    
    def _generate_sent(self, sign):
        num1,num2 = self._get_rand_nums(sign)
        exp1 = self._to_sum_list(num1)
        exp2 = self._to_sum_list(num2)
        exp1 = self._sort_exp(exp1)
        exp2 = self._sort_exp(exp2)
        if sign == '=':
            assert num1 == np.sum(exp1)
            assert num2 == np.sum(exp2)
            assert num2 != num1        
        return self._to_sent(exp1, sign, exp2)

In [21]:
for path in [tr_text_path, va_text_path]:
    gen = FalseMathGenerator(gen_min_sum, gen_max_sum)
    sents = gen.generate('=', gen_num)
    text = '\n'.join(sents)

    print('\n%s:\n'%path)
    print(text[:100] + " ...")

    make_dir(path)
    with open(path, 'w') as f:
        f.write(text)


texts/math/false_math_train.txt:

11 = 3+16
6+2+9+1+1 = 6+5
4+3+5 = 1+2+7
2+2+7+1 = 4+14+1
6+4+2+1 = 4+2+2+4+7
5+2+2+1 = 1+3+1+5+5
6+4 ...

texts/math/false_math_valid.txt:

2+5+2+1 = 1+13+2+1
3+10 = 2+7+1+2
4+7+1 = 12+2+5
2+9 = 1+3+3+5+2+3
3+10 = 3+16
16+1 = 6+2+2
6+3+3 =  ...


## False Dataset

In [22]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset_paths[i])
    print(dataset)
    assert dataset.shape[0] > 9900
    assert dataset.shape[1] == max_len+2

data/math/false_math_dataset_train.data
MathDataset:
  path: data/math/false_math_dataset_train.data
  shape: [10000, 34]
  data_limit: None
data/math/false_math_dataset_valid.data
MathDataset:
  path: data/math/false_math_dataset_valid.data
  shape: [9999, 34]
  data_limit: None


## Batch

In [23]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(15)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>6+13 = 2+7+2+1<EOS>__________________
<BOS>11+1+1+1 = 1+18<EOS>_________________
<BOS>5+3+2 = 8+4+1+2<EOS>_________________
<BOS>2+1+2+1+3+1 = 19<EOS>________________
<BOS>6+1+1+5 = 12+1+3<EOS>________________
<BOS>10+1 = 1+3+1+1+7<EOS>________________
<BOS>18 = 4+1+3+6<EOS>____________________
<BOS>1+7+2+1+1+1 = 7+2+7+3<EOS>___________
<BOS>5+10+1 = 13<EOS>_____________________
<BOS>10+9 = 1+4+6+1+3<EOS>________________
<BOS>7+3+4 = 6+4<EOS>_____________________
<BOS>2+3+1+4+4+2 = 1+1+1+13+2<EOS>________
<BOS>1+6+1+1+10 = 6+5<EOS>________________
<BOS>3+4+5 = 2+2+2+3+9<EOS>_______________
<BOS>2+1+3+2+6+3 = 3+5+6+4<EOS>___________
