Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[microTVM] Add support for AutoTVM #8715

Merged
merged 78 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
a9a67ef
Initial commit of API server impl.
areusch Mar 11, 2021
14678ab
initial commit of api client
areusch Mar 11, 2021
98f3468
Add TVM-side glue code to use Project API
areusch Mar 11, 2021
687d4d5
Change tvm.micro.Session to use Project API
areusch Jul 8, 2021
a091433
Rework how crt_config.h is used on the host.
areusch Jul 8, 2021
f3fb2c0
Modify Transport infrastructure to work with Project API
areusch Jul 8, 2021
0ffb0c3
Add host microTVM API server
areusch Jul 8, 2021
cffa8f9
Zephyr implementation of microTVM API server
areusch Jul 8, 2021
26e78ca
consolidate CcompilerAnnotator
areusch Jul 20, 2021
3d97f6a
Allow model library format with c backend, add test.
areusch Jul 20, 2021
142bc0e
Update unit tests
areusch May 27, 2021
8678346
fix incorrect doc
areusch Jul 7, 2021
ef0d331
Delete old Zephyr build infrastructure
areusch Jul 8, 2021
b739908
Delete old build abstractions
areusch Jul 8, 2021
cab83a9
Delete old Transport implementations and simplify module
areusch Jul 8, 2021
1d09e96
lint
areusch Jul 9, 2021
9d79ef7
ASF header
areusch Jul 20, 2021
367525c
address gromero comments
areusch Jul 28, 2021
2da7b1b
final fixes?
areusch Aug 4, 2021
c4efe81
fix is_shutdown
areusch Aug 5, 2021
b340703
fix user-facing API
areusch Aug 5, 2021
adcfe4b
fix TempDirectory / operator
areusch Aug 5, 2021
226082d
Update micro_tflite tutorial
areusch Aug 5, 2021
1780f91
lint
areusch Aug 5, 2021
8e39635
fix test_crt and test_link_params
areusch Aug 5, 2021
0ea7bd7
undo global micro import, hopefully fix fixture
areusch Aug 5, 2021
ac61a56
lint
areusch Aug 5, 2021
9e32c8e
fix more tests
areusch Aug 5, 2021
8de0606
Add session_constructor_args to tracker request() function.
areusch Jan 8, 2021
7853a4f
Generate entry_func symbol in C host codegen.
areusch Jan 8, 2021
891c91d
print MeasureErrorNo enum value in MeasureResult repr
areusch Jan 8, 2021
9cbfae1
Add microTVM session constructor.
areusch Jan 8, 2021
57dd7c1
add build_kwargs as a Builder constructor arg.
areusch Jan 8, 2021
c3f287f
Add do_fork option to Builder, to support stateful builders
areusch Jan 8, 2021
82a32ca
Checkin module_loader used to build and flash microTVM for autotuning.
areusch Jan 8, 2021
90317d4
Import micro into top-level when enabled.
areusch Feb 26, 2021
5d85120
Add tvm.contrib.random.random_fill to microTVM.
areusch Jan 21, 2021
990e9c9
Move compilation to runner :O
areusch Feb 26, 2021
7f19d44
Add a tutorial for AutoTVM with microcontrollers.
areusch Jan 8, 2021
6df9658
Fix si_prefix in autotuner callback
areusch Jan 21, 2021
76c8396
black format and git-clang-format
areusch Feb 27, 2021
2b09fac
Switch tutorial back to qemu version
areusch Feb 27, 2021
1afb549
improve error reporting so CI will show test error
areusch Mar 5, 2021
fd22020
black format
areusch Mar 5, 2021
80fb240
autotvm is working
areusch Aug 5, 2021
6791644
fix tutorial
mehrdadh Aug 11, 2021
4e631f1
fix dependencies
mehrdadh Aug 11, 2021
97c2cc4
fix auto tune issue
mehrdadh Aug 11, 2021
387103e
merge main
mehrdadh Aug 11, 2021
099a493
lint
mehrdadh Aug 11, 2021
42cf68f
address comments
mehrdadh Aug 11, 2021
93f46a2
fix lint
mehrdadh Aug 11, 2021
f798111
test crt and zephyr added
mehrdadh Aug 11, 2021
ca773a4
Merge branch 'main' into microtvm-project-autotune
mehrdadh Aug 11, 2021
530e6ae
fix func registery size
mehrdadh Aug 11, 2021
49951e1
Merge branch 'main' into microtvm-project-autotune
mehrdadh Aug 12, 2021
b497a80
moved autotune test and fixed
mehrdadh Aug 12, 2021
fe7a03a
fix crt test
mehrdadh Aug 12, 2021
f4155cd
address comments
mehrdadh Aug 12, 2021
a8539dc
change relay text
mehrdadh Aug 12, 2021
8503086
change relay in text_zephyr
mehrdadh Aug 12, 2021
12aa292
class added
mehrdadh Aug 12, 2021
31372fd
changed relay module in tutorial and cleanup
mehrdadh Aug 12, 2021
69bf13f
address comments
mehrdadh Aug 13, 2021
72ae64f
address TK comments
mehrdadh Aug 13, 2021
539e1bc
change fork
mehrdadh Aug 13, 2021
818f822
final comments
mehrdadh Aug 13, 2021
9ccbbb2
Merge branch 'main' into microtvm-project-autotune
mehrdadh Aug 13, 2021
b22dd00
retrigger due to flahy test
mehrdadh Aug 14, 2021
38ba16b
Merge branch 'main' into microtvm-project-autotune
mehrdadh Aug 24, 2021
86768ff
Merge branch 'main' into microtvm-project-autotune
mehrdadh Aug 25, 2021
274e482
Merge branch 'main' into microtvm-project-autotune
mehrdadh Aug 27, 2021
0205908
Merge branch 'main' into microtvm-project-autotune
mehrdadh Sep 6, 2021
6889ea1
fix tutorial
mehrdadh Sep 7, 2021
c14d7b4
retrigger
mehrdadh Sep 8, 2021
b0b44f6
merge fix
mehrdadh Sep 8, 2021
ff49394
fix changes due to merge
mehrdadh Sep 8, 2021
38509fa
merge conflict fix
mehrdadh Sep 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/bundle_deploy/crt_config/crt_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
#define TVM_CRT_MAX_REGISTERED_MODULES 2

/*! Size of the global function registry, in bytes. */
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved

/*! Maximum packet size, in bytes, including the length header. */
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 512
Expand Down
1 change: 1 addition & 0 deletions apps/microtvm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ tensorflow-estimator = {version = "^2.1", optional = true}
# TFLite frontend
tflite = {version = "2.1.0", optional = true}
wheel = "*"
cloudpickle = "^1.6.0"
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved


[tool.poetry.extras]
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/crt/error_codes.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ typedef enum {
kTvmErrorFunctionCallNumArguments = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 0),
kTvmErrorFunctionCallWrongArgType = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 1),
kTvmErrorFunctionCallNotImplemented = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 2),
kTvmErrorFunctionCallInvalidArg = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionCall, 3),

// Time Evaluator - times functions for use with debug runtime.
kTvmErrorTimeEvaluatorBadHandle = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryTimeEvaluator, 0),
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel

if support.libinfo().get("USE_MICRO", "OFF") == "ON":
from . import micro

# NOTE: This file should be python2 compatible so we can
# raise proper error message when user run the package using
# an older version of the python
Expand Down
33 changes: 29 additions & 4 deletions python/tvm/autotvm/measure/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
import enum
import multiprocessing
from collections import namedtuple

Expand Down Expand Up @@ -52,8 +53,19 @@ class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost"
The absolute time stamp when we finish measurement.
"""

def __repr__(self):
error_no_str = (
str(self.error_no)
if self.error_no not in MeasureErrorNo
else str(MeasureErrorNo(self.error_no))
)
return (
f"{self.__class__.__name__}(costs={self.costs!r}, error_no={error_no_str}, "
f"all_cost={self.all_cost}, timestamp={self.timestamp!r})"
)

class MeasureErrorNo(object):

class MeasureErrorNo(enum.IntEnum):
"""Error type for MeasureResult"""

NO_ERROR = 0 # no error
Expand All @@ -77,12 +89,15 @@ class Builder(object):
n_parallel: int, optional
The number of tasks submitted in parallel
By default it will use all cpu cores
build_kwargs: dict, optional
Keyword args given to the build function.
"""

def __init__(self, timeout=10, n_parallel=None):
def __init__(self, timeout=10, n_parallel=None, build_kwargs=None):
self.timeout = timeout
self.n_parallel = n_parallel or multiprocessing.cpu_count()
self.build_kwargs = {}
self.user_build_kwargs = build_kwargs if build_kwargs is not None else {}
self.runner_build_kwargs = None
self.task = None

def set_task(self, task, build_kwargs=None):
Expand All @@ -97,7 +112,17 @@ def set_task(self, task, build_kwargs=None):
The additional kwargs for build function
"""
self.task = task
self.build_kwargs = build_kwargs
self.build_kwargs = dict(build_kwargs.items()) if build_kwargs is not None else {}
if any(k in self.build_kwargs for k in self.user_build_kwargs):
logging.warn(
"Overriding these runner-supplied kwargs with user-supplied:\n%s",
"\n".join(
f" * {k}: from {build_kwargs[k]!r} to {self.user_build_kwargs[k]!r}"
for k in sorted([k for k in build_kwargs if k in self.user_build_kwargs])
),
)
for k, v in self.user_build_kwargs.items():
self.build_kwargs[k] = v

def build(self, measure_inputs):
"""Build programs
Expand Down
27 changes: 24 additions & 3 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,22 @@ class LocalBuilder(Builder):
The timeout of a compilation
n_parallel: int
The number of tasks run in parallel. "None" will use all cpu cores
build_kwargs: dict
If supplied, additional kwargs passed to build_func. Overrides any build_kwargs supplied
by the Runner.
build_func: callable or str
If is 'default', use default build function
If is 'ndk', use function for android ndk
If id 'stackvm', use function for stackvm
If is callable, use it as custom build function, expect lib_format field.
do_fork: bool
If False, do not fork when building. Requires n_parallel=1.
"""

def __init__(self, timeout=10, n_parallel=None, build_func="default"):
super(LocalBuilder, self).__init__(timeout, n_parallel)
def __init__(
self, timeout=10, n_parallel=None, build_kwargs=None, build_func="default", do_fork=False
):
super(LocalBuilder, self).__init__(timeout, n_parallel, build_kwargs)

if isinstance(build_func, str):
if build_func == "default":
Expand All @@ -99,6 +106,11 @@ def __init__(self, timeout=10, n_parallel=None, build_func="default"):
else:
raise ValueError("Invalid build_func" + build_func)
self.build_func = _WrappedBuildFunc(build_func)
if not do_fork:
assert n_parallel in (
None,
1,
), f"if do_fork=False, need n_parallel=None or 1; got {n_parallel}"
self.executor = PopenPoolExecutor(
timeout=timeout, initializer=reset_global_scope, initargs=(AutotvmGlobalScope.current,)
)
Expand Down Expand Up @@ -518,7 +530,16 @@ def __call__(self, measure_input, tmp_dir, **kwargs):
)
# TODO(tvm-team) consider linline _build_func_common
func, arg_info = _build_func_common(measure_input, **kwargs)
func.export_library(filename, self.build_func)
if self.build_func.output_format == ".model-library-format":
# Late import to preserve autoTVM with USE_MICRO OFF
try:
from tvm import micro # pylint: disable=import-outside-toplevel
except ImportError:
raise ImportError("Requires USE_MICRO")

micro.export_model_library_format(func, filename)
else:
func.export_library(filename, self.build_func)
except Exception as e: # pylint: disable=broad-except
return BuildResult(None, None, e, time.time() - tic)
return BuildResult(filename, arg_info, None, time.time() - tic)
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/tuner/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ def __del__(self):

if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write(
"\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) "
"| %.2f s" % (prefix, 0, 0, 0, total, time.time() - tic)
"\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) "
"| %.2f s" % (prefix, 0, 0, si_prefix, 0, total, time.time() - tic)
)
sys.stdout.flush()

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
"""MicroTVM module for bare-metal backends"""

from .build import autotvm_build_func
from .build import AutoTvmModuleLoader
from .build import get_standalone_crt_dir
from .model_library_format import export_model_library_format, UnsupportedInModelLibraryFormatError
from .project import generate_project, GeneratedProject, TemplateProject
Expand Down
55 changes: 55 additions & 0 deletions python/tvm/micro/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

"""Defines top-level glue functions for building microTVM artifacts."""

import json
import logging
import os
import pathlib

from .._ffi import libinfo
from .. import rpc as _rpc


_LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -57,3 +60,55 @@ def get_standalone_crt_dir() -> str:
raise CrtNotFoundError()

return STANDALONE_CRT_DIR


class AutoTvmModuleLoader:
"""MicroTVM AutoTVM Module Loader

Parameters
----------
template_project_dir : str
project template path

project_options : dict
project generation option
"""

def __init__(self, template_project_dir: str, project_options: dict = None):
self._project_options = project_options

if isinstance(template_project_dir, pathlib.Path):
self._template_project_dir = str(template_project_dir)
elif not isinstance(template_project_dir, str):
raise TypeError(f"Incorrect type {type(template_project_dir)}.")

def __call__(self, remote_kw, build_result):
with open(build_result.filename, "rb") as build_file:
build_result_bin = build_file.read()

tracker = _rpc.connect_tracker(remote_kw["host"], remote_kw["port"])
remote = tracker.request(
remote_kw["device_key"],
priority=remote_kw["priority"],
session_timeout=remote_kw["timeout"],
session_constructor_args=[
"tvm.micro.compile_and_create_micro_session",
build_result_bin,
self._template_project_dir,
json.dumps(self._project_options),
],
)
system_lib = remote.get_function("runtime.SystemLib")()
yield remote, system_lib
try:
remote.get_function("tvm.micro.destroy_micro_session")()
except tvm.error.TVMError as exception:
_LOG.warning("Error destroying remote session: %s", str(exception), exc_info=1)


def autotvm_build_func():
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
"""A dummy build function which causes autotvm to use a different export format."""


# A sentinel value for the output format.
autotvm_build_func.output_format = ".model-library-format"
17 changes: 10 additions & 7 deletions python/tvm/micro/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,9 @@ def __init__(self, api_client):
if not self._info["is_template"]:
raise NotATemplateProjectError()

def generate_project(self, graph_executor_factory, project_dir, options):
"""Generate a project given GraphRuntimeFactory."""
model_library_dir = utils.tempdir()
model_library_format_path = model_library_dir.relpath("model.tar")
export_model_library_format(graph_executor_factory, model_library_format_path)

def generate_project_from_mlf(self, model_library_format_path, project_dir, options):
self._api_client.generate_project(
model_library_format_path=model_library_format_path,
model_library_format_path=str(model_library_format_path),
standalone_crt_dir=get_standalone_crt_dir(),
project_dir=project_dir,
options=options,
Expand All @@ -119,6 +114,14 @@ def generate_project(self, graph_executor_factory, project_dir, options):
def info(self):
return self._info

def generate_project(self, graph_executor_factory, project_dir, options):
"""Generate a project given GraphRuntimeFactory."""
model_library_dir = utils.tempdir()
model_library_format_path = model_library_dir.relpath("model.tar")
export_model_library_format(graph_executor_factory, model_library_format_path)

return self.generate_project_from_mlf(model_library_format_path, project_dir, options)


def generate_project(
template_project_dir: typing.Union[pathlib.Path, str],
Expand Down
73 changes: 72 additions & 1 deletion python/tvm/micro/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@

"""Defines a top-level glue class that operates the Transport and Flasher classes."""

import json
import logging
import sys

from ..error import register_error
from .._ffi import get_global_func
from .._ffi import get_global_func, register_func
from ..contrib import graph_executor
from ..contrib import utils
from ..contrib.debugger import debug_executor
from ..rpc import RPCSession
from . import project
from .transport import IoTimeoutError
from .transport import TransportLogger

Expand Down Expand Up @@ -234,3 +237,71 @@ def create_local_debug_executor(graph_json_str, mod, device, dump_root=None):
graph_json_str,
dump_root=dump_root,
)


RPC_SESSION = None


@register_func("tvm.micro.compile_and_create_micro_session")
def compile_and_create_micro_session(
mod_src_bytes: bytes,
template_project_dir: str,
project_options: dict = None,
):
"""Compile the given libraries and sources into a MicroBinary, then invoke create_micro_session.

Parameters
----------
mod_src_bytes : bytes
The content of a tarfile which contains the TVM-generated sources which together form the
SystemLib. This tar is expected to be created by export_library. The tar will be extracted
into a directory and the sources compiled into a MicroLibrary using the Compiler.

template_project_dir: str
The path to a template microTVM Project API project which is used to generate the embedded
project that is built and flashed onto the target device.

project_options: dict
Options for the microTVM API Server contained in template_project_dir.
"""
global RPC_SESSION
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you can avoid the global here by returning an object whose destructor closes the session. Is there a reason this would not work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm you mean __del__?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. I'm not sure clear on the lifetime of the RPC_SESSION object though. It looks like the destroyer for it is never called?


temp_dir = utils.tempdir()
# Keep temp directory for generate project
temp_dir.set_keep_for_debug(True)
mehrdadh marked this conversation as resolved.
Show resolved Hide resolved
model_library_format_path = temp_dir / "model.tar.gz"
with open(model_library_format_path, "wb") as mlf_f:
mlf_f.write(mod_src_bytes)

try:
template_project = project.TemplateProject.from_directory(template_project_dir)
generated_project = template_project.generate_project_from_mlf(
model_library_format_path,
temp_dir / "generated-project",
options=json.loads(project_options),
)
except Exception as exception:
logging.error("Project Generate Error: %s", str(exception))
raise exception

generated_project.build()
generated_project.flash()
transport = generated_project.transport()

RPC_SESSION = Session(transport_context_manager=transport)
RPC_SESSION.__enter__()
return RPC_SESSION._rpc._sess


@register_func
def destroy_micro_session():
"""Destroy RPC session for microTVM autotune."""
global RPC_SESSION

if RPC_SESSION is not None:
exc_type, exc_value, traceback = RPC_SESSION.__exit__(None, None, None)
RPC_SESSION = None
if (exc_type, exc_value, traceback) != (None, None, None):
exc = exc_type(exc_value) # See PEP 3109
exc.__traceback__ = traceback
raise exc
Loading