From be17db85cff571440da826e1a847e9278ce29cb8 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 15:32:43 +0800 Subject: [PATCH 1/6] [Relax][Frontend][TFLite] Add StableHLO unary/binary elementwise op support Add frontend mapping for 8 basic StableHLO TFLite builtin operators as pure unary/binary elementwise ops: - STABLEHLO_ABS, STABLEHLO_NEGATE (unary) - STABLEHLO_ADD, STABLEHLO_SUBTRACT, STABLEHLO_MULTIPLY, STABLEHLO_DIVIDE, STABLEHLO_MAXIMUM, STABLEHLO_MINIMUM (binary) Implementation uses dedicated _convert_stablehlo_unary / _convert_stablehlo_binary helpers that intentionally bypass TFLite fused-activation and QNN code paths, since StableHLO ops carry no TFLite-specific quantization or fused-activation metadata in their flatbuffer representation. Test coverage: 8 structural-equal tests with tvm.ir.assert_structural_equal. --- .../relax/frontend/tflite/tflite_frontend.py | 55 ++++++++++ tests/python/relax/test_frontend_tflite.py | 103 ++++++++++++++++++ 2 files changed, 158 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index d70f5d837e0f..743c1e2747c4 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -239,6 +239,30 @@ def __init__(self, model, subgraph, exp_tab, ctx): "SQRT": functools.partial(self._convert_unary_elemwise, relax_op=_op.sqrt), "SQUARE": self.convert_square, "SQUARED_DIFFERENCE": self.convert_squared_difference, + "STABLEHLO_ABS": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.abs + ), + "STABLEHLO_ADD": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.add + ), + "STABLEHLO_DIVIDE": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.divide + ), + "STABLEHLO_MAXIMUM": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.maximum + ), + "STABLEHLO_MINIMUM": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.minimum + ), + "STABLEHLO_MULTIPLY": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.multiply + ), + "STABLEHLO_NEGATE": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.negative + ), + "STABLEHLO_SUBTRACT": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.subtract + ), "SQUEEZE": self.convert_squeeze, "STRIDED_SLICE": self.convert_strided_slice, "SUB": functools.partial(self._convert_elemwise, relax_op=_op.subtract), @@ -1322,6 +1346,37 @@ def _convert_unary_elemwise(self, op, relax_op): out = self.quantize(out, output_tensor) return out + def _convert_stablehlo_unary(self, op, relax_op): + """Convert a unary StableHLO TFLite builtin operator. + + StableHLO builtins do not have TFLite fused activation attributes. Keep + this path independent from the regular TFLite elemwise/QNN helpers so + StableHLO semantics are mapped directly to Relax operators. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + assert len(self.get_output_tensors(op)) == 1, "output tensors length should be 1" + + in_expr = self.get_tensor_expr(input_tensors[0]) + return relax_op(in_expr) + + def _convert_stablehlo_binary(self, op, relax_op): + """Convert a binary StableHLO TFLite builtin operator. + + StableHLO builtins do not have TFLite fused activation attributes. Keep + this path independent from the regular TFLite elemwise/QNN helpers so + StableHLO semantics are mapped directly to Relax operators. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + assert len(self.get_output_tensors(op)) == 1, "output tensors length should be 1" + + lhs_expr = self.get_tensor_expr(input_tensors[0]) + rhs_expr = self.get_tensor_expr(input_tensors[1]) + return relax_op(lhs_expr, rhs_expr) + def convert_elu(self, op): """Convert TFLite ELU""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index d0401e464984..9e1e4147cad1 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3735,6 +3735,109 @@ def _finish_tflite_model(builder, *, subgraph, operator_codes, buffers): return bytes(builder.Output()) +def _load_model_from_buffer(model_bytes): + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(model_bytes, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(model_bytes, 0) + mod = from_tflite(tflite_model) + mod["main"] = mod["main"].without_attr("params") + return mod + + +def _get_stablehlo_builtin_operator(builtin_name): + if not hasattr(_tfl_builtin_operator, builtin_name): + pytest.skip(f"TFLite schema does not provide BuiltinOperator.{builtin_name}") + return getattr(_tfl_builtin_operator, builtin_name) + + +def _build_stablehlo_model(*, builtin_name, input_count): + """Build a minimal TFLite model containing one StableHLO builtin operator.""" + builder = flatbuffers.Builder(1024) + shape = [2, 2] + output_tensor_idx = input_count + builtin_op = _get_stablehlo_builtin_operator(builtin_name) + + tensors = [_build_tensor(builder, buffer_idx, shape) for buffer_idx in range(input_count + 1)] + stablehlo_op = _build_operator( + builder, + 0, + list(range(input_count)), + [output_tensor_idx], + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[stablehlo_op], + inputs=list(range(input_count)), + outputs=[output_tensor_idx], + ) + operator_codes = [_build_operator_code(builder, builtin_op)] + buffers = [_build_buffer(builder) for _ in range(input_count + 1)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + + +@pytest.mark.parametrize( + "builtin_name, relax_op", + [ + ("STABLEHLO_ABS", R.abs), + ("STABLEHLO_NEGATE", R.negative), + ], +) +def test_stablehlo_unary(builtin_name, relax_op): + """TFLite StableHLO unary elementwise operators.""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name=builtin_name, input_count=1) + ) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = relax_op(x) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.parametrize( + "builtin_name, relax_op", + [ + ("STABLEHLO_ADD", R.add), + ("STABLEHLO_SUBTRACT", R.subtract), + ("STABLEHLO_MULTIPLY", R.multiply), + ("STABLEHLO_DIVIDE", R.divide), + ("STABLEHLO_MAXIMUM", R.maximum), + ("STABLEHLO_MINIMUM", R.minimum), + ], +) +def test_stablehlo_binary(builtin_name, relax_op): + """TFLite StableHLO binary elementwise operators.""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name=builtin_name, input_count=2) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = relax_op(x, y) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def _build_csr_sparsity( builder, *, From cfa59f20e6311bcf37ea978b52c2a866c219f6e2 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 15:53:04 +0800 Subject: [PATCH 2/6] [Relax][Frontend][TFLite] Add remaining StableHLO elementwise ops with ternary SELECT Extend the StableHLO TFLite frontend with all remaining pure elementwise operators that require no attribute parsing: - Unary: STABLEHLO_COSINE (cos), STABLEHLO_EXPONENTIAL (exp), STABLEHLO_FLOOR (floor), STABLEHLO_LOG (log), STABLEHLO_LOGISTIC (sigmoid), STABLEHLO_RSQRT (rsqrt), STABLEHLO_TANH (tanh) - Binary: STABLEHLO_AND (logical_and), STABLEHLO_OR (logical_or), STABLEHLO_POWER (power), STABLEHLO_SHIFT_LEFT (left_shift) - Ternary: STABLEHLO_SELECT (where) with dedicated _convert_stablehlo_ternary helper The existing _convert_stablehlo_unary and _convert_stablehlo_binary helpers are reused; only STABLEHLO_SELECT needs the new ternary converter since R.where requires a 3-input signature with bool condition dtype. Test coverage: 20 structural-equal tests (12 new, 8 from previous commit). The SELECT test uses inline flatbuffer construction to set the condition input dtype to BOOL, matching the R.where requirement. --- .../relax/frontend/tflite/tflite_frontend.py | 89 ++++++++++++ tests/python/relax/test_frontend_tflite.py | 130 +++++++++++++++++- 2 files changed, 217 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 743c1e2747c4..ac8af71e6324 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -245,6 +245,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_ADD": functools.partial( self._convert_stablehlo_binary, relax_op=_op.add ), + "STABLEHLO_AND": self._convert_stablehlo_and, "STABLEHLO_DIVIDE": functools.partial( self._convert_stablehlo_binary, relax_op=_op.divide ), @@ -260,9 +261,40 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_NEGATE": functools.partial( self._convert_stablehlo_unary, relax_op=_op.negative ), + "STABLEHLO_COSINE": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.cos + ), + "STABLEHLO_EXPONENTIAL": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.exp + ), + "STABLEHLO_FLOOR": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.floor + ), + "STABLEHLO_LOG": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.log + ), + "STABLEHLO_LOGISTIC": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.sigmoid + ), + "STABLEHLO_OR": self._convert_stablehlo_or, + "STABLEHLO_POWER": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.power + ), + "STABLEHLO_RSQRT": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.rsqrt + ), + "STABLEHLO_SELECT": functools.partial( + self._convert_stablehlo_ternary, relax_op=_op.where + ), + "STABLEHLO_SHIFT_LEFT": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.left_shift + ), "STABLEHLO_SUBTRACT": functools.partial( self._convert_stablehlo_binary, relax_op=_op.subtract ), + "STABLEHLO_TANH": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.tanh + ), "SQUEEZE": self.convert_squeeze, "STRIDED_SLICE": self.convert_strided_slice, "SUB": functools.partial(self._convert_elemwise, relax_op=_op.subtract), @@ -1377,6 +1409,63 @@ def _convert_stablehlo_binary(self, op, relax_op): rhs_expr = self.get_tensor_expr(input_tensors[1]) return relax_op(lhs_expr, rhs_expr) + def _convert_stablehlo_and(self, op): + """Convert StableHLO AND for bool and integer tensors.""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + assert len(self.get_output_tensors(op)) == 1, "output tensors length should be 1" + + lhs = self.get_tensor_expr(input_tensors[0]) + rhs = self.get_tensor_expr(input_tensors[1]) + dtype = lhs.struct_info.dtype + if dtype == "bool": + op_fn = _op.logical_and + elif dtype.startswith(("int", "uint")): + op_fn = _op.bitwise_and + else: + raise tvm.error.OpNotImplemented( + f"STABLEHLO_AND with dtype {dtype} is not supported" + ) + return self.bb.normalize(op_fn(lhs, rhs)) + + def _convert_stablehlo_or(self, op): + """Convert StableHLO OR for bool and integer tensors.""" + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + + assert len(self.get_output_tensors(op)) == 1, "output tensors length should be 1" + + lhs = self.get_tensor_expr(input_tensors[0]) + rhs = self.get_tensor_expr(input_tensors[1]) + dtype = lhs.struct_info.dtype + if dtype == "bool": + op_fn = _op.logical_or + elif dtype.startswith(("int", "uint")): + op_fn = _op.bitwise_or + else: + raise tvm.error.OpNotImplemented( + f"STABLEHLO_OR with dtype {dtype} is not supported" + ) + return self.bb.normalize(op_fn(lhs, rhs)) + + def _convert_stablehlo_ternary(self, op, relax_op): + """Convert a ternary StableHLO TFLite builtin operator. + + StableHLO builtins do not have TFLite fused activation attributes. Keep + this path independent from the regular TFLite elemwise/QNN helpers so + StableHLO semantics are mapped directly to Relax operators. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + + assert len(self.get_output_tensors(op)) == 1, "output tensors length should be 1" + + arg0 = self.get_tensor_expr(input_tensors[0]) + arg1 = self.get_tensor_expr(input_tensors[1]) + arg2 = self.get_tensor_expr(input_tensors[2]) + return relax_op(arg0, arg1, arg2) + def convert_elu(self, op): """Convert TFLite ELU""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 9e1e4147cad1..e5f6ba29e2c7 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3779,11 +3779,44 @@ def _build_stablehlo_model(*, builtin_name, input_count): ) +def _build_stablehlo_typed_binary_model(*, builtin_name, tensor_type): + """Build a minimal TFLite StableHLO binary model with the requested tensor type.""" + builder = flatbuffers.Builder(1024) + shape = [2, 2] + output_tensor_idx = 2 + builtin_op = _get_stablehlo_builtin_operator(builtin_name) + + tensors = [ + _build_tensor(builder, buffer_idx, shape, tensor_type=tensor_type) + for buffer_idx in range(3) + ] + stablehlo_op = _build_operator(builder, 0, [0, 1], [output_tensor_idx]) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[stablehlo_op], + inputs=[0, 1], + outputs=[output_tensor_idx], + ) + operator_codes = [_build_operator_code(builder, builtin_op)] + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + + @pytest.mark.parametrize( "builtin_name, relax_op", [ ("STABLEHLO_ABS", R.abs), + ("STABLEHLO_COSINE", R.cos), + ("STABLEHLO_EXPONENTIAL", R.exp), + ("STABLEHLO_FLOOR", R.floor), + ("STABLEHLO_LOG", R.log), + ("STABLEHLO_LOGISTIC", R.sigmoid), ("STABLEHLO_NEGATE", R.negative), + ("STABLEHLO_RSQRT", R.rsqrt), + ("STABLEHLO_TANH", R.tanh), ], ) def test_stablehlo_unary(builtin_name, relax_op): @@ -3809,11 +3842,12 @@ def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="float3 "builtin_name, relax_op", [ ("STABLEHLO_ADD", R.add), - ("STABLEHLO_SUBTRACT", R.subtract), - ("STABLEHLO_MULTIPLY", R.multiply), ("STABLEHLO_DIVIDE", R.divide), ("STABLEHLO_MAXIMUM", R.maximum), ("STABLEHLO_MINIMUM", R.minimum), + ("STABLEHLO_MULTIPLY", R.multiply), + ("STABLEHLO_POWER", R.power), + ("STABLEHLO_SUBTRACT", R.subtract), ], ) def test_stablehlo_binary(builtin_name, relax_op): @@ -3838,6 +3872,98 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +@pytest.mark.parametrize( + "builtin_name, relax_op, dtype, tensor_type", + [ + ("STABLEHLO_AND", R.logical_and, "bool", _tfl_tensor_type.BOOL), + ("STABLEHLO_OR", R.logical_or, "bool", _tfl_tensor_type.BOOL), + ("STABLEHLO_AND", R.bitwise_and, "int32", _tfl_tensor_type.INT32), + ("STABLEHLO_OR", R.bitwise_or, "int32", _tfl_tensor_type.INT32), + ("STABLEHLO_SHIFT_LEFT", R.left_shift, "int32", _tfl_tensor_type.INT32), + ], +) +def test_stablehlo_typed_binary(builtin_name, relax_op, dtype, tensor_type): + """TFLite StableHLO binary elementwise operators with non-float dtype requirements.""" + mod = _load_model_from_buffer( + _build_stablehlo_typed_binary_model( + builtin_name=builtin_name, tensor_type=tensor_type + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype=dtype), + y: R.Tensor((2, 2), dtype=dtype), + ) -> R.Tensor((2, 2), dtype=dtype): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype=dtype) = relax_op(x, y) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +@pytest.mark.parametrize( + "builtin_name, relax_op", + [ + ("STABLEHLO_SELECT", R.where), + ], +) +def test_stablehlo_ternary(builtin_name, relax_op): + """TFLite StableHLO ternary elementwise operators.""" + builder = flatbuffers.Builder(1024) + shape = [2, 2] + builtin_op = _get_stablehlo_builtin_operator(builtin_name) + + # First input (condition) must be bool for R.where + tensor_0 = _build_tensor(builder, 0, shape, tensor_type=_tfl_tensor_type.BOOL) + tensor_1 = _build_tensor(builder, 1, shape) + tensor_2 = _build_tensor(builder, 2, shape) + tensor_out = _build_tensor(builder, 3, shape) + tensors = [tensor_0, tensor_1, tensor_2, tensor_out] + + stablehlo_op = _build_operator( + builder, + 0, + [0, 1, 2], + [3], + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[stablehlo_op], + inputs=[0, 1, 2], + outputs=[3], + ) + operator_codes = [_build_operator_code(builder, builtin_op)] + buffers = [_build_buffer(builder) for _ in range(4)] + + mod = _load_model_from_buffer( + _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=operator_codes, buffers=buffers + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + c: R.Tensor((2, 2), dtype="bool"), + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = relax_op(c, x, y) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + def _build_csr_sparsity( builder, *, From 6de2fa811e93068a3c2d7bec0388cf2672854ea9 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 16:25:11 +0800 Subject: [PATCH 3/6] [Relax][Frontend][TFLite] Add StableHLO options-based ops with BuiltinOptions2 support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce the first batch of StableHLO TFLite builtin operators that require BuiltinOptions2 attribute parsing: - STABLEHLO_CONVERT → R.astype (reads output dtype from tensor metadata) - STABLEHLO_CLAMP → R.minimum(R.maximum(x, min), max) (arg reordering) - STABLEHLO_CONCATENATE → R.concat with StablehloConcatenateOptions - STABLEHLO_BROADCAST_IN_DIM → R.broadcast_to with broadcast dimensions - STABLEHLO_IOTA → R.arange + R.reshape + R.broadcast_to - STABLEHLO_COMPARE → R.equal/greater/less/... with 6 comparison directions Add _get_stablehlo_options helper for parsing BuiltinOptions2 flatbuffers. R.clip was considered for CLAMP but rejected because it only accepts scalar PrimValue min/max, not tensor inputs. Test coverage: 32 structural-equal tests (20 previous + 12 new) passed. --- .../relax/frontend/tflite/tflite_frontend.py | 198 ++++++++- tests/python/relax/test_frontend_tflite.py | 395 ++++++++++++++++++ 2 files changed, 578 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index ac8af71e6324..f188c69304a9 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -246,36 +246,42 @@ def __init__(self, model, subgraph, exp_tab, ctx): self._convert_stablehlo_binary, relax_op=_op.add ), "STABLEHLO_AND": self._convert_stablehlo_and, - "STABLEHLO_DIVIDE": functools.partial( - self._convert_stablehlo_binary, relax_op=_op.divide - ), - "STABLEHLO_MAXIMUM": functools.partial( - self._convert_stablehlo_binary, relax_op=_op.maximum - ), - "STABLEHLO_MINIMUM": functools.partial( - self._convert_stablehlo_binary, relax_op=_op.minimum - ), - "STABLEHLO_MULTIPLY": functools.partial( - self._convert_stablehlo_binary, relax_op=_op.multiply - ), - "STABLEHLO_NEGATE": functools.partial( - self._convert_stablehlo_unary, relax_op=_op.negative - ), + "STABLEHLO_BROADCAST_IN_DIM": self._convert_stablehlo_broadcast_in_dim, + "STABLEHLO_CLAMP": self._convert_stablehlo_clamp, + "STABLEHLO_COMPARE": self._convert_stablehlo_compare, + "STABLEHLO_CONCATENATE": self._convert_stablehlo_concatenate, + "STABLEHLO_CONVERT": self._convert_stablehlo_convert, "STABLEHLO_COSINE": functools.partial( self._convert_stablehlo_unary, relax_op=_op.cos ), + "STABLEHLO_DIVIDE": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.divide + ), "STABLEHLO_EXPONENTIAL": functools.partial( self._convert_stablehlo_unary, relax_op=_op.exp ), "STABLEHLO_FLOOR": functools.partial( self._convert_stablehlo_unary, relax_op=_op.floor ), + "STABLEHLO_IOTA": self._convert_stablehlo_iota, "STABLEHLO_LOG": functools.partial( self._convert_stablehlo_unary, relax_op=_op.log ), "STABLEHLO_LOGISTIC": functools.partial( self._convert_stablehlo_unary, relax_op=_op.sigmoid ), + "STABLEHLO_MAXIMUM": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.maximum + ), + "STABLEHLO_MINIMUM": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.minimum + ), + "STABLEHLO_MULTIPLY": functools.partial( + self._convert_stablehlo_binary, relax_op=_op.multiply + ), + "STABLEHLO_NEGATE": functools.partial( + self._convert_stablehlo_unary, relax_op=_op.negative + ), "STABLEHLO_OR": self._convert_stablehlo_or, "STABLEHLO_POWER": functools.partial( self._convert_stablehlo_binary, relax_op=_op.power @@ -1466,6 +1472,168 @@ def _convert_stablehlo_ternary(self, op, relax_op): arg2 = self.get_tensor_expr(input_tensors[2]) return relax_op(arg0, arg1, arg2) + def _get_stablehlo_options(self, op, options_cls): + """Parse BuiltinOptions2 for a StableHLO TFLite builtin operator. + + Returns an initialized options object of the given class. + """ + from tflite.BuiltinOptions2 import BuiltinOptions2 + + op_options = op.BuiltinOptions2() + # Look up the expected BuiltinOptions2 enum value by matching the class + # name to an enum member (e.g. StablehloConcatenateOptions → 1). + options_type = getattr(BuiltinOptions2, options_cls.__name__, None) + if options_type is not None: + assert op.BuiltinOptions2Type() == options_type, ( + f"Unexpected BuiltinOptions2 type: expected " + f"{options_cls.__name__}, got {op.BuiltinOptions2Type()}" + ) + result = options_cls() + result.Init(op_options.Bytes, op_options.Pos) + return result + + def _convert_stablehlo_convert(self, op): + """Convert STABLEHLO_CONVERT to Relax (astype). + + Reads the output tensor dtype from the TFLite schema and applies + relax.op.astype. This path is intentionally separate from the + generic _convert_stablehlo_unary helper because the output dtype + is operator-level metadata, not a Relax op parameter. + """ + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1, "input tensors length should be 1" + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1, "output tensors length should be 1" + + in_expr = self.get_tensor_expr(input_tensors[0]) + output_dtype = self.get_tensor_type_str(output_tensors[0].tensor.Type()) + return self.bb.normalize(relax.op.astype(in_expr, output_dtype)) + + def _convert_stablehlo_clamp(self, op): + """Convert STABLEHLO_CLAMP to Relax. + + StableHLO clamp(min, operand, max) → R.minimum(R.maximum(operand, min), max). + """ + # NOTE: R.clip is not used here because it only accepts scalar PrimValue + # min/max, not tensor inputs. + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 3, "input tensors length should be 3" + + assert len(self.get_output_tensors(op)) == 1 + + min_expr = self.get_tensor_expr(input_tensors[0]) + operand_expr = self.get_tensor_expr(input_tensors[1]) + max_expr = self.get_tensor_expr(input_tensors[2]) + + clamped = self.bb.normalize(relax.op.maximum(operand_expr, min_expr)) + return self.bb.normalize(relax.op.minimum(clamped, max_expr)) + + def _convert_stablehlo_concatenate(self, op): + """Convert STABLEHLO_CONCATENATE to Relax.""" + from tflite.StablehloConcatenateOptions import StablehloConcatenateOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) >= 1, "input tensors length should be >= 1" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloConcatenateOptions) + dim = opts.Dimension() + + in_exprs = [self.get_tensor_expr(t) for t in input_tensors] + return self.bb.normalize(relax.op.concat(in_exprs, axis=dim)) + + def _convert_stablehlo_broadcast_in_dim(self, op): + """Convert STABLEHLO_BROADCAST_IN_DIM to Relax.""" + from tflite.StablehloBroadcastInDimOptions import StablehloBroadcastInDimOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 1 + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1 + + opts = self._get_stablehlo_options(op, StablehloBroadcastInDimOptions) + broadcast_dims = [int(d) for d in opts.BroadcastDimensionsAsNumpy()] + + in_expr = self.get_tensor_expr(input_tensors[0]) + input_shape = [int(d) for d in self.get_tensor_shape(input_tensors[0])] + output_shape = [int(d) for d in self.get_tensor_shape(output_tensors[0])] + + # Map input dims to output dims via broadcast_dims, filling + # unmapped positions with 1 so broadcast_to covers them. + intermediate_shape = [1] * len(output_shape) + for i, d in enumerate(broadcast_dims): + intermediate_shape[d] = input_shape[i] + + reshaped = self.bb.normalize(relax.op.reshape(in_expr, intermediate_shape)) + return self.bb.normalize(relax.op.broadcast_to(reshaped, output_shape)) + + def _convert_stablehlo_iota(self, op): + """Convert STABLEHLO_IOTA to Relax (arange + broadcast).""" + from tflite.StablehloIotaOptions import StablehloIotaOptions + + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1 + + opts = self._get_stablehlo_options(op, StablehloIotaOptions) + iota_dim = opts.IotaDimension() + + output_tensor = output_tensors[0] + output_shape = [int(d) for d in self.get_tensor_shape(output_tensor)] + output_dtype = self.get_tensor_type_str(output_tensor.tensor.Type()) + + # arange along the iota dimension + size = output_shape[iota_dim] + arange_1d = self.bb.normalize(relax.op.arange(0, size, 1, output_dtype)) + + # reshape to [1, ..., size, ..., 1] + broadcast_shape = [1] * len(output_shape) + broadcast_shape[iota_dim] = size + arange_reshaped = self.bb.normalize(relax.op.reshape(arange_1d, broadcast_shape)) + + # broadcast to full output shape + return self.bb.normalize(relax.op.broadcast_to(arange_reshaped, output_shape)) + + def _convert_stablehlo_compare(self, op): + """Convert STABLEHLO_COMPARE to Relax binary comparison ops.""" + from tflite.StablehloCompareOptions import StablehloCompareOptions + from tflite.StablehloComparisonDirection import StablehloComparisonDirection + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2 + assert len(self.get_output_tensors(op)) == 1 + + from tflite.StablehloComparisonType import StablehloComparisonType + + opts = self._get_stablehlo_options(op, StablehloCompareOptions) + direction = opts.ComparisonDirection() + compare_type = opts.CompareType() + + # TOTALORDER compare is not expressible via Relax comparison ops. + if compare_type == StablehloComparisonType.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER: + raise tvm.error.OpNotImplemented( + "STABLEHLO_COMPARE with TOTALORDER comparison type is not supported" + ) + + _DIR = StablehloComparisonDirection + direction_map = { + _DIR.STABLEHLO_COMPARISON_DIRECTION_EQ: relax.op.equal, + _DIR.STABLEHLO_COMPARISON_DIRECTION_NE: relax.op.not_equal, + _DIR.STABLEHLO_COMPARISON_DIRECTION_GE: relax.op.greater_equal, + _DIR.STABLEHLO_COMPARISON_DIRECTION_GT: relax.op.greater, + _DIR.STABLEHLO_COMPARISON_DIRECTION_LE: relax.op.less_equal, + _DIR.STABLEHLO_COMPARISON_DIRECTION_LT: relax.op.less, + } + relax_fn = direction_map.get(direction) + if relax_fn is None: + raise tvm.error.OpNotImplemented( + f"Unsupported StableHLO comparison direction: {direction}" + ) + + lhs = self.get_tensor_expr(input_tensors[0]) + rhs = self.get_tensor_expr(input_tensors[1]) + return self.bb.normalize(relax_fn(lhs, rhs)) + def convert_elu(self, op): """Convert TFLite ELU""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index e5f6ba29e2c7..3d500d48f387 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3563,6 +3563,14 @@ def _get_tflite_schema_enum(enum_name): _tfl_buffer = _get_tflite_schema_module("Buffer") _tfl_conv2d_options = _get_tflite_schema_module("Conv2DOptions") _tfl_dilate_options = _get_tflite_schema_module("DilateOptions") + +# ── StableHLO BuiltinOptions2 schema modules ──────────────────────────── +_tfl_stablehlo_concat_opts = _get_tflite_schema_module("StablehloConcatenateOptions") +_tfl_stablehlo_bcast_opts = _get_tflite_schema_module("StablehloBroadcastInDimOptions") +_tfl_stablehlo_iota_opts = _get_tflite_schema_module("StablehloIotaOptions") +_tfl_stablehlo_compare_opts = _get_tflite_schema_module("StablehloCompareOptions") +_tfl_stablehlo_comp_dir = _get_tflite_schema_module("StablehloComparisonDirection") +_tfl_stablehlo_comp_type = _get_tflite_schema_module("StablehloComparisonType") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") @@ -3961,9 +3969,396 @@ def main( R.output(gv) return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + + + +def _build_stablehlo_convert_model(): + """STABLEHLO_CONVERT: float32 input -> int32 output.""" + builder = flatbuffers.Builder(1024) + shape = [2, 2] + + t_in = _build_tensor(builder, 0, shape, tensor_type=_tfl_tensor_type.FLOAT32) + t_out = _build_tensor(builder, 1, shape, tensor_type=_tfl_tensor_type.INT32) + tensors = [t_in, t_out] + + op_code = _build_operator_code( + builder, _get_stablehlo_builtin_operator("STABLEHLO_CONVERT") + ) + op = _build_operator(builder, 0, [0], [1]) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=[0], + outputs=[1], + ) + buffers = [_build_buffer(builder) for _ in range(2)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_convert(): + """TFLite StableHLO CONVERT (astype float32 -> int32).""" + mod = _load_model_from_buffer(_build_stablehlo_convert_model()) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 2), dtype="float32")) -> R.Tensor((2, 2), dtype="int32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="int32") = R.astype(x, dtype="int32") + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_clamp(): + """TFLite StableHLO CLAMP (clip with min/operand/max order).""" + mod = _load_model_from_buffer( + _build_stablehlo_model(builtin_name="STABLEHLO_CLAMP", input_count=3) + ) + + @I.ir_module + class Expected: + @R.function + def main( + m: R.Tensor((2, 2), dtype="float32"), + x: R.Tensor((2, 2), dtype="float32"), + M: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="float32"): + R.func_attr({"num_input": 3}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="float32") = R.minimum(R.maximum(x, m), M) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_stablehlo_concat_model(dimension, num_inputs): + """STABLEHLO_CONCATENATE with given dimension and number of inputs.""" + builder = flatbuffers.Builder(1024) + shape = [2, 2] + + # Build concat options + _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsStart(builder) + _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsAddDimension(builder, dimension) + concat_opts = _tfl_stablehlo_concat_opts.StablehloConcatenateOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_CONCATENATE") + op_code = _build_operator_code(builder, builtin_op) + + if dimension == 0: + out_shape = [num_inputs * shape[0], shape[1]] + else: + out_shape = [shape[0], num_inputs * shape[1]] + tensors = [ + _build_tensor(builder, i, shape) for i in range(num_inputs) + ] + [_build_tensor(builder, num_inputs, out_shape)] + + op = _build_operator( + builder, + 0, + list(range(num_inputs)), + [num_inputs], + builtin_options2_type=_tfl_builtin_options2.StablehloConcatenateOptions, + builtin_options2=concat_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=list(range(num_inputs)), + outputs=[num_inputs], + ) + buffers = [_build_buffer(builder) for _ in range(num_inputs + 1)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +@pytest.mark.parametrize("dimension", [0, 1]) +def test_stablehlo_concatenate(dimension): + """TFLite StableHLO CONCATENATE with 2 inputs along given axis.""" + num_inputs = 2 + mod = _load_model_from_buffer( + _build_stablehlo_concat_model(dimension=dimension, num_inputs=num_inputs) + ) + + out_dim = (4, 2) if dimension == 0 else (2, 4) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor(out_dim, dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor(out_dim, dtype="float32") = R.concat((x, y), axis=dimension) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_stablehlo_broadcast_in_dim_model(input_shape, broadcast_dims, output_shape): + """STABLEHLO_BROADCAST_IN_DIM with given broadcast dimensions.""" + builder = flatbuffers.Builder(1024) + + # Build broadcast dimensions vector + _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsStartBroadcastDimensionsVector( + builder, len(broadcast_dims) + ) + for d in reversed(broadcast_dims): + builder.PrependInt64(d) + dims_vec = builder.EndVector() + + _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsStart(builder) + _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsAddBroadcastDimensions( + builder, dims_vec + ) + bcast_opts = _tfl_stablehlo_bcast_opts.StablehloBroadcastInDimOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_BROADCAST_IN_DIM") + op_code = _build_operator_code(builder, builtin_op) + + t_in = _build_tensor(builder, 0, input_shape) + t_out = _build_tensor(builder, 1, output_shape) + tensors = [t_in, t_out] + + op = _build_operator( + builder, + 0, + [0], + [1], + builtin_options2_type=_tfl_builtin_options2.StablehloBroadcastInDimOptions, + builtin_options2=bcast_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=[0], + outputs=[1], + ) + buffers = [_build_buffer(builder) for _ in range(2)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_broadcast_in_dim(): + """TFLite StableHLO BROADCAST_IN_DIM: (3,) -> (2, 3) with dims=[1].""" + mod = _load_model_from_buffer( + _build_stablehlo_broadcast_in_dim_model( + input_shape=[3], broadcast_dims=[1], output_shape=[2, 3] + ) + ) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor((3,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="float32") = R.broadcast_to( + R.reshape(x, (1, 3)), (2, 3) + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_stablehlo_iota_model(iota_dimension, output_shape): + """STABLEHLO_IOTA with given iota dimension and output shape.""" + builder = flatbuffers.Builder(1024) + + _tfl_stablehlo_iota_opts.StablehloIotaOptionsStart(builder) + _tfl_stablehlo_iota_opts.StablehloIotaOptionsAddIotaDimension(builder, iota_dimension) + iota_opts = _tfl_stablehlo_iota_opts.StablehloIotaOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_IOTA") + op_code = _build_operator_code(builder, builtin_op) + + t_out = _build_tensor(builder, 0, output_shape, tensor_type=_tfl_tensor_type.INT32) + tensors = [t_out] + + op = _build_operator( + builder, + 0, + [], + [0], + builtin_options2_type=_tfl_builtin_options2.StablehloIotaOptions, + builtin_options2=iota_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=[], + outputs=[0], + ) + buffers = [_build_buffer(builder)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_iota(): + """TFLite StableHLO IOTA: iota_dim=1, shape=(2, 3), dtype=int32.""" + mod = _load_model_from_buffer( + _build_stablehlo_iota_model(iota_dimension=1, output_shape=[2, 3]) + ) + + @I.ir_module + class Expected: + @R.function + def main() -> R.Tensor((2, 3), dtype="int32"): + R.func_attr({"num_input": 0}) + with R.dataflow(): + gv: R.Tensor((2, 3), dtype="int32") = R.broadcast_to( + R.reshape(R.arange(0, 3, 1, dtype="int32"), (1, 3)), (2, 3) + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def _build_stablehlo_compare_model(direction): + """STABLEHLO_COMPARE with given comparison direction.""" + builder = flatbuffers.Builder(1024) + + _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) + _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection(builder, direction) + cmp_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE") + op_code = _build_operator_code(builder, builtin_op) + + shape = [2, 2] + t_lhs = _build_tensor(builder, 0, shape) + t_rhs = _build_tensor(builder, 1, shape) + t_out = _build_tensor(builder, 2, shape, tensor_type=_tfl_tensor_type.BOOL) + tensors = [t_lhs, t_rhs, t_out] + + op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, + builtin_options2=cmp_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=[0, 1], + outputs=[2], + ) + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +@pytest.mark.parametrize( + "direction_enum, relax_op", + [ + (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_EQ, R.equal), + (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_NE, R.not_equal), + (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GE, R.greater_equal), + (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_GT, R.greater), + (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LE, R.less_equal), + (_tfl_stablehlo_comp_dir.StablehloComparisonDirection.STABLEHLO_COMPARISON_DIRECTION_LT, R.less), + ], +) +def test_stablehlo_compare(direction_enum, relax_op): + """TFLite StableHLO COMPARE with various comparison directions.""" + mod = _load_model_from_buffer(_build_stablehlo_compare_model(direction_enum)) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((2, 2), dtype="float32"), + y: R.Tensor((2, 2), dtype="float32"), + ) -> R.Tensor((2, 2), dtype="bool"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + gv: R.Tensor((2, 2), dtype="bool") = relax_op(x, y) + R.output(gv) + return gv + tvm.ir.assert_structural_equal(mod, Expected) +def test_stablehlo_compare_totalorder_unsupported(): + """STABLEHLO_COMPARE with TOTALORDER type raises OpNotImplemented.""" + builder = flatbuffers.Builder(1024) + + _DIR = _tfl_stablehlo_comp_dir.StablehloComparisonDirection + _TYPE = _tfl_stablehlo_comp_type.StablehloComparisonType + + _tfl_stablehlo_compare_opts.StablehloCompareOptionsStart(builder) + _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddComparisonDirection( + builder, _DIR.STABLEHLO_COMPARISON_DIRECTION_EQ + ) + _tfl_stablehlo_compare_opts.StablehloCompareOptionsAddCompareType( + builder, _TYPE.STABLEHLO_COMPARISON_TYPE_FLOAT_TOTAL_ORDER + ) + cmp_opts = _tfl_stablehlo_compare_opts.StablehloCompareOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_COMPARE") + op_code = _build_operator_code(builder, builtin_op) + + shape = [2, 2] + t_lhs = _build_tensor(builder, 0, shape) + t_rhs = _build_tensor(builder, 1, shape) + t_out = _build_tensor(builder, 2, shape, tensor_type=_tfl_tensor_type.BOOL) + tensors = [t_lhs, t_rhs, t_out] + + op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloCompareOptions, + builtin_options2=cmp_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=tensors, + operators=[op], + inputs=[0, 1], + outputs=[2], + ) + buffers = [_build_buffer(builder) for _ in range(3)] + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="TOTALORDER"): + from_tflite(tflite_model) + + + + def _build_csr_sparsity( builder, *, From 7eb6ed3025585a938df311eac7e70d6262b5db12 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 17:15:42 +0800 Subject: [PATCH 4/6] [Relax][Frontend][TFLite] Add StableHLO PAD and DYNAMIC_SLICE support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add frontend mapping for two StableHLO TFLite builtin operators that manipulate tensor shapes: - STABLEHLO_PAD → R.nn.pad with constant mode. Parses EdgePaddingLow, EdgePaddingHigh, and InteriorPadding from StablehloPadOptions. Raises OpNotImplemented when interior (dilation) padding is non-zero. - STABLEHLO_DYNAMIC_SLICE → R.dynamic_strided_slice. Reads SliceSizes from StablehloDynamicSliceOptions and start indices from scalar tensor inputs. Begin/end/strides are constructed as int64 1D tensors. Both ops extend the BuiltinOptions2 parsing infrastructure introduced in the previous commit, adding vector-attribute (PAD) and dynamic-input (DYNAMIC_SLICE) patterns. Test coverage: 33 structural-equal tests passed (31 previous + 2 new). --- .../relax/frontend/tflite/tflite_frontend.py | 105 ++++++ tests/python/relax/test_frontend_tflite.py | 342 ++++++++++++++++++ 2 files changed, 447 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index f188c69304a9..d49058abccf0 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -257,6 +257,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_DIVIDE": functools.partial( self._convert_stablehlo_binary, relax_op=_op.divide ), + "STABLEHLO_DYNAMIC_SLICE": self._convert_stablehlo_dynamic_slice, "STABLEHLO_EXPONENTIAL": functools.partial( self._convert_stablehlo_unary, relax_op=_op.exp ), @@ -283,6 +284,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): self._convert_stablehlo_unary, relax_op=_op.negative ), "STABLEHLO_OR": self._convert_stablehlo_or, + "STABLEHLO_PAD": self._convert_stablehlo_pad, "STABLEHLO_POWER": functools.partial( self._convert_stablehlo_binary, relax_op=_op.power ), @@ -1634,6 +1636,109 @@ def _convert_stablehlo_compare(self, op): rhs = self.get_tensor_expr(input_tensors[1]) return self.bb.normalize(relax_fn(lhs, rhs)) + def _convert_stablehlo_pad(self, op): + """Convert STABLEHLO_PAD to Relax (nn.pad). + + Maps edge padding to R.nn.pad with constant mode. Interior padding + (dilation) is not supported in the first version. + """ + from tflite.StablehloPadOptions import StablehloPadOptions + + input_tensors = self.get_input_tensors(op) + # operand + padding_value + assert len(input_tensors) == 2, "input tensors length should be 2" + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloPadOptions) + edge_low = [int(d) for d in opts.EdgePaddingLowAsNumpy()] + edge_high = [int(d) for d in opts.EdgePaddingHighAsNumpy()] + interior = [int(d) for d in opts.InteriorPaddingAsNumpy()] + + if any(d != 0 for d in interior): + raise tvm.error.OpNotImplemented( + "STABLEHLO_PAD with interior (dilation) padding is not supported" + ) + if any(d < 0 for d in edge_low) or any(d < 0 for d in edge_high): + raise tvm.error.OpNotImplemented( + "STABLEHLO_PAD with negative edge padding (crop) is not supported" + ) + + operand = self.get_tensor_expr(input_tensors[0]) + + # R.nn.pad only supports a static Python float pad_value. + pad_value_tensor = input_tensors[1] + if not self.has_expr(pad_value_tensor.tensor_idx): + pad_val = float(self.get_tensor_value(pad_value_tensor)) + else: + raise tvm.error.OpNotImplemented( + "STABLEHLO_PAD with dynamic padding value is not supported" + ) + + # R.nn.pad with flat pad_width: [lo0, hi0, lo1, hi1, ...] + pad_width = [] + for lo, hi in zip(edge_low, edge_high): + pad_width.extend([lo, hi]) + + return self.bb.normalize( + relax.op.nn.pad(operand, pad_width=pad_width, pad_value=pad_val) + ) + + def _convert_stablehlo_dynamic_slice(self, op): + """Convert STABLEHLO_DYNAMIC_SLICE to Relax (dynamic_strided_slice). + + Start indices are assumed to be constant (non-dynamic) values stored + in the flatbuffer. Truly dynamic (runtime) start indices require + Relax arithmetic to compute begin/end from scalar inputs and are not + yet supported. + """ + from tflite.StablehloDynamicSliceOptions import StablehloDynamicSliceOptions + + input_tensors = self.get_input_tensors(op) + # operand + N start-index scalars + assert len(input_tensors) >= 2 + ndim = len(input_tensors) - 1 + assert len(self.get_output_tensors(op)) == 1 + + opts = self._get_stablehlo_options(op, StablehloDynamicSliceOptions) + slice_sizes = [int(d) for d in opts.SliceSizesAsNumpy()] + assert len(slice_sizes) == ndim + + operand = self.get_tensor_expr(input_tensors[0]) + + # Build constant 1D tensors for begin, end, strides + # (assumes start values are constant in the flatbuffer) + # TODO: support dynamic start indices via Relax arithmetic + if any(self.has_expr(t.tensor_idx) for t in input_tensors[1:]): + raise tvm.error.OpNotImplemented( + "STABLEHLO_DYNAMIC_SLICE with dynamic start indices is not supported" + ) + start_vals = [int(self.get_tensor_value(t)) for t in input_tensors[1:]] + operand_shape = [int(d) for d in self.get_tensor_shape(input_tensors[0])] + for start, size, dim in zip(start_vals, slice_sizes, operand_shape): + if start < 0 or start + size > dim: + raise tvm.error.OpNotImplemented( + "STABLEHLO_DYNAMIC_SLICE with out-of-bounds start indices is not supported" + ) + end_vals = [s + sz for s, sz in zip(start_vals, slice_sizes)] + stride_vals = [1] * ndim + + def _const_1d(values, dtype="int64"): + arr = np.array(values, dtype=dtype) + return self.bb.normalize( + relax.op.reshape( + relax.const(arr, dtype=dtype), + [len(values)], + ) + ) + + begin = _const_1d(start_vals) + end = _const_1d(end_vals) + strides = _const_1d(stride_vals) + + return self.bb.normalize( + relax.op.dynamic_strided_slice(operand, begin, end, strides) + ) + def convert_elu(self, op): """Convert TFLite ELU""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 3d500d48f387..42e2c3f02412 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3571,6 +3571,8 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_compare_opts = _get_tflite_schema_module("StablehloCompareOptions") _tfl_stablehlo_comp_dir = _get_tflite_schema_module("StablehloComparisonDirection") _tfl_stablehlo_comp_type = _get_tflite_schema_module("StablehloComparisonType") +_tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions") +_tfl_stablehlo_dyn_slice_opts = _get_tflite_schema_module("StablehloDynamicSliceOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") @@ -4359,6 +4361,346 @@ def test_stablehlo_compare_totalorder_unsupported(): + + +def _pad_vector(builder, values): + """Build a FlatBuffers int64 vector for pad options.""" + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector( + builder, len(values) + ) + for v in reversed(values): + builder.PrependInt64(v) + return builder.EndVector() + + +def _build_stablehlo_pad_model(edge_low, edge_high, interior): + """STABLEHLO_PAD with given padding vectors.""" + builder = flatbuffers.Builder(1024) + + # Build EdgePaddingLow vector + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector(builder, len(edge_low)) + for v in reversed(edge_low): + builder.PrependInt64(v) + lo_vec = builder.EndVector() + + # Build EdgePaddingHigh vector + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector(builder, len(edge_high)) + for v in reversed(edge_high): + builder.PrependInt64(v) + hi_vec = builder.EndVector() + + # Build InteriorPadding vector + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector(builder, len(interior)) + for v in reversed(interior): + builder.PrependInt64(v) + int_vec = builder.EndVector() + + _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder, hi_vec) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder, int_vec) + pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD") + op_code = _build_operator_code(builder, builtin_op) + + t_in = _build_tensor(builder, 0, [3, 3]) + # pad_value is a scalar tensor + t_pad_val = _build_tensor(builder, 1, []) + t_out = _build_tensor(builder, 2, [4, 4]) + tensors = [t_in, t_pad_val, t_out] + + op = _build_operator( + builder, 0, [0, 1], [2], + builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, + builtin_options2=pad_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[op], + inputs=[0], outputs=[2], + ) + buffers = [ + _build_buffer(builder), + _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()), + _build_buffer(builder), + ] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_pad(): + """TFLite StableHLO PAD: edge_low=[1,0], edge_high=[0,1], interior=[0,0].""" + mod = _load_model_from_buffer( + _build_stablehlo_pad_model(edge_low=[1, 0], edge_high=[0, 1], interior=[0, 0]) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3, 3), dtype="float32"), + ) -> R.Tensor((4, 4), dtype="float32"): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor((4, 4), dtype="float32") = R.nn.pad( + x, pad_width=[1, 0, 0, 1], pad_value=0.0 + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_pad_interior_unsupported(): + """STABLEHLO_PAD with interior padding raises OpNotImplemented.""" + builder = flatbuffers.Builder(1024) + + lo_vec = _pad_vector(builder, [0, 0]) + hi_vec = _pad_vector(builder, [0, 0]) + int_vec = _pad_vector(builder, [1, 0]) + + _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder, hi_vec) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder, int_vec) + pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD") + op_code = _build_operator_code(builder, builtin_op) + + t_in = _build_tensor(builder, 0, [3, 3]) + t_pv = _build_tensor(builder, 1, []) + t_out = _build_tensor(builder, 2, [3, 3]) + tensors = [t_in, t_pv, t_out] + + op = _build_operator( + builder, 0, [0, 1], [2], + builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, + builtin_options2=pad_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[op], + inputs=[0], outputs=[2], + ) + buffers = [ + _build_buffer(builder), + _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()), + _build_buffer(builder), + ] + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + with pytest.raises(tvm.error.OpNotImplemented, match="interior"): + from_tflite(tflite_model) + + +def test_stablehlo_pad_negative_unsupported(): + """STABLEHLO_PAD with negative edge padding raises OpNotImplemented.""" + builder = flatbuffers.Builder(1024) + + lo_vec = _pad_vector(builder, [-1, 0]) + hi_vec = _pad_vector(builder, [0, 0]) + int_vec = _pad_vector(builder, [0, 0]) + + _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingHigh(builder, hi_vec) + _tfl_stablehlo_pad_opts.StablehloPadOptionsAddInteriorPadding(builder, int_vec) + pad_opts = _tfl_stablehlo_pad_opts.StablehloPadOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_PAD") + op_code = _build_operator_code(builder, builtin_op) + + t_in = _build_tensor(builder, 0, [3, 3]) + t_pv = _build_tensor(builder, 1, []) + t_out = _build_tensor(builder, 2, [2, 3]) + tensors = [t_in, t_pv, t_out] + + op = _build_operator( + builder, 0, [0, 1], [2], + builtin_options2_type=_tfl_builtin_options2.StablehloPadOptions, + builtin_options2=pad_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[op], + inputs=[0], outputs=[2], + ) + buffers = [ + _build_buffer(builder), + _build_buffer(builder, np.array([0.0], dtype=np.float32).tobytes()), + _build_buffer(builder), + ] + buf = _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + with pytest.raises(tvm.error.OpNotImplemented, match="negative"): + from_tflite(tflite_model) + + +def _build_stablehlo_dynamic_slice_model(slice_sizes, start_vals): + """STABLEHLO_DYNAMIC_SLICE with given slice sizes and start indices.""" + builder = flatbuffers.Builder(1024) + ndim = len(slice_sizes) + + # Build SliceSizes vector + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector( + builder, ndim + ) + for v in reversed(slice_sizes): + builder.PrependInt64(v) + sizes_vec = builder.EndVector() + + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder) + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes( + builder, sizes_vec + ) + dyn_opts = _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE") + op_code = _build_operator_code(builder, builtin_op) + + # operand + start indices + output + t_in = _build_tensor(builder, 0, [3, 3]) + start_tensors = [] + start_inputs = [] + start_buffers = [] + for i, sv in enumerate(start_vals): + bidx = 1 + i + start_tensors.append( + _build_tensor(builder, bidx, [], tensor_type=_tfl_tensor_type.INT32) + ) + start_inputs.append(bidx) + start_buffers.append( + _build_buffer(builder, np.array([sv], dtype=np.int32).tobytes()) + ) + out_idx = 1 + ndim + t_out = _build_tensor(builder, out_idx, slice_sizes) + tensors = [t_in, *start_tensors, t_out] + op_inputs = [0, *start_inputs] + + op = _build_operator( + builder, 0, op_inputs, [out_idx], + builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions, + builtin_options2=dyn_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[op], + inputs=[0], outputs=[out_idx], + ) + buffers = [_build_buffer(builder), *start_buffers, _build_buffer(builder)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes): + """STABLEHLO_DYNAMIC_SLICE with runtime start-index inputs.""" + builder = flatbuffers.Builder(1024) + ndim = len(slice_sizes) + + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStartSliceSizesVector( + builder, ndim + ) + for v in reversed(slice_sizes): + builder.PrependInt64(v) + sizes_vec = builder.EndVector() + + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsStart(builder) + _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsAddSliceSizes( + builder, sizes_vec + ) + dyn_opts = _tfl_stablehlo_dyn_slice_opts.StablehloDynamicSliceOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_DYNAMIC_SLICE") + op_code = _build_operator_code(builder, builtin_op) + + t_in = _build_tensor(builder, 0, [3, 3]) + start_tensors = [ + _build_tensor(builder, 1 + i, [], tensor_type=_tfl_tensor_type.INT32) + for i in range(ndim) + ] + out_idx = 1 + ndim + t_out = _build_tensor(builder, out_idx, slice_sizes) + start_inputs = list(range(1, 1 + ndim)) + tensors = [t_in, *start_tensors, t_out] + op_inputs = [0, *start_inputs] + + op = _build_operator( + builder, 0, op_inputs, [out_idx], + builtin_options2_type=_tfl_builtin_options2.StablehloDynamicSliceOptions, + builtin_options2=dyn_opts, + ) + subgraph = _build_subgraph( + builder, tensors=tensors, operators=[op], + inputs=op_inputs, outputs=[out_idx], + ) + buffers = [_build_buffer(builder) for _ in range(out_idx + 1)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +def test_stablehlo_dynamic_slice(): + """TFLite StableHLO DYNAMIC_SLICE: start=[0,1], sizes=[2,2] from (3,3).""" + mod = _load_model_from_buffer( + _build_stablehlo_dynamic_slice_model( + slice_sizes=[2, 2], start_vals=[0, 1] + ) + ) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((3, 3), dtype="float32"), + ) -> R.Tensor(dtype="float32", ndim=2): + R.func_attr({"num_input": 1}) + with R.dataflow(): + gv: R.Tensor(dtype="float32", ndim=2) = R.dynamic_strided_slice( + x, + R.reshape(R.const([0, 1], dtype="int64"), [2]), + R.reshape(R.const([2, 3], dtype="int64"), [2]), + R.reshape(R.const([1, 1], dtype="int64"), [2]), + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_dynamic_slice_dynamic_starts_unsupported(): + """TFLite StableHLO DYNAMIC_SLICE with runtime starts is not supported yet.""" + buf = _build_stablehlo_dynamic_slice_with_dynamic_starts_model(slice_sizes=[2, 2]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="dynamic start"): + from_tflite(tflite_model) + + +def test_stablehlo_dynamic_slice_out_of_bounds_unsupported(): + """TFLite StableHLO DYNAMIC_SLICE with out-of-bounds starts is not supported.""" + buf = _build_stablehlo_dynamic_slice_model(slice_sizes=[2, 2], start_vals=[0, 2]) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="out-of-bounds"): + from_tflite(tflite_model) + + def _build_csr_sparsity( builder, *, From cbb0dbe2d982f3106aa1cfcb57b83d4e4d0d070d Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 17:53:09 +0800 Subject: [PATCH 5/6] [Relax][Frontend][TFLite] Add StableHLO GATHER support (take-equivalent subset) Add frontend mapping for STABLEHLO_GATHER with a conservative take-equivalent implementation: - Parses 6 attributes from StablehloGatherOptions (OffsetDims, CollapsedSliceDims, StartIndexMap, IndexVectorDim, SliceSizes, IndicesAreSorted) - Only supports single-axis gather with index vector dim == rank(indices)-1 and slice_sizes matching R.take semantics - Validates offset_dims layout, output shape, and collapsed dims against expected R.take behavior; raises OpNotImplemented otherwise - Reshapes indices from [N, 1] to [N] before calling R.take Tests: 3 new (2 take-equivalent parametrized for axis 0/1, 1 error path for multi-dimensional start_index_map). Total: 38 stablehlo tests passed. --- .../relax/frontend/tflite/tflite_frontend.py | 81 ++++++++++ tests/python/relax/test_frontend_tflite.py | 152 ++++++++++++++++++ 2 files changed, 233 insertions(+) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index d49058abccf0..793061e0e3c4 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -264,6 +264,7 @@ def __init__(self, model, subgraph, exp_tab, ctx): "STABLEHLO_FLOOR": functools.partial( self._convert_stablehlo_unary, relax_op=_op.floor ), + "STABLEHLO_GATHER": self._convert_stablehlo_gather, "STABLEHLO_IOTA": self._convert_stablehlo_iota, "STABLEHLO_LOG": functools.partial( self._convert_stablehlo_unary, relax_op=_op.log @@ -1739,6 +1740,86 @@ def _const_1d(values, dtype="int64"): relax.op.dynamic_strided_slice(operand, begin, end, strides) ) + + def _convert_stablehlo_gather(self, op): + """Convert STABLEHLO_GATHER to Relax (take-equivalent subset only). + + Only handles gather patterns equivalent to R.take along a single axis. + Multi-dimensional gathers, index_vector_dim != rank(indices)-1, and + non-trivial slice_sizes raise OpNotImplemented. + """ + from tflite.StablehloGatherOptions import StablehloGatherOptions + + input_tensors = self.get_input_tensors(op) + assert len(input_tensors) == 2, "input tensors length should be 2" + output_tensors = self.get_output_tensors(op) + assert len(output_tensors) == 1 + + opts = self._get_stablehlo_options(op, StablehloGatherOptions) + offset_dims = [int(d) for d in opts.OffsetDimsAsNumpy()] + collapsed_slice_dims = [int(d) for d in opts.CollapsedSliceDimsAsNumpy()] + start_index_map = [int(d) for d in opts.StartIndexMapAsNumpy()] + slice_sizes = [int(d) for d in opts.SliceSizesAsNumpy()] + index_vector_dim = int(opts.IndexVectorDim()) + + data_tensor, indices_tensor = input_tensors + data_shape = [int(d) for d in self.get_tensor_shape(data_tensor)] + indices_shape = [int(d) for d in self.get_tensor_shape(indices_tensor)] + output_shape = [int(d) for d in self.get_tensor_shape(output_tensors[0])] + + if len(start_index_map) != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER only supports one start_index_map entry" + ) + axis = start_index_map[0] + if axis < 0 or axis >= len(data_shape): + raise tvm.error.OpNotImplemented(f"Unsupported STABLEHLO_GATHER axis: {axis}") + if collapsed_slice_dims != [axis]: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER only supports collapsed_slice_dims matching the gather axis" + ) + if len(slice_sizes) != len(data_shape): + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER slice_sizes must match operand rank" + ) + for i, (size, dim) in enumerate(zip(slice_sizes, data_shape)): + expected = 1 if i == axis else dim + if size != expected: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER only supports take-equivalent slice_sizes" + ) + if index_vector_dim != len(indices_shape) - 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER only supports trailing index_vector_dim" + ) + if not indices_shape or indices_shape[index_vector_dim] != 1: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER only supports index vector size 1" + ) + + indices_batch_shape = indices_shape[:index_vector_dim] + expected_offset_dims = list(range(axis)) + list( + range(axis + len(indices_batch_shape), len(data_shape) + len(indices_batch_shape) - 1) + ) + if offset_dims != expected_offset_dims: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER offset_dims do not match Relax take output layout" + ) + + expected_output_shape = ( + data_shape[:axis] + indices_batch_shape + data_shape[axis + 1 :] + ) + if output_shape != expected_output_shape: + raise tvm.error.OpNotImplemented( + "STABLEHLO_GATHER output shape does not match Relax take semantics" + ) + + data = self.get_tensor_expr(data_tensor) + indices = self.get_tensor_expr(indices_tensor) + indices = self.bb.normalize(relax.op.reshape(indices, indices_batch_shape)) + return self.bb.normalize(relax.op.take(data, indices, axis=axis, mode="fast")) + + def convert_elu(self, op): """Convert TFLite ELU""" input_tensors = self.get_input_tensors(op) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 42e2c3f02412..7a7ea09c1056 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -3573,6 +3573,7 @@ def _get_tflite_schema_enum(enum_name): _tfl_stablehlo_comp_type = _get_tflite_schema_module("StablehloComparisonType") _tfl_stablehlo_pad_opts = _get_tflite_schema_module("StablehloPadOptions") _tfl_stablehlo_dyn_slice_opts = _get_tflite_schema_module("StablehloDynamicSliceOptions") +_tfl_stablehlo_gather_opts = _get_tflite_schema_module("StablehloGatherOptions") _tfl_dimension_metadata = _get_tflite_schema_module("DimensionMetadata") _tfl_fully_connected_options = _get_tflite_schema_module("FullyConnectedOptions") _tfl_int32_vector = _get_tflite_schema_module("Int32Vector") @@ -4359,6 +4360,157 @@ def test_stablehlo_compare_totalorder_unsupported(): from_tflite(tflite_model) +def _stablehlo_gather_i64_vector(builder, start_vector_fn, values): + start_vector_fn(builder, len(values)) + for value in reversed(values): + builder.PrependInt64(value) + return builder.EndVector() + + +def _build_stablehlo_gather_model( + *, + data_shape, + indices_shape, + output_shape, + offset_dims, + collapsed_slice_dims, + start_index_map, + index_vector_dim, + slice_sizes, +): + """Build a minimal STABLEHLO_GATHER TFLite model.""" + builder = flatbuffers.Builder(1024) + + offset_dims_vec = _stablehlo_gather_i64_vector( + builder, + _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartOffsetDimsVector, + offset_dims, + ) + collapsed_slice_dims_vec = _stablehlo_gather_i64_vector( + builder, + _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartCollapsedSliceDimsVector, + collapsed_slice_dims, + ) + start_index_map_vec = _stablehlo_gather_i64_vector( + builder, + _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartStartIndexMapVector, + start_index_map, + ) + slice_sizes_vec = _stablehlo_gather_i64_vector( + builder, + _tfl_stablehlo_gather_opts.StablehloGatherOptionsStartSliceSizesVector, + slice_sizes, + ) + + _tfl_stablehlo_gather_opts.StablehloGatherOptionsStart(builder) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddOffsetDims( + builder, offset_dims_vec + ) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddCollapsedSliceDims( + builder, collapsed_slice_dims_vec + ) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddStartIndexMap( + builder, start_index_map_vec + ) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddIndexVectorDim( + builder, index_vector_dim + ) + _tfl_stablehlo_gather_opts.StablehloGatherOptionsAddSliceSizes( + builder, slice_sizes_vec + ) + gather_opts = _tfl_stablehlo_gather_opts.StablehloGatherOptionsEnd(builder) + + builtin_op = _get_stablehlo_builtin_operator("STABLEHLO_GATHER") + op_code = _build_operator_code(builder, builtin_op) + + t_data = _build_tensor(builder, 0, data_shape) + t_indices = _build_tensor(builder, 1, indices_shape, tensor_type=_tfl_tensor_type.INT32) + t_out = _build_tensor(builder, 2, output_shape) + op = _build_operator( + builder, + 0, + [0, 1], + [2], + builtin_options2_type=_tfl_builtin_options2.StablehloGatherOptions, + builtin_options2=gather_opts, + ) + subgraph = _build_subgraph( + builder, + tensors=[t_data, t_indices, t_out], + operators=[op], + inputs=[0, 1], + outputs=[2], + ) + buffers = [_build_buffer(builder) for _ in range(3)] + return _finish_tflite_model( + builder, subgraph=subgraph, operator_codes=[op_code], buffers=buffers + ) + + +@pytest.mark.parametrize( + "axis, offset_dims, slice_sizes, output_shape", + [ + (0, [1], [1, 4], [2, 4]), + (1, [0], [3, 1], [3, 2]), + ], +) +def test_stablehlo_gather_take_equivalent(axis, offset_dims, slice_sizes, output_shape): + """TFLite StableHLO GATHER take-equivalent subset.""" + mod = _load_model_from_buffer( + _build_stablehlo_gather_model( + data_shape=[3, 4], + indices_shape=[2, 1], + output_shape=output_shape, + offset_dims=offset_dims, + collapsed_slice_dims=[axis], + start_index_map=[axis], + index_vector_dim=1, + slice_sizes=slice_sizes, + ) + ) + + out_shape = tuple(output_shape) + + @I.ir_module + class Expected: + @R.function + def main( + data: R.Tensor((3, 4), dtype="float32"), + indices: R.Tensor((2, 1), dtype="int32"), + ) -> R.Tensor(out_shape, dtype="float32"): + R.func_attr({"num_input": 2}) + with R.dataflow(): + reshaped: R.Tensor((2,), dtype="int32") = R.reshape(indices, (2,)) + gv: R.Tensor(out_shape, dtype="float32") = R.take( + data, reshaped, axis=axis, mode="fast" + ) + R.output(gv) + return gv + + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_stablehlo_gather_complex_unsupported(): + """TFLite StableHLO GATHER with multi-dimensional start_index_map is unsupported.""" + buf = _build_stablehlo_gather_model( + data_shape=[3, 4], + indices_shape=[2, 2], + output_shape=[2], + offset_dims=[], + collapsed_slice_dims=[0, 1], + start_index_map=[0, 1], + index_vector_dim=1, + slice_sizes=[1, 1], + ) + if hasattr(tflite.Model, "Model"): + tflite_model = tflite.Model.Model.GetRootAsModel(buf, 0) + else: + tflite_model = tflite.Model.GetRootAsModel(buf, 0) + + with pytest.raises(tvm.error.OpNotImplemented, match="start_index_map"): + from_tflite(tflite_model) + + From 3060968dfb36270f953eda46df56f39fdaf5cef9 Mon Sep 17 00:00:00 2001 From: HoYi Date: Mon, 11 May 2026 20:16:34 +0800 Subject: [PATCH 6/6] [Relax][Frontend][TFLite] Fix StableHLO PAD vectors and slice constants --- .../relax/frontend/tflite/tflite_frontend.py | 7 +- tests/python/relax/test_frontend_tflite.py | 85 +++++++++++-------- 2 files changed, 51 insertions(+), 41 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 793061e0e3c4..0157f129be82 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -1725,12 +1725,7 @@ def _convert_stablehlo_dynamic_slice(self, op): def _const_1d(values, dtype="int64"): arr = np.array(values, dtype=dtype) - return self.bb.normalize( - relax.op.reshape( - relax.const(arr, dtype=dtype), - [len(values)], - ) - ) + return self.bb.normalize(relax.const(arr, dtype=dtype)) begin = _const_1d(start_vals) end = _const_1d(end_vals) diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 7a7ea09c1056..39e496183476 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -4510,16 +4510,9 @@ def test_stablehlo_gather_complex_unsupported(): with pytest.raises(tvm.error.OpNotImplemented, match="start_index_map"): from_tflite(tflite_model) - - - - - -def _pad_vector(builder, values): +def _pad_vector(builder, start_vector_fn, values): """Build a FlatBuffers int64 vector for pad options.""" - _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector( - builder, len(values) - ) + start_vector_fn(builder, len(values)) for v in reversed(values): builder.PrependInt64(v) return builder.EndVector() @@ -4529,23 +4522,21 @@ def _build_stablehlo_pad_model(edge_low, edge_high, interior): """STABLEHLO_PAD with given padding vectors.""" builder = flatbuffers.Builder(1024) - # Build EdgePaddingLow vector - _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector(builder, len(edge_low)) - for v in reversed(edge_low): - builder.PrependInt64(v) - lo_vec = builder.EndVector() - - # Build EdgePaddingHigh vector - _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector(builder, len(edge_high)) - for v in reversed(edge_high): - builder.PrependInt64(v) - hi_vec = builder.EndVector() - - # Build InteriorPadding vector - _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector(builder, len(interior)) - for v in reversed(interior): - builder.PrependInt64(v) - int_vec = builder.EndVector() + lo_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector, + edge_low, + ) + hi_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector, + edge_high, + ) + int_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector, + interior, + ) _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) @@ -4608,9 +4599,21 @@ def test_stablehlo_pad_interior_unsupported(): """STABLEHLO_PAD with interior padding raises OpNotImplemented.""" builder = flatbuffers.Builder(1024) - lo_vec = _pad_vector(builder, [0, 0]) - hi_vec = _pad_vector(builder, [0, 0]) - int_vec = _pad_vector(builder, [1, 0]) + lo_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector, + [0, 0], + ) + hi_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector, + [0, 0], + ) + int_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector, + [1, 0], + ) _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) @@ -4655,9 +4658,21 @@ def test_stablehlo_pad_negative_unsupported(): """STABLEHLO_PAD with negative edge padding raises OpNotImplemented.""" builder = flatbuffers.Builder(1024) - lo_vec = _pad_vector(builder, [-1, 0]) - hi_vec = _pad_vector(builder, [0, 0]) - int_vec = _pad_vector(builder, [0, 0]) + lo_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingLowVector, + [-1, 0], + ) + hi_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartEdgePaddingHighVector, + [0, 0], + ) + int_vec = _pad_vector( + builder, + _tfl_stablehlo_pad_opts.StablehloPadOptionsStartInteriorPaddingVector, + [0, 0], + ) _tfl_stablehlo_pad_opts.StablehloPadOptionsStart(builder) _tfl_stablehlo_pad_opts.StablehloPadOptionsAddEdgePaddingLow(builder, lo_vec) @@ -4819,9 +4834,9 @@ def main( with R.dataflow(): gv: R.Tensor(dtype="float32", ndim=2) = R.dynamic_strided_slice( x, - R.reshape(R.const([0, 1], dtype="int64"), [2]), - R.reshape(R.const([2, 3], dtype="int64"), [2]), - R.reshape(R.const([1, 1], dtype="int64"), [2]), + R.const([0, 1], dtype="int64"), + R.const([2, 3], dtype="int64"), + R.const([1, 1], dtype="int64"), ) R.output(gv) return gv