In [8]:
import math
import torch
from torch.nn import functional as F
from torch.utils.cpp_extension import load
import torch.nn as nn
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), "../../"))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from util.test_util import test_eval

In [4]:
batch_size = 1
num_head = 12
head_embd = 64
start_len = 32
seq_len = 64


q = torch.randn(batch_size, num_head, seq_len, head_embd).cuda()
k = torch.zeros(batch_size, num_head, seq_len, head_embd).cuda()
v = torch.zeros(batch_size, num_head, seq_len, head_embd).cuda()
k[:, :, :start_len, :] = torch.randn(batch_size, num_head, start_len, head_embd, device='cuda')
v[:, :, :start_len, :] = torch.randn(batch_size, num_head, start_len, head_embd, device='cuda')

In [5]:
def manual_attn(q, k, v):
    att = (q @ k.transpose(-2, -1) * (1.0 / math.sqrt(k.size(-1)))) # q 和 k(最后两个维度转置) 的点积，再进行缩放；缩放因子是最后一个维度的平方根，防止点积结果过大，导致梯度消失或爆炸
    att = F.softmax(att, dim=-1) # 对注意力权重进行 softmax 归一化
    y = att @ v # 使用注意力权重对 v 进行加权求和，得到输出
    return y

In [6]:
flash_attn_manual = load(
    name='flash_attn_manual', 
    sources=['flash_attention.cu'], 
    extra_cuda_cflags=['-O2'])

If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


In [9]:
inputs = (q, k.clone(), v.clone())

In [10]:
print("=== Profiling Basic Attention ===")
test_eval(manual_attn, inputs)

=== Profiling Basic Attention ===
[31m100 iters, min = 0.0796 ms, max = 0.0999 ms, avg = 0.0820 ms[m


In [11]:
print("=== Profiling Flash Attention CUDA ===")
test_eval(flash_attn_manual.forward, inputs)

=== Profiling Flash Attention CUDA ===
Max shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory: 49152, requested shared memory: 28672 \nMax shared memory