In [1]:
import time
import torch
import torch.nn.functional as F
import os
os.environ['CUDA_LAUNCH_BLOCKING']='1'

In [2]:
from torch.utils.cpp_extension import load_inline
from torch.profiler import profile, record_function, ProfilerActivity
def load_cuda(cuda_src, cpp_src, funcs, opt=False, verbose=False):
    return load_inline(cuda_sources=[cuda_src], cpp_sources=[cpp_src], functions=funcs, with_cuda=True,
                       extra_cuda_cflags=["-O3"] if opt else [], verbose=verbose, name="inline_ext")


In [6]:
# Initialize the matrix on device
matrix = torch.randn(1024, 32768, device='cuda', dtype=torch.float32)

# Warm up
_ = torch.nn.functional.softmax(matrix, dim=-1)

# Ensure all CUDA operations are finished
torch.cuda.synchronize()  

total_time = 0
n_iters = 5

for i in range(n_iters):
    # Measure time
    torch.cuda.synchronize()  # Ensure all CUDA operations are finished
    start = time.time()
    _ = torch.nn.functional.softmax(matrix, dim=-1)
    torch.cuda.synchronize()  # Synchronize again
    end = time.time()
    
    total_time += (end - start) * 1000
    print(total_time)

print(f"Softmax computation time (average): {(total_time/n_iters):.3f} ms")

1.192331314086914
2.275705337524414
3.372669219970703
4.453182220458984
5.533456802368164
Softmax computation time (average): 1.107 ms


In [11]:
cuda_src = open("softmax.cu").read()
cpp_src = """
torch::Tensor naive_softmax(torch::Tensor input);
torch::Tensor online_normalizer_softmax(torch::Tensor input);
torch::Tensor share_memory_softmax(torch::Tensor input);
torch::Tensor warp_shuffle_softmax(torch::Tensor input);

"""
funcs = ["naive_softmax", "online_normalizer_softmax", "share_memory_softmax", "warp_shuffle_softmax"]
ext = load_cuda(cuda_src, cpp_src, funcs)

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [12]:
out = ext.share_memory_softmax(matrix)
out.sum(-1)

tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000], device='cuda:0')

In [22]:
total_time = 0
n_iters = 5

for i in range(n_iters):
    # Measure time
    torch.cuda.synchronize()  # Ensure all CUDA operations are finished
    start = time.time()
    _ = ext.share_memory_softmax(matrix)
    torch.cuda.synchronize()  # Synchronize again
    end = time.time()
    
    total_time += (end - start) * 1000
    print(total_time)

print(f"Softmax computation time (average): {(total_time/n_iters):.3f} ms")

1.3437271118164062
2.6297569274902344
3.8673877716064453
5.094289779663086
6.294965744018555
Softmax computation time (average): 1.259 ms
