Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add bidirectional test
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored and leezu committed Aug 4, 2018
1 parent 953e1bb commit fc6c23b
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import mxnet as mx
from mxnet import gluon
from mxnet import gluon, nd
import numpy as np
import copy
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -82,6 +82,7 @@ def test_lstm_cpu_inference():
model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)
model_cell = model._unfuse()
model.initialize(mx.init.One())

y = model(x).asnumpy()
y_cell = model_cell.unroll(2, x, layout='TNC', merge_outputs=True)[0].asnumpy()

Expand Down Expand Up @@ -242,6 +243,45 @@ def test_bidirectional():
assert outs == [(10, 200), (10, 200), (10, 200)]


def test_layer_bidirectional():
class RefBiLSTM(gluon.Block):
def __init__(self, size, **kwargs):
super(RefBiLSTM, self).__init__(**kwargs)
with self.name_scope():
self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')

def forward(self, inpt):
fwd = self._lstm_fwd(inpt)
bwd_inpt = nd.flip(inpt, 0)
bwd = self._lstm_bwd(bwd_inpt)
bwd = nd.flip(bwd, 0)
return nd.concat(fwd, bwd, dim=2)

size = 7
in_size = 5
weights = {}
for d in ['l', 'r']:
weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size))
weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))
weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))

net = gluon.rnn.LSTM(size, bidirectional=True, prefix='lstm_')
ref_net = RefBiLSTM(size, prefix='lstm_')
net.initialize()
ref_net.initialize()
net_params = net.collect_params()
ref_net_params = ref_net.collect_params()
for k in weights:
net_params[k].set_data(weights[k])
ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])

data = mx.random.uniform(shape=(3, 10, in_size))
assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy())



def test_zoneout():
cell = gluon.rnn.ZoneoutCell(gluon.rnn.RNNCell(100, prefix='rnn_'), zoneout_outputs=0.5,
zoneout_states=0.5)
Expand Down

0 comments on commit fc6c23b

Please sign in to comment.