Open
Description
I wrote a helper that allows someone to use CuDNN attention within Pytorch seamlessly.
import cudnn
import torch
import math
# export CUDNN_FRONTEND_LOG_FLIE=fe.log
# export CUDNN_FRONTEND_LOG_INFO=1
# import os
# os.environ["CUDNN_FRONTEND_LOG_FILE"] = "fe.log"
# os.environ["CUDNN_FRONTEND_LOG_INFO"] = "1"
def convert_to_cudnn_type(torch_type):
if torch_type == torch.float16:
return cudnn.data_type.HALF
elif torch_type == torch.bfloat16:
return cudnn.data_type.BFLOAT16
elif torch_type == torch.float32:
return cudnn.data_type.FLOAT
elif torch_type == torch.int32:
return cudnn.data_type.INT32
elif torch_type == torch.int64:
return cudnn.data_type.INT64
else:
raise ValueError("Unsupported tensor data type.")
def make_cudnn_autograd(*, num_heads, head_dim, dtype):
assert dtype in [torch.float16, torch.bfloat16], f"Invalid dtype {dtype}"
dtype = convert_to_cudnn_type(dtype)
# match CuDNN's docs
H, D = num_heads, head_dim
del num_heads, head_dim
cache = {}
def init_or_check_tensor_attrs(tensor_name, tensor):
nonlocal cache
for attr in ['shape', 'stride', 'dtype', 'device']:
key = f'{tensor_name}_{attr}'
if key not in cache:
cache[key] = getattr(tensor, attr)
if callable(cache[key]):
cache[key] = cache[key]()
else:
v = cache[key]() if callable(cache[key]) else cache[key]
assert cache[key] == v, f"Expected {cache[key]} but got {v}"
class CuDNNAttention(torch.autograd.Function):
@staticmethod
def forward(ctx, B, N, L, q, kv, seqlens_kv):
assert q.shape == (B, N, H, D)
assert kv.shape == (B, N + L, 2, H, D)
assert seqlens_kv.shape == (B,)
# CuDNN plans are compiled for a specific shape, stride, dtype
# So we need to verify those attributes
init_or_check_tensor_attrs('q', q)
init_or_check_tensor_attrs('kv', kv)
init_or_check_tensor_attrs('seqlens_kv', seqlens_kv)
q = q.permute(0, 2, 1, 3) # B N H D -> B H N D
kv_view = kv.permute(2, 0, 3, 1, 4) # B S KV H D -> KV B H S D
k_view, v_view = torch.unbind(kv_view, dim=0)
assert not k_view.is_contiguous() and not v_view.is_contiguous(), f"kv should not be contiguous (unnecessary copy)"
assert k_view.shape == (B, H, (N + L), D), f"Got shape {k_view.shape} instead of {(B, num_heads, (N + L), D)}"
assert v_view.shape == (B, H, (N + L), D)
# TODO: Is this safe?
if 'stats' not in cache:
cache['stats'] = torch.empty(B, H, N, 1, dtype=torch.float32, device=q.device)
cache['seqlens_q'] = torch.tensor([N] * B, device=q.device, dtype=torch.int32).view(B, 1, 1, 1)
cache['o'] = torch.empty_like(q)
stats = cache['stats']
seqlens_q = cache['seqlens_q']
o = cache['o']
seqlens_kv = seqlens_kv.view(B, 1, 1, 1)
if 'compiled_graph_fwd' not in cache:
print("Compiling CuDNN graphs ...")
g_fwd = cudnn.pygraph(
io_data_type=dtype,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
cache['name_to_cu_tensor'] = {
'q_cu': g_fwd.tensor_like(q.detach()),
'k_cu': g_fwd.tensor_like(k_view.detach()),
'v_cu': g_fwd.tensor_like(v_view.detach()),
'seqlens_q_cu': g_fwd.tensor_like(seqlens_q.detach()),
'seqlens_kv_cu': g_fwd.tensor_like(seqlens_kv.detach())
}
cu_tens = cache['name_to_cu_tensor']
o_forward, stats_forward = g_fwd.sdpa(
name="sdpa",
q=cu_tens['q_cu'],
k=cu_tens['k_cu'],
v=cu_tens['v_cu'],
is_inference=False,
attn_scale=1.0 / math.sqrt(D),
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=cu_tens['seqlens_q_cu'],
seq_len_kv=cu_tens['seqlens_kv_cu']
)
o_forward.set_output(True).set_dim(o.shape).set_stride(o.stride()).set_data_type(dtype)
stats_forward.set_output(True).set_data_type(cudnn.data_type.FLOAT).set_dim(stats.shape).set_stride(stats.stride())
cu_tens['o_forward_cu'] = o_forward
cu_tens['stats_forward_cu'] = stats_forward
def assert_cudnn_shape(tensor, expected_shape):
assert tuple(tensor.get_dim()) == expected_shape, f"Expected shape {expected_shape} but got {tensor.get_dim()}"
assert_cudnn_shape(cu_tens['q_cu'], (B, H, N, D))
assert_cudnn_shape(cu_tens['k_cu'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['v_cu'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['o_forward_cu'], (B, H, N, D))
assert_cudnn_shape(cu_tens['stats_forward_cu'], (B, H, N, 1))
assert_cudnn_shape(cu_tens['seqlens_q_cu'], (B, 1, 1, 1))
assert_cudnn_shape(cu_tens['seqlens_kv_cu'], (B, 1, 1, 1))
g_fwd.validate()
g_fwd.build_operation_graph()
g_fwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
g_fwd.check_support()
g_fwd.build_plans()
cache['compiled_graph_fwd'] = g_fwd
g_bwd = cudnn.pygraph(
io_data_type=dtype,
intermediate_data_type=cudnn.data_type.FLOAT,
compute_data_type=cudnn.data_type.FLOAT,
)
cu_tens['q_cu_bwd'] = g_bwd.tensor_like(q.detach())
cu_tens['k_cu_bwd'] = g_bwd.tensor_like(k_view.detach())
cu_tens['v_cu_bwd'] = g_bwd.tensor_like(v_view.detach())
cu_tens['o_cu_bwd'] = g_bwd.tensor_like(o.detach())
cu_tens['dO_cu_bwd'] = g_bwd.tensor_like(o.detach())
cu_tens['stats_cu_bwd'] = g_bwd.tensor_like(stats.detach())
cu_tens['seqlens_q_cu_bwd'] = g_bwd.tensor_like(seqlens_q.detach())
cu_tens['seqlens_kv_cu_bwd'] = g_bwd.tensor_like(seqlens_kv.detach())
dQ_bwd_cu, dK_bwd_cu, dV_bwd_cu = g_bwd.sdpa_backward(
name="sdpa_backward",
q=cu_tens['q_cu_bwd'],
k=cu_tens['k_cu_bwd'],
v=cu_tens['v_cu_bwd'],
o=cu_tens['o_cu_bwd'],
dO=cu_tens['dO_cu_bwd'],
stats=cu_tens['stats_cu_bwd'],
attn_scale=1.0 / math.sqrt(D),
use_causal_mask=False,
use_padding_mask=True,
seq_len_q=cu_tens['seqlens_q_cu_bwd'],
seq_len_kv=cu_tens['seqlens_kv_cu_bwd']
)
# TODO: Is this safe?
# cache['dQ'] = torch.empty_like(q).contiguous()
# cache['dK'] = torch.empty_like(k_view).contiguous()
# cache['dV'] = torch.empty_like(v_view).contiguous()
cache['dQ'] = torch.empty_like(q)
cache['dK'] = torch.empty_like(k_view)
cache['dV'] = torch.empty_like(v_view)
dQ_bwd_cu.set_output(True).set_dim(cache['dQ'].size()).set_stride(cache['dQ'].stride())
dK_bwd_cu.set_output(True).set_dim(cache['dK'].size()).set_stride(cache['dK'].stride())
dV_bwd_cu.set_output(True).set_dim(cache['dV'].size()).set_stride(cache['dV'].stride())
cu_tens['dQ_cu_bwd'] = dQ_bwd_cu
cu_tens['dK_cu_bwd'] = dK_bwd_cu
cu_tens['dV_cu_bwd'] = dV_bwd_cu
assert_cudnn_shape(cu_tens['q_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['k_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['v_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['dQ_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['dK_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['dV_cu_bwd'], (B, H, N + L, D))
assert_cudnn_shape(cu_tens['o_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['dO_cu_bwd'], (B, H, N, D))
assert_cudnn_shape(cu_tens['stats_cu_bwd'], (B, H, N, 1))
assert_cudnn_shape(cu_tens['seqlens_q_cu_bwd'], (B, 1, 1, 1))
assert_cudnn_shape(cu_tens['seqlens_kv_cu_bwd'], (B, 1, 1, 1))
g_bwd.validate()
g_bwd.build_operation_graph()
g_bwd.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK])
g_bwd.check_support()
g_bwd.build_plans()
cache['compiled_graph_bwd'] = g_bwd
# TODO: Is this safe?
cache['workspace'] = torch.empty(
max(g_fwd.get_workspace_size(), g_bwd.get_workspace_size()),
device=q.device, dtype=torch.uint8
)
name_to_cu_tensor = cache['name_to_cu_tensor']
variant_pack_forward = {
name_to_cu_tensor[name]: tensor for name, tensor in [
('q_cu', q),
('k_cu', k_view),
('v_cu', v_view),
('o_forward_cu', o),
('stats_forward_cu', stats),
('seqlens_q_cu', seqlens_q),
('seqlens_kv_cu', seqlens_kv)
]
}
cache['compiled_graph_fwd'].execute(variant_pack_forward, cache['workspace'])
ctx.save_for_backward(q, k_view, v_view, o, stats, seqlens_kv)
ctx.B, ctx.N, ctx.L = B, N, L
ctx.dtype = dtype
return o
@staticmethod
def backward(ctx, grad_output):
q, k, v, o, stats, seqlens = ctx.saved_tensors
B, N, L = ctx.B, ctx.N, ctx.L
seqlens_q = cache['seqlens_q']
cu_tens = cache['name_to_cu_tensor']
assert tuple(grad_output.shape) == (B, H, N, D)
assert tuple(grad_output.shape) == tuple(cu_tens['dO_cu_bwd'].get_dim())
# For batch size 1, the stride can have 2 1s, I think this is a Pytorch bug
# https://discuss.pytorch.org/t/stride-has-2-1s-in-it/208036
assert tuple(grad_output.stride())[1:] == tuple(cu_tens['dO_cu_bwd'].get_stride())[1:], f"{tuple(cu_tens['dO_cu_bwd'].get_stride())} (expected) != {tuple(grad_output.stride())} (actual) for shape {tuple(grad_output.shape)}"
assert convert_to_cudnn_type(grad_output.dtype) == cu_tens['dO_cu_bwd'].get_data_type()
variant_pack_backward = {
cu_tens[name]: tensor for name, tensor in [
('dQ_cu_bwd', cache['dQ']),
('dK_cu_bwd', cache['dK']),
('dV_cu_bwd', cache['dV']),
('q_cu_bwd', q),
('k_cu_bwd', k),
('v_cu_bwd', v),
('o_cu_bwd', o),
('dO_cu_bwd', grad_output),
('stats_cu_bwd', stats),
('seqlens_q_cu_bwd', seqlens_q),
('seqlens_kv_cu_bwd', seqlens)
]
}
cache['compiled_graph_bwd'].execute(variant_pack_backward, cache['workspace'])
assert cache['dQ'].shape == (B, H, N, D)
dQ = cache['dQ'].permute(0, 2, 1, 3) # B H N D -> B N H D
assert cache['dK'].shape == (B, H, N + L, D)
assert cache['dV'].shape == (B, H, N + L, D)
dKV = torch.stack([cache['dK'], cache['dV']], dim=2)
assert dKV.shape == (B, H, 2, N + L, D)
dKV = dKV.permute(0, 3, 2, 1, 4) # B H 2 N D -> B N 2 H D
return None, None, None, dQ, dKV, None
return CuDNNAttention
However, while this gets better forward pass performance. It gets far worse backwards pass performance. Any thoughts on why this might be the case? I'm hoping there might be some obvious deficiency in my code.
(Unit is ms).
attention-forward-performance:
batch_size CuDNN FlashAttention
0 1.0 0.022976 0.033024
1 2.0 0.021664 0.039456
2 4.0 0.047680 0.058112
3 6.0 0.056800 0.072208
attention-backward-performance:
batch_size CuDNN FlashAttention
0 2.0 0.386144 0.282272
1 4.0 0.741664 0.301184
2 6.0 1.108608 0.464320
Metadata
Metadata
Assignees
Labels
No labels