Skip to content

[BugFix] Fix doc of argmax #4231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/api/paddle/argmax_cn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ argmax
::::::::
- **x** (Tensor) - 输入的多维 ``Tensor`` ,支持的数据类型:float32、float64、int16、int32、int64、uint8。
- **axis** (int,可选) - 指定对输入Tensor进行运算的轴, ``axis`` 的有效范围是[-R, R),R是输入 ``x`` 的维度个数, ``axis`` 为负数时,进行计算的 ``axis`` 与 ``axis`` + R 一致。默认值为None, 将会对输入的 `x` 进行平铺展开,返回最大值的索引。
- **keepdim** (bool,可选)- 是否保留进行最大值索引操作的轴,默认值为False。
- **keepdim** (bool,可选)- 是否在输出Tensor中保留减小的维度。如果 keepdim 为True,则输出Tensor和 x 具有相同的维度(减少的维度除外,减少的维度的大小为1),默认值为False。
- **dtype** (np.dtype|str,可选)- 输出Tensor的数据类型,可选值为int32,int64,默认值为int64,将返回int64类型的结果。
- **name** (str,可选) – 具体用法请参见 :ref:`api_guide_Name` ,一般无需设置,默认值为None。

Expand All @@ -35,9 +35,12 @@ argmax
x = paddle.to_tensor(data)
out1 = paddle.argmax(x)
print(out1) # 2
out2 = paddle.argmax(x, axis=1)
out2 = paddle.argmax(x, axis=0)
print(out2)
# [2 3 1]
# [2, 2, 0, 1]
out3 = paddle.argmax(x, axis=-1)
print(out3)
# [2 3 1]
out4 = paddle.argmax(x, axis=0, keepdim=True)
print(out4)
# [[2, 2, 0, 1]]