[FEA] Add SwiGLU backward pass implementation, test cases and benchmark#46
Conversation
6b335d1 to
4a976eb
Compare
|
Hi @Weili-0234 , thank you for your contribution. |
|
/ok to test 4a976eb |
|
Hi @hannahli-nv, thank you for running the CI! I noticed there's 1 out of 284 test failure in Here's the raw error output: I see that
The max difference (0.0156) is just slightly above the float16 default atol (0.01). This appears to be inherent FP16 precision variance affecting 1 element out of 16M. Since this failing case
Let me know what you prefer. I'm happy to make the change myself, or feel free to push a fix directly to the PR branch if that's easier. Also, I've submitted the CLA document as requested. Thank you! |
Hi @Weili-0234 , we have received your signed CLA file.
|
| @pytest.mark.parametrize("backend", _backends) | ||
| def test_op(self, batch_size, seq_len, hidden_size, intermediate_size, backend, arch): | ||
| """Test for functional correctness of SwiGLU implementation""" | ||
| def test_forward(self, batch_size, seq_len, hidden_size, intermediate_size, backend, arch): |
There was a problem hiding this comment.
Please rename the functions in test/ops/test_swiglu.py so that they start with "test_op". Because only tests whose names contain the substring "test_op" will be run in the TileGym CI.
For example, you can rename "test_forward" to "test_op_forward".
Please also update the other function names in this file.
Thx.
| def get_providers(): | ||
| providers = [("torch", "PyTorch", ("green", "-"))] | ||
| if is_backend_available("cutile"): | ||
| providers.insert(0, ("tilegym", "TileGym", ("orange", "-"))) |
There was a problem hiding this comment.
Please don't name the backend as "TileGym", just set it as "CuTile" like the other benchmark files do.
| line_names=list(names), | ||
| styles=list(styles), | ||
| ylabel="GB/s", | ||
| plot_name=f"fused-swiglu-mlp-{mode_name}-{model_name}-bs{batch_size}", |
There was a problem hiding this comment.
Please add "-GBps" at the end of plot_name to make it able to be recognized by the CI summary page. You can refer to bench_silu_and_mul.py as an example.
| line_names=list(names), | ||
| styles=list(styles), | ||
| ylabel="GB/s", | ||
| plot_name=f"silu-and-mul-{mode_name}-hidden{hidden_size}-{dtype_name}", |
There was a problem hiding this comment.
Same as commented in tests/benchmark/bench_fused_swiglu_mlp.py.
| line_names=list(names), | ||
| styles=list(styles), | ||
| ylabel="GB/s", | ||
| plot_name=f"swiglu-{mode_name}-hidden{hidden_size}-{dtype_name}", |
There was a problem hiding this comment.
Same as commented in tests/benchmark/bench_fused_swiglu_mlp.py.
- Remove failing test case (16,512,2048,float16) due to FP16 precision on sm120 - Rename TileGym backend to CuTile in bench_fused_swiglu_mlp.py - Add -GBps suffix to plot_name in all benchmark files for CI recognition - Add new benchmarks to tests/benchmark/README.md - Rename test functions in test_swiglu.py to start with test_op_ prefix Requested-by: @hannahli-nv
|
Hi @hannahli-nv |
|
/ok to test 8b35a72 |
- Add silu_and_mul_backward_kernel_row_wise CuTile kernel - Add SiLUAndMulFunction autograd wrapper for silu_and_mul - Add swiglu_backward_kernel CuTile kernel - Implement SiLUMulFunction.backward() for swiglu - Update PartiallyFusedSwiGLUMLP to use torch.matmul when requires_grad=True Uses recomputation strategy to save memory during backward pass.
- Expand test_silu_and_mul.py with regular and irregular shape tests
- Expand test_swiglu.py with backward tests and irregular shapes
- Add test_fused_swiglu_backward.py for PartiallyFusedSwiGLUMLP
Coverage includes:
- Regular: bs={8,16,32}, seq_len={512,1024}, hidden={512,1024,2048}
- Irregular: prime batch (7,11,13), odd seq_len (100,333), non-power-of-2 hidden (1000,1500,3000)
66 tests passing, 1 skipped (OOM on 5070)
- Add bench_silu_and_mul_backward.py for silu_and_mul fwd/bwd - Add bench_fused_swiglu_mlp.py for end-to-end MLP benchmark - Add bench_swiglu_backward.py for SwiGLU activation benchmark Benchmarks show ~1.7x speedup over PyTorch for forward, ~1.35x speedup for backward pass.
- Change ylabel from 'ms' to 'GB/s' - Add memory bandwidth calculation for forward/backward/full modes - Consistent with existing TileGym benchmark style (bench_silu_and_mul.py)
Follow tests/ops/README.md guidelines: test methods should be named test_op.
- Remove failing test case (16,512,2048,float16) due to FP16 precision on sm120 - Rename TileGym backend to CuTile in bench_fused_swiglu_mlp.py - Add -GBps suffix to plot_name in all benchmark files for CI recognition - Add new benchmarks to tests/benchmark/README.md - Rename test functions in test_swiglu.py to start with test_op_ prefix Requested-by: @hannahli-nv
8b35a72 to
b40796b
Compare
|
/ok to test b40796b |
Description
This PR adds backward pass support for
PartiallyFusedSwiGLUMLPand related SwiGLU operations, enabling training with these fused kernels. The implementation uses a recomputation strategy (similar to Liger-Kernel's SwiGLU backward) that recomputes intermediate activations during backward instead of saving them, reducing memory usage.Related Issue
NVIDIA/cutile-python#15
Implementation
Code Changes
silu_and_mul.py: Added backward kernel and autograd wrappersilu_and_mul_backward_kernel_row_wise: CuTile kernel for backward passSiLUAndMulFunction:torch.autograd.Functionwrappersilu_and_mul()to use autograd whenrequires_grad=Trueswiglu.py: Added backward kernel and completed autograd implementationswiglu_backward_kernel: CuTile kernel for the 2-input SwiGLU variantSiLUMulFunction.backward()fused_swiglu.py: UpdatedPartiallyFusedSwiGLUMLP.forward()requires_gradand falls back totorch.matmulfor training (sincetilegym.ops.matmuldoesn't have backward support yet)tilegym.ops.matmulGradient Math
Performance
Benchmarks on NVIDIA GeForce RTX 5070 Ti (16GB, CUDA 13.1).
Forward Pass - float32 (GB/s)
Forward Pass - bfloat16 (GB/s)
Backward Pass - float32 (GB/s)
Backward Pass - bfloat16 (GB/s)
Full (Forward + Backward) - float32 (GB/s)
Full (Forward + Backward) - bfloat16 (GB/s)
Run benchmarks:
Testing
Verification Strategy
The testing implementation strictly mirrors the pattern in
tests/ops/test_rmsnorm_backward.py:common.PyTestCaseand methods are namedtest_op(followingtests/ops/README.md).self.assertCorrectness()(intest_silu_and_mul.py) to verify gradients against PyTorch reference locally.silu_and_mul: Validates backward pass by passinggradientarg toassertCorrectness.PartiallyFusedSwiGLUMLP: Manually compares gradients (.grad) of inputs and weights against a vanilla PyTorch implementation (F.linear+F.silu).Coverage
Run tests:
Usage
CI Configuration
Checklist
./format.sh)