From 3af4a514395e142b957a7b0d489aed41c5446090 Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Wed, 28 Sep 2022 14:03:43 +0200 Subject: [PATCH 1/9] feat: add bias mask --- implementations/attention.py | 58 ++++++++++++++++++++++++++++++++---- optimizer/attention.py | 4 +-- test/models/data_utils.py | 6 +++- test/test_attention.py | 31 ++++++++++++++----- 4 files changed, 83 insertions(+), 16 deletions(-) diff --git a/implementations/attention.py b/implementations/attention.py index 3ed7890d..51e1ff1a 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -1,3 +1,5 @@ +from typing import Union + import torch import triton import triton.language as tl @@ -9,17 +11,31 @@ # Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213 -def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): +def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool, attention_mask: Union[torch.Tensor, None]): + """ + Reference implementation for attention + @param q: Query matrix size (batch, heads, seq_length, BLOCK_DHEAD) + @param k: Key matrix size (batch, heads, seq_length, BLOCK_DHEAD) + @param v: Value matrix size (batch, heads, seq_length, BLOCK_DHEAD) + @param sm_scale: Scaling factor applied after operation QxK + @param attention_mask: Size (batch, 1, 1, seq_length) or (batch, heads, seq_length, seq_length). Warning the mask isn't a binary mask + like the one you use normally. This mask is directly added to QxK. + @return: + """ seq_length = q.size(2) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + + if attention_mask is not None: + p += attention_mask if is_causal: M = torch.tril(torch.ones((seq_length, seq_length), device="cuda")) p = torch.where(M == 0, float("-inf"), p) - p = torch.softmax(p.float(), dim=-1).to(q.dtype) + p = torch.nn.functional.softmax(p, dim=-1) ref_out = torch.matmul(p, v, out=output) return ref_out + @triton.jit def _fwd_kernel( heads, @@ -28,12 +44,16 @@ def _fwd_kernel( K, V, sm_scale, + mask, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug output, q_batch_stride, q_head_stride, q_m_stride, q_k_stride, - k_batch_stride, k_head_stride, k_n_stride, k_k_stride, + k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # We name n,k instead of k,n because of the transpose v_batch_stride, v_head_stride, v_k_stride, v_n_stride, o_batch_stride, o_head_stride, o_m_stride, o_n_stride, + mask_batch_stride, mask_head_stride, mask_m_stride, mask_k_stride, + HAS_MASK: tl.constexpr, + IS_MASK_BROADCAST: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DHEAD: tl.constexpr, @@ -137,6 +157,18 @@ def _fwd_kernel( if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (n_row_offset + offs_n[None, :]), 0, float("-inf")) + + if HAS_MASK: + offs_mask = current_batch_idx * mask_batch_stride \ + + (offs_n[None, :] + n_row_offset) * mask_k_stride + if IS_MASK_BROADCAST: + m = tl.load(mask + offs_mask) + else: + offs_mask += offs_m[:, None] * mask_m_stride \ + + current_head_idx * mask_head_stride + m = tl.load(mask + offs_mask, eviction_policy="evict_first") + qk += m + # We compute softmax normalization like in Milakov et al. # We renamed m (in the original article) to l to avoid confusions # We start with the current block qk @@ -197,7 +229,7 @@ class Attention(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): + def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None): """ Computes attention. FP32 input and output are not supported. @@ -226,6 +258,13 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, grid = (triton.cdiv(seq_length, BLOCK_M), batch * heads) tmp = torch.empty((batch * heads, seq_length), device=q.device, dtype=torch.float32) + HAS_MASK = False + IS_MASK_BROADCAST = False + if attention_mask is not None: + assert attention_mask.size() == (batch, heads, seq_length, seq_length) or attention_mask.size() == (batch, 1, 1, seq_length) + HAS_MASK = True + IS_MASK_BROADCAST = attention_mask.size() != (batch, heads, seq_length, seq_length) + _fwd_kernel[grid]( heads, seq_length, @@ -233,12 +272,19 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, k, v, sm_scale, + attention_mask, tmp, output, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), output.stride(0), output.stride(1), output.stride(2), output.stride(3), + attention_mask.stride(0) if HAS_MASK else 0, + attention_mask.stride(1) if HAS_MASK else 0, + attention_mask.stride(2) if HAS_MASK else 0, + attention_mask.stride(3) if HAS_MASK else 0, + HAS_MASK=HAS_MASK, + IS_MASK_BROADCAST=IS_MASK_BROADCAST, IS_CAUSAL=is_causal, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -250,5 +296,5 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return output -def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool = False): - return Attention.apply(q, k, v, output, sm_scale, is_causal) +def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool = False, attention_mask: torch.Tensor = None): + return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask) diff --git a/optimizer/attention.py b/optimizer/attention.py index acc0ce8d..66f8d004 100644 --- a/optimizer/attention.py +++ b/optimizer/attention.py @@ -4,8 +4,8 @@ from utils.extended_matcher import replace_pattern -def attention_wrapper(q, k, v, output, sm_scale, is_causal, *args): - return attention_forward(q, k, v, output, sm_scale, is_causal=is_causal) +def attention_wrapper(q, k, v, output, sm_scale, is_causal, attention_mask): + return attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask) torch.fx.wrap('attention_wrapper') diff --git a/test/models/data_utils.py b/test/models/data_utils.py index a8669ab2..be6868e3 100644 --- a/test/models/data_utils.py +++ b/test/models/data_utils.py @@ -3,6 +3,10 @@ import torch +def get_attention_mask(shape: (int, int)) -> torch.Tensor: + return torch.randint(1, shape[1], (shape[0],), device="cuda")[:, None] > torch.arange(0, shape[1], device="cuda")[ + None, :] + def get_input_causal(shape: (int, int)) -> Dict[str, torch.Tensor]: batch, seq_length = shape mask = torch.tril(torch.ones((batch, seq_length, seq_length), dtype=torch.int64, device="cuda")) @@ -16,6 +20,6 @@ def get_input_causal(shape: (int, int)) -> Dict[str, torch.Tensor]: def get_input_non_causal(shape: (int, int)) -> Dict[str, torch.Tensor]: return { "input_ids": torch.randint(2, 1000, size=shape, dtype=torch.int64, device="cuda"), - "attention_mask": torch.ones(size=shape, dtype=torch.int64, device="cuda"), + "attention_mask": get_attention_mask(shape).to(torch.int64), "token_type_ids": torch.ones(size=shape, dtype=torch.int64, device="cuda") } diff --git a/test/test_attention.py b/test/test_attention.py index 8fad6465..f489f555 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -16,23 +16,39 @@ def original_triton_flash_attention(is_causal: bool, *args, **kwargs): implementations = { - "original": lambda q, k, v, output, sm_scale, is_causal: original_triton_flash_attention(is_causal, q, k, v, output, sm_scale), - "triton": lambda q, k, v, output, sm_scale, is_causal: attention_forward(q, k, v, output, sm_scale, is_causal), - "torch": lambda q, k, v, output, sm_scale, is_causal: attention_reference(q, k, v, output, sm_scale, is_causal), + "original": lambda q, k, v, output, sm_scale, is_causal, attention_mask: original_triton_flash_attention(is_causal, q, k, v, output, sm_scale), + "triton": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_forward(q, k, v, output, sm_scale, is_causal, attention_mask), + "torch": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_reference(q, k, v, output, sm_scale, is_causal, attention_mask), } +def generate_broadcast_mask(batch, seq_length, dtype=torch.float32): + attention_mask = torch.randint(0, 2, size=(batch, seq_length), device="cuda").to(dtype) + attention_mask = torch.reshape(attention_mask, (batch, 1, 1, seq_length)) + attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min + return attention_mask + +def generate_bias_mask(batch, seq_length, dtype=torch.float32): + return torch.rand((batch, 48, seq_length, seq_length), dtype=dtype, device="cuda") + +def generate_none_mask(batch, seq_length, dtype=torch.float32): + return None @set_seed() @pytest.mark.parametrize("shape", [(bs, seq_l) for bs in [1, 8, 32, 64] for seq_l in [16, 64, 128, 256, 384, 512]], ids=lambda x: f"{x[0]}x{x[1]}") -@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"]) # fp32 not yet possible because of a bug in triton @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"]) +@pytest.mark.parametrize("mask_fn", [generate_bias_mask, generate_broadcast_mask, generate_none_mask], ids=["bias-mask", "broadcast-mask", 'no-mask']) @pytest.mark.parametrize("implementation", implementations.keys()) -def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable, dtype: torch.dtype, is_causal: bool): +def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable, mask_fn: Callable, dtype: torch.dtype, is_causal: bool): batch, seq_length = shape if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512): pytest.skip("Original Triton implementation only supports fp16 and seq_length=512") + if implementation == "original" and mask_fn != generate_none_mask: + pytest.skip("Original Triton implementation doesn't support masks") + if is_causal and mask_fn != generate_none_mask: + pytest.skip("Not supported") # batch, heads, seq_length, dhead mat_shape = (batch, 48, seq_length, 64) @@ -43,6 +59,7 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable "output": torch.empty(mat_shape, device="cuda"), "sm_scale": 0.3, # Scaling applied before softmax (sqrt(dhead) in Vaswani et al.) "is_causal": is_causal, + "attention_mask": mask_fn(batch, seq_length) } expected = attention_reference(**args) @@ -51,7 +68,7 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable func = implementations[implementation] value = benchmark(func, **cast_args) - assert torch.allclose(value.float(), expected, atol=1e-1) + assert torch.allclose(value.float(), expected, atol=1e-2) @set_seed() @@ -63,7 +80,7 @@ def test_mixed_stride(): v = torch.rand_like(q) sm_scale = 0.3 - expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False) + expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=None) output = torch.empty_like(q) attention_forward(q, k, v, output, sm_scale) assert torch.allclose(output, expected, atol=1e-2) From 5b86b05f18b02f45957d151b1a70b5aff03d21b5 Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Thu, 29 Sep 2022 12:54:50 +0200 Subject: [PATCH 2/9] fix: ensure mask finite values --- implementations/attention.py | 2 ++ test/test_attention.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/implementations/attention.py b/implementations/attention.py index 51e1ff1a..d28e97e8 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -262,6 +262,8 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, IS_MASK_BROADCAST = False if attention_mask is not None: assert attention_mask.size() == (batch, heads, seq_length, seq_length) or attention_mask.size() == (batch, 1, 1, seq_length) + # Move inside kernel ? + attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, max=torch.finfo(attention_mask.dtype).max) HAS_MASK = True IS_MASK_BROADCAST = attention_mask.size() != (batch, heads, seq_length, seq_length) diff --git a/test/test_attention.py b/test/test_attention.py index f489f555..805b90fd 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -22,7 +22,9 @@ def original_triton_flash_attention(is_causal: bool, *args, **kwargs): } def generate_broadcast_mask(batch, seq_length, dtype=torch.float32): - attention_mask = torch.randint(0, 2, size=(batch, seq_length), device="cuda").to(dtype) + attention_mask = torch.randint(1, seq_length, (batch,), device="cuda")[:, None] > torch.arange(0, seq_length, device="cuda")[ + None, :] + attention_mask = attention_mask.to(dtype) attention_mask = torch.reshape(attention_mask, (batch, 1, 1, seq_length)) attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min return attention_mask @@ -68,7 +70,7 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable func = implementations[implementation] value = benchmark(func, **cast_args) - assert torch.allclose(value.float(), expected, atol=1e-2) + assert torch.allclose(value.float(), expected, atol=1e-1) @set_seed() From 1e8e6794af26061dd8669cef81fbd76d7ec6c62b Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Thu, 29 Sep 2022 15:57:28 +0200 Subject: [PATCH 3/9] feat: support mask broadcast --- implementations/attention.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/implementations/attention.py b/implementations/attention.py index d28e97e8..8f845612 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -52,8 +52,8 @@ def _fwd_kernel( v_batch_stride, v_head_stride, v_k_stride, v_n_stride, o_batch_stride, o_head_stride, o_m_stride, o_n_stride, mask_batch_stride, mask_head_stride, mask_m_stride, mask_k_stride, + mask_batch_size: tl.constexpr, mask_head_size: tl.constexpr, mask_m_size: tl.constexpr, mask_k_size: tl.constexpr, HAS_MASK: tl.constexpr, - IS_MASK_BROADCAST: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DHEAD: tl.constexpr, @@ -159,13 +159,22 @@ def _fwd_kernel( if HAS_MASK: - offs_mask = current_batch_idx * mask_batch_stride \ + mask_batch_idx = current_batch_idx, + if mask_batch_size == 1: + mask_batch_idx = 0 + + mask_head_idx = current_head_idx + if mask_head_size == 1: + mask_head_idx = 0 + + offs_mask = mask_batch_idx * mask_batch_stride \ + + mask_head_idx * mask_head_stride \ + (offs_n[None, :] + n_row_offset) * mask_k_stride - if IS_MASK_BROADCAST: + + if mask_m_size == 1: m = tl.load(mask + offs_mask) else: - offs_mask += offs_m[:, None] * mask_m_stride \ - + current_head_idx * mask_head_stride + offs_mask += offs_m[:, None] * mask_m_stride m = tl.load(mask + offs_mask, eviction_policy="evict_first") qk += m @@ -259,13 +268,15 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tmp = torch.empty((batch * heads, seq_length), device=q.device, dtype=torch.float32) HAS_MASK = False - IS_MASK_BROADCAST = False if attention_mask is not None: - assert attention_mask.size() == (batch, heads, seq_length, seq_length) or attention_mask.size() == (batch, 1, 1, seq_length) + assert attention_mask.size(0) == batch or attention_mask.size(0) == 1 + assert attention_mask.size(1) == heads or attention_mask.size(1) == 1 + assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1 + assert attention_mask.size(3) == seq_length + # Move inside kernel ? attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, max=torch.finfo(attention_mask.dtype).max) HAS_MASK = True - IS_MASK_BROADCAST = attention_mask.size() != (batch, heads, seq_length, seq_length) _fwd_kernel[grid]( heads, @@ -285,8 +296,11 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask.stride(1) if HAS_MASK else 0, attention_mask.stride(2) if HAS_MASK else 0, attention_mask.stride(3) if HAS_MASK else 0, + attention_mask.size(0) if HAS_MASK else 0, + attention_mask.size(1) if HAS_MASK else 0, + attention_mask.size(2) if HAS_MASK else 0, + attention_mask.size(3) if HAS_MASK else 0, HAS_MASK=HAS_MASK, - IS_MASK_BROADCAST=IS_MASK_BROADCAST, IS_CAUSAL=is_causal, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, From a55e66fa783bcc6a42430aa24ff9d22532148277 Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Thu, 29 Sep 2022 16:19:50 +0200 Subject: [PATCH 4/9] fix: minor --- implementations/attention.py | 33 +++++++++++++++++++-------------- test/test_torchdynamo_bert.py | 2 ++ 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/implementations/attention.py b/implementations/attention.py index 8f845612..d28c15fc 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -11,7 +11,8 @@ # Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213 -def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool, attention_mask: Union[torch.Tensor, None]): +def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, + is_causal: bool, attention_mask: Union[torch.Tensor, None]): """ Reference implementation for attention @param q: Query matrix size (batch, heads, seq_length, BLOCK_DHEAD) @@ -35,7 +36,6 @@ def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, outpu return ref_out - @triton.jit def _fwd_kernel( heads, @@ -52,7 +52,10 @@ def _fwd_kernel( v_batch_stride, v_head_stride, v_k_stride, v_n_stride, o_batch_stride, o_head_stride, o_m_stride, o_n_stride, mask_batch_stride, mask_head_stride, mask_m_stride, mask_k_stride, - mask_batch_size: tl.constexpr, mask_head_size: tl.constexpr, mask_m_size: tl.constexpr, mask_k_size: tl.constexpr, + MASK_BATCH_SIZE: tl.constexpr, + MASK_HEAD_SIZE: tl.constexpr, + MASK_M_SIZE: tl.constexpr, + MASK_K_SIZE: tl.constexpr, HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, @@ -157,21 +160,20 @@ def _fwd_kernel( if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (n_row_offset + offs_n[None, :]), 0, float("-inf")) - if HAS_MASK: mask_batch_idx = current_batch_idx, - if mask_batch_size == 1: + if MASK_BATCH_SIZE == 1: mask_batch_idx = 0 mask_head_idx = current_head_idx - if mask_head_size == 1: + if MASK_HEAD_SIZE == 1: mask_head_idx = 0 offs_mask = mask_batch_idx * mask_batch_stride \ + mask_head_idx * mask_head_stride \ + (offs_n[None, :] + n_row_offset) * mask_k_stride - if mask_m_size == 1: + if MASK_M_SIZE == 1: m = tl.load(mask + offs_mask) else: offs_mask += offs_m[:, None] * mask_m_stride @@ -238,7 +240,8 @@ class Attention(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None): + def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, + sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None): """ Computes attention. FP32 input and output are not supported. @@ -275,7 +278,8 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, assert attention_mask.size(3) == seq_length # Move inside kernel ? - attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, max=torch.finfo(attention_mask.dtype).max) + attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, + max=torch.finfo(attention_mask.dtype).max) HAS_MASK = True _fwd_kernel[grid]( @@ -296,10 +300,10 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask.stride(1) if HAS_MASK else 0, attention_mask.stride(2) if HAS_MASK else 0, attention_mask.stride(3) if HAS_MASK else 0, - attention_mask.size(0) if HAS_MASK else 0, - attention_mask.size(1) if HAS_MASK else 0, - attention_mask.size(2) if HAS_MASK else 0, - attention_mask.size(3) if HAS_MASK else 0, + MASK_BATCH_SIZE=attention_mask.size(0) if HAS_MASK else 0, + MASK_HEAD_SIZE=attention_mask.size(1) if HAS_MASK else 0, + MASK_M_SIZE=attention_mask.size(2) if HAS_MASK else 0, + MASK_K_SIZE=attention_mask.size(3) if HAS_MASK else 0, HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, BLOCK_M=BLOCK_M, @@ -312,5 +316,6 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return output -def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool = False, attention_mask: torch.Tensor = None): +def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, + is_causal: bool = False, attention_mask: torch.Tensor = None): return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask) diff --git a/test/test_torchdynamo_bert.py b/test/test_torchdynamo_bert.py index 3056d587..d9d2bc14 100644 --- a/test/test_torchdynamo_bert.py +++ b/test/test_torchdynamo_bert.py @@ -32,6 +32,8 @@ class Implementation: "dynamo_cuda_graphs": Implementation(get_model_dynamo_cuda_graphs, is_causal=False), "dynamo_optimized": Implementation(get_model_optimized, is_causal=False), "dynamo_optimized_cuda_graphs": Implementation(get_model_optimized_cuda_graphs, is_causal=False), + # In this implementation both causal mask and the assume causal mask optimization will be applied, leads to slower + # benchmark "dynamo_optimizer_cuda_graphs_causal": Implementation(get_model_optimized_causal_cuda_graphs, is_causal=True), } From 05ab0f2aa2131df0e7a3904900a46d26541a6ecb Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Thu, 29 Sep 2022 16:27:15 +0200 Subject: [PATCH 5/9] fix: docs --- implementations/attention.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/implementations/attention.py b/implementations/attention.py index d28c15fc..11eafb2a 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -19,7 +19,8 @@ def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, outpu @param k: Key matrix size (batch, heads, seq_length, BLOCK_DHEAD) @param v: Value matrix size (batch, heads, seq_length, BLOCK_DHEAD) @param sm_scale: Scaling factor applied after operation QxK - @param attention_mask: Size (batch, 1, 1, seq_length) or (batch, heads, seq_length, seq_length). Warning the mask isn't a binary mask + @param attention_mask: Attention mask broadcastable to (batch, heads, seq_length, seq_length). Warning the mask + isn't a binary mask like the one you use normally. This mask is directly added to QxK. @return: """ @@ -90,6 +91,17 @@ def _fwd_kernel( @param o_head_stride: output matrix stride for head dimension @param o_m_stride: output matrix stride for rows @param o_n_stride: output matrix stride for columns + @param mask: Attention mask matrix broadcastable to (batch, heads, seq_length, seq_length) + @param mask_batch_stride: Matrix mask stride for batch dimension + @param mask_head_stride: Matrix mask stride for head dimension + @param mask_m_stride: Matrix mask stride for rows + @param mask_k_stride: Matrix mask stride for columns + @param MASK_BATCH_SIZE: Matrix mask size for batch dimension + @param MASK_HEAD_SIZE: Matrix mask size for head dimension + @param MASK_M_SIZE: Matrix mask size for rows + @param MASK_K_SIZE: Matrix mask size for columns + @param HAS_MASK: Whether the mask is applied + @param IS_CAUSAL: Whether the mask is applied @param BLOCK_M: number of rows computed in a single instance for matrix Q @param BLOCK_DHEAD: number of columns per head @param BLOCK_N: number of rows computed at each loop in the main loop for matrix K and V From 4b2ad457586210126aaaab3d88f85d314167588c Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Thu, 29 Sep 2022 16:49:05 +0200 Subject: [PATCH 6/9] fix: fix tests --- test/test_attention.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_attention.py b/test/test_attention.py index 805b90fd..94a63e4f 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -49,8 +49,6 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable pytest.skip("Original Triton implementation only supports fp16 and seq_length=512") if implementation == "original" and mask_fn != generate_none_mask: pytest.skip("Original Triton implementation doesn't support masks") - if is_causal and mask_fn != generate_none_mask: - pytest.skip("Not supported") # batch, heads, seq_length, dhead mat_shape = (batch, 48, seq_length, 64) From 3cd945a252b53bed15d99b354917bf288e874e1f Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Fri, 30 Sep 2022 09:44:37 +0200 Subject: [PATCH 7/9] fix: perf enhancements --- implementations/attention.py | 86 ++++++++++++++++++----------------- test/test_attention.py | 15 ++++-- test/test_torchdynamo_bert.py | 2 +- 3 files changed, 56 insertions(+), 47 deletions(-) diff --git a/implementations/attention.py b/implementations/attention.py index 11eafb2a..d5e77bf2 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional import torch import triton @@ -12,7 +12,7 @@ # Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213 def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, - is_causal: bool, attention_mask: Union[torch.Tensor, None]): + is_causal: bool, attention_mask: Union[torch.Tensor, None]) -> torch.Tensor: """ Reference implementation for attention @param q: Query matrix size (batch, heads, seq_length, BLOCK_DHEAD) @@ -53,6 +53,7 @@ def _fwd_kernel( v_batch_stride, v_head_stride, v_k_stride, v_n_stride, o_batch_stride, o_head_stride, o_m_stride, o_n_stride, mask_batch_stride, mask_head_stride, mask_m_stride, mask_k_stride, + min_clamp_value, MASK_BATCH_SIZE: tl.constexpr, MASK_HEAD_SIZE: tl.constexpr, MASK_M_SIZE: tl.constexpr, @@ -157,6 +158,18 @@ def _fwd_kernel( n_end = seq_length if IS_CAUSAL: n_end = (m_block_idx + 1) * BLOCK_M, + + if HAS_MASK: + mask_batch_idx = current_batch_idx, + if MASK_BATCH_SIZE == 1: + mask_batch_idx = 0 + + mask_head_idx = current_head_idx + if MASK_HEAD_SIZE == 1: + mask_head_idx = 0 + + offs_base_mask = mask_batch_idx * mask_batch_stride + mask_head_idx * mask_head_stride + # loop over k, v and update accumulator # n_row_offset is the row offset on dimension N of the current block # It's used for both the N dimension of K and V because they are handled at the same time @@ -173,23 +186,16 @@ def _fwd_kernel( qk += tl.where(offs_m[:, None] >= (n_row_offset + offs_n[None, :]), 0, float("-inf")) if HAS_MASK: - mask_batch_idx = current_batch_idx, - if MASK_BATCH_SIZE == 1: - mask_batch_idx = 0 - - mask_head_idx = current_head_idx - if MASK_HEAD_SIZE == 1: - mask_head_idx = 0 - - offs_mask = mask_batch_idx * mask_batch_stride \ - + mask_head_idx * mask_head_stride \ - + (offs_n[None, :] + n_row_offset) * mask_k_stride - + offs_mask = offs_base_mask + (offs_n[None, :] + n_row_offset) * mask_k_stride + # If it's a broadcast we only load vector size BLOCK_N else a matrix size (BLOCK_M, BLOCK_N) if MASK_M_SIZE == 1: m = tl.load(mask + offs_mask) else: offs_mask += offs_m[:, None] * mask_m_stride + # The mask matrix is never reused m = tl.load(mask + offs_mask, eviction_policy="evict_first") + # Avoids NaN + m = tl.where(m == float("-inf"), min_clamp_value, m) qk += m # We compute softmax normalization like in Milakov et al. @@ -253,7 +259,7 @@ class Attention(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, - sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None): + sm_scale: float, is_causal: bool, attention_mask: Optional[torch.Tensor] = None): """ Computes attention. FP32 input and output are not supported. @@ -284,34 +290,32 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, HAS_MASK = False if attention_mask is not None: - assert attention_mask.size(0) == batch or attention_mask.size(0) == 1 - assert attention_mask.size(1) == heads or attention_mask.size(1) == 1 - assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1 - assert attention_mask.size(3) == seq_length - - # Move inside kernel ? - attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, - max=torch.finfo(attention_mask.dtype).max) + assert attention_mask.size(0) == batch or attention_mask.size(0) == 1, "Incompatible broadcast batch dimension" + assert attention_mask.size(1) == heads or attention_mask.size(1) == 1, "Incompatible broadcast heads dimension" + assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1, "Incompatible broadcast seq_length dimension" + assert attention_mask.size(3) == seq_length, "Last size of mask must be seq_length to broadcast on QK^t" + HAS_MASK = True _fwd_kernel[grid]( - heads, - seq_length, - q, - k, - v, - sm_scale, - attention_mask, - tmp, - output, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - output.stride(0), output.stride(1), output.stride(2), output.stride(3), - attention_mask.stride(0) if HAS_MASK else 0, - attention_mask.stride(1) if HAS_MASK else 0, - attention_mask.stride(2) if HAS_MASK else 0, - attention_mask.stride(3) if HAS_MASK else 0, + heads=heads, + seq_length=seq_length, + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + mask=attention_mask, + TMP=tmp, + output=output, + q_batch_stride=q.stride(0), q_head_stride=q.stride(1), q_m_stride=q.stride(2), q_k_stride=q.stride(3), + k_batch_stride=k.stride(0), k_head_stride=k.stride(1), k_n_stride=k.stride(2), k_k_stride=k.stride(3), + v_batch_stride=v.stride(0), v_head_stride=v.stride(1), v_k_stride=v.stride(2), v_n_stride=v.stride(3), + o_batch_stride=output.stride(0), o_head_stride=output.stride(1), o_m_stride=output.stride(2), o_n_stride=output.stride(3), + mask_batch_stride=attention_mask.stride(0) if HAS_MASK else 0, + mask_head_stride=attention_mask.stride(1) if HAS_MASK else 0, + mask_m_stride=attention_mask.stride(2) if HAS_MASK else 0, + mask_k_stride=attention_mask.stride(3) if HAS_MASK else 0, + min_clamp_value=torch.finfo(attention_mask.dtype).min, MASK_BATCH_SIZE=attention_mask.size(0) if HAS_MASK else 0, MASK_HEAD_SIZE=attention_mask.size(1) if HAS_MASK else 0, MASK_M_SIZE=attention_mask.size(2) if HAS_MASK else 0, @@ -329,5 +333,5 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, - is_causal: bool = False, attention_mask: torch.Tensor = None): + is_causal: bool = False, attention_mask: Optional[torch.Tensor] = None): return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask) diff --git a/test/test_attention.py b/test/test_attention.py index 94a63e4f..833565f1 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -47,7 +47,7 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable batch, seq_length = shape if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512): pytest.skip("Original Triton implementation only supports fp16 and seq_length=512") - if implementation == "original" and mask_fn != generate_none_mask: + elif implementation == "original" and mask_fn != generate_none_mask: pytest.skip("Original Triton implementation doesn't support masks") # batch, heads, seq_length, dhead @@ -74,13 +74,18 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable @set_seed() def test_mixed_stride(): # Column major - q = torch.transpose(torch.rand((4, 48, 64, 512), dtype=torch.float16, device="cuda"), -1, -2) + q = torch.rand((4, 48, 64, 512), dtype=torch.float16, device="cuda").transpose(-1, -2) # Interlaced batch - k = torch.transpose(torch.rand((48, 4, 512, 64), dtype=torch.float16, device="cuda"), 0, 1) + k = torch.rand((48, 4, 512, 64), dtype=torch.float16, device="cuda").transpose(0, 1) v = torch.rand_like(q) + mask = torch.rand((48, 4, 512, 512), dtype=torch.float16, device="cuda").transpose(0, 1).transpose(-1, -2) sm_scale = 0.3 - expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=None) + assert not q.is_contiguous() + assert not k.is_contiguous() + assert not mask.is_contiguous() + + expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=mask) output = torch.empty_like(q) - attention_forward(q, k, v, output, sm_scale) + attention_forward(q, k, v, output, sm_scale, attention_mask=mask) assert torch.allclose(output, expected, atol=1e-2) diff --git a/test/test_torchdynamo_bert.py b/test/test_torchdynamo_bert.py index d9d2bc14..7a7169a0 100644 --- a/test/test_torchdynamo_bert.py +++ b/test/test_torchdynamo_bert.py @@ -33,7 +33,7 @@ class Implementation: "dynamo_optimized": Implementation(get_model_optimized, is_causal=False), "dynamo_optimized_cuda_graphs": Implementation(get_model_optimized_cuda_graphs, is_causal=False), # In this implementation both causal mask and the assume causal mask optimization will be applied, leads to slower - # benchmark + # benchmark. It's not needed if we are sure the mask is causal, we can use the "assume causal mask optimization". "dynamo_optimizer_cuda_graphs_causal": Implementation(get_model_optimized_causal_cuda_graphs, is_causal=True), } From 985817b2550039332d652fabde13fe262aa593d7 Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Fri, 30 Sep 2022 11:18:45 +0200 Subject: [PATCH 8/9] fix: param error attention --- implementations/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/implementations/attention.py b/implementations/attention.py index d5e77bf2..bc77db67 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -315,7 +315,7 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask_head_stride=attention_mask.stride(1) if HAS_MASK else 0, mask_m_stride=attention_mask.stride(2) if HAS_MASK else 0, mask_k_stride=attention_mask.stride(3) if HAS_MASK else 0, - min_clamp_value=torch.finfo(attention_mask.dtype).min, + min_clamp_value=torch.finfo(attention_mask.dtype).min if HAS_MASK else 0, MASK_BATCH_SIZE=attention_mask.size(0) if HAS_MASK else 0, MASK_HEAD_SIZE=attention_mask.size(1) if HAS_MASK else 0, MASK_M_SIZE=attention_mask.size(2) if HAS_MASK else 0, From 8cbce46b1b7035f8d8a5a097a20ff9fe1c98266e Mon Sep 17 00:00:00 2001 From: gaetansnl Date: Fri, 30 Sep 2022 17:27:47 +0200 Subject: [PATCH 9/9] fix: test --- test/test_torchdynamo_bert.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/test/test_torchdynamo_bert.py b/test/test_torchdynamo_bert.py index 7a7169a0..b6a7c9b1 100644 --- a/test/test_torchdynamo_bert.py +++ b/test/test_torchdynamo_bert.py @@ -87,16 +87,17 @@ def test_benchmark_implementations(benchmark, model_reference_fp32, shape: (int, @set_seed() -def test_support_shape_change(model_reference_fp32): +@pytest.mark.parametrize("name", implementations.keys()) +def test_support_shape_change(name, model_reference_fp32): """Test that the model can handle shape changes without being reloaded/rebuilt.""" - for name, implementation in implementations.items(): - model_tested = implementation.model() - for shape in [(1, 64), (8, 256), (16, 256), (16, 64)]: - pytorch_input = get_input_causal(shape) if implementation.is_causal else get_input_non_causal(shape) - with torch.inference_mode(): - expected = model_reference_fp32(**pytorch_input) - with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): - result = model_tested(**pytorch_input) - max_diff = torch.max(torch.abs(result["last_hidden_state"].float() - expected["last_hidden_state"])) - assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"], - atol=1e-1, rtol=1e-1), f"[{name}] failed with shape {shape}, max diff: {max_diff}" + implementation = implementations[name] + model_tested = implementation.model() + for shape in [(1, 64), (8, 256), (16, 256), (16, 64)]: + pytorch_input = get_input_causal(shape) if implementation.is_causal else get_input_non_causal(shape) + with torch.inference_mode(): + expected = model_reference_fp32(**pytorch_input) + with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): + result = model_tested(**pytorch_input) + max_diff = torch.max(torch.abs(result["last_hidden_state"].float() - expected["last_hidden_state"])) + assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"], + atol=1e-1, rtol=1e-1), f"[{name}] failed with shape {shape}, max diff: {max_diff}" \ No newline at end of file