Skip to content

Commit

Permalink
rename argument for SimpleRNN and SimpleRNNCell, fix sample code
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfeiyu committed Aug 27, 2020
1 parent d843db8 commit d0f9fba
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 29 deletions.
28 changes: 21 additions & 7 deletions python/paddle/fluid/layers/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,12 @@ def _switch_grad(x, stop=False):
return (final_outputs, final_states)


def birnn(cell_fw, cell_bw, inputs, initial_states, sequence_length, time_major,
def birnn(cell_fw,
cell_bw,
inputs,
initial_states,
sequence_length=None,
time_major=False,
**kwargs):
"""
birnn creates a bidirectional recurrent neural network specified by
Expand Down Expand Up @@ -686,8 +691,7 @@ def birnn(cell_fw, cell_bw, inputs, initial_states, sequence_length, time_major,
else the shape is `[batch_size, time_steps, size]`, where size is
`cell_fw.hidden_size + cell_bw.hidden_size`.
final_states (tuple): A tuple of the final states of the forward
cell and backward cell.
cell and backward cell.
Examples:
Expand All @@ -696,12 +700,22 @@ def birnn(cell_fw, cell_bw, inputs, initial_states, sequence_length, time_major,
import paddle
paddle.disable_static()
cell_fw = LSTMCell(16, 32)
cell_bw = LSTMCell(16, 32)
inputs = paddle.rand((2, 23, 16))
outputs, final_states = paddle.nn.functional.birnn(cell_fw, cell_bw, inputs)
cell_fw = paddle.nn.LSTMCell(16, 32)
cell_bw = paddle.nn.LSTMCell(16, 32)
inputs = paddle.rand((4, 23, 16))
hf, cf = paddle.rand((4, 32)), paddle.rand((4, 32))
hb, cb = paddle.rand((4, 32)), paddle.rand((4, 32))
initial_states = ((hf, cf), (hb, cb))
outputs, final_states = paddle.nn.functional.birnn(
cell_fw, cell_bw, inputs, initial_states)
"""
if initial_states is None:
state_fw = cell_fw.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0)
state_bw = cell_fw.get_initial_states(
batch_ref=inputs, batch_dim_idx=1 if time_major else 0)
states_fw, states_bw = initial_states
outputs_fw, states_fw = rnn(cell_fw,
inputs,
Expand Down
45 changes: 23 additions & 22 deletions python/paddle/nn/layer/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ class SimpleRNNCell(RNNCellBase):
Arguments:
input_size (int): The input size.
hidden_size (int): The hidden size.
nonlinearity (str, optional): The activation in the SimpleRNN cell.
activation (str, optional): The activation in the SimpleRNN cell.
It can be `tanh` or `relu`. Defaults to `tanh`.
weight_ih_attr (ParamAttr, optional): The parameter attribute for
`weight_ih`. Default: None.
Expand Down Expand Up @@ -342,7 +342,7 @@ class SimpleRNNCell(RNNCellBase):
def __init__(self,
input_size,
hidden_size,
nonlinearity="tanh",
activation="tanh",
weight_ih_attr=None,
weight_hh_attr=None,
bias_ih_attr=None,
Expand Down Expand Up @@ -371,13 +371,13 @@ def __init__(self,

self.input_size = input_size
self.hidden_size = hidden_size
if nonlinearity not in ["tanh", "relu"]:
if activation not in ["tanh", "relu"]:
raise ValueError(
"nonlinearity for SimpleRNNCell should be tanh or relu, "
"but get {}".format(nonlinearity))
self.nonlinearity = nonlinearity
self._nonlinear_fn = paddle.tanh \
if nonlinearity == "tanh" \
"activation for SimpleRNNCell should be tanh or relu, "
"but get {}".format(activation))
self.activation = activation
self._activation_fn = paddle.tanh \
if activation == "tanh" \
else F.relu

def forward(self, inputs, states=None):
Expand All @@ -390,7 +390,7 @@ def forward(self, inputs, states=None):
h2h = paddle.matmul(pre_h, self.weight_hh, transpose_y=True)
if self.bias_hh is not None:
h2h += self.bias_hh
h = self._nonlinear_fn(i2h + h2h)
h = self._activation_fn(i2h + h2h)
return h, h

@property
Expand Down Expand Up @@ -479,9 +479,10 @@ class LSTMCell(RNNCellBase):
x = paddle.randn((4, 16))
prev_h = paddle.randn((4, 32))
prev_c = paddle.randn((4, 32))
cell = paddle.nn.LSTMCell(16, 32)
y, h = cell(x, prev_h)
y, (h, c) = cell(x, (prev_h, prev_c))
"""

Expand Down Expand Up @@ -758,7 +759,7 @@ class RNN(Layer):
prev_h = paddle.randn((4, 32))
cell = paddle.nn.SimpleRNNCell(16, 32)
rnn = paddle.RNN(cell)
rnn = paddle.nn.RNN(cell)
outputs, final_states = rnn(inputs, prev_h)
"""
Expand Down Expand Up @@ -848,9 +849,9 @@ class BiRNN(Layer):
import paddle
paddle.disable_static()
cell_fw = LSTMCell(16, 32)
cell_bw = LSTMCell(16, 32)
rnn = BidirectionalRNN(cell_fw, cell_bw)
cell_fw = paddle.nn.LSTMCell(16, 32)
cell_bw = paddle.nn.LSTMCell(16, 32)
rnn = paddle.nn.BiRNN(cell_fw, cell_bw)
inputs = paddle.rand((2, 23, 16))
outputs, final_states = rnn(inputs)
Expand Down Expand Up @@ -953,7 +954,7 @@ class SimpleRNN(RNNMixin):
input_size (int): The input size for the first layer's cell.
hidden_size (int): The hidden size for each layer's cell.
num_layers (int, optional): Number of layers. Defaults to 1.
nonlinearity (str, optional): The activation in each SimpleRNN cell. It can be
activation (str, optional): The activation in each SimpleRNN cell. It can be
`tanh` or `relu`. Defaults to `tanh`.
direction (str, optional): The direction of the network. It can be "forward",
"backward" and "bidirectional". Defaults to "forward".
Expand Down Expand Up @@ -1018,7 +1019,7 @@ def __init__(self,
input_size,
hidden_size,
num_layers=1,
nonlinearity="tanh",
activation="tanh",
direction="forward",
dropout=0.,
time_major=False,
Expand All @@ -1031,29 +1032,29 @@ def __init__(self,

if direction in ["forward", "backward"]:
is_reverse = direction == "backward"
cell = SimpleRNNCell(input_size, hidden_size, nonlinearity,
cell = SimpleRNNCell(input_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr, bias_ih_attr,
bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
for i in range(1, num_layers):
cell = SimpleRNNCell(hidden_size, hidden_size, nonlinearity,
cell = SimpleRNNCell(hidden_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr,
bias_ih_attr, bias_hh_attr)
self.append(RNN(cell, is_reverse, time_major))
elif direction == "bidirectional":
cell_fw = SimpleRNNCell(input_size, hidden_size, nonlinearity,
cell_fw = SimpleRNNCell(input_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr,
bias_ih_attr, bias_hh_attr)
cell_bw = SimpleRNNCell(input_size, hidden_size, nonlinearity,
cell_bw = SimpleRNNCell(input_size, hidden_size, activation,
weight_ih_attr, weight_hh_attr,
bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
for i in range(1, num_layers):
cell_fw = SimpleRNNCell(
2 * hidden_size, hidden_size, nonlinearity, weight_ih_attr,
2 * hidden_size, hidden_size, activation, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
cell_bw = SimpleRNNCell(
2 * hidden_size, hidden_size, nonlinearity, weight_ih_attr,
2 * hidden_size, hidden_size, activation, weight_ih_attr,
weight_hh_attr, bias_ih_attr, bias_hh_attr)
self.append(BiRNN(cell_fw, cell_bw, time_major))
else:
Expand Down

0 comments on commit d0f9fba

Please sign in to comment.