Skip to content

Commit

Permalink
[Yaml] add yaml for Uniform random and add unit test. (#41517)
Browse files Browse the repository at this point in the history
* gather op

* add mod

* [Yaml] final state for uniform and uniform_random
  • Loading branch information
2742195759 committed Apr 11, 2022
1 parent 9107dc6 commit cd2a4cd
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 70 deletions.
76 changes: 7 additions & 69 deletions paddle/fluid/operators/uniform_random_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/infermeta/nullary.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -122,74 +124,6 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "UniformRandomOp");

PADDLE_ENFORCE_LT(
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"The uniform_random's min must less then max. But received min = "
"%f great than or equal max = %f.",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_num must greater than or "
"equal 0. But recevied diag_num (%d) < 0.",
ctx->Attrs().Get<int>("diag_num")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_step must greater than or "
"equal 0. But recevied diag_step (%d) < 0.",
ctx->Attrs().Get<int>("diag_step")));

if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensorList");
PADDLE_ENFORCE_GT(inputs_name.size(), 0,
platform::errors::InvalidArgument(
"Input(ShapeTensorList)'size of "
"Op(uniform_random) can't be zero."
"Please check the Attr(shape)'s size of"
"Op(fluid.layers.uniform_random).)"));
auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims));

return;
}
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
PADDLE_ENFORCE_EQ(
shape_dims.size(), 1,
platform::errors::InvalidArgument(
"ShapeError: Input(ShapeTensor)' dimension size of "
"Op(uniform_random) must be 1."
"But received ShapeTensor's dimensions = %d, shape = [%s]",
shape_dims.size(), shape_dims));
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int64_t>(num_ele, -1);
auto out_dims = phi::make_ddim(vec_dims);
ctx->SetOutputDim("Out", out_dims);
return;
}

PADDLE_ENFORCE_EQ(shape.empty(), false,
platform::errors::InvalidArgument(
"if there is no Input(ShapeTensorList) and no "
"Input(ShapeTensor),the "
"attr(shape) information must "
"be set by Attr(shape)."));
std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size());
for (auto dim : shape) {
tensor_shape.push_back(static_cast<int64_t>(dim));
}
ctx->SetOutputDim("Out", phi::make_ddim(tensor_shape));
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand Down Expand Up @@ -274,12 +208,16 @@ class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
} // namespace operators
} // namespace paddle

DECLARE_INFER_SHAPE_FUNCTOR(uniform_random, UniformRandomInferShapeFunctor,
PD_INFER_META(phi::UniformRandomInferMeta));

REGISTER_OPERATOR(
uniform_random, paddle::operators::UniformRandomOp,
paddle::operators::UniformRandomOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::UniformRandomOpVarTypeInference);
paddle::operators::UniformRandomOpVarTypeInference,
UniformRandomInferShapeFunctor);

REGISTER_OP_CPU_KERNEL(
uniform_random_batch_size_like,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/infermeta/nullary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,18 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) {
out->set_dtype(dtype);
}

void UniformRandomInferMeta(const IntArray& shape,
DataType dtype,
float min,
float max,
int seed,
MetaTensor* out) {
auto out_dims = phi::make_ddim(shape.GetData());
out->set_dims(out_dims);
out->set_dtype(dtype);
out->set_layout(DataLayout::NCHW);
}

void RandintInferMeta(
int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out) {
PADDLE_ENFORCE_NOT_NULL(
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/infermeta/nullary.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,11 @@ void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
DataType dtype,
MetaTensor* out);

void UniformRandomInferMeta(const IntArray& shape,
DataType dtype,
float min,
float max,
int seed,
MetaTensor* out);

} // namespace phi
18 changes: 18 additions & 0 deletions python/paddle/fluid/tests/unittests/test_uniform_random_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.framework import _test_eager_guard


def output_hist(out):
Expand All @@ -52,6 +53,7 @@ def output_hist_diag(out):
class TestUniformRandomOp_attr_tensorlist(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.new_shape = (1000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
Expand Down Expand Up @@ -84,6 +86,7 @@ def init_attrs(self):
class TestUniformRandomOp_attr_tensorlist_int32(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.new_shape = (1000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
Expand All @@ -110,6 +113,7 @@ def verify_output(self, outs):
class TestUniformRandomOp_attr_tensor(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int64")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
Expand All @@ -131,6 +135,7 @@ def verify_output(self, outs):
class TestUniformRandomOp_attr_tensor_int32(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.inputs = {"ShapeTensor": np.array([1000, 784]).astype("int32")}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
Expand All @@ -152,6 +157,7 @@ def verify_output(self, outs):
class TestUniformRandomOp(OpTest):
def setUp(self):
self.op_type = "uniform_random"
self.python_api = paddle.uniform
self.inputs = {}
self.init_attrs()
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
Expand All @@ -174,6 +180,18 @@ def verify_output(self, outs):
np.allclose(
hist, prob, rtol=0, atol=0.01), "hist: " + str(hist))

def test_check_api(self):
places = self._get_places()
for place in places:
with fluid.dygraph.base.guard(place=place):
out = self.python_api(self.attrs['shape'], 'float32',
self.attrs['min'], self.attrs['max'],
self.attrs['seed'])

def test_check_api_eager(self):
with _test_eager_guard():
self.test_check_api()


class TestUniformRandomOpError(unittest.TestCase):
def test_errors(self):
Expand Down
9 changes: 8 additions & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,14 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)

if paddle.in_dynamic_mode():
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
return _C_ops.final_state_uniform_random(shape, dtype,
float(min),
float(max), seed,
_current_expected_place())

if _in_legacy_dygraph():
shape = utils.convert_shape_to_list(shape)
return _C_ops.uniform_random('shape', shape, 'min',
float(min), 'max',
Expand Down
12 changes: 12 additions & 0 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2035,6 +2035,18 @@
func : unfold
backward : unfold_grad

- api : uniform_random
args : (IntArray shape, DataType dtype, float min, float max, int seed, Place place={})
output : Tensor(out)
infer_meta :
func : UniformRandomInferMeta
param: [shape, dtype, min, max, seed]
kernel :
func : uniform_random
param: [shape, dtype, min, max, seed]
data_type : dtype
backend : place

# The `axis` argument of Python API paddle.unique is not vector
- api : unique
args : (Tensor x, bool return_index, bool return_inverse, bool return_counts, int[] axis, DataType dtype=DataType::INT64)
Expand Down

0 comments on commit cd2a4cd

Please sign in to comment.