Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Unify InferMeta(Shape) Function in pten and fluid op #38976

Merged
merged 32 commits into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ffd6624
infermeta context init design
chenwhql Jan 13, 2022
651c00c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Jan 13, 2022
ac239ac
support infermeta called in fluid op
chenwhql Jan 15, 2022
5212fdd
add hasattr and attr methods
chenwhql Jan 19, 2022
f753d59
add dygraah GetVarPtrs support
chenwhql Jan 19, 2022
dcfa257
rename arg_map_context to arg_map_utils
chenwhql Jan 19, 2022
2b53e60
add registry for arg map func
chenwhql Jan 20, 2022
d3749af
Merge branch 'develop' into pten/upgrade_infermeta_design
chenwhql Jan 20, 2022
4f2bc42
resolve conflit
chenwhql Jan 20, 2022
7b64985
refactor op utils design
chenwhql Jan 21, 2022
b3f9bf9
Merge branch 'develop' into pten/upgrade_infermeta_design
chenwhql Jan 21, 2022
e0f4bed
polish meta config
chenwhql Jan 21, 2022
5b653b5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Jan 21, 2022
c9332c4
resolve conflit
chenwhql Jan 22, 2022
f8205c4
resolve conflit
chenwhql Jan 22, 2022
e9c98db
fix details
chenwhql Jan 22, 2022
974abd4
resolve conflit
chenwhql Jan 22, 2022
516fbe2
remove hasattr method
chenwhql Jan 22, 2022
fe90b26
resolve conflit
chenwhql Jan 22, 2022
c590ca1
resolve conflit
chenwhql Jan 22, 2022
1a18ee9
revert cmake order change
chenwhql Jan 22, 2022
9d88587
revert some change
chenwhql Jan 22, 2022
de66d00
resolve conflit
chenwhql Jan 23, 2022
adfbb98
change init pos
chenwhql Jan 23, 2022
cc0c3c8
fix compile faileed
chenwhql Jan 23, 2022
610d612
fix typo
chenwhql Jan 23, 2022
a5a028a
fix inference failed
chenwhql Jan 23, 2022
c8dd7e6
fix windows ccompile failed
wanghuancoder Jan 24, 2022
e49964c
resolve conflit
chenwhql Jan 24, 2022
7bf344c
polish format
chenwhql Jan 24, 2022
2bcf72d
Merge branch 'develop' into pten/upgrade_infermeta_design
chenwhql Jan 25, 2022
8ef7ad6
resolve conflit
chenwhql Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Expand Up @@ -199,11 +199,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
pten pten_utils kernel_factory infershape_utils)
ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory)
pten pten_utils kernel_factory infershape_utils)
ENDIF()

cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
Expand Down Expand Up @@ -414,6 +414,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place)

cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten pten_api_utils op_info shape_inference)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_api_utils op_info)

# Get the current working branch
Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/framework/details/op_registry.h
Expand Up @@ -275,10 +275,6 @@ struct OpInfoFiller<T, kVarTypeInference> {
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
info->infer_shape_ = [](InferShapeContext* ctx) {
T inference;
inference(ctx);
Expand Down
225 changes: 225 additions & 0 deletions paddle/fluid/framework/infershape_utils.cc
@@ -0,0 +1,225 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/infershape_utils.h"

#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/core/op_utils.h"

namespace paddle {
namespace framework {

class InferShapeArgumentMappingContext : public pten::ArgumentMappingContext {
public:
explicit InferShapeArgumentMappingContext(const InferShapeContext& ctx)
: ctx_(ctx) {}

bool HasInput(const std::string& name) const override {
return ctx_.HasInput(name);
}

bool HasOutput(const std::string& name) const override {
return ctx_.HasOutput(name);
}

bool HasAttr(const std::string& name) const override {
// TODO(chenweihang): impl this method later
return false;
}

size_t InputSize(const std::string& name) const override {
return ctx_.Inputs(name).size();
}

size_t OutputSize(const std::string& name) const override {
return ctx_.Outputs(name).size();
}

bool IsDenseTensorInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::LOD_TENSOR;
}

bool IsSelectedRowsInput(const std::string& name) const override {
auto var_types = ctx_.GetInputsVarType(name);
return var_types[0] == proto::VarType::SELECTED_ROWS;
}

private:
const InferShapeContext& ctx_;
};

// TODO(chenweihang): Support SelectedRows later
// TODO(chenweihang): Support TensorArray later
class CompatMetaTensor : public pten::MetaTensor {
public:
CompatMetaTensor(InferShapeVarPtr var, bool is_runtime)
: var_(std::move(var)), is_runtime_(is_runtime) {}

CompatMetaTensor() = default;
CompatMetaTensor(const CompatMetaTensor&) = default;
CompatMetaTensor(CompatMetaTensor&&) = default;
CompatMetaTensor& operator=(const CompatMetaTensor&) = delete;
CompatMetaTensor& operator=(CompatMetaTensor&&) = delete;

int64_t numel() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<Tensor>().numel();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->ElementSize();
}
}

DDim dims() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().dims();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return make_ddim(var->GetShape());
}
}

pten::DataType dtype() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().dtype();
} else {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return pten::TransToPtenDataType(var->GetDataType());
}
}

DataLayout layout() const override {
if (is_runtime_) {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().layout();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported get layout for VarDesc now."));
}
}

void set_dims(const DDim& dims) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->dims = dims;
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetShape(vectorize(dims));
}
}

void set_dtype(pten::DataType dtype) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->dtype = dtype;
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetDataType(pten::TransToProtoVarType(dtype));
}
}

void set_layout(DataLayout layout) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->layout = layout;
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported set layout for VarDesc now."));
}
}

void share_lod(const MetaTensor& meta_tensor) override {
if (is_runtime_) {
auto* var = BOOST_GET(Variable*, var_);
LoDTensor* tensor = var->GetMutable<LoDTensor>();
pten::CompatibleDenseTensorUtils::GetMutableMeta(
static_cast<pten::DenseTensor*>(tensor))
->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
} else {
auto* var = BOOST_GET(VarDesc*, var_);
var->SetLoDLevel(static_cast<const CompatMetaTensor&>(meta_tensor)
.GetCompileTimeLoD());
}
}

private:
const LoD& GetRuntimeLoD() const {
auto* var = BOOST_GET_CONST(Variable*, var_);
return var->Get<LoDTensor>().lod();
}
int32_t GetCompileTimeLoD() const {
auto* var = BOOST_GET_CONST(VarDesc*, var_);
return var->GetLoDLevel();
}

InferShapeVarPtr var_;
bool is_runtime_;
};

pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
auto arg_map_fn = pten::OpUtils::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

// 2. build infermeta context
pten::InferMetaContext infer_meta_context;

auto& input_names = std::get<0>(signature.args);
auto& output_names = std::get<2>(signature.args);
// auto& attr_names = std::get<1>(signature.args);

// TODO(chenweihang): support multiple inputs and outputs
pten::InferMetaContext infer_mete_context;
chenwhql marked this conversation as resolved.
Show resolved Hide resolved
for (auto& in_name : input_names) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
}
for (auto& out_name : output_names) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
}
// TODO(chenweihang): support attrs later

return infer_meta_context;
}

} // namespace framework
} // namespace paddle
44 changes: 44 additions & 0 deletions paddle/fluid/framework/infershape_utils.h
@@ -0,0 +1,44 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>

#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/pten/core/op_utils.h"

namespace pten {
class InferMetaContext;
} // namespace pten

namespace paddle {
namespace framework {

pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);

#define DELCARE_INFER_SHAPE_FUNCTOR(op_type, functor_name, fn) \
struct functor_name : public paddle::framework::InferShapeBase { \
void operator()( \
paddle::framework::InferShapeContext* ctx) const override { \
auto infer_meta_context = \
paddle::framework::BuildInferMetaContext(ctx, #op_type); \
fn(&infer_meta_context); \
} \
}

} // namespace framework
} // namespace paddle
10 changes: 9 additions & 1 deletion paddle/fluid/framework/operator.cc
Expand Up @@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/op_utils.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -1097,6 +1098,13 @@ bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}

void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to "
"be called, please override corresponding InferShape function in the "
"specific operator."));
}

void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
Expand Down Expand Up @@ -1796,7 +1804,7 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(

KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
return KernelSignatureMap::Instance().Get(
return pten::KernelSignatureMap::Instance().Get(
pten::TransToPtenKernelName(Type()));
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.h
Expand Up @@ -555,7 +555,7 @@ class OperatorWithKernel : public OperatorBase {
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;

virtual void InferShape(InferShapeContext* ctx) const = 0;
virtual void InferShape(InferShapeContext* ctx) const;

void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
Expand Down