-
Notifications
You must be signed in to change notification settings - Fork 74
Partition operation for RaggedIterDomain #5674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Review updated until commit 8a73bb2 Description
|
| Relevant files | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| 🔒 No security concerns identified |
| ⚡ Recommended focus areas for review |
Partition Class Design
|
Test failures (partial, pipeline still running)
-
(Medium, 1)
Large numerical mismatch in NVFuser PingPongCircularBuffering test (PingPongCircularBuffering.StageSlicePositionComputeAt)Test Name H100 Source PingPongCircularBuffering.StageSlicePositionComputeAt/stage_slice_position_4 ❌ Link
Greptile SummaryThis PR adds the
The implementation is well-integrated with existing infrastructure (dispatch system, IR patterns) and follows established conventions (similar to Split/Merge operations). The 1D extents limitation is noted with a TODO for future multi-dimensional support. Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User Code
participant TV as TensorView
participant TD as TensorDomain
participant RID as RaggedIterDomain
participant Partition as Partition Expr
User->>TV: partition(axis, extents)
activate TV
TV->>TV: Validation (nDims, compute position, producer position, parallel type)
TV->>TV: Auto-cast extents to Index if needed
TV->>TD: partition(axis, extents)
deactivate TV
activate TD
TD->>RID: partition(id, extents)
deactivate TD
activate RID
RID->>RID: Validation (null checks, IterType, RaggedIterDomain check, extents type/dims)
RID->>RID: Get extents domain
RID->>RID: Create component IterDomain with extent = num_components
RID->>RID: Create RaggedIterDomain with extents tensor
RID->>Partition: Create Partition expr
Partition-->>RID: Partition created
RID-->>TD: {component_id, ragged_id}
deactivate RID
TD->>TD: Remove original axis from loop_domain
TD->>TD: Insert ragged_id at axis position
TD->>TD: Insert component_id at axis position
TD-->>TV: Return
TV-->>User: Return this (modified TensorView)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
|
!test |
csrc/ir/interface_nodes.h
Outdated
| // Returns this TensorView with the axis replaced by component and ragged dims | ||
| // e.g. partition(0, offsets) on tv[id{N}] results in: | ||
| // tv[id{num_components}, ragged_id{extents}] | ||
| TensorView* partition(int64_t axis, TensorView* offsets); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| TensorView* partition(int64_t axis, TensorView* offsets); | |
| TensorView* partition(int64_t axis, TensorView* extents); |
Do we want to do extents first? tokens_per_expert will be extents anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that would make things a little simpler.
|
|
||
| //! Extents tensor containing extent for each component | ||
| TensorView* extents() const { | ||
| return attributeVal(0)->as<TensorView>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be an input? Our convention seems to be to treat Val*s as inputs and treat others like attributes. Not sure whether/how it matters in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm less confident with this design, but at this moment I feel an attribute seems more appropriate to me.
Inputs and outputs in the existing IterDomain exprs are always IterDomains. Intuitively, they take some existing iteration spaces and transform them into something else, which can be affine or non affine.
In that sense, since the offset tensor itself is not transformed in the Partition expr, it doesn't seem to be considered as an input.
Note that in Split, the split factor is an attribute, so that would also suggest the offset tensor should be an attribute.
That said, I don't think none of the existing exprs has tensors as attributes, which makes me less confident with possible implications of this design. It might bite us in some cases where some fusion traversal might miss the tensor as it isn't input nor output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Note that scatter/gather might eventually need something similar, e.g.,
values indices TV
\ /
gather
|
i
The current implementation of gather makes logical and loop disconnected, creating quite some special cases for gather/scatter in the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the values tensor is not connected. I don't have any idea to have an expression to connect them. Not sure if there's anything but open to any suggestion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was trying to describing the idea but I'm sure it landed poorly :)
ID j TV indices
\ /
[gather]
|
ID i
Gather (probably needs a better name) here is an IterDomain operation that connects IterDomain i in loop/logical and IterDomain j in root. TV indices is another input or an attribute of this gather. The math it does is j = indices[i].
The same Gather IterDomain op can be reused for Scatter. However, i would be in loop and j would be in logical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of this. Scatter is supported in the greedy scheduler.
FYI,
Fuser/benchmarks/python/layers_for_inference_benchmark.py
Lines 634 to 635 in 9c2023d
| token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id) | |
| outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, I'm skeptical.
Sure. Will this tensor list exercise give you more confidence either way? Partition is similar -- it connects an IterDomain to a non-containing TensorView.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'll probably hit some problems with Partition too. Will see how it goes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not aware of this. Scatter is supported in the greedy scheduler.
FYI,
Fuser/benchmarks/python/layers_for_inference_benchmark.py
Lines 634 to 635 in 9c2023d
token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id) outs_sorted_by_token_id = outs_sorted_by_expert_id[token_ids_sorted_by_expert_inverse_id]
Interesting. Perhaps, we should rewrite this with scatter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, we should rewrite this with scatter.
Definitely. @jjsjann123, do you remember what issue you ran into? In-place update?
wujingyue
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What would be a good milestone to celebrate? How about getting GroupedMmaOp to work with tensor list abstraction before I take over?
|
!test |
Two major tasks would be |
| // Split, the split factor is an attribute. However, that said, none | ||
| // of the existing exprs has tensors as attributes, which makes this | ||
| // choice less certain with possible implications. | ||
| addAttribute(extents); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly just a mechanical question.
If TensorView is added as an attribute here, should we also add the partition op into extents.uses_?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a very good point, thank you! I think you're right but I'm really not confident with this design. A bit nervous with adding TensorViews as attributes...
Another option might be tracking this dependency as a tensor-level op. In a follow-up PR, I add asNested as a user-facing op but have not so far added a new Expr class but just reuse LoadStoreOp. Perhaps, a new Expr class specifically for this operation should be added to keep track of the use dependency.
I'll create an issue to remember this.
Adds the
Partitioniter domain op to represent the construction of a ragged iter domain from an iter domain. Similar toIterDomain::split,RaggedIterDomain::partitionis added by using thePartitionop.The offsets/extents parameter is still limited to 1D. Will be extended in a follow-up PR.