# MathStyleDataset

In [1]:
import numpy as np
from math_generator import MathGenerator
from math_vocab import MathVocab
from math_dataset import MathDataset
from utils import make_dir

vocab_path      = "data/math_style/math_vocab.data"

tr_text_path    = "texts/math_style/math_%s_train.txt"
va_text_path    = "texts/math_style/math_%s_valid.txt"
tr_dataset_path = "data/math_style/math_dataset_%s_train.data"
va_dataset_path = "data/math_style/math_dataset_%s_valid.data"

max_len     = 32
gen_min_sum = 10
gen_max_sum = 20
gen_num     = 10000

## Generator

In [3]:
def apply_style(text, style):
    if style == 0:
        return text
    else:
        return text.replace(" = ", "=")        
    
for style in [0,1]:
    for path in [tr_text_path, va_text_path]:
        path = path % style
        gen = MathGenerator(gen_min_sum, gen_max_sum)
        sents = gen.generate('=', gen_num)
        text = '\n'.join(sents)
        text = apply_style(text, style)

        print('\n%s:\n'%path)
        print(text[:100] + " ...")

        make_dir(path)
        with open(path, 'w') as f:
            f.write(text)


texts/math_style/math_0_train.txt:

8+1+9 = 18
16 = 1+1+12+2
7+2+1 = 1+9
16 = 5+11
2+13 = 1+13+1
18 = 8+2+7+1
15+1 = 10+6
7+6+1 = 4+1+8+ ...

texts/math_style/math_0_valid.txt:

2+2+2+3+1+1 = 1+1+2+7
1+7+6 = 3+4+2+5
6+6+1+1+1 = 12+1+1+1
1+9+1 = 4+6+1
16+1+1 = 5+12+1
12+5 = 1+6+ ...

texts/math_style/math_1_train.txt:

9+4+1=7+1+5+1
1+7+4=7+5
4+8+4+2=12+6
1+15+1=4+13
3+1+6+2=1+10+1
6+1+2+1+2=2+10
11+1+1=4+9
8+4=4+4+1+ ...

texts/math_style/math_1_valid.txt:

4+1+5+1=1+1+1+2+6
2+3+6=2+1+3+1+4
1+11+7=1+6+4+8
1+6+8+2=1+12+4
1+18=17+1+1
8+2+1=4+2+5
2+10=7+1+4
8 ...


## Vocab

In [4]:
voc = MathVocab()
voc.build([t % s for t in [tr_text_path, va_text_path] for s in [0,1]])
voc.save(vocab_path)
print(voc)
assert voc.size == 17, voc.size # '0123456789 =+' + UNK, EOS, BOS, PAD

MathVocab:
  size: 17
  _tokens_to_words: ['<UNK>', '<BOS>', '<EOS>', '<PAD>', '+', '1', '2', '=', ' ', '3', '4', '5', '6', '7', '8', '9', '0']


## Dataset

In [30]:
import numpy as np
import nltk
from utils import sent_to_words, make_dir
from vocabulary import BOS_CODE, EOS_CODE, PAD_CODE
import pickle

class MathStyleDataset:
    def __init__(self):
        self._ds0 = MathDataset()
        self._ds1 = MathDataset()
        
    def __str__(self):
        return "%s\n%s" % (self._ds0, self._ds1)
        
    def build(self, path, vocab, max_len, min_len=1):
        self._ds0.build(path % 0, vocab, max_len, min_len)
        self._ds1.build(path % 1, vocab, max_len, min_len)
            
    def get_next_batch(self, bs):
        bs0 = bs//2
        bs1 = bs - bs0
        
        sent0 = self._ds0.get_next_batch(bs0)
        sent1 = self._ds1.get_next_batch(bs1)
        
        return np.concatenate((sent0,sent1), axis=0)
    
    def save(self, path):
        self._ds0.save(path % 0)
        self._ds1.save(path % 1)

    def restore(self, path):
        self._ds0.restore(path % 0)
        self._ds1.restore(path % 1)
        


In [31]:
a = np.array([[1,1]])
b = np.array([[2,2],[3,3]])
np.concatenate((a,b), axis=0)

array([[1, 1],
       [2, 2],
       [3, 3]])

In [32]:
voc = MathVocab()
voc.restore(vocab_path)

dataset_paths = [tr_dataset_path, va_dataset_path]

for i,text_path in enumerate([tr_text_path, va_text_path]):
    dataset = MathStyleDataset()
    dataset.build(text_path, voc, max_len=max_len)
    dataset.save(dataset_paths[i])
    print(dataset)
    assert dataset._ds0.shape[0] > 9900
    assert dataset._ds0.shape[1] == max_len+2
    assert dataset._ds1.shape[0] > 9900
    assert dataset._ds1.shape[1] == max_len+2    

MathDataset:
  path: data/math_style/math_dataset_0_train.data
  shape: [10000, 34]
  data_limit: None
MathDataset:
  path: data/math_style/math_dataset_1_train.data
  shape: [10000, 34]
  data_limit: None
MathDataset:
  path: data/math_style/math_dataset_0_valid.data
  shape: [10000, 34]
  data_limit: None
MathDataset:
  path: data/math_style/math_dataset_1_valid.data
  shape: [10000, 34]
  data_limit: None


## Batch

In [34]:
voc = MathVocab()
voc.restore(vocab_path)

dataset = MathStyleDataset()
dataset.restore(tr_dataset_path)

batch = dataset.get_next_batch(14)
for sent in batch:
    restored = "".join(voc.to_words(sent))
    restored = restored.replace('<PAD>', '_')
    print(restored)

<BOS>5+7 = 1+2+9<EOS>_____________________
<BOS>6+1+4+6+1 = 4+14<EOS>________________
<BOS>9+2+5+2 = 8+5+5<EOS>_________________
<BOS>2+2+9+5 = 2+10+6<EOS>________________
<BOS>5+4+1 = 4+1+5<EOS>___________________
<BOS>17+1 = 1+17<EOS>_____________________
<BOS>6+9 = 1+13+1<EOS>____________________
<BOS>7+9+1=10+7<EOS>______________________
<BOS>2+1+6+3+1=1+2+1+9<EOS>_______________
<BOS>10=8+2<EOS>__________________________
<BOS>8+1+7+1=6+11<EOS>____________________
<BOS>1+3+2+1+4+2=1+12<EOS>________________
<BOS>12+1+2=9+5+1<EOS>____________________
<BOS>16+1=5+4+8<EOS>______________________
