Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1583,8 +1583,11 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
int dtype = in_types[rnn_enum::kData];
int itype = dtype;
if (param.use_sequence_length) {
itype = in_types[rnn_enum::kSequenceLength];
if (param.mode == rnn_enum::kLstm) itype -= 1;
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param.mode != rnn_enum::kLstm) {
seq_len_input_idx -= 1;
}
itype = in_types[seq_len_input_idx];
}

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
Expand Down Expand Up @@ -1649,7 +1652,7 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
// Hacky. This relies on fact that seq-len type is either the last input,
// or we aren't using seq-len input and this type should be same as dtype.
// Would prefer direct access to RNNParam object here but not sure how to get.
int itype = inputs[inputs.size()-1].type_flag_;
int itype = outputs[outputs.size()-1].type_flag_;

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
Expand All @@ -1669,6 +1672,15 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
}
}


if (param.use_sequence_length) {
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param.mode != rnn_enum::kLstm) {
seq_len_input_idx -= 1;
}
in_data.push_back(outputs[seq_len_input_idx]);
}

op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
});
});
Expand Down
58 changes: 36 additions & 22 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,6 @@ def forward(self, inpt):


def check_layer_bidirectional_varseqlen(size, in_size):
class RefBiLSTMVarSeqLen(gluon.Block):
def __init__(self, size, **kwargs):
super(RefBiLSTMVarSeqLen, self).__init__(**kwargs)
with self.name_scope():
self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')

def forward(self, inpt, sequence_length):
fwd = self._lstm_fwd(inpt)
bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True)
bwd = self._lstm_bwd(bwd_inpt)
bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True)
return nd.concat(fwd, bwd, dim=2)
weights = {}
for d in ['l', 'r']:
weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
Expand All @@ -248,31 +235,58 @@ def forward(self, inpt, sequence_length):
weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))

net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_')
ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_')
ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_')
net.initialize()
ref_net.initialize()
net_params = net.collect_params()
ref_net_params = ref_net.collect_params()
for k in weights:
net_params[k].set_data(weights[k])
ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])

ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k])

batch_size = 10
num_timesteps = 11
data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size))
data_np = data.asnumpy()

# TODO: figure out why int32 doesn't work here
sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float")

net_output = net(data, sequence_length=sequence_length).asnumpy()
ref_net_output = ref_net(data, sequence_length).asnumpy()
sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("int32")
sequence_length_np = sequence_length.asnumpy().astype("int32")

# Reference net is processing batch elements one at a time, so that it is "perfectly sized"
# Because of that, we need to accumulate gradients in reference net.
for p in ref_net.collect_params().values():
p.grad_req = 'add'

ref_net_output = []
with autograd.record():
net_output = net(data.copy(), sequence_length=sequence_length.copy())

for b in range(batch_size):
data_slice = mx.nd.array(data_np[:sequence_length_np[b], b, :]).reshape(sequence_length_np[b], 1, in_size)
ref_output_slice = ref_net(data_slice)
ref_net_output.append(ref_output_slice)

net_output_np = net_output.asnumpy()

# TODO: test state return value as well output
# Only compare the valid sections for each batch entry
for b in range(batch_size):
assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b])
assert_allclose(net_output_np[:sequence_length_np[b], b], ref_net_output[b].asnumpy().squeeze(1),
rtol=1e-2, atol=1e-6)

# Now test backward
net_output.backward()

for ref_output_slice in ref_net_output:
ref_output_slice.backward()

ref_net_params = ref_net.collect_params()

for k in weights:
net_grad = net_params[k].grad()
ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad()
assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(),
rtol=1e-2, atol=1e-6)


@with_seed()
Expand Down