[Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven last chunk#19341
[Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven last chunk#19341tlopex merged 3 commits intoapache:mainfrom
Conversation
- Remove NotImplementedError for keepdims=0 - Track split_is_scalar to correctly apply keepdims logic - Replace ValueError for uneven chunks with index-based splitting - Add tests for keepdims=0, uneven last chunk, and keepdims ignored cases Fixes apache#18945
There was a problem hiding this comment.
Code Review
This pull request updates the ONNX SplitToSequence implementation in the Relax frontend to support keepdims=0 and uneven last chunks, along with adding corresponding test cases. Feedback highlights a logic error where keepdims is incorrectly applied when a split input is provided, an issue with symbolic dimension handling where chunk size is misinterpreted as the number of sections, and a potential invalid argument error when the dimension size is smaller than the chunk size. Additionally, a redundant test case was identified for removal.
| else: | ||
| split = chunk_size |
There was a problem hiding this comment.
When dim_size is symbolic, setting split = chunk_size is incorrect because relax.op.split interprets an integer argument as the number of sections, not the chunk size. Additionally, since SplitToSequence returns a relax.Tuple, the number of outputs must be statically known at import time. If the number of chunks depends on a dynamic dimension, it cannot be represented as a fixed-size Tuple. This case should likely raise a NotImplementedError as it is not supported by the current Relax frontend architecture.
| if not keepdims and split_is_scalar: | ||
| output = bb.normalize(output) | ||
| n = len(output.struct_info.fields) | ||
| squeezed = [ | ||
| relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) | ||
| for i in range(n) | ||
| ] | ||
| return relax.Tuple(squeezed) |
There was a problem hiding this comment.
The ONNX specification for SplitToSequence states that the keepdims attribute is ignored if the split input is specified. The current implementation applies keepdims=0 whenever split is a scalar, even if it was provided as an input. This also contradicts the PR's own test case test_split_to_sequence_keepdims_ignored_when_split_provided (lines 5506-5527), which expects keepdims to be ignored for a scalar split input. The logic should only apply keepdims when the split input is NOT provided (i.e., len(inputs) == 1). Additionally, it is recommended to emit the output before performing TupleGetItem to avoid redundant expression nodes in the IR.
| if not keepdims and split_is_scalar: | |
| output = bb.normalize(output) | |
| n = len(output.struct_info.fields) | |
| squeezed = [ | |
| relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) | |
| for i in range(n) | |
| ] | |
| return relax.Tuple(squeezed) | |
| # keepdims=0 applies only when split is not provided (defaults to 1) | |
| if not keepdims and len(inputs) == 1: | |
| output = bb.emit(output) | |
| n = len(output.struct_info.fields) | |
| squeezed = [ | |
| relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), axis=[axis]) | |
| for i in range(n) | |
| ] | |
| return relax.Tuple(squeezed) |
| indices = list(range(chunk_size, dim_size_int, chunk_size)) | ||
| split = indices if indices else dim_size_int // chunk_size |
There was a problem hiding this comment.
If dim_size_int is less than chunk_size (e.g., dim=1, chunk=2), indices will be an empty list and dim_size_int // chunk_size will evaluate to 0. Passing 0 as the number of sections to relax.op.split is invalid. It should be 1 in this case to return the entire tensor as a single chunk, which is consistent with the ONNX specification for uneven last chunks.
| indices = list(range(chunk_size, dim_size_int, chunk_size)) | |
| split = indices if indices else dim_size_int // chunk_size | |
| indices = list(range(chunk_size, dim_size_int, chunk_size)) | |
| split = indices if indices else 1 |
| @pytest.mark.parametrize("split", [2, [16, 48]]) | ||
| def test_split_to_sequence(split): | ||
| split_to_sequence_node = helper.make_node( | ||
| "SplitToSequence", | ||
| ["data", "split"], | ||
| ["output"], | ||
| axis=0, | ||
| ) | ||
| split_shape = [len(split)] if isinstance(split, list) else () | ||
| split_node = make_constant_node( | ||
| "split", TensorProto.INT64, split_shape, [split] if isinstance(split, int) else split | ||
| ) | ||
| graph = helper.make_graph( | ||
| [split_node, split_to_sequence_node], | ||
| "test_split_to_sequence", | ||
| inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [64, 32])], | ||
| outputs=[helper.make_tensor_sequence_value_info("output", TensorProto.FLOAT, [32, 32])], | ||
| ) | ||
| model = helper.make_model(graph, producer_name="test_split_to_sequence") | ||
| check_correctness(model) |
…st chunk Addressed Gemini code review: - Use len(inputs)==1 instead of split_is_scalar for keepdims condition - Use bb.emit instead of bb.normalize before TupleGetItem - Use split=1 for empty indices edge case instead of dim_size//chunk_size - Raise NotImplementedError for dynamic dim size with scalar split - Remove duplicate test_split_to_sequence test
|
Hi @OmarAzizi |
|
@tlopex You are right that the integer approach works, but the calculation needed to be For
Updated the implementation accordingly. |
Summary
Fixes two spec violations in
SplitToSequence:keepdims=0 was raising
NotImplementedError. The fix squeezes the split axis from each chunk whensplitis scalar andkeepdims=0. Per spec:Uneven last chunk was raising
ValueError. The spec states: "The last chunk alone may be smaller than 'split' if the input size is not divisible by 'split'." Fixed by using index-based splitting viarange(chunk_size, dim_size, chunk_size)instead of a count.Reference: https://onnx.ai/onnx/operators/onnx__SplitToSequence.html
Closes part of #18945