In [1]:
from torch.utils.cpp_extension import load
import torch

In [31]:
faster_attn = load(name='faster_attn', sources=['../src/main.cpp', '../src/flash_attention_kernel.cu'], extra_cuda_cflags=['-O3', '-arch=sm_75', '--use_fast_math'])
'''
CFLAGS = -arch=sm_75 -O3 -lineinfo --use_fast_math --keep-device-functions -std=c++17
'''

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


'\nCFLAGS = -arch=sm_75 -O3 -lineinfo --use_fast_math --keep-device-functions -std=c++17\n'

In [32]:
from torch.utils import benchmark

In [104]:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.timeit(599)

In [105]:
def gen_data(): 
    B, H, S, D = 4, 8, 128, 64  # Batch, Heads, Seq Len, Head Dim
    dtype = torch.float32
    device = 'cuda'

    q = torch.randn(B, H, S, D, device=device, dtype=dtype).contiguous()
    k = torch.randn(B, H, S, D, device=device, dtype=dtype).contiguous()
    v = torch.randn(B, H, S, D, device=device, dtype=dtype).contiguous()
    return q, k, v

In [111]:
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn import functional as F
torch.set_float32_matmul_precision("high")  # ensure it's not "medium" or "highest" (they allow mixed precision)
q, k, v = gen_data()

# Flash attention not supported
# with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
#     try:
#         flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v)
#         print(f"The flash attention implementation runs in {flash_time:.3f} microseconds")
#     except RuntimeError:
#         print("FlashAttention is not supported. See warnings for reasons.")

mes=benchmark_torch_function_in_microseconds(faster_attn.forward, q, k, v)
print(f"Custom implementation run: ")
print(mes)

with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
    try:
        mes=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v)
        print(f"The memory efficient implementation runs:")
        print(mes)
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

with sdpa_kernel(SDPBackend.MATH):
    mes=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, q, k, v)
    print(f"The math implementation run:")
    print(mes)



Custom implementation run: 
<torch.utils.benchmark.utils.common.Measurement object at 0x7d26481dda50>
f(*args, **kwargs)
  84.84 us
  1 measurement, 599 runs , 1 thread
The memory efficient implementation runs:
<torch.utils.benchmark.utils.common.Measurement object at 0x7d2606758610>
f(*args, **kwargs)
  25.66 us
  1 measurement, 599 runs , 1 thread
The math implementation run:
<torch.utils.benchmark.utils.common.Measurement object at 0x7d2606758650>
f(*args, **kwargs)
  85.68 us
  1 measurement, 599 runs , 1 thread


In [36]:
out = F.scaled_dot_product_attention(q, k, v)

In [24]:
out.dtype

torch.float32