Skip to content

Latest commit

 

History

History
57 lines (38 loc) · 1.76 KB

where_cn.rst

File metadata and controls

57 lines (38 loc) · 1.76 KB

where

.. py:function:: paddle.where(condition, x, y, name=None)




该OP返回一个根据输入 condition, 选择 xy 的元素组成的多维 Tensor

Out_i =
\left\{
\begin{aligned}
&X_i, & & if \ cond_i \ is \ True \\
&Y_i, & & if \ cond_i \ is \ False \\
\end{aligned}
\right.

Note

numpy.where(condition) 功能与 paddle.nonzero(condition, as_tuple=True) 相同。

参数

  • condition (Tensor)- 选择 xy 元素的条件 。为 True (非零值)时,选择 x ,否则选择 y
  • x (Tensor,Scalar,可选)- 多维 TensorScalar,数据类型为 float32float64int32int64xy 必须都给出或者都不给出。
  • y (Tensor,Scalar,可选)- 多维 TensorScalar,数据类型为 float32float64int32int64xy 必须都给出或者都不给出。
  • name (str,可选)- 具体用法请参见 :ref:`api_guide_Name` ,一般无需设置,默认值为None。

返回

Tensor,数据类型与 x 相同的 Tensor

代码示例

import paddle

x = paddle.to_tensor([0.9383, 0.1983, 3.2, 1.2])
y = paddle.to_tensor([1.0, 1.0, 1.0, 1.0])
out = paddle.where(x>1, x, y)

print(out)
#out: [1.0, 1.0, 3.2, 1.2]

out = paddle.where(x>1)
print(out)
#out: (Tensor(shape=[2, 1], dtype=int64, place=CPUPlace, stop_gradient=True,
#            [[2],
#             [3]]),)