Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add bias mask #66

Merged
merged 9 commits into from
Sep 30, 2022
Merged

feat: add bias mask #66

merged 9 commits into from
Sep 30, 2022

Conversation

gaetansnl
Copy link
Contributor

@gaetansnl gaetansnl commented Sep 28, 2022

  • Add mask support
  • Add mask broadcast support

PS: Didn't refactor mask creation methods, will do it in separate PR

@gaetansnl gaetansnl marked this pull request as ready for review September 29, 2022 14:29
@pommedeterresautee pommedeterresautee added enhancement New feature or request model Model scope, HF, etc. labels Sep 29, 2022
Copy link
Member

@pommedeterresautee pommedeterresautee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed

@@ -9,13 +11,28 @@


# Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213
def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool):
def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float,
is_causal: bool, attention_mask: Union[torch.Tensor, None]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not very important but you can hint type the output -> torch.Tensor

TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
output,
q_batch_stride, q_head_stride, q_m_stride, q_k_stride,
k_batch_stride, k_head_stride, k_n_stride, k_k_stride,
k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # We name n,k instead of k,n because of the transpose
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we keep the comment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo yes

Comment on lines +56 to +59
MASK_BATCH_SIZE: tl.constexpr,
MASK_HEAD_SIZE: tl.constexpr,
MASK_M_SIZE: tl.constexpr,
MASK_K_SIZE: tl.constexpr,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are they constant? What makes them different from mask_*_stride for instance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because we have conditions on them

@@ -197,7 +252,8 @@ class Attention(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool):
def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor,
sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attention_mask: torch.Tensor = None -> optional

Comment on lines 287 to 290
assert attention_mask.size(0) == batch or attention_mask.size(0) == 1
assert attention_mask.size(1) == heads or attention_mask.size(1) == 1
assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1
assert attention_mask.size(3) == seq_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add error message that would serve as documentaiton. Basically it may look like "mask is neither matching QKt shape or is broadcastable on its XXX axis"

assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1
assert attention_mask.size(3) == seq_length

# Move inside kernel ?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, keep it outside of Triton if it works.
If the trick is from HF library, can you add a link to source code?

Comment on lines 188 to 192
if MASK_M_SIZE == 1:
m = tl.load(mask + offs_mask)
else:
offs_mask += offs_m[:, None] * mask_m_stride
m = tl.load(mask + offs_mask, eviction_policy="evict_first")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as discussed, we may want to add some comment

Comment on lines 298 to 314
heads,
seq_length,
q,
k,
v,
sm_scale,
attention_mask,
tmp,
output,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
output.stride(0), output.stride(1), output.stride(2), output.stride(3),
attention_mask.stride(0) if HAS_MASK else 0,
attention_mask.stride(1) if HAS_MASK else 0,
attention_mask.stride(2) if HAS_MASK else 0,
attention_mask.stride(3) if HAS_MASK else 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you named argument because now it's very long

batch, seq_length = shape
if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512):
pytest.skip("Original Triton implementation only supports fp16 and seq_length=512")
if implementation == "original" and mask_fn != generate_none_mask:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

elif to highlight that we chain our tests

@@ -63,7 +80,7 @@ def test_mixed_stride():
v = torch.rand_like(q)
sm_scale = 0.3

expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False)
expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add dedicated test to mask

assert attention_mask.size(3) == seq_length

# Move inside kernel ?
attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move inside kernel

@gaetansnl gaetansnl mentioned this pull request Sep 30, 2022
Copy link
Member

@pommedeterresautee pommedeterresautee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@pommedeterresautee
Copy link
Member

================================================================================ 11 passed, 143 deselected, 598 warnings in 255.70s (0:04:15) =================================================================================

@pommedeterresautee pommedeterresautee merged commit bbb6e12 into main Sep 30, 2022
@pommedeterresautee pommedeterresautee deleted the feat/bias-mask branch September 30, 2022 15:59
@pommedeterresautee pommedeterresautee linked an issue Oct 1, 2022 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request model Model scope, HF, etc.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Manage input mask in Flash Attention
2 participants