Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Unify InferMeta(Shape) Function in pten and fluid op #38976

Merged
merged 32 commits into from Jan 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ffd6624
infermeta context init design
chenwhql Jan 13, 2022
651c00c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Jan 13, 2022
ac239ac
support infermeta called in fluid op
chenwhql Jan 15, 2022
5212fdd
add hasattr and attr methods
chenwhql Jan 19, 2022
f753d59
add dygraah GetVarPtrs support
chenwhql Jan 19, 2022
dcfa257
rename arg_map_context to arg_map_utils
chenwhql Jan 19, 2022
2b53e60
add registry for arg map func
chenwhql Jan 20, 2022
d3749af
Merge branch 'develop' into pten/upgrade_infermeta_design
chenwhql Jan 20, 2022
4f2bc42
resolve conflit
chenwhql Jan 20, 2022
7b64985
refactor op utils design
chenwhql Jan 21, 2022
b3f9bf9
Merge branch 'develop' into pten/upgrade_infermeta_design
chenwhql Jan 21, 2022
e0f4bed
polish meta config
chenwhql Jan 21, 2022
5b653b5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
chenwhql Jan 21, 2022
c9332c4
resolve conflit
chenwhql Jan 22, 2022
f8205c4
resolve conflit
chenwhql Jan 22, 2022
e9c98db
fix details
chenwhql Jan 22, 2022
974abd4
resolve conflit
chenwhql Jan 22, 2022
516fbe2
remove hasattr method
chenwhql Jan 22, 2022
fe90b26
resolve conflit
chenwhql Jan 22, 2022
c590ca1
resolve conflit
chenwhql Jan 22, 2022
1a18ee9
revert cmake order change
chenwhql Jan 22, 2022
9d88587
revert some change
chenwhql Jan 22, 2022
de66d00
resolve conflit
chenwhql Jan 23, 2022
adfbb98
change init pos
chenwhql Jan 23, 2022
cc0c3c8
fix compile faileed
chenwhql Jan 23, 2022
610d612
fix typo
chenwhql Jan 23, 2022
a5a028a
fix inference failed
chenwhql Jan 23, 2022
c8dd7e6
fix windows ccompile failed
wanghuancoder Jan 24, 2022
e49964c
resolve conflit
chenwhql Jan 24, 2022
7bf344c
polish format
chenwhql Jan 24, 2022
2bcf72d
Merge branch 'develop' into pten/upgrade_infermeta_design
chenwhql Jan 25, 2022
8ef7ad6
resolve conflit
chenwhql Jan 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions cmake/pten.cmake
Expand Up @@ -243,3 +243,29 @@ function(register_kernels)
endif()
endforeach()
endfunction()

function(append_op_util_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "(PT_REGISTER_API_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_API_NAME" "PT_REGISTER_API_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");")
file(APPEND ${op_utils_header} "${util_declare}")
endfunction()

function(register_op_utils TARGET_NAME)
set(utils_srcs)
set(options "")
set(oneValueArgs "")
set(multiValueArgs EXCLUDES DEPS)
cmake_parse_arguments(register_op_utils "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})

file(GLOB SIGNATURES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_sig.cc")
foreach(target ${SIGNATURES})
append_op_util_declare(${target})
list(APPEND utils_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target})
endforeach()

cc_library(${TARGET_NAME} SRCS ${utils_srcs} DEPS ${register_op_utils_DEPS})
endfunction()
6 changes: 3 additions & 3 deletions paddle/fluid/framework/CMakeLists.txt
Expand Up @@ -192,11 +192,11 @@ cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_va
IF(WITH_XPU)
cc_library(operator SRCS operator.cc DEPS xpu_op_list op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils)
pten pten_utils kernel_factory infershape_utils op_utils)
ELSE()
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils
pten pten_utils kernel_factory infershape_utils)
pten pten_utils kernel_factory infershape_utils op_utils)
ENDIF()

cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
Expand Down Expand Up @@ -404,7 +404,7 @@ cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tens
cc_library(generator SRCS generator.cc DEPS enforce place)

cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows_utils place pten var_type_traits pten_api_utils op_info)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS lod_tensor selected_rows_utils attribute place pten var_type_traits pten pten_api_utils op_info shape_inference)
cc_library(infershape_utils SRCS infershape_utils.cc DEPS pten_utils attribute shape_inference op_utils)

# Get the current working branch
execute_process(
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/framework/details/op_registry.h
Expand Up @@ -275,10 +275,8 @@ struct OpInfoFiller<T, kVarTypeInference> {
template <typename T>
struct OpInfoFiller<T, kShapeInference> {
void operator()(const char* op_type, OpInfo* info) const {
PADDLE_ENFORCE_EQ(
info->infer_shape_, nullptr,
platform::errors::AlreadyExists(
"Duplicate InferShapeFN of %s has been registered", op_type));
// Note: if fill InferShapeFN by this Filler, the infershape here
// will overwrite the op->InferShape func registered in kOperator Filler
info->infer_shape_ = [](InferShapeContext* ctx) {
T inference;
inference(ctx);
Expand Down
38 changes: 38 additions & 0 deletions paddle/fluid/framework/infershape_utils.cc
Expand Up @@ -15,11 +15,14 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"

#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/core/meta_tensor.h"

namespace paddle {
Expand Down Expand Up @@ -186,5 +189,40 @@ class CompatMetaTensor : public pten::MetaTensor {
bool is_runtime_;
};

pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
InitDefaultKernelSignatureMap();
auto arg_map_fn = pten::OpUtilsMap::Instance().GetArgumentMappingFn(op_type);
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
InferShapeArgumentMappingContext arg_map_context(*ctx);
auto signature = arg_map_fn(arg_map_context);
VLOG(3) << "BuildInferMetaContext: op kernel signature - " << signature;

// 2. build infermeta context
pten::InferMetaContext infer_meta_context(ctx->IsRuntime());

auto& input_names = std::get<0>(signature.args);
auto& output_names = std::get<2>(signature.args);
// TODO(chenweihang): support attrs in next pr
// auto& attr_names = std::get<1>(signature.args);

// TODO(chenweihang): support multiple inputs and outputs
pten::InferMetaContext infer_mete_context;
chenwhql marked this conversation as resolved.
Show resolved Hide resolved
for (auto& in_name : input_names) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
}
for (auto& out_name : output_names) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
}
// TODO(chenweihang): support attrs later

return infer_meta_context;
}

} // namespace framework
} // namespace paddle
1 change: 0 additions & 1 deletion paddle/fluid/framework/infershape_utils.h
Expand Up @@ -26,7 +26,6 @@ class InferMetaContext;
namespace paddle {
namespace framework {

// TODO(chenweihang): impl this function in next PR
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type);

Expand Down
14 changes: 12 additions & 2 deletions paddle/fluid/framework/operator.cc
Expand Up @@ -31,6 +31,7 @@ limitations under the License. */
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/ops/compat/signatures.h"

namespace pten {
class DenseTensor;
Expand Down Expand Up @@ -1086,6 +1087,13 @@ bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
return use_mkldnn_ctx && this->SupportsMKLDNN(data_type);
}

void OperatorWithKernel::InferShape(InferShapeContext* ctx) const {
PADDLE_THROW(platform::errors::PermissionDenied(
"The default InferShape function of OperatorWithKernel is not allowed to "
"be called, please override corresponding InferShape function in the "
"specific operator."));
}

void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place,
const RuntimeContext& ctx) const {
Expand Down Expand Up @@ -1784,8 +1792,10 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(

KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
return KernelSignatureMap::Instance().Get(
pten::TransToPtenKernelName(Type()));
InitDefaultKernelSignatureMap();
ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::OpUtilsMap::Instance().GetArgumentMappingFn(Type())(
arg_mapping_ctx);
}

Scope* OperatorWithKernel::PreparePtenData(
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/framework/operator.h
Expand Up @@ -41,6 +41,7 @@ limitations under the License. */
#include "paddle/utils/flat_hash_map.h"

#include "paddle/pten/core/compat/arg_map_context.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"

Expand Down Expand Up @@ -468,8 +469,7 @@ class ExecutionArgumentMappingContext : public pten::ArgumentMappingContext {
}

bool IsDenseTensorInput(const std::string& name) const override {
return ctx_.InputVar(name)->IsType<framework::Tensor>() ||
ctx_.InputVar(name)->IsType<framework::LoDTensor>();
return ctx_.InputVar(name)->IsType<framework::LoDTensor>();
}

bool IsSelectedRowsInput(const std::string& name) const override {
Expand Down Expand Up @@ -550,7 +550,7 @@ class OperatorWithKernel : public OperatorBase {
bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
proto::VarType::Type data_type) const;

virtual void InferShape(InferShapeContext* ctx) const = 0;
virtual void InferShape(InferShapeContext* ctx) const;

void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
Expand Down
61 changes: 19 additions & 42 deletions paddle/fluid/framework/pten_utils.cc
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <sstream>

#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/kernel_factory.h"

Expand Down Expand Up @@ -89,48 +90,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
return pten::KernelKey(backend, layout, dtype);
}

KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::once_flag KernelSignatureMap::init_flag_;

KernelSignatureMap& KernelSignatureMap::Instance() {
std::call_once(init_flag_, [] {
kernel_signature_map_ = new KernelSignatureMap();
for (const auto& pair : OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
.emplace(pten::TransToPtenKernelName(op_type),
std::move(maker.GetKernelSignature()))
.second;
PADDLE_ENFORCE_EQ(
success, true,
platform::errors::PermissionDenied(
"Kernel signature of the operator %s has been registered.",
op_type));
}
}
});
return *kernel_signature_map_;
}

bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}

const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}

const paddle::SmallVector<std::string>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) {
Expand Down Expand Up @@ -196,6 +155,24 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
GetOutputArgsNames());
}

std::once_flag kernel_sig_map_init_flag;

void InitDefaultKernelSignatureMap() {
std::call_once(kernel_sig_map_init_flag, [] {
for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
pten::DefaultKernelSignatureMap::Instance().Insert(
op_type, std::move(maker.GetKernelSignature()));
}
}
});
}

void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place) {
if (!tensor->IsInitialized() || !(tensor->place() == place)) {
Expand Down
22 changes: 2 additions & 20 deletions paddle/fluid/framework/pten_utils.h
Expand Up @@ -44,26 +44,6 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(

/* Kernel Args parse */

// TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap {
public:
static KernelSignatureMap& Instance();

bool Has(const std::string& op_type) const;

const KernelSignature& Get(const std::string& op_type) const;

private:
KernelSignatureMap() = default;
DISABLE_COPY_AND_ASSIGN(KernelSignatureMap);

private:
static KernelSignatureMap* kernel_signature_map_;
static std::once_flag init_flag_;

paddle::flat_hash_map<std::string, KernelSignature> map_;
};

class KernelArgsNameMaker {
public:
virtual ~KernelArgsNameMaker() {}
Expand All @@ -72,6 +52,8 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
};

void InitDefaultKernelSignatureMap();

void SetAllocationForOutputTenosr(pten::DenseTensor* tensor,
const platform::Place& place);

Expand Down
7 changes: 0 additions & 7 deletions paddle/fluid/operators/scale_op.cc
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include <string>
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/ops/compat/scale_args_fn.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -71,12 +70,6 @@ class ScaleOp : public framework::OperatorWithKernel {
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
framework::ExecutionArgumentMappingContext arg_mapping_ctx(ctx);
return pten::ScaleOpArgumentMapping(arg_mapping_ctx);
}
};

class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
Expand Down
16 changes: 7 additions & 9 deletions paddle/fluid/operators/sign_op.cc
Expand Up @@ -14,22 +14,17 @@ limitations under the License. */

#include "paddle/fluid/operators/sign_op.h"
#include <memory>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/core/infermeta_utils.h"
#include "paddle/pten/infermeta/unary.h"

namespace paddle {
namespace operators {

class SignOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "sign");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "sign");

ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};

template <typename AttrType>
Expand Down Expand Up @@ -64,9 +59,12 @@ class SignGradMaker : public framework::SingleGradOpMaker<T> {

namespace ops = paddle::operators;

DELCARE_INFER_SHAPE_FUNCTOR(sign, SignInferShapeFunctor,
PT_INFER_META(pten::UnchangedInferMetaNew));
REGISTER_OPERATOR(sign, ops::SignOp, ops::SignOpMaker<float>,
ops::SignGradMaker<paddle::framework::OpDesc>,
ops::SignGradMaker<paddle::imperative::OpBase>);
ops::SignGradMaker<paddle::imperative::OpBase>,
SignInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
sign, ops::SignKernel<paddle::platform::CPUDeviceContext, float>,
ops::SignKernel<paddle::platform::CPUDeviceContext, double>);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pybind/CMakeLists.txt
Expand Up @@ -2,7 +2,7 @@ set(PYBIND_DEPS pybind python proto_desc memory executor fleet_wrapper box_wrapp
feed_fetch_method pass generate_pass pass_builder parallel_executor profiler layer tracer engine scope_pool
analysis_predictor imperative_profiler imperative_flag save_load_util dlpack_tensor device_context
gloo_wrapper infer_io_utils heter_wrapper generator op_version_registry ps_gpu_wrapper custom_operator
cost_model cuda_graph_with_memory_pool fleet_executor global_utils)
cost_model cuda_graph_with_memory_pool fleet_executor global_utils pten_utils)

if (WITH_PSCORE)
set(PYBIND_DEPS ${PYBIND_DEPS} ps_service)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/pybind.cc
Expand Up @@ -50,6 +50,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/parallel_executor.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/save_load_util.h"
#include "paddle/fluid/framework/scope_pool.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/CMakeLists.txt
Expand Up @@ -21,7 +21,7 @@ add_subdirectory(ops)
add_subdirectory(tests)

# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils)
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta lod_utils op_compat_infos)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}")
Expand Down