From 2dcfd6134017010ce79ba650ab34e27b0620323e Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 19 Jun 2020 19:36:08 -0700 Subject: [PATCH] [Target] Introduce Target Id Registry (#5838) --- include/tvm/ir/op.h | 8 - include/tvm/node/attr_registry_map.h | 2 + include/tvm/runtime/registry.h | 8 + include/tvm/target/target_id.h | 306 +++++++++++++++++++++++++++ src/node/attr_registry.h | 1 + src/target/target_id.cc | 164 ++++++++++++++ tests/cpp/target_test.cc | 88 ++++++++ 7 files changed, 569 insertions(+), 8 deletions(-) create mode 100644 include/tvm/target/target_id.h create mode 100644 src/target/target_id.cc create mode 100644 tests/cpp/target_test.cc diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 2bc2c90c7854..9a91302a98b4 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -340,14 +340,6 @@ class OpAttrMap : public AttrRegistryMap { explicit OpAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} }; -#define TVM_STRINGIZE_DETAIL(x) #x -#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) -#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) -/*! - * \brief Macro to include current line as string - */ -#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) - // internal macros to make #define TVM_OP_REGISTER_VAR_DEF static DMLC_ATTRIBUTE_UNUSED ::tvm::OpRegEntry& __make_##Op diff --git a/include/tvm/node/attr_registry_map.h b/include/tvm/node/attr_registry_map.h index 748b3a80969c..9c554af9bc21 100644 --- a/include/tvm/node/attr_registry_map.h +++ b/include/tvm/node/attr_registry_map.h @@ -23,6 +23,8 @@ #ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_ #define TVM_NODE_ATTR_REGISTRY_MAP_H_ +#include + #include #include diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index 4a5a21088222..86e3706b2058 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -289,6 +289,14 @@ class Registry { #define TVM_REGISTER_GLOBAL(OpName) \ TVM_STR_CONCAT(TVM_FUNC_REG_VAR_DEF, __COUNTER__) = ::tvm::runtime::Registry::Register(OpName) +#define TVM_STRINGIZE_DETAIL(x) #x +#define TVM_STRINGIZE(x) TVM_STRINGIZE_DETAIL(x) +#define TVM_DESCRIBE(...) describe(__VA_ARGS__ "\n\nFrom:" __FILE__ ":" TVM_STRINGIZE(__LINE__)) +/*! + * \brief Macro to include current line as string + */ +#define TVM_ADD_FILELINE "\n\nDefined in " __FILE__ ":L" TVM_STRINGIZE(__LINE__) + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_REGISTRY_H_ diff --git a/include/tvm/target/target_id.h b/include/tvm/target/target_id.h new file mode 100644 index 000000000000..93c88c758006 --- /dev/null +++ b/include/tvm/target/target_id.h @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/target/target_id.h + * \brief Target id registry + */ +#ifndef TVM_TARGET_TARGET_ID_H_ +#define TVM_TARGET_TARGET_ID_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace detail { +template +struct ValueTypeInfoMaker; +} + +/*! \brief Perform schema validation */ +TVM_DLL void TargetValidateSchema(const Map& config); + +template +class TargetIdAttrMap; + +/*! \brief Target Id, specifies the kind of the target */ +class TargetIdNode : public Object { + public: + /*! \brief Name of the target id */ + String name; + /*! \brief Stores the required type_key and type_index of a specific attr of a target */ + struct ValueTypeInfo { + String type_key; + uint32_t type_index; + std::unique_ptr key; + std::unique_ptr val; + }; + + static constexpr const char* _type_key = "TargetId"; + TVM_DECLARE_FINAL_OBJECT_INFO(TargetIdNode, Object); + + private: + uint32_t AttrRegistryIndex() const { return index_; } + String AttrRegistryName() const { return name; } + /*! \brief Perform schema validation */ + void ValidateSchema(const Map& config) const; + /*! \brief A hash table that stores the type information of each attr of the target key */ + std::unordered_map key2vtype_; + /*! \brief Index used for internal lookup of attribute registry */ + uint32_t index_; + friend void TargetValidateSchema(const Map&); + friend class TargetId; + template + friend class AttrRegistry; + template + friend class AttrRegistryMapContainerMap; + friend class TargetIdRegEntry; + template + friend struct detail::ValueTypeInfoMaker; +}; + +/*! + * \brief Managed reference class to TargetIdNode + * \sa TargetIdNode + */ +class TargetId : public ObjectRef { + public: + /*! \brief Get the attribute map given the attribute name */ + template + static inline TargetIdAttrMap GetAttrMap(const String& attr_name); + /*! + * \brief Retrieve the TargetId given its name + * \param target_id_name Name of the target id + * \return The TargetId requested + */ + TVM_DLL static const TargetId& Get(const String& target_id_name); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetId, ObjectRef, TargetIdNode); + + private: + /*! \brief Mutable access to the container class */ + TargetIdNode* operator->() { return static_cast(data_.get()); } + TVM_DLL static const AttrRegistryMapContainerMap& GetAttrMapContainer( + const String& attr_name); + template + friend class AttrRegistry; + friend class TargetIdRegEntry; +}; + +/*! + * \brief Map used to store meta-information about TargetId + * \tparam ValueType The type of the value stored in map + */ +template +class TargetIdAttrMap : public AttrRegistryMap { + public: + using TParent = AttrRegistryMap; + using TParent::count; + using TParent::get; + using TParent::operator[]; + explicit TargetIdAttrMap(const AttrRegistryMapContainerMap& map) : TParent(map) {} +}; + +/*! + * \brief Helper structure to register TargetId + * \sa TVM_REGISTER_TARGET_ID + */ +class TargetIdRegEntry { + public: + /*! + * \brief Register additional attributes to target_id. + * \param attr_name The name of the attribute. + * \param value The value to be set. + * \param plevel The priority level of this attribute, + * an higher priority level attribute + * will replace lower priority level attribute. + * Must be bigger than 0. + * + * Cannot set with same plevel twice in the code. + * + * \tparam ValueType The type of the value to be set. + */ + template + inline TargetIdRegEntry& set_attr(const String& attr_name, const ValueType& value, + int plevel = 10); + /*! + * \brief Register a valid configuration option and its ValueType for validation + * \param key The configuration key + * \tparam ValueType The value type to be registered + */ + template + inline TargetIdRegEntry& add_attr_option(const String& key); + /*! \brief Set name of the TargetId to be the same as registry if it is empty */ + inline TargetIdRegEntry& set_name(); + /*! + * \brief Register or get a new entry. + * \param target_id_name The name of the TargetId. + * \return the corresponding entry. + */ + TVM_DLL static TargetIdRegEntry& RegisterOrGet(const String& target_id_name); + + private: + TargetId id_; + String name; + + /*! \brief private constructor */ + explicit TargetIdRegEntry(uint32_t reg_index) : id_(make_object()) { + id_->index_ = reg_index; + } + /*! + * \brief update the attribute TargetIdAttrMap + * \param key The name of the attribute + * \param value The value to be set + * \param plevel The priority level + */ + TVM_DLL void UpdateAttr(const String& key, TVMRetValue value, int plevel); + template + friend class AttrRegistry; + friend class TargetId; +}; + +#define TVM_TARGET_ID_REGISTER_VAR_DEF \ + static DMLC_ATTRIBUTE_UNUSED ::tvm::TargetIdRegEntry& __make_##TargetId + +/*! + * \def TVM_REGISTER_TARGET_ID + * \brief Register a new target id, or set attribute of the corresponding target id. + * + * \param TargetIdName The name of target id + * + * \code + * + * TVM_REGISTER_TARGET_ID("llvm") + * .set_attr("TPreCodegenPass", a-pre-codegen-pass) + * .add_attr_option("system_lib") + * .add_attr_option("mtriple") + * .add_attr_option("mattr"); + * + * \endcode + */ +#define TVM_REGISTER_TARGET_ID(TargetIdName) \ + TVM_STR_CONCAT(TVM_TARGET_ID_REGISTER_VAR_DEF, __COUNTER__) = \ + ::tvm::TargetIdRegEntry::RegisterOrGet(TargetIdName).set_name() + +namespace detail { +template class Container> +struct is_specialized : std::false_type { + using type = std::false_type; +}; + +template