From 497b1d5eb3809ff8d9be25de3877c7880c88fb3c Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Mon, 20 Nov 2017 10:28:35 +0800 Subject: [PATCH 1/2] fix custom op for backward compatibility --- python/mxnet/operator.py | 2 +- tests/python/unittest/test_operator.py | 36 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/python/mxnet/operator.py b/python/mxnet/operator.py index f515bf83b881..2f8c2c21c274 100644 --- a/python/mxnet/operator.py +++ b/python/mxnet/operator.py @@ -563,7 +563,7 @@ def infer_storage_type_backward(self, in_stype): list of aux stypes calculated from in_stype, in the same order as declared in list_auxiliary_states. """ - return in_stype, [in_stype[0]]*len(self.list_outputs()), \ + return in_stype, [in_stype[0]]*len(self.list_arguments()), \ [in_stype[0]]*len(self.list_auxiliary_states()) def list_outputs(self): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 55a3a5721851..5b87a9cff8d9 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3652,6 +3652,42 @@ def create_operator(self, ctx, shapes, dtypes): assert (x.grad.stype == 'csr') assert (y.stype == 'csr') assert (aux.stype == 'csr') + + # test for backward compatibility, i.e. the correctness of default implementation of + # infer storage in custom operator + class Mult(mx.operator.CustomOp): + def forward(self, is_train, req, in_data, out_data, aux): + self.assign(out_data[0], req[0], in_data[0]*in_data[1]) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + self.assign(in_grad[0], req[0], in_data[1]*out_grad[0]) + self.assign(in_grad[1], req[1], in_data[0]*out_grad[0]) + + @mx.operator.register("mult") + class MultProp(mx.operator.CustomOpProp): + def __init__(self): + super(MultProp, self).__init__(need_top_grad=True) + + def list_arguments(self): + return ['lhs', 'rhs'] + + def list_outputs(self): + return ['output'] + + def infer_shape(self, in_shape): + return in_shape, [in_shape[0]], [] + + def create_operator(self, ctx, shapes, dtypes): + return Mult() + + lhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10))) + rhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10))) + lhs.attach_grad() + rhs.attach_grad() + with mx.contrib.autograd.train_section(): + y = mx.nd.Custom(lhs, rhs, op_type='mult') + y.backward() + mx.nd.waitall() def test_psroipooling(): for num_rois in [1, 2]: From 3cc64c16ba46634b767711ef8c462b9b007b598e Mon Sep 17 00:00:00 2001 From: ZiyueHuang Date: Mon, 20 Nov 2017 14:22:41 +0800 Subject: [PATCH 2/2] address comments and add assert in test --- tests/python/unittest/test_operator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 5b87a9cff8d9..c2bc5193ca4f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3652,7 +3652,6 @@ def create_operator(self, ctx, shapes, dtypes): assert (x.grad.stype == 'csr') assert (y.stype == 'csr') assert (aux.stype == 'csr') - # test for backward compatibility, i.e. the correctness of default implementation of # infer storage in custom operator class Mult(mx.operator.CustomOp): @@ -3688,6 +3687,8 @@ def create_operator(self, ctx, shapes, dtypes): y = mx.nd.Custom(lhs, rhs, op_type='mult') y.backward() mx.nd.waitall() + assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy()) + assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy()) def test_psroipooling(): for num_rois in [1, 2]: