# LSTM on PennTreeBank
-----
This is an example to show how to use MXNet low-level symbol to make a LSTM network.

We would like to thank Wojciech Zaremba for his work LSTM in Torch. The data is same to Wojciech used in Torch LSTM. https://github.com/wojzaremba/lstm

To get the data, please download directly from:

Training text: https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt

Validation text: https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt

Test text: https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt

In [1]:
import mxnet as mx
import numpy as np
import time


    Build LSTM Symbol

    Parameters:
    ----------
    num_hidden: int
        hidden unit in LSTM
    x: symbol
        input x
    prev_c: symbol
        previous cell
    prev_h: symbol
        previous hidden
    layer_prefix: str
        name prefix for layer
    t_prefix: str
        name prefix for time
    arg_param: dict: str->symbol
        arguments symbol for the lstm symbol
    aux_param: dict: str->symbol
        auxiliary states symbol for the lstm symbol

    Returns:
    --------
    output: symbol
        grouped lstm output [c, h]

    arg_param: dict: str->symbol
        arguments symbol of the lstm symbol

    aux_param: dict: str->symbol
        auxiliary states symbol of the lstm symbol


In [2]:
def lstm_symbol(num_hidden,
                x, prev_c, prev_h,
                layer_prefix, t_prefix,
                arg_param=None, aux_param=None,
                **kwargs):
    # name and variable
    i2h_name = "%s_i2h" % layer_prefix
    h2h_name = "%s_h2h" % layer_prefix
    exist_flag = True
    if arg_param == None or i2h_name + "_weight" not in arg_param:
        exist_flag = False

    if not exist_flag:
        if arg_param == None:
            arg_param = {}
        arg_param[i2h_name + "_weight"] = mx.sym.Variable(i2h_name + "_weight")
        arg_param[i2h_name + "_bias"] = mx.sym.Variable(i2h_name + "_bias")
        arg_param[h2h_name + "_weight"] = mx.sym.Variable(h2h_name + "_weight")
        arg_param[h2h_name + "_bias"] = mx.sym.Variable(h2h_name + "_bias")
    if not exist_flag:
        if aux_param == None:
            aux_param = {}
        aux_param[i2h_name + "_moving_mean"] = mx.sym.Variable(i2h_name + "_moving_mean")
        aux_param[i2h_name + "_moving_var"] = mx.sym.Variable(i2h_name + "_moving_var")
        aux_param[h2h_name + "_moving_mean"] = mx.sym.Variable(h2h_name + "_moving_mean")
        aux_param[h2h_name + "_moving_var"] = mx.sym.Variable(h2h_name + "_moving_var")

    # transform 
    i2h = mx.sym.FullyConnected(*[x,
                                  arg_param[i2h_name + "_weight"],
                                  arg_param[i2h_name + "_bias"]],
                                  num_hidden=num_hidden * 4,
                                  name=i2h_name)
    h2h = mx.sym.FullyConnected(*[prev_h,
                                  arg_param[h2h_name + "_weight"],
                                  arg_param[h2h_name + "_bias"]],
                                  num_hidden=num_hidden * 4,
                                  name=h2h_name)
    gates = i2h + h2h

    # gates
    slice_gates = mx.sym.SliceChannel(data=gates, num_outputs=4)
    in_gate = mx.sym.Activation(data=slice_gates[0], act_type="sigmoid")
    in_transform = mx.sym.Activation(data=slice_gates[1], act_type="tanh")
    forget_gate = mx.sym.Activation(data=slice_gates[2], act_type="sigmoid")
    out_gate = mx.sym.Activation(data=slice_gates[3], act_type="sigmoid")

    # cal states
    next_c = (forget_gate * prev_c) + (in_gate * in_transform)
    next_h = out_gate * mx.sym.Activation(data=next_c, act_type="tanh")
    # We need to block gradient to set 0 gradient back automatically
    next_c = mx.sym.BlockGrad(data=next_c, name="%s_%s_c" % (t_prefix, layer_prefix))
    next_h = mx.sym.BlockGrad(data=next_h, name="%s_%s_h" % (t_prefix, layer_prefix))
    # if you like you can add a dropout symbol here
    # next_h = mx.sym.Dropout(data=next_h, p=0.5)
    output = mx.symbol.Group([next_c, next_h])
    return (output, arg_param, aux_param)


    Build a multi-layer LSTM model for a single component in unrolled RNN

    Parameters:
    -----------
    num_layer: int
        layers of LSTM network
    num_hidden: int
        hidden unit in each LSTM layer
    num_embed: int
        dimention of word embedding
    num_label: int
        output label dimention
    prev_states: list of tuple (prev_c, prev_h)
        prev_states for each LSTM layer
    t_prefix: str
        prefix name of time
    embed_var: list of symbol
        vairable for embedding layer
    cls_var: list of symbol
        variable for linear classifier
    arg_param: dict: str->symbol
        arguments symbol of the lstm symbol
    aux_param: dict: str->symbol
        auxiliary states symbol of the lstm symbol

    Returns:
    layers : list of symbol
        layers of current component
    arg_param: dict: str->symbol
        arguments symbol of the lstm symbol

    aux_param: dict: str->symbol
        auxiliary states symbol of the lstm symbol


In [3]:
def create_model(num_layer, num_hidden, num_embed, num_label,
                 prev_states,
                 t_prefix,
                 embed_var, cls_var, arg_param=None, aux_param=None,
                 **kwargs):
    layers = []
    data = mx.sym.Variable("%s_data" % t_prefix)
    embed_layer = mx.sym.FullyConnected(*[data, embed_var[0], embed_var[1]],
                                        num_hidden=num_embed, name="embedding")
    for i in range(num_layer):
        layer_prefix = "layer_%d" % i
        prev_c, prev_h = prev_states[i]
        if i == 0:
            data = embed_layer
        else:
            data = layers[-1][1]
        args = None
        auxs = None
        if arg_param != None:
            args = arg_param
        if aux_param != None:
            auxs = aux_param
        lstm, arg_param, aux_param = lstm_symbol(num_hidden,
                                                 data, prev_c, prev_h,
                                                 layer_prefix, t_prefix,
                                                 args, auxs,
                                                 **kwargs)
        layers.append(lstm)
    fc = mx.sym.FullyConnected(*[layers[-1][1], cls_var[0], cls_var[1]],
                               num_hidden=num_label, name="cls")
    sm = mx.sym.Softmax(data=fc, name="%s" % t_prefix)
    layers.append(sm)
    return layers, arg_param, aux_param


    Setup Recurrent Network Symbol

    Parameters:
    -----------
    seq_len: int
        length of sequence
    num_layer: int
        layer of hidden lstm layers
    num_embed: int
        dimention of embeeding layer
    num_label: int
        dimention of output space
    models = []

    Returns:
    --------
    rnn: symbol
        A final symbol of RNN network


In [4]:
def setup_rnn_symbol(seq_len, num_layer, num_hidden, num_embed, num_label, **kwargs):
    models = []
    arg_param = None
    aux_param = None
    embed_var = [mx.sym.Variable("embed_weight"), mx.sym.Variable("embed_bias")]
    cls_var = [mx.sym.Variable("cls_weight"), mx.sym.Variable("cls_bias")]
    init_states = []

    for i in range(num_layer):
        init_c = mx.sym.Variable("init_c_%d" % i)
        init_h = mx.sym.Variable("init_h_%d" % i)
        init_states.append([init_c, init_h])

    for i in range(seq_len):
        t_prefix = "t_%d" % i
        if i == 0:
            states = init_states
        else:
            states = [(models[-1][j][0], models[-1][j][1]) for j in range(num_layer)]
        model, arg_param, aux_param = create_model(num_layer, num_hidden, num_embed, num_label,
                                                    states, t_prefix,
                                                    embed_var, cls_var,
                                                    arg_param, aux_param,
                                                    **kwargs)
        models.append(model)
    prob = mx.sym.Group([md[-1] for md in models])
    state = mx.sym.Group([models[-1][i] for i in range(num_layer)])
    rnn = mx.sym.Group([prob, state])
    return rnn


    Setup Recurrent Network Executor

    Parameters:
    -----------
    ctx: Context
        running context
    seq_len: int
        length of sequence
    num_layer: int
        layer of hidden lstm layers
    num_embed: int
        dimention of embeeding layer
    num_label: int
        dimention of output space
    batch_size: int
        number of batch_size
    Returns:
    --------
    rnn: executor
        A final RNN network


In [5]:
def setup_rnn(ctx, seq_len, num_layer, num_hidden, num_embed, num_label, batch_size,
              initializer=mx.init.Uniform(0.05)):

    # get symbol
    rnn_sym = setup_rnn_symbol(seq_len, num_layer, num_hidden, num_embed, num_label)
    input_shapes = {}
    for name in rnn_sym.list_arguments():
        if "init" in name:
            input_shapes[name] = (batch_size, num_hidden)
        if "data" in name:
            input_shapes[name] = (batch_size, num_label)
    # bind symbol
    rnn_model = rnn_sym.simple_bind(ctx=ctx, **input_shapes)
    # init weight
    names = rnn_sym.list_arguments()
    args = dict(zip(names, rnn_model.arg_arrays))
    grad = dict(zip(names, rnn_model.grad_arrays))
    for name, arr in args.items():
        if name.endswith("weight") or name.endswith("bias") or \
           name.endswith("gamma") or name.endswith("beta"):
            initializer(name, arr)
    # structure for later use
    param_array = []
    for i in range(len(names)):
        name = names[i]
        if name.endswith("weight") or name.endswith("bias") or \
           name.endswith("gamma") or name.endswith("beta"):
            param_array.append((i, args[name], grad[name]))
    
    init_states = [(args["init_c_%d" % i], args["init_h_%d" % i]) for i in range(num_layer)]
    last_states = [(rnn_model.outputs[seq_len + i * 2], rnn_model.outputs[seq_len + i *2 + 1]) for i in range(num_layer)]
    return (rnn_sym, rnn_model, param_array, init_states, last_states)

In [6]:
def Logloss(y, prob):
    #eps = 1e-6
    #return -np.sum(np.log(np.maximum(np.choose(y.astype("int32"), prob.T), eps)))
    loss = 0.0
    for i in range(prob.shape[0]):
        loss += -np.log(np.max(prob[i, y[i]], 1e-8))
    loss /= prob.shape[0]
    return loss

def set_onehot_input(onehot, xidx):
    onehot[:] = 0.
    onehot[np.arange(onehot.shape[0]), xidx.astype("int32")] = 1.

def load_data(path, dic=None):
    fi = open(path)
    content = fi.read()
    content = content.replace('\n', '<eos>')
    content = content.split(' ')
    print("Loading %s, size of data = %d" % (path, len(content)))
    x = np.zeros(len(content))
    if dic == None:
        dic = {}
    idx = 0
    for i in range(len(content)):
        word = content[i]
        if len(word) == 0:
            continue
        if not word in dic:
            dic[word] = idx
            idx += 1
        x[i] = dic[word]
    print("Unique token: %d" % len(dic))
    return x, dic

def replicate_data(x, batch_size):
    nbatch = int(x.shape[0] / batch_size)
    x_cut = x[:nbatch * batch_size]
    data = x_cut.reshape((nbatch, batch_size), order='F')
    return data

In [7]:
batch_size = 20
seq_len = 20
vocab = 10000
rnn_hidden = 200
embed = 200
num_layer = 2
num_round = 4
ctx = mx.cpu()
optimizer = mx.optimizer.SGD(learning_rate=0.01, wd=0.0001)
# rnn model
rnn_sym, rnn, param_array, init_states, last_states,  = setup_rnn(ctx=ctx, 
                                                                  seq_len=seq_len, 
                                                                  num_layer=num_layer, 
                                                                  num_hidden=rnn_hidden, 
                                                                  num_embed=embed, 
                                                                  num_label=vocab, 
                                                                  batch_size=batch_size)
seq_prob = [mx.nd.zeros(ctx=mx.cpu(), shape=rnn.outputs[i].shape) for i in range(seq_len)]
param_dict = dict(zip(rnn_sym.list_arguments(), rnn.arg_arrays))
# load data
X_train, dic = load_data("./data/ptb.train.txt")
X_val, _ = load_data("./data/ptb.valid.txt", dic)
X_train_batch = replicate_data(X_train, batch_size)
X_val_batch = replicate_data(X_val, batch_size)
onehot = np.zeros((batch_size, vocab), dtype='float32')


Loading ./data/ptb.train.txt, size of data = 929590
Unique token: 10000
Loading ./data/ptb.valid.txt, size of data = 73761
Unique token: 10000


In [10]:
def set_rnn_inputs(seq_len, idx, onehot, X, param_dict):
    for j in range(seq_len):
        data_key = "t_%d_data" % j
        label_key = "t_%d_label" % j
        next_idx = (idx + 1) % X.shape[0]
        x = X[idx, :]
        y = X[next_idx, :]
        set_onehot_input(onehot, x)
        param_dict[data_key][:] = onehot
        param_dict[label_key][:] = y
        idx += 1

def get_rnn_outputs(seq_len, rnn, seq_prob):
    for j in range(seq_len):
        seq_prob[j][:] = rnn.outputs[j]

def get_nll(seq_len, idx, X, seq_prob):
    nll = 0.
    for j in range(seq_len):
        next_idx = (idx + 1) % X.shape[0]
        y = X[next_idx, :]
        nll += Logloss(y, seq_prob[j].asnumpy())
    return nll
    

for i in range(num_round):
    nbatch = 0.
    nll = 0.
    # reset states
    for init_c, init_h in init_states:
        init_c[:] = 0.
        init_h[:] = 0.
    tic = time.time()
    # train
    while nbatch < X_train_batch.shape[0]:
        set_rnn_inputs(seq_len, nbatch, onehot, X_train_batch, param_dict)
        rnn.forward(is_train=True)
        get_rnn_outputs(seq_len, rnn, seq_prob)
        rnn.backward()
        for ind, weight, grad in param_array:
            optimizer.update(ind, weight, grad, None)
        for j in range(num_layer):
            init_states[j][0][:] = last_states[j][0]
            init_states[j][1][:] = last_states[j][1]
        nll += get_nll(seq_len, nbatch, X_train_batch, seq_prob)
        nbatch += seq_len
        if nbatch % 1000 == 0:
            print("Epoch [%d], Batch [%d]: NLL=%.3f, Prep=%.3f" % (i, nbatch, nll / nbatch, np.exp(nll / nbatch)))
    toc = time.time()
    print("Epoch [%d] Train: Time: %.3f sec, NLL=%.3f, Prep=%.3f" % (i, toc - tic, nll / nbatch, np.exp(nll / nbatch)))
    nbatch = 0
    nll = 0.
    for init_c, init_h in init_states:
        init_c[:] = 0.
        init_h[:] = 0.
    while nbatch < X_val_batch.shape[0]:
        set_rnn_inputs(seq_len, nbatch, onehot, X_val_batch, param_dict)
        rnn.forward(is_train=False)
        get_rnn_outputs(seq_len, rnn, seq_prob)
        nll += get_nll(seq_len, nbatch, X_val_batch, seq_prob)
        nbatch += seq_len
    print("Epoch [%d] Val: NLL=%.3f, Prep=%.3f" % (i, nll / nbatch, np.exp(nll / nbatch)))
    
        
    
    
    

Epoch [0], Batch [20]: NLL=8.503, Prep=4931.846
Epoch [0], Batch [40]: NLL=8.511, Prep=4971.079
Epoch [0], Batch [60]: NLL=8.366, Prep=4300.328
Epoch [0], Batch [80]: NLL=8.273, Prep=3917.564
Epoch [0], Batch [100]: NLL=8.241, Prep=3793.372
Epoch [0], Batch [120]: NLL=8.146, Prep=3448.532
Epoch [0], Batch [140]: NLL=8.062, Prep=3172.689
Epoch [0], Batch [160]: NLL=8.041, Prep=3105.142
Epoch [0], Batch [180]: NLL=8.107, Prep=3318.143
Epoch [0], Batch [200]: NLL=8.091, Prep=3264.713
Epoch [0], Batch [220]: NLL=8.025, Prep=3055.690
Epoch [0], Batch [240]: NLL=8.020, Prep=3040.329
Epoch [0], Batch [260]: NLL=7.993, Prep=2960.196
Epoch [0], Batch [280]: NLL=7.970, Prep=2892.389
Epoch [0], Batch [300]: NLL=8.021, Prep=3042.987
Epoch [0], Batch [320]: NLL=7.979, Prep=2918.540
Epoch [0], Batch [340]: NLL=7.951, Prep=2839.064
Epoch [0], Batch [360]: NLL=7.982, Prep=2927.875
Epoch [0], Batch [380]: NLL=7.989, Prep=2948.181
Epoch [0], Batch [400]: NLL=7.966, Prep=2880.162
Epoch [0], Batch [420]: 



KeyboardInterrupt: 

<mxnet.symbol.Symbol at 0x7f8c5f3df080>