Skip to content

Commit

Permalink
remove enable_fixed_point option, separate MatMulFixedPointRewriter i…
Browse files Browse the repository at this point in the history
…mplementation
  • Loading branch information
Aleksei-grovety committed Feb 2, 2024
1 parent cdb1702 commit 53d7a2d
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 64 deletions.
44 changes: 37 additions & 7 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,17 +1402,20 @@ def callback(self, pre, post, node_map):
return ethosu_fc


class MatMulRewriter(DFPatternCallback):
"""Legalize matrix multiplication to an NPU operator"""
class MatrixMultiplicationRewriter(DFPatternCallback):
"""Legalize matrix multiplication with two tensors into sequence of NPU operators"""

def __init__(self):
def __init__(
self,
params_class: Type,
pattern: CallPattern,
):
super().__init__(require_type=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name})
)(wildcard(), wildcard())
self.pattern = pattern
self.params_class = params_class

def callback(self, pre, post, node_map):
params = ethosu_patterns.MatMulParams(post.op.body)
params = self.params_class(post.op.body)
ifm = post.args[0]
ifm2 = post.args[1]
lut = relay.const([], dtype=params.ifm.dtype)
Expand Down Expand Up @@ -1497,6 +1500,32 @@ def callback(self, pre, post, node_map):
return relay.reshape(concat, params.ofm.shape)


class MatMulRewriter(MatrixMultiplicationRewriter):
"""Convert ethos-u.matmul composite function to sequence of NPU operators"""

def __init__(self):
super().__init__(
params_class=ethosu_patterns.MatMulParams,
pattern=(
wildcard().has_attr({"Composite": ethosu_patterns.MatMulParams.composite_name})
)(wildcard(), wildcard()),
)


class MatMulFixedPointRewriter(MatrixMultiplicationRewriter):
"""Convert ethos-u.matmul_fixed_point composite function to sequence of NPU operators"""

def __init__(self):
super().__init__(
params_class=ethosu_patterns.MatMulFixedPointParams,
pattern=(
wildcard().has_attr(
{"Composite": ethosu_patterns.MatMulFixedPointParams.composite_name}
)
)(wildcard(), wildcard()),
)


class PadRewriter(DFPatternCallback):
"""Convert ethos-u.pad2d composite function to ethosu_depthwise_conv2d
operator"""
Expand Down Expand Up @@ -1644,6 +1673,7 @@ def transform_npu_function(self, _, func: relay.Function) -> relay.Function:
PartitionedSplitRewriter(),
FullyConnectedRewriter(),
MatMulRewriter(),
MatMulFixedPointRewriter(),
SplitRewriter(),
ChannelPadRewriter(),
Conv2DRewriter(),
Expand Down
126 changes: 84 additions & 42 deletions python/tvm/relay/op/contrib/ethosu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1808,46 +1808,29 @@ class FullyConnectedParams:
@requires_vela
def __init__(self, func_body):
from tvm.relay.backend.contrib.ethosu.util import QDenseArgs # type: ignore
from tvm.relay.backend.contrib.ethosu.util import (
BiasAddArgs,
RequantArgs,
get_fixed_point_fraction_size,
is_fixed_point_enabled,
)
from tvm.relay.backend.contrib.ethosu.util import BiasAddArgs, RequantArgs

self.activation = None
if is_fixed_point_enabled():
fract_scale = tvm.relay.Constant(
tvm.nd.array(np.array(1 / 2 ** get_fixed_point_fraction_size()))
)
fract_zero_point = tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32")))
qnn_dense = func_body
bias_add = None
if str(func_body.op.name) == "clip":
self.activation = func_body
requantize_op = self.activation.args[0]
else:
if str(func_body.op.name) == "clip":
self.activation = func_body
requantize_op = self.activation.args[0]
else:
requantize_op = func_body
requantize_op = func_body

call = requantize_op.args[0]
if str(requantize_op.args[0].op.name) == "nn.bias_add":
bias_add = call
qnn_dense = call.args[0]
else:
bias_add = None
qnn_dense = call
call = requantize_op.args[0]
if str(requantize_op.args[0].op.name) == "nn.bias_add":
bias_add = call
qnn_dense = call.args[0]
else:
bias_add = None
qnn_dense = call

# weights & biases are params as they should be constant
self.weights = TensorParams(
qnn_dense.args[QDenseArgs.WEIGHTS.value],
None,
fract_scale
if is_fixed_point_enabled()
else qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value],
fract_zero_point
if is_fixed_point_enabled()
else qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value],
qnn_dense.args[QDenseArgs.WEIGHTS_SCALE.value],
qnn_dense.args[QDenseArgs.WEIGHTS_ZERO_POINT.value],
)
self.biases = (
TensorParams(
Expand All @@ -1862,20 +1845,14 @@ def __init__(self, func_body):
self.ifm = TensorParams(
qnn_dense.args[QDenseArgs.IFM.value],
None,
fract_scale if is_fixed_point_enabled() else qnn_dense.args[QDenseArgs.IFM_SCALE.value],
fract_zero_point
if is_fixed_point_enabled()
else qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value],
qnn_dense.args[QDenseArgs.IFM_SCALE.value],
qnn_dense.args[QDenseArgs.IFM_ZERO_POINT.value],
)
self.ofm = TensorParams(
func_body,
None,
fract_scale
if is_fixed_point_enabled()
else requantize_op.args[RequantArgs.OFM_SCALE.value],
fract_zero_point
if is_fixed_point_enabled()
else requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
requantize_op.args[RequantArgs.OFM_SCALE.value],
requantize_op.args[RequantArgs.OFM_ZERO_POINT.value],
)

def is_valid(self) -> bool:
Expand Down Expand Up @@ -1958,7 +1935,67 @@ def matmul_pattern():
)
req = is_op("qnn.requantize")(dense, is_constant(), is_constant(), is_constant(), is_constant())
optional_clip = req.optional(is_op("clip"))
return optional_clip | is_op("nn.dense")(wildcard(), wildcard())
return optional_clip


class MatMulFixedPointParams:
"""
This class will parse a call to an ethos-u.matmul_fixed_point composite
function and extract the parameter information.
"""

composite_name = "ethos-u.matmul_fixed_point"

@requires_vela
def __init__(self, func_body):
from tvm.relay.backend.contrib.ethosu.util import QDenseArgs, get_fixed_point_fraction_size

self.fraction_size = get_fixed_point_fraction_size()
fract_scale = tvm.relay.Constant(tvm.nd.array(np.array(1 / 2**self.fraction_size)))
fract_zero_point = tvm.relay.Constant(tvm.nd.array(np.array(0, dtype="int32")))
dense = func_body

self.activation = None
self.weights = TensorParams(
dense.args[QDenseArgs.WEIGHTS.value],
None,
fract_scale,
fract_zero_point,
)
self.ifm = TensorParams(
dense.args[QDenseArgs.IFM.value],
None,
fract_scale,
fract_zero_point,
)
self.ofm = TensorParams(
dense,
None,
fract_scale,
fract_zero_point,
)

def is_valid(self) -> bool:
"""
Checks whether matrix multiplication has compatible attributes with HW
"""

if self.fraction_size < 0 or self.fraction_size > 16:
return False
if not check_valid_dtypes([self.ifm, self.ofm], supported_dtypes=[np.int16]):
return False
if not len(self.ifm.shape) == 2:
return False
if not len(self.ofm.shape) == 2:
return False
# The weights must be transposed
if self.ifm.shape[1] != self.weights.shape[1]:
return False
return True


def matmul_fixed_point_pattern():
return is_op("nn.dense")(wildcard(), wildcard())


class HardSwishParams:
Expand Down Expand Up @@ -2251,6 +2288,11 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal
matmul_pattern(),
lambda pat: MatMulParams(pat).is_valid(),
),
(
MatMulFixedPointParams.composite_name,
matmul_fixed_point_pattern(),
lambda pat: MatMulFixedPointParams(pat).is_valid(),
),
(
MaxPool2DParams.composite_name,
qnn_maxpool2d_pattern(),
Expand Down
10 changes: 1 addition & 9 deletions src/relay/backend/contrib/ethosu/compiler_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
Bool enable_cascader = Bool(false);
Bool enable_striping = Bool(false);
Bool disable_copying_constants = Bool(false);
Bool enable_fixed_point = Bool(false);
Integer fixed_point_fraction_size = Integer(0);
String dev_force_block_config;
String dev_max_open_plans;
Expand Down Expand Up @@ -74,17 +73,10 @@ struct EthosUCompilerConfigNode : public tvm::AttrsNode<EthosUCompilerConfigNode
"the linker script for section \".rodata.tvm\" that the constants are located in SRAM)")
.set_default(Bool(false));
String dev_warning = "Option is intended for development and debugging purposes only. ";
TVM_ATTR_FIELD(enable_fixed_point)
.describe(
"Whether calculation with fixed point is enabled. When this option "
"is "
"enabled, it is assumed that input data should be converted to fixed point "
"representation")
.set_default(Bool(false));
TVM_ATTR_FIELD(fixed_point_fraction_size)
.describe(
"Fraction size refers to the number of bits used to represent the fractional part of a "
"fixed point number")
"fixed point number for non-quantized int16 operations")
.set_default(Integer(0));
TVM_ATTR_FIELD(dev_force_block_config)
.describe((dev_warning + String("Force the block config to a given value; format = "
Expand Down
4 changes: 0 additions & 4 deletions tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def create_test_runner(
enable_cascader=False,
enable_striping=False,
workspace_pools=None,
enable_fixed_point=False,
fixed_point_fraction_size=0,
):

Expand Down Expand Up @@ -170,7 +169,6 @@ def create_test_runner(
"accelerator_config": accel,
"enable_cascader": enable_cascader,
"enable_striping": enable_striping,
"enable_fixed_point": enable_fixed_point,
"fixed_point_fraction_size": fixed_point_fraction_size,
},
"tir.usmp.enable": enable_usmp,
Expand Down Expand Up @@ -337,7 +335,6 @@ def compare_ethosu_with_reference(
output_tolerance=0,
print_cmm=False,
enable_cascader=None,
enable_fixed_point=False,
fixed_point_fraction_size=0,
):
if enable_cascader is None:
Expand Down Expand Up @@ -365,7 +362,6 @@ def compare_ethosu_with_reference(
enable_cascader=enable_cascader,
enable_striping=False,
workspace_pools=workspace_pools,
enable_fixed_point=enable_fixed_point,
fixed_point_fraction_size=fixed_point_fraction_size,
)
compiled_models = build_source(
Expand Down
3 changes: 1 addition & 2 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ def convert_to_fixed_point(arr, fract_size):
output_data = {"output": convert_to_fixed_point(output_data, fract_size)}
tolerance = convert_to_fixed_point(tolerance, fract_size)

config = {"enable_fixed_point": True, "fixed_point_fraction_size": fract_size}
config = {"fixed_point_fraction_size": fract_size}
with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}):
ethosu_mod = partition_for_ethosu(cpu_mod)

Expand All @@ -1672,7 +1672,6 @@ def convert_to_fixed_point(arr, fract_size):
accel_type,
enable_cascader=False,
output_tolerance=tolerance,
enable_fixed_point=True,
fixed_point_fraction_size=fract_size,
)

Expand Down

0 comments on commit 53d7a2d

Please sign in to comment.