Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMC][UMA] Support using UMA with TVMC #14165

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/autotuner.py
Expand Up @@ -48,7 +48,7 @@


@register_parser
def add_tune_parser(subparsers, _, json_params):
def add_tune_parser(subparsers, _, json_params, argv): # pylint: disable=unused-argument
"""Include parser for 'tune' subcommand"""

parser = subparsers.add_parser("tune", help="auto-tune a model")
Expand Down
45 changes: 41 additions & 4 deletions python/tvm/driver/tvmc/compiler.py
Expand Up @@ -47,17 +47,20 @@
from .transform import generate_transform_args, parse_graph_transform_args, apply_graph_transforms
from .shape_parser import parse_shape_string
from .workspace_pools import generate_workspace_pools_args, workspace_pools_recombobulate
from .extensions import load_extensions, get_extensions
from .arguments import TVMCSuppressedArgumentParser

# pylint: disable=invalid-name
logger = logging.getLogger("TVMC")


@register_parser
def add_compile_parser(subparsers, _, json_params):
def add_compile_parser(subparsers, main_parser, json_params, argv):
"""Include parser for 'compile' subcommand"""

parser = subparsers.add_parser("compile", help="compile a model.")
parser.set_defaults(func=drive_compile)

parser.add_argument(
"--cross-compiler",
default="",
Expand Down Expand Up @@ -114,16 +117,13 @@ def add_compile_parser(subparsers, _, json_params):
"e.g. '--pass-config tir.add_lower_pass=opt_level1,pass1,opt_level2,pass2'.",
)

generate_target_args(parser)
parser.add_argument(
"--tuning-records",
metavar="PATH",
default="",
help="path to an auto-tuning log file by AutoTVM. If not presented, "
"the fallback/tophub configs will be used.",
)
generate_registry_args(parser, Executor, "graph")
generate_registry_args(parser, Runtime, "cpp")

parser.add_argument("-v", "--verbose", action="count", default=0, help="increase verbosity.")
# TODO (@leandron) This is a path to a physical file, but
Expand Down Expand Up @@ -177,9 +177,42 @@ def add_compile_parser(subparsers, _, json_params):
for one_entry in json_params:
parser.set_defaults(**one_entry)

parser.add_argument(
"--experimental-tvmc-extension",
default=[],
action="append",
help="path from which to load packages named tvmc_extension which implement the "
"TVMCExtension interface.",
)
disposable_parser = TVMCSuppressedArgumentParser(main_parser)
try:
known_args, _ = disposable_parser.parse_known_args(argv)
except TVMCException:
known_args = None
try:
ext_dirs = known_args.experimental_tvmc_extension
except AttributeError:
ext_dirs = []
_handle_extensions(ext_dirs)

generate_target_args(parser)
generate_registry_args(parser, Executor, "graph")
generate_registry_args(parser, Runtime, "cpp")

generate_workspace_pools_args(parser)


def _handle_extensions(extra_paths):
extension_paths = extra_paths
if os.environ.get("TVMC_EXTENSION_DIR", None):
extension_paths.append(os.environ["TVMC_EXTENSION_DIR"])

load_extensions(extension_paths)
for ext in get_extensions():
for uma_backend in ext.uma_backends():
uma_backend.register()


def drive_compile(args):
"""Invoke tvmc.compiler module with command line arguments

Expand Down Expand Up @@ -411,6 +444,10 @@ def compile_model(
# dump which operations are offloaded to which backend
dump_operation_offloads(mod, initial_relay, dump_offloads)

for ext in get_extensions():
for uma_backend in ext.uma_backends():
mod = uma_backend.partition(mod)

if tuning_records and os.path.exists(tuning_records):
logger.debug("tuning records file provided: %s", tuning_records)

Expand Down
16 changes: 14 additions & 2 deletions python/tvm/driver/tvmc/composite_target.py
Expand Up @@ -32,6 +32,7 @@


from tvm.driver.tvmc import TVMCException
from tvm.driver.tvmc.extensions import get_extensions


# pylint: disable=invalid-name
Expand Down Expand Up @@ -87,7 +88,18 @@ def get_codegen_names():
list of str
all registered targets
"""
return list(REGISTERED_CODEGEN.keys())
return list(get_all_codegens().keys())


def get_all_codegens():
codegens = REGISTERED_CODEGEN
for ext in get_extensions():
for uma_backend in ext.uma_backends():
codegens[uma_backend.target_name] = {
"config_key": None,
"pass_pipeline": uma_backend.partition,
}
return codegens


def get_codegen_by_target(name):
Expand All @@ -104,6 +116,6 @@ def get_codegen_by_target(name):
requested target codegen information
"""
try:
return REGISTERED_CODEGEN[name]
return get_all_codegens()[name]
except KeyError:
raise TVMCException("Composite target %s is not defined in TVMC." % name)
128 changes: 128 additions & 0 deletions python/tvm/driver/tvmc/extensions.py
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Allows to extend TVMC with external code.
"""
import sys
import importlib
import inspect
import pkgutil
import warnings
import copy
from abc import abstractmethod


_EXTENSIONS = []


class TVMCExtension(object):
@abstractmethod
def uma_backends(self):
return []


def get_extensions():
"""Returns all loaded extensions."""

for ext in _EXTENSIONS:
yield ext


def load_extensions(paths):
"""
Loads extensions from the given locations.

Extensions must implement the `TVMCExtension` interface and be stored in a directory called
`tvmc_extension`.
"""

path_backup = copy.copy(sys.path)
sys.path.extend(paths)

top_modules = []
try:
mod = importlib.import_module("tvmc_extension")
top_modules.append(mod)
except ImportError:
pass

sys.path.clear()
sys.path.extend(path_backup)

extension_classes = _scan_all(top_modules)
for ext_cls in extension_classes:
_EXTENSIONS.append(ext_cls())


def _scan_all(top_level):
scanned_extensions = []
for mdl in top_level:
for importer, modname, _ in pkgutil.walk_packages(
path=mdl.__path__, prefix=mdl.__name__ + ".", onerror=lambda x: None
):
try:
module_name = modname.rsplit(".", 1)[-1]
# If module's name starts with "_", do not load the module.
# But if the module's name starts with a "__", then load the
# module.
if module_name.startswith("_") and not module_name.startswith("__"):
continue
lhutton1 marked this conversation as resolved.
Show resolved Hide resolved

with warnings.catch_warnings(record=True) as recorded_warnings:
if sys.version_info < (3, 10):
m = importer.find_module(modname) # type: ignore
assert m is not None
loaded_mod = m.load_module(modname)
else:
spec = importer.find_spec(modname)
assert spec is not None
if modname in sys.modules:
loaded_mod = sys.modules[modname]
else:
loaded_mod = importlib.util.module_from_spec(spec)
if loaded_mod is not None:
spec.loader.exec_module(loaded_mod)
sys.modules[modname] = loaded_mod

if len(recorded_warnings) > 0:
for warning in recorded_warnings:
warnings.showwarning(
message=warning.message,
category=warning.category,
filename=warning.filename,
lineno=warning.lineno,
file=warning.file,
line=warning.line,
)

if loaded_mod is not None:
for _name, obj in inspect.getmembers(loaded_mod):
if _is_concrete_extension_type(obj):
scanned_extensions.append(obj)
except ImportError as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to add a test for this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried adding this, but encountered an issue with race conditions. The way the sys.path is currently modified by extensions.py is problematic when more than one test is run in parallel. Is there already an approach for serializing tests in TVM? There doesn't seem to be an upstream solution: pytest-dev/pytest-xdist#18

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we instead avoid modifying sys.path when loading a module? I had a naive search and wondered if this could help? cc @Mousius @leandron who might know more about these things

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While such an approach should work for single source files, I'm not sure how that would work for packages or namespace packages.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay, I had a think about this but wasn't able to come up with a better solution myself, so I won't block on this

warnings.warn(
message=f"\n"
f"\tError importing extension '{modname}'.\n"
f"\t\t{type(err).__name__} : {err}",
category=UserWarning,
)

return scanned_extensions


def _is_concrete_extension_type(obj):
return inspect.isclass(obj) and issubclass(obj, TVMCExtension) and not inspect.isabstract(obj)
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/main.py
Expand Up @@ -80,7 +80,7 @@ def _main(argv):

subparser = parser.add_subparsers(title="commands")
for make_subparser in REGISTERED_PARSER:
make_subparser(subparser, parser, json_config_values)
make_subparser(subparser, parser, json_config_values, argv)

# Finally, add help for the main parser.
parser.add_argument("-h", "--help", action="help", help="show this help message and exit.")
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/micro.py
Expand Up @@ -46,7 +46,7 @@


@register_parser
def add_micro_parser(subparsers, main_parser, json_params):
def add_micro_parser(subparsers, main_parser, json_params, argv):
"""Includes parser for 'micro' context and associated subcommands:
create-project (create), build, and flash.
"""
Expand Down Expand Up @@ -163,7 +163,7 @@ def _add_parser(parser, platform):

disposable_parser = TVMCSuppressedArgumentParser(main_parser)
try:
known_args, _ = disposable_parser.parse_known_args()
known_args, _ = disposable_parser.parse_known_args(argv)
except TVMCException:
return

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/driver/tvmc/runner.py
Expand Up @@ -64,7 +64,7 @@


@register_parser
def add_run_parser(subparsers, main_parser, json_params):
def add_run_parser(subparsers, main_parser, json_params, argv):
"""Include parser for 'run' subcommand"""

# Use conflict_handler='resolve' to allow '--list-options' option to be properly overriden when
Expand Down Expand Up @@ -157,7 +157,7 @@ def add_run_parser(subparsers, main_parser, json_params):

disposable_parser = TVMCSuppressedArgumentParser(main_parser)
try:
known_args, _ = disposable_parser.parse_known_args()
known_args, _ = disposable_parser.parse_known_args(argv)
except TVMCException:
return

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/micro/model_library_format.py
Expand Up @@ -666,4 +666,4 @@ def export_model_library_format(

_make_tar(tempdir.path, file_name, modules)

return file_name
return str(file_name)
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/uma/api/lower.py
Expand Up @@ -72,7 +72,7 @@ def _get_tensors(te_cached_func):

compiler_attr = relay_prim_func.attrs["Compiler"]
target = tvm.target.Target.current()
if target.kind.name != compiler_attr:
if target is None or target.kind.name != compiler_attr:
target = tvm.target.Target(compiler_attr)

tir_prim_func = tir_prim_func.with_attr("target", target)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/backend/contrib/uma/api/partitioner.py
Expand Up @@ -73,7 +73,10 @@ def register(self) -> None:
register_pattern_table(self.target_name, self._pattern_table)

def partition(
self, mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
self,
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.runtime.NDArray]] = None,
mod_name: Optional[str] = "default",
) -> tvm.IRModule:
"""Partition the relay graph in parts supported and unsupported by the
target hardware accelerator.
Expand Down Expand Up @@ -102,7 +105,7 @@ def partition(
pass_sequence.append(relay.transform.AnnotateTarget(self.target_name))
if self.merge_compiler_regions:
pass_sequence.append(relay.transform.MergeCompilerRegions())
pass_sequence.append(relay.transform.PartitionGraph())
pass_sequence.append(relay.transform.PartitionGraph(mod_name=mod_name))
pass_sequence.extend(
[p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_0]
)
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/backend/contrib/uma/backend.py
Expand Up @@ -290,6 +290,9 @@ def register(self) -> None:
self._tir_to_runtime.register()

def partition(
self, mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None
self,
mod: tvm.IRModule,
params: Optional[Dict[str, tvm.runtime.NDArray]] = None,
mod_name: Optional[str] = "default",
) -> tvm.IRModule:
return self._relay_to_relay.partition(mod, params)
return self._relay_to_relay.partition(mod, params, mod_name=mod_name)
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import nonexistingmodule

nonexistingmodule.value = 1