[MXNET-92] Support float16 in L2Normalization operator #10078
Changes from all commits
a4bbfe1
d313aab
c402b56
57646cb
8a7ccb5
88effef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,13 +26,18 @@ | |
namespace mxnet { | ||
namespace op { | ||
template<> | ||
Operator* CreateOp<cpu>(L2NormalizationParam param) { | ||
return new L2NormalizationOp<cpu>(param); | ||
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it done this way elsewhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
Operator* op = NULL; | ||
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { | ||
op = new L2NormalizationOp<cpu, DType>(param); | ||
}); | ||
return op; | ||
} | ||
|
||
// DO_BIND_DISPATCH comes from static_operator_common.h | ||
Operator* L2NormalizationProp::CreateOperator(Context ctx) const { | ||
DO_BIND_DISPATCH(CreateOp, param_); | ||
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape, | ||
std::vector<int> *in_type) const { | ||
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you're overriding CreateOperatorEx(), then what ends up calling InferShape(), InferType(), which is normally done by the base class' CreateOperatorEx()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, just added calls to InferType and InferShape to the code, the PR will be updated soon. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just FYI, usually, DType is determined within the Forward() and Backward() functions using the type switch from the actual input blob at runtime. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not saying you need to change it, but if that were the case, you wouldn;t have to override CreateOpEx(), which has nontrivial logic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is InferShape(), InferType() being called? |
||
} | ||
|
||
DMLC_REGISTER_PARAMETER(L2NormalizationParam); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does something still call this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Honestly I'm not really sure, with a simple grep for "CreateOperator" in src only this usage appeared:
nnvm/legacy_op_util.cc:297: return OpStatePtr::Create(prop.ptr->CreateOperatorEx(ctx, &is, &it),
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I see it is masked by your override of CreateOperatorEx()