diff --git a/src/tilegym/ops/cutile/softmax.py b/src/tilegym/ops/cutile/softmax.py index fcea4a3..421295b 100644 --- a/src/tilegym/ops/cutile/softmax.py +++ b/src/tilegym/ops/cutile/softmax.py @@ -30,17 +30,7 @@ def softmax_kernel( for row_idx in range(pid, n_rows, num_programs): # Load the row tile using index-based access - # Note: TILE_SIZE is expected to be >= DIM_COLS (ensured by launch function) - row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE)) - - # Create mask for valid columns - col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32) - mask_col = col_offsets < DIM_COLS - mask = mask_col[None, :] - - # Apply mask by setting invalid positions to -inf for numerical stability - neg_inf_tile = ct.full((1, TILE_SIZE), -np.inf, dtype=row.dtype) - row = ct.where(mask, row, neg_inf_tile) + row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF) # Convert to float32 for computation row = ct.astype(row, torch.float32) @@ -79,16 +69,7 @@ def softmax_kernel_tma( for row_idx in range(pid, n_rows, num_programs): # Load the entire row in one tile (TILE_SIZE >= n_cols by design) - row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE)) - - # Create mask for valid columns - col_offsets = ct.arange(TILE_SIZE, dtype=torch.int32) - mask_col = col_offsets < n_cols - mask = mask_col[None, :] - - # Apply mask by setting invalid positions to -inf - neg_inf_tile = ct.full((1, TILE_SIZE), -np.inf, dtype=row.dtype) - row = ct.where(mask, row, neg_inf_tile) + row = ct.load(input, index=(row_idx, 0), shape=(1, TILE_SIZE), padding_mode=ct.PaddingMode.NEG_INF) # Convert to float32 for computation row = ct.astype(row, np.float32) @@ -125,20 +106,9 @@ def launch_softmax_kernel(input, output, TILE_SIZE=1024): n_rows, n_cols = input.shape original_n_cols = n_cols - # Pad tensors if needed to match TILE_SIZE - # This is required because ct.load expects shape to match tensor dimensions - needs_padding = n_cols < TILE_SIZE - if needs_padding: - padding = TILE_SIZE - n_cols - input_padded = torch.nn.functional.pad(input, (0, padding), value=-float('inf')) - output_padded = torch.nn.functional.pad(output, (0, padding), value=0.0) - else: - input_padded = input - output_padded = output - # Ensure tensors are contiguous - input_padded = input_padded.contiguous() - output_padded = output_padded.contiguous() + input = input.contiguous() + output = output.contiguous() NUM_SM = torch.cuda.get_device_properties(input.device).multi_processor_count occupancy = 4 # Match @ct.kernel(occupancy=4) @@ -150,18 +120,14 @@ def launch_softmax_kernel(input, output, TILE_SIZE=1024): grid, softmax_kernel, ( - output_padded, - input_padded, + output, + input, n_rows, TILE_SIZE, original_n_cols, ), ) - # Copy result back to original output tensor if padding was used - if needs_padding: - output.copy_(output_padded[:, :original_n_cols]) - def launch_softmax_kernel_tma( input, @@ -191,12 +157,6 @@ def launch_softmax_kernel_tma( # Regular TMA path (single tile per row, persistent scheduling) softmax_kernel_forward = softmax_kernel_tma - # Pad if needed - if n_cols < TILE_SIZE: - padding = TILE_SIZE - n_cols - input = torch.nn.functional.pad(input, (0, padding), value=-float('inf')) - output = torch.nn.functional.pad(output, (0, padding), value=0.0) - # Ensure tensors are contiguous input = input.contiguous() output = output.contiguous() @@ -219,10 +179,6 @@ def launch_softmax_kernel_tma( ), ) - # Trim output to original size - if original_n_cols < output.shape[-1]: - output = output[..., :original_n_cols] - class Softmax(torch.autograd.Function): @staticmethod diff --git a/tests/ops/test_softmax.py b/tests/ops/test_softmax.py index 2b972d8..0126a29 100644 --- a/tests/ops/test_softmax.py +++ b/tests/ops/test_softmax.py @@ -21,6 +21,8 @@ def reference(x): (256, 1024 * 32, torch.float32), (256, 256, torch.float16), (256, 2048, torch.float16), + (256, 9, torch.float16), + (256, 1009, torch.float16), ]) @pytest.mark.parametrize("backend", _backends) @pytest.mark.parametrize(