-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-100] Support float16 in Correlation operator #10125
Conversation
3ad1bdd
to
250d5eb
Compare
250d5eb
to
e671bd2
Compare
Tensor<xpu, 4, DType> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, DType>(s); | ||
Tensor<xpu, 4, DType> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, DType>(s); | ||
tmp1 = DType(0.0f); | ||
tmp2 = DType(0.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Preferably, Use static_cast.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will address this issue shortly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like static_cast is causing some compilation errors, taking a look at it now.
type_assign(&(*out_type)[1], dtype); | ||
type_assign(&(*out_type)[2], dtype); | ||
|
||
TYPE_ASSIGN_CHECK(*in_type, 0, dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need both type_assign and TYPE_ASSIGN_CHECK ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was advice from Jun, he suggested that we can do a mutual inference by reducing all datatypes to one and then assign the reduced type back to everything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the reduced datatype is not -1 then all inputs and outputs will then have the same datatype at the end of this function.
float16 doesn't work with static_cast |
Description
Add support for any datatype for Correlation operator, which is not mentioned in issue #2302.
Checklist
Essentials
make lint
)Changes