Skip to content

[Relax][Frontend][TFLite] Support control-flow multi-subgraph operators#19616

Merged
tlopex merged 3 commits into
apache:mainfrom
Aharrypotter:tflite_control_flow_multi_subgraph
May 27, 2026
Merged

[Relax][Frontend][TFLite] Support control-flow multi-subgraph operators#19616
tlopex merged 3 commits into
apache:mainfrom
Aharrypotter:tflite_control_flow_multi_subgraph

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

@Aharrypotter Aharrypotter commented May 26, 2026

Summary

This PR adds Relax TFLite frontend support for the TFLite builtin
control-flow / multi-subgraph operator family from #19519 item F:
CALL, IF, WHILE, and CALL_ONCE.

It builds on the multi-subgraph import infrastructure merged in PR #19587.
The frontend already accepts TFLite models with extra subgraphs while converting
only Subgraphs(0) into the Relax main function. This PR uses those extra
subgraphs as callable or control-flow regions for the TFLite control-flow
operators.

The supported subset is intentionally pure tensor and guard-first:

  • CALL lowers a referenced TFLite subgraph to a private Relax function and
    emits a direct call.
  • IF lowers the then/else subgraphs to private Relax functions and emits a
    private wrapper function containing Relax If.
  • WHILE lowers the cond/body subgraphs to private Relax functions and emits a
    recursive private Relax function for the loop.
  • CALL_ONCE supports the empty-init no-op subset and explicitly rejects
    non-empty or resource-like init patterns.

This PR does not model resource variable side effects. Those cases remain
explicitly guarded instead of being imported with incorrect pure functional
semantics.

Design

Shared Subgraph Lowering

The frontend now keeps shared conversion state across the main graph and
referenced subgraphs:

  • lowered_subgraphs
  • lowered_if_functions
  • lowered_while_functions
  • lowering_stack
  • module_builder

Referenced pure tensor subgraphs are lowered through a recursive
OperatorConverter using an isolated ExprTable, so subgraph tensor bindings
cannot overwrite bindings from the main graph. Lowered subgraphs are cached by
subgraph index and reused when the same region is referenced more than once.
Generated private functions are registered through the shared parent
module_builder, so nested cases such as main CALL -> subgraph A -> CALL subgraph B keep all private functions in the final IRModule.

Recursive ordinary CALL subgraphs are guarded with OpNotImplemented.
WHILE uses a dedicated recursive wrapper function instead, because recursion
is part of the intended Relax representation for the loop itself.

Boundary Validation

The control-flow converters validate subgraph boundaries before lowering:

  • referenced subgraph indices must be valid
  • op input/output arity must match the referenced subgraph interface
  • branch and loop tensor shape/dtype metadata must match the surrounding op
  • IF and WHILE conditions must be scalar bool tensors
  • WHILE loop-carried input/output tensors must have matching metadata

The shared _check_subgraph_interface helper is used by CALL, IF, and
WHILE to keep arity and metadata checks consistent across the control-flow
operators. _require_scalar_bool_tensor accepts both frontend TensorWrapper
objects and raw TFLite tensors so caller and referenced-subgraph condition
checks use the same path.

These checks keep the first implementation conservative and make unsupported
cases fail with targeted OpNotImplemented diagnostics.

Tuple Outputs

TFLite CALL, IF, and WHILE may produce multiple output tensors. The
frontend maps those cases to Relax tuple returns:

single output  -> tensor expression
multi output   -> Tuple(...)
op outputs     -> TupleGetItem(...)

This keeps the single-output IR simple while covering multi-output calls,
multi-output branches, and multi-variable loop state.

Operator Support

Operator TFLite options Relax lowering Supported subset
CALL CallOptions.Subgraph() private Relax function call pure tensor subgraphs, single or multiple outputs
IF IfOptions.ThenSubgraphIndex(), ElseSubgraphIndex() private wrapper function containing Relax If scalar bool condition, matching branch I/O metadata
WHILE WhileOptions.CondSubgraphIndex(), BodySubgraphIndex() recursive private Relax function scalar bool cond output, tensor loop-carried state
CALL_ONCE CallOnceOptions.InitSubgraphIndex() no-op for empty init subgraph empty init subgraph only

Not Included

  • Full CALL_ONCE resource/variable initialization semantics.
  • Resource, variant, hashtable, or variable tensor support.
  • TensorFlow-generated tf.cond / tf.while_loop smoke tests.
  • Dynamic-shape loop-state refinements beyond the current static metadata
    checks.

Tests

The tests manually build minimal TFLite flatbuffers and compare the imported
Relax IR with tvm.ir.assert_structural_equal. Unsupported-boundary tests use
pytest.raises.

Test Coverage
test_call_subgraph basic CALL to a pure tensor subgraph
test_call_subgraph_multi_output CALL tuple return and output binding
test_call_subgraph_nested_call nested CALL private function registration
test_call_subgraph_invalid_index_unsupported invalid CALL subgraph index
test_call_subgraph_io_mismatch_unsupported CALL arity mismatch
test_call_subgraph_output_metadata_mismatch_unsupported CALL output metadata guard
test_if_subgraphs basic IF branch selection
test_if_subgraphs_multi_output IF tuple branch returns
test_if_subgraphs_non_bool_condition_unsupported IF condition dtype guard
test_if_subgraphs_invalid_index_unsupported invalid then/else subgraph index
test_if_subgraphs_output_count_mismatch_unsupported branch output count guard
test_if_subgraphs_input_metadata_mismatch_unsupported branch input metadata guard
test_if_subgraphs_output_metadata_mismatch_unsupported branch output metadata guard
test_while_subgraphs basic recursive WHILE lowering
test_while_subgraphs_repeated_cond_body_pair shared cond/body loop function cache
test_while_subgraphs_two_loop_vars multi-variable loop state tuple path
test_while_subgraphs_non_bool_condition_unsupported WHILE cond output dtype guard
test_while_subgraphs_invalid_index_unsupported invalid cond/body subgraph index
test_while_subgraphs_zero_loop_vars_unsupported zero-loop-var guard
test_while_subgraphs_loop_state_metadata_mismatch_unsupported loop state metadata guard
test_while_subgraphs_output_count_mismatch_unsupported body output count guard
test_while_subgraphs_input_metadata_mismatch_unsupported cond/body input metadata guard
test_while_subgraphs_output_metadata_mismatch_unsupported cond/body output metadata guard
test_call_once_empty_init_subgraph empty CALL_ONCE no-op subset
test_call_once_non_empty_init_subgraph_unsupported non-empty init subgraph guard
test_call_once_inputs_outputs_unsupported CALL_ONCE op I/O guard
test_call_once_init_subgraph_io_unsupported init subgraph I/O guard
test_call_once_invalid_index_unsupported invalid init subgraph index

Local validation:

python -m ruff format --check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k "call_subgraph or if_subgraphs or while_subgraphs or call_once" -q

python -m pytest \
  tests/python/relax/test_frontend_tflite.py -q

Result:

ruff format --check: 2 files already formatted
ruff check: All checks passed
28 passed, 434 deselected
462 passed

References

Add Relax TFLite frontend support for the TFLite control-flow and multi-subgraph operators from apache#19519 item F: CALL, IF, WHILE, and the no-op subset of CALL_ONCE.

The implementation builds on the multi-subgraph import infrastructure from apache#19587. Referenced subgraphs are lowered to private Relax functions with isolated expression tables and shared lowering state, while the main TFLite subgraph remains the Relax main function.

- Lower CALL targets to private Relax functions and emit ordinary Relax function calls.

- Lower IF branches to private Relax functions and emit a private wrapper function containing Relax If.

- Lower WHILE cond/body subgraphs to private Relax functions and emit a recursive private Relax loop function.

- Support CALL_ONCE only when the init subgraph is empty; non-empty init subgraphs remain guarded because they may model resource initialization side effects.

- Validate referenced subgraph input/output counts and static tensor metadata at CALL, IF, and WHILE boundaries.

Added hand-built TFLite flatbuffer tests with structural equality coverage for CALL, IF, WHILE, empty CALL_ONCE, multi-output CALL, multi-output IF, and two-loop-var WHILE.

Added unsupported-path tests for non-empty CALL_ONCE, invalid subgraph indices, arity mismatches, non-bool IF/WHILE conditions, and static metadata mismatches.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for control-flow and call operators (CALL, CALL_ONCE, IF, WHILE) in the Relax TFLite frontend, enabling the translation of multi-subgraph TFLite models into Relax private functions. The review feedback highlights several critical issues regarding the handling of empty inputs or outputs in the TFLite flatbuffers Python API, where InputsAsNumpy() or OutputsAsNumpy() can return None and cause runtime errors during iteration or zipping. Additionally, a suggestion is made to use self.get_tensor_type_str instead of _decode_type for consistency.

Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed implementation and tests. I think this needs one more correctness check before merge.

The subgraph lowering currently creates a fresh BlockBuilder for each referenced subgraph and then only copies subgraph_mod[function_name_hint] back into the parent builder. If that subgraph itself contains a CALL / IF / WHILE that lowers another private function, the nested private function can remain only in the inner BlockBuilder and be absent from the final module, while the copied function still references its GlobalVar.

Could you add a regression test for a nested subgraph case, for example:

main CALL -> subgraph A -> CALL subgraph B

and either ensure all generated private functions are copied back to the parent module or adjust lowering so nested subgraph functions are registered in shared state / the parent builder?

Also, please normalize empty TFLite input/output vectors before iterating or zipping them. Some FlatBuffer APIs can return None for empty vectors, so InputsAsNumpy(), OutputsAsNumpy(), and metadata checks should use an explicit is not None fallback to [] rather than relying on direct iteration.

@Aharrypotter
Copy link
Copy Markdown
Contributor Author

Thanks @tlopex, addressed in the latest commit by fixing nested subgraph registration and normalizing empty FlatBuffer input/output vectors.

The nested subgraph issue is fixed by registering generated private functions
through the shared parent module_builder, instead of leaving nested functions
inside the temporary subgraph BlockBuilder. I also added
test_call_subgraph_nested_call to cover:

main CALL -> subgraph A -> CALL subgraph B

This verifies that both generated private functions are present in the final
IRModule.

I also normalized empty FlatBuffer input/output vectors before iterating or
zipping them, including InputsAsNumpy(), OutputsAsNumpy(), and metadata
checks, using an explicit is not None fallback path. While touching that code,
I switched _get_subgraph_params to use self.get_tensor_type_str(...) for
consistency with the rest of the class.

@Aharrypotter Aharrypotter requested a review from tlopex May 27, 2026 02:53
@tlopex tlopex merged commit fa66213 into apache:main May 27, 2026
11 checks passed
tlopex pushed a commit that referenced this pull request May 29, 2026
…htable Import Support (#19639)

## Summary

This PR adds incremental Relax TFLite frontend support for the resource
variable initialization subset:

- `VAR_HANDLE`
- `ASSIGN_VARIABLE`
- `READ_VARIABLE`

It builds on the TFLite control-flow / multi-subgraph support from
#19616,
especially `CALL_ONCE`. TFLite commonly represents initialization
through a
`CALL_ONCE` init subgraph, then uses resource handles from the main
subgraph to
read initialized variables. This PR supports that constrained
initialization
pattern without introducing general mutable runtime state into Relax.

The PR also adds explicit frontend guards for the TFLite builtin
hashtable
operators:

- `HASHTABLE`
- `HASHTABLE_IMPORT`
- `HASHTABLE_FIND`
- `HASHTABLE_SIZE`

These operators are intentionally left unsupported for now. TFLite
builtin
hashtable kernels are not generic tensor maps: their runtime
implementations
cover the `int64 -> string` and `string -> int64` table variants, and
correct
import requires proper `TensorType.STRING` support. Rejecting the
operators is
safer than lowering a synthetic numeric table semantics that TFLite does
not
actually implement.

## Design

### Shared Initialization State

The frontend now keeps resource initialization data in shared conversion
state:

- `conversion_state["resource_values"]`
- `conversion_state["in_call_once_init"]`

This state is shared by the main graph converter and the `CALL_ONCE`
init
subgraph converter. Each converter instance still keeps its own local
`self.resource_handles` map, keyed by TFLite tensor name.

Resource variables use `container + shared_name` from `VarHandleOptions`
when
present, falling back to the handle tensor name. This keeps tensor-name
bindings
scoped to each subgraph while allowing init subgraphs and the main graph
to
agree on the same logical resource.

### CALL_ONCE Init Subgraphs

`CALL_ONCE` now accepts a non-empty init subgraph when all operators are
in the
supported initialization subset:

- `VAR_HANDLE`
- `ASSIGN_VARIABLE`

The init subgraph still must have no inputs and no outputs. The
converter first
checks every operator against the allowlist, then converts the init
subgraph
with a fresh `ExprTable` and shared conversion state.

The init subconverter deliberately shares the parent `BlockBuilder`.
This is
safe for the current subset because all supported init operators update
importer
state and return `None`; they do not emit Relax bindings. A comment
documents
that this should be revisited if future `CALL_ONCE` init operators emit
Relax
expressions.

### Resource Variables

`VAR_HANDLE` is declarative. It registers the output resource tensor in
the
current converter's local `resource_handles` map and returns `None`.

`ASSIGN_VARIABLE` is accepted only while converting a supported
`CALL_ONCE` init
subgraph. It resolves the resource handle through the init converter's
local
handle map and stores the assigned tensor expression in shared
`conversion_state["resource_values"]`.

`READ_VARIABLE` resolves the main graph resource handle and returns the
initialized expression from shared state. If the resource has not been
initialized by a supported `CALL_ONCE` path, the frontend raises
`OpNotImplemented`.

This supports the common static-initialization inference pattern while
avoiding
incorrect lowering for runtime mutation.

### Hashtable Operators

`HASHTABLE` registers the table handle and validates the dtype pair
against
TFLite kernel constraints (`int64/string` or `string/int64`).

`HASHTABLE_IMPORT` in a supported `CALL_ONCE` init subgraph captures
static
metadata (table size, key/value dtypes) but does not store actual string
data,
because Relax does not yet support `TensorType.STRING`.

`HASHTABLE_SIZE` returns a scalar Relax constant for statically imported
tables.

`HASHTABLE_FIND` is rejected with `OpNotImplemented` because Relax
cannot
represent TFLite string tensors or the runtime lookup semantics.

## Operator Support

| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `VAR_HANDLE` | `VarHandleOptions` | handle registration only | main
graph and supported `CALL_ONCE` init subgraphs |
| `ASSIGN_VARIABLE` | `AssignVariableOptions` | store initialized Relax
expression in shared importer state | supported `CALL_ONCE` init
subgraphs only |
| `READ_VARIABLE` | `ReadVariableOptions` | return initialized Relax
expression | resource must have supported static initialization |
| `HASHTABLE` | `HashtableOptions` | handle registration + dtype
validation | validates `int64/string` or `string/int64` pair, rejects
other combinations |
| `HASHTABLE_IMPORT` | `HashtableImportOptions` | store static metadata
(size, key/value dtype) | `CALL_ONCE` init subgraphs only, constant
key/value shape validation |
| `HASHTABLE_FIND` | `HashtableFindOptions` | unsupported guard |
requires future `TensorType.STRING` support in Relax |
| `HASHTABLE_SIZE` | `HashtableSizeOptions` | scalar Relax constant |
returns `[size]` int64 for statically imported tables |

## Safety Checks

- `ASSIGN_VARIABLE` outside `CALL_ONCE` initialization raises
  `OpNotImplemented`.
- `READ_VARIABLE` without supported initialization raises
`OpNotImplemented`.
- `CALL_ONCE` init subgraphs with inputs or outputs remain unsupported.
- `CALL_ONCE` init subgraphs containing operators outside the
resource-variable
  initialization allowlist remain unsupported.
- TFLite builtin hashtable operators raise `OpNotImplemented` until the
  frontend can model their real int64/string table semantics.

## Not Included

- Runtime `ASSIGN_VARIABLE` mutation in the main graph.
- Runtime resource-state threading through Relax function parameters and
  returns.
- Cross-subgraph resource handle aliasing beyond the static
  `container/shared_name` matching pattern.
- Multiple runtime writes with ordering semantics.
- TFLite builtin hashtable lowering.
- `TensorType.STRING` import support.

## Tests

The tests manually build minimal TFLite flatbuffers and compare imported
Relax
IR with `tvm.ir.assert_structural_equal`. Unsupported patterns use
`pytest.raises`.

| Test | Coverage |
|---|---|
| `test_resource_variable_call_once_init_read` | `CALL_ONCE` init
subgraph with `VAR_HANDLE + ASSIGN_VARIABLE`, then main graph
`READ_VARIABLE` |
| `test_assign_variable_main_subgraph_unsupported` | runtime/main graph
`ASSIGN_VARIABLE` guard |
| `test_read_variable_uninitialized_unsupported` | `READ_VARIABLE`
without supported initialization guard |
| `test_hashtable_call_once_import_find_unsupported` | hashtable
init/find path remains unsupported |
| `test_hashtable_call_once_import_size_unsupported` | hashtable
init/size path remains unsupported |
| `test_hashtable_import_main_subgraph_unsupported` | main graph
`HASHTABLE_IMPORT` remains unsupported |
| `test_hashtable_size_uninitialized_unsupported` | uninitialized
`HASHTABLE_SIZE` remains unsupported |

Local validation:

```bash
python -m py_compile \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff format --check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k "resource_variable or read_variable_uninitialized or hashtable" -q

python -m pytest \
  tests/python/relax/test_frontend_tflite.py -q
```

Result:

```text
py_compile: passed
ruff format --check: files already formatted
ruff check: All checks passed
targeted resource/hashtable tests: 6 passed
full test_frontend_tflite.py: 472 passed
```
tlopex pushed a commit that referenced this pull request May 31, 2026
## Summary

This PR adds Relax TFLite frontend support for the TFLite builtin
`STABLEHLO_WHILE` operator.

`STABLEHLO_WHILE` uses StableHLO `BuiltinOptions2` to reference its
condition
and body region subgraphs. Its loop semantics otherwise match the
existing
TFLite `WHILE` importer path: loop-carried tensors are passed to the
cond/body
subgraphs, the cond subgraph returns a scalar bool, and the body
subgraph
returns the updated loop state.

## Design

### Shared While Lowering

The native TFLite `WHILE` converter is refactored through a shared
`_convert_while_like` helper. Native `WHILE` and `STABLEHLO_WHILE` now
share the
same validation and lowering path after their options are parsed:

- native `WHILE` reads `WhileOptions` from `BuiltinOptions`
- `STABLEHLO_WHILE` reads `StablehloWhileOptions` from `BuiltinOptions2`

Both paths lower the referenced cond/body subgraphs to private Relax
functions
and emit a recursive private Relax function for the loop.

### Boundary Validation

`STABLEHLO_WHILE` reuses the same guard-first checks as native `WHILE`:

- loop input count must match op output count
- cond subgraph input metadata must match loop-carried tensors
- cond subgraph must have exactly one output
- cond output must be a scalar bool tensor
- body subgraph input and output metadata must match loop-carried
tensors
- referenced cond/body subgraph indices must be valid non-main subgraphs

The recursive loop-function cache key now includes the generated
function
prefix. This prevents native `WHILE` and `STABLEHLO_WHILE` from
accidentally
sharing a cached loop wrapper if they reference the same cond/body
subgraph
indices.

## Operator Support

| Operator | TFLite options | Relax lowering | Supported subset |
|---|---|---|---|
| `STABLEHLO_WHILE` | `StablehloWhileOptions.CondSubgraphIndex()`,
`BodySubgraphIndex()` from `BuiltinOptions2` | recursive private Relax
function | tensor loop-carried state, scalar bool cond output, matching
cond/body interfaces |

## Tests

The tests manually build a minimal StableHLO while TFLite flatbuffer and
compare
the imported Relax IR with `tvm.ir.assert_structural_equal`. Unsupported
patterns use `pytest.raises`.

| Test | Coverage |
|---|---|
| `test_stablehlo_while` | basic `STABLEHLO_WHILE` recursive private
function lowering |
| `test_stablehlo_while_non_bool_condition_unsupported` | cond output
scalar bool guard |
| `test_stablehlo_while_invalid_index_unsupported` | invalid cond/body
subgraph index guard |
| `test_stablehlo_while_output_count_mismatch_unsupported` | body output
arity guard |
| `test_stablehlo_while_input_metadata_mismatch_unsupported` | cond
subgraph input metadata guard |
| `test_stablehlo_while_output_metadata_mismatch_unsupported` | body
subgraph output metadata guard |

Local validation:

```bash
python -m py_compile \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m ruff check \
  python/tvm/relax/frontend/tflite/tflite_frontend.py \
  tests/python/relax/test_frontend_tflite.py

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k stablehlo_while -q

python -m pytest \
  tests/python/relax/test_frontend_tflite.py \
  -k stablehlo -q
```

Result:

```text
py_compile: passed
ruff check: All checks passed
stablehlo_while tests: 6 passed
stablehlo tests: 84 passed
```

## References

- Issue #19519 item I: remaining StableHLO operators in TFLite
- PR #19587: StableHLO region-based ops and multi-subgraph model support
- PR #19616: TFLite control-flow / multi-subgraph support
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants