Skip to content

Commit

Permalink
Align where op with torch (#5850)
Browse files Browse the repository at this point in the history
* rename single_clinet prelu api name

* align where op with torch

* code format

* revert prelu code

* revert prelu code

Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
BBuf and oneflow-ci-bot committed Aug 12, 2021
1 parent 42149cd commit 326a19c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
10 changes: 7 additions & 3 deletions python/oneflow/nn/modules/where.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@


@register_tensor_op("where")
def where_op(condition, x, y):
def where_op(condition, x=None, y=None):
"""Return a tensor of elements selected from either :attr:`x` or :attr:`y`, depending on :attr:`condition`.
If the element in condition is larger than 0,
it will take the `x` element, else it will take the `y` element
.. note::
If :attr:`x` is None and :attr:`y` is None, flow.where(condition) is
identical to flow.nonzero(condition, as_tuple=True).
The tensors :attr:`condition`, :attr:`x`, :attr:`y` must be broadcastable.
It will take the `x` element, else it will take the `y` element.
Args:
condition (IntTensor): When 1 (nonzero), yield x, otherwise yield y
Expand Down Expand Up @@ -58,6 +59,9 @@ def where_op(condition, x, y):
"""

if x == None and y == None:
return flow.nonzero(condition, as_tuple=True)

assert condition.dtype == flow.int32 or condition.dtype == flow.int8
if isinstance(x, int) or isinstance(x, float):
x = flow.Tensor(
Expand Down
16 changes: 16 additions & 0 deletions python/oneflow/test/modules/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,21 @@ def _test_where_broadcast_x_backward(test_case, device):
test_case.assertTrue(np.allclose(x.grad.numpy(), x_grad, 1e-05, 1e-05))


def _test_where_x_y_none(test_case, device):
condition = flow.Tensor(
np.array([[[-0.462, 0.3139], [0.3898, -0.7197], [0.0478, -0.1657]]]),
dtype=flow.float32,
device=flow.device(device),
requires_grad=True,
)
of_out = flow.where(condition)
of_nonzero = flow.nonzero(condition, as_tuple=True)
for i in range(len(of_out)):
test_case.assertTrue(
np.allclose(of_out[i].numpy(), of_nonzero[i].numpy(), 1e-05, 1e-05)
)


@flow.unittest.skip_unless_1n1d()
class TestWhere(flow.unittest.TestCase):
def test_where(test_case):
Expand All @@ -186,6 +201,7 @@ def test_where(test_case):
_test_where_backward,
_test_where_broadcast_backward,
_test_where_broadcast_x_backward,
_test_where_x_y_none,
]
arg_dict["device"] = ["cpu", "cuda"]
for arg in GenArgList(arg_dict):
Expand Down

0 comments on commit 326a19c

Please sign in to comment.