Skip to content

[FEA] Add SwiGLU backward pass implementation, test cases and benchmark#46

Merged
hannahli-nv merged 6 commits intoNVIDIA:mainfrom
Weili-0234:feat/swiglu-backward-pass
Feb 9, 2026
Merged

[FEA] Add SwiGLU backward pass implementation, test cases and benchmark#46
hannahli-nv merged 6 commits intoNVIDIA:mainfrom
Weili-0234:feat/swiglu-backward-pass

Conversation

@Weili-0234
Copy link
Copy Markdown
Contributor

Description

This PR adds backward pass support for PartiallyFusedSwiGLUMLP and related SwiGLU operations, enabling training with these fused kernels. The implementation uses a recomputation strategy (similar to Liger-Kernel's SwiGLU backward) that recomputes intermediate activations during backward instead of saving them, reducing memory usage.

Related Issue

NVIDIA/cutile-python#15

Implementation

Code Changes

  1. silu_and_mul.py: Added backward kernel and autograd wrapper

    • silu_and_mul_backward_kernel_row_wise: CuTile kernel for backward pass
    • SiLUAndMulFunction: torch.autograd.Function wrapper
    • Updated silu_and_mul() to use autograd when requires_grad=True
  2. swiglu.py: Added backward kernel and completed autograd implementation

    • swiglu_backward_kernel: CuTile kernel for the 2-input SwiGLU variant
    • Implemented SiLUMulFunction.backward()
  3. fused_swiglu.py: Updated PartiallyFusedSwiGLUMLP.forward()

    • Auto-detects requires_grad and falls back to torch.matmul for training (since tilegym.ops.matmul doesn't have backward support yet)
    • Inference path still uses optimized tilegym.ops.matmul

Gradient Math

Forward:  c = silu(a) * b = a * σ(a) * b

Backward (recomputation):
  σ(a) = sigmoid(a)
  d_silu/da = σ(a) * (1 + a * (1 - σ(a)))
  
  da = dc * b * d_silu/da
  db = dc * silu(a)

Performance

Benchmarks on NVIDIA GeForce RTX 5070 Ti (16GB, CUDA 13.1).

Forward Pass - float32 (GB/s)

M hidden=2048 CuTile hidden=2048 PyTorch hidden=4096 CuTile hidden=4096 PyTorch hidden=11008 CuTile hidden=11008 PyTorch
1024 677 513 681 557 630 458
2048 711 540 733 501 655 464
4096 742 495 754 463 669 464
8192 760 461 770 464 676 465
16384 775 464 778 467 681 465

Forward Pass - bfloat16 (GB/s)

M hidden=2048 CuTile hidden=2048 PyTorch hidden=4096 CuTile hidden=4096 PyTorch hidden=11008 CuTile hidden=11008 PyTorch
1024 664 361 675 423 647 420
2048 716 417 705 453 698 421
4096 733 449 743 417 725 437
8192 758 411 762 434 742 445
16384 774 431 773 445 752 449

Backward Pass - float32 (GB/s)

M hidden=2048 CuTile hidden=2048 PyTorch hidden=4096 CuTile hidden=4096 PyTorch hidden=11008 CuTile hidden=11008 PyTorch
1024 344 217 284 166 250 135
2048 286 165 262 139 250 136
4096 263 138 258 136 250 136
8192 258 136 259 136 251 136
16384 258 136 258 136 251 136

Backward Pass - bfloat16 (GB/s)

M hidden=2048 CuTile hidden=2048 PyTorch hidden=4096 CuTile hidden=4096 PyTorch hidden=11008 CuTile hidden=11008 PyTorch
1024 344 201 342 198 263 138
2048 347 198 284 154 251 135
4096 288 155 261 135 251 135
8192 263 135 258 135 252 135
16384 258 135 258 136 253 135

Full (Forward + Backward) - float32 (GB/s)

M hidden=2048 CuTile hidden=2048 PyTorch hidden=4096 CuTile hidden=4096 PyTorch hidden=11008 CuTile hidden=11008 PyTorch
1024 450 307 374 233 328 185
2048 374 232 349 191 328 185
4096 350 191 344 185 328 185
8192 344 185 346 186 330 185
16384 344 185 345 186 330 185

Full (Forward + Backward) - bfloat16 (GB/s)

M hidden=2048 CuTile hidden=2048 PyTorch hidden=4096 CuTile hidden=4096 PyTorch hidden=11008 CuTile hidden=11008 PyTorch
1024 240 298 467 285 348 189
2048 452 285 374 214 335 183
4096 345 214 349 184 335 184
8192 350 183 344 184 336 184
16384 346 183 345 184 338 184

Run benchmarks:

python tests/benchmark/bench_silu_and_mul_backward.py
python tests/benchmark/bench_swiglu_backward.py
python tests/benchmark/bench_fused_swiglu_mlp.py

Testing

Verification Strategy

The testing implementation strictly mirrors the pattern in tests/ops/test_rmsnorm_backward.py:

  1. Naming Convention: Test class inherits from common.PyTestCase and methods are named test_op (following tests/ops/README.md).
  2. Correctness Check: Uses self.assertCorrectness() (in test_silu_and_mul.py) to verify gradients against PyTorch reference locally.
  3. Gradient Verification:
    • For silu_and_mul: Validates backward pass by passing gradient arg to assertCorrectness.
    • For PartiallyFusedSwiGLUMLP: Manually compares gradients (.grad) of inputs and weights against a vanilla PyTorch implementation (F.linear + F.silu).

Coverage

  • 66 passed, 1 skipped (OOM on large hidden size)
  • Expanded test coverage with:
    • Regular shapes: bs={8,16,32}, seq_len={512,1024}, hidden={512,1024,2048,4096}
    • Irregular shapes: prime batch sizes (7,11,13), odd seq_len (100,127,333), non-power-of-2 hidden (300,500,750,1000,1500,3000)

Run tests:

pytest tests/ops/test_silu_and_mul.py tests/ops/test_swiglu.py tests/ops/test_fused_swiglu_backward.py -v

Usage

mlp = PartiallyFusedSwiGLUMLP(config)
output = mlp(x)  # x.requires_grad needs to be True
loss = criterion(output, target)
loss.backward() 

CI Configuration

config:
  build: true
  # valid options are "ops" and "benchmark"
  test: ["ops", "benchmark"]

Checklist

  • Code formatted and imports sorted via repo specifications (./format.sh)
  • Documentation updated (if needed)
  • CI configuration reviewed

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@Weili-0234 Weili-0234 force-pushed the feat/swiglu-backward-pass branch from 6b335d1 to 4a976eb Compare February 1, 2026 08:19
@hannahli-nv
Copy link
Copy Markdown
Collaborator

Hi @Weili-0234 , thank you for your contribution.
As this is your first time contributing to TileGym, please submit your signed CLA document as described in CONTRIBUTING.md.
Thank you very much.

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 4a976eb

@Weili-0234
Copy link
Copy Markdown
Contributor Author

Hi @hannahli-nv, thank you for running the CI!

I noticed there's 1 out of 284 test failure in tests/ops/test_silu_and_mul.py where I expanded the test coverage with additional shape combinations. The failing test case is

test_op[cutile-16-512-2048-torch.float16]

Here's the raw error output:

AssertionError:
	*** OUTPUT 0 DID NOT MATCH THE REFERENCE (rtol=0.0, atol=0.01) ***
		allclose: False
		matched: 16777215 / 16777216 [100.00%]
		ref range:    -1.4516e+01 :  1.7969e+01
		test range:   -1.4523e+01 :  1.7953e+01
		|ref| range:   0.0000e+00 :  1.7969e+01
		|test| range:  0.0000e+00 :  1.7953e+01
		max absolute difference:  1.5625e-02
		max relative change:      1.4286e-01
		max max mean change:      1.2500e-01
		max arith mean change:    1.3333e-01
		shape: torch.Size([16, 512, 2048]) stride: (1048576, 2048, 1) dtype: torch.float16
		mismatched indices:tensor([[   7,  192, 2033]])
self = <tests.ops.test_silu_and_mul.Test_SiLUAndMul object at 0x7ef3d86b9b70>
batch_size = 16, seq_len = 512, hidden_size = 2048, dtype = torch.float16
backend = 'cutile', arch = 'sm120'

I see that tests/common.py has built-in dtype tolerance detection via get_dtype_tolerances():

  • float16: rtol=1e-2, atol=1e-2
  • bfloat16: rtol=1e-2, atol=2e-2
  • float32: rtol=1e-5, atol=1e-8

The max difference (0.0156) is just slightly above the float16 default atol (0.01). This appears to be inherent FP16 precision variance affecting 1 element out of 16M.

Since this failing case (16, 512, 2048, torch.float16) is one I added (not in the original test suite), I can either:

  • Remove this specific test case
  • Increase atol to 2e-2 to match bfloat16's tolerance
  • Revert all my test expansions and keep only the original 3 cases

Let me know what you prefer. I'm happy to make the change myself, or feel free to push a fix directly to the PR branch if that's easier.

Also, I've submitted the CLA document as requested. Thank you!

@hannahli-nv
Copy link
Copy Markdown
Collaborator

hannahli-nv commented Feb 4, 2026

Since this failing case (16, 512, 2048, torch.float16) is one I added (not in the original test suite), I can either:

  • Remove this specific test case
  • Increase atol to 2e-2 to match bfloat16's tolerance
  • Revert all my test expansions and keep only the original 3 cases

Let me know what you prefer. I'm happy to make the change myself, or feel free to push a fix directly to the PR branch if that's easier.

Hi @Weili-0234 , we have received your signed CLA file.
For your solutions, you can remove the specific test case.
Additionally, there are some more requested changes:

Comment thread tests/ops/test_swiglu.py Outdated
@pytest.mark.parametrize("backend", _backends)
def test_op(self, batch_size, seq_len, hidden_size, intermediate_size, backend, arch):
"""Test for functional correctness of SwiGLU implementation"""
def test_forward(self, batch_size, seq_len, hidden_size, intermediate_size, backend, arch):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename the functions in test/ops/test_swiglu.py so that they start with "test_op". Because only tests whose names contain the substring "test_op" will be run in the TileGym CI.
For example, you can rename "test_forward" to "test_op_forward".
Please also update the other function names in this file.
Thx.

def get_providers():
providers = [("torch", "PyTorch", ("green", "-"))]
if is_backend_available("cutile"):
providers.insert(0, ("tilegym", "TileGym", ("orange", "-")))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't name the backend as "TileGym", just set it as "CuTile" like the other benchmark files do.

line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"fused-swiglu-mlp-{mode_name}-{model_name}-bs{batch_size}",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add "-GBps" at the end of plot_name to make it able to be recognized by the CI summary page. You can refer to bench_silu_and_mul.py as an example.

line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"silu-and-mul-{mode_name}-hidden{hidden_size}-{dtype_name}",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as commented in tests/benchmark/bench_fused_swiglu_mlp.py.

line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"swiglu-{mode_name}-hidden{hidden_size}-{dtype_name}",
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as commented in tests/benchmark/bench_fused_swiglu_mlp.py.

Weili-0234 added a commit to Weili-0234/TileGym that referenced this pull request Feb 4, 2026
- Remove failing test case (16,512,2048,float16) due to FP16 precision on sm120
- Rename TileGym backend to CuTile in bench_fused_swiglu_mlp.py
- Add -GBps suffix to plot_name in all benchmark files for CI recognition
- Add new benchmarks to tests/benchmark/README.md
- Rename test functions in test_swiglu.py to start with test_op_ prefix

Requested-by: @hannahli-nv
@Weili-0234
Copy link
Copy Markdown
Contributor Author

Hi @hannahli-nv
Thank you for the detailed review! I've addressed all your feedback in the latest push.
Please let me know if there's anything else that needs adjustment!

@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test 8b35a72

Copy link
Copy Markdown
Collaborator

@hannahli-nv hannahli-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM!

- Add silu_and_mul_backward_kernel_row_wise CuTile kernel
- Add SiLUAndMulFunction autograd wrapper for silu_and_mul
- Add swiglu_backward_kernel CuTile kernel
- Implement SiLUMulFunction.backward() for swiglu
- Update PartiallyFusedSwiGLUMLP to use torch.matmul when requires_grad=True

Uses recomputation strategy to save memory during backward pass.
- Expand test_silu_and_mul.py with regular and irregular shape tests
- Expand test_swiglu.py with backward tests and irregular shapes
- Add test_fused_swiglu_backward.py for PartiallyFusedSwiGLUMLP

Coverage includes:
- Regular: bs={8,16,32}, seq_len={512,1024}, hidden={512,1024,2048}
- Irregular: prime batch (7,11,13), odd seq_len (100,333), non-power-of-2 hidden (1000,1500,3000)

66 tests passing, 1 skipped (OOM on 5070)
- Add bench_silu_and_mul_backward.py for silu_and_mul fwd/bwd
- Add bench_fused_swiglu_mlp.py for end-to-end MLP benchmark
- Add bench_swiglu_backward.py for SwiGLU activation benchmark

Benchmarks show ~1.7x speedup over PyTorch for forward,
~1.35x speedup for backward pass.
- Change ylabel from 'ms' to 'GB/s'
- Add memory bandwidth calculation for forward/backward/full modes
- Consistent with existing TileGym benchmark style (bench_silu_and_mul.py)
Follow tests/ops/README.md guidelines: test methods should be named test_op.
- Remove failing test case (16,512,2048,float16) due to FP16 precision on sm120
- Rename TileGym backend to CuTile in bench_fused_swiglu_mlp.py
- Add -GBps suffix to plot_name in all benchmark files for CI recognition
- Add new benchmarks to tests/benchmark/README.md
- Rename test functions in test_swiglu.py to start with test_op_ prefix

Requested-by: @hannahli-nv
@hannahli-nv hannahli-nv force-pushed the feat/swiglu-backward-pass branch from 8b35a72 to b40796b Compare February 9, 2026 06:39
@hannahli-nv
Copy link
Copy Markdown
Collaborator

/ok to test b40796b

@hannahli-nv hannahli-nv merged commit aaab450 into NVIDIA:main Feb 9, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants