diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index 58f923512d2e..b89bedbc6d45 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -43,7 +43,7 @@ #define TVM_CRT_MAX_REGISTERED_MODULES 2 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 /*! Maximum packet size, in bytes, including the length header. */ #define TVM_CRT_MAX_PACKET_SIZE_BYTES 512 diff --git a/apps/microtvm/pyproject.toml b/apps/microtvm/pyproject.toml index 8bfae0a157cd..98c769be48f5 100644 --- a/apps/microtvm/pyproject.toml +++ b/apps/microtvm/pyproject.toml @@ -111,6 +111,7 @@ tensorflow-estimator = {version = "^2.1", optional = true} # TFLite frontend tflite = {version = "2.1.0", optional = true} wheel = "*" +cloudpickle = "^1.6.0" [tool.poetry.extras] diff --git a/include/tvm/runtime/crt/error_codes.h b/include/tvm/runtime/crt/error_codes.h index d1a8619e8233..776691c4c7fc 100644 --- a/include/tvm/runtime/crt/error_codes.h +++ b/include/tvm/runtime/crt/error_codes.h @@ -93,6 +93,7 @@ typedef enum { kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0), kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1), kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2), + kTvmErrorFunctionCallInvalidArg = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 3), // Time Evaluator - times functions for use with debug runtime. kTvmErrorTimeEvaluatorBadHandle = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryTimeEvaluator, 0), diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 55a228882691..57374c54b297 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -68,6 +68,9 @@ # Contrib initializers from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel +if support.libinfo().get("USE_MICRO", "OFF") == "ON": + from . import micro + # NOTE: This file should be python2 compatible so we can # raise proper error message when user run the package using # an older version of the python diff --git a/python/tvm/autotvm/measure/measure.py b/python/tvm/autotvm/measure/measure.py index 8438b807d46e..ea7de35ad9e8 100644 --- a/python/tvm/autotvm/measure/measure.py +++ b/python/tvm/autotvm/measure/measure.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name """User facing API for specifying how to measure the generated code""" +import enum import multiprocessing from collections import namedtuple @@ -52,8 +53,19 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost" The absolute time stamp when we finish measurement. """ + def __repr__(self): + error_no_str = ( + str(self.error_no) + if self.error_no not in MeasureErrorNo + else str(MeasureErrorNo(self.error_no)) + ) + return ( + f"{self.__class__.__name__}(costs={self.costs!r}, error_no={error_no_str}, " + f"all_cost={self.all_cost}, timestamp={self.timestamp!r})" + ) -class MeasureErrorNo(object): + +class MeasureErrorNo(enum.IntEnum): """Error type for MeasureResult""" NO_ERROR = 0 # no error @@ -77,12 +89,15 @@ class Builder(object): n_parallel: int, optional The number of tasks submitted in parallel By default it will use all cpu cores + build_kwargs: dict, optional + Keyword args given to the build function. """ - def __init__(self, timeout=10, n_parallel=None): + def __init__(self, timeout=10, n_parallel=None, build_kwargs=None): self.timeout = timeout self.n_parallel = n_parallel or multiprocessing.cpu_count() - self.build_kwargs = {} + self.user_build_kwargs = build_kwargs if build_kwargs is not None else {} + self.runner_build_kwargs = None self.task = None def set_task(self, task, build_kwargs=None): @@ -97,7 +112,17 @@ def set_task(self, task, build_kwargs=None): The additional kwargs for build function """ self.task = task - self.build_kwargs = build_kwargs + self.build_kwargs = dict(build_kwargs.items()) if build_kwargs is not None else {} + if any(k in self.build_kwargs for k in self.user_build_kwargs): + logging.warn( + "Overriding these runner-supplied kwargs with user-supplied:\n%s", + "\n".join( + f" * {k}: from {build_kwargs[k]!r} to {self.user_build_kwargs[k]!r}" + for k in sorted([k for k in build_kwargs if k in self.user_build_kwargs]) + ), + ) + for k, v in self.user_build_kwargs.items(): + self.build_kwargs[k] = v def build(self, measure_inputs): """Build programs diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 42e046aefb4a..efe45daa1464 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -79,15 +79,22 @@ class LocalBuilder(Builder): The timeout of a compilation n_parallel: int The number of tasks run in parallel. "None" will use all cpu cores + build_kwargs: dict + If supplied, additional kwargs passed to build_func. Overrides any build_kwargs supplied + by the Runner. build_func: callable or str If is 'default', use default build function If is 'ndk', use function for android ndk If id 'stackvm', use function for stackvm If is callable, use it as custom build function, expect lib_format field. + do_fork: bool + If False, do not fork when building. Requires n_parallel=1. """ - def __init__(self, timeout=10, n_parallel=None, build_func="default"): - super(LocalBuilder, self).__init__(timeout, n_parallel) + def __init__( + self, timeout=10, n_parallel=None, build_kwargs=None, build_func="default", do_fork=False + ): + super(LocalBuilder, self).__init__(timeout, n_parallel, build_kwargs) if isinstance(build_func, str): if build_func == "default": @@ -99,6 +106,11 @@ def __init__(self, timeout=10, n_parallel=None, build_func="default"): else: raise ValueError("Invalid build_func" + build_func) self.build_func = _WrappedBuildFunc(build_func) + if not do_fork: + assert n_parallel in ( + None, + 1, + ), f"if do_fork=False, need n_parallel=None or 1; got {n_parallel}" self.executor = PopenPoolExecutor( timeout=timeout, initializer=reset_global_scope, initargs=(AutotvmGlobalScope.current,) ) @@ -518,7 +530,16 @@ def __call__(self, measure_input, tmp_dir, **kwargs): ) # TODO(tvm-team) consider linline _build_func_common func, arg_info = _build_func_common(measure_input, **kwargs) - func.export_library(filename, self.build_func) + if self.build_func.output_format == ".model-library-format": + # Late import to preserve autoTVM with USE_MICRO OFF + try: + from tvm import micro # pylint: disable=import-outside-toplevel + except ImportError: + raise ImportError("Requires USE_MICRO") + + micro.export_model_library_format(func, filename) + else: + func.export_library(filename, self.build_func) except Exception as e: # pylint: disable=broad-except return BuildResult(None, None, e, time.time() - tic) return BuildResult(filename, arg_info, None, time.time() - tic) diff --git a/python/tvm/autotvm/tuner/callback.py b/python/tvm/autotvm/tuner/callback.py index dc75de206d05..40ee24e077b4 100644 --- a/python/tvm/autotvm/tuner/callback.py +++ b/python/tvm/autotvm/tuner/callback.py @@ -145,8 +145,8 @@ def __del__(self): if logger.level < logging.DEBUG: # only print progress bar in non-debug mode sys.stdout.write( - "\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) " - "| %.2f s" % (prefix, 0, 0, 0, total, time.time() - tic) + "\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) " + "| %.2f s" % (prefix, 0, 0, si_prefix, 0, total, time.time() - tic) ) sys.stdout.flush() diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 88dcde8ceaf0..2aea9d3fd61d 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -16,6 +16,8 @@ # under the License. """MicroTVM module for bare-metal backends""" +from .build import autotvm_build_func +from .build import AutoTvmModuleLoader from .build import get_standalone_crt_dir from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError from .project import generate_project, GeneratedProject, TemplateProject diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 16e7ed24cb4f..7da9daf958c6 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -17,10 +17,13 @@ """Defines top-level glue functions for building microTVM artifacts.""" +import json import logging import os +import pathlib from .._ffi import libinfo +from .. import rpc as _rpc _LOG = logging.getLogger(__name__) @@ -57,3 +60,55 @@ def get_standalone_crt_dir() -> str: raise CrtNotFoundError() return STANDALONE_CRT_DIR + + +class AutoTvmModuleLoader: + """MicroTVM AutoTVM Module Loader + + Parameters + ---------- + template_project_dir : str + project template path + + project_options : dict + project generation option + """ + + def __init__(self, template_project_dir: str, project_options: dict = None): + self._project_options = project_options + + if isinstance(template_project_dir, pathlib.Path): + self._template_project_dir = str(template_project_dir) + elif not isinstance(template_project_dir, str): + raise TypeError(f"Incorrect type {type(template_project_dir)}.") + + def __call__(self, remote_kw, build_result): + with open(build_result.filename, "rb") as build_file: + build_result_bin = build_file.read() + + tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"]) + remote = tracker.request( + remote_kw["device_key"], + priority=remote_kw["priority"], + session_timeout=remote_kw["timeout"], + session_constructor_args=[ + "tvm.micro.compile_and_create_micro_session", + build_result_bin, + self._template_project_dir, + json.dumps(self._project_options), + ], + ) + system_lib = remote.get_function("runtime.SystemLib")() + yield remote, system_lib + try: + remote.get_function("tvm.micro.destroy_micro_session")() + except tvm.error.TVMError as exception: + _LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1) + + +def autotvm_build_func(): + """A dummy build function which causes autotvm to use a different export format.""" + + +# A sentinel value for the output format. +autotvm_build_func.output_format = ".model-library-format" diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index b1f2b49d972e..8a62c9b5f9ba 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -101,14 +101,9 @@ def __init__(self, api_client): if not self._info["is_template"]: raise NotATemplateProjectError() - def generate_project(self, graph_executor_factory, project_dir, options): - """Generate a project given GraphRuntimeFactory.""" - model_library_dir = utils.tempdir() - model_library_format_path = model_library_dir.relpath("model.tar") - export_model_library_format(graph_executor_factory, model_library_format_path) - + def generate_project_from_mlf(self, model_library_format_path, project_dir, options): self._api_client.generate_project( - model_library_format_path=model_library_format_path, + model_library_format_path=str(model_library_format_path), standalone_crt_dir=get_standalone_crt_dir(), project_dir=project_dir, options=options, @@ -119,6 +114,14 @@ def generate_project(self, graph_executor_factory, project_dir, options): def info(self): return self._info + def generate_project(self, graph_executor_factory, project_dir, options): + """Generate a project given GraphRuntimeFactory.""" + model_library_dir = utils.tempdir() + model_library_format_path = model_library_dir.relpath("model.tar") + export_model_library_format(graph_executor_factory, model_library_format_path) + + return self.generate_project_from_mlf(model_library_format_path, project_dir, options) + def generate_project( template_project_dir: typing.Union[pathlib.Path, str], diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index d4ad5b84fb76..abe7aff766e2 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -17,14 +17,17 @@ """Defines a top-level glue class that operates the Transport and Flasher classes.""" +import json import logging import sys from ..error import register_error -from .._ffi import get_global_func +from .._ffi import get_global_func, register_func from ..contrib import graph_executor +from ..contrib import utils from ..contrib.debugger import debug_executor from ..rpc import RPCSession +from . import project from .transport import IoTimeoutError from .transport import TransportLogger @@ -234,3 +237,71 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None): graph_json_str, dump_root=dump_root, ) + + +RPC_SESSION = None + + +@register_func("tvm.micro.compile_and_create_micro_session") +def compile_and_create_micro_session( + mod_src_bytes: bytes, + template_project_dir: str, + project_options: dict = None, +): + """Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session. + + Parameters + ---------- + mod_src_bytes : bytes + The content of a tarfile which contains the TVM-generated sources which together form the + SystemLib. This tar is expected to be created by export_library. The tar will be extracted + into a directory and the sources compiled into a MicroLibrary using the Compiler. + + template_project_dir: str + The path to a template microTVM Project API project which is used to generate the embedded + project that is built and flashed onto the target device. + + project_options: dict + Options for the microTVM API Server contained in template_project_dir. + """ + global RPC_SESSION + + temp_dir = utils.tempdir() + # Keep temp directory for generate project + temp_dir.set_keep_for_debug(True) + model_library_format_path = temp_dir / "model.tar.gz" + with open(model_library_format_path, "wb") as mlf_f: + mlf_f.write(mod_src_bytes) + + try: + template_project = project.TemplateProject.from_directory(template_project_dir) + generated_project = template_project.generate_project_from_mlf( + model_library_format_path, + temp_dir / "generated-project", + options=json.loads(project_options), + ) + except Exception as exception: + logging.error("Project Generate Error: %s", str(exception)) + raise exception + + generated_project.build() + generated_project.flash() + transport = generated_project.transport() + + RPC_SESSION = Session(transport_context_manager=transport) + RPC_SESSION.__enter__() + return RPC_SESSION._rpc._sess + + +@register_func +def destroy_micro_session(): + """Destroy RPC session for microTVM autotune.""" + global RPC_SESSION + + if RPC_SESSION is not None: + exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None) + RPC_SESSION = None + if (exc_type, exc_value, traceback) != (None, None, None): + exc = exc_type(exc_value) # See PEP 3109 + exc.__traceback__ = traceback + raise exc diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index a9834391ed88..045bf7904885 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -366,7 +366,9 @@ def text_summary(self): res += separate_line return res - def request(self, key, priority=1, session_timeout=0, max_retry=5): + def request( + self, key, priority=1, session_timeout=0, max_retry=5, session_constructor_args=None + ): """Request a new connection from the tracker. Parameters @@ -384,6 +386,11 @@ def request(self, key, priority=1, session_timeout=0, max_retry=5): max_retry : int, optional Maximum number of times to retry before give up. + + session_constructor_args : list, optional + List of additional arguments to passed as the remote session constructor. + The first element of the list is always a string specifying the name of + the session constructor, the following args are the positional args to that function. """ last_err = None for _ in range(max_retry): @@ -395,7 +402,13 @@ def request(self, key, priority=1, session_timeout=0, max_retry=5): if value[0] != base.TrackerCode.SUCCESS: raise RuntimeError("Invalid return value %s" % str(value)) url, port, matchkey = value[1] - return connect(url, port, matchkey, session_timeout) + return connect( + url, + port, + matchkey, + session_timeout, + session_constructor_args=session_constructor_args, + ) except socket.error as err: self.close() last_err = err diff --git a/python/tvm/support.py b/python/tvm/support.py index 800bfe4e2546..1adbee09c52c 100644 --- a/python/tvm/support.py +++ b/python/tvm/support.py @@ -29,7 +29,14 @@ def libinfo(): info: Dict[str, str] The dictionary of compile-time info. """ - return {k: v for k, v in GetLibInfo().items()} # pylint: disable=unnecessary-comprehension + get_lib_info_func = get_global_func("support.GetLibInfo", allow_missing=True) + if get_lib_info_func is not None: + lib_info = get_lib_info_func() + if lib_info is None: + return {} + else: + return {} + return dict(lib_info.items()) class FrontendTestModule(Module): diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 04721ee6d705..ea986a3bf096 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -395,6 +395,8 @@ int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMVal return 0; } +int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, + int* ret_type_code); tvm_crt_error_t TVMInitializeRuntime() { int idx = 0; tvm_crt_error_t error = kTvmErrorNoError; @@ -432,6 +434,10 @@ tvm_crt_error_t TVMInitializeRuntime() { error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0); } + if (error == kTvmErrorNoError) { + error = TVMFuncRegisterGlobal("tvm.contrib.random.random_fill", &TVMContribRandomFill, 0); + } + if (error != kTvmErrorNoError) { TVMPlatformMemoryFree(registry_backing_memory, dev); } @@ -563,3 +569,20 @@ release_and_return : { __attribute__((weak)) tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) { return kTvmErrorFunctionCallNotImplemented; } + +// Fill the tensor in args[0] with random data using TVMPlatformGenerateRandom. +// Named to correspond with the analogous function in the C++ runtime. +int TVMContribRandomFill(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val, + int* ret_type_code) { + if (num_args != 1) { + return kTvmErrorFunctionCallNumArguments; + } + + if (type_codes[0] != kTVMDLTensorHandle) { + return kTvmErrorFunctionCallWrongArgType; + } + + DLTensor* tensor = (DLTensor*)args[0].v_handle; + TVMNDArray arr = {*tensor}; + return TVMNDArray_RandomFill(&arr); +} diff --git a/src/runtime/crt/common/ndarray.c b/src/runtime/crt/common/ndarray.c index c97f7658938f..16bde3227f7c 100644 --- a/src/runtime/crt/common/ndarray.c +++ b/src/runtime/crt/common/ndarray.c @@ -47,18 +47,22 @@ int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, return 0; } +int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array) { + int64_t num_elems = 1; + int32_t idx; + for (idx = 0; idx < array->dl_tensor.ndim; ++idx) { + num_elems *= array->dl_tensor.shape[idx]; + } + return (num_elems * array->dl_tensor.dtype.bits + 7) / 8; +} + int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, TVMNDArray* array) { int status = TVMNDArray_Create(ndim, shape, dtype, dev, array); if (status != 0) { return status; } - int64_t num_elems = 1; - int32_t idx; - for (idx = 0; idx < array->dl_tensor.ndim; ++idx) { - num_elems *= shape[idx]; - } - int total_elem_bytes = (num_elems * dtype.bits + 7) / 8; + int total_elem_bytes = TVMNDArray_DataSizeBytes(array); array->dl_tensor.data = TVMBackendAllocWorkspace(kDLCPU, 0, total_elem_bytes, dtype.code, dtype.bits); memset(array->dl_tensor.data, 0, total_elem_bytes); @@ -136,6 +140,15 @@ int TVMNDArray_CreateView(TVMNDArray* arr, const tvm_index_t* shape, int32_t ndi return 0; } +int TVMNDArray_RandomFill(TVMNDArray* arr) { + int64_t num_bytes = TVMNDArray_DataSizeBytes(arr); + if (num_bytes < 0 || num_bytes > SIZE_MAX) { + return kTvmErrorFunctionCallInvalidArg; + } + + return TVMPlatformGenerateRandom(arr->dl_tensor.data, (size_t)num_bytes); +} + int TVMNDArray_Release(TVMNDArray* arr) { tvm_crt_error_t err; DLDevice dev = {kDLCPU, 0}; diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h index 7949aea6f171..aa718a303744 100644 --- a/src/runtime/crt/crt_config-template.h +++ b/src/runtime/crt/crt_config-template.h @@ -37,7 +37,7 @@ #define TVM_CRT_MAX_ARGS 10 /*! Size of the global function registry, in bytes. */ -#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 250 +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512 /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h index f878477e7b42..e5869ed2a303 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/common/ndarray.h @@ -44,6 +44,10 @@ typedef struct TVMNDArray { int TVMNDArray_Create(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, TVMNDArray* array); +int64_t TVMNDArray_DataSizeBytes(TVMNDArray* array); + +int TVMNDArray_RandomFill(TVMNDArray* array); + int TVMNDArray_Empty(int32_t ndim, const tvm_index_t* shape, DLDataType dtype, DLDevice dev, TVMNDArray* array); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index dc849b8fa6b3..80ace929b881 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -59,6 +59,17 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { function_names_.push_back(global_symbol.value()); CodeGenC::AddFunction(f); + if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + function_names_.push_back(runtime::symbol::tvm_module_main); + stream << "// CodegenC: NOTE: Auto-generated entry function\n"; + PrintFuncPrefix(); + stream << " " << tvm::runtime::symbol::tvm_module_main + << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " + << "int* out_ret_tcode, void* resource_handle) {\n"; + stream << " return " << global_symbol.value() + << "(args, arg_type_ids, num_args, out_ret_value, out_ret_tcode, resource_handle);\n"; + stream << "}\n"; + } } void CodeGenCHost::DeclareParameters(Map params) { diff --git a/tests/micro/zephyr/test_zephyr.py b/tests/micro/zephyr/test_zephyr.py index 5a7e69e3c7f9..6085318b5b50 100644 --- a/tests/micro/zephyr/test_zephyr.py +++ b/tests/micro/zephyr/test_zephyr.py @@ -15,10 +15,6 @@ # specific language governing permissions and limitations # under the License. -import contextlib -import copy -import datetime -import glob import logging import os import pathlib @@ -401,5 +397,143 @@ def test_tensors(sess): test_tensors(sess) +@tvm.testing.requires_micro +def test_autotune_conv2d(temp_dir, platform, west_cmd, tvm_debug): + """Test AutoTune for microTVM Zephyr""" + import tvm.relay as relay + + model, zephyr_board = PLATFORMS[platform] + + # Create a Relay model + data_shape = (1, 3, 16, 16) + weight_shape = (8, 3, 5, 5) + data = relay.var("data", relay.TensorType(data_shape, "float32")) + weight = relay.var("weight", relay.TensorType(weight_shape, "float32")) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + kernel_layout="OIHW", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + mod = tvm.IRModule.from_expr(f) + mod = relay.transform.InferType()(mod) + + data_sample = np.random.rand(data_shape[0], data_shape[1], data_shape[2], data_shape[3]).astype( + "float32" + ) + weight_sample = np.random.rand( + weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3] + ).astype("float32") + params = {mod["main"].params[1].name_hint: weight_sample} + + target = tvm.target.target.micro(model) + pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}) + with pass_context: + tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target) + assert len(tasks) > 0 + + repo_root = pathlib.Path( + subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() + ) + template_project_dir = repo_root / "apps" / "microtvm" / "zephyr" / "template_project" + module_loader = tvm.micro.AutoTvmModuleLoader( + template_project_dir=template_project_dir, + project_options={ + "zephyr_board": zephyr_board, + "west_cmd": west_cmd, + "verbose": 1, + "project_type": "host_driven", + }, + ) + builder = tvm.autotvm.LocalBuilder( + n_parallel=1, + build_kwargs={"build_option": {"tir.disable_vectorize": True}}, + do_fork=True, + build_func=tvm.micro.autotvm_build_func, + ) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + + measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) + + log_path = pathlib.Path("zephyr_autotune.log") + if log_path.exists(): + log_path.unlink() + + n_trial = 10 + for task in tasks: + tuner = tvm.autotvm.tuner.GATuner(task) + tuner.tune( + n_trial=n_trial, + measure_option=measure_option, + callbacks=[ + tvm.autotvm.callback.log_to_file(str(log_path)), + tvm.autotvm.callback.progress_bar(n_trial, si_prefix="M"), + ], + si_prefix="M", + ) + + # Build without tuning + with pass_context: + lowered = tvm.relay.build(mod, target=target, params=params) + + temp_dir = utils.tempdir() + project = tvm.micro.generate_project( + str(template_project_dir), + lowered, + temp_dir / "project", + { + "zephyr_board": zephyr_board, + "west_cmd": west_cmd, + "verbose": 1, + "project_type": "host_driven", + }, + ) + project.build() + project.flash() + + with tvm.micro.Session(project.transport()) as session: + graph_mod = tvm.micro.create_local_graph_executor( + lowered.get_graph_json(), session.get_system_lib(), session.device + ) + graph_mod.set_input(**lowered.get_params()) + graph_mod.run(data=data_sample) + expected_output = graph_mod.get_output(0).numpy() + del graph_mod + + # Build using autotune logs + with tvm.autotvm.apply_history_best(str(log_path)): + with pass_context: + lowered_tuned = tvm.relay.build(mod, target=target, params=params) + + temp_dir = utils.tempdir() + project = tvm.micro.generate_project( + str(template_project_dir), + lowered_tuned, + temp_dir / "project", + { + "zephyr_board": zephyr_board, + "west_cmd": west_cmd, + "verbose": 1, + "project_type": "host_driven", + }, + ) + project.build() + project.flash() + + with tvm.micro.Session(project.transport()) as session: + graph_mod = tvm.micro.create_local_graph_executor( + lowered_tuned.get_graph_json(), session.get_system_lib(), session.device + ) + graph_mod.set_input(**lowered_tuned.get_params()) + graph_mod.run(data=data_sample) + output = graph_mod.get_output(0).numpy() + del graph_mod + + tvm.testing.assert_allclose(output, expected_output, rtol=1e-4, atol=1e-5) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 586e9fbfb91e..5c6eb922fa17 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -219,5 +219,108 @@ def test_platform_timer(): assert len(result.results) == 3 +@tvm.testing.requires_micro +def test_autotune(): + """Verify that autotune works with micro.""" + import tvm.relay as relay + + data = relay.var("data", relay.TensorType((1, 3, 64, 64), "float32")) + weight = relay.var("weight", relay.TensorType((8, 3, 5, 5), "float32")) + y = relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + kernel_layout="OIHW", + out_dtype="float32", + ) + f = relay.Function([data, weight], y) + mod = tvm.IRModule.from_expr(f) + mod = relay.transform.InferType()(mod) + + main_func = mod["main"] + shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params} + type_dict = {p.name_hint: p.checked_type.dtype for p in main_func.params} + + weight_data = np.ones(shape_dict["weight"]).astype(type_dict["weight"]) + input_data = np.ones(shape_dict["data"]).astype(type_dict["data"]) + params = {"weight": weight_data} + inputs = {"data": input_data} + + target = tvm.target.target.micro("host") + template_project_dir = os.path.join(tvm.micro.get_standalone_crt_dir(), "template", "host") + + pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}) + with pass_context: + tasks = tvm.autotvm.task.extract_from_program(mod["main"], {}, target) + assert len(tasks) > 0 + + module_loader = tvm.micro.AutoTvmModuleLoader( + template_project_dir=template_project_dir, + project_options={}, + ) + builder = tvm.autotvm.LocalBuilder( + n_parallel=1, + build_kwargs={"build_option": {"tir.disable_vectorize": True}}, + do_fork=True, + build_func=tvm.micro.autotvm_build_func, + ) + runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + + measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) + + tune_log_file = pathlib.Path("crt_autotune.log") + if tune_log_file.exists(): + tune_log_file.unlink() + + num_trials = 10 + for task in tasks: + tuner = tvm.autotvm.tuner.GATuner(task) + tuner.tune( + n_trial=num_trials, + measure_option=measure_option, + callbacks=[ + tvm.autotvm.callback.log_to_file(str(tune_log_file)), + tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"), + ], + si_prefix="M", + ) + + # Build without tuning + with pass_context: + lowered = tvm.relay.build(mod, target=TARGET, params=params) + + temp_dir = tvm.contrib.utils.tempdir() + project = tvm.micro.generate_project(template_project_dir, lowered, temp_dir / "project") + project.build() + with tvm.micro.Session(project.transport()) as session: + graph_mod = tvm.micro.create_local_graph_executor( + lowered.get_graph_json(), session.get_system_lib(), session.device + ) + graph_mod.set_input(**lowered.get_params()) + graph_mod.run(**inputs) + expected_output = graph_mod.get_output(0).numpy() + del graph_mod + + # Build using autotune logs + with tvm.autotvm.apply_history_best(str(tune_log_file)): + with pass_context: + lowered_tuned = tvm.relay.build(mod, target=target, params=params) + + temp_dir = tvm.contrib.utils.tempdir() + project = tvm.micro.generate_project(template_project_dir, lowered_tuned, temp_dir / "project") + project.build() + with tvm.micro.Session(project.transport()) as session: + graph_mod = tvm.micro.create_local_graph_executor( + lowered_tuned.get_graph_json(), session.get_system_lib(), session.device + ) + graph_mod.set_input(**lowered_tuned.get_params()) + graph_mod.run(**inputs) + output = graph_mod.get_output(0).numpy() + del graph_mod + + tvm.testing.assert_allclose(output, expected_output, rtol=1e-4, atol=1e-5) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index 523065465172..7bf4d72b047e 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -88,7 +88,12 @@ def save_object(names): with open(path_runtime_py, "w") as fo: fo.write(runtime_py) - subprocess.check_call("python3 %s %s %s" % (path_runtime_py, path_dso, dtype), shell=True) + proc = subprocess.run( + [sys.executable, path_runtime_py, path_dso, dtype], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + assert proc.returncode == 0, f"{proc.args} exited with {proc.returncode}: {proc.stdout}" @tvm.testing.requires_gpu diff --git a/tutorials/micro/micro_autotune.py b/tutorials/micro/micro_autotune.py new file mode 100644 index 000000000000..136bcfeaec80 --- /dev/null +++ b/tutorials/micro/micro_autotune.py @@ -0,0 +1,250 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +.. _tutorial-micro-autotune: + +Autotuning with micro TVM +========================= +**Author**: `Andrew Reusch `_, `Mehrdad Hessar ` + +This tutorial explains how to autotune a model using the C runtime. +""" + +import numpy as np +import subprocess +import pathlib + +import tvm + +#################### +# Defining the model +#################### +# +# To begin with, define a model in Relay to be executed on-device. Then create an IRModule from relay model and +# fill parameters with random numbers. +# + +data_shape = (1, 3, 10, 10) +weight_shape = (6, 3, 5, 5) + +data = tvm.relay.var("data", tvm.relay.TensorType(data_shape, "float32")) +weight = tvm.relay.var("weight", tvm.relay.TensorType(weight_shape, "float32")) + +y = tvm.relay.nn.conv2d( + data, + weight, + padding=(2, 2), + kernel_size=(5, 5), + kernel_layout="OIHW", + out_dtype="float32", +) +f = tvm.relay.Function([data, weight], y) + +relay_mod = tvm.IRModule.from_expr(f) +relay_mod = tvm.relay.transform.InferType()(relay_mod) + +weight_sample = np.random.rand( + weight_shape[0], weight_shape[1], weight_shape[2], weight_shape[3] +).astype("float32") +params = {"weight": weight_sample} + +####################### +# Defining the target # +####################### +# Now we define the TVM target that describes the execution environment. This looks very similar +# to target definitions from other microTVM tutorials. +# +# When running on physical hardware, choose a target and a board that +# describe the hardware. There are multiple hardware targets that could be selected from +# PLATFORM list in this tutorial. You can chose the platform by passing --platform argument when running +# this tutorial. +# +TARGET = tvm.target.target.micro("host") + +# Compiling for physical hardware +# -------------------------------------------------------------------------- +# When running on physical hardware, choose a TARGET and a BOARD that describe the hardware. The +# STM32L4R5ZI Nucleo target and board is chosen in the example below. +# +# TARGET = tvm.target.target.micro("stm32l4r5zi") +# BOARD = "nucleo_l4r5zi" + +######################### +# Extracting tuning tasks +######################### +# Not all operators in the Relay program printed above can be tuned. Some are so trivial that only +# a single implementation is defined; others don't make sense as tuning tasks. Using +# `extract_from_program`, you can produce a list of tunable tasks. +# +# Because task extraction involves running the compiler, we first configure the compiler's +# transformation passes; we'll apply the same configuration later on during autotuning. + +pass_context = tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}) +with pass_context: + tasks = tvm.autotvm.task.extract_from_program(relay_mod["main"], {}, TARGET) +assert len(tasks) > 0 + +###################### +# Configuring microTVM +###################### +# Before autotuning, we need to define a module loader and then pass that to +# a `tvm.autotvm.LocalBuilder`. Then we create a `tvm.autotvm.LocalRunner` and use +# both builder and runner to generates multiple measurements for auto tunner. +# +# In this tutorial, we have the option to use x86 host as an example or use different targets +# from Zephyr RTOS. If you choose pass `--platform=host` to this tutorial it will uses x86. You can +# choose other options by choosing from `PLATFORM` list. +# + +repo_root = pathlib.Path( + subprocess.check_output(["git", "rev-parse", "--show-toplevel"], encoding="utf-8").strip() +) + +module_loader = tvm.micro.AutoTvmModuleLoader( + template_project_dir=repo_root / "src" / "runtime" / "crt" / "host", + project_options={}, +) +builder = tvm.autotvm.LocalBuilder( + n_parallel=1, + build_kwargs={"build_option": {"tir.disable_vectorize": True}}, + do_fork=True, + build_func=tvm.micro.autotvm_build_func, +) +runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + +measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) + +# Compiling for physical hardware +# -------------------------------------------------------------------------- +# module_loader = tvm.micro.AutoTvmModuleLoader( +# template_project_dir=repo_root / "apps" / "microtvm" / "zephyr" / "template_project", +# project_options={ +# "zephyr_board": BOARD, +# "west_cmd": "west", +# "verbose": 1, +# "project_type": "host_driven", +# }, +# ) +# builder = tvm.autotvm.LocalBuilder( +# n_parallel=1, +# build_kwargs={"build_option": {"tir.disable_vectorize": True}}, +# do_fork=False, +# build_func=tvm.micro.autotvm_build_func, +# ) +# runner = tvm.autotvm.LocalRunner(number=1, repeat=1, timeout=0, module_loader=module_loader) + +# measure_option = tvm.autotvm.measure_option(builder=builder, runner=runner) + +################ +# Run Autotuning +################ +# Now we can run autotuning separately on each extracted task. + +num_trials = 10 +for task in tasks: + tuner = tvm.autotvm.tuner.GATuner(task) + tuner.tune( + n_trial=num_trials, + measure_option=measure_option, + callbacks=[ + tvm.autotvm.callback.log_to_file("microtvm_autotune.log"), + tvm.autotvm.callback.progress_bar(num_trials, si_prefix="M"), + ], + si_prefix="M", + ) + +############################ +# Timing the untuned program +############################ +# For comparison, let's compile and run the graph without imposing any autotuning schedules. TVM +# will select a randomly-tuned implementation for each operator, which should not perform as well as +# the tuned operator. + +with pass_context: + lowered = tvm.relay.build(relay_mod, target=TARGET, params=params) + +temp_dir = tvm.contrib.utils.tempdir() + +project = tvm.micro.generate_project( + str(repo_root / "src" / "runtime" / "crt" / "host"), lowered, temp_dir / "project" +) + +# Compiling for physical hardware +# -------------------------------------------------------------------------- +# project = tvm.micro.generate_project( +# str(repo_root / "apps" / "microtvm" / "zephyr" / "template_project"), +# lowered, +# temp_dir / "project", +# { +# "zephyr_board": BOARD, +# "west_cmd": "west", +# "verbose": 1, +# "project_type": "host_driven", +# }, +# ) + +project.build() +project.flash() +with tvm.micro.Session(project.transport()) as session: + debug_module = tvm.micro.create_local_debug_executor( + lowered.get_graph_json(), session.get_system_lib(), session.device + ) + debug_module.set_input(**lowered.get_params()) + print("########## Build without Autotuning ##########") + debug_module.run() + del debug_module + +########################## +# Timing the tuned program +########################## +# Once autotuning completes, you can time execution of the entire program using the Debug Runtime: + +with tvm.autotvm.apply_history_best("microtvm_autotune.log"): + with pass_context: + lowered_tuned = tvm.relay.build(relay_mod, target=TARGET, params=params) + +temp_dir = tvm.contrib.utils.tempdir() + +project = tvm.micro.generate_project( + str(repo_root / "src" / "runtime" / "crt" / "host"), lowered_tuned, temp_dir / "project" +) + +# Compiling for physical hardware +# -------------------------------------------------------------------------- +# project = tvm.micro.generate_project( +# str(repo_root / "apps" / "microtvm" / "zephyr" / "template_project"), +# lowered_tuned, +# temp_dir / "project", +# { +# "zephyr_board": BOARD, +# "west_cmd": "west", +# "verbose": 1, +# "project_type": "host_driven", +# }, +# ) + +project.build() +project.flash() +with tvm.micro.Session(project.transport()) as session: + debug_module = tvm.micro.create_local_debug_executor( + lowered_tuned.get_graph_json(), session.get_system_lib(), session.device + ) + debug_module.set_input(**lowered_tuned.get_params()) + print("########## Build with Autotuning ##########") + debug_module.run() + del debug_module