Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Error Message No.18】 part 3 of paddle/cinn/frontend/op_mappers/* -part #64409

Merged
merged 5 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions paddle/cinn/frontend/op_mappers/paddle/relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ReluOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Relu op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Relu op must be 1."));
auto out_name = op_desc.Output("Out").front();
auto x = ctx.GetVar(x_name);
auto out = ctx.Builder()->Relu(x);
Expand All @@ -34,9 +40,15 @@ void ReluOpMapper(const paddle::cpp::OpDesc& op_desc,

void Relu6OpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Relu6 op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Relu6 op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto threshold = utils::GetAttrOrDefault<float>(op_desc, "threshold", 6.0f);
Expand All @@ -49,11 +61,20 @@ void Relu6OpMapper(const paddle::cpp::OpDesc& op_desc,

void ReluGradOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input(paddle::GradVarName("Out")).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input(paddle::GradVarName("Out")).size(),
1UL,
phi::errors::InvalidArgument("The input of ReluGrad op must be 1."));
auto dout_name = op_desc.Input(paddle::GradVarName("Out")).front();
CHECK_EQ(op_desc.Input("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Out").size(),
1UL,
phi::errors::InvalidArgument("The input of ReluGrad op must be 1."));
auto out_name = op_desc.Input("Out").front();
CHECK_EQ(op_desc.Output(paddle::GradVarName("X")).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output(paddle::GradVarName("X")).size(),
1UL,
phi::errors::InvalidArgument("The output of ReluGrad op must be 1."));
auto dx_name = op_desc.Output(paddle::GradVarName("X")).front();

auto dout = ctx.GetVar(dout_name);
Expand Down
47 changes: 37 additions & 10 deletions paddle/cinn/frontend/op_mappers/paddle/reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Reshape op must be 1."));
auto x_name = op_desc.Input("X").front();
auto x = ctx.GetVar(x_name);

Expand All @@ -33,7 +36,10 @@ void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc,

auto out = ctx.Builder()->Reshape(x, shape);

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Reshape op must be 1."));
auto out_name = op_desc.Output("Out").front();
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
Expand All @@ -42,13 +48,19 @@ void ReshapeOpMapper(const paddle::cpp::OpDesc& op_desc,
void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
auto get_input_var = [&op_desc, &ctx](const std::string& op_name) {
CHECK_EQ(op_desc.Input(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input(op_name).size(),
1UL,
phi::errors::InvalidArgument("The input of ReshapeGrad op must be 1."));
auto var_name = op_desc.Input(op_name).front();
return ctx.GetVar(var_name);
};

auto get_output_name = [&op_desc](const std::string& op_name) {
CHECK_EQ(op_desc.Output(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The output of ReshapeGrad op must be 1."));
return op_desc.Output(op_name).front();
};

Expand All @@ -67,7 +79,10 @@ void ReshapeGradOpMapper(const paddle::cpp::OpDesc& op_desc,

void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Reshape2 op must be 1."));
auto x_name = op_desc.Input("X").front();
auto x = ctx.GetVar(x_name);

Expand All @@ -78,7 +93,10 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,

auto out = ctx.Builder()->Reshape(x, shape);

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Reshape2 op must be 1."));
auto out_name = op_desc.Output("Out").front();
ctx.AddVar(out_name, out);
ctx.AddVarModelToProgram(out_name, out->id);
Expand All @@ -89,7 +107,10 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,
// will be used in Reshape_grad, in this way, the framework can reuse
// the memory of X immediately the Reshape2_op is finished.
// Considering compatibility issues, we could not fix Reshape2_op
CHECK_EQ(op_desc.Output("XShape").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("XShape").size(),
1UL,
phi::errors::InvalidArgument("The output of Reshape2 op must be 1."));
auto xshape_name = op_desc.Output("XShape").front();

auto xshape = ctx.Builder()->Identity(x);
Expand All @@ -102,13 +123,19 @@ void Reshape2OpMapper(const paddle::cpp::OpDesc& op_desc,
void Reshape2GradOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
auto get_input_var = [&op_desc, &ctx](const std::string& op_name) {
CHECK_EQ(op_desc.Input(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The input of Reshape2Grad op must be 1."));
auto var_name = op_desc.Input(op_name).front();
return ctx.GetVar(var_name);
};

auto get_output_name = [&op_desc](const std::string& op_name) {
CHECK_EQ(op_desc.Output(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The output of Reshape2Grad op must be 1."));
return op_desc.Output(op_name).front();
};

Expand Down
12 changes: 9 additions & 3 deletions paddle/cinn/frontend/op_mappers/paddle/reverse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ReverseOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Reverse op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Reverse op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto axes = utils::GetAttrOrDefault<std::vector<int>>(
Expand Down
47 changes: 34 additions & 13 deletions paddle/cinn/frontend/op_mappers/paddle/roll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,35 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/frontend/var_type_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void RollOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
// input
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Roll op must be 1."));
auto x_name = op_desc.Input("X").front();
// output
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Roll op must be 1."));
auto out_name = op_desc.Output("Out").front();

// attr shifts and axis
CHECK(op_desc.HasAttr("shifts"));
CHECK(op_desc.HasAttr("axis"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("shifts"),
true,
phi::errors::InvalidArgument("Roll op must have shifts attribute"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("axis"),
true,
phi::errors::InvalidArgument("Roll op must have axis attribute"));
std::vector<int> shifts = utils::ToShapeType(
utils::GetAttrOrDefault<std::vector<int64_t>>(op_desc, "shifts", {1}));
std::vector<int> axis = utils::ToShapeType(
Expand All @@ -44,8 +56,11 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc,
// check axis and shifts and when axis is None, we should flatten x.
bool axis_None = false;
if (axis.size() == 0) {
CHECK_EQ(shifts.size(), 1)
<< "shifts.size() should be equal to 1 when axis is None";
PADDLE_ENFORCE_EQ(
shifts.size(),
1UL,
phi::errors::InvalidArgument(
"shifts.size() should be equal to 1 when axis is None"));
axis.push_back(0);
axis_None = true;
int reshape_num = 1;
Expand All @@ -55,19 +70,25 @@ void RollOpMapper(const paddle::cpp::OpDesc& op_desc,
vec_x_dims = std::vector<int>{reshape_num};
x = ctx.Builder()->Reshape(x, vec_x_dims);
} else {
CHECK_EQ(shifts.size(), axis.size())
<< "shifts.size() should be equal to axis.size()";
PADDLE_ENFORCE_EQ(shifts.size(),
axis.size(),
phi::errors::InvalidArgument(
"shifts.size() should be equal to axis.size()"));
}

// preprocessing the shifts and axis
int shifts_size = shifts.size();
std::unordered_map<int, int> axis_to_shifts;
for (int i = 0; i < shifts_size; ++i) {
int vec_x_dims_size = vec_x_dims.size();
CHECK_GE(axis[i], -vec_x_dims_size)
<< "axis value should be >= " << -vec_x_dims_size;
CHECK_LT(axis[i], vec_x_dims_size)
<< "axis value should be < " << vec_x_dims_size;
PADDLE_ENFORCE_GE(axis[i],
-vec_x_dims_size,
phi::errors::InvalidArgument(
"axis value should be >= -vec_x_dims_size"));
PADDLE_ENFORCE_LT(
axis[i],
vec_x_dims_size,
phi::errors::InvalidArgument("axis value should be < vec_x_dims_size"));
if (axis[i] < 0) {
axis[i] += vec_x_dims_size;
}
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/frontend/op_mappers/paddle/scale.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/utils/string.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc,
const cinn::frontend::OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Scale op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Scale op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto bias = utils::GetAttrOrDefault<float>(op_desc, "bias", 0.0f);
Expand All @@ -38,7 +44,10 @@ void ScaleOpMapper(const paddle::cpp::OpDesc& op_desc,
absl::optional<Variable> out;
if (op_desc.HasInput("ScaleTensor") &&
!op_desc.Input("ScaleTensor").empty()) {
CHECK_EQ(op_desc.Input("ScaleTensor").size(), 1);
PADDLE_ENFORCE_EQ(
op_desc.Input("ScaleTensor").size(),
1UL,
phi::errors::InvalidArgument("The input of ScaleTensor must be 1."));
auto scale_name = op_desc.Input("ScaleTensor").front();
auto scale_tensor = ctx.GetVar(scale_name);

Expand Down
33 changes: 25 additions & 8 deletions paddle/cinn/frontend/op_mappers/paddle/scatter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,32 @@
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Scatter op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Input("Ids").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Ids").size(),
1UL,
phi::errors::InvalidArgument("The input of Scatter op must be 1."));
auto ids_name = op_desc.Input("Ids").front();
CHECK_EQ(op_desc.Input("Updates").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Updates").size(),
1UL,
phi::errors::InvalidArgument("The input of Scatter op must be 1."));
auto updates_name = op_desc.Input("Updates").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Scatter op must be 1."));
auto out_name = op_desc.Output("Out").front();

bool overwrite = utils::GetAttrOrDefault<bool>(op_desc, "overwrite", true);
Expand All @@ -38,16 +50,21 @@ void ScatterOpMapper(const paddle::cpp::OpDesc& op_desc,
const auto& input = ctx.GetVar(x_name);
auto indices = ctx.GetVar(ids_name);
const auto& updates = ctx.GetVar(updates_name);
CHECK(input->type == updates->type)
<< "checks whether the type of the input and the updates are the same.";
PADDLE_ENFORCE_EQ(input->type == updates->type,
true,
phi::errors::InvalidArgument(
"The type of input and updates should be the same."));
CHECK(indices->type == cinn::common::Int(32) ||
indices->type == cinn::common::Int(64))
<< "checks whether the data type of the indices is either int32 or int64";
if (indices->type == cinn::common::Int(64)) {
indices = ctx.Builder()->Cast(
indices, cinn::common::Type2Str(cinn::common::Int(32)));
}
CHECK_LE(indices->shape.size(), 2) << "Ids should be 0, 1 or 2 in scatter_op";
PADDLE_ENFORCE_LE(indices->shape.size(),
2UL,
phi::errors::InvalidArgument(
"The rank of indices should be less than 2."));
if (indices->shape.size() == 0) {
indices = ctx.Builder()->Reshape(indices, {1});
}
Expand Down
Loading