[MXNET-92] Support float16 in L2Normalization operator #10078
Conversation
@@ -294,7 +321,13 @@ class L2NormalizationProp : public OperatorProperty { | |||
return {ResourceRequest::kTempSpace}; | |||
} | |||
|
|||
Operator* CreateOperator(Context ctx) const override; | |||
Operator* CreateOperator(Context ctx) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does something still call this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly I'm not really sure, with a simple grep for "CreateOperator" in src only this usage appeared:
nnvm/legacy_op_util.cc:297: return OpStatePtr::Create(prop.ptr->CreateOperatorEx(ctx, &is, &it),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see it is masked by your override of CreateOperatorEx()
@@ -294,7 +321,13 @@ class L2NormalizationProp : public OperatorProperty { | |||
return {ResourceRequest::kTempSpace}; | |||
} | |||
|
|||
Operator* CreateOperator(Context ctx) const override; | |||
Operator* CreateOperator(Context ctx) const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see it is masked by your override of CreateOperatorEx()
DO_BIND_DISPATCH(CreateOp, param_); | ||
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, | ||
std::vector<int> *in_type) const { | ||
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you're overriding CreateOperatorEx(), then what ends up calling InferShape(), InferType(), which is normally done by the base class' CreateOperatorEx()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, just added calls to InferType and InferShape to the code, the PR will be updated soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI, usually, DType is determined within the Forward() and Backward() functions using the type switch from the actual input blob at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not saying you need to change it, but if that were the case, you wouldn;t have to override CreateOpEx(), which has nontrivial logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is InferShape(), InferType() being called?
b86f6d7
to
a60d44b
Compare
I think this PR should be ready for merge, @rahul003 would you please take a look at it to double-check? Thanks! |
@@ -26,13 +26,22 @@ | |||
namespace mxnet { | |||
namespace op { | |||
template<> | |||
Operator* CreateOp<cpu>(L2NormalizationParam param) { | |||
return new L2NormalizationOp<cpu>(param); | |||
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it done this way elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
src/operator/l2_normalization.cc
Outdated
DO_BIND_DISPATCH(CreateOp, param_); | ||
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, | ||
std::vector<int> *in_type) const { | ||
std::vector<TShape> out_shape, aux_shape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these checks are not necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean the checks for InferType and InferShape?
Please add a JIRA ticket |
src/operator/l2_normalization-inl.h
Outdated
std::vector<int> *aux_type) const override { | ||
CHECK_EQ(in_type->size(), 1U); | ||
int dtype = (*in_type)[0]; | ||
CHECK_NE(dtype, -1) << "Input must have specified type"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use mutual inference instead of terminating the program.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -2396,21 +2396,22 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10): | |||
exe = out.simple_bind(ctx=ctx, data=in_data.shape) | |||
output = exe.forward(is_train=True, data=in_data) | |||
# compare numpy + mxnet | |||
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5) | |||
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you also pass atol here. Default is 1e-20 which may result in test becoming flaky if the numbers are small.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
b2296ae
to
74a2fee
Compare
This PR should be good for merge, @cjolivier01 @piiswrong @anirudh2290 @reminisce @rahul003, would you please take another look at this to see if this is good to go through? |
@@ -2397,21 +2397,22 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10): | |||
exe = out.simple_bind(ctx=ctx, data=in_data.shape) | |||
output = exe.forward(is_train=True, data=in_data) | |||
# compare numpy + mxnet | |||
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5) | |||
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default is 1e-20 can you make atol bigger than this number maybe 1e-5 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
* enable other dtype in l2 normalization * Get rid of older code * address code reviews: get rid of unnecessary checks * address code reviews * fix buggy InferType in L2Normalization * address code review: change atol
* enable other dtype in l2 normalization * Get rid of older code * address code reviews: get rid of unnecessary checks * address code reviews * fix buggy InferType in L2Normalization * address code review: change atol
* enable other dtype in l2 normalization * Get rid of older code * address code reviews: get rid of unnecessary checks * address code reviews * fix buggy InferType in L2Normalization * address code review: change atol
* enable other dtype in l2 normalization * Get rid of older code * address code reviews: get rid of unnecessary checks * address code reviews * fix buggy InferType in L2Normalization * address code review: change atol
Description
Add support for any datatype for L2Normalization operator mentioned in Issue #2302.
Checklist
Essentials
make lint
)Changes