[Relax][Frontend][TFLite] Support STABLEHLO_WHILE#19646
Conversation
Add Relax TFLite frontend support for STABLEHLO_WHILE by parsing StablehloWhileOptions from BuiltinOptions2 and reusing the existing TFLite while lowering path. Refactor the native WHILE converter through a shared _convert_while_like helper so both native WHILE and STABLEHLO_WHILE validate cond/body subgraph boundaries, loop-carried tensor metadata, scalar bool conditions, and output arity consistently. Include the function prefix in the cached recursive loop key so native and StableHLO while functions cannot collide when they reference the same subgraph indices. Add manually-built StableHLO while flatbuffer tests covering the recursive Relax private function lowering plus invalid cond/body index, non-bool condition, subgraph output count mismatch, input metadata mismatch, and body output metadata mismatch guards.
There was a problem hiding this comment.
Code Review
This pull request introduces support for the STABLEHLO_WHILE operator in the TVM Relax TFLite frontend by refactoring the existing WHILE lowering logic into a reusable helper function. It also adds comprehensive unit tests to validate the lowering process and error handling. The review feedback suggests adding a robustness check in _convert_stablehlo_while to handle cases where the operator options might be parsed as None, preventing potential attribute errors.
| def _convert_stablehlo_while(self, op): | ||
| """Convert STABLEHLO_WHILE to a recursive Relax private function.""" | ||
| from tflite.StablehloWhileOptions import StablehloWhileOptions | ||
|
|
||
| opts = self._get_stablehlo_options(op, StablehloWhileOptions) | ||
| return self._convert_while_like( | ||
| op, | ||
| "STABLEHLO_WHILE", | ||
| int(opts.CondSubgraphIndex()), | ||
| int(opts.BodySubgraphIndex()), | ||
| "tflite_stablehlo_while", | ||
| ) |
There was a problem hiding this comment.
In _convert_stablehlo_while, if opts is None (e.g., due to a malformed model or parsing failure), calling opts.CondSubgraphIndex() will raise an AttributeError. Adding a check to ensure opts is not None before extracting the subgraph indices would make the parser more robust and provide a clearer error message.
| def _convert_stablehlo_while(self, op): | |
| """Convert STABLEHLO_WHILE to a recursive Relax private function.""" | |
| from tflite.StablehloWhileOptions import StablehloWhileOptions | |
| opts = self._get_stablehlo_options(op, StablehloWhileOptions) | |
| return self._convert_while_like( | |
| op, | |
| "STABLEHLO_WHILE", | |
| int(opts.CondSubgraphIndex()), | |
| int(opts.BodySubgraphIndex()), | |
| "tflite_stablehlo_while", | |
| ) | |
| def _convert_stablehlo_while(self, op): | |
| """Convert STABLEHLO_WHILE to a recursive Relax private function.""" | |
| from tflite.StablehloWhileOptions import StablehloWhileOptions | |
| opts = self._get_stablehlo_options(op, StablehloWhileOptions) | |
| if opts is None: | |
| raise tvm.error.OpNotImplemented("STABLEHLO_WHILE requires valid StablehloWhileOptions") | |
| return self._convert_while_like( | |
| op, | |
| "STABLEHLO_WHILE", | |
| int(opts.CondSubgraphIndex()), | |
| int(opts.BodySubgraphIndex()), | |
| "tflite_stablehlo_while", | |
| ) |
Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
STABLEHLO_WHILEoperator.STABLEHLO_WHILEuses StableHLOBuiltinOptions2to reference its conditionand body region subgraphs. Its loop semantics otherwise match the existing
TFLite
WHILEimporter path: loop-carried tensors are passed to the cond/bodysubgraphs, the cond subgraph returns a scalar bool, and the body subgraph
returns the updated loop state.
Design
Shared While Lowering
The native TFLite
WHILEconverter is refactored through a shared_convert_while_likehelper. NativeWHILEandSTABLEHLO_WHILEnow share thesame validation and lowering path after their options are parsed:
WHILEreadsWhileOptionsfromBuiltinOptionsSTABLEHLO_WHILEreadsStablehloWhileOptionsfromBuiltinOptions2Both paths lower the referenced cond/body subgraphs to private Relax functions
and emit a recursive private Relax function for the loop.
Boundary Validation
STABLEHLO_WHILEreuses the same guard-first checks as nativeWHILE:The recursive loop-function cache key now includes the generated function
prefix. This prevents native
WHILEandSTABLEHLO_WHILEfrom accidentallysharing a cached loop wrapper if they reference the same cond/body subgraph
indices.
Operator Support
STABLEHLO_WHILEStablehloWhileOptions.CondSubgraphIndex(),BodySubgraphIndex()fromBuiltinOptions2Tests
The tests manually build a minimal StableHLO while TFLite flatbuffer and compare
the imported Relax IR with
tvm.ir.assert_structural_equal. Unsupportedpatterns use
pytest.raises.test_stablehlo_whileSTABLEHLO_WHILErecursive private function loweringtest_stablehlo_while_non_bool_condition_unsupportedtest_stablehlo_while_invalid_index_unsupportedtest_stablehlo_while_output_count_mismatch_unsupportedtest_stablehlo_while_input_metadata_mismatch_unsupportedtest_stablehlo_while_output_metadata_mismatch_unsupportedLocal validation:
Result:
References