Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Runtime] Introduce runtime::Array #5585

Merged
merged 4 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 6 additions & 280 deletions include/tvm/node/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

namespace tvm {

using runtime::Array;
using runtime::ArrayNode;
using runtime::IterAdapter;
using runtime::make_object;
using runtime::Object;
using runtime::ObjectEqual;
Expand All @@ -46,16 +49,6 @@ using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;

/*! \brief array node content in array */
class ArrayNode : public Object {
public:
/*! \brief the data content */
std::vector<ObjectRef> data;

static constexpr const char* _type_key = "Array";
TVM_DECLARE_FINAL_OBJECT_INFO(ArrayNode, Object);
};

/*! \brief map node content */
class MapNode : public Object {
public:
Expand All @@ -82,273 +75,6 @@ class StrMapNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object);
};

/*!
* \brief iterator adapter that adapts TIter to return another type.
* \tparam Converter a struct that contains converting function
* \tparam TIter the content iterator type.
*/
template <typename Converter, typename TIter>
class IterAdapter {
public:
using difference_type = typename std::iterator_traits<TIter>::difference_type;
using value_type = typename Converter::ResultType;
using pointer = typename Converter::ResultType*;
using reference = typename Converter::ResultType&; // NOLINT(*)
using iterator_category = typename std::iterator_traits<TIter>::iterator_category;

explicit IterAdapter(TIter iter) : iter_(iter) {}
inline IterAdapter& operator++() {
++iter_;
return *this;
}
inline IterAdapter operator+(difference_type offset) const { return IterAdapter(iter_ + offset); }

template <typename T = IterAdapter>
typename std::enable_if<std::is_same<iterator_category, std::random_access_iterator_tag>::value,
typename T::difference_type>::type inline
operator-(const IterAdapter& rhs) const {
return iter_ - rhs.iter_;
}

inline bool operator==(IterAdapter other) const { return iter_ == other.iter_; }
inline bool operator!=(IterAdapter other) const { return !(*this == other); }
inline const value_type operator*() const { return Converter::convert(*iter_); }

private:
TIter iter_;
};

/*!
* \brief Array container of NodeRef in DSL graph.
* Array implements copy on write semantics, which means array is mutable
* but copy will happen when array is referenced in more than two places.
*
* operator[] only provide const acces, use Set to mutate the content.
* \tparam T The content NodeRef type.
*/
template <typename T,
typename = typename std::enable_if<std::is_base_of<ObjectRef, T>::value>::type>
class Array : public ObjectRef {
public:
/*!
* \brief default constructor
*/
Array() { data_ = make_object<ArrayNode>(); }
/*!
* \brief move constructor
* \param other source
*/
Array(Array<T>&& other) : ObjectRef() { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief copy constructor
* \param other source
*/
Array(const Array<T>& other) : ObjectRef() { // NOLINT(*)
data_ = std::move(other.data_);
}
/*!
* \brief constructor from pointer
* \param n the container pointer
*/
explicit Array(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template <typename IterType>
Array(IterType begin, IterType end) {
assign(begin, end);
}
/*!
* \brief constructor from initializer list
* \param init The initalizer list
*/
Array(std::initializer_list<T> init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
*/
Array(const std::vector<T>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
/*!
* \brief Constructs a container with n elements. Each element is a copy of val
* \param n The size of the container
* \param val The init value
*/
explicit Array(size_t n, const T& val) {
auto tmp_node = make_object<ArrayNode>();
for (size_t i = 0; i < n; ++i) {
tmp_node->data.push_back(val);
}
data_ = std::move(tmp_node);
}
/*!
* \brief move assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(Array<T>&& other) {
data_ = std::move(other.data_);
return *this;
}
/*!
* \brief copy assign operator
* \param other The source of assignment
* \return reference to self.
*/
Array<T>& operator=(const Array<T>& other) {
data_ = other.data_;
return *this;
}
/*!
* \brief reset the array to content from iterator.
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<ArrayNode>();
for (IterType it = begin; it != end; ++it) {
n->data.push_back(T(*it));
}
data_ = std::move(n);
}
/*!
* \brief Read i-th element from array.
* \param i The index
* \return the i-th element.
*/
inline const T operator[](size_t i) const {
return DowncastNoCheck<T>(static_cast<const ArrayNode*>(data_.get())->data[i]);
}
/*! \return The size of the array */
inline size_t size() const {
if (data_.get() == nullptr) return 0;
return static_cast<const ArrayNode*>(data_.get())->data.size();
}
/*!
* \brief copy on write semantics
* Do nothing if current handle is the unique copy of the array.
* Otherwise make a new copy of the array to ensure the current handle
* hold a unique copy.
*
* \return Handle to the internal node container(which ganrantees to be unique)
*/
inline ArrayNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<ArrayNode> n = make_object<ArrayNode>();
n->data = static_cast<ArrayNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
return static_cast<ArrayNode*>(data_.get());
}
/*!
* \brief push a new item to the back of the list
* \param item The item to be pushed.
*/
inline void push_back(const T& item) {
ArrayNode* n = this->CopyOnWrite();
n->data.push_back(item);
}
/*!
* \brief Resize the array.
* \param size The new size.
*/
inline void resize(size_t size) {
ArrayNode* n = this->CopyOnWrite();
n->data.resize(size);
}
/*!
* \brief set i-th element of the array.
* \param i The index
* \param value The value to be setted.
*/
inline void Set(size_t i, const T& value) {
ArrayNode* n = this->CopyOnWrite();
n->data[i] = value;
}
/*! \return whether array is empty */
inline bool empty() const { return size() == 0; }
/*!
* \brief Helper function to apply fmutate to mutate an array.
* \param fmutate The transformation function T -> T.
* \tparam F the type of the mutation function.
* \note This function performs copy on write optimization.
*/
template <typename F>
inline void MutateByApply(F fmutate) {
ArrayNode* ptr = static_cast<ArrayNode*>(data_.get());
if (ptr == nullptr) return;
if (data_.unique()) {
// Copy on write optimization.
// Perform inplace update because this is an unique copy.
for (size_t i = 0; i < ptr->data.size(); ++i) {
// It is important to use move here
// to make prevent the element's ref count from increasing
// so fmutate itself can perform copy-on-write optimization
T old_elem = DowncastNoCheck<T>(std::move(ptr->data[i]));
T new_elem = fmutate(std::move(old_elem));
ptr->data[i] = std::move(new_elem);
}
} else {
// lazily trigger copy if there is element change.
ObjectPtr<ArrayNode> copy;
for (size_t i = 0; i < ptr->data.size(); ++i) {
T old_elem = DowncastNoCheck<T>(ptr->data[i]);
T new_elem = fmutate(old_elem);
if (!new_elem.same_as(ptr->data[i])) {
// copy the old array
if (copy == nullptr) {
copy = runtime::make_object<ArrayNode>(*ptr);
}
copy->data[i] = std::move(new_elem);
}
}
// replace the data with the new copy.
if (copy != nullptr) {
data_ = std::move(copy);
}
}
}

/*! \brief specify container node */
using ContainerType = ArrayNode;

struct ValueConverter {
using ResultType = T;
static inline T convert(const ObjectRef& n) { return DowncastNoCheck<T>(n); }
};
using iterator = IterAdapter<ValueConverter, std::vector<ObjectRef>::const_iterator>;

using reverse_iterator =
IterAdapter<ValueConverter, std::vector<ObjectRef>::const_reverse_iterator>;

/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const ArrayNode*>(data_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const ArrayNode*>(data_.get())->data.end());
}
/*! \return rbegin iterator */
inline reverse_iterator rbegin() const {
return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rbegin());
}
/*! \return rend iterator */
inline reverse_iterator rend() const {
return reverse_iterator(static_cast<const ArrayNode*>(data_.get())->data.rend());
}
};

/*!
* \brief Map container of NodeRef->NodeRef in DSL graph.
* Map implements copy on write semantics, which means map is mutable
Expand Down Expand Up @@ -404,8 +130,8 @@ class Map : public ObjectRef {
assign(init.begin(), init.end());
}
/*!
* \brief constructor from vector
* \param init The vector
* \brief constructor from unordered_map
* \param init The unordered_map
*/
template <typename Hash, typename Equal>
Map(const std::unordered_map<K, V, Hash, Equal>& init) { // NOLINT(*)
Expand Down Expand Up @@ -625,7 +351,7 @@ struct ObjectTypeChecker<Array<T> > {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<ArrayNode>()) return false;
const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
for (const auto& p : n->data) {
for (const ObjectRef& p : *n) {
if (!ObjectTypeChecker<T>::Check(p.get())) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
"If `axis = None`, all axis of dimension 1 get squeezed;"
"Else, the dimension in axes get squeezed."
"It is an error if an axis does not has dimension 1.")
.set_default(NullValue<Array<Integer> >());
.set_default(NullValue<Array<Integer>>());
}
}; // struct SqueezeAttrs

Expand Down
Loading