Skip to content

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Jan 23, 2026

Optimize the TMA inner-reduction scheduler.

To support large shapes that don't fit into smem, implemented a serial split of the TMA axis. This is an alternative to grid reduction which I previously tried, which is slower due to having to synchronize results somehow.

TMA is now ~25% slower than non-TMA in the worst case (previously was 2x slower or worse), and small percentage points faster for a few large shapes:

Test                                                   TMA (ns)     non-TMA        Diff
--------------------------------------------------------------------------------------
TmaInnerReductionTest.Sum/float_4096                      87552        71968 +    21.7%    // TMA is 21.7% slower here
TmaInnerReductionTest.Sum/float_8192                     158816       142976 +    11.1%
TmaInnerReductionTest.Sum/float_16384                    302560       294912 +     2.6%
TmaInnerReductionTest.Sum/float_32768                    584736       587488      -0.5%
TmaInnerReductionTest.Sum/float_65536                   1138528      1170080      -2.7%    // TMA is faster here
TmaInnerReductionTest.Sum/float_131072                  2251840      2296352      -1.9%
TmaInnerReductionTest.Sum/__bfloat_8192                   87552        69888 +    25.3%
TmaInnerReductionTest.Sum/__bfloat_16384                 159744       143168 +    11.6%
TmaInnerReductionTest.Sum/__bfloat_32768                 302240       294848 +     2.5%
TmaInnerReductionTest.Sum/__bfloat_65536                 584064       589760      -1.0%
TmaInnerReductionTest.Sum/__bfloat_131072               1205568      1164160 +     3.6%
TmaInnerReductionTest.Sum/__bfloat_262144               2376288      2300416 +     3.3%

Other notes:

  • Completely removed vectorization and moved its factor into unroll. Vectorization was always slower loading fp32 from shared memory. Probably due to bank conflicts. Now that I am writing this, I realize fp16 probably should have vectorization so each thread owns one bank.

@tbqh tbqh requested a review from liqiangxl January 23, 2026 11:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile Overview

Greptile Summary

This PR optimizes the TMA inner-reduction scheduler by implementing serial splits for large reduction dimensions that don't fit in shared memory, achieving performance improvements of up to ~25% reduction in overhead compared to the previous implementation.

  • Replaced smem-fit eligibility check with a minimum 16KB transfer size threshold
  • Added getTmaSplit() algorithm to find optimal split factors using divisor-pair search up to sqrt(numel)
  • Removed vectorization from TMA path and consolidated into unroll factor to avoid shared memory bank conflicts
  • Added heuristic-based thread block sizing (256 or 512 threads based on MUFU computation presence)
  • Updated test expectations to handle new eligibility criteria and non-power-2 reduction shapes

Confidence Score: 4/5

  • Safe to merge with careful testing of edge cases in split factor algorithm
  • The implementation is well-structured with clear algorithmic improvements. The getTmaSplit() function uses a mathematically sound divisor-pair search strategy. Thread block heuristics are derived from benchmarking. However, the complexity of the split factor algorithm and interaction between alignment constraints, bounds, and divisor finding warrants thorough testing, especially for edge cases with unusual reduction sizes.
  • Pay attention to csrc/scheduler/reduction_tma.cpp - verify the getTmaSplit() algorithm handles all edge cases correctly, particularly unusual reduction sizes and alignment requirements

Important Files Changed

Filename Overview
csrc/scheduler/reduction.cpp Simplified TMA eligibility check by replacing smem fit constraint with minimum transfer size threshold
csrc/scheduler/reduction_tma.cpp Added serial split algorithm for large reductions, removed vectorization in favor of unroll, optimized thread count heuristics
csrc/scheduler/reduction_tma.h Renamed vectorization_factor to tma_split_factor to reflect new serial split functionality
tests/cpp/test_reduction.cpp Updated test expectations to match new TMA eligibility criteria and handle non-power-2 shapes

Sequence Diagram

sequenceDiagram
    participant User
    participant mayUseTma
    participant getReductionHeuristics
    participant getTmaSplit
    participant scheduleReduction

    User->>mayUseTma: Check if TMA can be used
    mayUseTma->>mayUseTma: Check total_reduction_bytes >= 16384
    mayUseTma->>mayUseTma: Verify 16-byte alignment
    mayUseTma-->>User: Return true/false

    alt TMA enabled
        User->>getReductionHeuristics: Get TMA heuristics
        getReductionHeuristics->>getTmaSplit: Find split factor
        getTmaSplit->>getTmaSplit: Calculate bounds<br/>(lower_elem_bound, upper_elem_bound)
        getTmaSplit->>getTmaSplit: Search divisors up to sqrt(numel)
        getTmaSplit->>getTmaSplit: Check alignment & bounds
        alt Split factor found
            getTmaSplit-->>getReductionHeuristics: Return split factor
            getReductionHeuristics->>getReductionHeuristics: Set threads_per_block<br/>(256 or 512 based on mufu)
            getReductionHeuristics->>getReductionHeuristics: Calculate unroll_factor
            getReductionHeuristics-->>User: Return TmaInnerReductionParams
            User->>scheduleReduction: Schedule with params
            scheduleReduction->>scheduleReduction: Cache inputs with TMA
            scheduleReduction->>scheduleReduction: Apply tma_split_factor split
            scheduleReduction->>scheduleReduction: Apply TIDx, Unroll, Unswitch splits
            scheduleReduction->>scheduleReduction: Propagate to TMA & non-TMA TVs
            scheduleReduction-->>User: Scheduled fusion
        else No split factor found
            getTmaSplit-->>getReductionHeuristics: Return 0
            getReductionHeuristics-->>User: Return nullptr (fallback to non-TMA)
        end
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link

github-actions bot commented Jan 23, 2026

Review updated until commit 6f28318

Description

  • Implement TMA serial split to support large reduction shapes that exceed shared memory

  • Replace smem size checks with total reduction bytes threshold (16KB minimum)

  • Optimize thread block configuration from 128 to 256 threads per block

  • Simplify non-TMA scheduling pattern by removing vectorization splits

  • Add warp padding for improved memory access patterns

Changes walkthrough

Relevant files
Enhancement
reduction.cpp
Update TMA eligibility criteria based on bytes                     

csrc/scheduler/reduction.cpp

  • Update mayUseTma() to use total reduction bytes instead of element
    count
  • Remove small reduction size restriction (< 128 elements)
  • Remove shared memory size restriction for reduction dimension
  • Add 16KB minimum total reduction bytes requirement
  • +4/-10   
    reduction_tma.cpp
    Implement TMA serial split and optimize scheduling             

    csrc/scheduler/reduction_tma.cpp

  • Add TMA split factor calculation based on half shared memory capacity
  • Increase threads per block from 128 to 256
  • Implement serial split for large reduction dimensions
  • Add cacheAndForkOutputs call for memory optimization
  • Simplify non-TMA scheduling by removing vectorization splits
  • Add warp padding for memory access optimization
  • +30/-17 
    reduction_tma.h
    Rename vectorization factor parameter                                       

    csrc/scheduler/reduction_tma.h

  • Rename vectorization_factor to tma_split_factor
  • Update parameter documentation to reflect new purpose
  • +2/-2     
    Tests
    test_reduction.cpp
    Update TMA test expectations for new criteria                       

    tests/cpp/test_reduction.cpp

  • Update expectTmaUsed() to match new mayUseTma() logic
  • Remove small reduction size check (< 128 elements)
  • Remove shared memory size check for reduction dimension
  • Add 16KB total reduction bytes minimum requirement
  • +4/-10   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Serial Split Logic

    The serial split implementation adds a new outer split dimension when tma_split_factor > 1, but there's no explicit synchronization between the serial iterations. Need to verify that the reduction correctly accumulates results across the serial splits without data races or missing reductions.

    if (rparams->tma_split_factor > 1) {
      reduction_tv->split(1, rparams->tma_split_factor, false);
      reduction_tv->axis(1)->parallelize(ParallelType::Serial);
    
      for (auto tma_tv : tma_tvs) {
        tma_tv->split(1, rparams->tma_split_factor, false);
      }
    }
    Performance Regression Risk

    The change from 128 to 256 threads_per_block and removal of vectorization_factor could potentially cause performance regressions on some hardware configurations or problem sizes. The PR shows good results for tested cases, but broader testing across different GPU architectures and problem sizes would be valuable.

    int64_t threads_per_block = 256;
    int64_t unroll_factor = scheduler_utils::lastPow2(target_vect_unroll);

    NVF_ERROR(has_iter_axis && has_red_axis);

    if (rparams->tma_split_factor > 1) {
    reduction_tv->split(1, rparams->tma_split_factor, false);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we need to consider divisible split for better performance when setting this tma_split_factor?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes, this is actually needed for correctness, not just performance. Lowering will fail if the reduction size is not divisible by the split. I believe this is a restriction on 1D TMA.

    I added a new function getTmaSplit() to search for a valid split size. This fixes failures with large non-divisible sizes.

    Note that now the TMA scheduler will return nullptr during heuristic checking if the splitting fails. This the splitting logic is complicated and I don't think we should duplicate it in mayUseTma().


    for (auto tma_tv : tma_tvs) {
    tma_tv->split(1, rparams->tma_split_factor, false);
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Don't need manually split tma_tv, TransformPropagator should be able to propagate the transforms.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    It fails to propagate for bfloat16. Claude thinks this is due to to casts blocking something in the maximum spanning tree.

    @liqiangxl
    Copy link
    Collaborator

    TMA is now ~25% slower than non-TMA in the worst case (previously was 2x slower or worse), and small percentage points faster for a few large shapes:

    Test                                                   TMA (ns)     non-TMA        Diff
    --------------------------------------------------------------------------------------
    TmaInnerReductionTest.Sum/float_4096                      87552        71968 +    21.7%    // TMA is 21.7% slower here
    TmaInnerReductionTest.Sum/float_8192                     158816       142976 +    11.1%
    TmaInnerReductionTest.Sum/float_16384                    302560       294912 +     2.6%
    TmaInnerReductionTest.Sum/float_32768                    584736       587488      -0.5%
    TmaInnerReductionTest.Sum/float_65536                   1138528      1170080      -2.7%    // TMA is faster here
    TmaInnerReductionTest.Sum/float_131072                  2251840      2296352      -1.9%
    TmaInnerReductionTest.Sum/__bfloat_8192                   87552        69888 +    25.3%
    TmaInnerReductionTest.Sum/__bfloat_16384                 159744       143168 +    11.6%
    TmaInnerReductionTest.Sum/__bfloat_32768                 302240       294848 +     2.5%
    TmaInnerReductionTest.Sum/__bfloat_65536                 584064       589760      -1.0%
    TmaInnerReductionTest.Sum/__bfloat_131072               1205568      1164160 +     3.6%
    TmaInnerReductionTest.Sum/__bfloat_262144               2376288      2300416 +     3.3%
    

    Thanks for the measurement. Can you also check the performance of https://github.com/NVIDIA/Fuser/blob/main/benchmarks/python/test_reduction.py? It includes both inner and outer reduction and we only need to check inner reduction with batch size of 16384

    @liqiangxl liqiangxl closed this Jan 23, 2026
    @liqiangxl liqiangxl reopened this Jan 23, 2026
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @tbqh tbqh closed this Jan 28, 2026
    @tbqh
    Copy link
    Collaborator Author

    tbqh commented Jan 28, 2026

    Moved over to #5887 to preserve serial split code on this PR.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants