feat: RMSNorm backward pass kernels#29
Conversation
|
why is the |
I changed my the design of the second kernel and removed it altogether since it was a huge bottleneck and there was an easier way. I honestly have no clue why bf16 was so much slower then fp16 since I deleted it before investigating it. |
|
We use all bfloat16 so this would be a nonstarter. |
|
@vgoklani Yes I'm aware, that's why I changed the algorithm (it was slow regardless of BF16 vs FP16, though not sure why BF16 was so much worse). Anyway with my new implementation, BF16 vs FP16 performance are now pretty similar. |
|
@hannahli-nv could I get a review + CI on this PR? |
|
/ok to test 24e65b8 |
|
/ok to test 077e8fd |
|
Oops don't know how that # made it in there. Thanks for fixing it |
hannahli-nv
left a comment
There was a problem hiding this comment.
Overall LGTM, thanks for the contribution!
Description
Adds RMSNorm backward pass to TileGym - the first backward kernel implementation
Implementation:
rms_norm_backward_kernel_dx) computes dx row-parallel and stores intermediate values (dy * x * rstd) into a float32 temp_buffertemp_buffer.sum(dim=0)using PyTorch's optimized reduction (avoids a second kernel with different access pattern)rms_norm_backward_torch) for testing/benchmarkingtest_rmsnorm_backward.pyandbench_rmsnorm_backward.pyPerformance
2-5x faster than PyTorch across all dtypes.
Both Torch Reference and cuTILE still do accumulates in FP32
bfloat16
float16
float32
CI Configuration
Checklist
./format.sh)