Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, Topi] [TF, MXNet] Unravel Index operator #5082

Merged
merged 11 commits into from
Mar 23, 2020

Conversation

maheshambule
Copy link
Contributor

@maheshambule maheshambule commented Mar 17, 2020

Adds support for unravel_index op.
NumPy Reference:
https://docs.scipy.org/doc/numpy/reference/generated/numpy.unravel_index.html

Added support for Tensorflow and MXNet frontends.

@maheshambule maheshambule changed the title Unravel index op [Relay, Topi] [TF, MXNet]Unravel Index operator Mar 17, 2020
@maheshambule
Copy link
Contributor Author

cc: @kevinthesun, @jwfromm, @masahi Please help in reviewing.

@@ -2509,9 +2515,7 @@ def _parse_param(self, key, value, name, shape):

array_ndim = len(np_array.shape)
if array_ndim == 0:
new_array = np.empty([1], dtype=np_array.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this because we want to pass the scalar as scalar only and not as a tensor of rank 1.

@maheshambule maheshambule changed the title [Relay, Topi] [TF, MXNet]Unravel Index operator [Relay, Topi] [TF, MXNet] Unravel Index operator Mar 17, 2020
for (int v = GetConstInt(shape_shape[0]) - 1; v >= 0; --v) {
ret = tvm::if_then_else(i == v, indexmod(indices_divs.back(), shape[v]), ret);
cur_val = indexdiv(indices_divs.back(), shape[v]);
indices_divs.push_back(cur_val);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that UnravelIndex from topi/include/topi/detail/ravel_unravel.h isn't used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function in this file returns all the coordinates for a given index. In compute definition we just want a coordinate for the current compute index and not for all of them. I was facing issue while extracting the current coordinate because compute index which is a Var can not be directly used to extract Expr from an array of Exprs. I had to use if_then_else construct for that. Please let me know if I am missing something here and if there is an easier way to achieve this. I could have modified the existing function to meet my purposes for example pass in the coordinate index I want to extract and return just that coordinate. Please let me know if I should implement this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok that makes sense. This implementation is good then, no need to change it.

@@ -321,6 +321,12 @@ struct ArgWhereAttrs : public tvm::AttrsNode<ArgWhereAttrs> {
}
}; // struct ArgWhereAttrs

/*! \brief Attributes used in unravel_index operators */
struct UnRavelIndexAttrs : public tvm::AttrsNode<UnRavelIndexAttrs> {
TVM_DECLARE_ATTRS(UnRavelIndexAttrs, "relay.attrs.UnRavelIndexAttrs") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's any need to define an attribute type for an operator without attributes. Although argwhere seems to do the same thing you have, other operators without attributes just don't use one (see nn.batch_flatten as one example). I'd argue we should try to avoid defining unnecessary attrs to prevent bloat.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Thanks. This is good to know. I have removed the attrs for both unravel_index and argwhere.

Copy link
Contributor

@jwfromm jwfromm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments, once fixed I'll merge this.

@tqchen tqchen added the status: need update need update based on feedbacks label Mar 20, 2020
@masahi masahi merged commit fdc8b0d into apache:master Mar 23, 2020
@masahi
Copy link
Member

masahi commented Mar 23, 2020

Thanks @maheshambule @jwfromm

@maheshambule
Copy link
Contributor Author

Thanks @jwfromm, @masahi for review and comments.

@maheshambule maheshambule deleted the unravel_index_op branch March 23, 2020 11:37
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* first cut unravel_index

* merge fixes

* change rates to dilations

* unravel_index op relay, topi, mxnet, tf

* doc changes

* small changes

* remove empty unravel and argwhere attrs

* remove empty unravel and argwhere attrs
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* first cut unravel_index

* merge fixes

* change rates to dilations

* unravel_index op relay, topi, mxnet, tf

* doc changes

* small changes

* remove empty unravel and argwhere attrs

* remove empty unravel and argwhere attrs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants