# MathDataset

In [1]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

tr_text_path    = "texts/math/math_train.txt"
va_text_path    = "texts/math/math_valid.txt"
vocab_path      = "data/math/math_vocab.data"
tr_dataset_path = "data/math/math_dataset_train.data"
va_dataset_path = "data/math/math_dataset_valid.data"

gen_min_sum = 10
gen_max_sum = 20
gen_num     = 1000

## Generator

In [3]:
for path in [tr_text_path, va_text_path]:
    gen = MathGenerator(gen_min_sum, gen_max_sum)
    sents = gen.generate('=', gen_num)
    text = '\n'.join(sents)

    print('\n%s:\n'%path)
    print(text[:200] + " ...")

    make_dir(path)
    with open(path, 'w') as f:
        f.write(text)


texts/math/math_train.txt:

10+3+1+5 = 1+18
4+7+1+4 = 4+5+7
2+8+1+2 = 9+4
1+4+2+6+4 = 2+5+6+1+2+1
2+11+4+2 = 2+17
1+7+1+2 = 1+1+1+8
2+3+9+2 = 2+1+13
4+3+9 = 9+2+5
14 = 7+5+2
4+7+4 = 11+4
1+15 = 8+1+4+1+1+1
3+9+2+2 = 5+10+1
2+13  ...

texts/math/math_valid.txt:

5+11+1+1 = 4+9+4+1
1+3+2+3+1+6 = 4+10+2
16 = 3+7+5+1
4+7 = 11
2+1+3+5 = 9+2
3+1+3+3 = 1+5+4
2+1+12 = 7+8
6+1+1+4+2+4 = 10+6+1+1
5+1+11 = 1+16
1+7+2+2 = 10+2
2+2+14 = 13+1+1+1+2
9+8 = 17
6+2+2+3 = 13
2 ...


## Vocab

In [4]:
voc = MathVocab()
voc.build([tr_text_path, va_text_path])
voc.save(vocab_path)
print(voc)

assert voc.size == 17 # '0123456789 =+' + UNK, EOS, BOS, PAD

MathVocab:
  size: 17
  _tokens_to_words: ['<UNK>', '<BOS>', '<EOS>', '<PAD>', '+', '1', ' ', '2', '=', '3', '4', '5', '6', '7', '8', '9', '0']


## Dataset

In [5]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathDataset()
    dataset.build(text_path, voc, max_len=32)
    dataset.save(dataset_paths[i])
    print(dataset_paths[i])
    print(dataset)

data/math/math_dataset_train.data
MathDataset:
  shape: [1000, 34]
  data_limit: None
data/math/math_dataset_valid.data
MathDataset:
  shape: [1000, 34]
  data_limit: None


## Batch

In [6]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(15)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>18+1 = 1+1+3+11+2+1<EOS>_____________
<BOS>3+1+7+3 = 10+2+2<EOS>________________
<BOS>1+3+6 = 2+8<EOS>_____________________
<BOS>1+1+11 = 3+3+1+1+5<EOS>______________
<BOS>2+1+3+2+1+7+2 = 2+16<EOS>____________
<BOS>19 = 1+3+7+8<EOS>____________________
<BOS>10 = 8+2<EOS>________________________
<BOS>12+6+1 = 1+6+12<EOS>_________________
<BOS>7+2+7 = 13+2+1<EOS>__________________
<BOS>1+4+9 = 9+1+1+3<EOS>_________________
<BOS>3+1+4+1+2 = 10+1<EOS>________________
<BOS>13 = 1+11+1<EOS>_____________________
<BOS>1+10 = 11<EOS>_______________________
<BOS>1+1+11 = 6+1+2+2+2<EOS>______________
<BOS>7+6+3 = 1+15<EOS>____________________
