-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
int_set.h
229 lines (216 loc) · 7.49 KB
/
int_set.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arith/int_set.h
* \brief Integer set
*/
#ifndef TVM_ARITH_INT_SET_H_
#define TVM_ARITH_INT_SET_H_
#include <tvm/ir/expr.h>
#include <tvm/tir/expr.h>
#include <unordered_map>
namespace tvm {
namespace arith {
using tir::IterVar;
using tir::Var;
using tir::VarNode;
//-----------------------------------------------
// Integer set data structure.
//
// This is a API build on top of the base
// integer analysis API to provide set analysis.
//------------------------------------------------
/*!
* \brief Sign type of an integer expression.
*/
enum SignType { kPositive, kNegative, kZero, kUnknown };
/*!
* \brief Base class of all Integer set containers.
* represent a set of integers in one dimension.
* \sa IntSet
*/
class IntSetNode : public Object {
public:
static constexpr const char* _type_key = "IntSet";
static constexpr bool _type_has_method_sequal_reduce = false;
TVM_DECLARE_BASE_OBJECT_INFO(IntSetNode, Object);
};
/*!
* \brief Managed reference to IntSetNode.
* \sa IntSetNode
*/
class IntSet : public ObjectRef {
public:
/*! \brief constructor */
IntSet() {}
// constructor from not container.
explicit IntSet(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
const IntSetNode* operator->() const { return static_cast<const IntSetNode*>(get()); }
/*!
* \brief Find a range that covers the region.
* \param max_range The range to be covered.
* \return The covering range.
*/
Range cover_range(Range max_range) const;
/*! \return Lower bound of the set */
PrimExpr min() const;
/*! \return upper bound of the set */
PrimExpr max() const;
/*! \return Whether the set represent nothing */
bool is_nothing() const;
/*! \return Whether the set represent everything */
bool is_everything() const;
/*! \return Whether the set is a single point */
bool is_single_point() const;
/*! \return Whether the set is proved to be bigger than 0 */
bool can_prove_positive() const;
/*! \return Whether the set is proved to be smaller than 0 */
bool can_prove_negative() const;
/*! \return Whether the set is proved to be smaller than or equal to 0 */
bool can_prove_non_positive() const;
/*! \return Whether the set is proved to be larger than or equal to 0 */
bool can_prove_non_negative() const;
/*! \return The sign of the elements in the integer set */
SignType sign_type() const;
/*!
* \brief The single point value, call only if is_single_point is true
* \return The point value.
*/
PrimExpr point_value() const;
/*!
* \brief Try to match IntSet with range r.
*
* \note It is guanrateed that IntSet::range(r).match_range(r) == true
* \return true if we can prove they are the same.
*/
bool match_range(const Range& r) const;
/*! \return The set contains nothing */
static IntSet nothing();
/*! \return The set contains everything */
static IntSet everything();
/*!
* \brief construct a point set.
* \param point The point in the set.
* \return construct a single point set
*/
static IntSet single_point(PrimExpr point);
/*!
* \brief construct a integer set from vector expression.
* \param vec The vector expression, can also be single point.
* \return The result set containing the indices in the vector.
*/
static IntSet vector(PrimExpr vec);
/*!
* \brief Construct a set representing a range.
* \param r The range
* \return constructed set.
*/
static IntSet range(Range r);
/*!
* \brief Construct a set representing a interval.
* \param min The minimum value of the interval.
* \param max The maximum value of the interval.
* \return constructed set.
*/
static IntSet interval(PrimExpr min, PrimExpr max);
};
//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
*
* \param dom_map The domain map to convert.
* \return The converted map.
*/
Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e, const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param r The initial range.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(Range r, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
*
* \param s The initial set.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values.
*/
IntSet EvalSet(IntSet s, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
* \param r The range to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(Range r, const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectPtrHash, ObjectPtrEqual>;
/*!
* \brief Find the integer set of every sub-expression, given the
* domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return the map from the expression to its possible value.
*/
ExprIntSetMap EvalSetForEachSubExpr(PrimExpr e,
const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
*/
IntSet Union(const Array<IntSet>& sets);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be intersected
* \return the set after intersected
*/
IntSet Intersect(const Array<IntSet>& sets);
} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_INT_SET_H_