In [1]:
import time
import torch
import torch.nn.functional as F

In [2]:
bz = 32
seq_len = 2048
dims = 64
n_heads = 8

q = torch.randn(bz, n_heads, seq_len, dims, dtype = torch.float16).cuda()
k = torch.randn(bz, n_heads, seq_len, dims, dtype = torch.float16).cuda()
v = torch.randn(bz, n_heads, seq_len, dims, dtype = torch.float16).cuda()

In [3]:
q.shape

torch.Size([32, 8, 2048, 64])

In [4]:
dropout_rate = 0.2
num_trials = 10

In [7]:
# standard attention
torch.cuda.synchronize()
start_time = time.time()
for i in range(num_trials):
    attn = q @ k.transpose(-2, -1)
    attn = attn.softmax(dim = -1)
    attn = F.dropout(attn, p = dropout_rate, training = True)
    x = (attn @ v).transpose(1, 2)
    
torch.cuda.synchronize()
print(f"Standard Attention {time.time() - start_time} seconds")


# flash attention
with torch.backends.cuda.sdp_kernel(
    enable_flash = True, enable_math = False, enable_mem_efficient = False
):
    torch.cuda.synchronize()
    start_time = time.time()
    
    for i in range(num_trials):
        out = F.scaled_dot_product_attention(q, k, v, dropout_p = dropout_rate)
    torch.cuda.synchronize()
    print(f"Flash Attention took {time.time() - start_time} seconds")

Standard Attention 0.18029260635375977 seconds
Flash Attention took 0.05571484565734863 seconds
