In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# Example Usage:
query, key, value = (
    torch.randn(2, 3, 8, device=device),
    torch.randn(2, 3, 8, device=device),
    torch.randn(2, 3, 8, device=device),
)
F.scaled_dot_product_attention(query, key, value)

tensor([[[ 0.0748, -0.2111,  0.0108, -0.2959, -0.3132,  0.4317, -0.1025,
           0.2680],
         [ 0.2194,  0.7728, -0.4848, -0.2770,  0.5635,  0.5530, -0.3188,
           0.3752],
         [ 0.2338,  0.7983, -0.4732, -0.2727,  0.5849,  0.5092, -0.3505,
           0.3789]],

        [[-0.2411, -0.0841, -0.1971, -0.4740, -0.4923,  1.2068,  0.2419,
          -1.0348],
         [-0.2571, -0.0342, -0.2399, -0.1845, -0.4531,  1.0531,  0.0447,
          -0.9572],
         [-0.3720,  0.2098, -0.5014,  0.8150, -0.4192,  0.5959, -0.6231,
          -0.7724]]])

In [3]:
# Lets define a helpful benchmarking function:
import torch.utils.benchmark as benchmark


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16

query = torch.rand(
    batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)
key = torch.rand(
    batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)
value = torch.rand(
    batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype
)

print(
    f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
)

# Lets explore the speed of each of the 3 implementations
from torch.backends.cuda import sdp_kernel, SDPBackend

# Helpful arguments mapper
backend_map = {
    SDPBackend.MATH: {
        "enable_math": True,
        "enable_flash": False,
        "enable_mem_efficient": False,
    },
    SDPBackend.FLASH_ATTENTION: {
        "enable_math": False,
        "enable_flash": True,
        "enable_mem_efficient": False,
    },
    SDPBackend.EFFICIENT_ATTENTION: {
        "enable_math": False,
        "enable_flash": False,
        "enable_mem_efficient": True,
    },
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
    print(
        f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
    )


with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
    try:
        print(
            f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
        )
    except RuntimeError:
        print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
    try:
        print(
            f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
        )
    except RuntimeError:
        print("EfficientAttention is not supported. See warnings for reasons.")

The default implementation runs in 56866768.684 microseconds
