# High-level RNN MXNet Example

In [2]:
import os
import sys
import numpy as np
import mxnet as mx
from mxnet.io import DataDesc
from common.params_lstm import *
from common.utils import *

In [3]:
# Force one-gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [4]:
print("OS: ", sys.platform)
print("Python: ", sys.version)
print("Numpy: ", np.__version__)
print("MXNet: ", mx.__version__)
print("GPU: ", get_gpu_name())
print(get_cuda_version())
print("CuDNN Version ", get_cudnn_version())

OS:  linux
Python:  3.5.2 |Anaconda custom (64-bit)| (default, Jul  2 2016, 17:53:06) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
Numpy:  1.14.1
MXNet:  0.12.0
GPU:  ['Tesla P100-PCIE-16GB', 'Tesla P100-PCIE-16GB']
CUDA Version 8.0.61
CuDNN Version  6.0.21


In [5]:
def create_symbol(CUDNN=True,
                  maxf=MAXFEATURES, edim=EMBEDSIZE, nhid=NUMHIDDEN, maxl=MAXLEN):
    # https://mxnet.incubator.apache.org/api/python/rnn.html
    data = mx.symbol.Variable('data')
    embedded_step = mx.symbol.Embedding(data=data, input_dim=maxf, output_dim=edim)
    
    # Fusing RNN layers across time step into one kernel
    # Improves speed but is less flexible
    # Currently only supported if using cuDNN on GPU
    if not CUDNN:
        gru_cell = mx.rnn.GRUCell(num_hidden=nhid)
    else:
        gru_cell = mx.rnn.FusedRNNCell(num_hidden=nhid, num_layers=1, mode='gru')
    
    begin_state = gru_cell.begin_state()
    # Call the cell to get the output of one time step for a batch.
    # TODO: TNC layout (sequence length, batch size, and feature dimensions) is faster for RNN
    outputs, states = gru_cell.unroll(length=maxl, inputs=embedded_step, merge_outputs=False, layout='TNC')
    
    fc1 = mx.symbol.FullyConnected(data=outputs[-1], num_hidden=2) 
    input_y = mx.symbol.Variable('softmax_label')  
    m = mx.symbol.SoftmaxOutput(data=fc1, label=input_y, name="softmax")
    return m

In [6]:
def init_model(m, batchs=BATCHSIZE, maxl=MAXLEN, lr=LR, b1=BETA_1, b2=BETA_2, eps=EPS):
    ctx = [mx.gpu(0)]
    mod = mx.mod.Module(context=ctx, symbol=m)
    mod.bind(data_shapes=[DataDesc(name='data', shape=(maxl, batchs), layout='TNC')],
             label_shapes=[DataDesc(name='softmax_label', shape=(batchs,))])
    # Glorot-uniform initializer
    mod.init_params(initializer=mx.init.Xavier(rnd_type='uniform'))
    mod.init_optimizer(optimizer='Adam', 
                       optimizer_params=(('learning_rate', lr),
                                         ('beta1', b1),
                                         ('beta2', b2),
                                         ('epsilon', eps)))
    return mod

In [7]:
%%time
# Data into format for library
x_train, x_test, y_train, y_test = imdb_for_library(seq_len=MAXLEN, max_features=MAXFEATURES)
wrapper_db = lambda args: mx.io.DataBatch(data=[mx.nd.array(args[0])], label=[mx.nd.array(args[1])])

print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
print(x_train.dtype, x_test.dtype, y_train.dtype, y_test.dtype)

Preparing train set...
Preparing test set...
Trimming to 30000 max-features
Padding to length 150
(25000, 150) (25000, 150) (25000,) (25000,)
int32 int32 int32 int32
CPU times: user 5.59 s, sys: 391 ms, total: 5.98 s
Wall time: 5.98 s


In [8]:
%%time
# Load symbol
sym = create_symbol()

CPU times: user 43.3 ms, sys: 0 ns, total: 43.3 ms
Wall time: 42.6 ms


In [9]:
%%time
# Initialise model
model = init_model(sym)

CPU times: user 901 ms, sys: 521 ms, total: 1.42 s
Wall time: 1.43 s


In [10]:
%%time
# Main training loop: 12.7s
metric = mx.metric.create('acc')
for j in range(EPOCHS):
    metric.reset()
    for batch in map(wrapper_db, yield_mb_tn(x_train, y_train, BATCHSIZE, shuffle=True)):
        model.forward(batch) 
        model.update_metric(metric, batch.label)
        model.backward()              
        model.update()
    print('Epoch %d, Training %s' % (j, metric.get()))

Epoch 0, Training ('accuracy', 0.7873397435897436)
Epoch 1, Training ('accuracy', 0.9302083333333333)
Epoch 2, Training ('accuracy', 0.9705128205128205)
CPU times: user 21 s, sys: 4.39 s, total: 25.4 s
Wall time: 23.7 s
