Skip to content

Commit

Permalink
add PADDLE_ENFORCE_EQ,test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
tink2123 committed Aug 20, 2020
1 parent a2a22f5 commit 0a2aad0
Show file tree
Hide file tree
Showing 5 changed files with 460 additions and 23 deletions.
12 changes: 11 additions & 1 deletion paddle/fluid/operators/interpolate_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ static void Interpolate1DInferShapeCheck(framework::InferShapeContext* ctx) {
out_w = -1;
} else {
float scale_w = ctx->Attrs().Get<float>("scale_w");
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
if (scale_w > 0) {
// round down
out_w = (data_layout == DataLayout::kNCHW
Expand Down Expand Up @@ -157,6 +160,10 @@ static void Interpolate2DInferShapeCheck(framework::InferShapeContext* ctx) {
} else {
float scale_h = ctx->Attrs().Get<float>("scale_h");
float scale_w = ctx->Attrs().Get<float>("scale_w");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
if (scale_h > 0 && scale_w > 0) {
// round down
out_h = (data_layout == DataLayout::kNCHW
Expand Down Expand Up @@ -255,7 +262,10 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) {
float scale_d = ctx->Attrs().Get<float>("scale_d");
float scale_h = ctx->Attrs().Get<float>("scale_h");
float scale_w = ctx->Attrs().Get<float>("scale_w");

PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
if (scale_d > 0 && scale_h > 0 && scale_w > 0) {
// round down
out_d = (data_layout == DataLayout::kNCHW
Expand Down
44 changes: 44 additions & 0 deletions paddle/fluid/operators/interpolate_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -846,8 +846,14 @@ static void Interpolate1DCUDAFwd(const framework::ExecutionContext& ctx,
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
Expand Down Expand Up @@ -933,9 +939,17 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_w = ctx.Attr<float>("scale_w");
scale_h = ctx.Attr<float>("scale_h");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_w > 0 && scale_h > 0) {
out_h = static_cast<int>(in_h * scale_h);
Expand Down Expand Up @@ -1050,10 +1064,18 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_d = ctx.Attr<float>("scale_d");
scale_h = ctx.Attr<float>("scale_h");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_d > 0 && scale_h > 0 && scale_w > 0) {
out_d = static_cast<int>(in_d * scale_d);
Expand Down Expand Up @@ -1147,8 +1169,14 @@ static void Interpolate1DCUDABwd(const framework::ExecutionContext& ctx,
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
Expand Down Expand Up @@ -1233,9 +1261,17 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_w = ctx.Attr<float>("scale_w");
scale_h = ctx.Attr<float>("scale_h");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_w > 0 && scale_h > 0) {
out_h = static_cast<int>(in_h * scale_h);
Expand Down Expand Up @@ -1348,10 +1384,18 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_d = ctx.Attr<float>("scale_d");
scale_h = ctx.Attr<float>("scale_h");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_d > 0 && scale_h > 0 && scale_w > 0) {
out_d = static_cast<int>(in_d * scale_d);
Expand Down
60 changes: 52 additions & 8 deletions paddle/fluid/operators/interpolate_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -792,12 +792,18 @@ static void Interpolate1DCPUFwd(const framework::ExecutionContext& ctx,
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale = ctx.Attr<float>("scale_w");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
if (scale_w > 0) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Expand Down Expand Up @@ -866,9 +872,17 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_w = ctx.Attr<float>("scale_w");
scale_h = ctx.Attr<float>("scale_h");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_h > 0 && scale_w > 0) {
out_h = static_cast<int>(in_h * scale_h);
Expand Down Expand Up @@ -963,10 +977,18 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_d = ctx.Attr<float>("scale_d");
scale_h = ctx.Attr<float>("scale_h");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_w > 0 && scale_h > 0 && scale_d > 0) {
out_d = static_cast<int>(in_d * scale_d);
Expand Down Expand Up @@ -1046,12 +1068,18 @@ static void Interpolate1DCPUBwd(const framework::ExecutionContext& ctx,
auto scale_tensor = ctx.Input<Tensor>("Scale");
if (scale_tensor != nullptr) {
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
scale = scale_data[0];
scale_w = scale_data[0];
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale = ctx.Attr<float>("scale_w");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(scale_w > 0, true, platform::errors::InvalidArgument(
"scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale > 0) {
out_w = static_cast<int>(in_w * scale);
if (scale_w > 0) {
out_w = static_cast<int>(in_w * scale_w);
}
auto out_size = ctx.Input<Tensor>("OutSize");
if (out_size != nullptr) {
Expand Down Expand Up @@ -1120,9 +1148,17 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx,
scale_w = scale_data[0];
scale_h = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_h = ctx.Attr<float>("scale_h");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_h > 0 && scale_w > 0) {
out_h = static_cast<int>(in_h * scale_h);
Expand Down Expand Up @@ -1216,10 +1252,18 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx,
scale_h = scale_data[0];
scale_w = scale_data[0];
}
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
} else {
scale_d = ctx.Attr<float>("scale_d");
scale_h = ctx.Attr<float>("scale_h");
scale_w = ctx.Attr<float>("scale_w");
PADDLE_ENFORCE_EQ(
scale_w > 0 && scale_h > 0 && scale_d > 0, true,
platform::errors::InvalidArgument("scale of Op(interpolate) "
"should be greater than 0."));
}
if (scale_d > 0 && scale_h > 0 && scale_w > 0) {
out_d = static_cast<int>(in_d * scale_d);
Expand Down
Loading

1 comment on commit 0a2aad0

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 Commit ID: 0a2aad0 contains failed CI.

Please sign in to comment.