Skip to content

Commit

Permalink
zeros_like API: remove device; input -> x (#25413)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Jul 14, 2020
1 parent f795a1b commit f8eccb0
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 60 deletions.
17 changes: 10 additions & 7 deletions python/paddle/fluid/layers/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,14 +1454,17 @@ def zeros_like(x, out=None):
with `x`.
Args:
x(Variable): The input tensor which specifies shape and dtype, the input data dtype could be bool, float32, float64, int32, int64.
out(Variable, optional): If is :attr:`None` , the op will create the variable as output, the data type and shape of \
this variable will be same as input :attr:`x`. If is a tensor, the data type and shape need to be same as input :attr:`x`.
The default value is :attr:`None` .
x(Variable): The input tensor which specifies shape and dtype, the
input data dtype could be bool, float32, float64, int32, int64.
out(Variable, optional): If is :attr:`None` , the op will create the
variable as output, the data type and shape of this variable will
be same as input :attr:`x`. If is a tensor, the data type and shape
need to be same as input :attr:`x`. The default value is :attr:`None` .
Returns:
Variable: The N-D tensor, the element in tensor is related to input data type, if the input data type is bool, \
the output value is False, otherwise is zero. The output shape is the same as the input.
Variable: The N-D tensor, the element in tensor is related to input
data type, if the input data type is bool, the output value is
False, otherwise is zero. The output shape is the same as the input.
Examples:
.. code-block:: python
Expand All @@ -1480,7 +1483,7 @@ def zeros_like(x, out=None):
else:
check_variable_and_dtype(
out, "out", ['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like')
'zeros_like')

helper.append_op(
type='fill_zeros_like', inputs={'X': [x]}, outputs={'Out': [out]})
Expand Down
82 changes: 82 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zeros_like_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import zeros_like
from paddle.fluid import core, Program, program_guard


class TestZerosLikeAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x = paddle.data('x', [3, 4])
self.assertRaises(TypeError, zeros_like, x, 'int8')


class TestZerosLikeAPI(unittest.TestCase):
def test_api(self):
shape = [3, 4]
startup_program = Program()
train_program = Program()
with program_guard(train_program, startup_program):
x = paddle.data('X', shape)

# 'bool', 'float32', 'float64', 'int32', 'int64'
out1 = zeros_like(x)
out2 = zeros_like(x, np.bool)
out3 = zeros_like(x, 'float64')
out4 = zeros_like(x, 'int32')
out5 = zeros_like(x, 'int64')

place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
outs = exe.run(train_program,
feed={'X': np.ones(shape).astype('float32')},
fetch_list=[out1, out2, out3, out4, out5])

for i, dtype in enumerate(
[np.float32, np.bool, np.float64, np.int32, np.int64]):
self.assertEqual(outs[i].dtype, dtype)
self.assertEqual((outs[i] == np.zeros(shape, dtype)).all(), True)


class TestZerosLikeImpeartive(unittest.TestCase):
def test_out(self):
shape = [3, 4]
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with paddle.imperative.guard(place):
x = paddle.imperative.to_variable(np.ones(shape))
for dtype in [np.bool, np.float32, np.float64, np.int32, np.int64]:
out = zeros_like(x, dtype)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(),
True)

out = paddle.tensor.zeros_like(x)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(),
True)

out = paddle.tensor.creation.zeros_like(x)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(),
True)


if __name__ == "__main__":
unittest.main()
76 changes: 23 additions & 53 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def full_like(x, fill_value, dtype=None, name=None):
helper = LayerHelper("full_like", **locals())
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'full_like')
'full_like/zeros_like')
out = helper.create_variable_for_type_inference(dtype=dtype)

helper.append_op(
Expand All @@ -107,7 +107,7 @@ def full_like(x, fill_value, dtype=None, name=None):
attrs={'value': fill_value,
"dtype": dtype},
outputs={'Out': [out]})

out.stop_gradient = True
return out


Expand Down Expand Up @@ -254,74 +254,44 @@ def zeros(shape, dtype=None, name=None):
return fill_constant(value=0.0, shape=shape, dtype=dtype, name=name)


def zeros_like(input, dtype=None, device=None, name=None):
def zeros_like(x, dtype=None, name=None):
"""
:alias_main: paddle.zeros_like
:alias: paddle.zeros_like,paddle.tensor.zeros_like,paddle.tensor.creation.zeros_like
:alias: paddle.zeros_like, paddle.tensor.zeros_like, paddle.tensor.creation.zeros_like
This function creates a zeros tensor which has identical shape and dtype
with `input`.
Args:
input(Variable): The input tensor which specifies shape and dtype.The dtype of input can be
bool, float32, float64, int32, int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type can be set bool, float32, float64, int32, int64.
The default value is None, the dtype is the same as input.
device(str, optional): Which device to run the operator. The :attr:`device` must be
None, 'cpu', 'gpu'. If :attr:`device` is None, it will be choose the device that the user set in
the paddle program. Default value is None.
name(str, optional): The name of output variable, normally there is no need for user to set this this property.
Default value is None, the framework set the name of output variable.
x(Variable): The input tensor which specifies shape and dtype. The
dtype of input can be bool, float16, float32, float64, int32, int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type can
be set bool, float16, float32, float64, int32, int64. The default
value is None, the dtype is the same as input.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
out(Variable): The tensor variable storing the output.
Raise:
TypeError: If dtype is not bool, float16, float32, float64, int32 or int64.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
x = fluid.data(name='x', dtype='float32', shape=[3])
data = paddle.ones_like(x) # data=[1.0, 1.0, 1.0]
data1 = paddle.ones_like(input=x, device="gpu") #data1=[1.0, 1.0. 1.0]
import paddle
import numpy as np
"""
paddle.enable_imperative()
helper = LayerHelper("zeros_like", **locals())
x = paddle.imperative.to_variable(np.array([1,2,3], dtype='float32'))
out1 = paddle.zeros_like(x) # [1.0, 1.0, 1.0]
out2 = paddle.zeros_like(x, dtype='int32') # [1, 1, 1]
attrs = {"value": 0.0}
var_dtype = None
if dtype is not None:
check_dtype(dtype, 'create data type',
['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like')
var_dtype = convert_np_dtype_to_dtype_(dtype)
attrs["dtype"] = var_dtype
else:
var_dtype = input.dtype

out = helper.create_variable_for_type_inference(dtype=var_dtype)

if device is not None:
if device not in ['cpu', 'gpu']:
raise ValueError(
"The value of 'device' in zeros_op must be cpu or gpu, but received %s."
% (device))
with fluid.device_guard(device):
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs=attrs,
outputs={'Out': [out]})
return out
helper.append_op(
type='fill_any_like',
inputs={'X': [input]},
attrs=attrs,
outputs={'Out': [out]})
out.stop_gradient = True
return out
"""
return full_like(x=x, fill_value=0, dtype=dtype, name=name)


def eye(num_rows,
Expand Down

0 comments on commit f8eccb0

Please sign in to comment.