-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add bias mask #66
Changes from 5 commits
3af4a51
5b86b05
1e8e679
a55e66f
05ab0f2
4b2ad45
3cd945a
985817b
8cbce46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from typing import Union | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
@@ -9,13 +11,28 @@ | |
|
||
|
||
# Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213 | ||
def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): | ||
def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, | ||
is_causal: bool, attention_mask: Union[torch.Tensor, None]): | ||
""" | ||
Reference implementation for attention | ||
@param q: Query matrix size (batch, heads, seq_length, BLOCK_DHEAD) | ||
@param k: Key matrix size (batch, heads, seq_length, BLOCK_DHEAD) | ||
@param v: Value matrix size (batch, heads, seq_length, BLOCK_DHEAD) | ||
@param sm_scale: Scaling factor applied after operation QxK | ||
@param attention_mask: Attention mask broadcastable to (batch, heads, seq_length, seq_length). Warning the mask | ||
isn't a binary mask | ||
like the one you use normally. This mask is directly added to QxK. | ||
@return: | ||
""" | ||
seq_length = q.size(2) | ||
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | ||
|
||
if attention_mask is not None: | ||
p += attention_mask | ||
if is_causal: | ||
M = torch.tril(torch.ones((seq_length, seq_length), device="cuda")) | ||
p = torch.where(M == 0, float("-inf"), p) | ||
p = torch.softmax(p.float(), dim=-1).to(q.dtype) | ||
p = torch.nn.functional.softmax(p, dim=-1) | ||
ref_out = torch.matmul(p, v, out=output) | ||
return ref_out | ||
|
||
|
@@ -28,12 +45,19 @@ def _fwd_kernel( | |
K, | ||
V, | ||
sm_scale, | ||
mask, | ||
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug | ||
output, | ||
q_batch_stride, q_head_stride, q_m_stride, q_k_stride, | ||
k_batch_stride, k_head_stride, k_n_stride, k_k_stride, | ||
k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # We name n,k instead of k,n because of the transpose | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we keep the comment? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. imo yes |
||
v_batch_stride, v_head_stride, v_k_stride, v_n_stride, | ||
o_batch_stride, o_head_stride, o_m_stride, o_n_stride, | ||
mask_batch_stride, mask_head_stride, mask_m_stride, mask_k_stride, | ||
MASK_BATCH_SIZE: tl.constexpr, | ||
MASK_HEAD_SIZE: tl.constexpr, | ||
MASK_M_SIZE: tl.constexpr, | ||
MASK_K_SIZE: tl.constexpr, | ||
Comment on lines
+57
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are they constant? What makes them different from mask_*_stride for instance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because we have conditions on them |
||
HAS_MASK: tl.constexpr, | ||
IS_CAUSAL: tl.constexpr, | ||
BLOCK_M: tl.constexpr, | ||
BLOCK_DHEAD: tl.constexpr, | ||
|
@@ -67,6 +91,17 @@ def _fwd_kernel( | |
@param o_head_stride: output matrix stride for head dimension | ||
@param o_m_stride: output matrix stride for rows | ||
@param o_n_stride: output matrix stride for columns | ||
@param mask: Attention mask matrix broadcastable to (batch, heads, seq_length, seq_length) | ||
@param mask_batch_stride: Matrix mask stride for batch dimension | ||
@param mask_head_stride: Matrix mask stride for head dimension | ||
@param mask_m_stride: Matrix mask stride for rows | ||
@param mask_k_stride: Matrix mask stride for columns | ||
@param MASK_BATCH_SIZE: Matrix mask size for batch dimension | ||
@param MASK_HEAD_SIZE: Matrix mask size for head dimension | ||
@param MASK_M_SIZE: Matrix mask size for rows | ||
@param MASK_K_SIZE: Matrix mask size for columns | ||
@param HAS_MASK: Whether the mask is applied | ||
@param IS_CAUSAL: Whether the mask is applied | ||
@param BLOCK_M: number of rows computed in a single instance for matrix Q | ||
@param BLOCK_DHEAD: number of columns per head | ||
@param BLOCK_N: number of rows computed at each loop in the main loop for matrix K and V | ||
|
@@ -137,6 +172,26 @@ def _fwd_kernel( | |
if IS_CAUSAL: | ||
qk += tl.where(offs_m[:, None] >= (n_row_offset + offs_n[None, :]), 0, float("-inf")) | ||
|
||
if HAS_MASK: | ||
mask_batch_idx = current_batch_idx, | ||
if MASK_BATCH_SIZE == 1: | ||
mask_batch_idx = 0 | ||
|
||
mask_head_idx = current_head_idx | ||
if MASK_HEAD_SIZE == 1: | ||
mask_head_idx = 0 | ||
|
||
offs_mask = mask_batch_idx * mask_batch_stride \ | ||
+ mask_head_idx * mask_head_stride \ | ||
+ (offs_n[None, :] + n_row_offset) * mask_k_stride | ||
|
||
if MASK_M_SIZE == 1: | ||
m = tl.load(mask + offs_mask) | ||
else: | ||
offs_mask += offs_m[:, None] * mask_m_stride | ||
m = tl.load(mask + offs_mask, eviction_policy="evict_first") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as discussed, we may want to add some comment |
||
qk += m | ||
|
||
# We compute softmax normalization like in Milakov et al. | ||
# We renamed m (in the original article) to l to avoid confusions | ||
# We start with the current block qk | ||
|
@@ -197,7 +252,8 @@ class Attention(torch.autograd.Function): | |
|
||
@staticmethod | ||
@custom_fwd(cast_inputs=torch.float16) | ||
def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): | ||
def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, | ||
sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
Computes attention. | ||
FP32 input and output are not supported. | ||
|
@@ -226,19 +282,41 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
grid = (triton.cdiv(seq_length, BLOCK_M), batch * heads) | ||
tmp = torch.empty((batch * heads, seq_length), device=q.device, dtype=torch.float32) | ||
|
||
HAS_MASK = False | ||
if attention_mask is not None: | ||
assert attention_mask.size(0) == batch or attention_mask.size(0) == 1 | ||
assert attention_mask.size(1) == heads or attention_mask.size(1) == 1 | ||
assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1 | ||
assert attention_mask.size(3) == seq_length | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add error message that would serve as documentaiton. Basically it may look like "mask is neither matching QKt shape or is broadcastable on its XXX axis" |
||
|
||
# Move inside kernel ? | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, keep it outside of Triton if it works. |
||
attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move inside kernel |
||
max=torch.finfo(attention_mask.dtype).max) | ||
HAS_MASK = True | ||
|
||
_fwd_kernel[grid]( | ||
heads, | ||
seq_length, | ||
q, | ||
k, | ||
v, | ||
sm_scale, | ||
attention_mask, | ||
tmp, | ||
output, | ||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | ||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | ||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | ||
output.stride(0), output.stride(1), output.stride(2), output.stride(3), | ||
attention_mask.stride(0) if HAS_MASK else 0, | ||
attention_mask.stride(1) if HAS_MASK else 0, | ||
attention_mask.stride(2) if HAS_MASK else 0, | ||
attention_mask.stride(3) if HAS_MASK else 0, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you named argument because now it's very long |
||
MASK_BATCH_SIZE=attention_mask.size(0) if HAS_MASK else 0, | ||
MASK_HEAD_SIZE=attention_mask.size(1) if HAS_MASK else 0, | ||
MASK_M_SIZE=attention_mask.size(2) if HAS_MASK else 0, | ||
MASK_K_SIZE=attention_mask.size(3) if HAS_MASK else 0, | ||
HAS_MASK=HAS_MASK, | ||
IS_CAUSAL=is_causal, | ||
BLOCK_M=BLOCK_M, | ||
BLOCK_N=BLOCK_N, | ||
|
@@ -250,5 +328,6 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, | |
return output | ||
|
||
|
||
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool = False): | ||
return Attention.apply(q, k, v, output, sm_scale, is_causal) | ||
def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, | ||
is_causal: bool = False, attention_mask: torch.Tensor = None): | ||
return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,23 +16,41 @@ def original_triton_flash_attention(is_causal: bool, *args, **kwargs): | |
|
||
|
||
implementations = { | ||
"original": lambda q, k, v, output, sm_scale, is_causal: original_triton_flash_attention(is_causal, q, k, v, output, sm_scale), | ||
"triton": lambda q, k, v, output, sm_scale, is_causal: attention_forward(q, k, v, output, sm_scale, is_causal), | ||
"torch": lambda q, k, v, output, sm_scale, is_causal: attention_reference(q, k, v, output, sm_scale, is_causal), | ||
"original": lambda q, k, v, output, sm_scale, is_causal, attention_mask: original_triton_flash_attention(is_causal, q, k, v, output, sm_scale), | ||
"triton": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_forward(q, k, v, output, sm_scale, is_causal, attention_mask), | ||
"torch": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_reference(q, k, v, output, sm_scale, is_causal, attention_mask), | ||
} | ||
|
||
def generate_broadcast_mask(batch, seq_length, dtype=torch.float32): | ||
attention_mask = torch.randint(1, seq_length, (batch,), device="cuda")[:, None] > torch.arange(0, seq_length, device="cuda")[ | ||
None, :] | ||
attention_mask = attention_mask.to(dtype) | ||
attention_mask = torch.reshape(attention_mask, (batch, 1, 1, seq_length)) | ||
attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min | ||
return attention_mask | ||
|
||
def generate_bias_mask(batch, seq_length, dtype=torch.float32): | ||
return torch.rand((batch, 48, seq_length, seq_length), dtype=dtype, device="cuda") | ||
|
||
def generate_none_mask(batch, seq_length, dtype=torch.float32): | ||
return None | ||
|
||
@set_seed() | ||
@pytest.mark.parametrize("shape", [(bs, seq_l) for bs in [1, 8, 32, 64] for seq_l in [16, 64, 128, 256, 384, 512]], | ||
ids=lambda x: f"{x[0]}x{x[1]}") | ||
@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"]) | ||
# fp32 not yet possible because of a bug in triton | ||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) | ||
@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"]) | ||
@pytest.mark.parametrize("mask_fn", [generate_bias_mask, generate_broadcast_mask, generate_none_mask], ids=["bias-mask", "broadcast-mask", 'no-mask']) | ||
@pytest.mark.parametrize("implementation", implementations.keys()) | ||
def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable, dtype: torch.dtype, is_causal: bool): | ||
def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable, mask_fn: Callable, dtype: torch.dtype, is_causal: bool): | ||
batch, seq_length = shape | ||
if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512): | ||
pytest.skip("Original Triton implementation only supports fp16 and seq_length=512") | ||
if implementation == "original" and mask_fn != generate_none_mask: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. elif to highlight that we chain our tests |
||
pytest.skip("Original Triton implementation doesn't support masks") | ||
if is_causal and mask_fn != generate_none_mask: | ||
pytest.skip("Not supported") | ||
|
||
# batch, heads, seq_length, dhead | ||
mat_shape = (batch, 48, seq_length, 64) | ||
|
@@ -43,6 +61,7 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable | |
"output": torch.empty(mat_shape, device="cuda"), | ||
"sm_scale": 0.3, # Scaling applied before softmax (sqrt(dhead) in Vaswani et al.) | ||
"is_causal": is_causal, | ||
"attention_mask": mask_fn(batch, seq_length) | ||
} | ||
|
||
expected = attention_reference(**args) | ||
|
@@ -63,7 +82,7 @@ def test_mixed_stride(): | |
v = torch.rand_like(q) | ||
sm_scale = 0.3 | ||
|
||
expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False) | ||
expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you please add dedicated test to mask |
||
output = torch.empty_like(q) | ||
attention_forward(q, k, v, output, sm_scale) | ||
assert torch.allclose(output, expected, atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not very important but you can hint type the output
-> torch.Tensor