[Relax][Frontend][TFLite] Support control-flow multi-subgraph operators#19616
Conversation
Add Relax TFLite frontend support for the TFLite control-flow and multi-subgraph operators from apache#19519 item F: CALL, IF, WHILE, and the no-op subset of CALL_ONCE. The implementation builds on the multi-subgraph import infrastructure from apache#19587. Referenced subgraphs are lowered to private Relax functions with isolated expression tables and shared lowering state, while the main TFLite subgraph remains the Relax main function. - Lower CALL targets to private Relax functions and emit ordinary Relax function calls. - Lower IF branches to private Relax functions and emit a private wrapper function containing Relax If. - Lower WHILE cond/body subgraphs to private Relax functions and emit a recursive private Relax loop function. - Support CALL_ONCE only when the init subgraph is empty; non-empty init subgraphs remain guarded because they may model resource initialization side effects. - Validate referenced subgraph input/output counts and static tensor metadata at CALL, IF, and WHILE boundaries. Added hand-built TFLite flatbuffer tests with structural equality coverage for CALL, IF, WHILE, empty CALL_ONCE, multi-output CALL, multi-output IF, and two-loop-var WHILE. Added unsupported-path tests for non-empty CALL_ONCE, invalid subgraph indices, arity mismatches, non-bool IF/WHILE conditions, and static metadata mismatches.
There was a problem hiding this comment.
Code Review
This pull request adds support for control-flow and call operators (CALL, CALL_ONCE, IF, WHILE) in the Relax TFLite frontend, enabling the translation of multi-subgraph TFLite models into Relax private functions. The review feedback highlights several critical issues regarding the handling of empty inputs or outputs in the TFLite flatbuffers Python API, where InputsAsNumpy() or OutputsAsNumpy() can return None and cause runtime errors during iteration or zipping. Additionally, a suggestion is made to use self.get_tensor_type_str instead of _decode_type for consistency.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for the detailed implementation and tests. I think this needs one more correctness check before merge.
The subgraph lowering currently creates a fresh BlockBuilder for each referenced subgraph and then only copies subgraph_mod[function_name_hint] back into the parent builder. If that subgraph itself contains a CALL / IF / WHILE that lowers another private function, the nested private function can remain only in the inner BlockBuilder and be absent from the final module, while the copied function still references its GlobalVar.
Could you add a regression test for a nested subgraph case, for example:
main CALL -> subgraph A -> CALL subgraph B
and either ensure all generated private functions are copied back to the parent module or adjust lowering so nested subgraph functions are registered in shared state / the parent builder?
Also, please normalize empty TFLite input/output vectors before iterating or zipping them. Some FlatBuffer APIs can return None for empty vectors, so InputsAsNumpy(), OutputsAsNumpy(), and metadata checks should use an explicit is not None fallback to [] rather than relying on direct iteration.
|
Thanks @tlopex, addressed in the latest commit by fixing nested subgraph registration and normalizing empty FlatBuffer input/output vectors. The nested subgraph issue is fixed by registering generated private functions This verifies that both generated private functions are present in the final I also normalized empty FlatBuffer input/output vectors before iterating or |
…htable Import Support (#19639) ## Summary This PR adds incremental Relax TFLite frontend support for the resource variable initialization subset: - `VAR_HANDLE` - `ASSIGN_VARIABLE` - `READ_VARIABLE` It builds on the TFLite control-flow / multi-subgraph support from #19616, especially `CALL_ONCE`. TFLite commonly represents initialization through a `CALL_ONCE` init subgraph, then uses resource handles from the main subgraph to read initialized variables. This PR supports that constrained initialization pattern without introducing general mutable runtime state into Relax. The PR also adds explicit frontend guards for the TFLite builtin hashtable operators: - `HASHTABLE` - `HASHTABLE_IMPORT` - `HASHTABLE_FIND` - `HASHTABLE_SIZE` These operators are intentionally left unsupported for now. TFLite builtin hashtable kernels are not generic tensor maps: their runtime implementations cover the `int64 -> string` and `string -> int64` table variants, and correct import requires proper `TensorType.STRING` support. Rejecting the operators is safer than lowering a synthetic numeric table semantics that TFLite does not actually implement. ## Design ### Shared Initialization State The frontend now keeps resource initialization data in shared conversion state: - `conversion_state["resource_values"]` - `conversion_state["in_call_once_init"]` This state is shared by the main graph converter and the `CALL_ONCE` init subgraph converter. Each converter instance still keeps its own local `self.resource_handles` map, keyed by TFLite tensor name. Resource variables use `container + shared_name` from `VarHandleOptions` when present, falling back to the handle tensor name. This keeps tensor-name bindings scoped to each subgraph while allowing init subgraphs and the main graph to agree on the same logical resource. ### CALL_ONCE Init Subgraphs `CALL_ONCE` now accepts a non-empty init subgraph when all operators are in the supported initialization subset: - `VAR_HANDLE` - `ASSIGN_VARIABLE` The init subgraph still must have no inputs and no outputs. The converter first checks every operator against the allowlist, then converts the init subgraph with a fresh `ExprTable` and shared conversion state. The init subconverter deliberately shares the parent `BlockBuilder`. This is safe for the current subset because all supported init operators update importer state and return `None`; they do not emit Relax bindings. A comment documents that this should be revisited if future `CALL_ONCE` init operators emit Relax expressions. ### Resource Variables `VAR_HANDLE` is declarative. It registers the output resource tensor in the current converter's local `resource_handles` map and returns `None`. `ASSIGN_VARIABLE` is accepted only while converting a supported `CALL_ONCE` init subgraph. It resolves the resource handle through the init converter's local handle map and stores the assigned tensor expression in shared `conversion_state["resource_values"]`. `READ_VARIABLE` resolves the main graph resource handle and returns the initialized expression from shared state. If the resource has not been initialized by a supported `CALL_ONCE` path, the frontend raises `OpNotImplemented`. This supports the common static-initialization inference pattern while avoiding incorrect lowering for runtime mutation. ### Hashtable Operators `HASHTABLE` registers the table handle and validates the dtype pair against TFLite kernel constraints (`int64/string` or `string/int64`). `HASHTABLE_IMPORT` in a supported `CALL_ONCE` init subgraph captures static metadata (table size, key/value dtypes) but does not store actual string data, because Relax does not yet support `TensorType.STRING`. `HASHTABLE_SIZE` returns a scalar Relax constant for statically imported tables. `HASHTABLE_FIND` is rejected with `OpNotImplemented` because Relax cannot represent TFLite string tensors or the runtime lookup semantics. ## Operator Support | Operator | TFLite options | Relax lowering | Supported subset | |---|---|---|---| | `VAR_HANDLE` | `VarHandleOptions` | handle registration only | main graph and supported `CALL_ONCE` init subgraphs | | `ASSIGN_VARIABLE` | `AssignVariableOptions` | store initialized Relax expression in shared importer state | supported `CALL_ONCE` init subgraphs only | | `READ_VARIABLE` | `ReadVariableOptions` | return initialized Relax expression | resource must have supported static initialization | | `HASHTABLE` | `HashtableOptions` | handle registration + dtype validation | validates `int64/string` or `string/int64` pair, rejects other combinations | | `HASHTABLE_IMPORT` | `HashtableImportOptions` | store static metadata (size, key/value dtype) | `CALL_ONCE` init subgraphs only, constant key/value shape validation | | `HASHTABLE_FIND` | `HashtableFindOptions` | unsupported guard | requires future `TensorType.STRING` support in Relax | | `HASHTABLE_SIZE` | `HashtableSizeOptions` | scalar Relax constant | returns `[size]` int64 for statically imported tables | ## Safety Checks - `ASSIGN_VARIABLE` outside `CALL_ONCE` initialization raises `OpNotImplemented`. - `READ_VARIABLE` without supported initialization raises `OpNotImplemented`. - `CALL_ONCE` init subgraphs with inputs or outputs remain unsupported. - `CALL_ONCE` init subgraphs containing operators outside the resource-variable initialization allowlist remain unsupported. - TFLite builtin hashtable operators raise `OpNotImplemented` until the frontend can model their real int64/string table semantics. ## Not Included - Runtime `ASSIGN_VARIABLE` mutation in the main graph. - Runtime resource-state threading through Relax function parameters and returns. - Cross-subgraph resource handle aliasing beyond the static `container/shared_name` matching pattern. - Multiple runtime writes with ordering semantics. - TFLite builtin hashtable lowering. - `TensorType.STRING` import support. ## Tests The tests manually build minimal TFLite flatbuffers and compare imported Relax IR with `tvm.ir.assert_structural_equal`. Unsupported patterns use `pytest.raises`. | Test | Coverage | |---|---| | `test_resource_variable_call_once_init_read` | `CALL_ONCE` init subgraph with `VAR_HANDLE + ASSIGN_VARIABLE`, then main graph `READ_VARIABLE` | | `test_assign_variable_main_subgraph_unsupported` | runtime/main graph `ASSIGN_VARIABLE` guard | | `test_read_variable_uninitialized_unsupported` | `READ_VARIABLE` without supported initialization guard | | `test_hashtable_call_once_import_find_unsupported` | hashtable init/find path remains unsupported | | `test_hashtable_call_once_import_size_unsupported` | hashtable init/size path remains unsupported | | `test_hashtable_import_main_subgraph_unsupported` | main graph `HASHTABLE_IMPORT` remains unsupported | | `test_hashtable_size_uninitialized_unsupported` | uninitialized `HASHTABLE_SIZE` remains unsupported | Local validation: ```bash python -m py_compile \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m ruff format --check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m ruff check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m pytest \ tests/python/relax/test_frontend_tflite.py \ -k "resource_variable or read_variable_uninitialized or hashtable" -q python -m pytest \ tests/python/relax/test_frontend_tflite.py -q ``` Result: ```text py_compile: passed ruff format --check: files already formatted ruff check: All checks passed targeted resource/hashtable tests: 6 passed full test_frontend_tflite.py: 472 passed ```
## Summary This PR adds Relax TFLite frontend support for the TFLite builtin `STABLEHLO_WHILE` operator. `STABLEHLO_WHILE` uses StableHLO `BuiltinOptions2` to reference its condition and body region subgraphs. Its loop semantics otherwise match the existing TFLite `WHILE` importer path: loop-carried tensors are passed to the cond/body subgraphs, the cond subgraph returns a scalar bool, and the body subgraph returns the updated loop state. ## Design ### Shared While Lowering The native TFLite `WHILE` converter is refactored through a shared `_convert_while_like` helper. Native `WHILE` and `STABLEHLO_WHILE` now share the same validation and lowering path after their options are parsed: - native `WHILE` reads `WhileOptions` from `BuiltinOptions` - `STABLEHLO_WHILE` reads `StablehloWhileOptions` from `BuiltinOptions2` Both paths lower the referenced cond/body subgraphs to private Relax functions and emit a recursive private Relax function for the loop. ### Boundary Validation `STABLEHLO_WHILE` reuses the same guard-first checks as native `WHILE`: - loop input count must match op output count - cond subgraph input metadata must match loop-carried tensors - cond subgraph must have exactly one output - cond output must be a scalar bool tensor - body subgraph input and output metadata must match loop-carried tensors - referenced cond/body subgraph indices must be valid non-main subgraphs The recursive loop-function cache key now includes the generated function prefix. This prevents native `WHILE` and `STABLEHLO_WHILE` from accidentally sharing a cached loop wrapper if they reference the same cond/body subgraph indices. ## Operator Support | Operator | TFLite options | Relax lowering | Supported subset | |---|---|---|---| | `STABLEHLO_WHILE` | `StablehloWhileOptions.CondSubgraphIndex()`, `BodySubgraphIndex()` from `BuiltinOptions2` | recursive private Relax function | tensor loop-carried state, scalar bool cond output, matching cond/body interfaces | ## Tests The tests manually build a minimal StableHLO while TFLite flatbuffer and compare the imported Relax IR with `tvm.ir.assert_structural_equal`. Unsupported patterns use `pytest.raises`. | Test | Coverage | |---|---| | `test_stablehlo_while` | basic `STABLEHLO_WHILE` recursive private function lowering | | `test_stablehlo_while_non_bool_condition_unsupported` | cond output scalar bool guard | | `test_stablehlo_while_invalid_index_unsupported` | invalid cond/body subgraph index guard | | `test_stablehlo_while_output_count_mismatch_unsupported` | body output arity guard | | `test_stablehlo_while_input_metadata_mismatch_unsupported` | cond subgraph input metadata guard | | `test_stablehlo_while_output_metadata_mismatch_unsupported` | body subgraph output metadata guard | Local validation: ```bash python -m py_compile \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m ruff check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m pytest \ tests/python/relax/test_frontend_tflite.py \ -k stablehlo_while -q python -m pytest \ tests/python/relax/test_frontend_tflite.py \ -k stablehlo -q ``` Result: ```text py_compile: passed ruff check: All checks passed stablehlo_while tests: 6 passed stablehlo tests: 84 passed ``` ## References - Issue #19519 item I: remaining StableHLO operators in TFLite - PR #19587: StableHLO region-based ops and multi-subgraph model support - PR #19616: TFLite control-flow / multi-subgraph support
Summary
This PR adds Relax TFLite frontend support for the TFLite builtin
control-flow / multi-subgraph operator family from #19519 item F:
CALL,IF,WHILE, andCALL_ONCE.It builds on the multi-subgraph import infrastructure merged in PR #19587.
The frontend already accepts TFLite models with extra subgraphs while converting
only
Subgraphs(0)into the Relaxmainfunction. This PR uses those extrasubgraphs as callable or control-flow regions for the TFLite control-flow
operators.
The supported subset is intentionally pure tensor and guard-first:
CALLlowers a referenced TFLite subgraph to a private Relax function andemits a direct call.
IFlowers the then/else subgraphs to private Relax functions and emits aprivate wrapper function containing Relax
If.WHILElowers the cond/body subgraphs to private Relax functions and emits arecursive private Relax function for the loop.
CALL_ONCEsupports the empty-init no-op subset and explicitly rejectsnon-empty or resource-like init patterns.
This PR does not model resource variable side effects. Those cases remain
explicitly guarded instead of being imported with incorrect pure functional
semantics.
Design
Shared Subgraph Lowering
The frontend now keeps shared conversion state across the main graph and
referenced subgraphs:
lowered_subgraphslowered_if_functionslowered_while_functionslowering_stackmodule_builderReferenced pure tensor subgraphs are lowered through a recursive
OperatorConverterusing an isolatedExprTable, so subgraph tensor bindingscannot overwrite bindings from the main graph. Lowered subgraphs are cached by
subgraph index and reused when the same region is referenced more than once.
Generated private functions are registered through the shared parent
module_builder, so nested cases such asmain CALL -> subgraph A -> CALL subgraph Bkeep all private functions in the final IRModule.Recursive ordinary
CALLsubgraphs are guarded withOpNotImplemented.WHILEuses a dedicated recursive wrapper function instead, because recursionis part of the intended Relax representation for the loop itself.
Boundary Validation
The control-flow converters validate subgraph boundaries before lowering:
IFandWHILEconditions must be scalar bool tensorsWHILEloop-carried input/output tensors must have matching metadataThe shared
_check_subgraph_interfacehelper is used byCALL,IF, andWHILEto keep arity and metadata checks consistent across the control-flowoperators.
_require_scalar_bool_tensoraccepts both frontendTensorWrapperobjects and raw TFLite tensors so caller and referenced-subgraph condition
checks use the same path.
These checks keep the first implementation conservative and make unsupported
cases fail with targeted
OpNotImplementeddiagnostics.Tuple Outputs
TFLite
CALL,IF, andWHILEmay produce multiple output tensors. Thefrontend maps those cases to Relax tuple returns:
This keeps the single-output IR simple while covering multi-output calls,
multi-output branches, and multi-variable loop state.
Operator Support
CALLCallOptions.Subgraph()IFIfOptions.ThenSubgraphIndex(),ElseSubgraphIndex()IfWHILEWhileOptions.CondSubgraphIndex(),BodySubgraphIndex()CALL_ONCECallOnceOptions.InitSubgraphIndex()Not Included
CALL_ONCEresource/variable initialization semantics.tf.cond/tf.while_loopsmoke tests.checks.
Tests
The tests manually build minimal TFLite flatbuffers and compare the imported
Relax IR with
tvm.ir.assert_structural_equal. Unsupported-boundary tests usepytest.raises.test_call_subgraphCALLto a pure tensor subgraphtest_call_subgraph_multi_outputCALLtuple return and output bindingtest_call_subgraph_nested_callCALLprivate function registrationtest_call_subgraph_invalid_index_unsupportedCALLsubgraph indextest_call_subgraph_io_mismatch_unsupportedCALLarity mismatchtest_call_subgraph_output_metadata_mismatch_unsupportedCALLoutput metadata guardtest_if_subgraphsIFbranch selectiontest_if_subgraphs_multi_outputIFtuple branch returnstest_if_subgraphs_non_bool_condition_unsupportedIFcondition dtype guardtest_if_subgraphs_invalid_index_unsupportedtest_if_subgraphs_output_count_mismatch_unsupportedtest_if_subgraphs_input_metadata_mismatch_unsupportedtest_if_subgraphs_output_metadata_mismatch_unsupportedtest_while_subgraphsWHILEloweringtest_while_subgraphs_repeated_cond_body_pairtest_while_subgraphs_two_loop_varstest_while_subgraphs_non_bool_condition_unsupportedWHILEcond output dtype guardtest_while_subgraphs_invalid_index_unsupportedtest_while_subgraphs_zero_loop_vars_unsupportedtest_while_subgraphs_loop_state_metadata_mismatch_unsupportedtest_while_subgraphs_output_count_mismatch_unsupportedtest_while_subgraphs_input_metadata_mismatch_unsupportedtest_while_subgraphs_output_metadata_mismatch_unsupportedtest_call_once_empty_init_subgraphCALL_ONCEno-op subsettest_call_once_non_empty_init_subgraph_unsupportedtest_call_once_inputs_outputs_unsupportedCALL_ONCEop I/O guardtest_call_once_init_subgraph_io_unsupportedtest_call_once_invalid_index_unsupportedLocal validation:
python -m ruff format --check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m ruff check \ python/tvm/relax/frontend/tflite/tflite_frontend.py \ tests/python/relax/test_frontend_tflite.py python -m pytest \ tests/python/relax/test_frontend_tflite.py \ -k "call_subgraph or if_subgraphs or while_subgraphs or call_once" -q python -m pytest \ tests/python/relax/test_frontend_tflite.py -qResult:
References