Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Relaxing type requirements for broadcast_like #17977

Merged
merged 3 commits into from May 4, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/operator/tensor/broadcast_reduce_op_value.cc
Expand Up @@ -138,7 +138,16 @@ NNVM_REGISTER_OP(broadcast_like)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs"};
})
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FInferType>("FInferType", [](const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2) << " in operator " << attrs.name;
std::vector<int> checked_in_attrs = { (*in_attrs)[0] };
bool ret = !type_is_none((*in_attrs)[1]) &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

Is it necessary for the condition !type_is_none((*in_attrs)[1]) ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I actully just found the issue, the code is copied from merged PR #14097, which had a similar issue. I think it is better to keep the same assumpation for all *_like op need to relax type requirment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is necessary - what FInferType returns is whether it succeeded in inferring all the types (so that if all operators return true we know that all types are inferred). That is why it is important to not lie and return true only if all types are really inferred (even if we do not actually do anything with the other type).

ElemwiseType<1, 1>(attrs, &checked_in_attrs, out_attrs);
(*in_attrs)[0] = checked_in_attrs[0];
return ret;
})
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::ObjectPtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_operator.py
Expand Up @@ -3119,6 +3119,16 @@ def test_reshape_like_different_types():
z = mx.nd.reshape_like(x, y)
assert_allclose(z.asnumpy(), [[0,0],[0,0],[0,0]])

@with_seed()
def test_broadcast_like_different_types():
x = mx.nd.zeros((2, 1))
y = mx.nd.ones((2, 2))

y = mx.nd.array(y).astype('int32')
z = mx.nd.broadcast_like(x, y)
assert_allclose(z.asnumpy(), [[0,0],[0,0]])
tobecontinued marked this conversation as resolved.
Show resolved Hide resolved
assert x.dtype == z.dtype

@with_seed()
def test_flip():
for ndim in range(1, 6):
Expand Down