From 9ef4b53526e2c74294c365540bb8a54991a59ed6 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Mon, 29 Jun 2020 14:19:39 +0800 Subject: [PATCH 1/4] rand API: remove out, devive, stop_gradient; add name test=develop --- paddle/fluid/operators/gaussian_random_op.cc | 2 +- python/paddle/fluid/layers/nn.py | 100 ++++++----------- python/paddle/fluid/layers/tensor.py | 8 +- python/paddle/fluid/layers/utils.py | 14 +-- .../fluid/tests/unittests/test_rand_op.py | 58 +++++----- python/paddle/tensor/random.py | 103 +++++++----------- 6 files changed, 113 insertions(+), 172 deletions(-) diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 0f1e4de5cbb88..253078751ce66 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -98,7 +98,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { return; } - if (!(ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList"))) { + if (!ctx->HasInput("ShapeTensor") && !ctx->HasInputs("ShapeTensorList")) { PADDLE_ENFORCE_GT( shape.size(), 0UL, platform::errors::InvalidArgument( diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 11a4d933245dd..f5f0fb5c08687 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -10486,29 +10486,24 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): # [2.8675377 , 2.2279181 , 0.79029655, 2.8447366 ]], dtype=float32) """ - helper = LayerHelper('gaussian_random', **locals()) - out = helper.create_variable_for_type_inference(dtype) - if not isinstance(shape, (list, tuple, Variable)): - raise TypeError( - "The type of 'shape' in fill_constant must be Variable, list or tuple, but " - "received %s." % (type(shape))) - c_dtype = convert_np_dtype_to_dtype_(dtype) + check_type(shape, 'shape', (list, tuple, Variable), 'gaussian_random') + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) + check_dtype(dtype, 'dtype', ['float32', 'float64'], 'gaussian_random') + + inputs = {} attrs = { 'mean': mean, 'std': std, 'seed': seed, - 'dtype': c_dtype, + 'dtype': dtype, 'use_mkldnn': False } - - inputs = {} utils._get_shape_tensor_inputs( - inputs=inputs, - helper=helper, - attrs=attrs, - shape=shape, - op_type='gaussian_random') + inputs=inputs, attrs=attrs, shape=shape, op_type='gaussian_random') + helper = LayerHelper('gaussian_random', **locals()) + out = helper.create_variable_for_type_inference(dtype) helper.append_op( type='gaussian_random', inputs=inputs, @@ -14863,7 +14858,8 @@ def gather_tree(ids, parents): @templatedoc() -def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): +def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, + name=None): """ This OP initializes a variable with random values sampled from a uniform distribution in the range [min, max). @@ -14878,18 +14874,24 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): result=[[0.8505902, 0.8397286]] Args: - shape (list|tuple|Variable): The shape of the output Tensor, if the shape is a list or tuple, - its elements can be an integer - or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64. - If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be int32 or int64. - dtype(np.dtype|core.VarDesc.VarType|str, optional): The type of the output Tensor. Supported data types: float32, float64. - Default: float32. - min (float, optional): The lower bound on the range of random values to generate, the min is included in the range. Default -1.0. - max (float, optional): The upper bound on the range of random values to generate, the max is excluded in the range. Default 1.0. - seed (int, optional): Random seed used for generating samples. 0 means use a - seed generated by the system. Note that if seed is not 0, this - operator will always generate the same random numbers every time. - Default 0. + shape (list|tuple|Variable): The shape of the output Tensor, if the + shape is a list or tuple, its elements can be an integer or a + Tensor with the shape [1], and the type of the Tensor must be + int32 or int64. If the shape is a Variable, it is a 1-D Tensor, and + the type of the Tensor must be int32 or int64. + dtype(np.dtype|core.VarDesc.VarType|str, optional): The type of the + output Tensor. Supported data types: float32, float64. Default: float32. + min (float, optional): The lower bound on the range of random values + to generate, the min is included in the range. Default -1.0. + max (float, optional): The upper bound on the range of random values + to generate, the max is excluded in the range. Default 1.0. + seed (int, optional): Random seed used for generating samples. 0 means + use a seed generated by the system. Note that if seed is not 0, + this operator will always generate the same random numbers every + time. Default 0. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Returns: Variable: A Tensor of the specified shape filled with uniform_random values. @@ -14919,62 +14921,26 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0): var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32") result_4 = fluid.layers.uniform_random(var_shape_int32) - - """ check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random') if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random') - def get_new_shape_tensor(list_shape): - new_shape_tensor = [] - for dim in list_shape: - if isinstance(dim, Variable): - dim.stop_gradient = True - new_shape_tensor.append(dim) - else: - assert (isinstance(dim, int)) - temp_out = helper.create_variable_for_type_inference('int64') - fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out) - new_shape_tensor.append(temp_out) - return new_shape_tensor - - def get_attr_shape(list_shape): - unk_dim_idx = -1 - attrs_shape = [] - for dim_idx, dim_size in enumerate(list_shape): - if isinstance(dim_size, Variable): - attrs_shape.append(-1) - else: - attrs_shape.append(dim_size) - assert dim_size > 0, ( - "Each dimension size given in shape must not be negative " - "except one unknown dimension.") - return attrs_shape - helper = LayerHelper("uniform_random", **locals()) inputs = dict() attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} if in_dygraph_mode(): attrs['shape'] = shape else: - if isinstance(shape, Variable): - shape.stop_gradient = True - inputs["ShapeTensor"] = shape - elif isinstance(shape, (list, tuple)): - assert len(shape) > 0, ( - "The size of argument(shape) can't be zero.") - attrs["shape"] = get_attr_shape(shape) - if utils._contain_var(shape): - inputs['ShapeTensorList'] = get_new_shape_tensor(shape) + utils._get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random') out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="uniform_random", inputs=inputs, attrs=attrs, outputs={"Out": out}) - - return helper.append_activation(out) + return out def unbind(input, axis=0): diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index a7f10584b73f9..10424089f7283 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -719,12 +719,8 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): 'fill_constant') helper = LayerHelper("fill_constant", **locals()) - inputs = utils._get_shape_tensor_inputs( - inputs=inputs, - helper=helper, - attrs=attrs, - shape=shape, - op_type='fill_constant') + utils._get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type='fill_constant') if out is None: out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index df00a0f561ffc..88d57f4873cf2 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -282,7 +282,7 @@ def _contain_var(list_or_tuple): return False -def _get_shape_tensor_inputs(inputs, helper, attrs, shape, op_type): +def _get_shape_tensor_inputs(inputs, attrs, shape, op_type): from .tensor import fill_constant, cast def _get_attr_shape(list_shape): @@ -295,7 +295,7 @@ def _get_attr_shape(list_shape): return attr_shape def _get_shape_tensor(list_shape): - new_shape_tensor = [] + shape_tensor_list = [] for idx, dim in enumerate(list_shape): if isinstance(dim, Variable): dim.stop_gradient = True @@ -305,11 +305,11 @@ def _get_shape_tensor(list_shape): '(When type of shape in' + op_type + 'is list or tuple.)') if convert_dtype(dim.dtype) == 'int64': dim = cast(x=dim, dtype='int32') - new_shape_tensor.append(dim) + shape_tensor_list.append(dim) else: temp_out = fill_constant([1], 'int32', dim, force_cpu=True) - new_shape_tensor.append(temp_out) - return new_shape_tensor + shape_tensor_list.append(temp_out) + return shape_tensor_list if isinstance(shape, Variable): shape.stop_gradient = True @@ -325,8 +325,8 @@ def _get_shape_tensor(list_shape): attrs["shape"] = _get_attr_shape(shape) if _contain_var(shape): inputs['ShapeTensorList'] = _get_shape_tensor(shape) - - return inputs + else: + raise TypeError("Shape only supports Variable, or list, or tuple.") def _convert_to_tensor_list(old_list, dtype="int32"): diff --git a/python/paddle/fluid/tests/unittests/test_rand_op.py b/python/paddle/fluid/tests/unittests/test_rand_op.py index 4725e2fae2d0b..c8e0130b77dc6 100644 --- a/python/paddle/fluid/tests/unittests/test_rand_op.py +++ b/python/paddle/fluid/tests/unittests/test_rand_op.py @@ -47,71 +47,73 @@ def test_dtype(): self.assertRaises(TypeError, test_dtype) - def test_shape_list(): - rand(shape=[2.]) - - self.assertRaises(TypeError, test_shape_list) - - def test_shape_list2(): - rand(shape=[2, 3.]) - - self.assertRaises(TypeError, test_shape_list2) - - def test_device(): - rand(shape=[3, 4], device='device') - - self.assertRaises(ValueError, test_device) - class TestRandOp(unittest.TestCase): """ This class test the common usages of randop. - """ - def test_run(self): - use_cuda = False + def run_net(self, use_cuda=False): place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() exe = fluid.Executor(place) train_program = fluid.Program() startup_program = fluid.Program() with fluid.program_guard(train_program, startup_program): - result_1 = rand(shape=[3, 4]) + result_0 = rand([3, 4]) + result_1 = rand([3, 4], 'float64') + dim_1 = fluid.layers.fill_constant([1], "int64", 3) dim_2 = fluid.layers.fill_constant([1], "int32", 5) result_2 = rand(shape=[dim_1, dim_2]) + var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") result_3 = rand(var_shape) + var_shape_int32 = fluid.data( name='var_shape_int32', shape=[2], dtype="int32") result_4 = rand(var_shape_int32) + exe.run(startup_program) x1 = np.array([3, 2]).astype('int64') x2 = np.array([4, 3]).astype('int32') - ret = exe.run(train_program, - feed={"var_shape": x1, - "var_shape_int32": x2}, - fetch_list=[result_1, result_2, result_3, result_4]) + ret = exe.run( + train_program, + feed={"var_shape": x1, + "var_shape_int32": x2}, + fetch_list=[result_1, result_1, result_2, result_3, result_4]) + + def test_run(self): + self.run_net(False) + if core.is_compiled_with_cuda(): + self.run_net(True) class TestRandOpForDygraph(unittest.TestCase): """ This class test the common usages of randop. - """ - def test_run(self): - use_cuda = False - with fluid.dygraph.guard(): - rand(shape=[3, 4]) + def run_net(self, use_cuda=False): + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + with fluid.dygraph.guard(place): + rand([3, 4]) + + rand([3, 4], 'float64') + dim_1 = fluid.layers.fill_constant([1], "int64", 3) dim_2 = fluid.layers.fill_constant([1], "int32", 5) rand(shape=[dim_1, dim_2]) + var_shape = fluid.dygraph.to_variable(np.array([3, 4])) rand(var_shape) + def test_run(self): + self.run_net(False) + if core.is_compiled_with_cuda(): + self.run_net(True) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index feb2f6afd000d..e04519a4bd7db 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -406,7 +406,7 @@ def randperm(n, return out -def rand(shape, out=None, dtype=None, device=None, stop_gradient=True): +def rand(shape, dtype=None, name=None): """ :alias_main: paddle.rand :alias: paddle.rand,paddle.tensor.rand,paddle.tensor.random.rand @@ -424,22 +424,19 @@ def rand(shape, out=None, dtype=None, device=None, stop_gradient=True): result=[[0.8505902, 0.8397286]] Args: - shape(list|tuple|Variable): Shape of the Tensor to be created. - The data type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, - the elements of it should be integers or Tensors with shape [1]. - If ``shape`` is a Variable, it should be an 1-D Tensor . - out(Variable, optional): Optional output which can be any created - Variable that meets the requirements to store the result of operation. - if out is None, a new Varibale will be create to store the result. - dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output tensor - which can be float32, float64, if dytpe is `None`, the data - type of created tensor is `float32` - device(str, optional): This parameter specifies that the Tensor is created - on the GPU or CPU. - stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable, - default value is True. + shape(list|tuple|Variable): Shape of the Tensor to be created. The data + type is ``int32`` or ``int64`` . If ``shape`` is a list or tuple, + the elements of it should be integers or Tensors with shape [1]. If + ``shape`` is a Variable, it should be an 1-D Tensor . + dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the + output tensor which can be float32, float64, if dytpe is `None`, + the data type of created tensor is `float32` + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. Returns: - Variable: A Tensor of the specified shape filled with random numbers from a uniform distribution on the interval [0, 1). + Variable: A Tensor of the specified shape filled with random numbers + from a uniform distribution on the interval [0, 1). Raises: TypeError: The shape type should be list or tupple or Variable. @@ -447,54 +444,34 @@ def rand(shape, out=None, dtype=None, device=None, stop_gradient=True): Examples: .. code-block:: python - import paddle - import paddle.fluid as fluid - - # example 1: - # attr shape is a list which doesn't contain tensor Variable. - result_1 = paddle.rand(shape=[3, 4]) - - # example 2: - # attr shape is a list which contains tensor Variable. - dim_1 = fluid.layers.fill_constant([1],"int64",3) - dim_2 = fluid.layers.fill_constant([1],"int32",5) - result_2 = paddle.rand(shape=[dim_1, dim_2]) + import paddle + import numpy as np + + paddle.enable_imperative() + + # example 1: attr shape is a list which doesn't contain tensor Variable. + result_1 = paddle.rand(shape=[2, 3]) + # [[0.451152 , 0.55825245, 0.403311 ], + # [0.22550228, 0.22106001, 0.7877319 ]] + + # example 2: attr shape is a list which contains tensor Variable. + dim_1 = paddle.fill_constant([1], "int64", 2) + dim_2 = paddle.fill_constant([1], "int32", 3) + result_2 = paddle.rand(shape=[dim_1, dim_2, 2]) + # [[[0.8879919 0.25788337] + # [0.28826773 0.9712097 ] + # [0.26438272 0.01796806]] + # [[0.33633623 0.28654453] + # [0.79109055 0.7305809 ] + # [0.870881 0.2984597 ]]] + + # example 3: attr shape is a Variable, the data type must be int64 or int32. + var_shape = paddle.imperative.to_variable(np.array([2, 3])) + result_3 = paddle.rand(var_shape) + # [[0.22920267 0.841956 0.05981819] + # [0.4836288 0.24573246 0.7516129 ]] - # example 3: - # attr shape is a Variable, the data type must be int64 or int32. - var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64") - result_3 = paddle.rand(var_shape) - var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32") - result_4 = paddle.rand(var_shape_int32) """ if dtype is None: dtype = 'float32' - - check_dtype(dtype, 'dtype', ['float32', 'float64'], 'rand') - - check_type(shape, 'shape', (Variable, list, tuple), 'rand') - if isinstance(shape, Variable): - check_variable_and_dtype(shape, 'shape', ['int32', 'int64'], 'rand') - elif isinstance(shape, (list, tuple)): - for i, _shape in enumerate(shape): - if not isinstance(_shape, Variable): - check_type(_shape, '_shape', (int), 'rand') - else: - check_variable_and_dtype(_shape, 'shape[' + str(i) + ']', - ['int32', 'int64'], 'rand') - - if device not in [None, 'cpu', 'gpu']: - raise ValueError( - "The input device should in [None, 'cpu', 'gpu'], but received {}". - format(device)) - - helper = LayerHelper("rand", **locals()) - if out is None: - out = helper.create_variable_for_type_inference(dtype=dtype) - else: - check_variable_and_dtype(out, 'out', [dtype], 'rand') - out.stop_gradient = stop_gradient - - with device_guard(device): - out = uniform_random(shape, dtype, min=0., max=1.0) - return out + return uniform_random(shape, dtype, min=0.0, max=1.0, name=name) From 5144a1976ddd4dace1c64125213dfd3bc3b96b33 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 3 Jul 2020 09:56:20 +0800 Subject: [PATCH 2/4] test=develop --- python/paddle/fluid/layers/nn.py | 17 ++++++++++------- python/paddle/fluid/layers/tensor.py | 7 +------ python/paddle/fluid/layers/utils.py | 13 +++++++++++++ python/paddle/tensor/random.py | 2 -- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index dc253804a0103..bdb029872087a 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14920,20 +14920,23 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, result_4 = fluid.layers.uniform_random(var_shape_int32) """ - check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random') if not isinstance(dtype, core.VarDesc.VarType): dtype = convert_np_dtype_to_dtype_(dtype) + + if in_dygraph_mode(): + shape = utils._convert_shape_to_list(shape) + return core.ops.uniform_random('shape', shape, 'min', min, 'max', max, + 'seed', seed, 'dtype', dtype) + + check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random') check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random') - helper = LayerHelper("uniform_random", **locals()) inputs = dict() attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype} - if in_dygraph_mode(): - attrs['shape'] = shape - else: - utils._get_shape_tensor_inputs( - inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random') + utils._get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=shape, op_type='uniform_random') + helper = LayerHelper("uniform_random", **locals()) out = helper.create_variable_for_type_inference(dtype) helper.append_op( type="uniform_random", inputs=inputs, attrs=attrs, diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 10424089f7283..5bae5a8c62abb 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -685,12 +685,7 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None): attrs['str_value'] = str(float(value)) if in_dygraph_mode(): - if isinstance(shape, (list, tuple)): - shape = list( - map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x, - shape)) - else: - shape = list(shape.numpy().astype(int)) + shape = utils._convert_shape_to_list(shape) if out is None: out = _varbase_creator(dtype=dtype) diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index 88d57f4873cf2..0d6965239e14b 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -345,3 +345,16 @@ def _convert_to_tensor_list(old_list, dtype="int32"): temp_out = fill_constant([1], dtype, ele, force_cpu=True) new_list_tensor.append(temp_out) return new_list_tensor + + +def _convert_shape_to_list(shape): + """ + Convert shape(list, tuple, variable) to list in imperative mode + """ + if isinstance(shape, (list, tuple)): + shape = list( + map(lambda x: x.numpy()[0] if isinstance(x, Variable) else x, + shape)) + else: + shape = list(shape.numpy().astype(int)) + return shape diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index e04519a4bd7db..67198ce46a36f 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -447,8 +447,6 @@ def rand(shape, dtype=None, name=None): import paddle import numpy as np - paddle.enable_imperative() - # example 1: attr shape is a list which doesn't contain tensor Variable. result_1 = paddle.rand(shape=[2, 3]) # [[0.451152 , 0.55825245, 0.403311 ], From a911e0ad3d7c7075f4c19761dffae5a35c045e44 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Fri, 3 Jul 2020 17:27:08 +0800 Subject: [PATCH 3/4] test=develop --- python/paddle/fluid/layers/nn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bdb029872087a..46d6c39621779 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -14925,8 +14925,9 @@ def uniform_random(shape, dtype='float32', min=-1.0, max=1.0, seed=0, if in_dygraph_mode(): shape = utils._convert_shape_to_list(shape) - return core.ops.uniform_random('shape', shape, 'min', min, 'max', max, - 'seed', seed, 'dtype', dtype) + return core.ops.uniform_random('shape', shape, 'min', + float(min), 'max', + float(max), 'seed', seed, 'dtype', dtype) check_type(shape, 'shape', (list, tuple, Variable), 'uniform_random') check_dtype(dtype, 'dtype', ('float32', 'float64'), 'uniform_random') From e5863fa9bc240f9714bf87e2d409b07bf0be869a Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Mon, 6 Jul 2020 13:44:18 +0800 Subject: [PATCH 4/4] test=develop --- python/paddle/tensor/random.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 67198ce46a36f..16bbb09b4f309 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -447,6 +447,7 @@ def rand(shape, dtype=None, name=None): import paddle import numpy as np + paddle.enable_imperative() # example 1: attr shape is a list which doesn't contain tensor Variable. result_1 = paddle.rand(shape=[2, 3]) # [[0.451152 , 0.55825245, 0.403311 ],