In [18]:
%matplotlib inline
import mxnet as mx
ctx=mx.cpu(0)
import logging
logging.getLogger().setLevel(logging.DEBUG)
from char_utils import *

In [25]:
num_hidden=64
embed_size=64
batch_size=100
dataset_size=10000

In [26]:
train_set, inverse_train_set, eval_set, inverse_eval_set = generate_chars_train_eval_sets(dataset_size=dataset_size)

train_iter = generate_chars_iterators(train_set=train_set, label_set=inverse_train_set, batch_size=batch_size)
eval_iter = generate_chars_iterators(train_set=eval_set, label_set=inverse_eval_set, batch_size=batch_size)

In [27]:
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')

embed = mx.sym.Embedding(
    data=data,
    input_dim=vocab_size,
    output_dim=embed_size
)

bi_cell = mx.rnn.BidirectionalCell(
    mx.rnn.GRUCell(num_hidden=num_hidden, prefix="gru1_"),
    mx.rnn.GRUCell(num_hidden=num_hidden, prefix="gru2_"),
    output_prefix="bi_"
)

encoder = mx.rnn.ResidualCell(bi_cell)
        
_, encoder_state = encoder.unroll(
    length=max_string_len,
    inputs=embed,
    merge_outputs=False
)

encoder_state = mx.sym.concat(encoder_state[0][0],encoder_state[1][0])

decoder = mx.rnn.GRUCell(num_hidden=num_hidden*2)

rnn_output, decoder_state = decoder.unroll(
    length=num_hidden*2,
    inputs=encoder_state,
    merge_outputs=True
)

flat=mx.sym.Flatten(data=rnn_output)

fc=mx.sym.FullyConnected(
    data=flat,
    num_hidden=max_string_len*vocab_size
)
#drop=mx.sym.Dropout(data=fc, p=0.5)
act=mx.sym.Activation(data=fc, act_type='relu')


out = mx.sym.Reshape(data=act, shape=((0,max_string_len,vocab_size)))

#out=mx.sym.round(out)

net = mx.sym.LinearRegressionOutput(data=out, label=label)

In [28]:
model = mx.module.Module(net)
model.fit(
    train_data=train_iter,
    eval_data=eval_iter,
    eval_metric = 'acc',
    optimizer=mx.optimizer.Adam(rescale_grad=1/batch_size),
    #optimizer_params={'learning_rate':0.001, 'momentum':0.9},
    initializer=mx.initializer.Xavier(),
    batch_end_callback=mx.callback.Speedometer(batch_size, 10),
    num_epoch=8
)

INFO:root:Epoch[0] Batch [10]	Speed: 144.64 samples/sec	accuracy=0.880303
INFO:root:Epoch[0] Batch [20]	Speed: 153.50 samples/sec	accuracy=0.882650
INFO:root:Epoch[0] Batch [30]	Speed: 153.83 samples/sec	accuracy=0.881433
INFO:root:Epoch[0] Batch [40]	Speed: 144.31 samples/sec	accuracy=0.881317
INFO:root:Epoch[0] Batch [50]	Speed: 146.15 samples/sec	accuracy=0.882742
INFO:root:Epoch[0] Batch [60]	Speed: 141.91 samples/sec	accuracy=0.883058
INFO:root:Epoch[0] Batch [70]	Speed: 144.96 samples/sec	accuracy=0.884042
INFO:root:Epoch[0] Batch [80]	Speed: 158.92 samples/sec	accuracy=0.884625
INFO:root:Epoch[0] Batch [90]	Speed: 159.08 samples/sec	accuracy=0.886483
INFO:root:Epoch[0] Train-accuracy=0.885815
INFO:root:Epoch[0] Time cost=66.623
INFO:root:Epoch[0] Validation-accuracy=0.886467
INFO:root:Epoch[1] Batch [10]	Speed: 160.80 samples/sec	accuracy=0.887818
INFO:root:Epoch[1] Batch [20]	Speed: 163.24 samples/sec	accuracy=0.889400
INFO:root:Epoch[1] Batch [30]	Speed: 162.34 samples/sec	acc

In [31]:
import difflib

testset_size=100

test_set, inverse_test_set, _, _ = generate_chars_train_eval_sets(dataset_size=testset_size)

test_iter = generate_chars_iterators(train_set=test_set, label_set=inverse_test_set, batch_size=1)

predictions=model.predict(test_iter)

match_count=0
for i,pred in enumerate(predictions):
    matched = ints2text(onehot2int(mx.ndarray.round(predictions[i]))) == ints2text(inverse_test_set[i])
    if matched:
        match_count+=1
    else:       
        print(i)
        inverse=ints2text(inverse_test_set[i])
        print(inverse)
        inverse_pred=ints2text(onehot2int(mx.ndarray.round(predictions[i])))
        print(inverse_pred)
        print(matched)
        for i,s in enumerate(difflib.ndiff(inverse, inverse_pred)):
            if s[0]==' ': continue
            elif s[0]=='-':
                print(u'Delete "{}" from position {}'.format(s[-1],i))
            elif s[0]=='+':
                print(u'Add "{}" to position {}'.format(s[-1],i))    
        print("--------------------")

print("Matched %d/%d times" % (match_count,testset_size))
    

4
#edfbbabdebdcd§
#edfbbabdebdad§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
5
#dedcdcfcacecc§
#dedcdcfcaceac§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
13
#fbbcbeedbeacc§
#fbbcbeedbeaac§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
26
#eaacafbfecacc§
#eaacafbfecaac§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
41
#afefbabbecdcd§
#afefbabbecdad§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
55
#eebadeffabacf§
#eebadeffabaaf§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
59
#eaafcabeecaca§
#eaafcabeecaaa§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
60
#accaedabdcfcf§
#accaedabdcfaf§
False
Delete "c" from position 12
Add "a" to position 13
--------------------
62
#ffceffdebdeca§
#ffceffdebdeaa§
False
Delete "c" from position 12
Add "a" to position 13
------