Skip to content

Commit

Permalink
[CINN] update check expr in arange
Browse files Browse the repository at this point in the history
  • Loading branch information
hxzd5568 committed May 11, 2024
1 parent 6354af2 commit 8db9727
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1371,33 +1371,48 @@ std::shared_ptr<framework::OpStrategy> StrategyForArangeSymbolic(
const std::vector<std::vector<ir::Dim>> &output_shapes,
const Target &target) {
auto attr_store = attrs.attr_store;
CHECK(attr_store.count("start"));
CHECK(attr_store.count("stop"));
CHECK(attr_store.count("step"));
CHECK(attr_store.count("dtype"));
PADDLE_ENFORCE_GT(attr_store.count("start"),
0U,
::common::errors::InvalidArgument(
"No start attribute in arange Op! Please check."));
PADDLE_ENFORCE_GT(attr_store.count("stop"),
0U,
::common::errors::InvalidArgument(
"No stop attribute in arange Op! Please check."));
PADDLE_ENFORCE_GT(attr_store.count("start"),
0U,
::common::errors::InvalidArgument(
"No end attribute in arange Op! Please check."));
PADDLE_ENFORCE_GT(attr_store.count("start"),
0U,
::common::errors::InvalidArgument(
"No dtype attribute in arange Op! Please check."));

auto start = absl::get<float>(attr_store.at("start"));
auto stop = absl::get<float>(attr_store.at("stop"));
auto step = absl::get<float>(attr_store.at("step"));
auto dtype =
cinn::common::Str2Type(absl::get<std::string>(attr_store.at("dtype")));

framework::CINNCompute arange_compute(
[=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty())
<< "The input argument of arange compute is empty! Please check.\n";
CINNValuePack pack_args = args[0];
framework::CINNCompute arange_compute([=](lang::Args args,
lang::RetValue *ret) {
PADDLE_ENFORCE_EQ(
!args.empty(),
true,
::common::errors::InvalidArgument(
"The input argument of arange compute is empty! Please check."));
CINNValuePack pack_args = args[0];

CHECK_EQ(pack_args.size(), 1U);
std::string tensor_name = pack_args[0].operator std::string();
CHECK_EQ(pack_args.size(), 1U);
std::string tensor_name = pack_args[0].operator std::string();

auto out = pe::Arange(start, stop, step, dtype, tensor_name);
std::vector<cinn::common::CINNValue> res;
auto stages = CreateStages({out});
res.push_back(cinn::common::CINNValue(out));
res.push_back(cinn::common::CINNValue(stages));
*ret = CINNValuePack{res};
});
auto out = pe::Arange(start, stop, step, dtype, tensor_name);
std::vector<cinn::common::CINNValue> res;
auto stages = CreateStages({out});
res.push_back(cinn::common::CINNValue(out));
res.push_back(cinn::common::CINNValue(stages));
*ret = CINNValuePack{res};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(
Expand Down

0 comments on commit 8db9727

Please sign in to comment.