Skip to content

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Dec 13, 2025

Adds the Partition iter domain op to represent the construction of a ragged iter domain from an iter domain. Similar to IterDomain::split, RaggedIterDomain::partition is added by using the Partition op.

The offsets/extents parameter is still limited to 1D. Will be extended in a follow-up PR.

@github-actions
Copy link

github-actions bot commented Dec 13, 2025

Review updated until commit 8a73bb2

Description

  • Adds Partition operation to split IterDomain into component and ragged dimensions

  • Implements RaggedIterDomain::partition() static method with validation checks

  • Adds TensorDomain::partition() and TensorView::partition() methods following split pattern

  • Includes comprehensive tests for partition functionality and validation

Changes walkthrough

Relevant files
Enhancement
internal_base_nodes.cpp
Implement partition methods for RaggedIterDomain and TensorDomain

csrc/ir/internal_base_nodes.cpp

  • Added RaggedIterDomain::partition() static method creating component
    and ragged IterDomains
  • Added TensorDomain::partition() method following split pattern
  • Includes validation for input IterDomain, extents tensor type and
    dimensionality
  • +75/-0   
    internal_nodes.cpp
    Define Partition expression class                                               

    csrc/ir/internal_nodes.cpp

  • Added Partition class definition with component and ragged outputs
  • Implemented partition expression with input IterDomain and extents
    attribute
  • Added toString() and clone methods for Partition operation
  • +41/-0   
    tensor_view.cpp
    Implement TensorView partition method                                       

    csrc/tensor_view.cpp

  • Added TensorView::partition() method with validation checks
  • Follows pattern of TensorView::split with compute position validation
  • Includes parallel type validation and automatic Index type casting
  • +45/-0   
    dispatch.h
    Add Partition to dispatch registry                                             

    csrc/dispatch.h

  • Added Partition to the expression dispatch list for visitor pattern
    support
  • +1/-0     
    interface_nodes.h
    Declare TensorView partition method                                           

    csrc/ir/interface_nodes.h

  • Added TensorView::partition() method declaration with documentation
  • Method partitions axis into component and ragged dimensions using
    extents tensor
  • +8/-0     
    internal_base_nodes.h
    Declare RaggedIterDomain and TensorDomain partition methods

    csrc/ir/internal_base_nodes.h

  • Added RaggedIterDomain::partition() static method declaration with
    documentation
  • Added TensorDomain::partition() method declaration
  • Includes TODO note for multi-dimensional extents support
  • +20/-0   
    internal_nodes.h
    Declare Partition expression class                                             

    csrc/ir/internal_nodes.h

  • Added Partition class declaration with component(), ragged(), in(),
    and extents() accessors
  • Inherits from Expr with clone and create functionality
  • +44/-0   
    Tests
    test_ragged_iter_domain.cpp
    Add comprehensive partition operation tests                           

    tests/cpp/test_ragged_iter_domain.cpp

  • Added PartitionBasic test verifying partition operation creation and
    structure
  • Added PartitionValidation test covering null inputs, type checks, and
    error cases
  • Added TensorViewPartition test verifying TensorView integration
  • +140/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Partition Class Design

    The Partition class stores the extents tensor as an attribute rather than an input. This design choice is documented in the code but differs from typical IterDomain expressions where inputs/outputs are always IterDomains. While this seems intentional for this specific case, it should be validated that this doesn't cause issues with IR traversal, cloning, or other IR operations that expect tensor attributes to be properly handled.

    // Note: extents is held as an attribute rather than an input,
    // despite it's a TensorView. Inputs and outputs in the existing
    // IterDomain exprs are always IterDomains. Intuitively, they
    // transform input iteration spaces into output iteration spaces in
    // some way. Since the extents tensor itself is not transformed in the
    // Partition expr, it doesn't seem to be considered as an input. Note that in
    // Split, the split factor is an attribute. However, that said, none
    // of the existing exprs has tensors as attributes, which makes this
    // choice less certain with possible implications.
    addAttribute(extents);
    Error Handling Completeness

    The partition method includes comprehensive validation checks, but the error messages could be more specific about the actual values that caused failures. For example, when reporting IterType mismatches, including the actual IterType value in the error message would help with debugging.

    NVF_ERROR_EQ(
        in->getIterType(),
        IterType::Iteration,
        "partition: only IterType::Iteration is supported, got ",
        in->getIterType(),
        " for IterDomain: ",
        in->toString());

    Test failures (partial, pipeline still running)

    • (Medium, 1) Large numerical mismatch in NVFuser PingPongCircularBuffering test (PingPongCircularBuffering.StageSlicePositionComputeAt)

      Test Name H100 Source
      PingPongCircularBuffering.StageSlicePositionComputeAt/stage_slice_position_4 Link

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 13, 2025

    Greptile Summary

    This PR adds the Partition iter domain operation to enable construction of ragged iter domains from regular iter domains. The implementation introduces:

    • New Partition Expr node (csrc/ir/internal_nodes.{h,cpp}): Transforms a single IterDomain into component and ragged dimensions, storing extents as an attribute (following Split's pattern)
    • RaggedIterDomain::partition static method (csrc/ir/internal_base_nodes.cpp): Core logic implementing partitioning with comprehensive validation (null checks, IterType/ParallelType constraints, extents tensor validation)
    • TensorDomain::partition instance method (csrc/ir/internal_base_nodes.cpp): Applies partition to a specific axis while maintaining domain consistency (follows split pattern)
    • TensorView::partition public API (csrc/tensor_view.cpp, csrc/ir/interface_nodes.h): User-facing wrapper with additional validation and automatic Index dtype casting
    • Comprehensive test coverage (tests/cpp/test_ragged_iter_domain.cpp): Basic functionality, all validation error cases, and integration with TensorView

    The implementation is well-integrated with existing infrastructure (dispatch system, IR patterns) and follows established conventions (similar to Split/Merge operations). The 1D extents limitation is noted with a TODO for future multi-dimensional support.

    Confidence Score: 5/5

    • This PR is safe to merge - introduces well-tested functionality with robust validation and clear patterns following established conventions
    • The implementation demonstrates high quality: (1) Comprehensive validation at multiple layers prevents invalid states (null checks, type checks, dimensionality checks); (2) Clear separation of concerns across layers (IR node, RaggedIterDomain static method, TensorDomain, TensorView wrapper); (3) Excellent test coverage including both happy path and all validation error cases; (4) Consistent with existing codebase patterns (Split/Merge operations, attribute storage for transform factors); (5) Well-documented with clear comments explaining design decisions (e.g., extents as attribute); (6) Logical correctness verified in axis insertion order matching test expectations
    • No files require special attention

    Important Files Changed

    Filename Overview
    csrc/ir/internal_base_nodes.cpp Implements core partition logic: RaggedIterDomain::partition() with comprehensive validation (null checks, IterType validation, extents type/dimensionality checks) and TensorDomain::partition() that follows the split pattern for axis transformation. All error handling is explicit and clear.
    csrc/ir/internal_nodes.cpp Implements Partition constructor with detailed comment explaining the design decision to store extents as an attribute rather than input (maintains consistency with IterDomain transform semantics). Implements toString/toInlineString following the pattern of similar ops. CLONE and CREATE macros properly defined.
    csrc/tensor_view.cpp Implements TensorView::partition() wrapper following the split pattern with proper validation checks (non-zero dims, compute position, producer position, Serial parallelization). Auto-casts extents to Index type if needed. Correctly delegates to domain()->partition().
    tests/cpp/test_ragged_iter_domain.cpp Adds comprehensive tests for Partition operation: basic functionality test, validation tests (null inputs, wrong dtypes, wrong dimensions, wrong IterType, RaggedIterDomain input), and TensorView::partition integration test. All major error cases and happy path are covered.

    Sequence Diagram

    sequenceDiagram
        participant User as User Code
        participant TV as TensorView
        participant TD as TensorDomain
        participant RID as RaggedIterDomain
        participant Partition as Partition Expr
        
        User->>TV: partition(axis, extents)
        activate TV
        TV->>TV: Validation (nDims, compute position, producer position, parallel type)
        TV->>TV: Auto-cast extents to Index if needed
        TV->>TD: partition(axis, extents)
        deactivate TV
        
        activate TD
        TD->>RID: partition(id, extents)
        deactivate TD
        
        activate RID
        RID->>RID: Validation (null checks, IterType, RaggedIterDomain check, extents type/dims)
        RID->>RID: Get extents domain
        RID->>RID: Create component IterDomain with extent = num_components
        RID->>RID: Create RaggedIterDomain with extents tensor
        RID->>Partition: Create Partition expr
        Partition-->>RID: Partition created
        RID-->>TD: {component_id, ragged_id}
        deactivate RID
        
        TD->>TD: Remove original axis from loop_domain
        TD->>TD: Insert ragged_id at axis position
        TD->>TD: Insert component_id at axis position
        TD-->>TV: Return
        
        TV-->>User: Return this (modified TensorView)
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 13, 2025

    !test

    @naoyam naoyam requested a review from wujingyue December 13, 2025 08:05
    // Returns this TensorView with the axis replaced by component and ragged dims
    // e.g. partition(0, offsets) on tv[id{N}] results in:
    // tv[id{num_components}, ragged_id{extents}]
    TensorView* partition(int64_t axis, TensorView* offsets);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Suggested change
    TensorView* partition(int64_t axis, TensorView* offsets);
    TensorView* partition(int64_t axis, TensorView* extents);

    Do we want to do extents first? tokens_per_expert will be extents anyway.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Sure, that would make things a little simpler.


    //! Extents tensor containing extent for each component
    TensorView* extents() const {
    return attributeVal(0)->as<TensorView>();
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    should this be an input? Our convention seems to be to treat Val*s as inputs and treat others like attributes. Not sure whether/how it matters in practice.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah, I'm less confident with this design, but at this moment I feel an attribute seems more appropriate to me.

    Inputs and outputs in the existing IterDomain exprs are always IterDomains. Intuitively, they take some existing iteration spaces and transform them into something else, which can be affine or non affine.

    In that sense, since the offset tensor itself is not transformed in the Partition expr, it doesn't seem to be considered as an input.

    Note that in Split, the split factor is an attribute, so that would also suggest the offset tensor should be an attribute.

    That said, I don't think none of the existing exprs has tensors as attributes, which makes me less confident with possible implications of this design. It might bite us in some cases where some fusion traversal might miss the tensor as it isn't input nor output.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Sure. Note that scatter/gather might eventually need something similar, e.g.,

    values   indices TV
          \ /
        gather
          |
          i
    

    The current implementation of gather makes logical and loop disconnected, creating quite some special cases for gather/scatter in the code.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Right, the values tensor is not connected. I don't have any idea to have an expression to connect them. Not sure if there's anything but open to any suggestion!

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I was trying to describing the idea but I'm sure it landed poorly :)

       ID j  TV indices
          \ /
       [gather]
          |
         ID i
    

    Gather (probably needs a better name) here is an IterDomain operation that connects IterDomain i in loop/logical and IterDomain j in root. TV indices is another input or an attribute of this gather. The math it does is j = indices[i].

    The same Gather IterDomain op can be reused for Scatter. However, i would be in loop and j would be in logical.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm not aware of this. Scatter is supported in the greedy scheduler.

    FYI,

    token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id)
    outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id]

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Again, I'm skeptical.

    Sure. Will this tensor list exercise give you more confidence either way? Partition is similar -- it connects an IterDomain to a non-containing TensorView.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yeah, I'll probably hit some problems with Partition too. Will see how it goes.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm not aware of this. Scatter is supported in the greedy scheduler.

    FYI,

    token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id)
    outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id]

    Interesting. Perhaps, we should rewrite this with scatter.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Perhaps, we should rewrite this with scatter.

    Definitely. @jjsjann123, do you remember what issue you ran into? In-place update?

    Copy link
    Collaborator

    @wujingyue wujingyue left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What would be a good milestone to celebrate? How about getting GroupedMmaOp to work with tensor list abstraction before I take over?

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    !test

    @naoyam
    Copy link
    Collaborator Author

    naoyam commented Dec 18, 2025

    What would be a good milestone to celebrate? How about getting GroupedMmaOp to work with tensor list abstraction before I take over?

    Two major tasks would be combine and multi-dim partition/combine. At that point, it should be possible to express the token shuffling operation for expert parallelism.

    @naoyam naoyam merged commit 195bd5e into main Dec 18, 2025
    53 of 58 checks passed
    @naoyam naoyam deleted the raggediterdomain_partition branch December 18, 2025 21:32
    // Split, the split factor is an attribute. However, that said, none
    // of the existing exprs has tensors as attributes, which makes this
    // choice less certain with possible implications.
    addAttribute(extents);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Mostly just a mechanical question.

    If TensorView is added as an attribute here, should we also add the partition op into extents.uses_?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    That's a very good point, thank you! I think you're right but I'm really not confident with this design. A bit nervous with adding TensorViews as attributes...

    Another option might be tracking this dependency as a tensor-level op. In a follow-up PR, I add asNested as a user-facing op but have not so far added a new Expr class but just reuse LoadStoreOp. Perhaps, a new Expr class specifically for this operation should be added to keep track of the use dependency.

    I'll create an issue to remember this.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants