In [156]:
%matplotlib inline
from random import choice, randrange
import mxnet as mx
import numpy as np
ctx=mx.cpu(0)
import logging
logging.getLogger().setLevel(logging.DEBUG)

In [231]:
vocabulary=list("abcdef")
EOS='§'
SOS='#'
vocabulary.append(EOS)
vocabulary.append(SOS)
vocab_size=len(vocabulary)
MAX_STRING_LEN = 15
MAX_INPUST_LEN = 100
num_hidden=64
embed_size=64
batch_size=100
int2char = {i:c for i,c in enumerate(vocabulary)}
char2int = {c:i for i,c in enumerate(vocabulary)}

In [232]:
print(int2char)
print("vocab size: "+str(vocab_size))

{0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: '§', 7: '#'}
vocab size: 8


In [233]:
def generate_strings(min_len, max_len):
    random_length = randrange(min_len, max_len)
    random_char_list = [choice(vocabulary[:-2]) for _ in range(random_length)]
    random_string = ''.join(random_char_list) 
    return SOS+random_string+EOS

In [234]:
def text2ints(string):
    return [char2int[char] for char in string]

def ints2text(numbers):
    return ''.join([int2char[num] for num in numbers])

def int2onehot(numbers):
    return mx.nd.one_hot(mx.nd.array(numbers),vocab_size)

def onehot2int(matrix):
    fin=[]
    for vec in matrix:
        fin.append(int(vec.argmax(axis=0).asnumpy().tolist()[0]))
    return fin

In [235]:
string=generate_strings(MAX_STRING_LEN-2, MAX_STRING_LEN-1)
print(string, len(string))
assert ints2text(text2ints(string)) == string

#aeeadbbdfabec§ 15


In [236]:
set_size=10000
train_set = [text2ints(generate_strings(MAX_STRING_LEN-2, MAX_STRING_LEN-1)) for _ in range(set_size)]
inverse_train_set = [[char2int[SOS]]+sentence[1:-1][::-1]+[char2int[EOS]] for sentence in train_set]
eval_set = [text2ints(generate_strings(MAX_STRING_LEN-2, MAX_STRING_LEN-1)) for _ in range(set_size//10)]
inverse_eval_set = [[char2int[SOS]]+sentence[1:-1][::-1]+[char2int[EOS]] for sentence in eval_set]

#print(train_set[0])
#print(inverse_train_set[0])

In [237]:
train_iter = mx.io.NDArrayIter(
    data=mx.nd.one_hot(mx.nd.array(train_set),vocab_size),
    label=mx.nd.one_hot(mx.nd.array(inverse_train_set),vocab_size),
    batch_size=batch_size
)
eval_iter = mx.io.NDArrayIter(
    data=mx.nd.one_hot(mx.nd.array(eval_set),vocab_size),
    label=mx.nd.one_hot(mx.nd.array(inverse_eval_set),vocab_size),
    batch_size=batch_size
)

In [238]:
#def sym_gen(seq_len):
    

In [239]:
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')

embed = mx.sym.Embedding(
    data=data,
    input_dim=vocab_size,
    output_dim=embed_size
)

In [240]:
bi_cell = mx.rnn.BidirectionalCell(
    mx.rnn.GRUCell(num_hidden=num_hidden, prefix="gru1_"),
    mx.rnn.GRUCell(num_hidden=num_hidden, prefix="gru2_"),
    output_prefix="bi_"
)

encoder = mx.rnn.ResidualCell(bi_cell)
        
_, encoder_state = encoder.unroll(
    length=MAX_STRING_LEN,
    inputs=embed,
    merge_outputs=False
)

encoder_state = mx.sym.concat(encoder_state[0][0],encoder_state[1][0])

decoder = mx.rnn.GRUCell(num_hidden=num_hidden*2)

rnn_output, decoder_state = decoder.unroll(
    length=num_hidden*2,
    inputs=encoder_state,
    merge_outputs=True
)

In [241]:
flat=mx.sym.Flatten(data=rnn_output)

fc=mx.sym.FullyConnected(
    data=flat,
    num_hidden=MAX_STRING_LEN*vocab_size
)
#drop=mx.sym.Dropout(data=fc, p=0.5)
act=mx.sym.Activation(data=fc, act_type='relu')


out = mx.sym.Reshape(data=act, shape=((0,MAX_STRING_LEN,vocab_size)))

#out=mx.sym.round(out)

In [242]:
net = mx.sym.LinearRegressionOutput(data=out, label=label)

In [243]:
model = mx.module.Module(net)
model.fit(
    train_data=train_iter,
    eval_data=eval_iter,
    eval_metric = 'acc',
    optimizer=mx.optimizer.Adam(rescale_grad=1/batch_size),
    #optimizer_params={'learning_rate':0.001, 'momentum':0.9},
    initializer=mx.initializer.Xavier(),
    batch_end_callback=mx.callback.Speedometer(batch_size, 10),
    num_epoch=8
)

INFO:root:Epoch[0] Batch [10]	Speed: 135.35 samples/sec	accuracy=0.881182
INFO:root:Epoch[0] Batch [20]	Speed: 122.81 samples/sec	accuracy=0.881942
INFO:root:Epoch[0] Batch [30]	Speed: 135.84 samples/sec	accuracy=0.878783
INFO:root:Epoch[0] Batch [40]	Speed: 159.03 samples/sec	accuracy=0.882467
INFO:root:Epoch[0] Batch [50]	Speed: 172.17 samples/sec	accuracy=0.881692
INFO:root:Epoch[0] Batch [60]	Speed: 170.01 samples/sec	accuracy=0.884733
INFO:root:Epoch[0] Batch [70]	Speed: 150.54 samples/sec	accuracy=0.884958
INFO:root:Epoch[0] Batch [80]	Speed: 104.80 samples/sec	accuracy=0.886033
INFO:root:Epoch[0] Batch [90]	Speed: 103.08 samples/sec	accuracy=0.887933
INFO:root:Epoch[0] Train-accuracy=0.887963
INFO:root:Epoch[0] Time cost=75.733
INFO:root:Epoch[0] Validation-accuracy=0.893317
INFO:root:Epoch[1] Batch [10]	Speed: 104.26 samples/sec	accuracy=0.892386
INFO:root:Epoch[1] Batch [20]	Speed: 104.51 samples/sec	accuracy=0.891442
INFO:root:Epoch[1] Batch [30]	Speed: 102.26 samples/sec	acc

In [244]:
import difflib

items=50

test_set = [text2ints(generate_strings(MAX_STRING_LEN-2, MAX_STRING_LEN-1)) for _ in range(items)]

#print(test_set[0])
#print(ints2text(test_set[0]))

inverse_test_set = [[char2int[SOS]]+sentence[1:-1][::-1]+[char2int[EOS]] for sentence in test_set]    

#print(inverse_test_set[0])
#print(ints2text(inverse_test_set[0]))
    
#print(mx.nd.one_hot(mx.nd.array(test_set[0]),vocab_size))
#print(mx.nd.one_hot(mx.nd.array(inverse_test_set[0]),vocab_size))

test_iter = mx.io.NDArrayIter(
    data=mx.nd.one_hot(mx.nd.array(test_set),vocab_size),
    label=mx.nd.one_hot(mx.nd.array(inverse_test_set),vocab_size),
    batch_size=1
)

#test_iter.reset()
predictions=model.predict(test_iter)

match_count=0
for i,pred in enumerate(predictions):
    matched = ints2text(onehot2int(mx.ndarray.round(predictions[i]))) == ints2text(inverse_test_set[i])
    if matched:
        match_count+=1
    else:       
        print(i)
        inverse=ints2text(inverse_test_set[i])
        print(inverse)
        inverse_pred=ints2text(onehot2int(mx.ndarray.round(predictions[i])))
        print(inverse_pred)
        print(matched)
        for i,s in enumerate(difflib.ndiff(inverse, inverse_pred)):
            if s[0]==' ': continue
            elif s[0]=='-':
                print(u'Delete "{}" from position {}'.format(s[-1],i))
            elif s[0]=='+':
                print(u'Add "{}" to position {}'.format(s[-1],i))    
        print("--------------------")

print("Matched %d/%d times" % (match_count,items))
    

Matched 50/50 times
