In [1]:
import torch
import triton
import triton.language as tl

@triton.jit
def softmax_kernel(
    output_ptr, input_ptr,
    M, N, # M rows, N columns
    input_row_stride, input_col_stride, # Strides of the input tensor
    output_row_stride, output_col_stride, # Strides of the output tensor
    BLOCK_SIZE: tl.constexpr # N is the dimension we iterate over
):
    # We tile across the rows
    row_idx = tl.program_id(0)

    # Calculate pointer to the start of the current row
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # Create a block of pointers for the current row
    # This assumes we want to iterate over the columns (N) for softmax
    # The 'columns' are accessed via input_col_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets * input_col_stride

    # Load a block of elements
    # We use a mask to handle rows that are not a multiple of BLOCK_SIZE
    mask = col_offsets < N
    block = tl.load(input_ptrs, mask=mask, other=-float('inf'))

    # Compute softmax
    numerator = tl.exp(block - tl.max(block, axis=0))
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # Store the result
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets * output_col_stride
    tl.store(output_ptrs, softmax_output, mask=mask)

class Softmax(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        M, N = x.shape
        output = torch.empty_like(x)

        # Get strides from PyTorch tensor
        # PyTorch returns strides in elements, not bytes, so we don't need x.element_size() here
        input_row_stride, input_col_stride = x.stride()
        output_row_stride, output_col_stride = output.stride()

        # BLOCK_SIZE must be a power of 2 for optimal performance
        # We need to choose a BLOCK_SIZE that is at least N, or tile if N is larger
        # For simplicity, let's assume N <= 1024 and pick a suitable block size.
        # For larger N, you would need multiple blocks per row.
        BLOCK_SIZE = triton.next_power_of_2(N)
        if BLOCK_SIZE > 2048: # Cap block size to avoid register pressure
            BLOCK_SIZE = 2048
        # If N is small, a smaller block size might be better
        if N < 64:
             BLOCK_SIZE = 64

        # Number of programs is equal to the number of rows (M)
        grid = lambda meta: (M,)

        softmax_kernel[grid](
            output, x,
            M, N,
            input_row_stride, input_col_stride,
            output_row_stride, output_col_stride,
            BLOCK_SIZE=BLOCK_SIZE
        )
        return output

In [3]:
# Test cases
def test_softmax_triton(x):
    # Ensure x is on CUDA
    if x.device.type != 'cuda':
        x = x.cuda()
    print(f"\nTesting softmax on tensor of shape {x.shape}, strides {x.stride()}")
    triton_output = Softmax.apply(x)
    torch_output = torch.softmax(x.float(), dim=-1) # Softmax over the last dimension

    # Compare
    assert torch.allclose(triton_output, torch_output, atol=1e-5), \
        f"Mismatch!\nTriton:\n{triton_output}\nTorch:\n{torch_output}"
    print("Match!")

# 1. Row-major (default PyTorch)
# Softmax over the last dimension (columns) is natural for row-major.
x_row_major = torch.randn(128, 512, device='cuda', dtype=torch.float32)
test_softmax_triton(x_row_major)

# 2. Column-major for the 'row' dimension (i.e., transposed from default)
# If you want to perform softmax over the 'rows' of a logically column-major matrix,
# you should transpose it first to make the 'rows' the last dimension in PyTorch.
# Or, if your original matrix is (M, N) and you want softmax over dim 0 (rows),
# you can transpose it to (N, M) and then apply softmax over dim 1.
x_col_major_logical = torch.randn(512, 128, device='cuda', dtype=torch.float32)
# To apply softmax over dim 0 (rows) of x_col_major_logical,
# you would effectively transpose it to make the 'rows' the last dimension
# for the Triton kernel which expects softmax over the last dim.
# So, x_col_major_logical.T becomes (128, 512) and is row-major.
test_softmax_triton(x_col_major_logical.T)

# Another example: a non-contiguous tensor due to slicing
x_sliced = torch.randn(256, 256, device='cuda', dtype=torch.float32)[:, ::2] # x_sliced is (256, 128) but non-contiguous
print(f"\nOriginal sliced tensor strides: {x_sliced.stride()}")
# The current kernel works fine with non-contiguous strides as it uses the provided strides.
test_softmax_triton(x_sliced)

# A more complex permutation
x_permuted = torch.randn(32, 64, 128, device='cuda', dtype=torch.float32).permute(0, 2, 1) # (32, 128, 64)
# To apply softmax over the last dimension (64), this is fine.
# If we wanted softmax over the middle dimension (128), we'd need to permute again
# or re-design the kernel to iterate over a different stride.
test_softmax_triton(x_permuted.reshape(-1, x_permuted.shape[-1])) # Flatten to 2D for the current kernel


Testing softmax on tensor of shape torch.Size([128, 512]), strides (512, 1)
Match!

Testing softmax on tensor of shape torch.Size([128, 512]), strides (1, 128)
Match!

Original sliced tensor strides: (256, 2)

Testing softmax on tensor of shape torch.Size([256, 128]), strides (256, 2)
Match!

Testing softmax on tensor of shape torch.Size([4096, 64]), strides (64, 1)
Match!
