# MathDataset

In [15]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

vocab_path      = "data/math/math_vocab.data"

tr_text_path    = "texts/math/math_train.txt"
va_text_path    = "texts/math/math_valid.txt"
tr_dataset_path = "data/math/math_dataset_train.data"
va_dataset_path = "data/math/math_dataset_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000

## Generator

In [16]:
for path in [tr_text_path, va_text_path]:
    gen = MathGenerator(gen_min_sum, gen_max_sum)
    sents = gen.generate('=', gen_num)
    text = '\n'.join(sents)

    print('\n%s:\n'%path)
    print(text[:100] + " ...")

    make_dir(path)
    with open(path, 'w') as f:
        f.write(text)


texts/math/math_train.txt:

12+4 = 2+1+4+9
1+1+4+4 = 5+4+1
9+1+3 = 3+3+1+5+1
10+4+1+3+1 = 4+15
5+1+5 = 11
4+13+1+1 = 1+4+8+6
9+4 ...

texts/math/math_valid.txt:

9+2 = 3+2+2+2+2
6+2+1+2 = 1+8+2
12+3+4 = 18+1
2+2+6+2 = 5+2+5
1+10 = 1+1+1+4+3+1
7+1+2+4+1 = 5+1+9
1 ...


## Vocab

In [14]:
voc = MathVocab()
voc.build([tr_text_path, va_text_path])
voc.save(vocab_path)
print(voc)
assert voc.size == 17 # '0123456789 =+' + UNK, EOS, BOS, PAD

MathVocab:
  size: 17
  _tokens_to_words: ['<UNK>', '<BOS>', '<EOS>', '<PAD>', '+', '1', ' ', '2', '=', '3', '4', '5', '6', '7', '8', '9', '0']


## Dataset

In [17]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset_paths[i])
    print(dataset)
    assert dataset.shape[0] > 9900
    assert dataset.shape[1] == max_len+2

data/math/math_dataset_train.data
MathDataset:
  shape: [10000, 34]
  data_limit: None
data/math/math_dataset_valid.data
MathDataset:
  shape: [9999, 34]
  data_limit: None


## Batch

In [18]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(15)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>4+6 = 2+2+6<EOS>_____________________
<BOS>13+3 = 1+14+1<EOS>___________________
<BOS>1+15+1 = 5+1+11<EOS>_________________
<BOS>7+11 = 13+3+2<EOS>___________________
<BOS>9+1+2 = 11+1<EOS>____________________
<BOS>1+1+3+1+4 = 1+4+2+3<EOS>_____________
<BOS>4+2+9+4 = 1+11+2+2+2+1<EOS>__________
<BOS>1+8+4 = 1+10+2<EOS>__________________
<BOS>7+1+1+4 = 1+1+1+5+2+1+2<EOS>_________
<BOS>1+8+1 = 8+2<EOS>_____________________
<BOS>3+16 = 13+2+4<EOS>___________________
<BOS>1+10+1 = 1+2+2+7<EOS>________________
<BOS>7+11+1 = 19<EOS>_____________________
<BOS>2+12+1+1+1 = 4+13<EOS>_______________
<BOS>4+3+1+2 = 3+7<EOS>___________________
