-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
function.h
200 lines (184 loc) · 7.89 KB
/
function.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/function.h
* \brief Relay Function.
*/
#ifndef TVM_RELAY_FUNCTION_H_
#define TVM_RELAY_FUNCTION_H_
#include <tvm/ir/function.h>
#include <tvm/relay/expr.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Relay Function container
* \sa Function
*/
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", ¶ms);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
// Important to make def equal first.
equal->MarkGraphNode();
return equal.DefEqual(params, other->params) &&
equal.DefEqual(type_params, other->type_params) && equal(ret_type, other->ret_type) &&
equal(attrs, other->attrs) && equal(body, other->body);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce.DefHash(params);
hash_reduce.DefHash(type_params);
hash_reduce(ret_type);
hash_reduce(attrs);
hash_reduce(body);
}
/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;
static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};
/*!
* \brief Managed reference to FunctionNode.
* \sa FunctionNode
*/
class Function : public BaseFunc {
public:
/*!
* \brief Constructor
* \param params The parameters of the function.
* \param body The body of the function.
* \param ret_type The return type of the function.
* \param ty_params The type parameters.
* \param attrs Additional function attributes.
* \param span The span of the function.
*/
TVM_DLL Function(tvm::Array<Var> params, Expr body, Type ret_type, tvm::Array<TypeVar> ty_params,
tvm::DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionNode);
};
/*!
* \brief Returns the function with given properties. A null property denotes 'no change'.
* Returns function if all properties are unchanged. Otherwise, returns a copy with the new fields.
* \param function The function to copy.
* \param opt_params The (optional) params for the copied function. If none,
* ret_function->params = function->params.
* \param opt_body The (optional) body for the copied function. If none,
* ret_function->body = function->body.
* \param opt_ret_type The (optional) return type for the copied function. If none,
* ret_function->ret_type = function->ret_type.
* \param opt_ty_params The (optional) type params for the copied function. If none,
* ret_function->type_params = function->type_params.
* \param opt_attrs
* The (optional) attributes for the copied function. If none,
* ret_function->attrs = function->attrs.
* \param opt_virtual_device The (optional) virtual_device for the copied function. If none,
* ret_function->virtual_device = function->virtual_device.
* \param opt_span The (optional) span for the copied function. If none,
* ret_function->span = function->span.
* \return If all properties are null or the same as the property in the input function
* (i.e., opt_params is null or opt_params.value() == function->params, etc.), then we return
* function. Otherwise, we return a copy of function with the different fields overwritten. (i.e.,
* if opt_params.value() != function->params, then ret_function->params = opt_params.value()).
*/
Function WithFields(Function function, Optional<Array<Var>> opt_params = Optional<Array<Var>>(),
Optional<Expr> opt_body = Optional<Expr>(),
Optional<Type> opt_ret_type = Optional<Type>(),
Optional<Array<TypeVar>> opt_ty_params = Optional<Array<TypeVar>>(),
Optional<DictAttrs> opt_attrs = Optional<DictAttrs>(),
Optional<VirtualDevice> opt_virtual_device = Optional<VirtualDevice>(),
Optional<Span> opt_span = Optional<Span>());
/*
* \brief Returns the Relay FunctionNode represented by base_func if it should be optimized,
* otherwise returns nullptr.
*
* This means returns nullptr:
* - For PrimFuncs, since not Relay Functions.
* - For Functions marked for external compilation (with "Compiler").
* - For Functions marked as already having an external definition (with "ExternalSymbol").
* - For Functions marked as not to be optimized (with "SkipOptimization").
*
* TODO(mbs): Audit all enumerations of IRModule::functions to use this or some family of such.
*/
const FunctionNode* AsOptimizableFunctionNode(const BaseFunc& base_func);
/*!
* \brief namespace of the attributes that can be attached to a relay::Function.
*/
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for building this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */
constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
/*! \brief Mark the function as only composed of reshape operations. */
constexpr const char* kReshapeOnly = "relay.reshape_only";
} // namespace attr
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_FUNCTION_H_