Skip to content

Commit

Permalink
add registry for arg map func
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Jan 20, 2022
1 parent dcfa257 commit 2b53e60
Show file tree
Hide file tree
Showing 20 changed files with 243 additions and 162 deletions.
26 changes: 26 additions & 0 deletions cmake/pten_kernel.cmake → cmake/pten.cmake
Expand Up @@ -184,3 +184,29 @@ function(register_kernels)
endif()
endforeach()
endfunction()

function(append_arg_mapping_declare TARGET)
file(READ ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET} target_content)
string(REGEX MATCH "PT_REGISTER_ARG_MAPPING_FN\\([ \t\r\n]*[a-z0-9_]*, [ \t\r\n]*[a-z0-9_]*" arg_mapping_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" arg_mapping_declare "${arg_mapping_registrar}")
string(APPEND arg_mapping_declare ");")
message(STATUS "${arg_mapping_declare}")
file(APPEND ${arg_mapping_fns_header} "${arg_mapping_declare}")
endfunction()

function(register_arg_mapping_fns TARGET_NAME)
set(arg_mapping_srcs)
set(options "")
set(oneValueArgs "")
set(multiValueArgs EXCLUDES DEPS)
cmake_parse_arguments(register_arg_mapping_fns "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})

file(GLOB SIGNATURES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_sig.cc")
foreach(target ${SIGNATURES})
append_arg_mapping_declare(${target})
list(APPEND arg_mapping_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${target})
endforeach()

cc_library(${TARGET_NAME} SRCS ${arg_mapping_srcs} DEPS ${register_arg_mapping_fns_DEPS})
endfunction()
4 changes: 2 additions & 2 deletions paddle/fluid/framework/infershape_utils.cc
Expand Up @@ -16,7 +16,7 @@ limitations under the License. */

#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/pten/core/arg_map_utils.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/compat_utils.h"
#include "paddle/pten/core/convert_utils.h"
#include "paddle/pten/core/dense_tensor.h"
Expand Down Expand Up @@ -195,7 +195,7 @@ class CompatMetaTensor : public pten::MetaTensor {
pten::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
const std::string& op_type) {
// 1. get kernel args
auto arg_map_fn = pten::OpUtils::Instance().GetArgumentMappingFn(op_type);
auto arg_map_fn = pten::OpUtilsMap::Instance().Get(op_type).arg_mapping_fn;
PADDLE_ENFORCE_NOT_NULL(
arg_map_fn, platform::errors::NotFound(
"The ArgumentMappingFn of %s op is not found.", op_type));
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/operator.cc
Expand Up @@ -1808,7 +1808,9 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(

KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
return pten::KernelSignatureMap::Instance().Get(
// only init DefaultKernelSignatureMap when pten kernel needed
framework::IntiDefaultKernelSignatureMap();
return pten::DefaultKernelSignatureMap::Instance().Get(
pten::TransToPtenKernelName(Type()));
}

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/operator.h
Expand Up @@ -40,7 +40,7 @@ limitations under the License. */
#include "paddle/fluid/platform/variant.h"
#include "paddle/utils/flat_hash_map.h"

#include "paddle/pten/core/arg_map_utils.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"

Expand Down
44 changes: 8 additions & 36 deletions paddle/fluid/framework/pten_utils.cc
Expand Up @@ -155,51 +155,23 @@ KernelSignature KernelArgsNameMakerByOpProto::GetKernelSignature() {
GetOutputArgsNames());
}

} // namespace framework
} // namespace paddle

// NOTE: [Why the definition of KernelSignatureMap method isn't in op_utils.cc]
// - In order to avoid introducing fluid proto and op-related dependencies
// into pten
namespace pten {
std::once_flag kernel_sig_map_init_flag;

KernelSignatureMap& KernelSignatureMap::Instance() {
std::call_once(init_flag_, [] {
kernel_signature_map_ = new KernelSignatureMap();
void IntiDefaultKernelSignatureMap() {
std::call_once(kernel_sig_map_init_flag, [] {
for (const auto& pair : paddle::framework::OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
.emplace(TransToPtenKernelName(op_type),
std::move(maker.GetKernelSignature()))
.second;
PADDLE_ENFORCE_EQ(
success, true,
paddle::platform::errors::PermissionDenied(
"Kernel signature of the operator %s has been registered.",
op_type));
pten::DefaultKernelSignatureMap::Instance().Insert(op_type,
std::move(maker.GetKernelSignature());
}
}
});
return *kernel_signature_map_;
}

bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}

const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
paddle::platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}

} // namespace pten
} // namespace framework
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/framework/pten_utils.h
Expand Up @@ -25,7 +25,7 @@ limitations under the License. */
#include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/core/arg_map_utils.h"
#include "paddle/pten/core/arg_map_context.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h"
#include "paddle/utils/small_vector.h"
Expand All @@ -52,5 +52,7 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
};

void IntiDefaultKernelSignatureMap();

} // namespace framework
} // namespace paddle
5 changes: 4 additions & 1 deletion paddle/pten/CMakeLists.txt
Expand Up @@ -3,6 +3,9 @@
# float16.h/complex.h/bfloat16.h into pten
include_directories(${PADDLE_SOURCE_DIR}/paddle/fluid/platform)

# pten auto cmake utils
include(pten)

# paddle experimental common components
add_subdirectory(common)

Expand All @@ -24,7 +27,7 @@ add_subdirectory(tests)

# make an unity target for compile deps
set(PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory
kernel_context arg_map_utils infermeta op_utils)
kernel_context arg_map_context infermeta op_utils)
get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}")
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/core/CMakeLists.txt
Expand Up @@ -8,15 +8,15 @@ endif()

cc_library(kernel_factory SRCS kernel_factory.cc DEPS enforce convert_utils)
cc_library(kernel_context SRCS kernel_context.cc DEPS enforce pten_context)
cc_library(arg_map_utils SRCS arg_map_utils.cc DEPS enforce)
cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)

cc_library(tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce)
cc_library(tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector)
cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base)

cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor)
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_utils enforce convert_utils)
cc_library(op_utils SRCS op_utils.cc DEPS arg_map_context enforce convert_utils)

# Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN)
Expand Down
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/core/arg_map_utils.h"
#include "paddle/pten/core/arg_map_context.h"

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
Expand Down
Expand Up @@ -69,6 +69,4 @@ class ArgumentMappingContext {
virtual bool IsSelectedRowsInput(const std::string& name) const = 0;
};

#define PT_REGISTER_ARGUMENT_MAPPING_FN()

} // namespace pten
31 changes: 4 additions & 27 deletions paddle/pten/core/kernel_registry.h
Expand Up @@ -24,6 +24,7 @@
#include "paddle/pten/core/kernel_def.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_utils.h"
#include "paddle/pten/core/macros.h"

#include "paddle/fluid/platform/enforce.h"

Expand Down Expand Up @@ -158,33 +159,6 @@ struct KernelRegistrar {
}
};

#define PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
_PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)

#define _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)

#ifdef __COUNTER__
#define PT_ID __COUNTER__
#else
#define PT_ID __LINE__
#endif

#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif

#define PT_CONCATENATE(arg1, arg2) PT_CONCATENATE1(arg1, arg2)
#define PT_CONCATENATE1(arg1, arg2) PT_CONCATENATE2(arg1, arg2)
#define PT_CONCATENATE2(arg1, arg2) arg1##arg2
#define PT_EXPAND(x) x

/**
* Reference:
*
Expand Down Expand Up @@ -834,6 +808,9 @@ struct KernelRegistrar {
* to avoid being removed by linker
*/
#define PT_DECLARE_KERNEL(kernel_name, backend, layout) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_declare_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_DECLARE_KERNEL must be called in global namespace."); \
extern int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout(); \
UNUSED static int \
__declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \
Expand Down
31 changes: 31 additions & 0 deletions paddle/pten/core/macros.h
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#pragma once

namespace pten {

// Disable the copy and assignment operator for a class.
#ifndef DISABLE_COPY_AND_ASSIGN
#define DISABLE_COPY_AND_ASSIGN(classname) \
Expand All @@ -23,3 +25,32 @@ limitations under the License. */
classname& operator=(const classname&) = delete; \
classname& operator=(classname&&) = delete
#endif

#define PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
_PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg)

#define _PT_STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
struct __test_global_namespace_##uniq_name##__ {}; \
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
__test_global_namespace_##uniq_name##__>::value, \
msg)

#ifdef __COUNTER__
#define PT_ID __COUNTER__
#else
#define PT_ID __LINE__
#endif

#if defined(_WIN32)
#define UNUSED
#define __builtin_expect(EXP, C) (EXP)
#else
#define UNUSED __attribute__((unused))
#endif

#define PT_CONCATENATE(arg1, arg2) PT_CONCATENATE1(arg1, arg2)
#define PT_CONCATENATE1(arg1, arg2) PT_CONCATENATE2(arg1, arg2)
#define PT_CONCATENATE2(arg1, arg2) arg1##arg2
#define PT_EXPAND(x) x

} // namespace pten

0 comments on commit 2b53e60

Please sign in to comment.