Skip to content

feat: RMSNorm backward pass kernels#29

Merged
hannahli-nv merged 13 commits intoNVIDIA:mainfrom
aghilann:rms-norm-backward2
Jan 6, 2026
Merged

feat: RMSNorm backward pass kernels#29
hannahli-nv merged 13 commits intoNVIDIA:mainfrom
aghilann:rms-norm-backward2

Conversation

@aghilann
Copy link
Copy Markdown
Contributor

@aghilann aghilann commented Jan 4, 2026

Description

Adds RMSNorm backward pass to TileGym - the first backward kernel implementation

Implementation:

  • Single CuTile kernel (rms_norm_backward_kernel_dx) computes dx row-parallel and stores intermediate values (dy * x * rstd) into a float32 temp_buffer
  • dw is computed via temp_buffer.sum(dim=0) using PyTorch's optimized reduction (avoids a second kernel with different access pattern)
  • All accumulations are done in FP32 regardless of input dtype for numerical stability
  • Added PyTorch reference impl (rms_norm_backward_torch) for testing/benchmarking
  • Added test_rmsnorm_backward.py and bench_rmsnorm_backward.py

Performance

2-5x faster than PyTorch across all dtypes.
Both Torch Reference and cuTILE still do accumulates in FP32

bfloat16

N CuTile (GB/s) PyTorch (GB/s)
1024 1537.03 472.44
2048 2280.10 533.60
4096 2813.37 515.98
8192 3640.76 526.04
16384 3916.75 542.96

float16

N CuTile (GB/s) PyTorch (GB/s)
1024 1563.98 484.67
2048 2361.91 535.59
4096 2927.92 520.76
8192 3566.86 532.16
16384 4023.40 549.59

float32

N CuTile (GB/s) PyTorch (GB/s)
1024 2255.25 853.33
2048 2922.93 931.21
4096 3892.06 901.09
8192 3645.70 922.79
16384 2638.26 955.57

CI Configuration

config:
  build: true
  test: ["ops", "benchmark"]

Checklist

  • Code formatted (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jan 4, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@aghilann aghilann changed the title feat: RMSNorm backward pass feat: RMSNorm backward pass kernels Jan 4, 2026
@vgoklani
Copy link
Copy Markdown

vgoklani commented Jan 4, 2026

why is the bfloat16 implementation so much slower?

@aghilann aghilann marked this pull request as draft January 4, 2026 17:01
@aghilann
Copy link
Copy Markdown
Contributor Author

aghilann commented Jan 4, 2026

why is the bfloat16 implementation so much slower?

I changed my the design of the second kernel and removed it altogether since it was a huge bottleneck and there was an easier way. I honestly have no clue why bf16 was so much slower then fp16 since I deleted it before investigating it.

@vgoklani
Copy link
Copy Markdown

vgoklani commented Jan 4, 2026

We use all bfloat16 so this would be a nonstarter.

@aghilann aghilann marked this pull request as ready for review January 4, 2026 22:17
@aghilann
Copy link
Copy Markdown
Contributor Author

aghilann commented Jan 4, 2026

@vgoklani Yes I'm aware, that's why I changed the algorithm (it was slow regardless of BF16 vs FP16, though not sure why BF16 was so much worse). Anyway with my new implementation, BF16 vs FP16 performance are now pretty similar.

@aghilann
Copy link
Copy Markdown
Contributor Author

aghilann commented Jan 6, 2026

@hannahli-nv could I get a review + CI on this PR?

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 24e65b8

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 077e8fd

@aghilann
Copy link
Copy Markdown
Contributor Author

aghilann commented Jan 6, 2026

Oops don't know how that # made it in there. Thanks for fixing it

Copy link
Copy Markdown
Collaborator

@hannahli-nv hannahli-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, thanks for the contribution!

@hannahli-nv hannahli-nv merged commit d901cf4 into NVIDIA:main Jan 6, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants