From 8db97275bc34ad4d785c5b04c29b2635799baf0e Mon Sep 17 00:00:00 2001 From: hxzd5568 Date: Sat, 11 May 2024 10:34:45 +0000 Subject: [PATCH] [CINN] update check expr in arange --- paddle/cinn/hlir/op/elementwise.cc | 51 +++++++++++++++++++----------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index f206fe7af7cc97..f7baef34ede472 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -1371,10 +1371,22 @@ std::shared_ptr StrategyForArangeSymbolic( const std::vector> &output_shapes, const Target &target) { auto attr_store = attrs.attr_store; - CHECK(attr_store.count("start")); - CHECK(attr_store.count("stop")); - CHECK(attr_store.count("step")); - CHECK(attr_store.count("dtype")); + PADDLE_ENFORCE_GT(attr_store.count("start"), + 0U, + ::common::errors::InvalidArgument( + "No start attribute in arange Op! Please check.")); + PADDLE_ENFORCE_GT(attr_store.count("stop"), + 0U, + ::common::errors::InvalidArgument( + "No stop attribute in arange Op! Please check.")); + PADDLE_ENFORCE_GT(attr_store.count("start"), + 0U, + ::common::errors::InvalidArgument( + "No end attribute in arange Op! Please check.")); + PADDLE_ENFORCE_GT(attr_store.count("start"), + 0U, + ::common::errors::InvalidArgument( + "No dtype attribute in arange Op! Please check.")); auto start = absl::get(attr_store.at("start")); auto stop = absl::get(attr_store.at("stop")); @@ -1382,22 +1394,25 @@ std::shared_ptr StrategyForArangeSymbolic( auto dtype = cinn::common::Str2Type(absl::get(attr_store.at("dtype"))); - framework::CINNCompute arange_compute( - [=](lang::Args args, lang::RetValue *ret) { - CHECK(!args.empty()) - << "The input argument of arange compute is empty! Please check.\n"; - CINNValuePack pack_args = args[0]; + framework::CINNCompute arange_compute([=](lang::Args args, + lang::RetValue *ret) { + PADDLE_ENFORCE_EQ( + !args.empty(), + true, + ::common::errors::InvalidArgument( + "The input argument of arange compute is empty! Please check.")); + CINNValuePack pack_args = args[0]; - CHECK_EQ(pack_args.size(), 1U); - std::string tensor_name = pack_args[0].operator std::string(); + CHECK_EQ(pack_args.size(), 1U); + std::string tensor_name = pack_args[0].operator std::string(); - auto out = pe::Arange(start, stop, step, dtype, tensor_name); - std::vector res; - auto stages = CreateStages({out}); - res.push_back(cinn::common::CINNValue(out)); - res.push_back(cinn::common::CINNValue(stages)); - *ret = CINNValuePack{res}; - }); + auto out = pe::Arange(start, stop, step, dtype, tensor_name); + std::vector res; + auto stages = CreateStages({out}); + res.push_back(cinn::common::CINNValue(out)); + res.push_back(cinn::common::CINNValue(stages)); + *ret = CINNValuePack{res}; + }); auto strategy = std::make_shared(); strategy->AddImpl(