diff --git a/implementations/attention.py b/implementations/attention.py index 3ed7890d..bc77db67 100644 --- a/implementations/attention.py +++ b/implementations/attention.py @@ -1,3 +1,5 @@ +from typing import Union, Optional + import torch import triton import triton.language as tl @@ -9,13 +11,28 @@ # Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213 -def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): +def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, + is_causal: bool, attention_mask: Union[torch.Tensor, None]) -> torch.Tensor: + """ + Reference implementation for attention + @param q: Query matrix size (batch, heads, seq_length, BLOCK_DHEAD) + @param k: Key matrix size (batch, heads, seq_length, BLOCK_DHEAD) + @param v: Value matrix size (batch, heads, seq_length, BLOCK_DHEAD) + @param sm_scale: Scaling factor applied after operation QxK + @param attention_mask: Attention mask broadcastable to (batch, heads, seq_length, seq_length). Warning the mask + isn't a binary mask + like the one you use normally. This mask is directly added to QxK. + @return: + """ seq_length = q.size(2) p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + + if attention_mask is not None: + p += attention_mask if is_causal: M = torch.tril(torch.ones((seq_length, seq_length), device="cuda")) p = torch.where(M == 0, float("-inf"), p) - p = torch.softmax(p.float(), dim=-1).to(q.dtype) + p = torch.nn.functional.softmax(p, dim=-1) ref_out = torch.matmul(p, v, out=output) return ref_out @@ -28,12 +45,20 @@ def _fwd_kernel( K, V, sm_scale, + mask, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug output, q_batch_stride, q_head_stride, q_m_stride, q_k_stride, - k_batch_stride, k_head_stride, k_n_stride, k_k_stride, + k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # We name n,k instead of k,n because of the transpose v_batch_stride, v_head_stride, v_k_stride, v_n_stride, o_batch_stride, o_head_stride, o_m_stride, o_n_stride, + mask_batch_stride, mask_head_stride, mask_m_stride, mask_k_stride, + min_clamp_value, + MASK_BATCH_SIZE: tl.constexpr, + MASK_HEAD_SIZE: tl.constexpr, + MASK_M_SIZE: tl.constexpr, + MASK_K_SIZE: tl.constexpr, + HAS_MASK: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DHEAD: tl.constexpr, @@ -67,6 +92,17 @@ def _fwd_kernel( @param o_head_stride: output matrix stride for head dimension @param o_m_stride: output matrix stride for rows @param o_n_stride: output matrix stride for columns + @param mask: Attention mask matrix broadcastable to (batch, heads, seq_length, seq_length) + @param mask_batch_stride: Matrix mask stride for batch dimension + @param mask_head_stride: Matrix mask stride for head dimension + @param mask_m_stride: Matrix mask stride for rows + @param mask_k_stride: Matrix mask stride for columns + @param MASK_BATCH_SIZE: Matrix mask size for batch dimension + @param MASK_HEAD_SIZE: Matrix mask size for head dimension + @param MASK_M_SIZE: Matrix mask size for rows + @param MASK_K_SIZE: Matrix mask size for columns + @param HAS_MASK: Whether the mask is applied + @param IS_CAUSAL: Whether the mask is applied @param BLOCK_M: number of rows computed in a single instance for matrix Q @param BLOCK_DHEAD: number of columns per head @param BLOCK_N: number of rows computed at each loop in the main loop for matrix K and V @@ -122,6 +158,18 @@ def _fwd_kernel( n_end = seq_length if IS_CAUSAL: n_end = (m_block_idx + 1) * BLOCK_M, + + if HAS_MASK: + mask_batch_idx = current_batch_idx, + if MASK_BATCH_SIZE == 1: + mask_batch_idx = 0 + + mask_head_idx = current_head_idx + if MASK_HEAD_SIZE == 1: + mask_head_idx = 0 + + offs_base_mask = mask_batch_idx * mask_batch_stride + mask_head_idx * mask_head_stride + # loop over k, v and update accumulator # n_row_offset is the row offset on dimension N of the current block # It's used for both the N dimension of K and V because they are handled at the same time @@ -137,6 +185,19 @@ def _fwd_kernel( if IS_CAUSAL: qk += tl.where(offs_m[:, None] >= (n_row_offset + offs_n[None, :]), 0, float("-inf")) + if HAS_MASK: + offs_mask = offs_base_mask + (offs_n[None, :] + n_row_offset) * mask_k_stride + # If it's a broadcast we only load vector size BLOCK_N else a matrix size (BLOCK_M, BLOCK_N) + if MASK_M_SIZE == 1: + m = tl.load(mask + offs_mask) + else: + offs_mask += offs_m[:, None] * mask_m_stride + # The mask matrix is never reused + m = tl.load(mask + offs_mask, eviction_policy="evict_first") + # Avoids NaN + m = tl.where(m == float("-inf"), min_clamp_value, m) + qk += m + # We compute softmax normalization like in Milakov et al. # We renamed m (in the original article) to l to avoid confusions # We start with the current block qk @@ -197,7 +258,8 @@ class Attention(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) - def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): + def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, + sm_scale: float, is_causal: bool, attention_mask: Optional[torch.Tensor] = None): """ Computes attention. FP32 input and output are not supported. @@ -226,19 +288,39 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, grid = (triton.cdiv(seq_length, BLOCK_M), batch * heads) tmp = torch.empty((batch * heads, seq_length), device=q.device, dtype=torch.float32) + HAS_MASK = False + if attention_mask is not None: + assert attention_mask.size(0) == batch or attention_mask.size(0) == 1, "Incompatible broadcast batch dimension" + assert attention_mask.size(1) == heads or attention_mask.size(1) == 1, "Incompatible broadcast heads dimension" + assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1, "Incompatible broadcast seq_length dimension" + assert attention_mask.size(3) == seq_length, "Last size of mask must be seq_length to broadcast on QK^t" + + HAS_MASK = True + _fwd_kernel[grid]( - heads, - seq_length, - q, - k, - v, - sm_scale, - tmp, - output, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - output.stride(0), output.stride(1), output.stride(2), output.stride(3), + heads=heads, + seq_length=seq_length, + Q=q, + K=k, + V=v, + sm_scale=sm_scale, + mask=attention_mask, + TMP=tmp, + output=output, + q_batch_stride=q.stride(0), q_head_stride=q.stride(1), q_m_stride=q.stride(2), q_k_stride=q.stride(3), + k_batch_stride=k.stride(0), k_head_stride=k.stride(1), k_n_stride=k.stride(2), k_k_stride=k.stride(3), + v_batch_stride=v.stride(0), v_head_stride=v.stride(1), v_k_stride=v.stride(2), v_n_stride=v.stride(3), + o_batch_stride=output.stride(0), o_head_stride=output.stride(1), o_m_stride=output.stride(2), o_n_stride=output.stride(3), + mask_batch_stride=attention_mask.stride(0) if HAS_MASK else 0, + mask_head_stride=attention_mask.stride(1) if HAS_MASK else 0, + mask_m_stride=attention_mask.stride(2) if HAS_MASK else 0, + mask_k_stride=attention_mask.stride(3) if HAS_MASK else 0, + min_clamp_value=torch.finfo(attention_mask.dtype).min if HAS_MASK else 0, + MASK_BATCH_SIZE=attention_mask.size(0) if HAS_MASK else 0, + MASK_HEAD_SIZE=attention_mask.size(1) if HAS_MASK else 0, + MASK_M_SIZE=attention_mask.size(2) if HAS_MASK else 0, + MASK_K_SIZE=attention_mask.size(3) if HAS_MASK else 0, + HAS_MASK=HAS_MASK, IS_CAUSAL=is_causal, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, @@ -250,5 +332,6 @@ def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, return output -def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool = False): - return Attention.apply(q, k, v, output, sm_scale, is_causal) +def attention_forward(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, + is_causal: bool = False, attention_mask: Optional[torch.Tensor] = None): + return Attention.apply(q, k, v, output, sm_scale, is_causal, attention_mask) diff --git a/optimizer/attention.py b/optimizer/attention.py index acc0ce8d..66f8d004 100644 --- a/optimizer/attention.py +++ b/optimizer/attention.py @@ -4,8 +4,8 @@ from utils.extended_matcher import replace_pattern -def attention_wrapper(q, k, v, output, sm_scale, is_causal, *args): - return attention_forward(q, k, v, output, sm_scale, is_causal=is_causal) +def attention_wrapper(q, k, v, output, sm_scale, is_causal, attention_mask): + return attention_forward(q, k, v, output, sm_scale, is_causal=is_causal, attention_mask=attention_mask) torch.fx.wrap('attention_wrapper') diff --git a/test/models/data_utils.py b/test/models/data_utils.py index a8669ab2..be6868e3 100644 --- a/test/models/data_utils.py +++ b/test/models/data_utils.py @@ -3,6 +3,10 @@ import torch +def get_attention_mask(shape: (int, int)) -> torch.Tensor: + return torch.randint(1, shape[1], (shape[0],), device="cuda")[:, None] > torch.arange(0, shape[1], device="cuda")[ + None, :] + def get_input_causal(shape: (int, int)) -> Dict[str, torch.Tensor]: batch, seq_length = shape mask = torch.tril(torch.ones((batch, seq_length, seq_length), dtype=torch.int64, device="cuda")) @@ -16,6 +20,6 @@ def get_input_causal(shape: (int, int)) -> Dict[str, torch.Tensor]: def get_input_non_causal(shape: (int, int)) -> Dict[str, torch.Tensor]: return { "input_ids": torch.randint(2, 1000, size=shape, dtype=torch.int64, device="cuda"), - "attention_mask": torch.ones(size=shape, dtype=torch.int64, device="cuda"), + "attention_mask": get_attention_mask(shape).to(torch.int64), "token_type_ids": torch.ones(size=shape, dtype=torch.int64, device="cuda") } diff --git a/test/test_attention.py b/test/test_attention.py index 8fad6465..833565f1 100644 --- a/test/test_attention.py +++ b/test/test_attention.py @@ -16,23 +16,39 @@ def original_triton_flash_attention(is_causal: bool, *args, **kwargs): implementations = { - "original": lambda q, k, v, output, sm_scale, is_causal: original_triton_flash_attention(is_causal, q, k, v, output, sm_scale), - "triton": lambda q, k, v, output, sm_scale, is_causal: attention_forward(q, k, v, output, sm_scale, is_causal), - "torch": lambda q, k, v, output, sm_scale, is_causal: attention_reference(q, k, v, output, sm_scale, is_causal), + "original": lambda q, k, v, output, sm_scale, is_causal, attention_mask: original_triton_flash_attention(is_causal, q, k, v, output, sm_scale), + "triton": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_forward(q, k, v, output, sm_scale, is_causal, attention_mask), + "torch": lambda q, k, v, output, sm_scale, is_causal, attention_mask: attention_reference(q, k, v, output, sm_scale, is_causal, attention_mask), } +def generate_broadcast_mask(batch, seq_length, dtype=torch.float32): + attention_mask = torch.randint(1, seq_length, (batch,), device="cuda")[:, None] > torch.arange(0, seq_length, device="cuda")[ + None, :] + attention_mask = attention_mask.to(dtype) + attention_mask = torch.reshape(attention_mask, (batch, 1, 1, seq_length)) + attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min + return attention_mask + +def generate_bias_mask(batch, seq_length, dtype=torch.float32): + return torch.rand((batch, 48, seq_length, seq_length), dtype=dtype, device="cuda") + +def generate_none_mask(batch, seq_length, dtype=torch.float32): + return None @set_seed() @pytest.mark.parametrize("shape", [(bs, seq_l) for bs in [1, 8, 32, 64] for seq_l in [16, 64, 128, 256, 384, 512]], ids=lambda x: f"{x[0]}x{x[1]}") -@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"]) # fp32 not yet possible because of a bug in triton @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +@pytest.mark.parametrize("is_causal", [True, False], ids=["causal", "non-causal"]) +@pytest.mark.parametrize("mask_fn", [generate_bias_mask, generate_broadcast_mask, generate_none_mask], ids=["bias-mask", "broadcast-mask", 'no-mask']) @pytest.mark.parametrize("implementation", implementations.keys()) -def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable, dtype: torch.dtype, is_causal: bool): +def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable, mask_fn: Callable, dtype: torch.dtype, is_causal: bool): batch, seq_length = shape if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512): pytest.skip("Original Triton implementation only supports fp16 and seq_length=512") + elif implementation == "original" and mask_fn != generate_none_mask: + pytest.skip("Original Triton implementation doesn't support masks") # batch, heads, seq_length, dhead mat_shape = (batch, 48, seq_length, 64) @@ -43,6 +59,7 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable "output": torch.empty(mat_shape, device="cuda"), "sm_scale": 0.3, # Scaling applied before softmax (sqrt(dhead) in Vaswani et al.) "is_causal": is_causal, + "attention_mask": mask_fn(batch, seq_length) } expected = attention_reference(**args) @@ -57,13 +74,18 @@ def test_benchmark_masked(benchmark, shape: (int, int), implementation: Callable @set_seed() def test_mixed_stride(): # Column major - q = torch.transpose(torch.rand((4, 48, 64, 512), dtype=torch.float16, device="cuda"), -1, -2) + q = torch.rand((4, 48, 64, 512), dtype=torch.float16, device="cuda").transpose(-1, -2) # Interlaced batch - k = torch.transpose(torch.rand((48, 4, 512, 64), dtype=torch.float16, device="cuda"), 0, 1) + k = torch.rand((48, 4, 512, 64), dtype=torch.float16, device="cuda").transpose(0, 1) v = torch.rand_like(q) + mask = torch.rand((48, 4, 512, 512), dtype=torch.float16, device="cuda").transpose(0, 1).transpose(-1, -2) sm_scale = 0.3 - expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False) + assert not q.is_contiguous() + assert not k.is_contiguous() + assert not mask.is_contiguous() + + expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=mask) output = torch.empty_like(q) - attention_forward(q, k, v, output, sm_scale) + attention_forward(q, k, v, output, sm_scale, attention_mask=mask) assert torch.allclose(output, expected, atol=1e-2) diff --git a/test/test_torchdynamo_bert.py b/test/test_torchdynamo_bert.py index 3056d587..b6a7c9b1 100644 --- a/test/test_torchdynamo_bert.py +++ b/test/test_torchdynamo_bert.py @@ -32,6 +32,8 @@ class Implementation: "dynamo_cuda_graphs": Implementation(get_model_dynamo_cuda_graphs, is_causal=False), "dynamo_optimized": Implementation(get_model_optimized, is_causal=False), "dynamo_optimized_cuda_graphs": Implementation(get_model_optimized_cuda_graphs, is_causal=False), + # In this implementation both causal mask and the assume causal mask optimization will be applied, leads to slower + # benchmark. It's not needed if we are sure the mask is causal, we can use the "assume causal mask optimization". "dynamo_optimizer_cuda_graphs_causal": Implementation(get_model_optimized_causal_cuda_graphs, is_causal=True), } @@ -85,16 +87,17 @@ def test_benchmark_implementations(benchmark, model_reference_fp32, shape: (int, @set_seed() -def test_support_shape_change(model_reference_fp32): +@pytest.mark.parametrize("name", implementations.keys()) +def test_support_shape_change(name, model_reference_fp32): """Test that the model can handle shape changes without being reloaded/rebuilt.""" - for name, implementation in implementations.items(): - model_tested = implementation.model() - for shape in [(1, 64), (8, 256), (16, 256), (16, 64)]: - pytorch_input = get_input_causal(shape) if implementation.is_causal else get_input_non_causal(shape) - with torch.inference_mode(): - expected = model_reference_fp32(**pytorch_input) - with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): - result = model_tested(**pytorch_input) - max_diff = torch.max(torch.abs(result["last_hidden_state"].float() - expected["last_hidden_state"])) - assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"], - atol=1e-1, rtol=1e-1), f"[{name}] failed with shape {shape}, max diff: {max_diff}" + implementation = implementations[name] + model_tested = implementation.model() + for shape in [(1, 64), (8, 256), (16, 256), (16, 64)]: + pytorch_input = get_input_causal(shape) if implementation.is_causal else get_input_non_causal(shape) + with torch.inference_mode(): + expected = model_reference_fp32(**pytorch_input) + with torch.cuda.amp.autocast(enabled=True, dtype=torch.float16, cache_enabled=True): + result = model_tested(**pytorch_input) + max_diff = torch.max(torch.abs(result["last_hidden_state"].float() - expected["last_hidden_state"])) + assert torch.allclose(result["last_hidden_state"].float(), expected["last_hidden_state"], + atol=1e-1, rtol=1e-1), f"[{name}] failed with shape {shape}, max diff: {max_diff}" \ No newline at end of file