In [1]:
import torch

In [2]:
import custom_softmax_cuda

In [3]:
custom_softmax_cuda.safe_softmax(torch.tensor([[1.0, 3.0], [2.0, 2.0], [3.0, 1.0]], device="cuda:0", dtype=torch.float32))

tensor([[0.1192, 0.8808],
        [0.5000, 0.5000],
        [0.8808, 0.1192]], device='cuda:0')

In [4]:
# Test cases
def test_softmax_cuda(x):
    # Ensure x is on CUDA
    if x.device.type != 'cuda':
        x = x.cuda()
    print(f"\nTesting softmax on tensor of shape {x.shape}, strides {x.stride()}")
    cuda_output = custom_softmax_cuda.safe_softmax(x)
    torch_output = torch.softmax(x.float(), dim=-1) # Softmax over the last dimension

    # Compare
    if torch.allclose(cuda_output, torch_output, atol=1e-5):
        print("Match!")
    else:
        print("Mismatch!")


# 1. Row-major (default PyTorch)
# Softmax over the last dimension (columns) is natural for row-major.
x_row_major = torch.randn(128, 512, device='cuda', dtype=torch.float32)
test_softmax_cuda(x_row_major)

# 2. Column-major for the 'row' dimension (i.e., transposed from default)
# If you want to perform softmax over the 'rows' of a logically column-major matrix,
# you should transpose it first to make the 'rows' the last dimension in PyTorch.
# Or, if your original matrix is (M, N) and you want softmax over dim 0 (rows),
# you can transpose it to (N, M) and then apply softmax over dim 1.
x_col_major_logical = torch.randn(512, 128, device='cuda', dtype=torch.float32)
# To apply softmax over dim 0 (rows) of x_col_major_logical,
# you would effectively transpose it to make the 'rows' the last dimension
# for the Triton kernel which expects softmax over the last dim.
# So, x_col_major_logical.T becomes (128, 512) and is row-major.
test_softmax_cuda(x_col_major_logical.T)

# Another example: a non-contiguous tensor due to slicing
x_sliced = torch.randn(256, 256, device='cuda', dtype=torch.float32)[:, ::2] # x_sliced is (256, 128) but non-contiguous
print(f"\nOriginal sliced tensor strides: {x_sliced.stride()}")
# The current kernel works fine with non-contiguous strides as it uses the provided strides.
test_softmax_cuda(x_sliced)

# A more complex permutation
x_permuted = torch.randn(8, 64, 8, device='cuda', dtype=torch.float32).permute(0, 2, 1) # (8, 32, 256)
# To apply softmax over the last dimension (64), this is fine.
# If we wanted softmax over the middle dimension (128), we'd need to permute again
# or re-design the kernel to iterate over a different stride.
test_softmax_cuda(x_permuted.reshape(-1, x_permuted.shape[-1])) # Flatten to 2D for the current kernel


Testing softmax on tensor of shape torch.Size([128, 512]), strides (512, 1)
Match!

Testing softmax on tensor of shape torch.Size([128, 512]), strides (1, 128)
Match!

Original sliced tensor strides: (256, 2)

Testing softmax on tensor of shape torch.Size([256, 128]), strides (256, 2)
Match!

Testing softmax on tensor of shape torch.Size([64, 64]), strides (64, 1)
Match!


In [6]:
from torch.utils.cpp_extension import load

In [7]:
jit_softmax = load("jit_softmax", sources=["online_softmax.cu", "safe_softmax.cu", "binding.cpp"])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [10]:
torch.sum(jit_softmax.safe_softmax(torch.randn(512, 128, device='cuda', dtype=torch.float32)), 1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 