From 3fbae3b6549a1d905120f8c1d65a5d8af1b88267 Mon Sep 17 00:00:00 2001 From: jinmanx Date: Sun, 7 Dec 2025 01:57:24 -0800 Subject: [PATCH 1/2] remove padding --- src/tilegym/ops/cutile/softmax.py | 37 +++++-------------------------- tests/ops/test_softmax.py | 2 ++ 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/src/tilegym/ops/cutile/softmax.py b/src/tilegym/ops/cutile/softmax.py index fcea4a3..1453b1c 100644 --- a/src/tilegym/ops/cutile/softmax.py +++ b/src/tilegym/ops/cutile/softmax.py @@ -31,7 +31,7 @@ def softmax_kernel( for row_idx in range(pid, n_rows, num_programs): # Load the row tile using index-based access # Note: TILE_SIZE is expected to be >= DIM_COLS (ensured by launch function) - row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE)) + row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF) # Create mask for valid columns col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32) @@ -79,7 +79,7 @@ def softmax_kernel_tma( for row_idx in range(pid, n_rows, num_programs): # Load the entire row in one tile (TILE_SIZE >= n_cols by design) - row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE)) + row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF) # Create mask for valid columns col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32) @@ -125,20 +125,9 @@ def launch_softmax_kernel(input, output, TILE_SIZE=1024): n_rows, n_cols = input.shape original_n_cols = n_cols - # Pad tensors if needed to match TILE_SIZE - # This is required because ct.load expects shape to match tensor dimensions - needs_padding = n_cols < TILE_SIZE - if needs_padding: - padding = TILE_SIZE - n_cols - input_padded = torch.nn.functional.pad(input, (0, padding), value=-float('inf')) - output_padded = torch.nn.functional.pad(output, (0, padding), value=0.0) - else: - input_padded = input - output_padded = output - # Ensure tensors are contiguous - input_padded = input_padded.contiguous() - output_padded = output_padded.contiguous() + input = input.contiguous() + output = output.contiguous() NUM_SM = torch.cuda.get_device_properties(input.device).multi_processor_count occupancy = 4 # Match @ct.kernel(occupancy=4) @@ -150,18 +139,14 @@ def launch_softmax_kernel(input, output, TILE_SIZE=1024): grid, softmax_kernel, ( - output_padded, - input_padded, + output, + input, n_rows, TILE_SIZE, original_n_cols, ), ) - # Copy result back to original output tensor if padding was used - if needs_padding: - output.copy_(output_padded[:, :original_n_cols]) - def launch_softmax_kernel_tma( input, @@ -191,12 +176,6 @@ def launch_softmax_kernel_tma( # Regular TMA path (single tile per row, persistent scheduling) softmax_kernel_forward = softmax_kernel_tma - # Pad if needed - if n_cols < TILE_SIZE: - padding = TILE_SIZE - n_cols - input = torch.nn.functional.pad(input, (0, padding), value=-float('inf')) - output = torch.nn.functional.pad(output, (0, padding), value=0.0) - # Ensure tensors are contiguous input = input.contiguous() output = output.contiguous() @@ -219,10 +198,6 @@ def launch_softmax_kernel_tma( ), ) - # Trim output to original size - if original_n_cols < output.shape[-1]: - output = output[..., :original_n_cols] - class Softmax(torch.autograd.Function): @staticmethod diff --git a/tests/ops/test_softmax.py b/tests/ops/test_softmax.py index 2b972d8..0126a29 100644 --- a/tests/ops/test_softmax.py +++ b/tests/ops/test_softmax.py @@ -21,6 +21,8 @@ def reference(x): (256, 1024 * 32, torch.float32), (256, 256, torch.float16), (256, 2048, torch.float16), + (256, 9, torch.float16), + (256, 1009, torch.float16), ]) @pytest.mark.parametrize("backend", _backends) @pytest.mark.parametrize( From ad4c538e9159b30711ac0e15fbe4630018ea8fb5 Mon Sep 17 00:00:00 2001 From: jinmanx Date: Sun, 7 Dec 2025 02:10:23 -0800 Subject: [PATCH 2/2] remove padding logic --- src/tilegym/ops/cutile/softmax.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/src/tilegym/ops/cutile/softmax.py b/src/tilegym/ops/cutile/softmax.py index 1453b1c..421295b 100644 --- a/src/tilegym/ops/cutile/softmax.py +++ b/src/tilegym/ops/cutile/softmax.py @@ -30,18 +30,8 @@ def softmax_kernel( for row_idx in range(pid, n_rows, num_programs): # Load the row tile using index-based access - # Note: TILE_SIZE is expected to be >= DIM_COLS (ensured by launch function) row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF) - # Create mask for valid columns - col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32) - mask_col = col_offsets < DIM_COLS - mask = mask_col[None, :] - - # Apply mask by setting invalid positions to -inf for numerical stability - neg_inf_tile = ct.full((1, TILE_SIZE), -np.inf, dtype=row.dtype) - row = ct.where(mask, row, neg_inf_tile) - # Convert to float32 for computation row = ct.astype(row, torch.float32) @@ -81,15 +71,6 @@ def softmax_kernel_tma( # Load the entire row in one tile (TILE_SIZE >= n_cols by design) row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF) - # Create mask for valid columns - col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32) - mask_col = col_offsets < n_cols - mask = mask_col[None, :] - - # Apply mask by setting invalid positions to -inf - neg_inf_tile = ct.full((1, TILE_SIZE), -np.inf, dtype=row.dtype) - row = ct.where(mask, row, neg_inf_tile) - # Convert to float32 for computation row = ct.astype(row, np.float32)