# Dataset for Text Editing Task

This notebook mainly shows the code to generate a dataset for inserting operators. For example, the input consists of a sequence of positive real numbers (e.g., 112), and its corresponding outputs should be a correct math equation (e.g., 1+1=2) by inserting operators at the right positions.

## Notes
Both the input and output should be the same single real number when the input sequence length is 1. For example, "1"$\rightarrow$"1."  
In other cases, operators are necessary to hold the output equation. For example, "1 1 2"$\rightarrow$"1 + 1 = 2."
1. sequence length  
output_seq_len = 2 * input_seq_len - 1
2. vocab size  
This parameter stands for the unique number of digits involved in the inputs. For exmaple, samples range from 0 to 9 if input vocab size is 10.
3. data size

In [5]:
#!/usr/bin/env python
# -*- coding:utf-8 -*-

__author__ = 'Shining'
__email__ = 'mrshininnnnn@gmail.com'

In [6]:
# dependency
import os
import numpy as np
from utils import save_txt

In [7]:
# helper functions
def save_txt(path, line_list):
    with open(path, 'w', encoding='utf-8') as f: 
        for line in line_list: 
            f.write(line + '\n') 
    f.close()

In [8]:
# the calss to generate dataset
# for math operator intertion task
class MathematicalOperatorInsertion(): 
    """docstring for ClassName"""
    def __init__(self, operators):
        super(MathematicalOperatorInsertion, self).__init__()
        self.operators = operators
    
    def gen_base_dataset(self, vocab_size):
        # return a base dataset
        x = [str(i) for i in range(vocab_size)]
        y = x.copy()
        return x, y
    
    def gen_base_dict(self, vocab_size):
        # initialize a base value dict
        return {str(i):[] for i in range(vocab_size)}
        
    def gen_operation(self, vocab_size, seq_len):
        # a recursive function to geneate an operation
        # given the number of digits to involve
        a = np.random.choice(range(vocab_size))
        o = np.random.choice(self.operators)
        b = np.random.choice(range(vocab_size))
        if seq_len == 1:
            return [str(a)]
        else:
            out_set = self.gen_operation(vocab_size, seq_len-1)
            return out_set + [o, str(b)]
    
    def gen_operation_list(self, vocab_size, seq_len, data_size):
        # to control the data size
        counter = 1
        operations_pool = set()
        while True:
            # to avoid duplicates
            operation = self.gen_operation(vocab_size, seq_len)
            if ''.join(operation) in operations_pool:
                continue
            else:
                operations_pool.add(''.join(operation))
            # to avoid zero division error
            try: 
                # flost to int to string
                value = eval(' '.join(operation))
                if value % 1 != 0.:
                    continue
                else:
                    value = str(int(value))
                    # to keep vocab size
                    if value in self.value_dict: 
                        self.value_dict[value].append(operation)
                        if counter >= data_size:
                            break
                        else:
                            counter += 1
            except: 
                pass
            if len(operations_pool) >= self.space_size: 
                break
    
    def gen_equation_list(self):
        # generate the relational equation
        # given the value dict
        for v in self.value_dict:
            for x in self.value_dict[v]:
                y = x + ["=="] + [v]
                x = [i for i in y if i.isdigit()]
                self.xs.append(' '.join(x))
                self.ys.append(' '.join(y))

    def generate(self, vocab_size, seq_len, data_size):
        if seq_len == 0:
            return self.gen_base_dataset(
                vocab_size=vocab_size)
        # input sequences, # output sequences
        self.xs, self.ys = [], []
        # the max data size
        self.space_size = vocab_size**seq_len*len(self.operators)**(seq_len-1)
        # initialize a value dictionary
        # to save the value of each sequence
        self.value_dict = self.gen_base_dict(
            vocab_size=vocab_size)
        # insert operators and generate equations
        self.gen_operation_list(
            vocab_size=vocab_size, 
            seq_len=seq_len, 
            data_size=data_size)
        # generate relations given the value dict
        self.gen_equation_list()
        
        return self.xs, self.ys

In [9]:
# definition
vocab_size = 10
seq_len = 5 # must >= 0
data_size = 100
operators = ['+', '-', '*', '/']
print('space size', vocab_size**seq_len*len(operators)**(seq_len-1))

space size 25600000


In [10]:
# data generation
moi = MathematicalOperatorInsertion(operators)
xs, ys = moi.generate(
    vocab_size=vocab_size, 
    seq_len=seq_len, 
    data_size=data_size)

In [11]:
len(xs)

100

In [12]:
xs[-15:]

['5 6 9 9 3 8',
 '8 8 1 0 0 8',
 '2 6 6 2 3 8',
 '0 8 3 3 0 8',
 '2 6 8 2 0 8',
 '6 7 8 0 8 8',
 '4 0 4 3 3 8',
 '7 0 7 9 9 8',
 '9 6 2 7 0 8',
 '2 9 3 2 2 8',
 '0 0 9 0 9 9',
 '0 9 6 6 9 9',
 '3 3 5 1 2 9',
 '9 3 0 8 4 9',
 '0 4 8 4 3 9']

In [13]:
len(ys)

100

In [14]:
ys[-15:]

['5 + 6 * 9 / 9 - 3 == 8',
 '8 - 8 / 1 * 0 * 0 == 8',
 '2 + 6 + 6 / 2 - 3 == 8',
 '0 + 8 * 3 / 3 + 0 == 8',
 '2 * 6 - 8 / 2 - 0 == 8',
 '6 / 7 / 8 * 0 + 8 == 8',
 '4 + 0 + 4 / 3 * 3 == 8',
 '7 - 0 * 7 + 9 / 9 == 8',
 '9 - 6 - 2 + 7 + 0 == 8',
 '2 + 9 - 3 * 2 / 2 == 8',
 '0 + 0 / 9 + 0 + 9 == 9',
 '0 * 9 / 6 * 6 + 9 == 9',
 '3 + 3 + 5 - 1 * 2 == 9',
 '9 + 3 * 0 / 8 / 4 == 9',
 '0 + 4 + 8 / 4 + 3 == 9']

In [422]:
# to save dataset
data_size = len(xs)
print(len(ys), data_size)
outdir = 'raw'
outdir = os.path.join(outdir, 'vocab_size_{}'.format(vocab_size), 
                    'seq_len_{}'.format(seq_len+1), 
                    'data_size_{}'.format(data_size))
if not os.path.exists(outdir): 
    os.makedirs(outdir)
outdir

152 152


'raw/vocab_size_10/seq_len_3/data_size_152'

In [15]:
# save_txt(os.path.join(outdir, 'x.txt'), xs)
# save_txt(os.path.join(outdir, 'y.txt'), ys)