Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 6 additions & 50 deletions src/tilegym/ops/cutile/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,7 @@ def softmax_kernel(

for row_idx in range(pid, n_rows, num_programs):
# Load the row tile using index-based access
# Note: TILE_SIZE is expected to be >= DIM_COLS (ensured by launch function)
row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE))

# Create mask for valid columns
col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32)
mask_col = col_offsets < DIM_COLS
mask = mask_col[None, :]

# Apply mask by setting invalid positions to -inf for numerical stability
neg_inf_tile = ct.full((1, TILE_SIZE), -np.inf, dtype=row.dtype)
row = ct.where(mask, row, neg_inf_tile)
row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF)

# Convert to float32 for computation
row = ct.astype(row, torch.float32)
Expand Down Expand Up @@ -79,16 +69,7 @@ def softmax_kernel_tma(

for row_idx in range(pid, n_rows, num_programs):
# Load the entire row in one tile (TILE_SIZE >= n_cols by design)
row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE))

# Create mask for valid columns
col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32)
mask_col = col_offsets < n_cols
mask = mask_col[None, :]

# Apply mask by setting invalid positions to -inf
neg_inf_tile = ct.full((1, TILE_SIZE), -np.inf, dtype=row.dtype)
row = ct.where(mask, row, neg_inf_tile)
row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF)

# Convert to float32 for computation
row = ct.astype(row, np.float32)
Expand Down Expand Up @@ -125,20 +106,9 @@ def launch_softmax_kernel(input, output, TILE_SIZE=1024):
n_rows, n_cols = input.shape
original_n_cols = n_cols

# Pad tensors if needed to match TILE_SIZE
# This is required because ct.load expects shape to match tensor dimensions
needs_padding = n_cols < TILE_SIZE
if needs_padding:
padding = TILE_SIZE - n_cols
input_padded = torch.nn.functional.pad(input, (0, padding), value=-float('inf'))
output_padded = torch.nn.functional.pad(output, (0, padding), value=0.0)
else:
input_padded = input
output_padded = output

# Ensure tensors are contiguous
input_padded = input_padded.contiguous()
output_padded = output_padded.contiguous()
input = input.contiguous()
output = output.contiguous()

NUM_SM = torch.cuda.get_device_properties(input.device).multi_processor_count
occupancy = 4 # Match @ct.kernel(occupancy=4)
Expand All @@ -150,18 +120,14 @@ def launch_softmax_kernel(input, output, TILE_SIZE=1024):
grid,
softmax_kernel,
(
output_padded,
input_padded,
output,
input,
n_rows,
TILE_SIZE,
original_n_cols,
),
)

# Copy result back to original output tensor if padding was used
if needs_padding:
output.copy_(output_padded[:, :original_n_cols])


def launch_softmax_kernel_tma(
input,
Expand Down Expand Up @@ -191,12 +157,6 @@ def launch_softmax_kernel_tma(
# Regular TMA path (single tile per row, persistent scheduling)
softmax_kernel_forward = softmax_kernel_tma

# Pad if needed
if n_cols < TILE_SIZE:
padding = TILE_SIZE - n_cols
input = torch.nn.functional.pad(input, (0, padding), value=-float('inf'))
output = torch.nn.functional.pad(output, (0, padding), value=0.0)

# Ensure tensors are contiguous
input = input.contiguous()
output = output.contiguous()
Expand All @@ -219,10 +179,6 @@ def launch_softmax_kernel_tma(
),
)

# Trim output to original size
if original_n_cols < output.shape[-1]:
output = output[..., :original_n_cols]


class Softmax(torch.autograd.Function):
@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions tests/ops/test_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def reference(x):
(256, 1024 * 32, torch.float32),
(256, 256, torch.float16),
(256, 2048, torch.float16),
(256, 9, torch.float16),
(256, 1009, torch.float16),
])
@pytest.mark.parametrize("backend", _backends)
@pytest.mark.parametrize(
Expand Down