-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
attr_registry_map.h
134 lines (126 loc) · 4.25 KB
/
attr_registry_map.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/node/attr_registry_map.h
* \brief Attribute map used in registry.
*/
#ifndef TVM_NODE_ATTR_REGISTRY_MAP_H_
#define TVM_NODE_ATTR_REGISTRY_MAP_H_
#include <tvm/node/container.h>
#include <utility>
#include <vector>
namespace tvm {
/*!
* \brief Generic attribute map.
* \tparam KeyType the type of the key.
*/
template <typename KeyType>
class AttrRegistryMapContainerMap {
public:
/*!
* \brief Check if the map has key.
* \param key The key to the map
* \return 1 if key is contained in map, 0 otherwise.
*/
int count(const KeyType& key) const {
if (key.defined()) {
const uint32_t idx = key->AttrRegistryIndex();
return idx < data_.size() ? (data_[idx].second != 0) : 0;
} else {
return 0;
}
}
/*!
* \brief get the corresponding value element at key.
* \param key The key to the map
* \return the const reference to the content value.
*/
const runtime::TVMRetValue& operator[](const KeyType& key) const {
CHECK(key.defined());
const uint32_t idx = key->AttrRegistryIndex();
CHECK(idx < data_.size() && data_[idx].second != 0)
<< "Attribute " << attr_name_ << " has not been registered for " << key->name;
return data_[idx].first;
}
/*!
* \brief get the corresponding value element at key with default value.
* \param key The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
* \tparam ValueType The content value type.
*/
template <typename ValueType>
ValueType get(const KeyType& key, ValueType def_value) const {
CHECK(key.defined());
const uint32_t idx = key->AttrRegistryIndex();
if (idx < data_.size() && data_[idx].second != 0) {
return data_[idx].first;
} else {
return def_value;
}
}
private:
/*! \brief The name of the attr field */
String attr_name_;
/*! \brief The internal data. */
std::vector<std::pair<runtime::TVMRetValue, int>> data_;
/*! \brief The constructor */
AttrRegistryMapContainerMap() = default;
template <typename, typename>
friend class AttrRegistry;
friend class OpRegEntry;
};
/*!
* \brief Map<Key, ValueType> used to store meta-data.
* \tparam KeyType The type of the key
* \tparam ValueType The type of the value stored in map.
*/
template <typename KeyType, typename ValueType>
class AttrRegistryMap {
public:
/*!
* \brief constructor
* \param map The internal map.
*/
explicit AttrRegistryMap(const AttrRegistryMapContainerMap<KeyType>& map) : map_(map) {}
/*!
* \brief Check if the map has op as key.
* \param key The key to the map
* \return 1 if op is contained in map, 0 otherwise.
*/
int count(const KeyType& key) const { return map_.count(key); }
/*!
* \brief get the corresponding value element at key.
* \param key The key to the map
* \return the const reference to the content value.
*/
ValueType operator[](const KeyType& key) const { return map_[key]; }
/*!
* \brief get the corresponding value element at key with default value.
* \param key The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
ValueType get(const KeyType& key, ValueType def_value) const { return map_.get(key, def_value); }
protected:
/*! \brief The internal map field */
const AttrRegistryMapContainerMap<KeyType>& map_;
};
} // namespace tvm
#endif // TVM_NODE_ATTR_REGISTRY_MAP_H_