Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

fix custom op for backward compatibility #8721

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mxnet/operator.py
Expand Up @@ -563,7 +563,7 @@ def infer_storage_type_backward(self, in_stype):
list of aux stypes calculated from in_stype,
in the same order as declared in list_auxiliary_states.
"""
return in_stype, [in_stype[0]]*len(self.list_outputs()), \
return in_stype, [in_stype[0]]*len(self.list_arguments()), \
[in_stype[0]]*len(self.list_auxiliary_states())

def list_outputs(self):
Expand Down
37 changes: 37 additions & 0 deletions tests/python/unittest/test_operator.py
Expand Up @@ -3652,6 +3652,43 @@ def create_operator(self, ctx, shapes, dtypes):
assert (x.grad.stype == 'csr')
assert (y.stype == 'csr')
assert (aux.stype == 'csr')
# test for backward compatibility, i.e. the correctness of default implementation of
# infer storage in custom operator
class Mult(mx.operator.CustomOp):
def forward(self, is_train, req, in_data, out_data, aux):
self.assign(out_data[0], req[0], in_data[0]*in_data[1])

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
self.assign(in_grad[0], req[0], in_data[1]*out_grad[0])
self.assign(in_grad[1], req[1], in_data[0]*out_grad[0])

@mx.operator.register("mult")
class MultProp(mx.operator.CustomOpProp):
def __init__(self):
super(MultProp, self).__init__(need_top_grad=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found an issue with custom_op when used with need_top_grad=False. Fixing here: #8725


def list_arguments(self):
return ['lhs', 'rhs']

def list_outputs(self):
return ['output']

def infer_shape(self, in_shape):
return in_shape, [in_shape[0]], []

def create_operator(self, ctx, shapes, dtypes):
return Mult()

lhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
rhs = mx.nd.array(np.random.uniform(-1, 1, size=(4, 10)))
lhs.attach_grad()
rhs.attach_grad()
with mx.contrib.autograd.train_section():
y = mx.nd.Custom(lhs, rhs, op_type='mult')
y.backward()
mx.nd.waitall()
assert_almost_equal(rhs.asnumpy(), lhs.grad.asnumpy())
assert_almost_equal(lhs.asnumpy(), rhs.grad.asnumpy())

def test_psroipooling():
for num_rois in [1, 2]:
Expand Down