Skip to content
44 changes: 24 additions & 20 deletions src/tilegym/ops/cutile/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import cuda.tile as ct
import torch
import torch.nn as nn
from cuda.tile._numeric_semantics import RoundingMode as RMd

from tilegym.backend import register_impl

Expand All @@ -15,28 +16,29 @@


def sigmoid(x):
return 1.0 / (1.0 + ct.exp(-x))
denom = ct.add(1.0, ct.exp(-x), flush_to_zero=True)
return ct.truediv(1.0, denom, flush_to_zero=True, rounding_mode=RMd.APPROX)
Copy link
Copy Markdown
Contributor Author

@aghilann aghilann Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A good chunk of the savings came from Rmd.APPROX without losing precision - verified via tests



def silu(x):
return x * sigmoid(x)
return ct.mul(x, sigmoid(x), flush_to_zero=True)


def ceildiv(a, b):
return -(a // -b)


@ct.kernel
def swiglu_forward_kernel(a, b, c, TILE_SIZE: ct.Constant[int]):
row = ct.bid(0)
col = ct.bid(1)

a_tile = ct.load(a, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO)
b_tile = ct.load(b, index=(row, col), shape=(1, TILE_SIZE), padding_mode=PAD_ZERO)
offsets = ct.arange(TILE_SIZE, dtype=ct.int32)

# Sigmoid requires type float32
c_tile = silu(a_tile.astype(ct.float32)).astype(a.dtype) * b_tile
ct.store(c, index=(row, col), tile=c_tile)
a_tile = ct.gather(a, (row, offsets), check_bounds=True, padding_value=0.0)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good chunk of the perf improvements came from gather scatter vs load/store

b_tile = ct.gather(b, (row, offsets), check_bounds=True, padding_value=0.0)


def ceildiv(a, b):
return -(a // -b)
a_tile_f32 = a_tile.astype(ct.float32)
c_tile = silu(a_tile_f32).astype(a.dtype) * b_tile
ct.scatter(c, (row, offsets), c_tile, check_bounds=True)


def swiglu_forward(a, b):
Expand All @@ -51,18 +53,16 @@ def swiglu_forward(a, b):
c = torch.empty_like(a)
n_rows = a.shape[0]

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
TILE_N = ceildiv(NUM_SMS, n_rows)
TILE_SIZE = next_power_of_2(int(n_cols / TILE_N))
grid = (n_rows, ceildiv(n_cols, TILE_SIZE), 1)
TILE_SIZE = next_power_of_2(n_cols)
grid = (n_rows,)
ct.launch(
torch.cuda.current_stream(),
grid,
swiglu_forward_kernel,
(
a.data,
b.data,
c.data,
a,
b,
c,
TILE_SIZE,
),
)
Expand All @@ -88,8 +88,12 @@ def swiglu_backward_kernel(dc, a, b, da, db, TILE_SIZE: ct.Constant[int]):
a_tile_f32 = a_tile.astype(ct.float32)
b_tile_f32 = b_tile.astype(ct.float32)

# NOTE: sigmoid is intentionally inlined here to preserve current backward
# kernel behavior and benchmark baselines. Forward already uses
# the shared `sigmoid()` helper; backward will switch to it in a follow-up
# optimization PR that re-benchmarks backward performance.
# Compute sigmoid(a) and silu(a)
sigmoid_a = sigmoid(a_tile_f32)
sigmoid_a = 1.0 / (1.0 + ct.exp(-a_tile_f32))
Copy link
Copy Markdown
Contributor Author

@aghilann aghilann Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inlined this for now because I didn’t want to modify the backward kernel in this PR - that would require re-benchmarking it as well. I have additional optimizations planned that I’ll include in a separate PR, which will also make use of the new sigmoid implementation I added.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a clear comment here? The current changes make the forward and backward codes a bit confusing. You can delete the comment when the backward PR is ready.

Copy link
Copy Markdown
Contributor Author

@aghilann aghilann Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @hannahli-nv , added a short comment. Will remove it in a subsequent PR when I optimize the swiglu backward kernel. Ready for review

silu_a = a_tile_f32 * sigmoid_a

# db = dc * silu(a)
Expand Down
23 changes: 15 additions & 8 deletions tests/benchmark/bench_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def get_supported_backends():
return [p for p in ALL_BACKENDS if p is not None]


def create_benchmark_config(batch_size, M):
def create_benchmark_config(batch_size, M, dtype):
"""Create a benchmark configuration for given parameters"""
available_backends = get_supported_backends()
if not available_backends:
return None

backends, names, styles = zip(*available_backends)

dtype_name = str(dtype).split(".")[-1]
return triton.testing.Benchmark(
x_names=["N"],
x_vals=[2**i for i in range(10, 15)], # 1024 to 16384
Expand All @@ -54,38 +55,44 @@ def create_benchmark_config(batch_size, M):
line_names=list(names),
styles=list(styles),
ylabel="GB/s",
plot_name=f"swiglu-batch{batch_size}-M{M}-GBps",
plot_name=f"swiglu-batch{batch_size}-M{M}-{dtype_name}-GBps",
args={
"batch_size": batch_size,
"M": M,
"dtype": dtype,
},
)


@triton.testing.perf_report(
[
create_benchmark_config(batch_size, M)
for batch_size in [1, 8] # Different batch sizes
create_benchmark_config(batch_size, M, dtype)
for batch_size in [1, 4, 8] # Different batch sizes
for M in [128, 4096] # Different rows
for dtype in [torch.float16, torch.bfloat16, torch.float32]
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most benchmarks test across various dtypes, I thought this one should too

]
)
def bench_swiglu(
batch_size,
M,
N,
backend,
dtype,
device=DEVICE,
):
dtype = torch.float16

# Generate input data: two tensors for SwiGLU operation
a = torch.randn(batch_size, M, N, device=device, dtype=dtype)
b = torch.randn(batch_size, M, N, device=device, dtype=dtype)

# Use unified dispatch system
fn = lambda: tilegym.ops.get_swiglu(backend=backend)(a, b)
ref = lambda: reference_swiglu(a, b)
torch.testing.assert_close(fn(), ref(), atol=1e-2, rtol=1e-2)
if dtype is torch.float32:
ref = lambda: reference_swiglu(a, b)
atol, rtol = 1e-5, 1e-5
else:
ref = lambda: reference_swiglu(a, b)
atol, rtol = 1e-2, 1e-2
torch.testing.assert_close(fn(), ref(), atol=atol, rtol=rtol)

# Benchmark the function
ms = triton.testing.do_bench_cudagraph(fn)
Expand Down