This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
rnn_layer.py
528 lines (461 loc) · 23.1 KB
/
rnn_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=no-member, invalid-name, protected-access, no-self-use
# pylint: disable=too-many-branches, too-many-arguments, no-self-use
# pylint: disable=too-many-lines, arguments-differ
"""Definition of various recurrent neural network layers."""
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']
from ... import ndarray
from .. import Block
from . import rnn_cell
class _RNNLayer(Block):
"""Implementation of recurrent layers."""
def __init__(self, hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
mode, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout == 'TNC' or layout == 'NTC', \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
self._hidden_size = hidden_size
self._num_layers = num_layers
self._mode = mode
self._layout = layout
self._dropout = dropout
self._dir = 2 if bidirectional else 1
self._input_size = input_size
self._i2h_weight_initializer = i2h_weight_initializer
self._h2h_weight_initializer = h2h_weight_initializer
self._i2h_bias_initializer = i2h_bias_initializer
self._h2h_bias_initializer = h2h_bias_initializer
self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]
self.i2h_weight = []
self.h2h_weight = []
self.i2h_bias = []
self.h2h_bias = []
ng, ni, nh = self._gates, input_size, hidden_size
for i in range(num_layers):
for j in (['l', 'r'] if self._dir == 2 else ['l']):
self.i2h_weight.append(
self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni),
init=i2h_weight_initializer,
allow_deferred_init=True))
self.h2h_weight.append(
self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh),
init=h2h_weight_initializer,
allow_deferred_init=True))
self.i2h_bias.append(
self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,),
init=i2h_bias_initializer,
allow_deferred_init=True))
self.h2h_bias.append(
self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,),
init=h2h_bias_initializer,
allow_deferred_init=True))
ni = nh * self._dir
self._unfused = self._unfuse()
def __repr__(self):
s = '{name}({mapping}, {_layout}'
if self._num_layers != 1:
s += ', num_layers={_num_layers}'
if self._dropout != 0:
s += ', dropout={_dropout}'
if self._dir == 2:
s += ', bidirectional'
s += ')'
shape = self.i2h_weight[0].shape
mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates)
return s.format(name=self.__class__.__name__,
mapping=mapping,
**self.__dict__)
def state_info(self, batch_size=0):
raise NotImplementedError
def _unfuse(self):
"""Unfuses the fused RNN in to a stack of rnn cells."""
get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='relu',
**kwargs),
'rnn_tanh': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size,
activation='tanh',
**kwargs),
'lstm': lambda **kwargs: rnn_cell.LSTMCell(self._hidden_size,
**kwargs),
'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size,
**kwargs)}[self._mode]
stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, params=self.params)
with stack.name_scope():
ni = self._input_size
for i in range(self._num_layers):
kwargs = {'input_size': ni,
'i2h_weight_initializer': self._i2h_weight_initializer,
'h2h_weight_initializer': self._h2h_weight_initializer,
'i2h_bias_initializer': self._i2h_bias_initializer,
'h2h_bias_initializer': self._h2h_bias_initializer}
if self._dir == 2:
stack.add(rnn_cell.BidirectionalCell(
get_cell(prefix='l%d_'%i, **kwargs),
get_cell(prefix='r%d_'%i, **kwargs)))
else:
stack.add(get_cell(prefix='l%d_'%i, **kwargs))
if self._dropout > 0 and i != self._num_layers - 1:
stack.add(rnn_cell.DropoutCell(self._dropout))
ni = self._hidden_size * self._dir
return stack
def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
"""Initial state for this cell.
Parameters
----------
batch_size: int
Only required for `NDArray` API. Size of the batch ('N' in layout).
Dimension of the input.
func : callable, default `ndarray.zeros`
Function for creating initial state.
For Symbol API, func can be `symbol.zeros`, `symbol.uniform`,
`symbol.var` etc. Use `symbol.var` if you want to directly
feed input as states.
For NDArray API, func can be `ndarray.zeros`, `ndarray.ones`, etc.
**kwargs :
Additional keyword arguments passed to func. For example
`mean`, `std`, `dtype`, etc.
Returns
-------
states : nested list of Symbol
Starting states for the first RNN step.
"""
states = []
for i, info in enumerate(self.state_info(batch_size)):
if info is not None:
info.update(kwargs)
else:
info = kwargs
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states
def forward(self, inputs, states=None):
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
states = self.begin_state(batch_size, ctx=inputs.context)
if isinstance(states, ndarray.NDArray):
states = [states]
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu' or self._mode == 'lstm':
out = self._forward_kernel(inputs, states)
else:
out = self._forward(inputs, states)
# out is (output, state)
return out[0] if skip_states else out
def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
outputs, states = self._unfused.unroll(
inputs.shape[axis], inputs, states,
layout=self._layout, merge_outputs=True)
new_states = []
for i in range(ns):
state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0)
new_states.append(state)
return outputs, new_states
def _forward_kernel(self, inputs, states):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
params = sum(zip(self.i2h_weight, self.h2h_weight), ())
params += sum(zip(self.i2h_bias, self.h2h_bias), ())
params = (i.data(ctx).reshape((-1,)) for i in params)
params = ndarray.concat(*params, dim=0)
rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode)
if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
outputs, states = rnn[0], [rnn[1]]
if self._layout == 'NTC':
outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1)
return outputs, states
class RNN(_RNNLayer):
r"""Applies a multi-layer Elman RNN with `tanh` or `ReLU` non-linearity to an input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
h_t = \tanh(w_{ih} * x_t + b_{ih} + w_{hh} * h_{(t-1)} + b_{hh})
where :math:`h_t` is the hidden state at time `t`, and :math:`x_t` is the output
of the previous layer at time `t` or :math:`input_t` for the first layer.
If nonlinearity='relu', then `ReLU` is used instead of `tanh`.
Parameters
----------
hidden_size: int
The number of features in the hidden state h.
num_layers: int, default 1
Number of recurrent layers.
activation: {'relu' or 'tanh'}, default 'relu'
The activation function to use.
layout : str, default 'TNC'
The format of input and output tensors. T, N and C stand for
sequence length, batch size, and feature dimensions respectively.
dropout: float, default 0
If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer.
bidirectional: bool, default False
If `True`, becomes a bidirectional RNN.
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
prefix : str or None
Prefix of this `Block`.
params : ParameterDict or None
Shared Parameters for this `Block`.
Inputs:
- **data**: input tensor with shape `(sequence_length, batch_size, input_size)`
when `layout` is "TNC". For other layouts, dimensions are permuted accordingly
using transpose() operator which adds performance overhead. Consider creating
batches in TNC layout during data batching step.
- **states**: initial recurrent state tensor with shape
`(num_layers, batch_size, num_hidden)`. If `bidirectional` is True,
shape will instead be `(2*num_layers, batch_size, num_hidden)`. If
`states` is None, zeros will be used as default begin states.
Outputs:
- **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
when `layout` is "TNC". If `bidirectional` is True, output shape will instead
be `(sequence_length, batch_size, 2*num_hidden)`
- **out_states**: output recurrent state tensor with the same shape as `states`.
If `states` is None `out_states` will not be returned.
Examples
--------
>>> layer = mx.gluon.rnn.RNN(100, 3)
>>> layer.initialize()
>>> input = mx.nd.random.uniform(shape=(5, 3, 10))
>>> # by default zeros are used as begin state
>>> output = layer(input)
>>> # manually specify begin state.
>>> h0 = mx.nd.random.uniform(shape=(3, 3, 100))
>>> output, hn = layer(input, h0)
"""
def __init__(self, hidden_size, num_layers=1, activation='relu',
layout='TNC', dropout=0, bidirectional=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
input_size=0, **kwargs):
super(RNN, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'rnn_'+activation, **kwargs)
def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
class LSTM(_RNNLayer):
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
\begin{array}{ll}
i_t = sigmoid(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\
f_t = sigmoid(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hc} h_{(t-1)} + b_{hg}) \\
o_t = sigmoid(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
\end{array}
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the
cell state at time `t`, :math:`x_t` is the hidden state of the previous
layer at time `t` or :math:`input_t` for the first layer, and :math:`i_t`,
:math:`f_t`, :math:`g_t`, :math:`o_t` are the input, forget, cell, and
out gates, respectively.
Parameters
----------
hidden_size: int
The number of features in the hidden state h.
num_layers: int, default 1
Number of recurrent layers.
layout : str, default 'TNC'
The format of input and output tensors. T, N and C stand for
sequence length, batch size, and feature dimensions respectively.
dropout: float, default 0
If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer.
bidirectional: bool, default False
If `True`, becomes a bidirectional RNN.
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer, default 'lstmbias'
Initializer for the bias vector. By default, bias for the forget
gate is initialized to 1 while all other biases are initialized
to zero.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
prefix : str or None
Prefix of this `Block`.
params : `ParameterDict` or `None`
Shared Parameters for this `Block`.
Inputs:
- **data**: input tensor with shape `(sequence_length, batch_size, input_size)`
when `layout` is "TNC". For other layouts, dimensions are permuted accordingly
using transpose() operator which adds performance overhead. Consider creating
batches in TNC layout during data batching step.
- **states**: a list of two initial recurrent state tensors. Each has shape
`(num_layers, batch_size, num_hidden)`. If `bidirectional` is True,
shape will instead be `(2*num_layers, batch_size, num_hidden)`. If
`states` is None, zeros will be used as default begin states.
Outputs:
- **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
when `layout` is "TNC". If `bidirectional` is True, output shape will instead
be `(sequence_length, batch_size, 2*num_hidden)`
- **out_states**: a list of two output recurrent state tensors with the same
shape as in `states`. If `states` is None `out_states` will not be returned.
Examples
--------
>>> layer = mx.gluon.rnn.LSTM(100, 3)
>>> layer.initialize()
>>> input = mx.nd.random.uniform(shape=(5, 3, 10))
>>> # by default zeros are used as begin state
>>> output = layer(input)
>>> # manually specify begin state.
>>> h0 = mx.nd.random.uniform(shape=(3, 3, 100))
>>> c0 = mx.nd.random.uniform(shape=(3, 3, 100))
>>> output, hn = layer(input, [h0, c0])
"""
def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
**kwargs):
super(LSTM, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'lstm', **kwargs)
def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'},
{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]
class GRU(_RNNLayer):
r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence.
For each element in the input sequence, each layer computes the following
function:
.. math::
\begin{array}{ll}
r_t = sigmoid(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
i_t = sigmoid(W_{ii} x_t + b_{ii} + W_hi h_{(t-1)} + b_{hi}) \\
n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\
h_t = (1 - i_t) * n_t + i_t * h_{(t-1)} \\
\end{array}
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the hidden
state of the previous layer at time `t` or :math:`input_t` for the first layer,
and :math:`r_t`, :math:`i_t`, :math:`n_t` are the reset, input, and new gates, respectively.
Parameters
----------
hidden_size: int
The number of features in the hidden state h
num_layers: int, default 1
Number of recurrent layers.
layout : str, default 'TNC'
The format of input and output tensors. T, N and C stand for
sequence length, batch size, and feature dimensions respectively.
dropout: float, default 0
If non-zero, introduces a dropout layer on the outputs of each
RNN layer except the last layer
bidirectional: bool, default False
If True, becomes a bidirectional RNN.
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
input_size: int, default 0
The number of expected features in the input x.
If not specified, it will be inferred from input.
prefix : str or None
Prefix of this `Block`.
params : ParameterDict or None
Shared Parameters for this `Block`.
Inputs:
- **data**: input tensor with shape `(sequence_length, batch_size, input_size)`
when `layout` is "TNC". For other layouts, dimensions are permuted accordingly
using transpose() operator which adds performance overhead. Consider creating
batches in TNC layout during data batching step.
- **states**: initial recurrent state tensor with shape
`(num_layers, batch_size, num_hidden)`. If `bidirectional` is True,
shape will instead be `(2*num_layers, batch_size, num_hidden)`. If
`states` is None, zeros will be used as default begin states.
Outputs:
- **out**: output tensor with shape `(sequence_length, batch_size, num_hidden)`
when `layout` is "TNC". If `bidirectional` is True, output shape will instead
be `(sequence_length, batch_size, 2*num_hidden)`
- **out_states**: output recurrent state tensor with the same shape as `states`.
If `states` is None `out_states` will not be returned.
Examples
--------
>>> layer = mx.gluon.rnn.GRU(100, 3)
>>> layer.initialize()
>>> input = mx.nd.random.uniform(shape=(5, 3, 10))
>>> # by default zeros are used as begin state
>>> output = layer(input)
>>> # manually specify begin state.
>>> h0 = mx.nd.random.uniform(shape=(3, 3, 100))
>>> output, hn = layer(input, h0)
"""
def __init__(self, hidden_size, num_layers=1, layout='TNC',
dropout=0, bidirectional=False, input_size=0,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
**kwargs):
super(GRU, self).__init__(hidden_size, num_layers, layout,
dropout, bidirectional, input_size,
i2h_weight_initializer, h2h_weight_initializer,
i2h_bias_initializer, h2h_bias_initializer,
'gru', **kwargs)
def state_info(self, batch_size=0):
return [{'shape': (self._num_layers * self._dir, batch_size, self._hidden_size),
'__layout__': 'LNC'}]