Skip to content

Commit

Permalink
add masked_fill api
Browse files Browse the repository at this point in the history
  • Loading branch information
jinyouzhi committed Sep 13, 2023
1 parent b02a42c commit e667702
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@
from .tensor.manipulation import view # noqa: F401
from .tensor.manipulation import view_as # noqa: F401
from .tensor.manipulation import unfold # noqa: F401
from .tensor.manipulation import masked_fill # noqa: F401
from .tensor.manipulation import masked_fill_ # noqa: F401
from .tensor.math import abs # noqa: F401
from .tensor.math import abs_ # noqa: F401
from .tensor.math import acos # noqa: F401
Expand Down Expand Up @@ -843,4 +845,6 @@
'i1e',
'polygamma',
'polygamma_',
'masked_fill',
'masked_fill_',
]
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
from .manipulation import view # noqa: F401
from .manipulation import view_as # noqa: F401
from .manipulation import unfold # noqa: F401
from .manipulation import masked_fill # noqa: F401
from .manipulation import masked_fill_ # noqa: F401
from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401
Expand Down Expand Up @@ -673,6 +675,8 @@
'i1e',
'polygamma',
'polygamma_',
'masked_fill',
'masked_fill_',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
67 changes: 67 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4430,6 +4430,73 @@ def repeat_interleave(x, repeats, axis=None, name=None):
return out


def masked_fill(x, mask, value, name=None):
"""
Fills elements of the input tensor with value where mask is True. The mask's shape must be broadcastable with shape of the input tensor.
Args:
x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float32, float64.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of ``mask`` must be bool.
value (Scaler or 0-D Tensor): The value used to fill the target tensor.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor, same dimention and dtype with x.
Examples:
.. code-block:: python
>>> import paddle
>>> input_tensor = paddle.ones((3, 3), dtype="float32")
>>> mask_tensor = paddle.to_tensor([[True, False, True],
... [False, True, False],
... [True, False, True]])
>>> outplace_res = paddle.masked_fill(input_tensor, mask_tensor, 0)
>>> print(outplace_res)
"""
value = value.item(0) if isinstance(value, core.eager.Tensor) else value
helper = LayerHelper("masked_fill", **locals())
check_variable_and_dtype(
x,
'x',
['float32', 'float64', 'int32', 'int64'],
'paddle.tensor.manipulation.masked_fill',
)
check_variable_and_dtype(
mask,
'mask',
['bool'],
'paddle.tensor.manipulation.masked_fill',
)
y = paddle.full_like(x, value)
return paddle.where(mask, y, x)


@inplace_apis_in_dygraph_only
def masked_fill_(x, mask, value, name=None):
"""
Inplace version of ``masked_fill`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_fill`.
Examples:
.. code-block:: python
>>> import paddle
>>> input_tensor = paddle.ones((3, 3), dtype="float32")
>>> mask_tensor = paddle.to_tensor([[True, False, True],
... [False, True, False],
... [True, False, True]])
>>> inplace_res = paddle.masked_fill_(input_tensor, mask_tensor, 0)
>>> print(inplace_res)
"""
y = paddle.full_like(x, value)
x = paddle.where(mask, y, x)
return x


def moveaxis(x, source, destination, name=None):
"""
Move the axis of tensor from ``source`` position to ``destination`` position.
Expand Down
12 changes: 12 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,18 @@ def inplace_api_processing(self, var):
return paddle.unsqueeze_(var, -1)


class TestDygraphInplaceMaskedFill(TestDygraphInplace):
def non_inplace_api_processing(self, var):
self.value = np.random.uniform((-5), 5)
self.mask = np.random.randint(2, var.shape).astype('bool')
return paddle.masked_fill(var, self.mask, self.value)

def inplace_api_processing(self, var):
self.value = np.random.uniform((-5), 5)
self.mask = np.random.randint(2, var.shape).astype('bool')
return paddle.masked_fill_(var, self.mask, self.value)


class TestDygraphInplaceReshape(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.reshape(var, [-1])
Expand Down
137 changes: 137 additions & 0 deletions test/legacy_test/test_masked_fill_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle
from paddle.base import core
from paddle.static import Program, program_guard

DYNAMIC = 1
STATIC = 2


def _run_masked_fill(mode, x, mask, value, device='cpu'):
# dynamic mode
if mode == DYNAMIC:
paddle.disable_static()
# Set device
paddle.set_device(device)
x_ = paddle.to_tensor(x)
mask_ = paddle.to_tensor(mask)
# value is scaler
if isinstance(value, (float, int)):
value_ = value
# value is tensor
else:
value_ = paddle.to_tensor(value)
res = paddle.masked_fill(x_, mask_, value_)
return res.numpy()
# static graph mode
elif mode == STATIC:
paddle.enable_static()
# value is scalar
if isinstance(value, (float, int)):
with program_guard(Program(), Program()):
x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
mask_ = paddle.static.data(
name="mask", shape=mask.shape, dtype=mask.dtype
)
value_ = value
res = paddle.masked_fill(x_, mask_, value_)
place = (
paddle.CPUPlace()
if device == 'cpu'
else paddle.CUDAPlace(0)
)
exe = paddle.static.Executor(place)
outs = exe.run(
feed={'x': x, 'mask': mask, 'value': value},
fetch_list=[res],
)
return outs[0]
# y is tensor
else:
with program_guard(Program(), Program()):
x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
mask_ = paddle.static.data(
name="mask", shape=mask.shape, dtype=mask.dtype
)
value_ = paddle.static.data(
name="value", shape=value.shape, dtype=value.dtype
)
res = paddle.masked_fill(x_, mask_, value_)
place = (
paddle.CPUPlace()
if device == 'cpu'
else paddle.CUDAPlace(0)
)
exe = paddle.static.Executor(place)
outs = exe.run(
feed={'x': x, 'mask': mask, 'value': value},
fetch_list=[res],
)
return outs[0]


def check_dtype(input, desired_dtype):
if input.dtype != desired_dtype:
raise ValueError(
"The expected data type to be obtained is {}, but got {}".format(
desired_dtype, input.dtype
)
)


def _np_masked_fill(x, mask, value):
y = np.full_like(x, value)
return np.where(mask, y, x)


class TestMaskedFillAPI(unittest.TestCase):
def setUp(self):
self.places = ['cpu']
if core.is_compiled_with_cuda():
self.places.append('gpu')

def test_masked_fill(self):
np.random.seed(7)
for place in self.places:
shape = (100, 100)
for dt in (np.float64, np.float32, np.int64, np.int32):
x = np.random.uniform((-5), 5, shape).astype(dt)
mask = np.random.randint(2, size=shape).astype('bool')
value = np.random.uniform((-5), 5)
res = _run_masked_fill(DYNAMIC, x, mask, value, place)
check_dtype(res, dt)
np.testing.assert_allclose(res, _np_masked_fill(x, mask, value))
res = _run_masked_fill(STATIC, x, mask, value, place)
check_dtype(res, dt)
np.testing.assert_allclose(res, _np_masked_fill(x, mask, value))
# broadcast
x = np.random.uniform((-5), 5, shape).astype(dt)
mask = np.random.randint(2, size=shape[1:]).astype('bool')
value = np.random.uniform((-5), 5)
res = _run_masked_fill(DYNAMIC, x, mask, value, place)
check_dtype(res, dt)
np.testing.assert_allclose(res, _np_masked_fill(x, mask, value))
res = _run_masked_fill(STATIC, x, mask, value, place)
check_dtype(res, dt)
np.testing.assert_allclose(res, _np_masked_fill(x, mask, value))


if __name__ == '__main__':
unittest.main()

0 comments on commit e667702

Please sign in to comment.