diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7769bf643d867..5a5ec20b2ce3b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10416,20 +10416,28 @@ def uniform_random_batch_size_like(input, @templatedoc() -def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): +def gaussian_random(shape, + mean=0.0, + std=1.0, + seed=0, + dtype='float32', + name=None): """ Generate a random tensor whose data is drawn from a Gaussian distribution. Args: - shape (tuple[int] | list[int] | Variable | list[Variable]): Shape of the generated random tensor. - - mean (float): Mean of the random tensor, defaults to 0.0. - - std (float): Standard deviation of the random tensor, defaults to 1.0. - - seed (int): ${seed_comment} - - dtype(np.dtype | core.VarDesc.VarType | str): Output data type, float32 or float64. + shape(list|tuple|Variable): Shape of the Tensor to be created. The data + type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, + the elements of it should be integers or Tensors with shape [1]. If + ``shape`` is a Variable, it should be an 1-D Tensor . + mean(float): Mean of the random tensor, defaults to 0.0. + std(float): Standard deviation of the random tensor, defaults to 1.0. + seed(int): ${seed_comment} + dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output + tensor, which can be float32, float64. Default is float32. + name(str, optional): Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Default is None. Returns: Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified. @@ -10492,11 +10500,16 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): # array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ], # [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32) """ - - check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random') if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) - check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random') + + if in_dygraph_mode(): + shape = utils._convert_shape_to_list(shape) + return core.ops.gaussian_random('shape', shape, 'mean', mean, 'std', + std, 'seed', seed, 'dtype', dtype) + + check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random/randn') + check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random/randn') inputs = {} attrs = { @@ -10507,7 +10520,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): 'use_mkldnn': False } utils._get_shape_tensor_inputs( - inputs=inputs, attrs=attrs, shape=shape, op_type='gaussian_random') + inputs=inputs, + attrs=attrs, + shape=shape, + op_type='gaussian_random/randn') helper = LayerHelper('gaussian_random', **locals()) out = helper.create_variable_for_type_inference(dtype) @@ -15011,13 +15027,13 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, float(min), 'max', float(max), 'seed', seed, 'dtype', dtype) - check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random') - check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random') + check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random/rand') + check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random/rand') inputs = dict() attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} utils._get_shape_tensor_inputs( - inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random') + inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand') helper = LayerHelper("uniform_random", **locals()) out = helper.create_variable_for_type_inference(dtype) diff --git a/python/paddle/fluid/tests/unittests/test_randn_op.py b/python/paddle/fluid/tests/unittests/test_randn_op.py index 808e5a08fd65e..f65cc6dc53b7e 100644 --- a/python/paddle/fluid/tests/unittests/test_randn_op.py +++ b/python/paddle/fluid/tests/unittests/test_randn_op.py @@ -17,92 +17,71 @@ import unittest import numpy as np import paddle -import paddle.fluid as fluid import paddle.fluid.core as core -from paddle.fluid import Program, program_guard +from paddle import Program, program_guard class TestRandnOp(unittest.TestCase): def test_api(self): - x1 = paddle.randn(shape=[1000, 784], dtype='float32') - x2 = paddle.randn(shape=[1000, 784], dtype='float64') - x3 = fluid.layers.fill_constant( - shape=[1000, 784], dtype='float32', value=0) - paddle.randn(shape=[1000, 784], out=x3, dtype='float32') - x4 = paddle.randn(shape=[1000, 784], dtype='float32', device='cpu') - x5 = paddle.randn(shape=[1000, 784], dtype='float32', device='gpu') - x6 = paddle.randn( - shape=[1000, 784], - dtype='float32', - device='gpu', - stop_gradient=False) - - place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( - ) else fluid.CPUPlace() - exe = fluid.Executor(place) - res = exe.run(fluid.default_main_program(), - feed={}, - fetch_list=[x1, x2, x3, x4, x5, x6]) - - self.assertAlmostEqual(np.mean(res[0]), .0, delta=0.1) - self.assertAlmostEqual(np.std(res[0]), 1., delta=0.1) - self.assertAlmostEqual(np.mean(res[1]), .0, delta=0.1) - self.assertAlmostEqual(np.std(res[1]), 1., delta=0.1) - self.assertAlmostEqual(np.mean(res[2]), .0, delta=0.1) - self.assertAlmostEqual(np.std(res[2]), 1., delta=0.1) - self.assertAlmostEqual(np.mean(res[3]), .0, delta=0.1) - self.assertAlmostEqual(np.std(res[3]), 1., delta=0.1) - self.assertAlmostEqual(np.mean(res[4]), .0, delta=0.1) - self.assertAlmostEqual(np.std(res[4]), 1., delta=0.1) - self.assertAlmostEqual(np.mean(res[5]), .0, delta=0.1) - self.assertAlmostEqual(np.std(res[5]), 1., delta=0.1) + shape = [1000, 784] + train_program = Program() + startup_program = Program() + with program_guard(train_program, startup_program): + x1 = paddle.randn(shape, 'float32') + x2 = paddle.randn(shape, 'float64') + + dim_1 = paddle.fill_constant([1], "int64", 20) + dim_2 = paddle.fill_constant([1], "int32", 50) + x3 = paddle.randn([dim_1, dim_2, 784]) + + var_shape = paddle.nn.data('X', [2], 'int32') + x4 = paddle.randn(var_shape) + + place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.Executor(place) + res = exe.run(train_program, + feed={'X': np.array( + shape, dtype='int32')}, + fetch_list=[x1, x2, x3, x4]) + + for out in res: + self.assertAlmostEqual(np.mean(out), .0, delta=0.1) + self.assertAlmostEqual(np.std(out), 1., delta=0.1) + + +class TestRandnOpForDygraph(unittest.TestCase): + def test_api(self): + shape = [1000, 784] + place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else paddle.CPUPlace() + with paddle.imperative.guard(place): + x1 = paddle.randn(shape, 'float32') + x2 = paddle.randn(shape, 'float64') + + dim_1 = paddle.fill_constant([1], "int64", 20) + dim_2 = paddle.fill_constant([1], "int32", 50) + x3 = paddle.randn(shape=[dim_1, dim_2, 784]) + + var_shape = paddle.imperative.to_variable(np.array(shape)) + x4 = paddle.randn(var_shape) + + for out in [x1, x2, x3, x4]: + self.assertAlmostEqual(np.mean(out.numpy()), .0, delta=0.1) + self.assertAlmostEqual(np.std(out.numpy()), 1., delta=0.1) class TestRandnOpError(unittest.TestCase): def test_error(self): with program_guard(Program(), Program()): - # The argument shape's size of randn_op should not be 0. - def test_shape_size(): - out = paddle.randn(shape=[]) - - self.assertRaises(AssertionError, test_shape_size) + self.assertRaises(AssertionError, paddle.randn, []) # The argument shape's type of randn_op should be list or tuple. - def test_shape_type(): - out = paddle.randn(shape=1) - - self.assertRaises(TypeError, test_shape_type) - - # The argument dtype of randn_op should be float32 or float64. - def test_dtype_float16(): - out = paddle.randn(shape=[1, 2], dtype='float16') - - self.assertRaises(TypeError, test_dtype_float16) + self.assertRaises(TypeError, paddle.randn, 1) # The argument dtype of randn_op should be float32 or float64. - def test_dtype_int32(): - out = paddle.randn(shape=[1, 2], dtype='int32') - - self.assertRaises(TypeError, test_dtype_int32) - - # The argument dtype of randn_op should be float32 or float64. - def test_dtype_int64(): - out = paddle.randn(shape=[1, 2], dtype='int64') - - self.assertRaises(TypeError, test_dtype_int64) - - # The argument dtype of randn_op should be float32 or float64. - def test_dtype_uint8(): - out = paddle.randn(shape=[1, 2], dtype='uint8') - - self.assertRaises(TypeError, test_dtype_uint8) - - # The argument dtype of randn_op should be float32 or float64. - def test_dtype_bool(): - out = paddle.randn(shape=[1, 2], dtype='bool') - - self.assertRaises(TypeError, test_dtype_bool) + self.assertRaises(TypeError, paddle.randn, [1, 2], 'int32') if __name__ == "__main__": diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 7b105cc01e1d8..8eabaa84ce3d3 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -21,7 +21,7 @@ from ..fluid.layers.layer_function_generator import templatedoc from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype -from ..fluid.layers import uniform_random, utils +from ..fluid.layers import utils, uniform_random, gaussian_random from ..fluid.layers.tensor import fill_constant from ..fluid.io import shuffle #DEFINE_ALIAS @@ -206,36 +206,23 @@ def get_attr_shape(list_shape): return out -def randn(shape, - out=None, - dtype=None, - device=None, - stop_gradient=True, - name=None): +def randn(shape, dtype=None, name=None): """ :alias_main: paddle.randn :alias: paddle.randn,paddle.tensor.randn,paddle.tensor.random.randn This function returns a tensor filled with random numbers from a normal - distribution with mean 0 and variance 1 (also called the standard normal + distribution with mean 0 and standard deviation 1 (also called the standard normal distribution). Args: - shape(list|tuple): Shape of the generated random tensor. - out(Variable, optional): Optional output which can be any created Variable - that meets the requirements to store the result of operation. If the - out is `None`, a new Variable will be returned to store the result. - Default is None. - dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output - tensor, which can be float32, float64. if dtype is `None` , the data - type of output tensor is `float32` . - Default is None. - device(str, optional): Specific the output variable to be saved in cpu - or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output - variable will be automatically assigned devices. - Default: None. - stop_gradient(bool, optional): Indicating if we stop gradient from current(out) - Variable. Default is True. + shape(list|tuple|Variable): Shape of the Tensor to be created. The data + type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, + the elements of it should be integers or Tensors with shape [1]. If + ``shape`` is a Variable, it should be an 1-D Tensor . + dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output + tensor, which can be float32, float64. If dtype is `None` , the data + type of output tensor is `float32` . Default is None. name(str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Default is None. @@ -244,75 +231,50 @@ def randn(shape, Random tensor whose data is drawn from a standard normal distribution, dtype: flaot32 or float64 as specified. - Return type: - Variable + Return type: Variable Raises: - TypeError: If the type of `shape` is not list or tuple. + TypeError: If the type of `shape` is not Variable, list or tuple. TypeError: If the data type of `dtype` is not float32 or float64. ValueError: If the length of `shape` is not bigger than 0. Examples: .. code-block:: python - # declarative mode - import paddle - import paddle.fluid as fluid + import paddle + import numpy as np - data = paddle.randn([2, 4]) - place = fluid.CPUPlace() - exe = fluid.Executor(place) - res, = exe.run(fluid.default_main_program(), feed={}, fetch_list=[data]) - print(res) - # [[-1.4187592 0.7368311 -0.53748125 -0.0146909 ] - # [-0.66294265 -1.3090698 0.1898754 -0.14065823]] + paddle.enable_imperative() - .. code-block:: python + # example 1: attr shape is a list which doesn't contain tensor Variable. + result_1 = paddle.randn(shape=[2, 3]) + # [[-2.923464 0.11934398 -0.51249987] + # [ 0.39632758 0.08177969 0.2692008 ]] - # imperative mode - import paddle - import paddle.fluid as fluid - import paddle.fluid.dygraph as dg - - place = fluid.CPUPlace() - with dg.guard(place) as g: - x = paddle.randn([2, 4]) - x_np = x.numpy() - print(x_np) - # [[ 1.5149173 -0.26234224 -0.592486 1.4523455 ] - # [ 0.04581212 -0.85345626 1.1687907 -0.02512913]] - """ - helper = LayerHelper("randn", **locals()) - check_type(shape, 'shape', (list, tuple), 'randn') - assert len(shape) > 0, ("The size of argument(shape) can't be zero.") + # example 2: attr shape is a list which contains tensor Variable. + dim_1 = paddle.fill_constant([1], "int64", 2) + dim_2 = paddle.fill_constant([1], "int32", 3) + result_2 = paddle.randn(shape=[dim_1, dim_2, 2]) + # [[[-2.8852394 -0.25898588] + # [-0.47420555 0.17683524] + # [-0.7989969 0.00754541]] + # [[ 0.85201347 0.32320443] + # [ 1.1399018 0.48336947] + # [ 0.8086993 0.6868893 ]]] + + # example 3: attr shape is a Variable, the data type must be int64 or int32. + var_shape = paddle.imperative.to_variable(np.array([2, 3])) + result_3 = paddle.randn(var_shape) + # [[-2.878077 0.17099959 0.05111201] + # [-0.3761474 -1.044801 1.1870178 ]] + """ if dtype is None: dtype = 'float32' - check_dtype(dtype, 'create data type', ['float32', 'float64'], 'randn') - - if out is None: - out = helper.create_variable_for_type_inference(dtype=dtype) - else: - check_variable_and_dtype(out, 'out', [dtype], 'randn') - - out.stop_gradient = stop_gradient - - dtype = convert_np_dtype_to_dtype_(dtype) - seed = np.random.randint(0, 100) - - with device_guard(device): - helper.append_op( - type='gaussian_random', - outputs={'Out': out}, - attrs={ - 'shape': shape, - 'mean': 0.0, - 'std': 1.0, - 'seed': seed, - 'dtype': dtype, - 'use_mkldnn': False - }) + out = gaussian_random( + shape=shape, mean=0.0, std=1.0, seed=0, dtype=dtype, name=name) + out.stop_gradient = True return out @@ -369,6 +331,7 @@ def randperm(n, dtype="int64", name=None): attrs = {'n': n, 'dtype': dtype, 'seed': 0} helper.append_op( type='randperm', inputs={}, outputs={'Out': out}, attrs=attrs) + out.stop_gradient = True return out @@ -439,4 +402,7 @@ def rand(shape, dtype=None, name=None): """ if dtype is None: dtype = 'float32' - return uniform_random(shape, dtype, min=0.0, max=1.0, name=name) + + out = uniform_random(shape, dtype, min=0.0, max=1.0, name=name) + out.stop_gradient = True + return out