Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-92] Support float16 in L2Normalization operator #10078

Merged
merged 6 commits into from Mar 20, 2018

Conversation

haojin2
Copy link
Contributor

@haojin2 haojin2 commented Mar 12, 2018

Description

Add support for any datatype for L2Normalization operator mentioned in Issue #2302.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Code is well-documented:
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Change L2Normalization operator from only supporting real_t to supporting any datatype
  • Add additional test cases for float16

@@ -294,7 +321,13 @@ class L2NormalizationProp : public OperatorProperty {
return {ResourceRequest::kTempSpace};
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does something still call this?

Copy link
Contributor Author

@haojin2 haojin2 Mar 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly I'm not really sure, with a simple grep for "CreateOperator" in src only this usage appeared:
nnvm/legacy_op_util.cc:297: return OpStatePtr::Create(prop.ptr->CreateOperatorEx(ctx, &is, &it),

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see it is masked by your override of CreateOperatorEx()

@@ -294,7 +321,13 @@ class L2NormalizationProp : public OperatorProperty {
return {ResourceRequest::kTempSpace};
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see it is masked by your override of CreateOperatorEx()

DO_BIND_DISPATCH(CreateOp, param_);
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you're overriding CreateOperatorEx(), then what ends up calling InferShape(), InferType(), which is normally done by the base class' CreateOperatorEx()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, just added calls to InferType and InferShape to the code, the PR will be updated soon.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, usually, DType is determined within the Forward() and Backward() functions using the type switch from the actual input blob at runtime.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not saying you need to change it, but if that were the case, you wouldn;t have to override CreateOpEx(), which has nontrivial logic.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is InferShape(), InferType() being called?

@haojin2 haojin2 force-pushed the master branch 2 times, most recently from b86f6d7 to a60d44b Compare March 12, 2018 23:07
@haojin2
Copy link
Contributor Author

haojin2 commented Mar 13, 2018

I think this PR should be ready for merge, @rahul003 would you please take a look at it to double-check? Thanks!

@@ -26,13 +26,22 @@
namespace mxnet {
namespace op {
template<>
Operator* CreateOp<cpu>(L2NormalizationParam param) {
return new L2NormalizationOp<cpu>(param);
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it done this way elsewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

DO_BIND_DISPATCH(CreateOp, param_);
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
std::vector<TShape> out_shape, aux_shape;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these checks are not necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean the checks for InferType and InferShape?

@cjolivier01
Copy link
Member

Please add a JIRA ticket

std::vector<int> *aux_type) const override {
CHECK_EQ(in_type->size(), 1U);
int dtype = (*in_type)[0];
CHECK_NE(dtype, -1) << "Input must have specified type";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use mutual inference instead of terminating the program.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@cjolivier01 cjolivier01 changed the title Support float16 in L2Normalization operator [MXNET-92] Support float16 in L2Normalization operator Mar 13, 2018
@cjolivier01
Copy link
Member

@@ -2396,21 +2396,22 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
exe = out.simple_bind(ctx=ctx, data=in_data.shape)
output = exe.forward(is_train=True, data=in_data)
# compare numpy + mxnet
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5)
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also pass atol here. Default is 1e-20 which may result in test becoming flaky if the numbers are small.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@haojin2 haojin2 force-pushed the master branch 3 times, most recently from b2296ae to 74a2fee Compare March 16, 2018 19:55
@haojin2
Copy link
Contributor Author

haojin2 commented Mar 19, 2018

This PR should be good for merge, @cjolivier01 @piiswrong @anirudh2290 @reminisce @rahul003, would you please take another look at this to see if this is good to go through?

@@ -2397,21 +2397,22 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
exe = out.simple_bind(ctx=ctx, data=in_data.shape)
output = exe.forward(is_train=True, data=in_data)
# compare numpy + mxnet
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5)
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-20)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default is 1e-20 can you make atol bigger than this number maybe 1e-5 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@piiswrong piiswrong merged commit 1b71ce1 into apache:master Mar 20, 2018
ashokei pushed a commit to ashokei/incubator-mxnet that referenced this pull request Mar 27, 2018
* enable other dtype in l2 normalization

* Get rid of older code

* address code reviews: get rid of unnecessary checks

* address code reviews

* fix buggy InferType in L2Normalization

* address code review: change atol
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request Mar 30, 2018
* enable other dtype in l2 normalization

* Get rid of older code

* address code reviews: get rid of unnecessary checks

* address code reviews

* fix buggy InferType in L2Normalization

* address code review: change atol
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* enable other dtype in l2 normalization

* Get rid of older code

* address code reviews: get rid of unnecessary checks

* address code reviews

* fix buggy InferType in L2Normalization

* address code review: change atol
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* enable other dtype in l2 normalization

* Get rid of older code

* address code reviews: get rid of unnecessary checks

* address code reviews

* fix buggy InferType in L2Normalization

* address code review: change atol
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants