Skip to content

[Unity][Op] Extend relax.op.take op to match behavior with topi.take. #14481

Merged
vinx13 merged 3 commits intoapache:unityfrom
sunggg:extend_take
Apr 5, 2023
Merged

[Unity][Op] Extend relax.op.take op to match behavior with topi.take. #14481
vinx13 merged 3 commits intoapache:unityfrom
sunggg:extend_take

Conversation

@sunggg
Copy link
Contributor

@sunggg sunggg commented Apr 4, 2023

Currently, relax.op.take implements semantics of torch.take which has a stronger restriction than its backing topi.take which seems to follow the semantics of numpy.take.
i.e., relax.op.take enforces its input indices to be 1-d tensor.

To match its behavior with topi.take, this PR extends the semantics of relax.op.take by removing such restriction without breaking existing test cases (except one for checking 1-d tensor restriction).

As a result, this PR can support onnx.gather which was the main motivation of my previous PR #14457 by resolving the problem in my previous stab.

cc. @jwfromm @yongwww @MasterJH5574

@tvm-bot
Copy link
Collaborator

tvm-bot commented Apr 4, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me

@tqchen
Copy link
Member

tqchen commented Apr 4, 2023

as long as the existing behavior are preserved, i think we can support multi-dim index

@sunggg
Copy link
Contributor Author

sunggg commented Apr 4, 2023

@tqchen I checked that existing tests are passed, so I believe the existing behavior is intact.

@tqchen
Copy link
Member

tqchen commented Apr 4, 2023

@MasterJH5574 would be great if you can help to take a quick look at this one

!(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) {

if (indices_sinfo->IsUnknownDtype()) {
// TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for warning?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving as LOG(WARNING) looks fine to me.

@MasterJH5574
Copy link
Contributor

@sunggg Sorry for the delay and thanks for the extension!

@vinx13 vinx13 merged commit 8e4f94a into apache:unity Apr 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants