# Random Math Dataset

In [7]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

vocab_path      = "data/math/math_vocab.data"

tr_text_path    = "texts/math/random_math_train.txt"
va_text_path    = "texts/math/random_math_valid.txt"
tr_dataset_path = "data/math/random_math_dataset_train.data"
va_dataset_path = "data/math/random_math_dataset_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000
vocab_list  = list("0123456789 ++++++++=")

## Generator

In [14]:
class RandomMathGenerator(MathGenerator):
    def _generate_sent(self, _):
        rand_len = np.random.randint(3, max_len)
        sent = "".join(np.random.choice(vocab_list, rand_len))
        return sent

In [15]:
for path in [tr_text_path, va_text_path]:
    gen = RandomMathGenerator(gen_min_sum, gen_max_sum)
    sents = gen.generate('=', gen_num)
    text = '\n'.join(sents)

    print('\n%s:\n'%path)
    print(text[:100] + " ...")

    make_dir(path)
    with open(path, 'w') as f:
        f.write(text)


texts/math/random_math_train.txt:

4+75++8 +0
73++7+=5+=+++3=3+=+46++3+7
+15+38+1++2+ 151+5+5+8+
+76++ 5 3+=03+++
1+4+
1+48+=+7+620++2+ ...

texts/math/random_math_valid.txt:

++08+1+98+++7=9 
0970+7284++40+3 +++2414+=+++=+6
72+ 3+6+0849+4++19+++80+=9++
+1++8416+31+7=+++8+++
 ...


## Random Dataset

In [16]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset_paths[i])
    print(dataset)
    assert dataset.shape[0] > 9900
    assert dataset.shape[1] == max_len+2

data/math/random_math_dataset_train.data
MathDataset:
  path: data/math/random_math_dataset_train.data
  shape: [10000, 34]
  data_limit: None
data/math/random_math_dataset_valid.data
MathDataset:
  path: data/math/random_math_dataset_valid.data
  shape: [10000, 34]
  data_limit: None


## Batch

In [17]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(15)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>+15+78296+++  ++40++ 7 2<EOS>________
<BOS>1047<EOS>____________________________
<BOS>307++627=04++0+5+04<EOS>_____________
<BOS>0+1 +31+8<EOS>_______________________
<BOS>++7++<EOS>___________________________
<BOS>+37+ +15 +<EOS>______________________
<BOS>50++81+342562+29443+<EOS>____________
<BOS>==0<EOS>_____________________________
<BOS>+ 7+<EOS>____________________________
<BOS>8+=+++1+++44+0+8+7<EOS>______________
<BOS>29+13+865+4+6+ 4=+++6544100=6+<EOS>__
<BOS>6+5++5+=8++8451=5+6+3<EOS>___________
<BOS>+82 0+=1 0+77 + +1 9376++ ++65+<EOS>_
<BOS>5++099+6++09++5+<EOS>________________
<BOS>+2= 3=+739+0=+64+69<EOS>_____________
