In [None]:
%matplotlib inline
import mxnet as mx

ctx=mx.cpu(0)

import logging
logging.getLogger().setLevel(logging.DEBUG)

from masked_bucket_io import MaskedBucketSentenceIter
from xutils import read_content, load_vocab, sentence2id


In [None]:
num_hidden=512
embed_size=512
batch_size=100
dataset_size=10000

In [None]:
bucket_stride = 8
buckets = []
for i in range(8, 128, bucket_stride):
    for j in range(8, 128, bucket_stride):
        buckets.append((i, j))

bos_word = '<s>'
eos_word = '</s>'
unk_word = '<unk>'
special_words = {unk_word: 1, bos_word: 2, eos_word: 3}

train_source = 'english'
train_target = 'italian'
source_vocab_path = train_source + ".pkl"
target_vocab_path = train_target + ".pkl"

source_vocab = load_vocab(source_vocab_path, special_words)
inverted_source_vocab= {source_vocab[word]:word for word in source_vocab}
inverted_source_vocab[0]=''

target_vocab = load_vocab(target_vocab_path, special_words)
inverted_target_vocab = {target_vocab[word]:word for word in target_vocab}
inverted_target_vocab[0]=''

In [None]:
data_train = MaskedBucketSentenceIter(
    train_source,
    train_target,
    source_vocab,
    target_vocab,
    buckets,
    batch_size,
    text2id=sentence2id,
    read_content=read_content,
    max_read_sample=100000
)

In [None]:
data_train.reset()
item=data_train.next()
print(item)
for i in range(batch_size):
    print("source")
    print([inverted_source_vocab[int(i.asnumpy())] for i in item.data[0][i]])
    print("mask")
    print([int(i.asnumpy()) for i in item.data[1][i]])
    print("target")
    print([inverted_target_vocab[int(i.asnumpy())] for i in item.data[2][i]])
    print("label")
    print([inverted_target_vocab[int(i.asnumpy())] for i in item.label[0][i]])
    print()
data_train.reset()

In [None]:
def s2s_unroll(
    source_len,
    target_len,
    input_names,
    output_names,
    source_vocab_size,
    target_vocab_size,
    **kwargs):

    source = mx.sym.Variable('source')
    target = mx.sym.Variable('target')
    label = mx.sym.Variable('target_softmax_label')

    source_embed = mx.sym.Embedding(
        data=source,
        input_dim=source_vocab_size,
        output_dim=embed_size,
        weight= mx.sym.Variable('source_embed_weight')
    )
    
    #source_vector_of_words = mx.sym.split(data=source_embed, num_outputs=source_len, squeeze_axis=1)
    
    target_embed = mx.sym.Embedding(
        data=target,
        input_dim=target_vocab_size,
        output_dim=embed_size,
        weight= mx.sym.Variable('target_embed_weight')
    )
    
    #target_vector_of_words = mx.sym.split(data=target_embed, num_outputs=target_len, squeeze_axis=1)

    
    bi_cell = mx.rnn.BidirectionalCell(
        mx.rnn.GRUCell(num_hidden=num_hidden//2, prefix="gru1_"),
        mx.rnn.GRUCell(num_hidden=num_hidden//2, prefix="gru2_"),
        output_prefix="bi_"
    )

    encoder = mx.rnn.ResidualCell(bi_cell)

    _, encoder_state = encoder.unroll(
        length=source_len,
        inputs=source_embed,
        merge_outputs=False
    )

    encoder_state = mx.sym.concat(encoder_state[0][0],encoder_state[1][0])

    decoder = mx.rnn.GRUCell(num_hidden=num_hidden)

    rnn_output, decoder_state = decoder.unroll(
        length=target_len,
        inputs=target_embed,
        begin_state=encoder_state,
        merge_outputs=True
    )

    flat=mx.sym.Flatten(data=rnn_output)

    fc=mx.sym.FullyConnected(
        data=flat,
        num_hidden=target_len*target_vocab_size
    )
    #drop=mx.sym.Dropout(data=fc, p=0.5)
    act=mx.sym.Activation(data=fc, act_type='relu')


    out = mx.sym.Reshape(data=act, shape=((0,target_len,target_vocab_size)))

    net = mx.sym.SoftmaxOutput(data=out, label=label)

    return net, input_names, output_names


In [None]:
def sym_gen(source_vocab_size, target_vocab_size):
    
    def _sym_gen(s_t_len):
    
        return s2s_unroll(
            source_len=s_t_len[0],
            target_len=s_t_len[1],
            input_names=['source','target'],
            output_names=['target_softmax_label'],
            source_vocab_size=source_vocab_size,
            target_vocab_size=target_vocab_size
            )
    
    return _sym_gen

In [None]:
model = mx.mod.BucketingModule(
        sym_gen=sym_gen(len(source_vocab) + 1, len(target_vocab) + 1),
        default_bucket_key=data_train.default_bucket_key,
        context=ctx
)

In [None]:
model.fit(
    train_data=data_train,
    eval_metric = 'acc',
    optimizer=mx.optimizer.Adam(rescale_grad=1/batch_size),
    initializer=mx.initializer.Xavier(),
    batch_end_callback=mx.callback.Speedometer(batch_size, 10),
    num_epoch=1
)

In [None]:
import difflib

testset_size=100

test_set, inverse_test_set, _, _ = generate_chars_train_eval_sets(dataset_size=testset_size)

test_iter = generate_chars_iterators(train_set=test_set, label_set=inverse_test_set, batch_size=1)

predictions=model.predict(test_iter)

match_count=0
for i,pred in enumerate(predictions):
    matched = ints2text(onehot2int(mx.ndarray.round(predictions[i]))) == ints2text(inverse_test_set[i])
    if matched:
        match_count+=1
    else:       
        print(i)
        inverse=ints2text(inverse_test_set[i])
        print(inverse)
        inverse_pred=ints2text(onehot2int(mx.ndarray.round(predictions[i])))
        print(inverse_pred)
        print(matched)
        for i,s in enumerate(difflib.ndiff(inverse, inverse_pred)):
            if s[0]==' ': continue
            elif s[0]=='-':
                print(u'Delete "{}" from position {}'.format(s[-1],i))
            elif s[0]=='+':
                print(u'Add "{}" to position {}'.format(s[-1],i))    
        print("--------------------")

print("Matched %d/%d times" % (match_count,testset_size))
    