From 3636388dfb32294eac519cbd9b6a97fb9b5bc74e Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 14 Nov 2019 19:56:38 -0800 Subject: [PATCH] [Runtime] Make ADTObject POD container type --- include/tvm/runtime/container.h | 220 ++++++++++++++++++++++++++++++++ include/tvm/runtime/memory.h | 80 +++++++++++- include/tvm/runtime/vm.h | 29 ----- src/runtime/container.cc | 45 +++++++ src/runtime/vm/object.cc | 20 +-- src/runtime/vm/vm.cc | 10 +- 6 files changed, 355 insertions(+), 49 deletions(-) create mode 100644 include/tvm/runtime/container.h create mode 100644 src/runtime/container.cc diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h new file mode 100644 index 0000000000000..17e0965922346 --- /dev/null +++ b/include/tvm/runtime/container.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container.h + * \brief Common POD(plain old data) container types. + */ +#ifndef TVM_RUNTIME_CONTAINER_H_ +#define TVM_RUNTIME_CONTAINER_H_ +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/** + * @brief Base template for classes with array like memory layout. + * + * It provides general methods to access the memory. The memory + * layout is ArrayType + [ElemType]. The alignment of ArrayType + * and ElemType is handled by the memory allocator. + * + * @tparam ArrayType + * @tparam ElemType + */ +template +class InplaceArrayBase { + public: + /** + * @brief Initialize the elements in the array. + */ + void Init() { + CHECK_EQ(sizeof(ArrayType) % alignof(ElemType), 0); + for (size_t i = 0; i < Self()->size(); ++i) { + void* field_ptr = AddressOf(i); + new (field_ptr) ElemType(); + } + } + + /** + * @brief Initialize the elements in the array. + * + * @tparam Iterator Iterator type of the array. + * @param begin The begin iterator. + * @param end The end iterator. + */ + template + void Init(Iterator begin, Iterator end) { + CHECK_EQ(sizeof(ArrayType) % alignof(ElemType), 0); + ArrayType* self = Self(); + size_t num_elems = std::distance(begin, end); + if (num_elems != self->size()) { + LOG(FATAL) + << "Number of initializer values does not match number of elements\n"; + } + auto it = begin; + for (size_t i = 0; i < num_elems; ++i) { + void* field_ptr = AddressOf(i); + new (field_ptr) ElemType(*it); + ++it; + } + } + + /** + * @brief Initialize the elements in the array. + * + * @param init The initializer list of elements. + */ + void Init(std::initializer_list init) { + CHECK_EQ(sizeof(ArrayType) % alignof(ElemType), 0); + Init(init.begin(), init.end()); + } + + /*! + * \brief Access element at index + * \param idx The index of the element. + * \return Reference to ElemType at the index. + */ + ElemType& operator[](size_t idx) const { + size_t size = Self()->size(); + if (idx > size) { + LOG(FATAL) << "Index " << idx << " out of bounds " << size << "\n"; + } + return *(reinterpret_cast(AddressOf(idx))); + } + + /** + * @brief Destroy the Inplace Array Base object + */ + virtual ~InplaceArrayBase() { + if (!IsPOD()) { + size_t size = Self()->size(); + for (size_t i = 0; i < size; ++i) { + ElemType* fp = reinterpret_cast(AddressOf(i)); + fp->ElemType::~ElemType(); + } + } + } + + private: + /** + * @brief Check if the ElemType is Plain Old Data. + * + * @return If ElemType is POD. + */ + inline bool IsPOD() const { + return std::is_standard_layout::value && + std::is_trivial::value; + } + + /** + * @brief Return the self object for the array. + * + * @return Pointer to ArrayType. + */ + inline ArrayType* Self() const { + return static_cast(const_cast(this)); + } + + /** + * @brief Return the raw pointer to the element at idx. + * + * @param idx The index of the element. + * @return Raw pointer to the element. + */ + void* AddressOf(size_t idx) const { + const size_t kDataStart = sizeof(ArrayType); + ArrayType* self = Self(); + char* data_start = reinterpret_cast(self) + kDataStart; + return data_start + idx * sizeof(ElemType); + } +}; + +/*! \brief An object representing a structure or enumeration. */ +class ADTObj : public Object, public InplaceArrayBase { + public: + /*! \brief The tag representing the constructor used. */ + uint32_t tag_; + /*! \brief Number of fields in the ADT object. */ + uint32_t size_; + // The fields of the structure follows directly in memory. + + /** + * @brief The number of elements in the array. + */ + inline size_t size() const { return size_; } + + /** + * @brief Destroy the ADTObj object + */ + ~ADTObj() {} + + static constexpr const uint32_t _type_index = TypeIndex::kVMADT; + static constexpr const char* _type_key = "vm.ADT"; + TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); +}; + +/*! \brief reference to algebraic data type objects. */ +class ADT : public ObjectRef { + public: + /*! + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param fields The fields of the ADT object. + * \return The constructed ADT object reference. + */ + ADT(uint32_t tag, std::vector fields) + : ADT(tag, fields.begin(), fields.end()){}; + + /** + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param begin The begin iterator to the start of the fields array. + * \param end The end iterator to the end of the fields array. + * \return The constructed ADT object reference. + */ + template + ADT(uint32_t tag, Iterator begin, Iterator end); + + /** + * \brief construct an ADT object reference. + * \param tag The tag of the ADT object. + * \param init The initializer list of fields. + * \return The constructed ADT object reference. + */ + ADT(uint32_t tag, std::initializer_list init) + : ADT(tag, init.begin(), init.end()){}; + + /*! + * \brief construct a tuple object. + * \param fields The fields of the tuple. + * \return The constructed tuple type. + */ + static ADT Tuple(std::vector fields); + + TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_H_ diff --git a/include/tvm/runtime/memory.h b/include/tvm/runtime/memory.h index d28552eaf7fd3..8db14e9d6f192 100644 --- a/include/tvm/runtime/memory.h +++ b/include/tvm/runtime/memory.h @@ -23,6 +23,7 @@ #ifndef TVM_RUNTIME_MEMORY_H_ #define TVM_RUNTIME_MEMORY_H_ +#include #include #include #include "object.h" @@ -33,7 +34,7 @@ namespace runtime { * \brief Allocate an object using default allocator. * \param args arguments to the constructor. * \tparam T the node type. - * \return The NodePtr to the allocated object. + * \return The ObjectPtr to the allocated object. */ template inline ObjectPtr make_object(Args&&... args); @@ -73,6 +74,26 @@ class ObjAllocatorBase { ptr->deleter_ = Handler::Deleter(); return ObjectPtr(ptr); } + + /*! + * \tparam T The type to be allocated. + * \tparam ElemType The type to array element. + * \tparam Args The constructor signature. + * \param num_elems The number of array elements. + * \param args The arguments. + */ + template + inline ObjectPtr make_array(size_t num_elems, Args&&... args) { + using Handler = typename Derived::template Handler; + static_assert(std::is_base_of::value, + "make_node can only be used to create NodeBase"); + T* ptr = Handler::New(static_cast(this), + num_elems, + std::forward(args)...); + ptr->type_index_ = T::RuntimeTypeIndex(); + ptr->deleter_ = Handler::Deleter(); + return ObjectPtr(ptr); + } }; // Simple allocator that uses new/delete. @@ -124,11 +145,68 @@ class SimpleObjAllocator : }; }; +// Array allocator that uses new/delete. +class ArrayObjAllocator : + public ObjAllocatorBase { + public: + template + class Handler { + public: + using StorageType = typename std::aligned_union::type; + + template + static ArrayType* New(ArrayObjAllocator*, size_t num_elems, Args&&... args) { + // NOTE: the first argument is not needed for ArrayObjAllocator + // It is reserved for special allocators that needs to recycle + // the object to itself (e.g. in the case of object pool). + // + // In the case of an object pool, an allocator needs to create + // a special chunk memory that hides reference to the allocator + // and call allocator's release function in the deleter. + + // NOTE2: Use inplace new to allocate + // This is used to get rid of warning when deleting a virtual + // class with non-virtual destructor. + // We are fine here as we captured the right deleter during construction. + // This is also the right way to get storage type for an object pool. + size_t factor = sizeof(ArrayType) / sizeof(ElemType); + num_elems = (num_elems + factor - 1) / factor; + StorageType* data = new StorageType[num_elems+1]; + new (data) ArrayType(std::forward(args)...); + return reinterpret_cast(data); + } + + static Object::FDeleter Deleter() { + return Deleter_; + } + + private: + static void Deleter_(Object* objptr) { + // NOTE: this is important to cast back to ArrayType* + // because objptr and tptr may not be the same + // depending on how sub-class allocates the space. + ArrayType* tptr = static_cast(objptr); + // It is important to do tptr->ArrayType::~ArrayType(), + // so that we explicitly call the specific destructor + // instead of tptr->~ArrayType(), which could mean the intention + // call a virtual destructor(which may not be available and is not required). + tptr->ArrayType::~ArrayType(); + StorageType* p = reinterpret_cast(tptr); + delete []p; + } + }; +}; + template inline ObjectPtr make_object(Args&&... args) { return SimpleObjAllocator().make(std::forward(args)...); } +template +inline ObjectPtr make_array(size_t num_elems, Args&&... args) { + return ArrayObjAllocator().make_array(num_elems, std::forward(args)...); +} + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_MEMORY_H_ diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 317b53531c2da..41ecfb311261f 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -55,35 +55,6 @@ class Tensor : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Tensor, ObjectRef, TensorObj); }; - -/*! \brief An object representing a structure or enumeration. */ -class ADTObj : public Object { - public: - /*! \brief The tag representing the constructor used. */ - size_t tag; - /*! \brief The fields of the structure. */ - std::vector fields; - - static constexpr const uint32_t _type_index = TypeIndex::kVMADT; - static constexpr const char* _type_key = "vm.ADT"; - TVM_DECLARE_FINAL_OBJECT_INFO(ADTObj, Object); -}; - -/*! \brief reference to algebraic data type objects. */ -class ADT : public ObjectRef { - public: - ADT(size_t tag, std::vector fields); - - /*! - * \brief construct a tuple object. - * \param fields The fields of the tuple. - * \return The constructed tuple type. - */ - static ADT Tuple(std::vector fields); - - TVM_DEFINE_OBJECT_REF_METHODS(ADT, ObjectRef, ADTObj); -}; - /*! \brief An object representing a closure. */ class ClosureObj : public Object { public: diff --git a/src/runtime/container.cc b/src/runtime/container.cc new file mode 100644 index 0000000000000..7f2bee9698b3d --- /dev/null +++ b/src/runtime/container.cc @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/* + * \file src/runtime/container.cc + * \brief POD container type implementations. + */ +#include +#include +#include +#include "object_internal.h" +#include "runtime_base.h" + +namespace tvm { +namespace runtime { + +template +ADT::ADT(uint32_t tag, Iterator begin, Iterator end) { + size_t num_elems = std::distance(begin, end); + auto ptr = make_array(num_elems); + ptr->tag_ = tag; + ptr->size_ = num_elems; + ptr->Init(begin, end); + data_ = std::move(ptr); +} + +ADT ADT::Tuple(std::vector fields) { return ADT(0, fields); } + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vm/object.cc b/src/runtime/vm/object.cc index 12edf511db668..a0c8d64f5e5dd 100644 --- a/src/runtime/vm/object.cc +++ b/src/runtime/vm/object.cc @@ -22,6 +22,7 @@ * \brief VM related objects. */ #include +#include #include #include #include @@ -39,17 +40,6 @@ Tensor::Tensor(NDArray data) { data_ = std::move(ptr); } -ADT::ADT(size_t tag, std::vector fields) { - auto ptr = make_object(); - ptr->tag = tag; - ptr->fields = std::move(fields); - data_ = std::move(ptr); -} - -ADT ADT::Tuple(std::vector fields) { - return ADT(0, fields); -} - Closure::Closure(size_t func_index, std::vector free_vars) { auto ptr = make_object(); ptr->func_index = func_index; @@ -71,7 +61,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTTag") ObjectRef obj = args[0]; const auto* cell = obj.as(); CHECK(cell != nullptr); - *rv = static_cast(cell->tag); + *rv = static_cast(cell->tag_); }); TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") @@ -79,7 +69,7 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTNumberOfFields") ObjectRef obj = args[0]; const auto* cell = obj.as(); CHECK(cell != nullptr); - *rv = static_cast(cell->fields.size()); + *rv = static_cast(cell->size_); }); @@ -89,8 +79,8 @@ TVM_REGISTER_GLOBAL("_vmobj.GetADTFields") int idx = args[1]; const auto* cell = obj.as(); CHECK(cell != nullptr); - CHECK_LT(idx, cell->fields.size()); - *rv = cell->fields[idx]; + CHECK_LT(idx, cell->size_); + *rv = (*cell)[idx]; }); TVM_REGISTER_GLOBAL("_vmobj.Tensor") diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 463c5758ae022..895f9004581ae 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -747,7 +748,7 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, size_t arity = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* obj = args[i].as()) { - arity += obj->fields.size(); + arity += obj->size_; } else { ++arity; } @@ -759,7 +760,8 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, int idx = 0; for (Index i = 0; i < arg_count; i++) { if (const auto* dt_cell = args[i].as()) { - for (auto obj : dt_cell->fields) { + for (size_t fi = 0; fi < dt_cell->size_; ++fi) { + auto obj = (*dt_cell)[fi]; const auto* tensor = obj.as(); CHECK(tensor != nullptr); setter(idx++, tensor->data); @@ -921,7 +923,7 @@ void VirtualMachine::RunLoop() { CHECK(tuple != nullptr) << "Object is not data type object, register " << instr.object << ", Object tag " << object->type_index(); - auto field = tuple->fields[instr.field_index]; + auto field = (*tuple)[instr.field_index]; WriteRegister(instr.dst, field); pc++; goto main_loop; @@ -933,7 +935,7 @@ void VirtualMachine::RunLoop() { << "Object is not data type object, register " << instr.get_tag.object << ", Object tag " << object->type_index(); - auto tag = data->tag; + auto tag = data->tag_; auto tag_tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); reinterpret_cast(tag_tensor->data)[0] = tag; WriteRegister(instr.dst, Tensor(tag_tensor));