-
Notifications
You must be signed in to change notification settings - Fork 94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add bias mask #66
Conversation
bf96c43
to
9c24916
Compare
9c24916
to
3af4a51
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed
implementations/attention.py
Outdated
@@ -9,13 +11,28 @@ | |||
|
|||
|
|||
# Similar to https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L213 | |||
def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): | |||
def attention_reference(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, | |||
is_causal: bool, attention_mask: Union[torch.Tensor, None]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not very important but you can hint type the output -> torch.Tensor
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug | ||
output, | ||
q_batch_stride, q_head_stride, q_m_stride, q_k_stride, | ||
k_batch_stride, k_head_stride, k_n_stride, k_k_stride, | ||
k_batch_stride, k_head_stride, k_n_stride, k_k_stride, # We name n,k instead of k,n because of the transpose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we keep the comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imo yes
MASK_BATCH_SIZE: tl.constexpr, | ||
MASK_HEAD_SIZE: tl.constexpr, | ||
MASK_M_SIZE: tl.constexpr, | ||
MASK_K_SIZE: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are they constant? What makes them different from mask_*_stride for instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we have conditions on them
implementations/attention.py
Outdated
@@ -197,7 +252,8 @@ class Attention(torch.autograd.Function): | |||
|
|||
@staticmethod | |||
@custom_fwd(cast_inputs=torch.float16) | |||
def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, sm_scale: float, is_causal: bool): | |||
def forward(ctx: FunctionCtx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, output: torch.Tensor, | |||
sm_scale: float, is_causal: bool, attention_mask: torch.Tensor = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
attention_mask: torch.Tensor = None
-> optional
implementations/attention.py
Outdated
assert attention_mask.size(0) == batch or attention_mask.size(0) == 1 | ||
assert attention_mask.size(1) == heads or attention_mask.size(1) == 1 | ||
assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1 | ||
assert attention_mask.size(3) == seq_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add error message that would serve as documentaiton. Basically it may look like "mask is neither matching QKt shape or is broadcastable on its XXX axis"
implementations/attention.py
Outdated
assert attention_mask.size(2) == seq_length or attention_mask.size(2) == 1 | ||
assert attention_mask.size(3) == seq_length | ||
|
||
# Move inside kernel ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, keep it outside of Triton if it works.
If the trick is from HF library, can you add a link to source code?
implementations/attention.py
Outdated
if MASK_M_SIZE == 1: | ||
m = tl.load(mask + offs_mask) | ||
else: | ||
offs_mask += offs_m[:, None] * mask_m_stride | ||
m = tl.load(mask + offs_mask, eviction_policy="evict_first") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed, we may want to add some comment
implementations/attention.py
Outdated
heads, | ||
seq_length, | ||
q, | ||
k, | ||
v, | ||
sm_scale, | ||
attention_mask, | ||
tmp, | ||
output, | ||
q.stride(0), q.stride(1), q.stride(2), q.stride(3), | ||
k.stride(0), k.stride(1), k.stride(2), k.stride(3), | ||
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | ||
output.stride(0), output.stride(1), output.stride(2), output.stride(3), | ||
attention_mask.stride(0) if HAS_MASK else 0, | ||
attention_mask.stride(1) if HAS_MASK else 0, | ||
attention_mask.stride(2) if HAS_MASK else 0, | ||
attention_mask.stride(3) if HAS_MASK else 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you named argument because now it's very long
test/test_attention.py
Outdated
batch, seq_length = shape | ||
if implementation == "original" and (dtype == torch.bfloat16 or seq_length != 512): | ||
pytest.skip("Original Triton implementation only supports fp16 and seq_length=512") | ||
if implementation == "original" and mask_fn != generate_none_mask: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif to highlight that we chain our tests
test/test_attention.py
Outdated
@@ -63,7 +80,7 @@ def test_mixed_stride(): | |||
v = torch.rand_like(q) | |||
sm_scale = 0.3 | |||
|
|||
expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False) | |||
expected = attention_reference(q=q, k=k, v=v, output=torch.empty_like(q), sm_scale=sm_scale, is_causal=False, attention_mask=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you please add dedicated test to mask
implementations/attention.py
Outdated
assert attention_mask.size(3) == seq_length | ||
|
||
# Move inside kernel ? | ||
attention_mask = attention_mask.clamp(min=torch.finfo(attention_mask.dtype).min, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move inside kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
|
PS: Didn't refactor mask creation methods, will do it in separate PR