Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Object] Unify StrMapNode and MapNode #5687

Merged
merged 9 commits into from Jun 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/arith/analyzer.h
Expand Up @@ -107,7 +107,7 @@ class ConstIntBound : public ObjectRef {
*/
class ConstIntBoundAnalyzer {
public:
using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectHash, ObjectEqual>;
using BoundMapType = std::unordered_map<PrimExpr, ConstIntBound, ObjectPtrHash, ObjectPtrEqual>;
/*!
* \brief analyze the expr
* \param expr The expression of interest.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/arith/int_set.h
Expand Up @@ -198,7 +198,7 @@ IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_m
*/
IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/driver/driver_api.h
Expand Up @@ -83,7 +83,7 @@ TVM_DLL runtime::Module build(const Map<Target, IRModule>& input, const Target&
* pass Target().
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, IRModule>& input, const Target& target_host);
TVM_DLL runtime::Module build(const Map<String, IRModule>& input, const Target& target_host);
} // namespace tvm

#endif // TVM_DRIVER_DRIVER_API_H_
4 changes: 2 additions & 2 deletions include/tvm/ir/attrs.h
Expand Up @@ -201,7 +201,7 @@ class Attrs : public ObjectRef {
class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
Map<String, ObjectRef> dict;

bool SEqualReduce(const DictAttrsNode* other, SEqualReducer equal) const {
return equal(dict, other->dict);
Expand Down Expand Up @@ -230,7 +230,7 @@ class DictAttrs : public Attrs {
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/error.h
Expand Up @@ -177,8 +177,8 @@ class ErrorReporter {

private:
std::vector<Error> errors_;
std::unordered_map<ObjectRef, std::vector<size_t>, ObjectHash, ObjectEqual> node_to_error_;
std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_;
std::unordered_map<ObjectRef, std::vector<size_t>, ObjectPtrHash, ObjectPtrEqual> node_to_error_;
std::unordered_map<ObjectRef, GlobalVar, ObjectPtrHash, ObjectPtrEqual> node_to_gv_;
};

} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/function.h
Expand Up @@ -188,7 +188,7 @@ inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_va
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<std::string, ObjectRef> dict = {{attr_key, attr_value}};
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/module.h
Expand Up @@ -250,12 +250,12 @@ class IRModuleNode : public Object {
/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Map<std::string, GlobalVar> global_var_map_;
Map<String, GlobalVar> global_var_map_;

/*! \brief A map from string names to global type variables (ADT names)
* that ensures global uniqueness.
*/
Map<std::string, GlobalTypeVar> global_type_var_map_;
Map<String, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/transform.h
Expand Up @@ -100,7 +100,7 @@ class PassContextNode : public Object {
TraceFunc trace_func;

/*! \brief Pass specific configurations. */
Map<std::string, ObjectRef> config;
Map<String, ObjectRef> config;

PassContextNode() = default;

Expand Down
146 changes: 26 additions & 120 deletions include/tvm/node/container.h
Expand Up @@ -39,16 +39,40 @@ namespace tvm {

using runtime::Array;
using runtime::ArrayNode;
using runtime::Downcast;
using runtime::IterAdapter;
using runtime::make_object;
using runtime::Object;
using runtime::ObjectEqual;
using runtime::ObjectHash;
using runtime::ObjectPtr;
using runtime::ObjectPtrEqual;
using runtime::ObjectPtrHash;
using runtime::ObjectRef;
using runtime::String;
using runtime::StringObj;

struct ObjectHash {
size_t operator()(const ObjectRef& a) const {
if (const auto* str = a.as<StringObj>()) {
return String::HashBytes(str->data, str->size);
}
return ObjectPtrHash()(a);
}
};

struct ObjectEqual {
bool operator()(const ObjectRef& a, const ObjectRef& b) const {
if (a.same_as(b)) {
return true;
}
if (const auto* str_a = a.as<StringObj>()) {
if (const auto* str_b = b.as<StringObj>()) {
return String::memncmp(str_a->data, str_b->data, str_a->size, str_b->size) == 0;
}
}
return false;
}
};

/*! \brief map node content */
class MapNode : public Object {
public:
Expand All @@ -62,19 +86,6 @@ class MapNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(MapNode, Object);
};

/*! \brief specialized map node with string as key */
class StrMapNode : public Object {
public:
/*! \brief The corresponding conatiner type */
using ContainerType = std::unordered_map<std::string, ObjectRef>;

/*! \brief the data content */
ContainerType data;

static constexpr const char* _type_key = "StrMap";
TVM_DECLARE_FINAL_OBJECT_INFO(StrMapNode, Object);
};

/*!
* \brief Map container of NodeRef->NodeRef in DSL graph.
* Map implements copy on write semantics, which means map is mutable
Expand Down Expand Up @@ -249,97 +260,6 @@ class Map : public ObjectRef {
}
};

// specialize of string map
template <typename V, typename T1, typename T2>
class Map<std::string, V, T1, T2> : public ObjectRef {
public:
// for code reuse
Map() { data_ = make_object<StrMapNode>(); }
Map(Map<std::string, V>&& other) { // NOLINT(*)
data_ = std::move(other.data_);
}
Map(const Map<std::string, V>& other) : ObjectRef(other.data_) { // NOLINT(*)
}
explicit Map(ObjectPtr<Object> n) : ObjectRef(n) {}
template <typename IterType>
Map(IterType begin, IterType end) {
assign(begin, end);
}
Map(std::initializer_list<std::pair<std::string, V> > init) { // NOLINT(*)
assign(init.begin(), init.end());
}

template <typename Hash, typename Equal>
Map(const std::unordered_map<std::string, V, Hash, Equal>& init) { // NOLINT(*)
assign(init.begin(), init.end());
}
Map<std::string, V>& operator=(Map<std::string, V>&& other) {
data_ = std::move(other.data_);
return *this;
}
Map<std::string, V>& operator=(const Map<std::string, V>& other) {
data_ = other.data_;
return *this;
}
template <typename IterType>
void assign(IterType begin, IterType end) {
auto n = make_object<StrMapNode>();
for (IterType i = begin; i != end; ++i) {
n->data.emplace(std::make_pair(i->first, i->second));
}
data_ = std::move(n);
}
inline const V operator[](const std::string& key) const {
return DowncastNoCheck<V>(static_cast<const StrMapNode*>(data_.get())->data.at(key));
}
inline const V at(const std::string& key) const {
return DowncastNoCheck<V>(static_cast<const StrMapNode*>(data_.get())->data.at(key));
}
inline size_t size() const {
if (data_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(data_.get())->data.size();
}
inline size_t count(const std::string& key) const {
if (data_.get() == nullptr) return 0;
return static_cast<const StrMapNode*>(data_.get())->data.count(key);
}
inline StrMapNode* CopyOnWrite() {
if (data_.get() == nullptr || !data_.unique()) {
ObjectPtr<StrMapNode> n = make_object<StrMapNode>();
n->data = static_cast<const StrMapNode*>(data_.get())->data;
ObjectPtr<Object>(std::move(n)).swap(data_);
}
return static_cast<StrMapNode*>(data_.get());
}
inline void Set(const std::string& key, const V& value) {
StrMapNode* n = this->CopyOnWrite();
n->data[key] = value;
}
inline bool empty() const { return size() == 0; }
using ContainerType = StrMapNode;

struct ValueConverter {
using ResultType = std::pair<std::string, V>;
static inline ResultType convert(const std::pair<std::string, ObjectRef>& n) {
return std::make_pair(n.first, DowncastNoCheck<V>(n.second));
}
};

using iterator = IterAdapter<ValueConverter, StrMapNode::ContainerType::const_iterator>;

/*! \return begin iterator */
inline iterator begin() const {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.begin());
}
/*! \return end iterator */
inline iterator end() const {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.end());
}
/*! \return begin iterator */
inline iterator find(const std::string& key) const {
return iterator(static_cast<const StrMapNode*>(data_.get())->data.find(key));
}
};
} // namespace tvm

namespace tvm {
Expand All @@ -361,20 +281,6 @@ struct ObjectTypeChecker<Array<T> > {
static std::string TypeName() { return "List[" + ObjectTypeChecker<T>::TypeName() + "]"; }
};

template <typename V>
struct ObjectTypeChecker<Map<std::string, V> > {
static bool Check(const Object* ptr) {
if (ptr == nullptr) return true;
if (!ptr->IsInstance<StrMapNode>()) return false;
const StrMapNode* n = static_cast<const StrMapNode*>(ptr);
for (const auto& kv : n->data) {
if (!ObjectTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static std::string TypeName() { return "Map[str, " + ObjectTypeChecker<V>::TypeName() + ']'; }
};

template <typename K, typename V>
struct ObjectTypeChecker<Map<K, V> > {
static bool Check(const Object* ptr) {
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/node/node.h
Expand Up @@ -55,9 +55,9 @@ using runtime::Downcast;
using runtime::GetRef;
using runtime::make_object;
using runtime::Object;
using runtime::ObjectEqual;
using runtime::ObjectHash;
using runtime::ObjectPtr;
using runtime::ObjectPtrEqual;
using runtime::ObjectPtrHash;
using runtime::ObjectRef;
using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/node/reflection.h
Expand Up @@ -161,8 +161,7 @@ class ReflectionVTable {
* \param kwargs The field arguments.
* \return The created object.
*/
TVM_DLL ObjectRef CreateObject(const std::string& type_key,
const Map<std::string, ObjectRef>& kwargs);
TVM_DLL ObjectRef CreateObject(const std::string& type_key, const Map<String, ObjectRef>& kwargs);
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
Expand Down
3 changes: 1 addition & 2 deletions include/tvm/relay/dataflow_matcher.h
Expand Up @@ -94,8 +94,7 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr);
*
* \return Return the paritioned Expr.
*/
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<std::string, ObjectRef> attrs,
PackedFunc check);
Expr PartitionPattern(DFPattern pattern, Expr expr, Map<String, ObjectRef> attrs, PackedFunc check);

} // namespace relay
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/expr_functor.h
Expand Up @@ -214,7 +214,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {

protected:
/*! \brief Internal map used for memoization. */
std::unordered_map<Expr, Expr, ObjectHash, ObjectEqual> memo_;
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo_;
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/pattern_functor.h
Expand Up @@ -158,7 +158,7 @@ class PatternMutator : public ::tvm::relay::PatternFunctor<Pattern(const Pattern
virtual Constructor VisitConstructor(const Constructor& c);

private:
std::unordered_map<Var, Var, ObjectHash, ObjectEqual> var_map_;
std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_map_;
};

} // namespace relay
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Expand Up @@ -286,7 +286,7 @@ TVM_DLL Pass AlterOpLayout();
* this specifies the desired layout for data then kernel for nn.conv2d.
* \return The pass.
*/
TVM_DLL Pass ConvertLayout(const Map<std::string, Array<String>>& desired_layouts);
TVM_DLL Pass ConvertLayout(const Map<String, Array<String>>& desired_layouts);

/*!
* \brief Legalizes an expr with another expression.
Expand Down