Skip to content

Commit

Permalink
randn API: remove out, devive, stop_gradient; add name (#25409)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhupengyang committed Jul 14, 2020
1 parent 41d2247 commit 2502925
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 166 deletions.
50 changes: 33 additions & 17 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10416,20 +10416,28 @@ def uniform_random_batch_size_like(input,


@templatedoc()
def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
def gaussian_random(shape,
mean=0.0,
std=1.0,
seed=0,
dtype='float32',
name=None):
"""
Generate a random tensor whose data is drawn from a Gaussian distribution.

Args:
shape (tuple[int] | list[int] | Variable | list[Variable]): Shape of the generated random tensor.

mean (float): Mean of the random tensor, defaults to 0.0.

std (float): Standard deviation of the random tensor, defaults to 1.0.

seed (int): ${seed_comment}

dtype(np.dtype | core.VarDesc.VarType | str): Output data type, float32 or float64.
shape(list|tuple|Variable): Shape of the Tensor to be created. The data
type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple,
the elements of it should be integers or Tensors with shape [1]. If
``shape`` is a Variable, it should be an 1-D Tensor .
mean(float): Mean of the random tensor, defaults to 0.0.
std(float): Standard deviation of the random tensor, defaults to 1.0.
seed(int): ${seed_comment}
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output
tensor, which can be float32, float64. Default is float32.
name(str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Default is None.

Returns:
Variable: Random tensor whose data is drawn from a Gaussian distribution, dtype: flaot32 or float64 as specified.
Expand Down Expand Up @@ -10492,11 +10500,16 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
# array([[2.3060477 , 2.676496 , 3.9911983 , 0.9990833 ],
# [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32)
"""

check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random')
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random')

if in_dygraph_mode():
shape = utils._convert_shape_to_list(shape)
return core.ops.gaussian_random('shape', shape, 'mean', mean, 'std',
std, 'seed', seed, 'dtype', dtype)

check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random/randn')
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random/randn')

inputs = {}
attrs = {
Expand All @@ -10507,7 +10520,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
'use_mkldnn': False
}
utils._get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='gaussian_random')
inputs=inputs,
attrs=attrs,
shape=shape,
op_type='gaussian_random/randn')

helper = LayerHelper('gaussian_random', **locals())
out = helper.create_variable_for_type_inference(dtype)
Expand Down Expand Up @@ -15011,13 +15027,13 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0,
float(min), 'max',
float(max), 'seed', seed, 'dtype', dtype)

check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random')
check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random/rand')
check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random/rand')

inputs = dict()
attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
utils._get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random')
inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random/rand')

helper = LayerHelper("uniform_random", **locals())
out = helper.create_variable_for_type_inference(dtype)
Expand Down
121 changes: 50 additions & 71 deletions python/paddle/fluid/tests/unittests/test_randn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,92 +17,71 @@
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle import Program, program_guard


class TestRandnOp(unittest.TestCase):
def test_api(self):
x1 = paddle.randn(shape=[1000, 784], dtype='float32')
x2 = paddle.randn(shape=[1000, 784], dtype='float64')
x3 = fluid.layers.fill_constant(
shape=[1000, 784], dtype='float32', value=0)
paddle.randn(shape=[1000, 784], out=x3, dtype='float32')
x4 = paddle.randn(shape=[1000, 784], dtype='float32', device='cpu')
x5 = paddle.randn(shape=[1000, 784], dtype='float32', device='gpu')
x6 = paddle.randn(
shape=[1000, 784],
dtype='float32',
device='gpu',
stop_gradient=False)

place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(fluid.default_main_program(),
feed={},
fetch_list=[x1, x2, x3, x4, x5, x6])

self.assertAlmostEqual(np.mean(res[0]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[0]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[1]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[1]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[2]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[2]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[3]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[3]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[4]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[4]), 1., delta=0.1)
self.assertAlmostEqual(np.mean(res[5]), .0, delta=0.1)
self.assertAlmostEqual(np.std(res[5]), 1., delta=0.1)
shape = [1000, 784]
train_program = Program()
startup_program = Program()
with program_guard(train_program, startup_program):
x1 = paddle.randn(shape, 'float32')
x2 = paddle.randn(shape, 'float64')

dim_1 = paddle.fill_constant([1], "int64", 20)
dim_2 = paddle.fill_constant([1], "int32", 50)
x3 = paddle.randn([dim_1, dim_2, 784])

var_shape = paddle.nn.data('X', [2], 'int32')
x4 = paddle.randn(var_shape)

place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
exe = paddle.Executor(place)
res = exe.run(train_program,
feed={'X': np.array(
shape, dtype='int32')},
fetch_list=[x1, x2, x3, x4])

for out in res:
self.assertAlmostEqual(np.mean(out), .0, delta=0.1)
self.assertAlmostEqual(np.std(out), 1., delta=0.1)


class TestRandnOpForDygraph(unittest.TestCase):
def test_api(self):
shape = [1000, 784]
place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda(
) else paddle.CPUPlace()
with paddle.imperative.guard(place):
x1 = paddle.randn(shape, 'float32')
x2 = paddle.randn(shape, 'float64')

dim_1 = paddle.fill_constant([1], "int64", 20)
dim_2 = paddle.fill_constant([1], "int32", 50)
x3 = paddle.randn(shape=[dim_1, dim_2, 784])

var_shape = paddle.imperative.to_variable(np.array(shape))
x4 = paddle.randn(var_shape)

for out in [x1, x2, x3, x4]:
self.assertAlmostEqual(np.mean(out.numpy()), .0, delta=0.1)
self.assertAlmostEqual(np.std(out.numpy()), 1., delta=0.1)


class TestRandnOpError(unittest.TestCase):
def test_error(self):
with program_guard(Program(), Program()):

# The argument shape's size of randn_op should not be 0.
def test_shape_size():
out = paddle.randn(shape=[])

self.assertRaises(AssertionError, test_shape_size)
self.assertRaises(AssertionError, paddle.randn, [])

# The argument shape's type of randn_op should be list or tuple.
def test_shape_type():
out = paddle.randn(shape=1)

self.assertRaises(TypeError, test_shape_type)

# The argument dtype of randn_op should be float32 or float64.
def test_dtype_float16():
out = paddle.randn(shape=[1, 2], dtype='float16')

self.assertRaises(TypeError, test_dtype_float16)
self.assertRaises(TypeError, paddle.randn, 1)

# The argument dtype of randn_op should be float32 or float64.
def test_dtype_int32():
out = paddle.randn(shape=[1, 2], dtype='int32')

self.assertRaises(TypeError, test_dtype_int32)

# The argument dtype of randn_op should be float32 or float64.
def test_dtype_int64():
out = paddle.randn(shape=[1, 2], dtype='int64')

self.assertRaises(TypeError, test_dtype_int64)

# The argument dtype of randn_op should be float32 or float64.
def test_dtype_uint8():
out = paddle.randn(shape=[1, 2], dtype='uint8')

self.assertRaises(TypeError, test_dtype_uint8)

# The argument dtype of randn_op should be float32 or float64.
def test_dtype_bool():
out = paddle.randn(shape=[1, 2], dtype='bool')

self.assertRaises(TypeError, test_dtype_bool)
self.assertRaises(TypeError, paddle.randn, [1, 2], 'int32')


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 2502925

Please sign in to comment.