# Data Preprocessing

In [1]:
#!/usr/bin/env python
# -*- coding:utf-8 -*-

__author__ = 'Shining'
__email__ = 'mrshininnnnn@gmail.com'

In [2]:
# dependency
# public
import os
import numpy as np
import Levenshtein
from collections import Counter
# private
from utils import *

In [3]:
# define parameters
method = 'rec'
num_size = 10
seq_len = 5
data_size = 10000

In [9]:
# load path
indir = 'aec'
indir = os.path.join(indir, 
                     'num_size_{}'.format(num_size), 
                     'seq_len_{}'.format(seq_len), 
                     'data_size_{}'.format(data_size))
indir

'aec/num_size_10/seq_len_5/data_size_10000'

In [10]:
# save path
outdir = os.path.join(method, 
                      'num_size_{}'.format(num_size), 
                      'seq_len_{}'.format(seq_len), 
                      'data_size_{}'.format(data_size))
if not os.path.exists(outdir): 
    os.makedirs(outdir)
outdir

'rec/num_size_10/seq_len_5/data_size_10000'

In [11]:
# load raw dataset
raw_train_xs = load_txt(os.path.join(indir, 'train_x.txt'))
raw_train_ys = load_txt(os.path.join(indir, 'train_y.txt'))
raw_val_xs = load_txt(os.path.join(indir, 'val_x.txt'))
raw_val_ys = load_txt(os.path.join(indir, 'val_y.txt'))
raw_test_xs = load_txt(os.path.join(indir, 'test_x.txt'))
raw_test_ys = load_txt(os.path.join(indir, 'test_y.txt'))

In [12]:
# check data size
print('train sample size', len(raw_train_xs))
print('train label size', len(raw_train_ys))
print('val sample size', len(raw_val_xs))
print('val label size', len(raw_val_ys))
print('test sample size', len(raw_test_xs))
print('test label size', len(raw_test_ys))

train sample size 7000
train label size 7000
val sample size 1500
val label size 1500
test sample size 1500
test label size 1500


### Helper Functions

In [13]:
# def levenshtein_editops_list(source, target):
#     unique_elements = sorted(set(source + target)) 
#     char_list = [chr(i) for i in range(len(unique_elements))]
#     if len(unique_elements) > len(char_list):
#         raise Exception("too many elements")
#     else:
#         unique_element_map = {ele:char_list[i]  for i, ele in enumerate(unique_elements)}
#     source_str = ''.join([unique_element_map[ele] for ele in source])
#     target_str = ''.join([unique_element_map[ele] for ele in target])
#     transform_list = Levenshtein.editops(source_str, target_str)
#     return transform_list

# def gen_rec_pair(x: list, y: list) -> list:
#     # white space tokenization
#     x = x.split()
#     y = y.split()
#     xs = [x.copy()]
#     ys_ = []
#     editops = levenshtein_editops_list(x, y)
#     c = 0 
#     for tag, i, j in editops: 
#         i += c
#         if tag == 'replace':
#             y_ = ['<sub>', '<pos_{}>'.format(i), y[j]]
#             x[i] = y[j]
#         elif tag == 'delete': 
#             y_ = ['<delete>', '<pos_{}>'.format(i), '<done>']
#             del x[i]
#             c -= 1
#         elif tag == 'insert': 
#             y_ = ['<insert>', '<pos_{}>'.format(i), y[j]]
#             x.insert(i, y[j]) 
#             c += 1
#         xs.append(x.copy()) 
#         ys_.append(y_)
#     ys_.append(['<done>']*3)
#     index = np.random.choice(range(len(xs)))
#     x = xs[index]
#     y_ = ys_[index]
#     return x, y_, y

### Train

In [14]:
train_xs, train_ys_, train_ys = zip(*[gen_rec_pair(x, y) for x, y in zip(raw_train_xs, raw_train_ys)])

In [15]:
# take a look
for i in range(-10, 0, 1):
    print('src:', train_xs[i])
    print('tgt:', train_ys[i])
    print('pred:', train_ys_[i])
    print()

src: ['5', '-', '7', '+', '5', '*', '4', '==', '8']
tgt: ['5', '-', '7', '+', '5', '*', '2', '==', '8']
pred: ['<sub>', '<pos_6>', '2']

src: ['-', '+', '+', '5', '+', '11', '+', '+', '2', '8']
tgt: ['-', '10', '+', '5', '+', '11', '+', '2', '==', '8']
pred: ['<sub>', '<pos_1>', '10']

src: ['-', '10', '+', '6', '*', '4', '-', '3', '==', '11']
tgt: ['-', '10', '+', '6', '*', '4', '-', '3', '==', '11']
pred: ['<done>', '<done>', '<done>']

src: ['8', '+', '11', '-', '2', '-', '8', '==', '9']
tgt: ['8', '+', '11', '-', '2', '-', '8', '==', '9']
pred: ['<done>', '<done>', '<done>']

src: ['-', '2', '*', '2', '+', '9', '+', '4', '==', '9']
tgt: ['-', '2', '*', '2', '+', '9', '+', '4', '==', '9']
pred: ['<done>', '<done>', '<done>']

src: ['-', '4', '+', '7', '+', '7', '/', '9', '==', '4']
tgt: ['-', '4', '+', '7', '+', '9', '/', '9', '==', '4']
pred: ['<sub>', '<pos_5>', '9']

src: ['2', '+', '6', '-', '6', '6', '+', '11', '==', '7']
tgt: ['2', '+', '6', '-', '6', '+', '5', '==', '7']
pred

In [19]:
sum([y_ == ['<done>', '<done>', '<done>'] for y_ in train_ys_])/len(train_ys_)

0.4705714285714286

In [17]:
# source vocabulary frequency distribution
counter = Counter()
for x in train_xs:
    counter.update(x)

print(len(counter))
print(counter.most_common())

15
[('+', 8658), ('-', 8606), ('==', 6227), ('2', 4895), ('3', 4257), ('4', 4047), ('6', 3836), ('5', 3692), ('8', 3525), ('7', 3140), ('10', 3091), ('*', 3077), ('9', 3042), ('/', 2775), ('11', 2672)]


In [23]:
src_vocab_list = sorted(counter.keys())
print(src_vocab_list)

['*', '+', '-', '/', '10', '11', '2', '3', '4', '5', '6', '7', '8', '9', '==']


In [24]:
# soruce vocabulary dictionary
src_vocab2idx_dict = dict()
src_vocab2idx_dict['<pad>'] = 0 # to pad sequence length

i = len(src_vocab2idx_dict)
for token in src_vocab_list:
    src_vocab2idx_dict[token] = i
    i += 1

print(src_vocab2idx_dict)

{'<pad>': 0, '*': 1, '+': 2, '-': 3, '/': 4, '10': 5, '11': 6, '2': 7, '3': 8, '4': 9, '5': 10, '6': 11, '7': 12, '8': 13, '9': 14, '==': 15}


In [25]:
# target vocabulary frequency distribution
counter = Counter()
for y_ in train_ys_:
    counter.update(y_)

print(len(counter))
print(counter.most_common())

28
[('<done>', 10888), ('<sub>', 1570), ('<delete>', 1096), ('<insert>', 1070), ('<pos_1>', 509), ('<pos_4>', 479), ('<pos_3>', 459), ('<pos_7>', 457), ('<pos_6>', 442), ('<pos_2>', 426), ('<pos_5>', 424), ('<pos_0>', 415), ('-', 408), ('+', 407), ('==', 287), ('2', 162), ('4', 150), ('3', 148), ('/', 145), ('*', 140), ('6', 138), ('<pos_8>', 125), ('8', 121), ('5', 115), ('9', 115), ('10', 111), ('7', 103), ('11', 90)]


In [32]:
tgt_vocab_list = list(counter.keys())
for i in range(seq_len*2): 
    if '<pos_{}>'.format(i) not in tgt_vocab_list:
        tgt_vocab_list.append('<pos_{}>'.format(i))
tgt_vocab_list.sort()
print(tgt_vocab_list)

['*', '+', '-', '/', '10', '11', '2', '3', '4', '5', '6', '7', '8', '9', '<delete>', '<done>', '<insert>', '<pos_0>', '<pos_1>', '<pos_2>', '<pos_3>', '<pos_4>', '<pos_5>', '<pos_6>', '<pos_7>', '<pos_8>', '<pos_9>', '<sub>', '==']


In [486]:
# target vocabulary dictionary
tgt_vocab2idx_dict = dict()
tgt_vocab2idx_dict['<pad>'] = 0 # to pad sequence length
tgt_vocab2idx_dict['<s>'] = 1 # to mark the start of a sequence

i = len(tgt_vocab2idx_dict)
for token in tgt_vocab_list:
    tgt_vocab2idx_dict[token] = i
    i += 1

print(tgt_vocab2idx_dict)

{'<pad>': 0, '<s>': 1, '*': 2, '+': 3, '-': 4, '/': 5, '10': 6, '11': 7, '2': 8, '3': 9, '4': 10, '5': 11, '6': 12, '7': 13, '8': 14, '9': 15, '<delete>': 16, '<done>': 17, '<insert>': 18, '<pos_0>': 19, '<pos_1>': 20, '<pos_2>': 21, '<pos_3>': 22, '<pos_4>': 23, '<pos_5>': 24, '<pos_6>': 25, '<pos_7>': 26, '<pos_8>': 27, '<sub>': 28, '==': 29}


### Val

In [487]:
# white space tokenization
val_xs = white_space_tokenizer(raw_val_xs)
val_ys = white_space_tokenizer(raw_val_ys)

In [488]:
# take a look
for i in range(-10, 0, 1):
    print('src:', val_xs[i])
    print('tgt:', val_ys[i])
    print()

src: ['5', '/', '9', '3', '2', '/', '3']
tgt: ['5', '/', '9', '*', '9', '-', '2', '==', '3']

src: ['-', '5', '-', '4', '+', '2', '*', '10', '==', '6']
tgt: ['-', '7', '-', '7', '+', '2', '*', '10', '==', '6']

src: ['-', '+', '-', '9', '+', '6', '6', '9', '==', '2']
tgt: ['10', '-', '9', '+', '6', '/', '6', '==', '2']

src: ['8', '-', '7', '-', '6', '9', '+', '10', '==', '2']
tgt: ['8', '-', '7', '-', '9', '+', '10', '==', '2']

src: ['-', '4', '-', '*', '2', '9', '+', '8', '==', '11']
tgt: ['-', '4', '-', '2', '+', '9', '+', '8', '==', '11']

src: ['-', '8', '*', '+', '2', '+', '7', '==', '4']
tgt: ['-', '8', '+', '3', '+', '2', '+', '7', '==', '4']

src: ['8', '4', '*', '2', '-', '10', '-', '==', '11']
tgt: ['4', '*', '6', '-', '2', '-', '11', '==', '11']

src: ['3', '+', '-', '2', '-', '9', '==', '5']
tgt: ['3', '+', '7', '-', '2', '-', '3', '==', '5']

src: ['10', '8', '2', '*', '7', '-', '6', '==', '10']
tgt: ['10', '-', '2', '+', '7', '-', '5', '==', '10']

src: ['*', '9', '8', 

## Test

In [489]:
# white space tokenization
test_xs = white_space_tokenizer(raw_test_xs)
test_ys = white_space_tokenizer(raw_test_ys)

In [490]:
# take a look
for i in range(-10, 0, 1):
    print('src:', test_xs[i])
    print('tgt:', test_ys[i])
    print()

src: ['8', '-', '11', '*', '==', '2']
tgt: ['8', '-', '6', '/', '4', '*', '4', '==', '2']

src: ['4', '-', '7', '+', '8', '8', '==', '10']
tgt: ['4', '+', '5', '-', '7', '+', '8', '==', '10']

src: ['-', '-', '8', '3', '+', '6', '4', '10', '==', '8']
tgt: ['-', '5', '-', '3', '+', '6', '+', '10', '==', '8']

src: ['8', '8', '/', '4', '-', '4', '7', '7']
tgt: ['8', '*', '8', '/', '4', '-', '9', '==', '7']

src: ['7', '/', '2', '*', '8', '2', '/', '9', '4']
tgt: ['9', '/', '2', '*', '8', '/', '9', '==', '4']

src: ['3', '*', '5', '10', '-', '9', '==', '7']
tgt: ['3', '*', '2', '+', '10', '-', '9', '==', '7']

src: ['2', '*', '4', '+', '9', '6', '-', '7', '7', '==', '3']
tgt: ['2', '*', '2', '+', '6', '-', '7', '==', '3']

src: ['-', '+', '+', '-', '3', '-', '3', '==', '7']
tgt: ['11', '+', '2', '-', '3', '-', '3', '==', '7']

src: ['4', '-', '3', '+', '+', '2', '9', '2', '+', '2', '3']
tgt: ['-', '3', '+', '2', '+', '2', '+', '2', '==', '3']

src: ['2', '/', '5', '*', '*', '5', '+', '2',

In [491]:
# combine data sets to a dict
train_dict = {}
train_dict['ys'] = train_ys

val_dict = {}
val_dict['xs'] = val_xs
val_dict['ys'] = val_ys

test_dict = {}
test_dict['xs'] = test_xs
test_dict['ys'] = test_ys

data_dict = dict()
data_dict['train'] = train_dict
data_dict['val'] = val_dict
data_dict['test'] = test_dict

vocab_dict = dict()
vocab_dict['src'] = src_vocab2idx_dict
vocab_dict['tgt'] = tgt_vocab2idx_dict

In [492]:
# save output as json
data_path = os.path.join(outdir, 'data.json')
vocab_path = os.path.join(outdir, 'vocab.json')

save_json(data_path, data_dict)
save_json(vocab_path, vocab_dict)

# Archive Code

In [437]:
raw_train_xs = load_txt(os.path.join(indir, 'train_x.txt')) 
raw_train_ys = load_txt(os.path.join(indir, 'train_y.txt'))
x = raw_train_xs[4]
y = raw_train_ys[4]
x = x.split()
y = y.split()
print(x)
print(y)
print()

['-', '2', '+', '+', '*', '11', '-', '6', '5', '4', '9']
['-', '2', '+', '5', '+', '11', '-', '5', '==', '9']



In [11]:
import Levenshtein 

def levenshtein_editops_list(source, target):
    unique_elements = sorted(set(source + target)) 
    char_list = [chr(i) for i in range(len(unique_elements))]
    if len(unique_elements) > len(char_list):
        raise Exception("too many elements")
    else:
        unique_element_map = {ele:char_list[i]  for i, ele in enumerate(unique_elements)}
    source_str = ''.join([unique_element_map[ele] for ele in source])
    target_str = ''.join([unique_element_map[ele] for ele in target])
    transform_list = Levenshtein.editops(source_str, target_str)
    return transform_list

for i in range(7000):

    x = raw_train_xs[i]
    y = raw_train_ys[i]
    x = x.split()
    y = y.split()
#     print('src', x)
#     print('tgt', y)

    editops = levenshtein_editops_list(x, y)
    c = 0
    for tag, i, j in editops: 
        i += c
#         print(tag, i, j)
        if tag == 'replace':
            x[i] = y[j]
        elif tag == 'delete':
            del x[i]
            c -= 1
        elif tag == 'insert':
            x.insert(i, y[j]) 
            c += 1
    if x != y: 
        print(i)
        print('src', x)
        print('tgt', y)
        break

In [33]:
x = raw_train_xs[4]
y = raw_train_ys[4]
x = x.split()
y = y.split()
print(x)
print(y)
print()

['-', '2', '+', '+', '*', '11', '-', '6', '5', '4', '9']
['-', '2', '+', '5', '+', '11', '-', '5', '==', '9']



In [445]:
# for online end2end 
xs = [x.copy()] 
editops = levenshtein_editops_list(x, y)
c = 0 
for tag, i, j in editops: 
    i += c
    if tag == 'replace':
        x[i] = y[j]
    elif tag == 'delete':
        del x[i]
        c -= 1
    elif tag == 'insert':
        x.insert(i, y[j]) 
        c += 1
    xs.append(x.copy())

for x in xs:
    print(x)

['-', '2', '+', '+', '*', '11', '-', '6', '5', '4', '9']
['-', '2', '+', '5', '*', '11', '-', '6', '5', '4', '9']
['-', '2', '+', '5', '+', '11', '-', '6', '5', '4', '9']
['-', '2', '+', '5', '+', '11', '-', '5', '4', '9']
['-', '2', '+', '5', '+', '11', '-', '5', '==', '9']


In [34]:
# for offline recurrent inference
editops = levenshtein_editops_list(x, y)
tag, i, j = editops[0]
if tag == 'replace':
    y_ = ['<sub>', '<pos_{}>'.format(i), y[j]]
elif tag == 'delete': 
    y_ = ['<delete>', '<pos_{}>'.format(i), '<done>'] 
elif tag == 'insert': 
    y_ = ['<insert>', '<pos_{}>'.format(i), y[j]] 
print(y)
print(x)
print(y_)

['-', '2', '+', '5', '+', '11', '-', '5', '==', '9']
['-', '2', '+', '+', '*', '11', '-', '6', '5', '4', '9']
['<sub>', '<pos_3>', '5']


In [455]:
# for online recurrent inference
xs = [x.copy()]
ys_ = []
editops = levenshtein_editops_list(x, y)
c = 0 
for tag, i, j in editops: 
    i += c
    if tag == 'replace':
        y_ = ['<sub>', '<pos_{}>'.format(i), y[j]]
        x[i] = y[j]
    elif tag == 'delete': 
        y_ = ['<delete>', '<pos_{}>'.format(i), '<done>']
        del x[i]
        c -= 1
    elif tag == 'insert': 
        y_ = ['<insert>', '<pos_{}>'.format(i), y[j]]
        x.insert(i, y[j]) 
        c += 1

    xs.append(x.copy()) 
    ys_.append(y_)

ys_.append(['<done>']*3)

for x, y_ in zip(xs, ys_):
    print(x)
    print(y_)

['-', '2', '+', '+', '*', '11', '-', '6', '5', '4', '9']
['<sub>', '<pos_3>', '5']
['-', '2', '+', '5', '*', '11', '-', '6', '5', '4', '9']
['<sub>', '<pos_4>', '+']
['-', '2', '+', '5', '+', '11', '-', '6', '5', '4', '9']
['<delete>', '<pos_7>', '<done>']
['-', '2', '+', '5', '+', '11', '-', '5', '4', '9']
['<sub>', '<pos_8>', '==']
['-', '2', '+', '5', '+', '11', '-', '5', '==', '9']
['<done>', '<done>', '<done>']
