-
Notifications
You must be signed in to change notification settings - Fork 76
Optimize TMA inner-reduction and add TMA serial-split #5867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Greptile OverviewGreptile SummaryThis PR optimizes the TMA inner-reduction scheduler by implementing serial splits for large reduction dimensions that don't fit in shared memory, achieving performance improvements of up to ~25% reduction in overhead compared to the previous implementation.
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant mayUseTma
participant getReductionHeuristics
participant getTmaSplit
participant scheduleReduction
User->>mayUseTma: Check if TMA can be used
mayUseTma->>mayUseTma: Check total_reduction_bytes >= 16384
mayUseTma->>mayUseTma: Verify 16-byte alignment
mayUseTma-->>User: Return true/false
alt TMA enabled
User->>getReductionHeuristics: Get TMA heuristics
getReductionHeuristics->>getTmaSplit: Find split factor
getTmaSplit->>getTmaSplit: Calculate bounds<br/>(lower_elem_bound, upper_elem_bound)
getTmaSplit->>getTmaSplit: Search divisors up to sqrt(numel)
getTmaSplit->>getTmaSplit: Check alignment & bounds
alt Split factor found
getTmaSplit-->>getReductionHeuristics: Return split factor
getReductionHeuristics->>getReductionHeuristics: Set threads_per_block<br/>(256 or 512 based on mufu)
getReductionHeuristics->>getReductionHeuristics: Calculate unroll_factor
getReductionHeuristics-->>User: Return TmaInnerReductionParams
User->>scheduleReduction: Schedule with params
scheduleReduction->>scheduleReduction: Cache inputs with TMA
scheduleReduction->>scheduleReduction: Apply tma_split_factor split
scheduleReduction->>scheduleReduction: Apply TIDx, Unroll, Unswitch splits
scheduleReduction->>scheduleReduction: Propagate to TMA & non-TMA TVs
scheduleReduction-->>User: Scheduled fusion
else No split factor found
getTmaSplit-->>getReductionHeuristics: Return 0
getReductionHeuristics-->>User: Return nullptr (fallback to non-TMA)
end
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
|
Review updated until commit 6f28318 Description
|
| Relevant files | |||||||
|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Serial Split Logic
|
| NVF_ERROR(has_iter_axis && has_red_axis); | ||
|
|
||
| if (rparams->tma_split_factor > 1) { | ||
| reduction_tv->split(1, rparams->tma_split_factor, false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to consider divisible split for better performance when setting this tma_split_factor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is actually needed for correctness, not just performance. Lowering will fail if the reduction size is not divisible by the split. I believe this is a restriction on 1D TMA.
I added a new function getTmaSplit() to search for a valid split size. This fixes failures with large non-divisible sizes.
Note that now the TMA scheduler will return nullptr during heuristic checking if the splitting fails. This the splitting logic is complicated and I don't think we should duplicate it in mayUseTma().
|
|
||
| for (auto tma_tv : tma_tvs) { | ||
| tma_tv->split(1, rparams->tma_split_factor, false); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need manually split tma_tv, TransformPropagator should be able to propagate the transforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It fails to propagate for bfloat16. Claude thinks this is due to to casts blocking something in the maximum spanning tree.
Thanks for the measurement. Can you also check the performance of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
|
Moved over to #5887 to preserve serial split code on this PR. |
Optimize the TMA inner-reduction scheduler.
To support large shapes that don't fit into smem, implemented a serial split of the TMA axis. This is an alternative to grid reduction which I previously tried, which is slower due to having to synchronize results somehow.
TMA is now ~25% slower than non-TMA in the worst case (previously was 2x slower or worse), and small percentage points faster for a few large shapes:
Other notes: