Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
rename to dynamic_unroll.
Browse files Browse the repository at this point in the history
  • Loading branch information
BullDemonKing committed Feb 15, 2019
1 parent 363407f commit 4c96b8a
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/mkldnn
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 0f053c to 290226
10 changes: 5 additions & 5 deletions python/mxnet/gluon/contrib/rnn/rnn_cell.py
Expand Up @@ -323,8 +323,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
# pylint: enable= arguments-differ


def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
layout='TNC', valid_length=None):
def dynamic_unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
layout='TNC', valid_length=None):
"""Unrolls an RNN cell across time steps.
Currently, 'TNC' is a preferred layout. unroll on the input of this layout
Expand Down Expand Up @@ -376,9 +376,9 @@ def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
>>> state_shape = (batch_size, input_size)
>>> states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(2)]
>>> valid_length = mx.nd.array([2, 3])
>>> output, states = mx.gluon.contrib.rnn.rnn_cell.unroll(cell, rnn_data, states,
valid_length=valid_length,
layout='TNC')
>>> output, states = mx.gluon.contrib.rnn.rnn_cell.dynamic_unroll(cell, rnn_data, states,
valid_length=valid_length,
layout='TNC')
>>> print(output)
[[[ 0.00767238 0.00023103 0.03973929 -0.00925503 -0.05660512]
[ 0.00881535 0.05428379 -0.02493718 -0.01834097 0.02189514]]
Expand Down
5 changes: 3 additions & 2 deletions tests/python/unittest/test_gluon_contrib.py
Expand Up @@ -324,8 +324,9 @@ def __init__(self, cell_type, hidden_size, layout, prefix=None, params=None):
def hybrid_forward(self, F, inputs, states, valid_length):
if isinstance(valid_length, list) and len(valid_length) == 0:
valid_length = None
return contrib.rnn.rnn_cell.unroll(self.cell, inputs, states,
valid_length=valid_length, layout=self.layout)
return contrib.rnn.rnn_cell.dynamic_unroll(self.cell, inputs, states,
valid_length=valid_length,
layout=self.layout)

def check_unroll(cell_type, num_states, layout):
batch_size = 20
Expand Down

0 comments on commit 4c96b8a

Please sign in to comment.