Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/tilegym/autotune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: MIT

DISABLE_AUTOTUNE_ENV = "TILEGYM_DISABLE_AUTOTUNE"
_DISABLE_AUTOTUNE_TRUE_VALUES = frozenset({"1", "true", "yes", "on"})
_DISABLE_AUTOTUNE_FALSE_VALUES = frozenset({"0", "false", "no", "off"})


def is_autotune_disabled() -> bool:
"""Return whether autotune is disabled for the current process.

Autotune stays enabled by default. ``TILEGYM_DISABLE_AUTOTUNE`` is the
single public switch; ``1/true/yes/on`` disables autotune and
``0/false/no/off`` keeps the default enabled behavior. Operator code must
not read environment variables directly or alias removed ad hoc autotune
flags.
"""
# Local import keeps this function self-contained.
import os

disable_flag = os.environ.get(DISABLE_AUTOTUNE_ENV)
if disable_flag is None:
return False

disable_flag = disable_flag.strip().lower()
if disable_flag in _DISABLE_AUTOTUNE_TRUE_VALUES:
return True
if disable_flag in _DISABLE_AUTOTUNE_FALSE_VALUES:
return False

valid_values = ", ".join(sorted(_DISABLE_AUTOTUNE_TRUE_VALUES | _DISABLE_AUTOTUNE_FALSE_VALUES))
raise ValueError(f"{DISABLE_AUTOTUNE_ENV} must be one of {{{valid_values}}}; got {disable_flag!r}")


def is_autotune_enabled() -> bool:
"""Return the process-wide autotune policy."""
return not is_autotune_disabled()
2 changes: 1 addition & 1 deletion src/tilegym/ops/cutile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from .experimental.sparse_mla import tile_sparse_mla
from .experimental.swa_attention import tile_swa_attention
from .flash_decode import fmha_decode
from .moe import fused_moe_kernel as invoke_fused_moe_kernel
from .moe import invoke_fused_moe_kernel
from .moe_align_block import moe_align_block_size
from .recurrent_gated_delta_rule import recurrent_gated_delta_rule
from .rms_norm import get_rms_norm_module
Expand Down
139 changes: 100 additions & 39 deletions src/tilegym/ops/cutile/activation/geglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,98 +10,145 @@

from tilegym.backend import register_impl

from .gelu import GELU_EXACT
from .gelu import GELU_TANH
from .gelu import _gelu_fwd
from .gelu import _gelu_tanh_fwd
from .gelu import _normal_cdf
from .gelu import _normal_pdf
from .gelu import gelu_forward_ct
from .gelu import gelu_tanh_forward_ct
from .gelu import standard_normal_cdf_ct
from .gelu import standard_normal_pdf_ct

# Approximation mode constants
GELU_EXACT = 0
GELU_TANH = 1

def _gelu_bwd(x_val, dy_val, BLOCK_SIZE: ct.Constant[int]):
# dy * (Φ(x) + x * φ(x))
return dy_val * (_normal_cdf(x_val, BLOCK_SIZE) + x_val * _normal_pdf(x_val, BLOCK_SIZE))

def _gelu_bwd_ct(x_val, dy_val, BLOCK_SIZE: ct.Constant[int]):
"""
Compute GELU backward gradient: dy * (Φ(x) + x * φ(x))

Args:
x_val: Input value tile
dy_val: Output gradient tile
BLOCK_SIZE: Block size constant

Returns:
Gradient with respect to input
"""
cdf_val = standard_normal_cdf_ct(x_val, BLOCK_SIZE)
pdf_val = standard_normal_pdf_ct(x_val, BLOCK_SIZE)
x_pdf = x_val * pdf_val
grad_factor = cdf_val + x_pdf
return dy_val * grad_factor


@ct.kernel
def geglu_fwd_kernel(
def _geglu_fwd_kernel(
y,
x,
N: ct.Constant[int],
m_stride: ct.Constant[int],
my_stride: ct.Constant[int],
M_STRIDE: ct.Constant[int],
MY_STRIDE: ct.Constant[int],
BLOCK_SIZE: ct.Constant[int],
APPROXIMATE: ct.Constant[int],
):
"""
Forward kernel for GEGLU activation: output = a * GELU(b)
Forward kernel for GEGLU activation.

Computes: output = a * GELU(b)
where a is the left half and b is the right half of the input.
"""
bid = ct.bid(0)

# Compute global indices for this block
global_id = bid * BLOCK_SIZE + ct.arange(BLOCK_SIZE, dtype=ct.int32)

# Compute m_id (batch/row index) and n_offs (column offset)
# m_id = global_id // N
m_id = global_id // N
n_offs = global_id % N

left_ptr_offsets = m_id * m_stride + n_offs
right_ptr_offsets = m_id * m_stride + n_offs + N
out_ptr_offsets = m_id * my_stride + n_offs
# Compute strides for input and output
m_offs = m_id * M_STRIDE
my_offs = m_id * MY_STRIDE

# Calculate pointer offsets for left and right halves
left_ptr_offsets = m_offs + n_offs
right_ptr_offsets = m_offs + n_offs + N
out_ptr_offsets = my_offs + n_offs

# Load left and right halves using gather
a = ct.gather(x, (left_ptr_offsets,))
b = ct.gather(x, (right_ptr_offsets,))

# Compute a * GELU(b)
if APPROXIMATE == GELU_TANH:
out = a * _gelu_tanh_fwd(b, BLOCK_SIZE)
geglu_output = a * gelu_tanh_forward_ct(b, BLOCK_SIZE)
else:
out = a * _gelu_fwd(b, BLOCK_SIZE)
geglu_output = a * gelu_forward_ct(b, BLOCK_SIZE)

ct.scatter(y, (out_ptr_offsets,), out)
# Store output using scatter
ct.scatter(y, (out_ptr_offsets,), geglu_output)


@ct.kernel
def geglu_bwd_kernel(
def _geglu_bwd_kernel(
dx,
dy,
x,
N: ct.Constant[int],
m_stride: ct.Constant[int],
my_stride: ct.Constant[int],
M_STRIDE: ct.Constant[int],
MY_STRIDE: ct.Constant[int],
BLOCK_SIZE: ct.Constant[int],
APPROXIMATE: ct.Constant[int],
):
"""
Backward kernel for GEGLU: da = dy * GELU(b), db = dy * a * GELU'(b).
Backward kernel for GEGLU activation.

Computes gradients with respect to a and b:
- da = dy * GELU(b)
- db = dy * a * GELU'(b)
"""
bid = ct.bid(0)

# Compute global indices for this block
global_id = bid * BLOCK_SIZE + ct.arange(BLOCK_SIZE, dtype=ct.int32)

# Compute m_id (batch/row index) and n_offs (column offset)
m_id = global_id // N
n_offs = global_id % N

left_ptr_offsets = m_id * m_stride + n_offs
right_ptr_offsets = m_id * m_stride + n_offs + N
out_ptr_offsets = m_id * my_stride + n_offs
# Compute strides for input and output
m_offs = m_id * M_STRIDE
my_offs = m_id * MY_STRIDE

# Calculate pointer offsets
left_ptr_offsets = m_offs + n_offs
right_ptr_offsets = m_offs + n_offs + N
out_ptr_offsets = my_offs + n_offs

# Load input splits and output gradient
a = ct.gather(x, (left_ptr_offsets,))
b = ct.gather(x, (right_ptr_offsets,))
dy_val = ct.gather(dy, (out_ptr_offsets,))

# Compute GELU(b) for gradient of a
if APPROXIMATE == GELU_TANH:
gelu_b = _gelu_tanh_fwd(b, BLOCK_SIZE)
dy_da = gelu_tanh_forward_ct(b, BLOCK_SIZE)
else:
gelu_b = _gelu_fwd(b, BLOCK_SIZE)
dy_da = gelu_forward_ct(b, BLOCK_SIZE)

da = dy_val * gelu_b
db = a * _gelu_bwd(b, dy_val, BLOCK_SIZE)
# Compute gradients
da = dy_val * dy_da
db = a * _gelu_bwd_ct(b, dy_val, BLOCK_SIZE)

# Store gradients
ct.scatter(dx, (left_ptr_offsets,), da)
ct.scatter(dx, (right_ptr_offsets,), db)


class GegluFunction(torch.autograd.Function):
class _GEGLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x, dim, approximate):
assert approximate == "none" or approximate == "tanh", "Only `none` or `tanh` activations are supported"
# Process input
assert x.is_contiguous()
assert x.shape[dim] % 2 == 0

Expand All @@ -110,32 +157,38 @@ def forward(ctx, x, dim, approximate):
y_shape = list(x_shape)
y_shape[dim] = y_shape[dim] // 2

# Flatten input and output for kernel processing
x_flat = x.view(-1)
y_flat = torch.empty(reduce(operator.mul, y_shape, 1), device=x.device, dtype=x.dtype)

# Compute strides
if dim == 0:
m_stride = 0
my_stride = 0
else:
m_stride = x.stride(dim - 1)
my_stride = reduce(operator.mul, y_shape[dim:], 1)
my_stride = reduce(operator.mul, y_shape[dim:], 1) # Stride for flattened y

# Compute dimensions
M = reduce(operator.mul, x_shape[:dim], 1)
N2 = reduce(operator.mul, x_shape[dim:], 1) // 2
n_elements = reduce(operator.mul, x_shape, 1) // 2

BLOCK_SIZE = 256
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)

approximate_mode = GELU_TANH if approximate == "tanh" else GELU_EXACT

ct.launch(
torch.cuda.current_stream(),
grid,
geglu_fwd_kernel,
_geglu_fwd_kernel,
(y_flat, x_flat, N2, m_stride, my_stride, BLOCK_SIZE, approximate_mode),
)

# Reshape output back to expected shape
y = y_flat.view(y_shape)

ctx.save_for_backward(x, y)
ctx.M = M
ctx.N2 = N2
Expand All @@ -151,13 +204,17 @@ def backward(ctx, dy):
x, y = ctx.saved_tensors
dim = ctx.dim
approximate = ctx.approximate
M = ctx.M
N2 = ctx.N2
n_elements = ctx.n_elements

x_shape = x.shape
dx_flat = torch.empty_like(x.view(-1))

# Flatten tensors for kernel processing
x_flat = x.view(-1)
dy_flat = dy.view(-1)
dx_flat = torch.empty_like(x_flat)

# Compute strides
if dim == 0:
m_stride = 0
my_stride = 0
Expand All @@ -167,16 +224,20 @@ def backward(ctx, dy):

BLOCK_SIZE = 256
grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)

approximate_mode = GELU_TANH if approximate == "tanh" else GELU_EXACT

ct.launch(
torch.cuda.current_stream(),
grid,
geglu_bwd_kernel,
(dx_flat, dy.view(-1), x.view(-1), N2, m_stride, my_stride, BLOCK_SIZE, approximate_mode),
_geglu_bwd_kernel,
(dx_flat, dy_flat, x_flat, N2, m_stride, my_stride, BLOCK_SIZE, approximate_mode),
)

return dx_flat.view(x_shape), None, None
# Reshape output back to expected shape
dx = dx_flat.view(x_shape)

return dx, None, None


@register_impl("geglu", backend="cutile")
Expand All @@ -196,4 +257,4 @@ def geglu(input: torch.Tensor, dim=-1, approximate="none"):
dim: int
approximate: ``'none'`` or ``'tanh'``
"""
return GegluFunction.apply(input, dim, approximate)
return _GEGLU.apply(input, dim, approximate)
Loading
Loading