diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 6e61e762ee21..c42974593a6b 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -20,8 +20,12 @@ """ import logging import os.path +import re +import itertools +from copy import deepcopy from typing import Any, Optional, Dict, List, Union, Callable, Sequence from pathlib import Path +from collections import defaultdict import tvm from tvm import autotvm, auto_scheduler @@ -31,6 +35,8 @@ from tvm.ir.memory_pools import WorkspaceMemoryPools from tvm.target import Target from tvm.relay.backend import Executor, Runtime +from tvm.relay.analysis.operations_distribution import analyze_operations_distribution +from tvm.relay.transform.suffixes import tag_suffixes from . import composite_target, frontends, TVMCException from .model import TVMCModel, TVMCPackage @@ -69,6 +75,16 @@ def add_compile_parser(subparsers, _, json_params): default="", help="comma separated list of formats to export the input model, e.g. 'asm,ll,tir,relay'.", ) + parser.add_argument( + "--dump-offloads", + default="", + help="output a mapping of which operations of the initial Relay " + "will be transferred to which backend, indicating the composite " + "that includes those operations, " + "e.g. '--dump-offloads -' to dump to the console, " + "e.g. '--dump-offloads ' to dump to the file. " + "If not presented, no output is done. ", + ) parser.add_argument( "--model-format", choices=frontends.get_frontend_names(), @@ -171,6 +187,8 @@ def drive_compile(args): dump_code = [x.strip() for x in args.dump_code.split(",")] if args.dump_code else None + dump_offloads = args.dump_offloads if args.dump_offloads else "" + additional_targets = reconstruct_target_args(args) workspace_pools_target, extra_targets = target_from_cli(args.target, additional_targets) transform_args = parse_graph_transform_args(args) @@ -187,6 +205,7 @@ def drive_compile(args): cross_options=args.cross_compiler_options, output_format=args.output_format, dump_code=dump_code, + dump_offloads=dump_offloads, target_host=None, disabled_pass=args.disabled_pass, pass_context_configs=args.pass_config, @@ -213,6 +232,7 @@ def compile_model( cross_options: Optional[str] = None, output_format: str = "so", dump_code: Optional[List[str]] = None, + dump_offloads: str = "", target_host: Optional[str] = None, disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, @@ -259,6 +279,10 @@ def compile_model( dump_code : list[str], optional Dump the generated code for the specified source types, on the requested target. Choose from: ["asm", "ll", "tir", "relay"]. + dump_offloads : str + Dump the information about the partition of input model's layers by external codegen. + Can be '' to not dump at all, '-' to dump to the console + or '' to dump to the specified file. target_host : str, optional The target of the host machine if host-side code needs to be generated. @@ -313,6 +337,13 @@ def compile_model( if "tir" in dump_code: config, dumps = add_tir_to_dumps(config, dumps) + initial_relay = None + if dump_offloads != "": + # add suffixes to the span field for calls in Relay + mod = tag_suffixes(mod) + # remember initial Relay + initial_relay = deepcopy(mod) + tvm_target, extra_targets = target_from_cli(target, additional_target_options) tvm_target, target_host = Target.canon_target_and_host(tvm_target, target_host) @@ -337,6 +368,10 @@ def compile_model( for partition_function, opts in zip(partition_functions, partition_opts): mod = partition_function(mod, params, mod_name=mod_name, **opts) + if initial_relay: + # dump which operations are offloaded to which backend + dump_operation_offloads(mod, initial_relay, dump_offloads) + if tuning_records and os.path.exists(tuning_records): logger.debug("tuning records file provided: %s", tuning_records) @@ -496,3 +531,141 @@ def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): dump_name = module_name + "." + dump_format with open(Path(dump_root, dump_name), "w") as f: f.write(dumps[dump_format]) + + +def dump_operation_offloads(mod: tvm.ir.IRModule, initial_mod: tvm.ir.IRModule, dump_path: str): + """This helper function forms a line-by-line output of the initial Relay lines, + indicating which operations are ported to which target, + and indicating the composite that includes those operations; + the 'generic' target refers to operations uploaded to the host, e.g + 'target1 <- target1.qnn_conv2d' + 'target1 <- %0 = qnn.conv2d(%tfl.quantize, %v_param_1, ...' + 'target1 <- %1 = nn.bias_add(%0, %v_param_2, axis=3);' + 'target1 <- %2 = qnn.requantize(%1, meta[relay.Constant]...' + 'target2 <- target2.reshape' + 'target2 <- %3 = reshape(%2, newshape=[1, 1001]);' + 'generic <- %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1]...' + + Parameters + ---------- + mod : tvm.ir.IRModule + The partitioned IRModule with external global functions. + initial_mod : tvm.ir.IRModule + The initial IRModule that gets generated from a relay frontend. + dump_path: str + Value of the "dump_offloads" compiler atribute. + Could be dash ("-") or file path or empty string for + printing to console, file or doing nothing respectively. + """ + print_to_console = dump_path == "-" + save_to_file = all([dump_path != "-", dump_path != ""]) + + if print_to_console or save_to_file: + + operations_distribution = analyze_operations_distribution(mod) + + def annotate_f(x): + ret = "" + if isinstance(x, relay.Call): + # if there is no x.span.source_name.name in operations_distribution, + # this could mean that the span was not copied during the application of passes + # to the Relay, in which case we can not associate the initial Relay string + # with the resulting Relay call + source_name = x.span.source_name.name + if source_name in operations_distribution: + compiler_name, op_name, func_id = operations_distribution[source_name] + ret = ( + f", compiler_name: {compiler_name}, op_name: {op_name}, " + f"func_id: {func_id}" + ) + else: + ret = ", compiler_name: unknown, op_name: unknown, func_id: unknown" + return ret + + initial_relay_astext = initial_mod.astext(show_meta_data=False, annotate=annotate_f).split( + "\n" + ) + + # funcs_list is a list of internal composite/function IDs + # generated by analyze_operations_distribution(). + # funcs_list helps keep the order of lines from the initial Relay. + funcs_list = [] + + # target_statistic is a mapping of the target name to the + # number of initial Relay calls offloaded on the target + target_statistic = defaultdict(int) + + # funcs_dict is a mapping of the generated analyze_operations_distribution + # internal composite/function IDs to a list, where: + # 1st element is + # (1a): target name - it could be "generic" or "unknown" or + # (1b): specific target name, like "ethos-u" or "cmsis-nn" + # 2nd element is + # (2a): corresponding initial Relay line for the case (1a) or + # (2b): the name of the target composite functon in the other case (1b) + # 3rd element or subsequent ones are presented only for the case (2b) + # and are the initial Relay lines included in the corresponding + # target composite functon + funcs_dict = {} + + # Here we group together initial Relay lines from the one composite + counter = itertools.count() + for s in initial_relay_astext: + result = re.search( + r"(compiler_name: )(.*)(, op_name: )(.*)(, func_id: )((.*)(?=;)|(.*))", s + ) + if result: + target_name = result.group(2) + op_name = result.group(4) + func_id = result.group(6) + s = re.sub(r", compiler_name: (.*)", "", s).lstrip() + target_statistic[target_name] += 1 + + # create an identifier for each "unknown" case to keep the lines order + if func_id == "unknown": + func_id = str(next(counter) * -1) + + if func_id not in funcs_dict: + funcs_list.append(func_id) + funcs_dict[func_id] = [target_name] + if target_name not in ["unknown", "generic"]: + funcs_dict[func_id].append(op_name) + + funcs_dict[func_id].append(s) + + # Here we prepare the output for printing. + # The output in most cases keeps the original order of the Relay lines + # but some lines are moved to be in the corresponding composite group + output = [] + total = 0 + output.append("Total number of operators and distribution by targets") + output.append("Total:") + for target, statistic in target_statistic.items(): + total += statistic + output.append(f"{target}: {statistic}") + output[1] += f" {total}" + output[len(target_statistic) + 1] += "\n" + + for func_id in funcs_list: + _list = funcs_dict[func_id] + output.append(f"{_list[0]:10} <- {_list[1]}") + if _list[0] == "unknown": + output.append( + "Warning: The above line means that some pass(es) \ + in Relay partitioning" + ) + output.append("do not copy the span when the call is recreated") + output.append( + "and a line from initial Relay could not be associated \ + with the resulting Relay" + ) + for el in _list[2:]: + output.append(f"{_list[0]:10} <- {el}") + + if print_to_console: + print("\n" + "\n".join(output)) + if save_to_file: + file_path = os.path.abspath(dump_path) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as f: + f.write("\n".join(output)) diff --git a/python/tvm/relay/analysis/operations_distribution.py b/python/tvm/relay/analysis/operations_distribution.py new file mode 100644 index 000000000000..fc983c8e7eed --- /dev/null +++ b/python/tvm/relay/analysis/operations_distribution.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utilities that enable analyze Relay and get mappings for +the unique identifier of the Relay line to the tuple of +compiler name, composite name and composite/function identifier.""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor + + +class AnalyzeOperationsDistribution(ExprVisitor): + """A visitor pass that maintains the dictionary unique_op_ids where + the tuple (compiler name, composite name, composite/function identifier) + corresponds to the unique identifier of the Relay line. + TVMC compiler adds a unique Relay line identifier as a suffix + to the call span field using the tag_suffixes pass + if the --dump-offloads option is specified. + + Attributes + ---------- + unique_op_ids : Dict[str, str, int] + Mapping the unique identifier of the Relay line obtained from + the "span" field of the Call and the tuple of compiler name, + composite name and internal composite/function identifier. + func_name : str + The name of the composite name in the partitioned Relay or + 'generic' in case the Call has not been included in any composite. + func_id : int + Internal(inside unique_op_ids) composite/function identifier. + compiler_name : str + A name of the compiler (e.g. 'ethos-u' or 'cmsis-nn') or 'generic' + in case the Call has not been included in any composite. + """ + + def __init__(self): + self.unique_op_ids = {} + self.func_name = "" + self.func_id = 1 + self.compiler_name = "" + super().__init__() + + def extract(self, call: relay.Call): + self.compiler_name = "generic" + self.func_name = "generic" + if "Compiler" in call.attrs: + self.compiler_name = call.attrs["Compiler"] + self.visit(call) + + def visit_call(self, call: relay.Call): + if isinstance(call.op, tvm.ir.Op): + if call.span: + src = call.span.source_name.name + self.unique_op_ids[src] = [self.compiler_name, self.func_name, self.func_id] + if self.func_name == "generic": + self.func_id += 1 + if isinstance(call.op, relay.Function): + self.func_name = call.op.attrs["Composite"] + self.func_id += 1 + super().visit_call(call) + + +def analyze_operations_distribution(mod): + """Traverses the partitioned graph to get the unique identifier + of the Relay line from the Call's span field. + The result is maintained in the dictionary unique_op_ids where + the unique indicator obtained from the op's span corresponds to + the tuple (compiler name, composite name, composite/function identifier). + With this information we can annotate the textual representation + of the initial Relay by indicating into which target composite + and function the operators are converted + + Parameters + ---------- + mod : tvm.ir.IRModule + The partitioned Relay graph usually obtained with + partition_for_ function + + Returns + ------- + unique_op_ids : Dict[str, str, int] + Mapping from the unique identifier of the Relay line to the tuple of + compiler name, composite name, internal composite/function + identifier. + """ + analyze = AnalyzeOperationsDistribution() + for _, func in mod.functions.items(): + analyze.extract(func) + return analyze.unique_op_ids diff --git a/python/tvm/relay/transform/suffixes.py b/python/tvm/relay/transform/suffixes.py new file mode 100644 index 000000000000..e2f7a3c224c1 --- /dev/null +++ b/python/tvm/relay/transform/suffixes.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"Add suffix to the relay.Call's span fields" +from collections import defaultdict + +import tvm + +from ..expr_functor import ExprMutator +from .. import expr as _expr + + +class _SuffixTagger(ExprMutator): + """A pass to traverse the Relay graph to add suffix to the call's span fields. + This making span an unique indicator of a Relay line and we can use it to + obtain the mapping between the Relay that gets generated from a relay frontend + and the Relay after partitioning. + """ + + def __init__(self): + ExprMutator.__init__(self) + # key: span or source name, value: counter, indexed from 0 + self.lookup = defaultdict(int) + self.suffix = "_PART_" + # a set to record hashes of an expressions which spans have been already rewritten + self.hashes = set() + + def _tag_suffix(self, span): + # To avoid error once we introduce the SequentialSpan in the future + """https://discuss.tvm.apache.org/ + t/pre-rfc-tvm-explorer-infrastructure/13457#pass-source-information-builder-6 + """ + # Don't need this if currently + if isinstance(span, tvm.relay.Span): + ori_name = span.source_name.name + new_name = ori_name + self.suffix + str(self.lookup[ori_name]) + self.lookup[ori_name] += 1 + return tvm.relay.Span( + tvm.relay.SourceName(new_name), + span.line, + span.end_line, + span.column, + span.end_column, + ) + return span + + def visit(self, expr): + if hasattr(expr, "span"): + return super().visit(expr) + return expr + + def visit_call(self, call): + new_args = [self.visit(arg) for arg in call.args] + new_op = self.visit(call.op) + if tvm.ir.structural_hash(call) not in self.hashes: + self.hashes.add(tvm.ir.structural_hash(call)) + expr__ = _expr.CallWithFields( + call, + new_op, + new_args, + call.attrs, + call.type_args, + None, + self._tag_suffix(call.span), + ) + else: + expr__ = _expr.CallWithFields( + call, new_op, new_args, call.attrs, call.type_args, None, call.span + ) + return expr__ + + +def tag_suffixes(mod): + """Traverses the Relay graph to add suffix to the call's span fields. + That making span as an unique indicator of a Relay call and we can use it to + obtain the mapping between the offloaded result and the frontend operators. + + Parameters + ---------- + tvm.ir.IRModule + The IRModule that gets generated from a relay frontend. + + Returns + ------- + tvm.ir.IRModule + The IRModule with call's span fields tagged with suffixes. + """ + tagger = _SuffixTagger() + for global_var, func in mod.functions.items(): + func = tagger.visit(func) + mod.update_func(global_var, func) + return mod diff --git a/src/relay/backend/contrib/cmsisnn/extract_constants.cc b/src/relay/backend/contrib/cmsisnn/extract_constants.cc index c6ed7af9ff03..f82014d5d1f5 100644 --- a/src/relay/backend/contrib/cmsisnn/extract_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/extract_constants.cc @@ -206,6 +206,7 @@ class ExtractConstantsMutator : public MixedModeMutator { final_call = Call(new_func, new_args); } + final_call->span = call->span; return final_call; } diff --git a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc index 71c31c303588..0ef7091fc289 100644 --- a/src/relay/backend/contrib/cmsisnn/fuse_pads.cc +++ b/src/relay/backend/contrib/cmsisnn/fuse_pads.cc @@ -138,7 +138,7 @@ class FusePadsMutator : public MixedModeMutator { auto new_conv2d_args = conv2d_call->args; new_conv2d_args.erase(new_conv2d_args.begin()); new_conv2d_args.insert(new_conv2d_args.begin(), new_conv2d_input); - Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {}); + Call ret_call = Call(conv2d_call->op, new_conv2d_args, new_conv2d_attrs, {}, conv2d_call->span); return std::move(ret_call); } @@ -162,6 +162,7 @@ class FusePadsMutator : public MixedModeMutator { Function new_func = Function(FreeVars(new_body), new_body, func->ret_type, FreeTypeVars(new_body, mod_), func->attrs); ret_call = Call(new_func, post_call->args); + ret_call->span = call->span; } } diff --git a/src/relay/backend/contrib/cmsisnn/generate_constants.cc b/src/relay/backend/contrib/cmsisnn/generate_constants.cc index e08b61c457f9..3bdbb5d057eb 100644 --- a/src/relay/backend/contrib/cmsisnn/generate_constants.cc +++ b/src/relay/backend/contrib/cmsisnn/generate_constants.cc @@ -153,16 +153,17 @@ class GenerateConstantsMutator : public MixedModeMutator { // Conv2D arguments: data, weight, input_zp, weight_zp, input_sc, weight_sc Array conv2d_args = {conv2d_call->args[0], conv2d_kernel, conv2d_call->args[2], multiplier_const, conv2d_call->args[4], weight_scale}; - Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}); + Call ret_call = Call(conv2d_call->op, conv2d_args, new_conv2d_attrs, {}, conv2d_call->span); if (bias_add_call) { - ret_call = - Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, {}); + ret_call = Call(bias_add_call->op, {ret_call, bias_add_call->args[1]}, bias_add_call->attrs, + {}, bias_add_call->span); } Array requantize_args = {ret_call, req_inp_scale, shift_const, requantize_call->args[3], requantize_call->args[4]}; - ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {}); + ret_call = Call(requantize_call->op, requantize_args, requantize_call->attrs, {}, + requantize_call->span); if (clip_call) { - ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}); + ret_call = Call(clip_call->op, {ret_call}, clip_call->attrs, {}, clip_call->span); } return std::move(ret_call); } @@ -198,6 +199,7 @@ class GenerateConstantsMutator : public MixedModeMutator { } } + final_call->span = call->span; return final_call; } diff --git a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc index 0e2036505b6f..f64f485bfda2 100644 --- a/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc +++ b/src/relay/backend/contrib/cmsisnn/scalar_to_tensor_constant.cc @@ -83,6 +83,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { FreeTypeVars(new_body, mod_), func->attrs); mod_->Update(global_var, new_func); final_call = Call(global_var, call->args); + final_call->span = call->span; } // Substitute scalar constant with tensor constant in the call to composite function. @@ -140,7 +141,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { String arg_name = scalar_arg.as()->name_hint(); new_args.Set(i, Var(arg_name, tensor_arg->checked_type_)); } - return Call(call->op, new_args, call->attrs, {}); + return Call(call->op, new_args, call->attrs, {}, call->span); } // Replaces scalar constant with a tensor constant with same shape as that of the neighbouring @@ -187,7 +188,7 @@ class ScalarToTensorConstantMutator : public MixedModeMutator { if (new_args[0].same_as(new_args[1])) { new_args.erase(new_args.begin()); } - return Call(new_func, new_args); + return Call(new_func, new_args, Attrs(), {}, call->span); } private: diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 3f1985b7ddfa..eb6f9ec00432 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -258,6 +258,7 @@ class AnnotateTargetRewriter : public ExprRewriter { Array compiler_begins = std::get<1>(target_n_args); Call new_call = Call(post_call->op, compiler_begins, post_call->attrs); new_call->checked_type_ = pre->checked_type_; + new_call->span = pre->span; // Update the target map. op_expr_to_target_[new_call] = target; diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py index 844d08c66e03..e6ebec6ac4fa 100644 --- a/tests/python/contrib/test_ethosu/infra.py +++ b/tests/python/contrib/test_ethosu/infra.py @@ -28,7 +28,6 @@ import os import struct import numpy as np -import tflite.Model import math from enum import IntEnum import tensorflow as tf @@ -311,7 +310,15 @@ def representative_dataset(): converter.inference_output_type = tf.int8 tflite_graph = converter.convert() - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + # Get TFLite model from buffer + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) relay_module, params = relay.frontend.from_tflite(tflite_model) mod = partition_for_ethosu(relay_module, params) diff --git a/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py b/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py new file mode 100644 index 000000000000..2a9d88e41210 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_pass_operations_distribution.py @@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import numpy as np + +from tvm import relay +from tests.python.contrib.test_ethosu.infra import get_tflite_graph +from tvm.relay.op.contrib.ethosu import partition_for_ethosu +from tvm.relay.analysis.operations_distribution import analyze_operations_distribution +from tvm.relay.transform.suffixes import tag_suffixes + + +def test_operations_distribution_ethos(): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]], + "CONSTANT", + ) + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + + # Get TFLite model from buffer + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + + mod = tag_suffixes(mod) + mod = partition_for_ethosu(mod, params) + operations_distribution = analyze_operations_distribution(mod) + + expected = { + "Pad_PART_0": ["generic", "generic", 1], + "Conv2D2_PART_2": ["ethos-u", "ethos-u.qnn_conv2d", 3], + "Conv2D2_PART_1": ["ethos-u", "ethos-u.qnn_conv2d", 3], + "Conv2D2_PART_0": ["ethos-u", "ethos-u.qnn_conv2d", 3], + "Identity_PART_0": ["ethos-u", "ethos-u.pad2d", 4], + "Pad_1_PART_0": ["ethos-u", "ethos-u.pad2d", 5], + } + + assert operations_distribution == expected + + +def test_operations_distribution_generic(): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + dilations_out = 32 + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=dilations_out, + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + + # Get TFLite model from buffer + try: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_graph, 0) + except AttributeError: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + + mod = tag_suffixes(mod) + mod = partition_for_ethosu(mod, params) + operations_distribution = analyze_operations_distribution(mod) + + expected = { + "Identity_PART_0": ["generic", "generic", 1], + "Pad_1_PART_0": ["generic", "generic", 2], + "Pad_PART_0": ["generic", "generic", 3], + "Conv2D2_PART_2": ["generic", "generic", 4], + "Conv2D2_PART_1": ["generic", "generic", 5], + "Conv2D2_PART_0": ["generic", "generic", 6], + } + + assert operations_distribution == expected + + +if __name__ == "__main__": + test_operations_distribution() diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 61b1828aad99..f624984481da 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -16,6 +16,7 @@ # under the License. import os import re +import numpy as np import shutil import tarfile from os import path @@ -29,6 +30,7 @@ import tvm.testing from tvm.relay.op.contrib.ethosn import ethosn_available from tvm.relay.backend import Runtime, Executor +from tvm import relay from tvm.contrib.target.vitis_ai import vitis_ai_available @@ -49,6 +51,355 @@ def test_save_dumps(tmpdir_factory): assert path.exists("{}/{}".format(tmpdir, "fake_module.relay")) +def test_save_dump_offloads_ethosu(tmp_path_factory): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + import tflite.Model + from tvm.driver.tvmc.model import TVMCModel + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + weight_shape[2] = 3 + weights1 = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + weights2 = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op1 = tf.nn.conv2d( + op, + filters=weights1, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op2 = tf.nn.conv2d( + op, + filters=weights2, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.math.add(op1, op2) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[1]], [padding_out[2], padding[3]], [0, 0]], + "CONSTANT", + ) + return op + + from tests.python.contrib.test_ethosu.infra import get_tflite_graph + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, params = relay.frontend.from_tflite(tflite_model) + + tvmc_model = TVMCModel(mod, params) + + output_dir = tmp_path_factory.mktemp("tmp") + output_file_name = os.path.join(str(output_dir), "list.txt") + + tvmc.compiler.compile_model( + tvmc_model, + target="ethos-u,cmsis-nn,c", + runtime=Runtime("crt"), + tuning_records="", + package_path="module.tar", + executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}), + cross="", + cross_options="", + output_format="mlf", + dump_offloads=output_file_name, + disabled_pass=[""], + pass_context_configs=[ + "tir.disable_vectorize=1", + "tir.usmp.enable=1", + "tir.usmp.algorithm=hill_climb", + "tir.disable_storage_rewrite=1", + "relay.frontend.fill_span=1", + ], + additional_target_options={ + "c": {"mcpu": "cortex-m55"}, + "cmsis-nn": {"mcpu": "cortex-m55"}, + "ethos-u": { + "accelerator_config": "ethos-u55-256", + }, + }, + ) + + expected = [ + r"Total number of operators and distribution by targets", + r"Total: 11", + r"ethos-u: 10", + r"generic: 1", + r"", + r"ethos-u <- ethos-u.qnn_conv2d", + r'ethos-u <- %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392157f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"ethos-u <- %1 = nn.bias_add(%0, %v_param_2, axis=3)", + r'ethos-u <- %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.11364f, -128, axis=3, out_dtype="int8")', + r"ethos-u <- ethos-u.qnn_conv2d", + r'ethos-u <- %3 = qnn.conv2d(%2, %v_param_3, -128, 0, 0.11364f, meta[relay.Constant][2], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"ethos-u <- %4 = nn.bias_add(%3, %v_param_4, axis=3)", + r'ethos-u <- %7 = qnn.requantize(%4, meta[relay.Constant][3], 0, 1.56803f, -128, axis=3, out_dtype="int8")', + r"ethos-u <- ethos-u.qnn_conv2d", + r'ethos-u <- %5 = qnn.conv2d(%2, %v_param_5, -128, 0, 0.11364f, meta[relay.Constant][4], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"ethos-u <- %6 = nn.bias_add(%5, %v_param_6, axis=3)", + r'ethos-u <- %8 = qnn.requantize(%6, meta[relay.Constant][5], 0, 1.20538f, -128, axis=3, out_dtype="int8")', + r"ethos-u <- ethos-u.add", + r"ethos-u <- %9 = qnn.add(%7, %8, 1.56803f, -128, 1.20538f, -128, 2.77341f, -128)", + r"generic <- nn.pad(%9, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])", + ] + + file_path = os.path.abspath(output_file_name) + # check that file file_path was created + assert os.path.exists(file_path) + with open(file_path, "r") as f: + for i, file_string in enumerate(f): + r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL) + r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL) + # check that there is the same sequence of operations and composites, + # combined with target names + if r_output and r_expected: + assert r_output.group(0) == r_expected.group(0) + else: + assert r_output == r_expected + + +def test_save_dump_offloads_cmsis(tmp_path_factory): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + from tvm.driver.tvmc.model import TVMCModel + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.nn.relu(op) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]], + "CONSTANT", + ) + + from tests.python.contrib.test_ethosu.infra import get_tflite_graph + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, params = relay.frontend.from_tflite(tflite_model) + + tvmc_model = TVMCModel(mod, params) + + output_dir = tmp_path_factory.mktemp("tmp") + output_file_name = os.path.join(str(output_dir), "list.txt") + + tvmc.compiler.compile_model( + tvmc_model, + target="cmsis-nn,c", + runtime=Runtime("crt"), + tuning_records="", + package_path="module.tar", + executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}), + cross="", + cross_options="", + output_format="mlf", + dump_offloads=output_file_name, + disabled_pass=[""], + pass_context_configs=[ + "tir.disable_vectorize=1", + "tir.usmp.enable=1", + "tir.usmp.algorithm=hill_climb", + "tir.disable_storage_rewrite=1", + "relay.frontend.fill_span=1", + ], + additional_target_options={ + "c": {"mcpu": "cortex-m55"}, + "cmsis-nn": {"mcpu": "cortex-m55"}, + }, + ) + + expected = [ + r"Total number of operators and distribution by targets", + r"Total: 7", + r"cmsis-nn: 4", + r"generic: 3", + r"", + r"cmsis-nn <- cmsis-nn.qnn_conv2d", + r'cmsis-nn <- %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392157f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"cmsis-nn <- %1 = nn.bias_add(%0, %v_param_2, axis=3)", + r'cmsis-nn <- %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.113405f, -128, axis=3, out_dtype="int8")', + r"cmsis-nn <- %3 = clip(%2, a_min=-128f, a_max=127f)", + r"generic <- %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])", + r"generic <- %5 = nn.pad(%4, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + r"generic <- nn.pad(%5, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + ] + + file_path = os.path.abspath(output_file_name) + # check that file file_path was created + assert os.path.exists(file_path) + with open(file_path, "r") as f: + for i, file_string in enumerate(f): + r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL) + r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL) + # check that there is the same sequence of operations and composites, + # combined with target names + if r_output and r_expected: + assert r_output.group(0) == r_expected.group(0) + else: + assert r_output == r_expected + + +def test_save_dump_offloads_generic(tmp_path_factory): + + tflite = pytest.importorskip("tflite") + tensorflow = pytest.importorskip("tensorflow") + pytest.importorskip("ethosu.vela") + + import tensorflow as tf + from tvm.driver.tvmc.model import TVMCModel + + inp = (224, 224, 9) + input_shape = (1, *inp) + kernel_shape = (3, 3) + padding = (1, 1, 1, 1) + padding_out = (1, 33, 33, 1) + + @tf.function + def simple_net(x): + weight_shape = [kernel_shape[0], kernel_shape[1], input_shape[3], 3] + weights = tf.constant(np.random.uniform(size=weight_shape), dtype=tf.float32) + op = tf.nn.conv2d( + x, + filters=weights, + strides=1, + padding="SAME", + data_format="NHWC", + dilations=1, + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding_out[2]], [padding_out[1], padding[3]], [0, 0]], + "CONSTANT", + ) + op = tf.pad( + op, + [[0, 0], [padding[0], padding[2]], [padding[1], padding[3]], [0, 0]], + "CONSTANT", + ) + return tf.pad( + op, + [[0, 0], [padding_out[0], padding[2]], [padding[1], padding_out[3]], [0, 0]], + "CONSTANT", + ) + + from tests.python.contrib.test_ethosu.infra import get_tflite_graph + + _, tflite_graph = get_tflite_graph(simple_net, [input_shape]) + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + mod, params = relay.frontend.from_tflite(tflite_model) + + tvmc_model = TVMCModel(mod, params) + + output_dir = tmp_path_factory.mktemp("tmp") + output_file_name = os.path.join(str(output_dir), "list.txt") + + tvmc.compiler.compile_model( + tvmc_model, + target="c", + runtime=Runtime("crt"), + tuning_records="", + package_path="module.tar", + executor=Executor("aot", {"unpacked-api": 1, "interface-api": "c", "link-params": True}), + cross="", + cross_options="", + output_format="mlf", + dump_offloads=output_file_name, + disabled_pass=[""], + pass_context_configs=[ + "tir.disable_vectorize=1", + "tir.usmp.enable=1", + "tir.usmp.algorithm=hill_climb", + "tir.disable_storage_rewrite=1", + "relay.frontend.fill_span=1", + ], + additional_target_options={ + "c": {"mcpu": "cortex-m55"}, + }, + ) + + expected = [ + r"Total number of operators and distribution by targets", + r"Total: 6", + r"generic: 6", + r"", + r'generic <- %0 = qnn.conv2d(%x, %v_param_1, -128, 0, 0.00392156f, meta[relay.Constant][0], padding=[1, 1, 1, 1], channels=3, kernel_size=[3, 3], data_layout="NHWC", kernel_layout="HWIO", out_dtype="int32")', + r"generic <- %1 = nn.bias_add(%0, %v_param_2, axis=3)", + r'generic <- %2 = qnn.requantize(%1, meta[relay.Constant][1], 0, 0.103975f, -128, axis=3, out_dtype="int8")', + r"generic <- %3 = nn.pad(%2, -128f, pad_width=[[0, 0], [1, 33], [33, 1], [0, 0]])", + r"generic <- %4 = nn.pad(%3, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + r"generic <- nn.pad(%4, -128f, pad_width=[[0, 0], [1, 1], [1, 1], [0, 0]])", + ] + + file_path = os.path.abspath(output_file_name) + # check that file file_path was created + assert os.path.exists(file_path) + with open(file_path, "r") as f: + for i, file_string in enumerate(f): + r_output = re.search(r"(.*)\(", file_string.strip(), re.DOTALL) + r_expected = re.search(r"(.*)\(", expected[i], re.DOTALL) + # check that there is the same sequence of operations and composites, + # combined with target names + if r_output and r_expected: + assert r_output.group(0) == r_expected.group(0) + else: + assert r_output == r_expected + + # End to end tests for compilation