In [35]:
from torch.utils.cpp_extension import load
import torch

In [36]:
faster_attn = load(name='faster_attn', sources=['../src/main.cpp', '../src/flash_attention_kernel.cu'], extra_cuda_cflags=['-O3', '-arch=sm_70'])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [37]:
from torch.utils import benchmark

In [38]:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


In [39]:
def gen_data(): 
    B, H, S, D = 4, 8, 128, 64  # Batch, Heads, Seq Len, Head Dim
    dtype = torch.float32
    device = 'cuda'

    q = torch.randn(B, H, S, D, device=device, dtype=dtype).contiguous()
    k = torch.randn(B, H, S, D, device=device, dtype=dtype).contiguous()
    v = torch.randn(B, H, S, D, device=device, dtype=dtype).contiguous()
    return q, k, v

In [40]:
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn import functional as F
torch.set_float32_matmul_precision("highest")  # ensure it's not "medium" or "highest" (they allow mixed precision)
q, k, v = gen_data()

# Flash attention not supported
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
#     try:
#         flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v)
#         print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
#     except RuntimeError:
#         print("FlashAttention is not supported. See warnings for reasons.")

efficient_time=benchmark_torch_function_in_microseconds(faster_attn.forward, q, k, v)
print(f"Custom implementation runs in {efficient_time:.3f} microseconds")

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v)
        print(f"The memory efficient implementation runs in {efficient_time:.3f} microseconds")
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.MATH):
    math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v)
    print(f"The math implementation runs in {math_time:.3f} microseconds")



Custom implementation runs in 90.067 microseconds
The memory efficient implementation runs in 25.894 microseconds
The math implementation runs in 82.830 microseconds


In [22]:
out = F.scaled_dot_product_attention(q, k, v)

In [24]:
out.dtype

torch.float32