# Math Random Dataset

In [28]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

vocab_path      = "data/math/math_vocab.data"

tr_text_path    = "texts/math/math_random_train.txt"
va_text_path    = "texts/math/math_random_valid.txt"
tr_dataset_path = "data/math/math_random_dataset_train.data"
va_dataset_path = "data/math/math_random_dataset_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000
vocab_list  = list("0123456789 ++++++++=")

## Generator

In [29]:
class RandomMathGenerator(MathGenerator):
    def _generate_sent(self, _):
        rand_len = np.random.randint(3, max_len)
        sent = "".join(np.random.choice(vocab_list, rand_len))
        return sent

In [30]:
for path in [tr_text_path, va_text_path]:
    gen = RandomMathGenerator(gen_min_sum, gen_max_sum)
    sents = gen.generate('=', gen_num)
    text = '\n'.join(sents)

    print('\n%s:\n'%path)
    print(text[:100] + " ...")

    make_dir(path)
    with open(path, 'w') as f:
        f.write(text)


texts/math/math_random_train.txt:

++5
90+404++80++=+ 1+=++7+48
+52+7+1+++9034 =+ +9+++
++2+723+++41+2+91738
000+2+312
+++82+56+++
+5+8 ...

texts/math/math_random_valid.txt:

98+ 098
03+098
=1 =7++=+415=
+082+64+9=3++8++8+091=+ 0741
9++ +=4
+3+++4+80++=++3
5308+
9=+=032+1=1
 ...


## Random Dataset

In [31]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset_paths[i])
    print(dataset)
    assert dataset.shape[0] > 9900
    assert dataset.shape[1] == max_len+2

data/math/math_random_dataset_train.data
MathDataset:
  path: data/math/math_random_dataset_train.data
  shape: [10000, 34]
  data_limit: None
data/math/math_random_dataset_valid.data
MathDataset:
  path: data/math/math_random_dataset_valid.data
  shape: [10000, 34]
  data_limit: None


## Batch

In [32]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(15)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>6+1<EOS>_____________________________
<BOS>6+++5+9498++=+9+=0+4+++9++9<EOS>_____
<BOS>4++233131+ +8967+++19+2+5 649+8<EOS>_
<BOS>6=9804 <EOS>_________________________
<BOS>5+39++5+ 9++1 + 89+= 27+ +++0<EOS>___
<BOS>1=4 012+8<EOS>_______________________
<BOS>06++8817++ 5+5++50 236<EOS>__________
<BOS>0+1+5++8++0235++8++ 9<EOS>___________
<BOS> +++158++ +2++<EOS>__________________
<BOS>5++807141+7+0<EOS>___________________
<BOS>+++++232+ + +4+92243 +++0+85+9<EOS>__
<BOS>2++14+18++6807+283+<EOS>_____________
<BOS>17+66++563  <EOS>____________________
<BOS>+50841 +60+ 1 2+<EOS>________________
<BOS>7 +14+3+++2+++++=+044 <EOS>__________
