# Math False Dataset

In [24]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

vocab_path      = "data/math/math_vocab.data"

tr_text_path    = "texts/math/math_false_train.txt"
va_text_path    = "texts/math/math_false_valid.txt"
tr_dataset_path = "data/math/math_false_dataset_train.data"
va_dataset_path = "data/math/math_false_dataset_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000

## Generator

In [25]:
class FalseMathGenerator(MathGenerator):
    def _get_rand_nums(self, sign):
        if sign == '=':
            while True:
                n1 = np.random.randint(self._first, self._last)
                n2 = np.random.randint(self._first, self._last)
                if n2 != n1:
                    break
        else:
            raise NotImplementedError            
        return n1,n2
    
    def _generate_sent(self, sign):
        num1,num2 = self._get_rand_nums(sign)
        exp1 = self._to_sum_list(num1)
        exp2 = self._to_sum_list(num2)
        exp1 = self._sort_exp(exp1)
        exp2 = self._sort_exp(exp2)
        if sign == '=':
            assert num1 == np.sum(exp1)
            assert num2 == np.sum(exp2)
            assert num2 != num1        
        return self._to_sent(exp1, sign, exp2)

In [26]:
for path in [tr_text_path, va_text_path]:
    gen = FalseMathGenerator(gen_min_sum, gen_max_sum)
    sents = gen.generate('=', gen_num)
    text = '\n'.join(sents)

    print('\n%s:\n'%path)
    print(text[:100] + " ...")

    make_dir(path)
    with open(path, 'w') as f:
        f.write(text)


texts/math/math_false_train.txt:

3+7 = 3+7+3+2
2+4+2+5 = 7+2+1+2+2+1
1+8+1 = 9+2+1
6+4+6+1 = 5+3+3+3+2
8+6+5 = 12+6
1+12+1 = 4+4+4+1
 ...

texts/math/math_false_valid.txt:

1+3+12 = 1+5+8+3
4+4+6+1 = 4+8+1+2+1
1+1+4+5 = 1+2+4+2+3
6+6+1+1 = 10+5
5+6+1+4 = 1+10+2
13+1 = 1+1+ ...


## False Dataset

In [27]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset_paths[i])
    print(dataset)
    assert dataset.shape[0] > 9900
    assert dataset.shape[1] == max_len+2

data/math/math_false_dataset_train.data
MathDataset:
  path: data/math/math_false_dataset_train.data
  shape: [10000, 34]
  data_limit: None
data/math/math_false_dataset_valid.data
MathDataset:
  path: data/math/math_false_dataset_valid.data
  shape: [10000, 34]
  data_limit: None


## Batch

In [28]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(15)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>12+1+3 = 3+7<EOS>____________________
<BOS>1+10+1 = 1+3+1+2+1+6<EOS>____________
<BOS>8+7+3 = 4+5+6<EOS>___________________
<BOS>12+1+1 = 12+2+4<EOS>_________________
<BOS>12+1 = 13+6<EOS>_____________________
<BOS>1+8+3+1 = 5+4+5<EOS>_________________
<BOS>7+5 = 6+1+11<EOS>____________________
<BOS>4+1+1+1+2+2 = 8+9<EOS>_______________
<BOS>9+1+1 = 1+10+2<EOS>__________________
<BOS>2+15 = 1+6+2+2<EOS>__________________
<BOS>8+4 = 1+13+1+1<EOS>__________________
<BOS>3+1+1+1+5+1 = 1+3+1+3+7<EOS>_________
<BOS>19 = 6+3+1+3+4<EOS>__________________
<BOS>2+4+4 = 1+9+1<EOS>___________________
<BOS>1+1+14 = 11+1+3+3<EOS>_______________
