diff --git a/paddle/fluid/operators/channel_shuffle_op.cc b/paddle/fluid/operators/channel_shuffle_op.cc index b14e30c1209fb..74b2e04e63f70 100644 --- a/paddle/fluid/operators/channel_shuffle_op.cc +++ b/paddle/fluid/operators/channel_shuffle_op.cc @@ -1,21 +1,22 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/unary.h" namespace paddle { @@ -62,25 +63,6 @@ class ChannelShuffleOpMaker : public framework::OpProtoAndCheckerMaker { class ChannelShuffleGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Out")), true, - platform::errors::NotFound("Input(Out@Grad) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput(framework::GradVarName("X")), true, - platform::errors::NotFound("Output(X@Grad) should not be null")); - - auto do_dims = ctx->GetInputDim(framework::GradVarName("Out")); - PADDLE_ENFORCE_EQ(do_dims.size(), 4, - platform::errors::InvalidArgument( - "Input should be a 4-D tensor of format [N, C, " - "H, W] or [N, H, W, C], but got %u.", - do_dims.size())); - - auto dx_dims = do_dims; - ctx->SetOutputDim(framework::GradVarName("X"), dx_dims); - } }; template @@ -110,4 +92,9 @@ REGISTER_OPERATOR(channel_shuffle, ops::ChannelShuffleOp, ops::ChannelShuffleGradOpMaker, ChannelShuffleInferShapeFunctor); -REGISTER_OPERATOR(channel_shuffle_grad, ops::ChannelShuffleGradOp); +DECLARE_INFER_SHAPE_FUNCTOR(channel_shuffle_grad, + ChannelShuffleGradInferShapeFunctor, + PD_INFER_META(phi::ChannelShuffleGradInferMeta)); + +REGISTER_OPERATOR(channel_shuffle_grad, ops::ChannelShuffleGradOp, + ChannelShuffleGradInferShapeFunctor);