-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
tensor.h
263 lines (237 loc) · 8.44 KB
/
tensor.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/te/tensor.h
* \brief Dataflow tensor object
*/
#ifndef TVM_TE_TENSOR_H_
#define TVM_TE_TENSOR_H_
#include <tvm/arith/bound.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
namespace tvm {
namespace te {
using arith::IntSet;
using namespace tvm::tir;
// Internal node container of Tensor
class TensorNode;
// internal node container for Operation
class OperationNode;
/*!
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public DataProducer {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(ObjectPtr<Object> n) : DataProducer(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorNode* operator->() const;
/*!
* \brief check if two tensors equals each other.
* \param other tensor to be checked.
* \return whether the two tensors equals each other.
*/
inline bool operator==(const Tensor& other) const;
/*!
* \brief check if two tensors are different.
* \param other tensor to be checked.
* \return whether the two tensors are different.
*/
inline bool operator!=(const Tensor& other) const;
/*! \return The dimension of the tensor */
inline size_t ndim() const;
/*!
* \brief Take elements from the tensor
* \param args The indices
* \return the result expression representing tensor read.
*/
template <typename... Args>
inline PrimExpr operator()(Args&&... args) const {
Array<PrimExpr> indices{std::forward<Args>(args)...};
return operator()(indices);
}
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr operator()(Array<PrimExpr> indices) const;
/*!
* \brief Take elements from the tensor
* \param indices the indices.
* \return the result expression representing tensor read.
*/
TVM_DLL PrimExpr operator()(Array<Var> indices) const;
/*!
* \brief data structure to represent a slice that fixes first k coordinates.
* This is used to enable syntax sugar of Tensor[x][y][z] to get the element.
*/
class Slice {
public:
// construct via tensor and indices
Slice(const Tensor& tensor, std::vector<PrimExpr> indices)
: tensor_(tensor), indices_(indices) {}
/*!
* \brief get i-th slice from the current slice.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline Slice operator[](PrimExpr i) {
std::vector<PrimExpr> other = indices_;
other.emplace_back(i);
return Slice(tensor_, other);
}
/*!
* \brief Convert slice to expression.
* This is only valid when all the coordinates are fully specified.
* \return the corresponding expression of this slice.
*/
inline operator PrimExpr() const { return tensor_(indices_); }
private:
const Tensor& tensor_;
std::vector<PrimExpr> indices_;
};
/*!
* \brief get i-th slice from the current Tensor.
* \param i the index of the coordinate
* \return the subsequent slice.
*/
inline Slice operator[](PrimExpr i) const { return Slice(*this, {i}); }
/*! \brief specify container node */
using ContainerType = TensorNode;
};
/*! \brief Operation that produces tensors */
class Operation : public tir::FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(ObjectPtr<Object> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const OperationNode* operator->() const;
/*!
* \brief get the i-th output of the operation.
* \param i the output index.
* \return The i-th output.
*/
TVM_DLL Tensor output(size_t i) const;
/*! \brief specify container node */
using ContainerType = OperationNode;
};
/*! \brief Node to represent a tensor */
class TensorNode : public DataProducerNode {
public:
/*! \brief The shape of the tensor */
Array<PrimExpr> shape;
/*! \brief data type in the content of the tensor */
DataType dtype;
/*! \brief the source operation, can be None */
Operation op;
/*! \brief the output index from source operation */
int value_index{0};
/*! \brief constructor */
TensorNode() {}
void VisitAttrs(AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
Array<PrimExpr> GetShape() const final { return shape; }
DataType GetDataType() const final { return dtype; }
TVM_DLL String GetNameHint() const final;
TVM_DLL static Tensor make(Array<PrimExpr> shape, DataType dtype, Operation op, int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorNode, DataProducerNode);
};
// Implementations of inline functions
inline const TensorNode* Tensor::operator->() const {
return static_cast<const TensorNode*>(get());
}
inline size_t Tensor::ndim() const { return (*this)->shape.size(); }
inline bool Tensor::operator==(const Tensor& other) const {
if (get() == other.get()) return true;
if (get() == nullptr || other.get() == nullptr) return false;
if ((*this)->op.defined() || other->op.defined()) {
return (*this)->op == other->op && (*this)->value_index == other->value_index;
} else {
return false;
}
}
inline bool Tensor::operator!=(const Tensor& other) const { return !(*this == other); }
// macro to turn every operation of slice to expression
#define DEFINE_OVERLOAD_SLICE_UNARY_OP(Op) \
inline PrimExpr operator Op(const Tensor::Slice& a) { return Op a.operator PrimExpr(); }
#define DEFINE_OVERLOAD_SLICE_BINARY_OP(Op) \
template <typename T> \
inline PrimExpr operator Op(const Tensor::Slice& a, const T& b) { \
return a.operator PrimExpr() Op b; \
} \
template <typename T> \
inline PrimExpr operator Op(const T& a, const Tensor::Slice& b) { \
return a Op b.operator PrimExpr(); \
} \
inline PrimExpr operator Op(const Tensor::Slice& a, const Tensor::Slice& b) { \
return a.operator PrimExpr() Op b.operator PrimExpr(); \
}
DEFINE_OVERLOAD_SLICE_UNARY_OP(!);
DEFINE_OVERLOAD_SLICE_UNARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(+);
DEFINE_OVERLOAD_SLICE_BINARY_OP(-);
DEFINE_OVERLOAD_SLICE_BINARY_OP(*);
DEFINE_OVERLOAD_SLICE_BINARY_OP(==);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(!=);
DEFINE_OVERLOAD_SLICE_BINARY_OP(&&);
DEFINE_OVERLOAD_SLICE_BINARY_OP(||);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>>);
DEFINE_OVERLOAD_SLICE_BINARY_OP(<<);
DEFINE_OVERLOAD_SLICE_BINARY_OP(>); // NOLINT(*)
DEFINE_OVERLOAD_SLICE_BINARY_OP(<); // NOLINT(*)
} // namespace te
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::te::Operation> : public ::tvm::ObjectPtrHash {};
template <>
struct hash<::tvm::te::Tensor> {
std::size_t operator()(const ::tvm::te::Tensor& k) const {
::tvm::ObjectPtrHash hasher;
if (k.defined() && k->op.defined()) {
return hasher(k->op);
} else {
return hasher(k);
}
}
};
} // namespace std
#endif // TVM_TE_TENSOR_H_