Skip to content

[Relax][Frontend][TFLite] Support STABLEHLO_WHILE#19646

Merged
tlopex merged 1 commit into
apache:mainfrom
Aharrypotter:relax-tflite-stablehlo-while
May 31, 2026
Merged

[Relax][Frontend][TFLite] Support STABLEHLO_WHILE#19646
tlopex merged 1 commit into
apache:mainfrom
Aharrypotter:relax-tflite-stablehlo-while

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

Summary

This PR adds Relax TFLite frontend support for the TFLite builtin
STABLEHLO_WHILE operator.

STABLEHLO_WHILE uses StableHLO BuiltinOptions2 to reference its condition
and body region subgraphs. Its loop semantics otherwise match the existing
TFLite WHILE importer path: loop-carried tensors are passed to the cond/body
subgraphs, the cond subgraph returns a scalar bool, and the body subgraph
returns the updated loop state.

Design

Shared While Lowering

The native TFLite WHILE converter is refactored through a shared
_convert_while_like helper. Native WHILE and STABLEHLO_WHILE now share the
same validation and lowering path after their options are parsed:

  • native WHILE reads WhileOptions from BuiltinOptions
  • STABLEHLO_WHILE reads StablehloWhileOptions from BuiltinOptions2

Both paths lower the referenced cond/body subgraphs to private Relax functions
and emit a recursive private Relax function for the loop.

Boundary Validation

STABLEHLO_WHILE reuses the same guard-first checks as native WHILE:

  • loop input count must match op output count
  • cond subgraph input metadata must match loop-carried tensors
  • cond subgraph must have exactly one output
  • cond output must be a scalar bool tensor
  • body subgraph input and output metadata must match loop-carried tensors
  • referenced cond/body subgraph indices must be valid non-main subgraphs

The recursive loop-function cache key now includes the generated function
prefix. This prevents native WHILE and STABLEHLO_WHILE from accidentally
sharing a cached loop wrapper if they reference the same cond/body subgraph
indices.

Operator Support

Operator TFLite options Relax lowering Supported subset
STABLEHLO_WHILE StablehloWhileOptions.CondSubgraphIndex(), BodySubgraphIndex() from BuiltinOptions2 recursive private Relax function tensor loop-carried state, scalar bool cond output, matching cond/body interfaces

Tests

The tests manually build a minimal StableHLO while TFLite flatbuffer and compare
the imported Relax IR with tvm.ir.assert_structural_equal. Unsupported
patterns use pytest.raises.

Test Coverage
test_stablehlo_while basic STABLEHLO_WHILE recursive private function lowering
test_stablehlo_while_non_bool_condition_unsupported cond output scalar bool guard
test_stablehlo_while_invalid_index_unsupported invalid cond/body subgraph index guard
test_stablehlo_while_output_count_mismatch_unsupported body output arity guard
test_stablehlo_while_input_metadata_mismatch_unsupported cond subgraph input metadata guard
test_stablehlo_while_output_metadata_mismatch_unsupported body subgraph output metadata guard

Local validation:

python -m py_compile \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k stablehlo_while -q

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k stablehlo -q

Result:

py_compile: passed
ruff check: All checks passed
stablehlo_while tests: 6 passed
stablehlo tests: 84 passed

References

Add Relax TFLite frontend support for STABLEHLO_WHILE by parsing StablehloWhileOptions from BuiltinOptions2 and reusing the existing TFLite while lowering path.

Refactor the native WHILE converter through a shared _convert_while_like helper so both native WHILE and STABLEHLO_WHILE validate cond/body subgraph boundaries, loop-carried tensor metadata, scalar bool conditions, and output arity consistently. Include the function prefix in the cached recursive loop key so native and StableHLO while functions cannot collide when they reference the same subgraph indices.

Add manually-built StableHLO while flatbuffer tests covering the recursive Relax private function lowering plus invalid cond/body index, non-bool condition, subgraph output count mismatch, input metadata mismatch, and body output metadata mismatch guards.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the STABLEHLO_WHILE operator in the TVM Relax TFLite frontend by refactoring the existing WHILE lowering logic into a reusable helper function. It also adds comprehensive unit tests to validate the lowering process and error handling. The review feedback suggests adding a robustness check in _convert_stablehlo_while to handle cases where the operator options might be parsed as None, preventing potential attribute errors.

Comment on lines +2165 to +2176
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions

opts = self._get_stablehlo_options(op, StablehloWhileOptions)
return self._convert_while_like(
op,
"STABLEHLO_WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_stablehlo_while",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _convert_stablehlo_while, if opts is None (e.g., due to a malformed model or parsing failure), calling opts.CondSubgraphIndex() will raise an AttributeError. Adding a check to ensure opts is not None before extracting the subgraph indices would make the parser more robust and provide a clearer error message.

Suggested change
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
opts = self._get_stablehlo_options(op, StablehloWhileOptions)
return self._convert_while_like(
op,
"STABLEHLO_WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_stablehlo_while",
)
def _convert_stablehlo_while(self, op):
"""Convert STABLEHLO_WHILE to a recursive Relax private function."""
from tflite.StablehloWhileOptions import StablehloWhileOptions
opts = self._get_stablehlo_options(op, StablehloWhileOptions)
if opts is None:
raise tvm.error.OpNotImplemented("STABLEHLO_WHILE requires valid StablehloWhileOptions")
return self._convert_while_like(
op,
"STABLEHLO_WHILE",
int(opts.CondSubgraphIndex()),
int(opts.BodySubgraphIndex()),
"tflite_stablehlo_while",
)

@tlopex tlopex merged commit 99488d9 into apache:main May 31, 2026
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants