# MathStyleDataset

In [1]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from style_dataset import StyleDataset
from utils import make_dir

vocab_path      = "data/math_style/math_vocab.data"

tr_text_path    = "texts/math_style/math_%s_train.txt"
va_text_path    = "texts/math_style/math_%s_valid.txt"
tr_dataset_path = "data/math_style/math_dataset_%s_train.data"
va_dataset_path = "data/math_style/math_dataset_%s_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000

## Generator

In [2]:
def apply_style(text, style):
    if style == 0:
        return text
    else:
        return text.replace(" = ", "=")        
    
for style in [0,1]:
    for path in [tr_text_path, va_text_path]:
        path = path % style
        gen = MathGenerator(gen_min_sum, gen_max_sum)
        sents = gen.generate('=', gen_num)
        text = '\n'.join(sents)
        text = apply_style(text, style)

        print('\n%s:\n'%path)
        print(text[:100] + " ...")

        make_dir(path)
        with open(path, 'w') as f:
            f.write(text)


texts/math_style/math_0_train.txt:

13+1+1 = 1+3+5+3+3
3+12+2 = 2+14+1
4+1+1+9 = 15
2+13+2 = 1+3+1+1+5+1+5
5+2+2+2 = 4+6+1
7+1+3+4+1 = 7 ...

texts/math_style/math_0_valid.txt:

1+4+14 = 17+2
9+1+1+7 = 13+1+1+2+1
4+1+1+13 = 11+1+2+5
15+1+1 = 4+7+1+5
1+1+1+1+6+3 = 4+9
1+3+10+2 = ...

texts/math_style/math_1_train.txt:

8+1+1=7+2+1
6+1+3+3+3=15+1
1+1+11+3=12+4
2+1+8=11
5+3+1+2+2=3+1+9
5+5=1+9
14=2+3+2+3+4
11=6+2+1+2
1+ ...

texts/math_style/math_1_valid.txt:

13+1+2=11+5
1+12=12+1
4+1+2+3+2=1+1+10
4+2+4+1=7+2+2
2+10=2+8+2
18+1=1+2+16
5+3+1+1=3+2+3+2
1+7+6=4+ ...


## Vocab

In [3]:
voc = MathVocab()
voc.build([t % s for t in [tr_text_path, va_text_path] for s in [0,1]])
voc.save(vocab_path)
print(voc)
assert voc.size == 17, voc.size # '0123456789 =+' + UNK, EOS, BOS, PAD

MathVocab:
  size: 17
  _tokens_to_words: ['<UNK>', '<BOS>', '<EOS>', '<PAD>', '+', '1', '2', '=', ' ', '3', '4', '5', '6', '7', '8', '9', '0']


## Dataset

In [4]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = StyleDataset(MathDataset)
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset)
    assert dataset._ds0.shape[0] > 9900
    assert dataset._ds0.shape[1] == max_len+2
    assert dataset._ds1.shape[0] > 9900
    assert dataset._ds1.shape[1] == max_len+2
    assert dataset.get_data_size() == dataset.shape[0], dataset.get_data_size()

StyleDataset:
  path_0: data/math_style/math_dataset_0_train.data
  path_1: data/math_style/math_dataset_1_train.data
  shape: [19998, 34]

StyleDataset:
  path_0: data/math_style/math_dataset_0_valid.data
  path_1: data/math_style/math_dataset_1_valid.data
  shape: [20000, 34]



## Batch

In [6]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = StyleDataset(MathDataset)
dataset.restore(tr_dataset_path)

sents, styles = dataset.get_next_batch(14)
for sent, style in zip(sents, styles):
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print("%d: %s" % (style, restored))

0: <BOS>14 = 11+2+1<EOS>_____________________
0: <BOS>6+3+2+1 = 9+1+2<EOS>_________________
0: <BOS>4+1+5+6 = 2+2+3+3+3+3<EOS>___________
0: <BOS>13+1 = 1+1+8+3+1<EOS>________________
0: <BOS>16+2 = 10+1+1+6<EOS>_________________
0: <BOS>6+7 = 2+11<EOS>______________________
0: <BOS>1+9+3 = 2+11<EOS>____________________
1: <BOS>1+3+6+1=11<EOS>______________________
1: <BOS>11+2=3+1+9<EOS>______________________
1: <BOS>2+13+1=2+7+1+6<EOS>__________________
1: <BOS>2+8+6+1+1+1=1+1+3+2+6+2+4<EOS>_______
1: <BOS>6+5+1+1=11+1+1<EOS>__________________
1: <BOS>2+10=5+1+6<EOS>______________________
1: <BOS>2+10=3+3+6<EOS>______________________
