/
build_module.py
317 lines (266 loc) · 11.6 KB
/
build_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""The build utils in python."""
import warnings
from typing import Union, Optional, List, Mapping
import tvm.tir
from tvm.runtime import Module
from tvm.runtime import ndarray
from tvm.ir import container
from tvm.tir import PrimFunc
from tvm.ir.module import IRModule
from tvm.te import tensor
from tvm.te import schedule
from tvm.target import Target
from tvm.tir.buffer import Buffer
from tvm.tir.expr import Var
from tvm.driver import _ffi_api as _driver_ffi
from . import _ffi_api as ffi
def get_binds(args, compact=False, binds=None):
"""Internal function to get binds and arg_list given arguments.
Parameters
----------
args : list of Buffer or Tensor or Var
The argument lists to the function.
compact : bool
If the statement has already bound to a compact buffer.
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
Returns
-------
binds: dict
The bind specification
arg_list: list
The list of symbolic buffers of arguments.
"""
binds, arg_list = ffi.get_binds(args, compact, binds)
return binds, arg_list
def schedule_to_module(
sch: schedule.Schedule,
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
) -> IRModule:
"""According to the given schedule, form a function.
This is a low-level function intended for testing purposes, and
does not apply any optimization passes. In general, `tvm.lower`
and `tvm.build` should be used instead.
Parameters
----------
sch : tvm.te.schedule.Schedule
The given scheduler to form the raw body
args : list of Buffer or Tensor or Var
The argument lists to the function.
name : str
The name of result function, default name is "main"
binds : dict of :any:`Tensor` to :any:`Buffer`, optional
The binds information
Returns
-------
The body formed according to the given schedule
"""
return ffi.schedule_to_module(sch, args, name, binds)
def lower(
inp: Union[schedule.Schedule, PrimFunc, IRModule],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
name: str = "main",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
simple_mode: bool = False,
) -> IRModule:
"""Lowering step before build into target.
Parameters
----------
inp : Union[tvm.te.schedule.Schedule, tvm.tir.PrimFunc, IRModule]
The TE schedule or TensorIR PrimFunc/IRModule to be built
args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function for TE schedule.
It should be None if we want to lower TensorIR.
name : str
The name of the result function.
binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]]
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.
simple_mode : bool
Whether only output simple and compact statement, this will skip
LoopPartition, api wrapper generation and Unrolling.
Returns
-------
m : IRModule
The result IRModule
"""
if isinstance(inp, IRModule):
return ffi.lower_module(inp, simple_mode)
if isinstance(inp, PrimFunc):
return ffi.lower_primfunc(inp, name, simple_mode)
if isinstance(inp, schedule.Schedule):
return ffi.lower_schedule(inp, args, name, binds, simple_mode)
raise ValueError("Expected input to be an IRModule, PrimFunc or Schedule, but got, ", type(inp))
def build(
inputs: Union[schedule.Schedule, PrimFunc, IRModule, Mapping[str, IRModule]],
args: Optional[List[Union[Buffer, tensor.Tensor, Var]]] = None,
target: Optional[Union[str, Target]] = None,
target_host: Optional[Union[str, Target]] = None,
runtime: Optional[
"tvm.relay.backend.Runtime"
] = None, # Type is annotated this way to avoid cyclic dependency
name: Optional[str] = "default_function",
binds: Optional[Mapping[tensor.Tensor, Buffer]] = None,
):
"""Build a function with arguments as signature. Code will be generated
for devices coupled with target information.
Parameters
----------
inputs : Union[tvm.te.schedule.Schedule,
tvm.tir.PrimFunc, IRModule, Mapping[str, IRModule]]
The input to be built
args : Optional[List[Union[tvm.tir.Buffer, tensor.Tensor, Var]]]
The argument lists to the function.
target : Optional[Union[str, Target]]
The target and option of the compilation.
target_host : Optional[Union[str, Target]]
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm interpreter is used.
runtime : Optional[Runtime]
Runtime to generate artifacts for
name : Optional[str]
The name of result function.
binds : Optional[Mapping[tensor.Tensor, tvm.tir.Buffer]]
Dictionary that maps the binding of symbolic buffer to Tensor.
By default, a new buffer is created for each tensor in the argument.
Returns
-------
ret : tvm.module
A module that combines both host and device code.
Examples
________
There are two typical example uses of this function depending on the type
of the argument `inputs`:
1. it is an IRModule.
.. code-block:: python
n = 2
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
m = tvm.lower(s, [A, B, C], name="test_add")
rt_mod = tvm.build(m, target="llvm")
2. it is a dict of compilation target to IRModule.
.. code-block:: python
n = 2
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s1 = tvm.te.create_schedule(C.op)
with tvm.target.cuda() as cuda_tgt:
s2 = topi.cuda.schedule_injective(cuda_tgt, [C])
m1 = tvm.lower(s1, [A, B, C], name="test_add1")
m2 = tvm.lower(s2, [A, B, C], name="test_add2")
rt_mod = tvm.build({"llvm": m1, "cuda": m2})
Note
----
See the note on :any:`tvm.target` on target string format.
"""
if isinstance(inputs, schedule.Schedule):
if args is None:
raise ValueError("args must be given for build from schedule")
input_mod = lower(inputs, args, name=name, binds=binds)
elif isinstance(inputs, (list, tuple, container.Array)):
merged_mod = tvm.IRModule({})
for x in inputs:
merged_mod.update(lower(x))
input_mod = merged_mod
elif isinstance(inputs, PrimFunc):
input_mod = lower(inputs, name=name)
elif isinstance(inputs, tvm.IRModule):
input_mod = lower(inputs)
elif not isinstance(inputs, (dict, container.Map)):
raise ValueError(
f"Inputs must be Schedule, IRModule or dict of target to IRModule, "
f"but got {type(inputs)}."
)
if target_host is not None:
warnings.warn(
"target_host parameter is going to be deprecated. "
"Please pass in tvm.target.Target(target, host=target_host) instead."
)
if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
target_input_mod = {target: input_mod}
else:
target_input_mod = inputs
# Because modules can be created from a variety of sources, we annotate them
# with the relevant attributes here to ensure they propagate
annotated_mods = {}
for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, Target)):
raise ValueError("The key of inputs must be str or " "Target when inputs is dict.")
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
annotated_mods[tar] = mod.with_attr("runtime", runtime)
annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)
if not target_host:
for tar, mod in annotated_mods.items():
tar = Target(tar)
device_type = ndarray.device(tar.kind.name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"
annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)
rt_mod_host = _driver_ffi.preprocess_module(annotated_mods, target_host)
annotated_mods, target_host = Target.check_and_update_host_consist(annotated_mods, target_host)
if not isinstance(target_host, Target):
target_host = Target(target_host)
if str(runtime) == "crt" and runtime["system-lib"]:
if target_host.kind.name == "c":
create_csource_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateCSourceCrtMetadataModule"
)
to_return = create_csource_crt_metadata_module([rt_mod_host], target_host, runtime)
elif target_host.kind.name == "llvm":
create_llvm_crt_metadata_module = tvm._ffi.get_global_func(
"runtime.CreateLLVMCrtMetadataModule"
)
to_return = create_llvm_crt_metadata_module([rt_mod_host], target_host, runtime)
else:
to_return = rt_mod_host
return OperatorModule.from_module(to_return, ir_module_by_target=annotated_mods, name=name)
class OperatorModule(Module):
"""Wraps the Module returned by tvm.build() and captures additional outputs of that function."""
@classmethod
def from_module(cls, mod, **kwargs):
# NOTE(areusch): It is generally unsafe to continue using `mod` from this point forward.
# If an exception occurs in cls.__init__, handle will be deleted. For this reason,
# set mod.handle to None.
handle = mod.handle
mod.handle = None
return cls(handle, **kwargs)
def __init__(self, handle, ir_module_by_target=None, name=None):
super(OperatorModule, self).__init__(handle)
self.ir_module_by_target = ir_module_by_target
self.name = name