-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NNVM] Add argmax and argmin operations from topi #1462
Conversation
e488393
to
aac3c26
Compare
20d534c
to
20bd14e
Compare
Thanks for the contribution, please request reviews from reviewers |
@nishi-t @srkreddy1238 can you please review this? |
a = f(x, axis=axis) | ||
if keepdims: | ||
shape = list(a.shape) | ||
shape.insert(axis,1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that it will not work well when axis is None
The following code may be helpful:
https://github.com/dmlc/tvm/blob/master/topi/tests/python/test_topi_reduce.py#L11-L15
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I am about to fix it, but see my comment below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I got it. thanks
verify_reduce((4, 4, 3), _with_keepdims(np.argmax), sym.argmax, otype='int32') | ||
verify_reduce((4, 4, 3), _with_keepdims(np.argmax), sym.argmax, oshape=(4,1,3), otype='int32', axis=1, keepdims=True) | ||
verify_reduce((4, 4, 3), _with_keepdims(np.argmin), sym.argmin, otype='int32') | ||
verify_reduce((4, 4, 3), _with_keepdims(np.argmin), sym.argmin, oshape=(4,1,3), otype='int32', axis=1, keepdims=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a pattern which keepdims is True and axis is None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, axis=None
doesn't work, but the wrong shape is not the only reason. Looking into sources I would say that this value of axis is forbidden for every reduce operation in NNVM. Axis is defined as struct ReduceParams { .. TShape axis; .. }
, but FieldEntryBase::Set
doesn't have specializations for TShape and the default version rejects Nones with
NNVMError: Invalid Parameter format for axis expect but value='None'
(note also missing _type name)
I would prefer to not include None usecases in the current PR, but solve this issue separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I got it.
@@ -114,6 +116,22 @@ def test_reduce(): | |||
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2)) | |||
verify_reduce((4, 4, 3), np.sum, sym.sum) | |||
|
|||
def _with_keepdims(f): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a matter of taste, but I'd suggest that the _with_keepdims definition is before all verify_reduce
def test_reduce():
def _with_keepdims(f):
...
verify_reduce((2, 3, 4), np.max, sym.max, axis=1, keepdims=True)
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
return true; | ||
} | ||
|
||
NNVM_REGISTER_BASE_REDUCE_OP(argmax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NNVM_REGISTER_BASE_REDUCE_OP
uses ReduceParam, but do argmax and argmin support the exclude option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. While ops do use this option (see 01a7df4#diff-ef68c30ca3f3d21c6d5b8dd15cc570a1R315), it is not covered by tests. Should fix it soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, please check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that argmax doesn't have exclude option and the axis parameter is only an int value in mxnet and numpy. Are these little different from other reduction operators? If so, I think that argmax and argmin should have a parameter struct for them without ReduceParam.
@tqchen @srkreddy1238 Could you tell us your opinion about it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, it is nice to keep all reduction operator consistent, as I assume the logics are going to be the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tqchen Thank you for your comments. I got it.
@grwlf Thank you for your work. Please don't mind my above comment.
5e9b74d
to
cbada25
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
Makes argmax operation accessible from NNVM.
There is an issue regarding output types, since the test seems to treats integer value, returned by argmax, as a floating-point. I'd like to ask for advice here. Probably, graph_runtime requires additional changes.