Skip to content

[Relax][Frontend][TFLite] Add initial StableHLO builtin operator support#19536

Merged
tlopex merged 6 commits into
apache:mainfrom
Aharrypotter:stablehlo_tflite_ops
May 11, 2026
Merged

[Relax][Frontend][TFLite] Add initial StableHLO builtin operator support#19536
tlopex merged 6 commits into
apache:mainfrom
Aharrypotter:stablehlo_tflite_ops

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

Summary

This PR adds initial Relax TFLite frontend support for 29 StableHLO builtin
operators from #19519 item I.

The covered subset includes pure elementwise ops, BuiltinOptions2 /
metadata-based ops, simple shape-manipulation ops, and a take-equivalent subset
of STABLEHLO_GATHER.

StableHLO builtins carry no TFLite-specific quantization or fused-activation
metadata, so the implementation uses dedicated converter helpers that bypass the
existing TFLite elemwise/QNN code paths.

Relates to #19519.

Changes

  1. Zero-attribute elementwise helpers

    • Add _convert_stablehlo_unary, _convert_stablehlo_binary, and
      _convert_stablehlo_ternary for pure elementwise mapping.
    • Register 20 ops: unary (ABS, NEGATE, COSINE, EXPONENTIAL, FLOOR,
      LOG, LOGISTIC, RSQRT, TANH), binary (ADD, SUBTRACT, MULTIPLY,
      DIVIDE, MAXIMUM, MINIMUM, POWER), ternary (SELECTR.where),
      and dtype-dispatched bitwise/logical ops (AND / OR → logical ops for
      bool or bitwise ops for integer, SHIFT_LEFTR.left_shift for integer).
  2. BuiltinOptions2 infrastructure

    • Add _get_stablehlo_options helper for parsing BuiltinOptions2 flatbuffers
      with enum validation via getattr(BuiltinOptions2, options_cls.__name__).
    • Register 6 ops: CONVERTR.astype, CLAMP
      R.minimum(R.maximum(...)), CONCATENATER.concat,
      BROADCAST_IN_DIMR.reshape + R.broadcast_to, IOTA
      R.arange + R.broadcast_to, and COMPARE → 6 comparison directions
      (TOTALORDER raises OpNotImplemented).
  3. Shape-manipulation ops

    • PADR.nn.pad in constant mode. The initial PAD path supports
      non-negative edge padding with zero interior padding and a constant scalar
      padding value. Interior padding, negative padding, and dynamic padding
      values raise OpNotImplemented.
    • DYNAMIC_SLICER.dynamic_strided_slice. The initial path supports
      constant, in-bound start indices only. Runtime start indices and
      out-of-bounds StableHLO clamping semantics are deferred.
  4. Indexing op

    • GATHERR.take for the take-equivalent subset only.
    • Parses the relevant StablehloGatherOptions attributes needed to validate
      this subset: offset_dims, collapsed_slice_dims, start_index_map,
      index_vector_dim, and slice_sizes.
    • Validates the gather axis, collapsed dims, offset dims, slice sizes, and
      output shape against the expected R.take layout. Multi-dimensional and
      non-take-equivalent gather patterns raise OpNotImplemented.
  5. Not included

    • STABLEHLO_RESHAPE, STABLEHLO_TRANSPOSE, and STABLEHLO_SLICE are left
      to another contributor who expressed interest in those ops.
    • The remaining Issue [Tracking Issue][TFLite] Remaining builtin operator coverage beyond #19412 #19519 StableHLO items are deferred to follow-up PRs:
      CBRT, REMAINDER, SCATTER, CONVOLUTION, DOT_GENERAL, REDUCE,
      REDUCE_WINDOW, DYNAMIC_UPDATE_SLICE, COMPOSITE, CUSTOM_CALL,
      RNG_BIT_GENERATOR, SORT, and WHILE.
    • More general or multi-dimensional STABLEHLO_GATHER patterns are also
      deferred to follow-up work.

Testing

All tests use manually-built minimal TFLite flatbuffers with
tvm.ir.assert_structural_equal. BuiltinOptions2 ops construct their options
via the FlatBuffers schema API, modeled after the existing DILATE test pattern.

python -m pytest tests/python/relax/test_frontend_tflite.py -k stablehlo -q

Result

  • 29 StableHLO operators registered in the Relax TFLite frontend.

  • 44 StableHLO test cases covering all registered ops, including
    structural-equal tests and unsupported/error-path checks:

    • COMPARE with TOTALORDER
    • PAD with interior padding, negative padding, and dynamic padding values
    • DYNAMIC_SLICE with runtime starts and out-of-bounds starts
    • non-take-equivalent or multi-dimensional GATHER
  • All StableHLO TFLite frontend tests pass locally.

References

…upport

Add frontend mapping for 8 basic StableHLO TFLite builtin
operators as pure unary/binary elementwise ops:

- STABLEHLO_ABS, STABLEHLO_NEGATE (unary)
- STABLEHLO_ADD, STABLEHLO_SUBTRACT, STABLEHLO_MULTIPLY, STABLEHLO_DIVIDE,
  STABLEHLO_MAXIMUM, STABLEHLO_MINIMUM (binary)

Implementation uses dedicated _convert_stablehlo_unary / _convert_stablehlo_binary
helpers that intentionally bypass TFLite fused-activation and QNN code paths,
since StableHLO ops carry no TFLite-specific quantization or fused-activation
metadata in their flatbuffer representation.

Test coverage: 8 structural-equal tests with tvm.ir.assert_structural_equal.
…h ternary SELECT

Extend the StableHLO TFLite frontend with all remaining pure elementwise
operators that require no attribute parsing:

- Unary: STABLEHLO_COSINE (cos), STABLEHLO_EXPONENTIAL (exp),
  STABLEHLO_FLOOR (floor), STABLEHLO_LOG (log), STABLEHLO_LOGISTIC (sigmoid),
  STABLEHLO_RSQRT (rsqrt), STABLEHLO_TANH (tanh)
- Binary: STABLEHLO_AND (logical_and), STABLEHLO_OR (logical_or),
  STABLEHLO_POWER (power), STABLEHLO_SHIFT_LEFT (left_shift)
- Ternary: STABLEHLO_SELECT (where) with dedicated
  _convert_stablehlo_ternary helper

The existing _convert_stablehlo_unary and _convert_stablehlo_binary helpers
are reused; only STABLEHLO_SELECT needs the new ternary converter since
R.where requires a 3-input signature with bool condition dtype.

Test coverage: 20 structural-equal tests (12 new, 8 from previous commit).
The SELECT test uses inline flatbuffer construction to set the condition
input dtype to BOOL, matching the R.where requirement.
…nOptions2 support

Introduce the first batch of StableHLO TFLite builtin operators that
require BuiltinOptions2 attribute parsing:

- STABLEHLO_CONVERT → R.astype (reads output dtype from tensor metadata)
- STABLEHLO_CLAMP → R.minimum(R.maximum(x, min), max) (arg reordering)
- STABLEHLO_CONCATENATE → R.concat with StablehloConcatenateOptions
- STABLEHLO_BROADCAST_IN_DIM → R.broadcast_to with broadcast dimensions
- STABLEHLO_IOTA → R.arange + R.reshape + R.broadcast_to
- STABLEHLO_COMPARE → R.equal/greater/less/... with 6 comparison directions

Add _get_stablehlo_options helper for parsing BuiltinOptions2 flatbuffers.
R.clip was considered for CLAMP but rejected because it only accepts
scalar PrimValue min/max, not tensor inputs.

Test coverage: 32 structural-equal tests (20 previous + 12 new) passed.
Add frontend mapping for two StableHLO TFLite builtin operators
that manipulate tensor shapes:

- STABLEHLO_PAD → R.nn.pad with constant mode. Parses EdgePaddingLow,
  EdgePaddingHigh, and InteriorPadding from StablehloPadOptions.
  Raises OpNotImplemented when interior (dilation) padding is non-zero.
- STABLEHLO_DYNAMIC_SLICE → R.dynamic_strided_slice. Reads SliceSizes
  from StablehloDynamicSliceOptions and start indices from scalar
  tensor inputs. Begin/end/strides are constructed as int64 1D tensors.

Both ops extend the BuiltinOptions2 parsing infrastructure introduced
in the previous commit, adding vector-attribute (PAD) and dynamic-input
(DYNAMIC_SLICE) patterns.

Test coverage: 33 structural-equal tests passed (31 previous + 2 new).
…nt subset)

Add frontend mapping for STABLEHLO_GATHER with a conservative
take-equivalent implementation:

- Parses 6 attributes from StablehloGatherOptions (OffsetDims,
  CollapsedSliceDims, StartIndexMap, IndexVectorDim, SliceSizes,
  IndicesAreSorted)
- Only supports single-axis gather with index vector dim == rank(indices)-1
  and slice_sizes matching R.take semantics
- Validates offset_dims layout, output shape, and collapsed dims against
  expected R.take behavior; raises OpNotImplemented otherwise
- Reshapes indices from [N, 1] to [N] before calling R.take

Tests: 3 new (2 take-equivalent parametrized for axis 0/1,
1 error path for multi-dimensional start_index_map).
Total: 38 stablehlo tests passed.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for a wide range of StableHLO operators in the TFLite frontend for TVM Relax, covering unary, binary, ternary, and more complex operations like gather and dynamic slice. The changes include the core conversion logic and comprehensive unit tests. Feedback points out a bug in a test helper function regarding FlatBuffers vector generation and suggests removing a redundant reshape call in the dynamic slice implementation.

Comment thread tests/python/relax/test_frontend_tflite.py Outdated
Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py Outdated
@Aharrypotter Aharrypotter marked this pull request as ready for review May 11, 2026 12:17
@tlopex tlopex merged commit c0406a5 into apache:main May 11, 2026
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants