From a164d89ca609d6cfb8697ddf30d2480b291eb042 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 14 Dec 2023 20:38:53 -0800 Subject: [PATCH] [Unity][nn.Module] Refactor `ExternModule` `nn.ExternModule` allows incorporation of handcrafted kernels into the compilation stack and being invoked by Relax just like TIR or any other ordinary operator. This PR simplifies its workflow. The system consists of the abstract base class `ExternModule` and its two derivatives: - `.o` (object files) can be linked using `ObjectModule`. - `.cpp` (C++ files) and `.cu` (CUDA files) can be compiled and linked into the system usung `SourceModule`. **Symbols, and shape/dtype inference.** To provide the system with sufficient information about the kernels, it is required to provide all symbols of an external module, as well as a method for each symbol that tells the system about the output dtype/shape of this symbol. Consider a case where function `my_func` accepts two tensors, `a` of shape `(x, y, 1)`, `b` of shape `(y, z, 5)`, and then produces a tensor `c` of shape `(x, y, z, 9)`, the shape/dtype inference function should look like: ```python def shape_dtype_inference(a, b): x, y, _ = a.shape _, z, _ = b.shape return nn.Tensor.placeholder((x, y, z, 9), dtype="float32") ``` Regarding the interface, the symbols and their corresponding shape/dtype inference function should be provided as a Python dictionary that maps each symbol to the function as below: ```python symbols={ "my_func": shape_dtype_inference, } ``` **Calling convention.** All external modules now follows "destination-passing-style" (DPS) calling convention, which means the returned tensors are pre-allocated by the system already and passed in as an argument of the external function. Reuse the example above, the implementation of `my_func` should include three parameters in its signature, where tensors are represented using DLTensor from DLPack, the de facto standard of in-memory representation of tensors. More info on DLPack: https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206. To expose the symbol, `TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)` is guaranteed available: ```C++ // those headers are guaranteed to be available \#include \#include \#include namespace { // anonymous namespace hides the symbol `_my_func_impl` from other TUs int _my_func_impl(DLTensor* a, DLTensor* b, DLTensor* c) { // `a` and `b` are inputs, and `c` is the output } } // expose symbol `my_func` instead of `_my_func_impl` TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl); ``` **A compiler pass `AttachExternModules`.** It is introduced to attach a list of `nn.ExternModule`s into an IRModule at any stage of the compilation pipeline, and attach the compiled external modules as `runtime.Module`s into IRModule's `external_mods` attribute. It is required by linking in `relax.build`, but with the existence of this pass, source compilation can be deferred to arbitrary stage of TVM compilation. **Caveats.** It is required to call `nn.add_extern` to register external modules exactly once during `export_tvm`. Each symbol should be registered exactly once to avoid potential conflicts, and otherwise an error will be raised. This programming model might be a bit of constraint, and we will consider loose it slightly in the future. Also, for backward compatibility, `ExternModule`s are exported from `export_tvm` only when `allow_extern` flag is turned on. Otherwise, any external module will cause an exception asking to turn on the flag. --- 3rdparty/flashinfer | 2 +- python/tvm/contrib/cc.py | 2 +- python/tvm/relax/frontend/nn/__init__.py | 13 +- python/tvm/relax/frontend/nn/core.py | 381 ++++++----------- python/tvm/relax/frontend/nn/exporter.py | 314 ++++++++++++++ python/tvm/relax/frontend/nn/extern.py | 392 +++++++++++++++++ python/tvm/relax/frontend/nn/modules.py | 66 --- python/tvm/relax/frontend/nn/op.py | 119 ++---- python/tvm/relax/frontend/nn/spec.py | 393 +----------------- python/tvm/relax/transform/__init__.py | 1 + .../transform/attach_external_modules.py | 52 +++ python/tvm/relax/vm_build.py | 5 +- python/tvm/runtime/ndarray.py | 31 +- python/tvm/target/detect_target.py | 23 +- .../python/relax/frontend_nn_extern_module.cc | 69 +++ .../relax/test_frontend_nn_extern_module.py | 363 ++++++++-------- .../python/relax/test_frontend_nn_modules.py | 55 +-- 17 files changed, 1234 insertions(+), 1047 deletions(-) create mode 100644 python/tvm/relax/frontend/nn/exporter.py create mode 100644 python/tvm/relax/frontend/nn/extern.py create mode 100644 python/tvm/relax/transform/attach_external_modules.py create mode 100644 tests/python/relax/frontend_nn_extern_module.cc diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer index 11364ca4c3ce..e668648bf15b 160000 --- a/3rdparty/flashinfer +++ b/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 11364ca4c3ce651dd544efff3225906fe15c5b8a +Subproject commit e668648bf15b77360f1ca8478a54aa722622981c diff --git a/python/tvm/contrib/cc.py b/python/tvm/contrib/cc.py index 8ad70dc254ee..e678785cbfd5 100644 --- a/python/tvm/contrib/cc.py +++ b/python/tvm/contrib/cc.py @@ -82,7 +82,7 @@ def create_shared(output, objects, options=None, cc=None, cwd=None, ccache_env=N The compiler command. cwd : Optional[str] - The urrent working directory. + The current working directory. ccache_env : Optional[Dict[str, str]] The environment variable for ccache. Set `None` to disable ccache by default. diff --git a/python/tvm/relax/frontend/nn/__init__.py b/python/tvm/relax/frontend/nn/__init__.py index a2c6f39442db..5723e3d9ffc7 100644 --- a/python/tvm/relax/frontend/nn/__init__.py +++ b/python/tvm/relax/frontend/nn/__init__.py @@ -17,15 +17,9 @@ """A PyTorch-like API to build IRModules.""" # pylint: disable=redefined-builtin from . import op, spec -from .core import ( - Effect, - ExternModule, - Module, - ModuleList, - Parameter, - SourceModule, - Tensor, -) +from .core import Effect, Module, ModuleList, Parameter, Tensor +from .exporter import add_extern +from .extern import ExternModule, ObjectModule, SourceModule from .modules import ( GELU, Conv1D, @@ -35,7 +29,6 @@ KVCache, LayerNorm, Linear, - MultiLinear, ReLU, RMSNorm, SiLU, diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index c7d745c72154..8ed0efe2cd04 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -23,10 +23,6 @@ - Effect, a non-user-facing class that encloses potential side effects, for example, IO, impure external function callings, inplace mutation, etc. """ -import os -import shutil -import sys -import tempfile from collections import OrderedDict from typing import ( TYPE_CHECKING, @@ -41,26 +37,28 @@ Union, ) -import numpy as np +import numpy as np # type: ignore from tvm import tir -from tvm._ffi.libinfo import find_include_path -from tvm.contrib import cc as _cc from tvm.ir import IRModule -from tvm.runtime import Device, NDArray, load_static_library, ndarray +from tvm.ir.transform import Pass +from tvm.runtime import Device, NDArray +from tvm.runtime import device as as_device +from tvm.runtime import ndarray from tvm.runtime.relax_vm import VirtualMachine from tvm.target import Target from ... import expr as rx from ...block_builder import BlockBuilder -from ...struct_info import ShapeStructInfo, TensorStructInfo +from ...struct_info import ShapeStructInfo, TensorStructInfo, TupleStructInfo from ._tensor_op import _TensorOp from .subroutine import SubroutineMixin if TYPE_CHECKING: - import torch + import torch # type: ignore from . import spec as _spec + from .extern import ExternModule _DEFAULT_DTYPE = "float32" @@ -123,6 +121,37 @@ def from_scalar(data: Union[int, float], dtype: str) -> "Tensor": """Construct a tensor from a scalar with dtype specified.""" return Tensor(_expr=rx.const(data, dtype=dtype)) + @staticmethod + def placeholder( + shape: Sequence[Union[int, tir.PrimExpr]], + dtype: str, + name: str = "tensor", + ) -> "Tensor": + """Create a placeholder tensor with given shape and dtype. A placeholder tensor should + never be created directly by users in usual cases, and the only exception is to indicate + the shape/dtype of return values of an external function. + """ + new_shape = [] + for expr in shape: + if isinstance(expr, (int, tir.IntImm)): + expr = int(expr) + assert expr >= 0 + new_shape.append(expr) + continue + if not isinstance(expr, tir.PrimExpr): + raise TypeError(f"Invalid shape: {shape}") + assert expr.dtype == "int64" + new_shape.append(expr) + return Tensor( + _expr=rx.Var( + name_hint=name, + struct_info=TensorStructInfo( + shape=new_shape, # type: ignore[arg-type] + dtype=dtype, + ), + ) + ) + @property def shape(self) -> List[Union[int, tir.PrimExpr]]: """Returns the shape of the tensor as a list of integers. @@ -195,7 +224,7 @@ def __init__( """ if dtype is None: dtype = get_default_dtype() - super().__init__(_expr=_tensor_placeholder("param", shape, dtype=dtype)._expr) + super().__init__(_expr=Tensor.placeholder(shape, dtype=dtype, name="param")._expr) self._data = None self.attrs = OrderedDict() @@ -240,8 +269,8 @@ def to(self, dtype: Optional[str] = None) -> None: # pylint: disable=invalid-na "data is not recommended. It might lead to potential precision loss " "or other unexpected behaviors" ) - self._expr = _tensor_placeholder( # pylint: disable=protected-access - "param", self.shape, dtype=dtype + self._expr = Tensor.placeholder( # pylint: disable=protected-access + self.shape, dtype=dtype, name="param" )._expr @@ -381,7 +410,18 @@ def export_tvm( self, spec: "_spec.ModuleSpecType", debug: bool = False, - ) -> Tuple[IRModule, List[Tuple[str, Parameter]]]: + allow_extern: bool = False, + ) -> Union[ + Tuple[ + IRModule, + List[Tuple[str, Parameter]], + ], + Tuple[ + IRModule, + List[Tuple[str, Parameter]], + List["ExternModule"], + ], + ]: """Export the module to TVM IRModule and parameters Parameters @@ -400,233 +440,68 @@ def export_tvm( params : Dict[str, tvm.nd.array] A dictionary of parameters corresponding to the weights of the model. + ext_mods : List[nn.ExternModule] """ - from . import spec as _spec # pylint: disable=import-outside-toplevel + # pylint: disable=import-outside-toplevel + from . import spec as _spec + from .exporter import Exporter + # pylint: enable=import-outside-toplevel spec = _spec.ModuleSpec.from_raw(spec, self) - mod, params = _spec.SpecBuilder().build(spec, debug=debug) + mod, params, ext_mods = Exporter(debug=debug).build(spec) + if allow_extern: + return mod, params, ext_mods + if ext_mods: + raise ValueError( + "`ExternModule`(s) exist when they are not allowed. " + "Turn on flag `allow_extern` to allow." + ) return mod, params def jit( # pylint: disable=too-many-arguments self, spec: "_spec.ModuleSpec", - target: Union[str, Target] = "llvm", device: Union[str, Device] = "cpu", - pipeline: str = "zero", + pipeline: Union[None, str, Pass] = "default_build", out_format: str = "torch", debug: bool = False, ) -> Any: """Just-in-time compilation of a nn.model to an executable""" - from tvm import relax # pylint: disable=import-outside-toplevel - - from . import spec as _spec # pylint: disable=import-outside-toplevel - # Convert nn.Module to IRModule - spec = _spec.ModuleSpec.from_raw(spec, self) - mod, params = _spec.SpecBuilder().build(spec, debug=debug) - - # Convert parameters - device = _str_to_device(device) - params_ndarray = _param_to_ndarray(params, device) + def _compile(spec, device, pipeline, debug): + # pylint: disable=import-outside-toplevel + from ...transform import AttachExternModules + from ...vm_build import build as relax_build + from . import spec as _spec + from .exporter import Exporter + + # pylint: enable=import-outside-toplevel + + spec = _spec.ModuleSpec.from_raw(spec, self) + mod, params, ext_mods = Exporter(debug=debug).build(spec) + mod = AttachExternModules(ext_mods)(mod) # pylint: disable=not-callable + vm = VirtualMachine( # pylint: disable=invalid-name + relax_build( + mod, + target=Target.from_device(device), + pipeline=pipeline, + ), + device, + ) + params = _param_to_ndarray(params, device) + return spec, vm, params - # Compile mod and feed it to VM - mod = relax.pipeline.get_pipeline(pipeline)(mod) # pylint: disable=no-value-for-parameter - vm = VirtualMachine( # pylint: disable=invalid-name - relax.build(mod, target=target), - device, - ) + device = as_device(device) + spec, vm, params = _compile(spec, device, pipeline, debug) # pylint: disable=invalid-name if out_format == "torch": from . import torch # pylint: disable=import-outside-toplevel - return torch.TorchModule(spec=spec, params=params_ndarray, vm=vm) + return torch.TorchModule(spec=spec, params=params, vm=vm) raise ValueError(f"Unknown out_format: {out_format}") -class ExternModule(Module): - """Base class for external module. Subclass it to import your external models. - Modules can nest within each other in a tree structure using regular attribute assignment.""" - - module_spec: "_spec.ExternModuleSpec" - - def __init__(self, module_spec: "_spec.ExternModuleSpec") -> None: - super().__init__() - self.module_spec = module_spec - - def get_extern_func(self, func_name: str) -> Callable: - """This method helps get the external funciton in external module by function name. - It will wrap the functions as other prebuilt operators. - - Parameters - ---------- - func_name : str - The name of the function to get. - - Returns - ------ - ret_func: Callable - The callable function to call. - """ - for function_spec in self.module_spec.functions: - if function_spec.symbol == func_name: - # pylint: disable=cell-var-from-loop, import-outside-toplevel, protected-access - from tvm.relax import Tuple as RxTuple - from tvm.relax import call_dps_packed - - from . import spec as _spec - from .op import _wrap_nested - - def extern_func( - *args: List[ - Union[_spec.Tensor, _spec.ConstInt, _spec.ConstFloat, _spec.ConstString] - ] - ) -> Tensor: - spec2var = {} - for arg, arg_spec in zip(args, function_spec.args): - if not isinstance(arg_spec, _spec.Tensor): - continue - for value, value_spec in zip(arg.shape, arg_spec.shape): - if isinstance(value_spec, str): - if not value_spec in spec2var: - spec2var[value_spec] = value - else: - if not spec2var[value_spec] == value: - raise ValueError( - f"Confilict vars {spec2var[value_spec]} and {value} " - f"for {value_spec} in {function_spec}" - ) - out_shape = [] - func_spec_ret = function_spec.ret - assert isinstance( - func_spec_ret, _spec.Tensor - ), "Only single return value is supported for now" - for value_spec in func_spec_ret.shape: - if isinstance(value_spec, int): - out_shape.append(value_spec) - elif isinstance(value_spec, str): - if not value_spec in spec2var: - raise ValueError(f"Undefined var {value_spec} in {function_spec}") - out_shape.append(spec2var[value_spec]) - out_sinfo = TensorStructInfo( - out_shape, # type: ignore[arg-type] - func_spec_ret.dtype, - ) - relax_args = [] - for arg, arg_spec in zip(args, function_spec.args): - if isinstance(arg_spec, _spec.Tensor): - relax_args.append(arg._expr) - elif isinstance(arg_spec, _spec.ConstInt): - if arg_spec.dtype is None: - relax_args.append(rx.PrimValue(int(arg))) - else: - relax_args.append(rx.PrimValue(tir.IntImm(arg_spec.dtype, arg))) - elif isinstance(arg_spec, _spec.ConstFloat): - if arg_spec.dtype is None: - relax_args.append(rx.PrimValue(float(arg))) - else: - relax_args.append(rx.PrimValue(tir.FloatImm(arg_spec.dtype, arg))) - elif isinstance(arg_spec, _spec.ConstString): - relax_args.append(rx.StringImm(arg)) - - ret_tensor = _wrap_nested( - call_dps_packed( - func_name, - args=RxTuple(relax_args), - out_sinfo=out_sinfo, - ), - func_name, - ) - assert isinstance(ret_tensor, Tensor) - return ret_tensor - - # pylint: enable=cell-var-from-loop, import-outside-toplevel, protected-access - - return extern_func - raise ValueError(f"Unknown function {func_name} in the external module:{self.module_spec}") - - -class SourceModule(ExternModule): - """Base class for source module. Subclass it to import your source models. - - See PR #16006 (https://github.com/apache/tvm/pull/16006) for a detailed example. - """ - - def __init__( # pylint: disable=too-many-arguments,too-many-locals - self, - source_code: str, - source_format: str, # "cpp", "cu" - functions: Dict[str, "_spec.ExternFunctionSpec"], - compile_options: Optional[List[str]] = None, - compiler: Optional[str] = None, - output_format: str = "obj", # "obj", "wasm" - ): - from . import spec as _spec # pylint: disable=import-outside-toplevel - - def _detect_input_suffix(source_format: str) -> str: - if source_format == "cpp": - return ".cpp" - if source_format == "cu": - return ".cu" - raise ValueError(f"Invalid source format: {source_format}") - - def _detect_output_suffix(output_format: str) -> str: - if output_format == "obj": - if _cc._is_linux_like(): # pylint: disable=protected-access - return ".o" - if _cc._is_windows_like(): # pylint: disable=protected-access - return ".obj" - raise ValueError(f"Unsupported platform: {sys.platform}") - if output_format == "wasm": - return ".wasm" - raise ValueError(f"Invalid output format: {output_format}") - - source_suffix = _detect_input_suffix(source_format) - output_suffix = _detect_output_suffix(output_format) - if compile_options is None: - compile_options = [] - for include_path in find_include_path(): - compile_options.append("-I") - compile_options.append(include_path) - compile_options.append("-c") - compile_options.append("-DDMLC_USE_FOPEN64=0") - compile_options.append("-DDMLC_USE_LOGGING_LIBRARY=") - with tempfile.TemporaryDirectory() as temp_dir: - source_file = f"main{source_suffix}" - with open( - os.path.join(temp_dir, f"main{source_suffix}"), "w", encoding="utf-8" - ) as file: - file.write(source_code) - output_file = f"main{output_suffix}" - if shutil.which("ccache"): - ccache_env = {"CCACHE_COMPILERCHECK": "content"} - else: - ccache_env = None - _cc.create_shared( - output=output_file, - objects=[source_file], - options=compile_options, - cc=compiler, - cwd=temp_dir, - ccache_env=ccache_env, - ) - func_names: List[str] = [] - func_specs: List[_spec.ExternFunctionSpec] = [] - for func_name, func_spec in functions.items(): - func_names.append(func_name) - func_specs.append(func_spec) - if func_spec.symbol is None: - func_spec.symbol = func_name - library = load_static_library( - os.path.join(temp_dir, f"main{output_suffix}"), func_names=func_names - ) - module_spec = _spec.ExternModuleSpec( - library=library, - functions=func_specs, - ) - super().__init__(module_spec=module_spec) - - class ModuleList(Module): """Holds submodules in a list.""" @@ -660,6 +535,38 @@ def forward(self, x): # pylint: disable=invalid-name return x +def wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]: + """Wrap the given relax.Expr, emit it using the current BlockBuilder, + and automatically handle nested cases if the expr represents a Tuple. + + Parameters + ---------- + expr : relax.Expr + The Expr to be wrapped. + + name : str + Name hint. + + Returns + ------- + result : Union[Tensor, Tuple[Tensor]] + The computed result. + """ + if not isinstance(expr, rx.DataflowVar): + expr = BlockBuilder.current().emit(expr, name) + if isinstance(expr.struct_info_, TensorStructInfo): + return Tensor(_expr=expr) + if isinstance(expr.struct_info_, TupleStructInfo): + return tuple( + wrap_nested( # type: ignore + rx.TupleGetItem(expr, i), + name=f"{name}.{i}", + ) + for i in range(len(expr.struct_info_.fields)) + ) + raise TypeError(f"Unsupported return type: {expr.struct_info_}") + + def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any], bool]): """Find attributes that satisfy the condition recursively""" for name, item in root.__dict__.items(): @@ -680,31 +587,6 @@ def _attribute_finder(root: Module, prefix: str, condition_yield: Callable[[Any] ) -def _tensor_placeholder( - name: str, shape: Sequence[Union[int, tir.PrimExpr]], dtype: str -) -> "Tensor": - new_shape = [] - for expr in shape: - if isinstance(expr, (int, tir.IntImm)): - expr = int(expr) - assert expr >= 0 - new_shape.append(expr) - continue - if not isinstance(expr, tir.PrimExpr): - raise TypeError(f"Invalid shape: {shape}") - assert expr.dtype == "int64" - new_shape.append(expr) - return Tensor( - _expr=rx.Var( - name_hint=name, - struct_info=TensorStructInfo( - shape=new_shape, # type: ignore[arg-type] - dtype=dtype, - ), - ) - ) - - def _from_dlpack(tensor) -> NDArray: try: return ndarray.from_dlpack(tensor) @@ -722,19 +604,6 @@ def _from_dlpack(tensor) -> NDArray: ) -def _str_to_device(device: Union[str, Device]) -> Device: - if isinstance(device, Device): - return device - split = device.split(":") - if len(split) > 2: - raise ValueError(f"Invalid device: {device}") - device_type = split[0] - device_id = 0 if len(split) == 1 else int(split[1]) - if device_type not in Device.STR2MASK: - raise ValueError(f"Unsupported device type: {device_type}") - return Device(Device.STR2MASK[device_type], device_id) - - def _param_to_ndarray(params: List[Tuple[str, Parameter]], device: Device) -> List[NDArray]: results = [] missing = [] diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py new file mode 100644 index 000000000000..416913def48b --- /dev/null +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Export `nn.Module` to TVM's IRModule.""" +import threading +import typing + +from tvm import tir +from tvm.ir import IRModule + +from ... import expr as rx +from ...block_builder import BlockBuilder +from ...struct_info import ShapeStructInfo, TupleStructInfo +from . import core, extern +from . import spec as _spec +from .modules import IOEffect + + +def add_extern(mod: extern.ExternModule) -> None: + """Add an external module to the exporter.""" + try: + exporter = Exporter.current() + except Exception as exception: + raise RuntimeError( + "`nn.add_extern` should only be invoked when exporting a module." + ) from exception + exporter.add_external_module(mod) + + +class Exporter: + """Builder of ModuleSpec, which exports an nn.Module to TVM IRModule.""" + + _tls = threading.local() + + builder: BlockBuilder + io_effect: core.Effect + extern_mods: typing.List[extern.ExternModule] + + def __init__(self, debug: bool) -> None: + self.builder = BlockBuilder() + self.io_effect = IOEffect() if debug else None + self.extern_mods = [] + + @staticmethod + def current() -> "Exporter": + """Get the current Exporter under the with scope.""" + assert hasattr(Exporter._tls, "current") + return Exporter._tls.current + + def __enter__(self) -> "Exporter": + assert not hasattr(Exporter._tls, "current") + Exporter._tls.current = self + return self + + def __exit__(self, exc_type, exc, traceback) -> None: + assert hasattr(Exporter._tls, "current") + delattr(Exporter._tls, "current") + + def add_external_module(self, mod: extern.ExternModule) -> None: + """Add an external module to the exporter.""" + # pylint: disable=protected-access + all_symbols: typing.List[str] = [] + for extern_mod in self.extern_mods: + all_symbols.extend(extern_mod._symbols.keys()) + duplicated_symbols = list(set(mod._symbols.keys()) & set(all_symbols)) + # pylint: enable=protected-access + if duplicated_symbols: + raise ValueError(f"Duplicate symbols: {duplicated_symbols}") + self.extern_mods.append(mod) + + def build( # pylint: disable=too-many-locals + self, + spec: _spec.ModuleSpec, + ) -> typing.Tuple[ + IRModule, + typing.List[typing.Tuple[str, core.Parameter]], + typing.List[extern.ExternModule], + ]: + """Build the ModuleSpec to TVM IRModule. Returns the IRModule and the parameters.""" + + # pylint: disable=protected-access + def _params() -> typing.List[typing.Tuple[str, core.Parameter]]: + params = [] + for name, param in core._attribute_finder( + spec.module, prefix="", condition_yield=lambda x: isinstance(x, core.Parameter) + ): + params.append((name, param)) + return params + + def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: + result = [] + if self.io_effect is not None: + result.append(("", self.io_effect)) + for name, effect in core._attribute_finder( + spec.module, "", condition_yield=lambda x: isinstance(x, core.Effect) + ): + result.append((name, effect)) + return result + + # pylint: enable=protected-access + + params = _params() + effects = _effects() + ext_mods = self.extern_mods + with self: + if effects: + with self.builder.function("_initialize_effect"): + with self.builder.dataflow(): + outputs = _emit_effect_init(self.builder, effects) + self.builder.emit_func_output(outputs, params=[]) + for method_name, method_spec in zip(spec.method_names, spec.method_specs): + len_args = len(method_spec.arg_specs) + len_effects = { + "packed": 1, + "none": 0, + "plain": len(effects), + }[method_spec.effect_mode] + with self.builder.function( + method_name, + attrs={"num_input": len_args + len_effects}, # type: ignore + ): + with self.builder.dataflow(): + outputs, inputs = _emit_method(self.builder, method_spec, params, effects) + self.builder.emit_func_output(outputs, inputs) + mod = self.builder.finalize() + return mod, params, ext_mods + + +def _emit_effect_init( + builder: BlockBuilder, + effects: typing.List[typing.Tuple[str, core.Effect]], +): + outputs = [] + for prefix, effect in effects: + inits = effect.emit_init(prefix, builder) + assert isinstance(inits, list) + outputs.extend(inits) + outputs = builder.emit_output(builder.emit(rx.Tuple(outputs))) + return outputs + + +def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many-statements + builder: BlockBuilder, + spec: _spec.MethodSpec, + params: typing.List[typing.Tuple[str, core.Parameter]], + effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]], +): + # pylint: disable=protected-access + def _unwrap_ret(expr: typing.Any) -> typing.Any: + if isinstance(expr, core.Tensor): + return expr._expr + if isinstance(expr, tuple): + return rx.Tuple([_unwrap_ret(x) for x in expr]) + if isinstance(expr, list): + return rx.Tuple([_unwrap_ret(x) for x in expr]) + raise TypeError(f"Unsupported return type: {type(expr)}") + + def _convert_input(arg): + if isinstance(arg, tir.Var): + return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) + if isinstance(arg, core.Tensor): + return arg._expr # pylint: disable=protected-access + if isinstance(arg, _spec.Tuple): + return rx.Var( + arg.name, + struct_info=TupleStructInfo( + [_convert_input(arg_i).struct_info for arg_i in arg.elements] + ), + ) + raise TypeError(f"Unsupported input type: {type(arg)}") + + def _params(mode: str) -> typing.List[rx.Var]: + inputs: typing.List[rx.Var] = [] + for name, param in params: + var = core.Tensor.placeholder(param.shape, param.dtype, name)._expr + inputs.append(var) + param._expr = var + if mode == "none": + return [] + if mode == "plain": + return inputs + if mode == "packed": + input_var = rx.Var( + "packed_params", + TupleStructInfo(fields=[x.struct_info for x in inputs]), + ) + for i, (name, param) in enumerate(params): + param._expr = builder.emit(rx.TupleGetItem(input_var, i), name_hint=name) + return [input_var] + raise ValueError(f"Invalid param_mode: {mode}") + + def _effects(mode: str) -> typing.List[rx.Var]: + unflat_inputs: typing.List[typing.List[rx.Var]] = [] + for name, effect in effects: + effect_input = effect.create(name) + effect.set_state(effect_input) + unflat_inputs.append(effect_input) + inputs: typing.List[rx.Var] = sum(unflat_inputs, []) + if mode == "none": + return [] + if mode == "plain": + return inputs + if mode == "packed": + input_var = rx.Var( + "packed_effects", + TupleStructInfo(fields=[x.struct_info for x in inputs]), + ) + i = 0 + for effect_input, (_, effect) in zip(unflat_inputs, effects): + updated_effect_input = [] + for effect_input_i in effect_input: + updated_effect_input.append( + builder.emit( + rx.TupleGetItem(input_var, i), + name_hint=effect_input_i.name_hint, + ) + ) + i += 1 + effect.set_state(updated_effect_input) + return [input_var] + + raise ValueError(f"Invalid effect_mode: {mode}") + + # pylint: enable=protected-access + + def _detuple(arg, var: rx.Var, builder: BlockBuilder): + if isinstance(arg, _spec.Tuple): + ret = [] + for i, elem in enumerate(arg.elements): + field = builder.emit(rx.TupleGetItem(var, i), name_hint=f"{arg.name}_{i}") + ret.append(_detuple(elem, field, builder)) + return type(arg.elements)(ret) + if isinstance(arg, core.Tensor): + return core.Tensor(_expr=var) + if isinstance(arg, tir.Var): + return arg + raise TypeError(f"Unsupported input type: {type(arg)}") + + # TODO(@junrushao): Warn if params/effects are used when their mode is "none" + explicit_inputs = _method_spec_to_inputs(spec) + inputs = [_convert_input(x) for x in explicit_inputs] + inputs = inputs + _effects(spec.effect_mode) + inputs = inputs + _params(spec.param_mode) + + for arg_idx, (arg, var) in enumerate(zip(explicit_inputs, inputs)): + if isinstance(arg, _spec.Tuple): + explicit_inputs[arg_idx] = _detuple(arg, var, builder) + + outputs = spec.method(*explicit_inputs) + effect_outputs = [] + for _, effect in effects: + effect_outputs.extend(effect.finalize()) + if effect_outputs and spec.effect_mode != "none": + outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)])) + else: + outputs = builder.emit_output(_unwrap_ret(outputs)) + return outputs, inputs + + +def _method_spec_to_inputs( + spec: _spec.MethodSpec, +) -> typing.List[typing.Union[tir.Var, core.Tensor]]: + """Convert the MethodSpec to a list of inputs to Module's method.""" + str2var: typing.Dict[str, tir.Var] = {} + + def _get_var(name: str) -> tir.Var: + if name in str2var: + return str2var[name] + var = tir.Var(name, "int64") + str2var[name] = var + return var + + def _convert_input(arg_name, arg_spec): + if isinstance(arg_spec, _spec.Int): + arg = _get_var(arg_name) + elif isinstance(arg_spec, _spec.Tensor): + arg = core.Tensor.placeholder( # pylint: disable=protected-access + shape=[_get_var(x) if isinstance(x, str) else x for x in arg_spec.shape], + dtype=arg_spec.dtype, + name=arg_name, + ) + elif isinstance(arg_spec, _spec.Tuple): + elements = type(arg_spec.elements)( + [ + _convert_input(arg_name=arg_name + f"_{i}", arg_spec=arg_spec.elements[i]) + for i in range(len(arg_spec.elements)) + ] + ) + arg = _spec.Tuple( + name=arg_name, + elements=elements, + ) + else: + raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}") + return arg + + args = [] + for arg_name, arg_spec in zip(spec.arg_names, spec.arg_specs): + arg = _convert_input(arg_name=arg_name, arg_spec=arg_spec) + args.append(arg) + return args diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py new file mode 100644 index 000000000000..2d20809d2371 --- /dev/null +++ b/python/tvm/relax/frontend/nn/extern.py @@ -0,0 +1,392 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""External modules to be linked into the exported IRModule.""" +import os +import shutil +import sys +import tempfile +from pathlib import Path +from typing import Callable, Dict, List, Optional, Union + +from tvm import tir +from tvm.contrib import cc as _cc +from tvm.runtime import Module, load_static_library + +from ...op import call_dps_packed +from . import core +from .core import wrap_nested + + +class ExternModule: + """The abstract base class for external modules. External modules are designed to help + incorporate user-provided handcrafted kernels into the exported TVM IRModule. + """ + + _symbols: Dict[str, Callable] + + def __init__(self, symbols: Dict[str, Callable]) -> None: + self._symbols = symbols + + def __getitem__(self, func_name: str) -> Callable: + _inference_function = self._symbols[func_name] + + def _call(*input_args): + def _convert(arg, name: str): + from tvm import relax as rx # pylint: disable=import-outside-toplevel + + if isinstance(arg, core.Tensor): + return arg._expr # pylint: disable=protected-access + if isinstance(arg, int): + return rx.PrimValue(tir.IntImm("int64", arg)) + if isinstance(arg, float): + return rx.PrimValue(tir.FloatImm("float64", arg)) + if isinstance(arg, str): + return rx.StringImm(arg) + if isinstance(arg, tir.PrimExpr): + return rx.PrimValue(arg) + if isinstance(arg, (tuple, list)): + return rx.Tuple([_convert(e, f"{name}_{i}") for i, e in enumerate(arg)]) + raise TypeError(f"Unsupported input type: {type(arg)}") + + rx_inputs = _convert(input_args, "input") + rx_outputs_sinfo = _convert(_inference_function(*input_args), "dummy").struct_info + return wrap_nested(call_dps_packed(func_name, rx_inputs, rx_outputs_sinfo), func_name) + + return _call + + def _load(self, path: Path) -> Module: + return load_static_library(str(path), func_names=list(self._symbols.keys())) + + def load(self) -> Module: + """Loads the external module into a TVM runtime module.""" + raise NotImplementedError + + +class ObjectModule(ExternModule): # pylint: disable=too-few-public-methods + """A subclass of `nn.ExternModule`, which allows + users to provide an object `.o` file to be linked into compiled + artifact; + """ + + def __init__( + self, + symbols: Dict[str, Callable], + filepath: Path, + ) -> None: + if not isinstance(filepath, Path): + filepath = Path(filepath) + if not filepath.is_file(): + raise ValueError(f"Not a file: {str(filepath)}") + self.filepath = filepath + super().__init__(symbols) + + def load(self) -> Module: + return self._load(self.filepath) + + +class SourceModule(ExternModule): # pylint: disable=too-few-public-methods + """A subclass of `nn.ExternModule`. It compiles C++/CUDA source code and link them into the + eventual IRModule. + + **Shape/dtype inference.** The `nn.ExternModule` system requires users to provide additional + information to work, namely, `symbols`. It is a dictionary that maps each symbol in the + external object file to its shape/dtype inference function. Consider a case where function + `my_func` accepts two tensors, `a` of shape `(x, y, 1)`, and `b` of shape `(y, z, 5)`, and + produces a tensor `c` of shape `(x, y, z, 9)`, the shape/dtype inference function should look + like: + + .. code-block:: python + + def shape_dtype_inference(a, b): + x, y, _ = a.shape + _, z, _ = b.shape + return nn.Tensor.placeholder((x, y, z, 9), dtype="float32") + + + and the `symbols` dictionary should be provided as: + + .. code-block:: python + + symbols={ + "my_func": shape_dtype_inference, + } + + + **Calling convention.** All external modules now follows "destination-passing-style" (DPS) + calling convention, which means the returned tensors are pre-allocated by the system already + and passed in as an argument of the external function. + + Reuse the example above, the implementation of `my_func` should include three parameters in + its signature, where tensors are represented using DLTensor from DLPack, the de facto standard + of in-memory representation of tensors. More details: + https://github.com/dmlc/dlpack/blob/v0.8/include/dlpack/dlpack.h#L163-L206. + + To expose the symbol, `TVM_DLL_EXPORT_TYPED_FUNC(symbol, function)` is guaranteed available: + + .. code-block:: C++ + // those headers are guaranteed to be available + #include + #include + #include + + namespace { + // anonymous namespace hides the symbol `_my_func_impl` from other translation units + int _my_func_impl(DLTensor* a, DLTensor* b, DLTensor* c) { + // `a` and `b` are inputs, and `c` is the output + } + } + // expose symbol `my_func` instead of `_my_func_impl` + TVM_DLL_EXPORT_TYPED_FUNC(my_func, _my_func_impl); + + + **A compiler pass `AttachExternModules`.** It is introduced to attach a list of + `nn.ExternModule`s into an IRModule at any stage of the compilation pipeline, + and attach the compiled external modules as `runtime.Module`s into IRModule's `external_mods` + attribute. It is required by linking in `relax.build`, but with the existence of this pass, + source compilation can be deferred to arbitrary stage of TVM compilation. + + **Caveats.** It is required to call `nn.add_extern` to register external modules exactly once + during `export_tvm`. Each symbol should be registered exactly once to avoid potential conflicts, + and otherwise an error will be raised. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + symbols: Dict[str, Callable], + source_code: Union[str, Path], + source_format: str, # "cpp", "cu" + compile_options: Optional[List[str]] = None, + compiler: Optional[str] = None, + output_format: str = "obj", # "obj", "wasm" + ): + """Constructs a `nn.SourceModule` from source code. + + Parameters + ---------- + symbols : Dict[str, Callable] + The dictionary that maps each symbol in the external object file to its shape/dtype + inference function. + + source_code : Union[str, Path] + Source code or path to the source code to be compiled. + + source_format : str + The source code format. It can be either "cpp" or "cu". + + compile_options : Optional[List[str]] + The compile options. If not provided, the default compile options will be used. + + compiler : Optional[str] + The compiler. If not provided, the default compiler will be used. On Windows, + compilation requires `clang` by default. + + output_format : str + The output format. It can be either "obj" or "wasm". "obj" is the default format, + which is a shared object file. "wasm" is the WebAssembly format, which is a binary + file. + """ + + def _detect_input_suffix(source_format: str) -> str: + if source_format == "cpp": + return ".cpp" + if source_format == "cu": + return ".cu" + raise ValueError(f"Invalid source format: {source_format}") + + def _detect_output_suffix(output_format: str) -> str: + if output_format == "obj": + if _cc._is_linux_like(): # pylint: disable=protected-access + return ".o" + if _cc._is_windows_like(): # pylint: disable=protected-access + return ".obj" + raise ValueError(f"Unsupported platform: {sys.platform}") + if output_format == "wasm": + return ".wasm" + raise ValueError(f"Invalid output format: {output_format}") + + def _detect_source_code(source_code) -> str: + if isinstance(source_code, Path): + path = source_code + if not path.is_file(): + raise ValueError(f"Not a file: {str(path)}") + else: + try: + path = Path(source_code) + except: # pylint: disable=bare-except + return source_code + if not path.is_file(): + return source_code + with path.open("r", encoding="utf-8") as file: + return file.read() + + self.source_code = _detect_source_code(source_code) + if compile_options is None: + self.compile_options = SourceModule.get_compile_options(source_format=source_format) + else: + self.compile_options = list(compile_options) + self.compiler = compiler + self.source_suffix = _detect_input_suffix(source_format) + self.output_suffix = _detect_output_suffix(output_format) + super().__init__(symbols) + + @staticmethod + def tvm_home() -> Path: + """Find TVM's home directory. If `TVM_HOME` environment variable is set, use it. + Otherwise, use the directory where the `tvm` Python package is installed. + As a sanity check, it is required to have `include` and `3rdparty` as direct subdirectories. + + Returns + ------- + tvm_home : pathlib.Path + The TVM home directory, and it is guaranteed to have `include` and `3rdparty` as + direct subdirectories. + """ + if os.environ.get("TVM_HOME", None): + tvm_path = Path(os.environ["TVM_HOME"]) + assert tvm_path.exists(), ( + "Using environment variable `TVM_HOME`, " + f"but directory not found: {str(tvm_path)}" + ) + assert tvm_path.is_dir(), ( + "Using environment variable `TVM_HOME`, " + f"but it is not a directory: {str(tvm_path)}" + ) + else: + import tvm # pylint: disable=import-outside-toplevel + + tvm_path = Path(tvm.__file__).parent + assert tvm_path.is_dir() + tvm_path = tvm_path.resolve() + while True: + exists_include = (tvm_path / "include").is_dir() + exists_3rdparty = (tvm_path / "3rdparty").is_dir() + if exists_include and exists_3rdparty: + return tvm_path.resolve() + parent = tvm_path.parent + if parent == tvm_path: + raise ValueError( + "Cannot detect TVM directory. " + "Please explicitly specify it by setting `TVM_HOME` environment variable, " + "and make sure it contains `include` and `3rdparty` as direct sub-directories." + ) + tvm_path = parent + return tvm_path.resolve() + + @staticmethod + def get_includes(tvm_pkg: Optional[List[str]] = None) -> List[Path]: + """Returns the default include paths according to `tvm_home()`. + By default, it includes TVM, DLPack, and DMLC-Core. With `tvm_pkg` provided, it also + includes the specified package under `tvm_home/3rdparty`. + + Parameters + ---------- + tvm_pkg : Optional[List[str]] + The list of packages to be included under `tvm_home/3rdparty`. Each element should be + a relative path to `tvm_home/3rdparty`. + + Returns + ------- + includes : List[pathlib.Path] + The list of include paths. + """ + tvm_home = SourceModule.tvm_home() + results = [ + tvm_home / "include", + tvm_home / "3rdparty/dlpack/include", + tvm_home / "3rdparty/dmlc-core/include", + ] + if tvm_pkg: + for relative in tvm_pkg: + results.append(tvm_home / "3rdparty" / relative) + for path in results: + assert path.exists(), f"Not found: {str(path)}" + assert path.is_dir(), f"Not a directory: {str(path)}" + return results + + @staticmethod + def get_compile_options( + source_format: str, + tvm_pkg: Optional[List[str]] = None, + ) -> List[str]: + """Returns the default compile options depending on `source_format`, including the default + inlcude paths w.r.t. `tvm_home()`, default flags to configure DMLC-Core, and by default, + it uses "-O3" and "-std=c++17". + + Parameters + ---------- + source_format : str + The source code format. It can be either "cpp" or "cu". + + tvm_pkg : Optional[List[str]] + The list of packages to be included under `tvm_home/3rdparty`. Each element should be + a relative path to `tvm_home/3rdparty`. + + Returns + ------- + compile_options : List[str] + The list of compilation flags. + """ + include_flags = [] + for include_path in SourceModule.get_includes(tvm_pkg=tvm_pkg): + include_flags += ["-I", str(include_path)] + if source_format == "cpp": + host_flags = [ + "-c", # generate object file + "-O3", + "-std=c++17", + # DMLC default + "-DDMLC_USE_FOPEN64=0", + "-DDMLC_USE_LOGGING_LIBRARY=", + ] + elif source_format == "cu": + host_flags = [ + "-c", # generate object file + "-O3", + "-std=c++17", + # DMLC default + "-DDMLC_USE_FOPEN64=0", + "-DDMLC_USE_LOGGING_LIBRARY=", + # Enable `-fPIC` for the host compiler + "-Xcompiler=-fPIC", + ] + else: + raise ValueError(f"Invalid source format: {source_format}") + return include_flags + host_flags + + def compile(self, output_path: Path) -> None: + """Compiles the source code in a provided directory and returns the compiled artifact.""" + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + source_path = temp_dir / f"main{self.source_suffix}" + object_path = temp_dir / f"main{self.output_suffix}" + with source_path.open("w", encoding="utf-8") as file: + file.write(self.source_code) + _cc.create_shared( + output=str(object_path), + objects=[str(source_path)], + options=self.compile_options, + cc=self.compiler, + cwd=temp_dir, + ccache_env={"CCACHE_COMPILERCHECK": "content"} if shutil.which("ccache") else None, + ) + shutil.move(str(object_path), str(output_path)) + + def load(self) -> Module: + with tempfile.TemporaryDirectory() as temp_dir_str: + output_path = Path(temp_dir_str) / f"main{self.output_suffix}" + self.compile(output_path) + return self._load(output_path) diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 68b719d9adb6..b2c97a567ab8 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -18,8 +18,6 @@ """Builtin Modules.""" from typing import List, Optional, Sequence, Union -import numpy as np - from tvm import relax as rx from tvm import tir @@ -153,70 +151,6 @@ def to(self, dtype: Optional[str] = None) -> None: self.dtype = dtype # pylint: disable=attribute-defined-outside-init -class MultiLinear(Module): - """A layer that applies multiple linear transformations to the input.""" - - def __init__( - self, - in_features: int, - out_features: Sequence[int], - bias: bool = True, - dtype: Optional[str] = None, - out_dtype: Optional[str] = None, - ): - assert len(out_features) > 0 - total_out_features = sum(out_features) - - super().__init__() - self.in_features = in_features - self.out_features = out_features - self.out_dtype = out_dtype - self.weight = Parameter((total_out_features, in_features), dtype) - if bias: - self.bias = Parameter( - (total_out_features,), dtype=dtype if out_dtype is None else out_dtype - ) - else: - self.bias = None - - def forward(self, x: Tensor) -> Tensor: - """ - Forward method for linear layer. - - Parameters - ---------- - x : Tensor - The input tensor. - - Returns - ------- - ret : Tensor - The output tensor for the linear layer. - """ - sections = list(np.cumsum(self.out_features)[:-1]) - # x: [*B, in_features] - # w: [in_features, out_features] - w = op.permute_dims(self.weight) - # x: [*B, out_features] - x = op.matmul(x, w, out_dtype=self.out_dtype) - if self.bias is not None: - x = x + self.bias - results = op.split(x, sections, axis=-1) - return results - - def to(self, dtype: Optional[str] = None) -> None: - """ - Override to() such that we do not convert bias if there is `out_dtype`. - Otherwise, we might run into dtype mismatch when computing `x + self.bias` - since x is of type `out_dtype` and bias becomes `dtype`, potentially different. - """ - self.weight.to(dtype=dtype) - if self.bias is not None and self.out_dtype is None: - self.bias.to(dtype=dtype) - if dtype is not None and isinstance(getattr(self, "dtype", None), str): - self.dtype = dtype # pylint: disable=attribute-defined-outside-init - - class Conv1D(Module): """ Module for conv1d layer. diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 061465d0853d..75bc4574fcf1 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -28,45 +28,11 @@ from ... import expr as rx from ... import op as _op from ...block_builder import BlockBuilder -from ...struct_info import TensorStructInfo, TupleStructInfo -from .core import Tensor, get_default_dtype -from .spec import SpecBuilder +from .core import Tensor, get_default_dtype, wrap_nested IntExpr = Union[int, _tir.PrimExpr] -def _wrap_nested(expr: rx.Expr, name: str) -> Union[Tensor, Sequence[Tensor]]: - """Wrap the given relax.Expr, emit it using the current BlockBuilder, - and automatically handle nested cases if the expr represents a Tuple. - - Parameters - ---------- - expr : relax.Expr - The Expr to be wrapped. - - name : str - Name hint. - - Returns - ------- - result : Union[Tensor, Tuple[Tensor]] - The computed result. - """ - if not isinstance(expr, rx.DataflowVar): - expr = BlockBuilder.current().emit(expr, name) - if isinstance(expr.struct_info_, TensorStructInfo): - return Tensor(_expr=expr) - if isinstance(expr.struct_info_, TupleStructInfo): - return tuple( - _wrap_nested( - rx.TupleGetItem(expr, i), - name=f"{name}.{i}", - ) - for i in range(len(expr.struct_info_.fields)) - ) - raise TypeError(f"Unsupported return type: {expr.struct_info_}") - - def unsqueeze(x: Tensor, dim: int, name: str = "unsqueeze") -> Tensor: """Add a new axis to a tensor @@ -84,7 +50,7 @@ def unsqueeze(x: Tensor, dim: int, name: str = "unsqueeze") -> Tensor: result : Tensor Expanded result. """ - return _wrap_nested(_op.expand_dims(x._expr, dim), name) + return wrap_nested(_op.expand_dims(x._expr, dim), name) def concat(x: List[Tensor], dim: int, name: str = "concat") -> Tensor: @@ -106,7 +72,7 @@ def concat(x: List[Tensor], dim: int, name: str = "concat") -> Tensor: """ # Convert tensors to expressions. x = [t._expr for t in x] - return _wrap_nested(_op.concat(x, dim), name) + return wrap_nested(_op.concat(x, dim), name) def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor: @@ -134,7 +100,7 @@ def add(a: Tensor, b: Tensor, name: str = "add") -> Tensor: c = add(a, b) """ - return _wrap_nested(_op.add(a._expr, b._expr), name) + return wrap_nested(_op.add(a._expr, b._expr), name) def subtract(a: Tensor, b: Tensor, name: str = "subtract") -> Tensor: @@ -162,7 +128,7 @@ def subtract(a: Tensor, b: Tensor, name: str = "subtract") -> Tensor: c = subtract(a, b) """ - return _wrap_nested(_op.subtract(a._expr, b._expr), name) + return wrap_nested(_op.subtract(a._expr, b._expr), name) def multiply(a: Tensor, b: Tensor, name: str = "mul") -> Tensor: @@ -190,7 +156,7 @@ def multiply(a: Tensor, b: Tensor, name: str = "mul") -> Tensor: c = multiply(a, b) """ - return _wrap_nested(_op.multiply(a._expr, b._expr), name) + return wrap_nested(_op.multiply(a._expr, b._expr), name) def divide(a: Tensor, b: Tensor, name: str = "divide") -> Tensor: @@ -218,7 +184,7 @@ def divide(a: Tensor, b: Tensor, name: str = "divide") -> Tensor: c = divide(a, b) """ - return _wrap_nested(_op.divide(a._expr, b._expr), name) + return wrap_nested(_op.divide(a._expr, b._expr), name) def chunk(x: Tensor, chunks: int, dim: int = 0, name: str = "chunk") -> Tensor: @@ -240,7 +206,7 @@ def chunk(x: Tensor, chunks: int, dim: int = 0, name: str = "chunk") -> Tensor: result : Tuple[Tensor] A tuple with chunks elements containing slices of x. """ - return _wrap_nested(_op.split(x._expr, chunks, dim), name) + return wrap_nested(_op.split(x._expr, chunks, dim), name) def sum( @@ -274,7 +240,7 @@ def sum( result : Tensor The computed result. """ - return _wrap_nested(_op.sum(x._expr, axis, keepdims), name) + return wrap_nested(_op.sum(x._expr, axis, keepdims), name) def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = "matmul") -> Tensor: @@ -309,7 +275,7 @@ def matmul(a: Tensor, b: Tensor, out_dtype: Optional[str] = None, name: str = "m c = matmul(a, b) """ - return _wrap_nested(_op.matmul(a._expr, b._expr, out_dtype=out_dtype), name) + return wrap_nested(_op.matmul(a._expr, b._expr, out_dtype=out_dtype), name) def conv1d( @@ -392,7 +358,7 @@ def conv1d( if bias is not None: conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1])) - return _wrap_nested(conv_out, name) + return wrap_nested(conv_out, name) def conv2d( @@ -450,7 +416,7 @@ def conv2d( if bias is not None: conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1, 1])) - return _wrap_nested(conv_out, name) + return wrap_nested(conv_out, name) def conv1d_transpose( @@ -530,7 +496,7 @@ def conv1d_transpose( if bias is not None: conv_out = _op.add(conv_out, _op.reshape(bias._expr, [1, -1, 1])) - return _wrap_nested(conv_out, name) + return wrap_nested(conv_out, name) def maximum(x1: Tensor, x2: Tensor, name: str = "maximum"): @@ -558,7 +524,7 @@ def maximum(x1: Tensor, x2: Tensor, name: str = "maximum"): c = maximum(a, b) """ - return _wrap_nested(_op.maximum(x1._expr, x2._expr), name) + return wrap_nested(_op.maximum(x1._expr, x2._expr), name) def minimum(x1: Tensor, x2: Tensor, name: str = "minimum"): @@ -586,7 +552,7 @@ def minimum(x1: Tensor, x2: Tensor, name: str = "minimum"): c = minimum(a, b) """ - return _wrap_nested(_op.minimum(x1._expr, x2._expr), name) + return wrap_nested(_op.minimum(x1._expr, x2._expr), name) def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: str = "broadcast_to") -> Tensor: @@ -608,7 +574,7 @@ def broadcast_to(x: Tensor, shape: Sequence[IntExpr], name: str = "broadcast_to" result : Tensor The broadcasted tensor. """ - return _wrap_nested(_op.broadcast_to(x._expr, shape), name) + return wrap_nested(_op.broadcast_to(x._expr, shape), name) def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = "permute_dims") -> Tensor: @@ -630,7 +596,7 @@ def permute_dims(x: Tensor, axes: Optional[List[int]] = None, name: str = "permu result : Tensor The transposed result. """ - return _wrap_nested(_op.permute_dims(x._expr, axes=axes), name) + return wrap_nested(_op.permute_dims(x._expr, axes=axes), name) def reshape(x: Tensor, shape: Sequence[IntExpr], name="reshape") -> Tensor: @@ -668,7 +634,7 @@ def reshape(x: Tensor, shape: Sequence[IntExpr], name="reshape") -> Tensor: That is to say, in any case the dimension length of ``-1`` cannot be inferred in compile-time, an error will be thrown. """ - return _wrap_nested(_op.reshape(x._expr, shape), name) + return wrap_nested(_op.reshape(x._expr, shape), name) def repeat(x: Tensor, repeats: int, axis: Optional[int] = None, name="repeat") -> Tensor: @@ -705,7 +671,7 @@ def repeat(x: Tensor, repeats: int, axis: Optional[int] = None, name="repeat") - lv2 = repeat(x, repeats=2, axis=1) # lv2 == [[1., 1., 2., 2.], # [3., 3., 4., 4.]] """ - return _wrap_nested(_op.repeat(x._expr, repeats, axis), name) + return wrap_nested(_op.repeat(x._expr, repeats, axis), name) def squeeze(x: Tensor, axis: int = -1, name: str = "squeeze") -> Tensor: @@ -729,7 +695,7 @@ def squeeze(x: Tensor, axis: int = -1, name: str = "squeeze") -> Tensor: result : Tensor The squeezed result. """ - return _wrap_nested(_op.squeeze(x._expr, axis), name) + return wrap_nested(_op.squeeze(x._expr, axis), name) def take(x: Tensor, indices: Tensor, axis: Optional[int] = None, name="take") -> Tensor: @@ -759,7 +725,7 @@ def take(x: Tensor, indices: Tensor, axis: Optional[int] = None, name="take") -> ret : Tensor The taken result. """ - return _wrap_nested(_op.take(x._expr, indices._expr, axis), name) + return wrap_nested(_op.take(x._expr, indices._expr, axis), name) def astype(x: Tensor, dtype: str, name: str = "astype") -> Tensor: @@ -784,7 +750,7 @@ def astype(x: Tensor, dtype: str, name: str = "astype") -> Tensor: # If trying to cast to same dtype as x, skip casting. if x.dtype == dtype: return x - return _wrap_nested(_op.astype(x._expr, dtype), name) + return wrap_nested(_op.astype(x._expr, dtype), name) def relu(x: Tensor, name: str = "relu") -> Tensor: @@ -806,7 +772,7 @@ def relu(x: Tensor, name: str = "relu") -> Tensor: result : Tensor The computed result. """ - return _wrap_nested(_op.nn.relu(x._expr), name) + return wrap_nested(_op.nn.relu(x._expr), name) def silu(x: Tensor, name: str = "silu") -> Tensor: @@ -832,7 +798,7 @@ def silu(x: Tensor, name: str = "silu") -> Tensor: ---- The input tensor is required to have float dtype """ - return _wrap_nested(_op.nn.silu(x._expr), name) + return wrap_nested(_op.nn.silu(x._expr), name) def gelu(x: Tensor, approximate: Optional[str] = None, name: str = "gelu") -> Tensor: @@ -867,7 +833,7 @@ def gelu(x: Tensor, approximate: Optional[str] = None, name: str = "gelu") -> Te gelu_out = _op.nn.gelu_tanh(x._expr) else: gelu_out = _op.nn.gelu(x._expr) - return _wrap_nested(gelu_out, name) + return wrap_nested(gelu_out, name) def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor: @@ -897,7 +863,7 @@ def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor: ---- The input tensor is required to have float dtype """ - return _wrap_nested(_op.nn.softmax(x._expr, axis), name) + return wrap_nested(_op.nn.softmax(x._expr, axis), name) def layer_norm( @@ -969,7 +935,7 @@ def layer_norm( else: bias = rx.const(np.zeros(normalized_shape), dtype=dtype) - return _wrap_nested( + return wrap_nested( _op.nn.layer_norm( x._expr, gamma=weight, @@ -1020,7 +986,7 @@ def rms_norm( result : Tensor The computed result. """ - return _wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, axes, epsilon), name) + return wrap_nested(_op.nn.rms_norm(x._expr, weight._expr, axes, epsilon), name) def group_norm( @@ -1079,7 +1045,7 @@ def group_norm( dim = len(x._expr.struct_info.shape) if axes is None: axes = list(range(2, dim)) - return _wrap_nested( + return wrap_nested( _op.nn.group_norm( x._expr, weight, bias, num_groups, channel_axis=channel_axis, axes=axes, epsilon=eps ), @@ -1110,7 +1076,7 @@ def triu(x: Tensor, diagonal: int = 0, name: str = "triu") -> Tensor: ret : Tensor The result tensor. """ - return _wrap_nested(_op.triu(x._expr, diagonal), name) + return wrap_nested(_op.triu(x._expr, diagonal), name) def full( @@ -1147,7 +1113,7 @@ def full( fill_value = rx.const(fill_value, dtype=dtype) else: fill_value = fill_value._expr - return _wrap_nested(_op.full(shape, fill_value, dtype), name) + return wrap_nested(_op.full(shape, fill_value, dtype), name) def zeros( @@ -1173,7 +1139,7 @@ def zeros( result : Tensor The result tensor. """ - return _wrap_nested(_op.zeros(shape, dtype), name) + return wrap_nested(_op.zeros(shape, dtype), name) def split( @@ -1200,7 +1166,7 @@ def split( result : Tuple[Tensor, ...] A list of sub-arrays as the outcome of splitting. """ - return _wrap_nested(_op.split(ary._expr, indices_or_sections, axis), name) + return wrap_nested(_op.split(ary._expr, indices_or_sections, axis), name) def pad( @@ -1233,7 +1199,7 @@ def pad( result : Tensor Padded output tensor. """ - return _wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value, pad_mode=mode), name) + return wrap_nested(_op.nn.pad(x._expr, pad_width=pad, pad_value=value, pad_mode=mode), name) def get_timestep_embedding( @@ -1299,7 +1265,7 @@ def get_timestep_embedding( # Cast to proper output type emb = _op.astype(emb, dtype) - return _wrap_nested(emb, name) + return wrap_nested(emb, name) def scaled_dot_product_attention( @@ -1338,7 +1304,7 @@ def scaled_dot_product_attention( attn = _op.nn.attention( query._expr, key._expr, value._expr, causal_mask=causal_mask, scale=scale ) - return _wrap_nested(attn, name) + return wrap_nested(attn, name) def interpolate( @@ -1400,7 +1366,7 @@ def interpolate( else: coord_trans = "half_pixel" - return _wrap_nested( + return wrap_nested( _op.image.resize2d( x._expr, size, layout="NCHW", method=mode, coordinate_transformation_mode=coord_trans ), @@ -1426,7 +1392,7 @@ def ccl_allreduce(x: Tensor, op_type: str = "sum", name="ccl_allreduce"): result : Tensor The result tensor of allreduce. """ - return _wrap_nested(_op.ccl.allreduce(x._expr, op_type), name) + return wrap_nested(_op.ccl.allreduce(x._expr, op_type), name) def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"): @@ -1444,7 +1410,7 @@ def ccl_broadcast_from_worker0(x: Tensor, name="broadcast_from_worker"): result : Tensor The same tensor, which has been broadcast to all other workers. """ - return _wrap_nested(_op.ccl.broadcast_from_worker0(x._expr), name) + return wrap_nested(_op.ccl.broadcast_from_worker0(x._expr), name) def tensor_expr_op( @@ -1481,7 +1447,7 @@ def _convert(arg): return arg._expr # pylint: disable=protected-access return arg - return _wrap_nested( + return wrap_nested( BlockBuilder.current().emit_te( tensor_expr_func, *[_convert(arg) for arg in args], @@ -1517,13 +1483,14 @@ def debug_func(lineno: str, arg_0, arg_1, ...) -> None: # pylint: disable=import-outside-toplevel from tvm import relax as rx + from .exporter import Exporter from .modules import IOEffect # pylint: enable=import-outside-toplevel - if SpecBuilder.current().io_effect is None: + if Exporter.current().io_effect is None: raise RuntimeError("Debugging is only supported when debug mode is on.") - io: IOEffect = SpecBuilder.current().io_effect # type: ignore + io: IOEffect = Exporter.current().io_effect # type: ignore if _line_info is None: filename, line_number = inspect.getframeinfo(inspect.currentframe().f_back)[:2] diff --git a/python/tvm/relax/frontend/nn/spec.py b/python/tvm/relax/frontend/nn/spec.py index b7a97abf6969..210b16ce013a 100644 --- a/python/tvm/relax/frontend/nn/spec.py +++ b/python/tvm/relax/frontend/nn/spec.py @@ -16,17 +16,10 @@ # under the License. """Compilation specifications, for example, dynamic shape inputs.""" import inspect -import threading import typing -from tvm import tir -from tvm.ir import IRModule -from tvm.runtime import Module, load_static_library - -from ... import expr as rx -from ...block_builder import BlockBuilder -from ...struct_info import ShapeStructInfo, TupleStructInfo -from . import core +if typing.TYPE_CHECKING: + from .core import Module as nn_module_class ArgSpecType = typing.Union["Int", "Tensor"] MethodSpecType = typing.Union["MethodSpec", typing.Dict[str, ArgSpecType]] @@ -78,44 +71,6 @@ def __repr__(self) -> str: return self.elements.__repr__() -class ConstInt: # pylint: disable=too-few-public-methods - """An integer constant""" - - dtype: typing.Optional[str] - - def __init__(self, dtype: str = None) -> None: - self.dtype = dtype - - def __repr__(self) -> str: - if self.dtype is None: - return "const.int" - return f"const.int({self.dtype})" - - -class ConstFloat: # pylint: disable=too-few-public-methods - """A float constant""" - - dtype: typing.Optional[str] - - def __init__(self, dtype: str = None) -> None: - self.dtype = dtype - - def __repr__(self) -> str: - if self.dtype is None: - return "const.float" - return f"const.float({self.dtype})" - - -class ConstString: # pylint: disable=too-few-public-methods - """A string constant""" - - def __init__(self) -> None: - pass - - def __repr__(self) -> str: - return "const.string" - - class MethodSpec: """A spec for a compiled method""" @@ -222,60 +177,17 @@ def from_torch(args: typing.List[typing.Any], method: typing.Callable) -> "Metho return _method_spec_from_torch(args, method) - def as_inputs(self) -> typing.List[typing.Union[tir.Var, core.Tensor]]: - """Convert the MethodSpec to a list of inputs to Module's method.""" - str2var: typing.Dict[str, tir.Var] = {} - - def _get_var(name: str) -> tir.Var: - if name in str2var: - return str2var[name] - var = tir.Var(name, "int64") - str2var[name] = var - return var - - def _convert_input(arg_name, arg_spec): - if isinstance(arg_spec, Int): - arg = _get_var(arg_name) - elif isinstance(arg_spec, Tensor): - arg = core._tensor_placeholder( # pylint: disable=protected-access - name=arg_name, - shape=[_get_var(x) if isinstance(x, str) else x for x in arg_spec.shape], - dtype=arg_spec.dtype, - ) - elif isinstance(arg_spec, Tuple): - elements = type(arg_spec.elements)( - [ - _convert_input( - arg_name=arg_name + f"_tmp{i}", arg_spec=arg_spec.elements[i] - ) - for i in range(len(arg_spec.elements)) - ] - ) - arg = Tuple( - name=arg_name, - elements=elements, - ) - else: - raise TypeError(f"Invalid spec for argument {arg_name}: {arg_spec}") - return arg - - args = [] - for arg_name, arg_spec in zip(self.arg_names, self.arg_specs): - arg = _convert_input(arg_name=arg_name, arg_spec=arg_spec) - args.append(arg) - return args - class ModuleSpec: """A spec for a compiled nn.Module""" - module: core.Module + module: "nn_module_class" method_names: typing.List[str] method_specs: typing.List[MethodSpec] def __init__( self, - module: core.Module, + module: "nn_module_class", method_names: typing.List[str], method_specs: typing.List[MethodSpec], ) -> None: @@ -284,10 +196,9 @@ def __init__( self.method_specs = method_specs @staticmethod - def from_raw(spec: ModuleSpecType, module: core.Module) -> "ModuleSpec": + def from_raw(spec: ModuleSpecType, module: "nn_module_class") -> "ModuleSpec": """Create ModuleSpec from raw python dictionaries. - Examples -------- .. code-block:: python @@ -331,297 +242,3 @@ def __repr__(self) -> str: self.method_specs, ) ) - - -class ExternFunctionSpec: # pylint: disable=too-few-public-methods - """A spec for a compiled external function.""" - - args: typing.List[typing.Union[Tensor, ConstInt, ConstFloat, ConstString]] - ret: typing.Union[Tensor, typing.List[Tensor]] - symbol: typing.Optional[str] - - def __init__( - self, - args: typing.List[typing.Union[Tensor, ConstInt, ConstFloat, ConstString]], - ret: typing.Union[Tensor, typing.List[Tensor]], - symbol: typing.Optional[str] = None, - ) -> None: - self.args = args - self.ret = ret - self.symbol = symbol - - def __repr__(self) -> str: - arg_repr = ", ".join(arg.__repr__() for arg in self.args) - if isinstance(self.ret, list): - ret_repr = "(" + ", ".join(ret.__repr__() for ret in self.ret) + ")" - else: - ret_repr = self.ret.__repr__() - if self.symbol is None: - func = f"({arg_repr}) -> {ret_repr}" - else: - func = f"{self.symbol}({arg_repr}) -> {ret_repr}" - return f"ExternFunctionSpec: {func}" - - -class ExternModuleSpec: # pylint: disable=too-few-public-methods - """A spec for a compiled external Module.""" - - library: typing.Union[str, Module] - functions: typing.List[ExternFunctionSpec] - - def __init__( - self, - library: typing.Union[str, Module], - functions: typing.List[ExternFunctionSpec], - ) -> None: - self.library = library - self.functions = functions - - def load_library(self) -> Module: - """Load the external library into Module.""" - if isinstance(self.library, Module): - return self.library - return load_static_library( - self.library, - func_names=[func.symbol for func in self.functions], - ) - - def __repr__(self) -> str: - return f"ExternModuleSpec(library={self.library}):\n" + "\n".join( - f" {func.__repr__()}" for func in self.functions - ) - - -class SpecBuilder: - """Builder of ModuleSpec, which exports an nn.Module to TVM IRModule.""" - - _tls = threading.local() - - builder: BlockBuilder - io_effect: core.Effect - - def __init__(self) -> None: - from .modules import IOEffect # pylint: disable=import-outside-toplevel - - self.builder = BlockBuilder() - self.io_effect = IOEffect() - - @staticmethod - def current() -> "SpecBuilder": - """Get the current SpecBuilder under the with scope.""" - assert hasattr(SpecBuilder._tls, "current") - return SpecBuilder._tls.current - - def __enter__(self) -> "SpecBuilder": - assert not hasattr(SpecBuilder._tls, "current") - SpecBuilder._tls.current = self - return self - - def __exit__(self, exc_type, exc, traceback) -> None: - assert hasattr(SpecBuilder._tls, "current") - delattr(SpecBuilder._tls, "current") - - def build( # pylint: disable=too-many-locals - self, spec: ModuleSpec, debug: bool = False - ) -> typing.Tuple[IRModule, typing.List[typing.Tuple[str, core.Parameter]]]: - """Build the ModuleSpec to TVM IRModule. Returns the IRModule and the parameters.""" - - # pylint: disable=protected-access - def _params() -> typing.List[typing.Tuple[str, core.Parameter]]: - params = [] - for name, param in core._attribute_finder( - spec.module, prefix="", condition_yield=lambda x: isinstance(x, core.Parameter) - ): - params.append((name, param)) - return params - - def _effects() -> typing.List[typing.Tuple[str, core.Effect]]: - result = [] - if self.io_effect is not None: - result.append(("", self.io_effect)) - for name, effect in core._attribute_finder( - spec.module, "", condition_yield=lambda x: isinstance(x, core.Effect) - ): - result.append((name, effect)) - return result - - def _extern_modules() -> typing.List[core.ExternModule]: - result = [] - used = set() - for _, extern_module in core._attribute_finder( - spec.module, "", condition_yield=lambda x: isinstance(x, core.ExternModule) - ): - if extern_module not in used: - used.add(extern_module) - result.append(extern_module) - return result - - # pylint: enable=protected-access - - # Disable IO effects if not in debug mode. - if not debug: - self.io_effect = None - params = _params() - effects = _effects() - extern_modules = _extern_modules() - with self: - if effects: - with self.builder.function("_initialize_effect"): - with self.builder.dataflow(): - outputs = _emit_effect_init(self.builder, effects) - self.builder.emit_func_output(outputs, params=[]) - for method_name, method_spec in zip(spec.method_names, spec.method_specs): - len_args = len(method_spec.arg_specs) - len_effects = { - "packed": 1, - "none": 0, - "plain": len(effects), - }[method_spec.effect_mode] - with self.builder.function( - method_name, - attrs={"num_input": len_args + len_effects}, # type: ignore - ): - with self.builder.dataflow(): - outputs, inputs = _emit_method(self.builder, method_spec, params, effects) - self.builder.emit_func_output(outputs, inputs) - external_mods = [] - for extern_module in extern_modules: - external_mods.append(extern_module.module_spec.load_library()) - mod = self.builder.finalize() - if extern_modules: - original_external_mods = mod.get_attr("external_mods") - if original_external_mods is not None: - external_mods = original_external_mods + extern_modules - mod = mod.with_attr("external_mods", external_mods) - return mod, params - - -def _emit_effect_init( - builder: BlockBuilder, - effects: typing.List[typing.Tuple[str, core.Effect]], -): - outputs = [] - for prefix, effect in effects: - inits = effect.emit_init(prefix, builder) - assert isinstance(inits, list) - outputs.extend(inits) - outputs = builder.emit_output(builder.emit(rx.Tuple(outputs))) - return outputs - - -def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many-statements - builder: BlockBuilder, - spec: MethodSpec, - params: typing.List[typing.Tuple[str, core.Parameter]], - effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]], -): - # pylint: disable=protected-access - def _unwrap_ret(expr: typing.Any) -> typing.Any: - if isinstance(expr, core.Tensor): - return expr._expr - if isinstance(expr, tuple): - return rx.Tuple([_unwrap_ret(x) for x in expr]) - if isinstance(expr, list): - return rx.Tuple([_unwrap_ret(x) for x in expr]) - raise TypeError(f"Unsupported return type: {type(expr)}") - - def _convert_input(arg): - if isinstance(arg, tir.Var): - return rx.Var(arg.name, struct_info=ShapeStructInfo(values=[arg])) - if isinstance(arg, core.Tensor): - return arg._expr # pylint: disable=protected-access - if isinstance(arg, Tuple): - return rx.Var( - arg.name, - struct_info=TupleStructInfo( - [_convert_input(arg_i).struct_info for arg_i in arg.elements] - ), - ) - raise TypeError(f"Unsupported input type: {type(arg)}") - - def _params(mode: str) -> typing.List[rx.Var]: - inputs: typing.List[rx.Var] = [] - for name, param in params: - var = core._tensor_placeholder(name, param.shape, param.dtype)._expr - inputs.append(var) - param._expr = var - if mode == "none": - return [] - if mode == "plain": - return inputs - if mode == "packed": - input_var = rx.Var( - "packed_params", - TupleStructInfo(fields=[x.struct_info for x in inputs]), - ) - for i, (name, param) in enumerate(params): - param._expr = builder.emit(rx.TupleGetItem(input_var, i), name_hint=name) - return [input_var] - raise ValueError(f"Invalid param_mode: {mode}") - - def _effects(mode: str) -> typing.List[rx.Var]: - unflat_inputs: typing.List[typing.List[rx.Var]] = [] - for name, effect in effects: - effect_input = effect.create(name) - effect.set_state(effect_input) - unflat_inputs.append(effect_input) - inputs: typing.List[rx.Var] = sum(unflat_inputs, []) - if mode == "none": - return [] - if mode == "plain": - return inputs - if mode == "packed": - input_var = rx.Var( - "packed_effects", - TupleStructInfo(fields=[x.struct_info for x in inputs]), - ) - i = 0 - for effect_input, (_, effect) in zip(unflat_inputs, effects): - updated_effect_input = [] - for effect_input_i in effect_input: - updated_effect_input.append( - builder.emit( - rx.TupleGetItem(input_var, i), - name_hint=effect_input_i.name_hint, - ) - ) - i += 1 - effect.set_state(updated_effect_input) - return [input_var] - - raise ValueError(f"Invalid effect_mode: {mode}") - - # pylint: enable=protected-access - - def _detuple(arg, var: rx.Var, builder: BlockBuilder): - if isinstance(arg, Tuple): - ret = [] - for i, elem in enumerate(arg.elements): - field = builder.emit(rx.TupleGetItem(var, i), name_hint=f"{arg.name}_{i}") - ret.append(_detuple(elem, field, builder)) - return type(arg.elements)(ret) - if isinstance(arg, core.Tensor): - return core.Tensor(_expr=var) - if isinstance(arg, tir.Var): - return arg - raise TypeError(f"Unsupported input type: {type(arg)}") - - # TODO(@junrushao): Warn if params/effects are used when their mode is "none" - explicit_inputs = spec.as_inputs() - inputs = [_convert_input(x) for x in explicit_inputs] - inputs = inputs + _effects(spec.effect_mode) - inputs = inputs + _params(spec.param_mode) - - for arg_idx, (arg, var) in enumerate(zip(explicit_inputs, inputs)): - if isinstance(arg, Tuple): - explicit_inputs[arg_idx] = _detuple(arg, var, builder) - - outputs = spec.method(*explicit_inputs) - effect_outputs = [] - for _, effect in effects: - effect_outputs.extend(effect.finalize()) - if effect_outputs and spec.effect_mode != "none": - outputs = builder.emit_output(rx.Tuple([_unwrap_ret(outputs), rx.Tuple(effect_outputs)])) - else: - outputs = builder.emit_output(_unwrap_ret(outputs)) - return outputs, inputs diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index b52ed75827d2..19316c76b83d 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -78,6 +78,7 @@ from .optimize_layout_transform import OptimizeLayoutTransform from .remove_redundant_reshape import RemoveRedundantReshape from .fast_math import FastMathTransform +from .attach_external_modules import AttachExternModules # Import to register the legalization functions. from . import legalize_ops, tuning_api diff --git a/python/tvm/relax/transform/attach_external_modules.py b/python/tvm/relax/transform/attach_external_modules.py new file mode 100644 index 000000000000..f4ca4075f94d --- /dev/null +++ b/python/tvm/relax/transform/attach_external_modules.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A pass that attaches external modules to the IRModule. + +Note: "external modules" here refers to `relax.frontend.nn.ExternModule`. +""" +from typing import TYPE_CHECKING, List + +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass + +if TYPE_CHECKING: + from tvm.relax.frontend.nn import ExternModule + + +@module_pass(opt_level=0, name="AttachExternalModules") +class AttachExternModules: # pylint: disable=too-few-public-methods + """Attach variable bounds to each Relax function, which primarily helps with memory planning.""" + + def __init__(self, extern_modules: List["ExternModule"]): + self.extern_modules = extern_modules + + def transform_module(self, mod: IRModule, _ctx: PassContext) -> IRModule: + """Entrypoint""" + from tvm.relax.frontend.nn import ( # pylint: disable=import-outside-toplevel + ExternModule, + ) + + def _load(ext_mod: ExternModule): + assert isinstance(ext_mod, ExternModule), f"Expected ExternModule, but got: {ext_mod}" + return ext_mod.load() + + mod_attrs = dict(mod.attrs) if mod.attrs else {} + external_mods = mod_attrs.get("external_mods", []) + for ext in self.extern_modules: + external_mods.append(_load(ext)) + mod = mod.with_attr("external_mods", external_mods) + return mod diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index e4d1fefbe7a3..62760b341711 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -317,13 +317,12 @@ def _extract_attrs(mod: tvm.IRModule): if not params: params = {} - ext_libs, constants = _extract_attrs(mod) - params.update(dict(constants)) - if pipeline is not None: if isinstance(pipeline, str): pipeline = relax.get_pipeline(pipeline) mod = pipeline(mod) + ext_libs, constants = _extract_attrs(mod) + params.update(dict(constants)) builder = relax.ExecBuilder() mod = _vmcodegen(builder, mod, exec_mode) return _vmlink( diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index d6902bc62574..d66f5ee951a5 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -18,6 +18,7 @@ """Runtime NDArray API""" import ctypes import warnings + import numpy as np try: @@ -25,24 +26,38 @@ except ImportError: ml_dtypes = None import tvm._ffi +from tvm._ffi.base import _FFI_MODE, _LIB, c_array, check_call, string_types +from tvm._ffi.runtime_ctypes import ( + DataType, + DataTypeCode, + Device, + TVMArray, + TVMArrayHandle, + tvm_shape_index_t, +) -from tvm._ffi.base import _LIB, check_call, c_array, string_types, _FFI_MODE -from tvm._ffi.runtime_ctypes import DataType, Device, TVMArray, TVMArrayHandle -from tvm._ffi.runtime_ctypes import DataTypeCode, tvm_shape_index_t from . import _ffi_api try: # pylint: disable=wrong-import-position if _FFI_MODE == "ctypes": raise ImportError() - from tvm._ffi._cy3.core import _set_class_ndarray, _make_array, _from_dlpack - from tvm._ffi._cy3.core import NDArrayBase + from tvm._ffi._cy3.core import ( + NDArrayBase, + _from_dlpack, + _make_array, + _set_class_ndarray, + ) except (RuntimeError, ImportError) as error: # pylint: disable=wrong-import-position if _FFI_MODE == "cython": raise error - from tvm._ffi._ctypes.ndarray import _set_class_ndarray, _make_array, _from_dlpack - from tvm._ffi._ctypes.ndarray import NDArrayBase + from tvm._ffi._ctypes.ndarray import ( + NDArrayBase, + _from_dlpack, + _make_array, + _set_class_ndarray, + ) @tvm._ffi.register_object("runtime.NDArray") @@ -324,6 +339,8 @@ def device(dev_type, dev_id=0): assert tvm.device("cpu", 1) == tvm.cpu(1) assert tvm.device("cuda", 0) == tvm.cuda(0) """ + if isinstance(dev_type, Device): + return dev_type if not isinstance(dev_id, int): raise ValueError(f"Invalid device id: {dev_id}") diff --git a/python/tvm/target/detect_target.py b/python/tvm/target/detect_target.py index 0cfcf17e6a13..aada61164215 100644 --- a/python/tvm/target/detect_target.py +++ b/python/tvm/target/detect_target.py @@ -17,10 +17,10 @@ """Detect target.""" from typing import Union -from . import Target from .._ffi import get_global_func from .._ffi.runtime_ctypes import Device from ..runtime.ndarray import device +from . import Target def _detect_metal(dev: Device) -> Target: @@ -74,6 +74,23 @@ def _detect_vulkan(dev: Device) -> Target: ) +def _detect_cpu(dev: Device) -> Target: # pylint: disable=unused-argument + """Detect the host CPU architecture.""" + return Target( + { + "kind": "llvm", + "mtriple": get_global_func( + "tvm.codegen.llvm.GetDefaultTargetTriple", + allow_missing=False, + )(), + "mcpu": get_global_func( + "tvm.codegen.llvm.GetHostCPUName", + allow_missing=False, + )(), + } + ) + + def detect_target_from_device(dev: Union[str, Device]) -> Target: """Detects Target associated with the given device. If the device does not exist, there will be an Error. @@ -102,11 +119,11 @@ def detect_target_from_device(dev: Union[str, Device]) -> Target: f"Cannot detect device `{dev}`. Please make sure the device and its driver " "is installed properly, and TVM is compiled with the driver" ) - target = SUPPORT_DEVICE[device_type](dev) - return target + return SUPPORT_DEVICE[device_type](dev) SUPPORT_DEVICE = { + "cpu": _detect_cpu, "cuda": _detect_cuda, "metal": _detect_metal, "vulkan": _detect_vulkan, diff --git a/tests/python/relax/frontend_nn_extern_module.cc b/tests/python/relax/frontend_nn_extern_module.cc new file mode 100644 index 000000000000..09adbe9780d6 --- /dev/null +++ b/tests/python/relax/frontend_nn_extern_module.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file test_extern_module.cc + * \brief Testing code to be compiled by Relax nn.SourceModule + */ +#include +#include +#include + +namespace { + +int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) { + using namespace tvm::runtime; + ICHECK(a->ndim == 0); + ICHECK(b->ndim == 0); + ICHECK(c->ndim == 0); + ICHECK(DataType(a->dtype) == DataType::Float(32)); + ICHECK(DataType(b->dtype) == DataType::Float(32)); + ICHECK(DataType(c->dtype) == DataType::Float(32)); + float* a_data = static_cast(a->data); + float* b_data = static_cast(b->data); + float* c_data = static_cast(c->data); + *c_data = *a_data + *b_data; + return 0; +} + +int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) { + using namespace tvm::runtime; + ICHECK(a->ndim == 3); // [x, y, 1] + ICHECK(b->ndim == 3); // [y, z, 5] + ICHECK(c->ndim == 4); // [x, y, z, 9] + ICHECK(DataType(a->dtype) == DataType::Float(32)); + ICHECK(DataType(b->dtype) == DataType::Float(32)); + ICHECK(DataType(c->dtype) == DataType::Float(32)); + int x = a->shape[0]; + int y = a->shape[1]; + int z = b->shape[1]; + ICHECK(a->shape[0] == x); + ICHECK(a->shape[1] == y); + ICHECK(a->shape[2] == 1); + ICHECK(b->shape[0] == y); + ICHECK(b->shape[1] == z); + ICHECK(b->shape[2] == 5); + ICHECK(c->shape[0] == x); + ICHECK(c->shape[1] == y); + ICHECK(c->shape[2] == z); + ICHECK(c->shape[3] == 9); + return 0; +} +} // namespace +TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add); +TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym); diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index 2f00284a9141..6eaf1fbfc805 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -15,7 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring +import subprocess import tempfile +from pathlib import Path import numpy as np @@ -24,125 +26,62 @@ from tvm import relax from tvm.relax.frontend import nn from tvm.relax.frontend.nn import spec -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tir as T - -SOURCE_CODE = """ -#include -#include -#include - -namespace { - -int _scalar_add(DLTensor* a, DLTensor* b, DLTensor* c) { - using namespace tvm::runtime; - ICHECK(a->ndim == 0); - ICHECK(b->ndim == 0); - ICHECK(c->ndim == 0); - ICHECK(DataType(a->dtype) == DataType::Float(32)); - ICHECK(DataType(b->dtype) == DataType::Float(32)); - ICHECK(DataType(c->dtype) == DataType::Float(32)); - float* a_data = static_cast(a->data); - float* b_data = static_cast(b->data); - float* c_data = static_cast(c->data); - *c_data = *a_data + *b_data; - return 0; -} - -int _test_sym(DLTensor* a, DLTensor* b, DLTensor* c) { - using namespace tvm::runtime; - ICHECK(a->ndim == 3); - ICHECK(b->ndim == 3); - ICHECK(c->ndim == 4); - ICHECK(DataType(a->dtype) == DataType::Float(32)); - ICHECK(DataType(b->dtype) == DataType::Float(32)); - ICHECK(DataType(c->dtype) == DataType::Float(32)); - int x = a->shape[0]; - int y = a->shape[1]; - int z = b->shape[1]; - ICHECK(a->shape[0] == x); - ICHECK(a->shape[1] == y); - ICHECK(a->shape[2] == 1); - ICHECK(b->shape[0] == y); - ICHECK(b->shape[1] == z); - ICHECK(b->shape[2] == 5); - ICHECK(c->shape[0] == x); - ICHECK(c->shape[1] == y); - ICHECK(c->shape[2] == z); - ICHECK(c->shape[3] == 9); - return 0; -} - -} -TVM_DLL_EXPORT_TYPED_FUNC(ext_scalar_add, _scalar_add); -TVM_DLL_EXPORT_TYPED_FUNC(ext_test_sym, _test_sym); -""" - - -def test_extern_module(): - shape_a = ("x", "y", 1) - shape_b = ("y", "z", 5) - shape_c = ("x", "y", "z", 9) - dtype = "float32" - - class MyExtMod(nn.SourceModule): - def __init__(self): - super().__init__( - source_code=SOURCE_CODE, - source_format="cpp", - functions={ - "ext_scalar_add": spec.ExternFunctionSpec( - args=[ - spec.Tensor((), dtype), - spec.Tensor((), dtype), - ], - ret=spec.Tensor((), dtype), - ), - "ext_test_sym": spec.ExternFunctionSpec( - args=[ - spec.Tensor(shape_a, dtype), - spec.Tensor(shape_b, dtype), - ], - ret=spec.Tensor(shape_c, dtype), - ), - }, - compile_options=None, - compiler=None, - output_format="obj", - ) - - def scalar_add(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name - return self.get_extern_func("ext_scalar_add")(a, b) - - def test_sym(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name - return self.get_extern_func("ext_test_sym")(a, b) - - my_ext_mod = MyExtMod() - - class TestModule(nn.Module): - def __init__(self) -> None: - self.extern_matmul = my_ext_mod - - def scalar_add(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name - return self.extern_matmul.scalar_add(a, b) - - def test_sym(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name - return self.extern_matmul.test_sym(a, b) - - model = TestModule() - ir_module, _ = model.export_tvm( - spec={ - "scalar_add": { - "a": spec.Tensor((), dtype), - "b": spec.Tensor((), dtype), - }, - "test_sym": { - "a": spec.Tensor(shape_a, dtype), - "b": spec.Tensor(shape_b, dtype), - }, - } - ) +from tvm.relax.transform import AttachExternModules + + +def _infer_scalar_add(x, y): # pylint: disable=invalid-name + assert isinstance(x, nn.Tensor) + assert isinstance(y, nn.Tensor) + assert x.ndim == 0 and x.dtype == "float32" + assert y.ndim == 0 and y.dtype == "float32" + return nn.Tensor.placeholder(shape=(), dtype="float32") + + +def _infer_test_sym(a, b): # pylint: disable=invalid-name + def _var_equal(a, b): # pylint: disable=invalid-name + return tvm.ir.structural_equal(a, b, map_free_vars=True) + + assert isinstance(a, nn.Tensor) + assert isinstance(b, nn.Tensor) + assert a.ndim == 3 and a.dtype == "float32" # [x, y, 1] + assert b.ndim == 3 and b.dtype == "float32" # [y, z, 5] + x, y, z = a.shape[0], b.shape[0], b.shape[1] # pylint: disable=invalid-name + assert _var_equal(a.shape[0], x) + assert _var_equal(a.shape[1], y) + assert a.shape[2] == 1 + assert _var_equal(b.shape[0], y) + assert _var_equal(b.shape[1], z) + assert b.shape[2] == 5 + return nn.Tensor.placeholder(shape=(x, y, z, 9), dtype="float32") + + +def _test_scalar_add(func): + # pylint: disable=invalid-name + x = tvm.nd.array(np.array(1.0).astype("float32")) + y = tvm.nd.array(np.array(3.0).astype("float32")) + z = func(x, y).numpy() + # pylint: enable=invalid-name + assert z.ndim == 0 + assert z.dtype == "float32" + assert float(z) == 4.0 + + +def _test_infer_sym(func, x, y, z): # pylint: disable=invalid-name + # pylint: disable=invalid-name + a = tvm.nd.array(np.random.uniform(size=(x, y, 1)).astype("float32")) + b = tvm.nd.array(np.random.uniform(size=(y, z, 5)).astype("float32")) + c = func(a, b).numpy() + # pylint: enable=invalid-name + assert c.shape == (x, y, z, 9) + + +def _check_ir_equality(mod): + # pylint: disable=import-outside-toplevel + from tvm.script import ir as I + from tvm.script import relax as R + from tvm.script import tir as T + + # pylint: enable=import-outside-toplevel @I.ir_module class ExpectedModule: @@ -163,7 +102,9 @@ def scalar_add( def test_sym( a: R.Tensor(("x", "y", 1), dtype="float32"), b: R.Tensor(("y", "z", 5), dtype="float32") ) -> R.Tensor(("x", "y", "z", 9), dtype="float32"): - x, y, z = T.int64(), T.int64(), T.int64() + x = T.int64() + y = T.int64() + z = T.int64() R.func_attr({"num_input": 2}) with R.dataflow(): ext_test_sym = R.call_dps_packed( @@ -173,84 +114,138 @@ def test_sym( R.output(gv1) return gv1 - tvm.ir.assert_structural_equal(ir_module["scalar_add"], ExpectedModule["scalar_add"]) - tvm.ir.assert_structural_equal(ir_module["test_sym"], ExpectedModule["test_sym"]) - assert len(ir_module.attrs["external_mods"]) == 1 - assert ir_module.attrs["external_mods"][0].type_key == "static_library" - - scalar_a = tvm.nd.array(np.array(1.0, dtype="float32")) - scalar_b = tvm.nd.array(np.array(3.0, dtype="float32")) + tvm.ir.assert_structural_equal(ExpectedModule, mod) + + +def _compile_cc(src: Path, dst: Path): + # pylint: disable=import-outside-toplevel + from tvm._ffi.base import py_str + from tvm._ffi.libinfo import find_include_path + + # pylint: enable=import-outside-toplevel + + cmd = ["g++", str(src)] + for include_path in find_include_path(): + cmd += ["-I", include_path] + cmd += [ + "-DDMLC_USE_FOPEN64=0", + "-DDMLC_USE_LOGGING_LIBRARY=", + "-c", + "-fPIC", + "-o", + str(dst), + ] + with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc: + (out, _) = proc.communicate() + if proc.returncode != 0: + msg = "Compilation error:\n" + msg += py_str(out) + msg += "\nCommand line: " + " ".join(cmd) + raise RuntimeError(msg) + + +def test_extern_object(): + with tempfile.TemporaryDirectory() as temp_dir_str: + path = Path(temp_dir_str) / "main.o" + _compile_cc( + src=Path(__file__).parent / "frontend_nn_extern_module.cc", + dst=path, + ) - with tempfile.TemporaryDirectory() as temp_dir: - output_path = temp_dir + "/lib.so" - relax.build(ir_module, target="llvm").export_library(output_path) + class TestModule(nn.Module): + def __init__(self): + self.ext_mod = None + + def _get_ext_mod(self): + if self.ext_mod is None: + self.ext_mod = nn.ObjectModule( + { + "ext_scalar_add": _infer_scalar_add, + "ext_test_sym": _infer_test_sym, + }, + path, + ) + nn.add_extern(self.ext_mod) + return self.ext_mod + + def scalar_add(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name + return self._get_ext_mod()["ext_scalar_add"](a, b) + + def test_sym(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name + return self._get_ext_mod()["ext_test_sym"](a, b) + + mod, _, ext_mods = TestModule().export_tvm( + spec={ + "scalar_add": { + "a": spec.Tensor((), "float32"), + "b": spec.Tensor((), "float32"), + }, + "test_sym": { + "a": spec.Tensor(("x", "y", 1), "float32"), + "b": spec.Tensor(("y", "z", 5), "float32"), + }, + }, + allow_extern=True, + ) + _check_ir_equality(mod) + mod = AttachExternModules(ext_mods)(mod) # pylint: disable=not-callable compiled = tvm.runtime.relax_vm.VirtualMachine( - tvm.runtime.load_module(output_path), + relax.build(mod, target="llvm"), device=tvm.cpu(), ) - scalar_c = compiled["scalar_add"](scalar_a, scalar_b) + _test_scalar_add(compiled["scalar_add"]) + _test_infer_sym(compiled["test_sym"], x=3, y=4, z=2) + +def test_extern_source(): + source = Path(__file__).parent / "frontend_nn_extern_module.cc" -def test_extern_spec(): class TestModule(nn.Module): - def __init__(self) -> None: - self.ext_mod = nn.ExternModule( - spec.ExternModuleSpec( - library=tvm.runtime.Module(None), - functions=[ - spec.ExternFunctionSpec( - args=[ - spec.Tensor((2, 4), "float16"), - spec.ConstInt(), - spec.ConstInt("int32"), - spec.ConstFloat(), - spec.ConstFloat("float16"), - spec.ConstString(), - ], - ret=spec.Tensor((2, 4), "float16"), - symbol="test", - ) - ], + def __init__(self): + self.ext_mod = None + + def _get_ext_mod(self): + if self.ext_mod is None: + self.ext_mod = nn.SourceModule( + { + "ext_scalar_add": _infer_scalar_add, + "ext_test_sym": _infer_test_sym, + }, + source_code=source, + source_format="cpp", ) - ) - - def forward(self, x: nn.Tensor): - return self.ext_mod.get_extern_func("test")(x, 1, 2, 3.0, 4.0, "123") + nn.add_extern(self.ext_mod) + return self.ext_mod - @I.ir_module - class ExpectedModule: - I.module_attrs({"external_mods": [None]}) + def scalar_add(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name + return self._get_ext_mod()["ext_scalar_add"](a, b) - @R.function - def forward(x: R.Tensor((2, 4), dtype="float16")) -> R.Tensor((2, 4), dtype="float16"): - R.func_attr({"num_input": 1}) - with R.dataflow(): - test = R.call_dps_packed( - "test", - ( - x, - R.prim_value(1), - R.prim_value(T.int32(2)), - R.prim_value(T.float32(3)), - R.prim_value(T.float16(4)), - R.str("123"), - ), - out_sinfo=R.Tensor((2, 4), dtype="float16"), - ) - gv: R.Tensor((2, 4), dtype="float16") = test - R.output(gv) - return gv + def test_sym(self, a: nn.Tensor, b: nn.Tensor): # pylint: disable=invalid-name + return self._get_ext_mod()["ext_test_sym"](a, b) - model = TestModule() - ir_module, _ = model.export_tvm( + mod, _, ext_mods = TestModule().export_tvm( spec={ - "forward": { - "x": spec.Tensor((2, 4), "float16"), + "scalar_add": { + "a": spec.Tensor((), "float32"), + "b": spec.Tensor((), "float32"), }, - } + "test_sym": { + "a": spec.Tensor(("x", "y", 1), "float32"), + "b": spec.Tensor(("y", "z", 5), "float32"), + }, + }, + allow_extern=True, + ) + _check_ir_equality(mod) + mod = AttachExternModules(ext_mods)(mod) # pylint: disable=not-callable + compiled = tvm.runtime.relax_vm.VirtualMachine( + relax.build(mod, target="llvm"), + device=tvm.cpu(), ) - tvm.ir.assert_structural_equal(ir_module, ExpectedModule) + _test_scalar_add(compiled["scalar_add"]) + _test_infer_sym(compiled["test_sym"], x=3, y=4, z=2) if __name__ == "__main__": - tvm.testing.main() + test_extern_object() + test_extern_source() diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index ed1c851815bb..61fe95bccbfe 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -14,16 +14,17 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import List, Tuple + import numpy as np import pytest -from typing import Tuple, List import tvm import tvm.testing from tvm import relax from tvm.ir import assert_structural_equal -from tvm.relax.frontend.nn import core, modules, spec from tvm.relax.frontend import nn +from tvm.relax.frontend.nn import core, modules, spec from tvm.script import ir as I from tvm.script import relax as R @@ -121,56 +122,6 @@ def forward( assert_structural_equal(tvm_mod["forward"], forward, True) -def test_multi_linear(): - @R.function - def forward( - x: R.Tensor((3, 5, 4), dtype="float32"), - weight: R.Tensor((60, 4), dtype="float32"), - ) -> R.Tuple( - R.Tensor((3, 5, 4), dtype="float32"), - R.Tensor((3, 5, 8), dtype="float32"), - R.Tensor((3, 5, 16), dtype="float32"), - R.Tensor((3, 5, 32), dtype="float32"), - ): - R.func_attr({"num_input": 1}) - with R.dataflow(): - permute_dims: R.Tensor((4, 60), dtype="float32") = R.permute_dims(weight, axes=None) - matmul: R.Tensor((3, 5, 60), dtype="float32") = R.matmul( - x, permute_dims, out_dtype="void" - ) - split: R.Tuple( - R.Tensor((3, 5, 4), dtype="float32"), - R.Tensor((3, 5, 8), dtype="float32"), - R.Tensor((3, 5, 16), dtype="float32"), - R.Tensor((3, 5, 32), dtype="float32"), - ) = R.split(matmul, indices_or_sections=[4, 12, 28], axis=-1) - split_0: R.Tensor((3, 5, 4), dtype="float32") = split[0] - split_1: R.Tensor((3, 5, 8), dtype="float32") = split[1] - split_2: R.Tensor((3, 5, 16), dtype="float32") = split[2] - split_3: R.Tensor((3, 5, 32), dtype="float32") = split[3] - gv: R.Tuple( - R.Tensor((3, 5, 4), dtype="float32"), - R.Tensor((3, 5, 8), dtype="float32"), - R.Tensor((3, 5, 16), dtype="float32"), - R.Tensor((3, 5, 32), dtype="float32"), - ) = (split_0, split_1, split_2, split_3) - R.output(gv) - return gv - - mod = modules.MultiLinear( - in_features=4, - out_features=[4, 8, 16, 32], - bias=False, - ) - tvm_mod, _ = mod.export_tvm( - spec={ - "forward": {"x": spec.Tensor((3, 5, 4), "float32")}, - }, - debug=False, - ) - assert_structural_equal(tvm_mod["forward"], forward, True) - - def test_conv1d(): @R.function def forward(