Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add fused rnn cell (#5004)
Browse files Browse the repository at this point in the history
* fused rnn

fix

fused

fix

* fix

* rnn example

* fix
  • Loading branch information
piiswrong authored Feb 16, 2017
1 parent 4b0665e commit 22673b6
Show file tree
Hide file tree
Showing 10 changed files with 893 additions and 106 deletions.
173 changes: 173 additions & 0 deletions example/rnn/cudnn_lstm_bucketing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import numpy as np
import mxnet as mx
import argparse

parser = argparse.ArgumentParser(description="Train RNN on Penn Tree Bank",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--test', default=False, action='store_true',
help='whether to do testing instead of training')
parser.add_argument('--model-prefix', type=str, default=None,
help='path to save/load model')
parser.add_argument('--load-epoch', type=int, default=0,
help='load from epoch')
parser.add_argument('--num-layers', type=int, default=2,
help='number of stacked RNN layers')
parser.add_argument('--num-hidden', type=int, default=200,
help='hidden layer size')
parser.add_argument('--num-embed', type=int, default=200,
help='embedding layer size')
parser.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. ' \
'Increase batch size when using multiple gpus for best performance.')
parser.add_argument('--kv-store', type=str, default='device',
help='key-value store type')
parser.add_argument('--num-epochs', type=int, default=25,
help='max num of epochs')
parser.add_argument('--lr', type=float, default=0.01,
help='initial learning rate')
parser.add_argument('--optimizer', type=str, default='sgd',
help='the optimizer type')
parser.add_argument('--mom', type=float, default=0.0,
help='momentum for sgd')
parser.add_argument('--wd', type=float, default=0.00001,
help='weight decay for sgd')
parser.add_argument('--batch-size', type=int, default=32,
help='the batch size.')
parser.add_argument('--disp-batches', type=int, default=50,
help='show progress for every n batches')


#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]

start_label = 1
invalid_label = 0

def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
lines = open(fname).readlines()
lines = [filter(None, i.split(' ')) for i in lines]
sentences, vocab = mx.rnn.encode_sentences(lines, vocab=vocab, invalid_label=invalid_label, start_label=start_label)
return sentences, vocab

def get_data(layout):
train_sent, vocab = tokenize_text("./data/ptb.train.txt", start_label=start_label,
invalid_label=invalid_label)
val_sent, _ = tokenize_text("./data/ptb.test.txt", vocab=vocab, start_label=start_label,
invalid_label=invalid_label)

data_train = mx.rnn.BucketSentenceIter(train_sent, args.batch_size, buckets=buckets,
invalid_label=invalid_label, layout=layout)
data_val = mx.rnn.BucketSentenceIter(val_sent, args.batch_size, buckets=buckets,
invalid_label=invalid_label, layout=layout)
return data_train, data_val, vocab


def train(args):
data_train, data_val, vocab = get_data('TN')

cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, mode='lstm')

def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')

output, _ = cell.unroll(seq_len, inputs=embed, merge_outputs=True, layout='TNC')

pred = mx.sym.Reshape(output, shape=(-1, args.num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

return pred, ('data',), ('softmax_label',)

if args.gpus:
contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
else:
contexts = mx.cpu(0)

model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = data_train.default_bucket_key,
context = contexts)

if args.load_epoch:
_, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(
cell, args.model_prefix, args.load_epoch)
else:
arg_params = None
aux_params = None

model.fit(
train_data = data_train,
eval_data = data_val,
eval_metric = mx.metric.Perplexity(invalid_label),
kvstore = args.kv_store,
optimizer = args.optimizer,
optimizer_params = { 'learning_rate': args.lr,
'momentum': args.mom,
'wd': args.wd },
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
arg_params = arg_params,
aux_params = aux_params,
begin_epoch = args.load_epoch,
num_epoch = args.num_epochs,
batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches),
epoch_end_callback = mx.rnn.do_rnn_checkpoint(cell, args.model_prefix, 1)
if args.model_prefix else None)

def test(args):
assert args.model_prefix, "Must specifiy path to load from"
_, data_val, vocab = get_data('NT')

stack = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_'%i))

def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
output_dim=args.num_embed, name='embed')

outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

label = mx.sym.Reshape(label, shape=(-1,))
pred = mx.sym.SoftmaxOutput(data=pred, label=label, name='softmax')

return pred, ('data',), ('softmax_label',)

if args.gpus:
contexts = [mx.gpu(int(i)) for i in args.gpus.split(',')]
else:
contexts = mx.cpu(0)

model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = data_val.default_bucket_key,
context = contexts)
model.bind(data_val.provide_data, data_val.provide_label, for_training=False)

# note here we load using SequentialRNNCell instead of FusedRNNCell.
_, arg_params, aux_params = mx.rnn.load_rnn_checkpoint(stack, args.model_prefix, args.load_epoch)
model.set_params(arg_params, aux_params)

model.score(data_val, mx.metric.Perplexity(invalid_label),
batch_end_callback=mx.callback.Speedometer(args.batch_size, 5))

if __name__ == '__main__':
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

args = parser.parse_args()
if args.test:
# Demonstrates how to load a model trained with CuDNN RNN and predict
# with non-fused MXNet symbol
test(args)
else:
train(args)
9 changes: 4 additions & 5 deletions example/rnn/lstm_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,15 @@ def tokenize_text(fname, vocab=None, invalid_label=-1, start_label=0):
def sym_gen(seq_len):
data = mx.sym.Variable('data')
label = mx.sym.Variable('softmax_label')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab), output_dim=args.num_embed,name='embed')
embed = mx.sym.Embedding(data=data, input_dim=len(vocab),
output_dim=args.num_embed, name='embed')

stack = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_l%d_'%i))
outputs, states = mx.rnn.rnn_unroll(stack, seq_len, inputs=embed)
outputs, states = stack.unroll(seq_len, inputs=embed, merge_outputs=True)

outputs = [mx.sym.expand_dims(x, axis=1) for x in outputs]
pred = mx.sym.Concat(*outputs, dim=1)
pred = mx.sym.Reshape(pred, shape=(-1, args.num_hidden))
pred = mx.sym.Reshape(outputs, shape=(-1, args.num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=len(vocab), name='pred')

label = mx.sym.Reshape(label, shape=(-1,))
Expand Down
46 changes: 44 additions & 2 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=too-many-branches
# pylint: disable=too-many-branches, too-many-arguments
"""Initialization helper for mxnet"""
from __future__ import absolute_import, print_function

Expand Down Expand Up @@ -413,7 +413,7 @@ def __init__(self, factor_type="avg", slope=0.25):

@register
class Bilinear(Initializer):
"""docstring for Bilinear"""
"""Initialize weight for upsampling layer"""
def __init__(self):
super(Bilinear, self).__init__()

Expand All @@ -428,3 +428,45 @@ def _init_weight(self, _, arr):
y = (i / shape[3]) % shape[2]
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
arr[:] = weight.reshape(shape)


@register
class FusedRNN(Initializer):
"""Initialze parameters for fused rnn layer
Parameters
----------
init : Initializer
intializer applied to unpacked weights.
num_hidden : int
should be the same with arguments passed to FusedRNNCell.
num_layers : int
should be the same with arguments passed to FusedRNNCell.
mode : str
should be the same with arguments passed to FusedRNNCell.
bidirectional : bool
should be the same with arguments passed to FusedRNNCell.
"""
def __init__(self, init, num_hidden, num_layers, mode, bidirectional=False):
if not isinstance(init, Initializer):
klass, kwargs = json.loads(init)
init = _INITIALIZER_REGISTRY[klass.lower()](**kwargs)
super(FusedRNN, self).__init__(init=init.dumps(), num_hidden=num_hidden,
num_layers=num_layers, mode=mode,
bidirectional=bidirectional)
self._num_hidden = num_hidden
self._num_layers = num_layers
self._bidirectional = bidirectional
self._mode = mode
self._init = init

def _init_weight(self, _, arr):
from .rnn import rnn_cell
cell = rnn_cell.FusedRNNCell(self._num_hidden, self._num_layers,
self._mode, self._bidirectional, prefix='')
args = cell.unpack_weights({'parameters': arr})
for name in args:
desc = InitDesc(name)
self._init(desc, args[name])
arr[:] = cell.pack_weights(args)['parameters']

53 changes: 38 additions & 15 deletions python/mxnet/rnn/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ class BucketSentenceIter(DataIter):
name of data
label_name : str, default 'softmax_label'
name of label
layout : str
format of data and label. 'NT' means (batch_size, length)
and 'TN' means (length, batch_size).
"""
def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32',
buckets=None, data_name='data', label_name='softmax_label'):
def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
data_name='data', label_name='softmax_label', dtype='float32',
layout='NTC'):
super(BucketSentenceIter, self).__init__()
if not buckets:
buckets = [i for i, j in enumerate(np.bincount([len(s) for s in sentences]))
Expand All @@ -90,7 +94,7 @@ def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32',

ndiscard = 0
self.data = [[] for _ in buckets]
for i in xrange(len(sentences)):
for i in range(len(sentences)):
buck = bisect.bisect_left(buckets, len(sentences[i]))
if buck == len(buckets):
ndiscard += 1
Expand All @@ -103,43 +107,62 @@ def __init__(self, sentences, batch_size, invalid_label=-1, dtype='float32',

print("WARNING: discarded %d sentences longer than the largest bucket."%ndiscard)

self.default_bucket_key = max(buckets)

self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]

self.batch_size = batch_size
self.buckets = buckets
self.data_name = data_name
self.label_name = label_name
self.dtype = dtype
self.invalid_label = invalid_label
self.nddata = []
self.ndlabel = []
self.major_axis = layout.find('N')
self.default_bucket_key = max(buckets)

if self.major_axis == 0:
self.provide_data = [(data_name, (batch_size, self.default_bucket_key))]
self.provide_label = [(label_name, (batch_size, self.default_bucket_key))]
elif self.major_axis == 1:
self.provide_data = [(data_name, (self.default_bucket_key, batch_size))]
self.provide_label = [(label_name, (self.default_bucket_key, batch_size))]
else:
raise ValueError("Invalid layout %s: Must by NT (batch major) or TN (time major)")

self.idx = []
for i, buck in enumerate(self.data):
self.idx.extend([(i, j) for j in range(0, len(buck) - batch_size + 1, batch_size)])
self.curr_idx = 0

self.reset()

def reset(self):
self.curr_idx = 0
random.shuffle(self.idx)
for buck in self.data:
np.random.shuffle(buck)

self.nddata = []
self.ndlabel = []
for buck in self.data:
label = np.empty_like(buck)
label[:, :-1] = buck[:, 1:]
label[:, -1] = self.invalid_label

This comment has been minimized.

Copy link
@ForrestGan

ForrestGan Nov 15, 2017

你好,这个地方,我理解为:
label在这里应该是一维数组吧,看赋值情况,这个变成了二维数组;

buck[:, 1:]这个为特征矩阵
buck[:,1]这个为label数组吧?
辛苦看下理解是否正确?

self.nddata.append(ndarray.array(buck, dtype=self.dtype))
self.ndlabel.append(ndarray.array(label, dtype=self.dtype))

def next(self):
if self.curr_idx == len(self.idx):
raise StopIteration
i, j = self.idx[self.curr_idx]
self.curr_idx += 1

data = self.data[i][j:j+self.batch_size]
label = np.empty_like(data)
label[:, :-1] = data[:, 1:]
label[:, -1] = self.invalid_label

if self.major_axis == 1:
data = self.nddata[i][j:j+self.batch_size].T
label = self.ndlabel[i][j:j+self.batch_size].T
else:
data = self.nddata[i][j:j+self.batch_size]
label = self.ndlabel[i][j:j+self.batch_size]

return DataBatch([ndarray.array(data, dtype=self.dtype)],
[ndarray.array(label, dtype=self.dtype)],
return DataBatch([data], [label],
bucket_key=self.buckets[i],
provide_data=[(self.data_name, data.shape)],
provide_label=[(self.label_name, label.shape)])
Expand Down
Loading

0 comments on commit 22673b6

Please sign in to comment.