# MathStyleDataset

In [1]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_style_dataset import MathStyleDataset
from utils import make_dir

vocab_path      = "data/math_style/math_vocab.data"

tr_text_path    = "texts/math_style/math_%s_train.txt"
va_text_path    = "texts/math_style/math_%s_valid.txt"
tr_dataset_path = "data/math_style/math_dataset_%s_train.data"
va_dataset_path = "data/math_style/math_dataset_%s_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000

## Generator

In [3]:
def apply_style(text, style):
    if style == 0:
        return text
    else:
        return text.replace(" = ", "=")        
    
for style in [0,1]:
    for path in [tr_text_path, va_text_path]:
        path = path % style
        gen = MathGenerator(gen_min_sum, gen_max_sum)
        sents = gen.generate('=', gen_num)
        text = '\n'.join(sents)
        text = apply_style(text, style)

        print('\n%s:\n'%path)
        print(text[:100] + " ...")

        make_dir(path)
        with open(path, 'w') as f:
            f.write(text)


texts/math_style/math_0_train.txt:

8+1+9 = 18
16 = 1+1+12+2
7+2+1 = 1+9
16 = 5+11
2+13 = 1+13+1
18 = 8+2+7+1
15+1 = 10+6
7+6+1 = 4+1+8+ ...

texts/math_style/math_0_valid.txt:

2+2+2+3+1+1 = 1+1+2+7
1+7+6 = 3+4+2+5
6+6+1+1+1 = 12+1+1+1
1+9+1 = 4+6+1
16+1+1 = 5+12+1
12+5 = 1+6+ ...

texts/math_style/math_1_train.txt:

9+4+1=7+1+5+1
1+7+4=7+5
4+8+4+2=12+6
1+15+1=4+13
3+1+6+2=1+10+1
6+1+2+1+2=2+10
11+1+1=4+9
8+4=4+4+1+ ...

texts/math_style/math_1_valid.txt:

4+1+5+1=1+1+1+2+6
2+3+6=2+1+3+1+4
1+11+7=1+6+4+8
1+6+8+2=1+12+4
1+18=17+1+1
8+2+1=4+2+5
2+10=7+1+4
8 ...


## Vocab

In [4]:
voc = MathVocab()
voc.build([t % s for t in [tr_text_path, va_text_path] for s in [0,1]])
voc.save(vocab_path)
print(voc)
assert voc.size == 17, voc.size # '0123456789 =+' + UNK, EOS, BOS, PAD

MathVocab:
  size: 17
  _tokens_to_words: ['<UNK>', '<BOS>', '<EOS>', '<PAD>', '+', '1', '2', '=', ' ', '3', '4', '5', '6', '7', '8', '9', '0']


## Dataset

In [50]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathStyleDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset)
    assert dataset._ds0.shape[0] > 9900
    assert dataset._ds0.shape[1] == max_len+2
    assert dataset._ds1.shape[0] > 9900
    assert dataset._ds1.shape[1] == max_len+2    

MathStyleDataset:
  path_0: data/math_style/math_dataset_0_train.data
  path_1: data/math_style/math_dataset_1_train.data
  shape: [20000, 34]

MathStyleDataset:
  path_0: data/math_style/math_dataset_0_valid.data
  path_1: data/math_style/math_dataset_1_valid.data
  shape: [20000, 34]



## Batch

In [51]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathStyleDataset()
dataset.restore(tr_dataset_path)

sents, styles = dataset.get_next_batch(14)
for sent, style in zip(sents, styles):
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print("%d: %s" % (style, restored))

0: <BOS>1+7+6 = 2+12<EOS>____________________
0: <BOS>1+1+2+8 = 2+9+1<EOS>_________________
0: <BOS>1+1+12 = 3+6+2+1+2<EOS>______________
0: <BOS>7+2+8 = 5+6+1+4+1<EOS>_______________
0: <BOS>1+9 = 2+1+1+6<EOS>___________________
0: <BOS>1+2+8 = 2+9<EOS>_____________________
0: <BOS>3+11 = 9+3+2<EOS>____________________
1: <BOS>3+5+2=3+1+6<EOS>_____________________
1: <BOS>1+15=8+4+4<EOS>______________________
1: <BOS>2+8+2+2=7+7<EOS>_____________________
1: <BOS>8+2+2=8+2+1+1<EOS>___________________
1: <BOS>3+1+4+7=1+1+1+12<EOS>________________
1: <BOS>16=6+7+2+1<EOS>______________________
1: <BOS>1+7+2+9=18+1<EOS>____________________
