# High-level LSTM MXNet Example

In [1]:
import os
import sys
import numpy as np
import mxnet as mx
from common.params_lstm import *
from common.utils import *

In [2]:
print("OS: ", sys.platform)
print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("MXNet: ", mx.__version__)
print("GPU: ", get_gpu_name())

OS:  linux
Python:  3.5.2 |Anaconda custom (64-bit)| (default, Jul  2 2016, 17:53:06) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
Numpy:  1.13.3
MXNet:  0.11.0
GPU:  ['Tesla K80']


In [3]:
def create_symbol():
    # https://mxnet.incubator.apache.org/api/python/rnn.html
    data = mx.symbol.Variable('data')
    embedded_step = mx.symbol.Embedding(data=data, input_dim=MAXFEATURES, output_dim=EMBEDSIZE)
    gru_cell = mx.rnn.GRUCell(num_hidden=NUMHIDDEN)
    # Initialize its hidden and memory states.
    # 'begin_state' method takes an initialization function, and uses 'zeros' by default.
    begin_state = gru_cell.begin_state()
    # Call the cell to get the output of one time step for a batch.
    output, states = gru_cell.unroll(length=MAXLEN, inputs=embedded_step, merge_outputs=False)
    # output, states = gru_cell(embedded_step, begin_state) ***WRONG***
    # FC out
    fc1 = mx.symbol.FullyConnected(data=output[-1], num_hidden=2) 
    # Label
    input_y = mx.symbol.Variable('softmax_label')  
    m = mx.symbol.SoftmaxOutput(data=fc1, label=input_y, name="softmax")
    return m

In [4]:
def init_model(m):
    if GPU:
        ctx = [mx.gpu(0)]
    else:
        ctx = mx.cpu()
    mod = mx.mod.Module(context=ctx, symbol=m)
    mod.bind(data_shapes=[('data', (BATCHSIZE, MAXLEN))],
             label_shapes=[('softmax_label', (BATCHSIZE,))])
    # Glorot-uniform initializer
    mod.init_params(initializer=mx.init.Xavier(rnd_type='uniform'))
    mod.init_optimizer(optimizer='Adam', 
                       optimizer_params=(('learning_rate', LR),
                                         ('beta1', BETA_1),
                                         ('beta2', BETA_2),
                                         ('epsilon', EPS)))
    return mod

In [5]:
%%time
# Data into format for library
x_train, x_test, y_train, y_test = imdb_for_library(seq_len=MAXLEN, max_features=MAXFEATURES)# CNTK format

# Load data-iterator
#train_iter = mx.io.NDArrayIter(x_train, y_train, BATCHSIZE, shuffle=True)
# Use custom iterator instead of mx.io.NDArrayIter() for consistency
# Wrap as DataBatch class
wrapper_db = lambda args: mx.io.DataBatch(data=[mx.nd.array(args[0])], label=[mx.nd.array(args[1])])

print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)

Downloading https://s3.amazonaws.com/text-datasets/imdb.npz
Done.
Extracting files...
Done.
Trimming to 30000 max-features
Padding to length 150
(25000, 150) (25000, 150) (25000,) (25000,)
int32 int32 int32 int32
CPU times: user 5.53 s, sys: 284 ms, total: 5.81 s
Wall time: 7.11 s


In [6]:
%%time
# Load symbol
sym = create_symbol()

CPU times: user 172 ms, sys: 0 ns, total: 172 ms
Wall time: 168 ms


In [7]:
%%time
# Initialise model
model = init_model(sym)

CPU times: user 860 ms, sys: 864 ms, total: 1.72 s
Wall time: 1.74 s


In [8]:
%%time
# Train and log accuracy
metric = mx.metric.create('acc')
for j in range(EPOCHS):
    #train_iter.reset()
    metric.reset()
    #for batch in train_iter:
    for batch in map(wrapper_db, yield_mb(x_train, y_train, BATCHSIZE, shuffle=True)):
        model.forward(batch, is_train=True) 
        model.update_metric(metric, batch.label)
        model.backward()              
        model.update()
    print('Epoch %d, Training %s' % (j, metric.get()))

Epoch 0, Training ('accuracy', 0.7556891025641026)
Epoch 1, Training ('accuracy', 0.91290064102564106)
Epoch 2, Training ('accuracy', 0.95789262820512822)
CPU times: user 1min 17s, sys: 14.2 s, total: 1min 31s
Wall time: 1min 27s


In [9]:
%%time
y_guess = model.predict(mx.io.NDArrayIter(x_test, batch_size=BATCHSIZE, shuffle=False))
y_guess = np.argmax(y_guess.asnumpy(), axis=-1)

CPU times: user 10.1 s, sys: 2.13 s, total: 12.2 s
Wall time: 11.5 s


In [10]:
print("Accuracy: ", sum(y_guess == y_test)/len(y_guess))

Accuracy:  0.85864
