-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changing naive attention to SDPA gives wrong result for batched llama example #22
Comments
For the PainlessInferenceAcceleration/pia/lookahead/models/llama/modeling_llama_batch.py Lines 299 to 325 in 6280cb2
The def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if query.size(0) == 1:
attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
key.squeeze(0).transpose(-1, -2))
else:
attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
attn_weights = attn_weights.add_(attention_mask)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights which is a very standard masked SDPA... Why is it not equivalent to using |
OK I see, the original Here's a simple test: import torch
import torch.nn.functional as F
def attn(query, key, value, attention_mask=None):
if query.size(0) == 1:
attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
key.squeeze(0).transpose(-1, -2))
else:
attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
attn_weights = attn_weights.add_(attention_mask)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output
def sdp_attn(query, key, value, attention_mask=None):
return F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=1.0)
@torch.inference_mode()
def main():
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
batch_size = 2
q_len, kv_len = 9, 9
head_num = 32
hidden_size = 128
q_shape = [batch_size, head_num, q_len, hidden_size]
kv_shape = [batch_size, head_num, kv_len, hidden_size]
mask_shape = [batch_size, 1, q_len, kv_len]
dtype = torch.float16
device = "cuda"
query = torch.randn(q_shape, dtype=dtype).to(device)
key = torch.randn(kv_shape, dtype=dtype).to(device)
value = torch.randn(kv_shape, dtype=dtype).to(device)
attention_mask = torch.zeros(mask_shape, dtype=dtype).to(device)
out = attn(query, key, value, attention_mask)
sdp_out = sdp_attn(query, key, value, attention_mask)
print(torch.allclose(out, sdp_out)) # True
if __name__ == "__main__":
main() |
Now with def _sdp_attn(self, query, key, value, attention_mask=None, head_mask=None):
with torch.backends.cuda.sdp_kernel(enable_math=False):
return F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, scale=1.0), None I can get correct results:
|
I attempted to swap-in FlashAttention for batched llama, by simply changing
self._attn()
toself._sdp_attn()
insideLlamaAttention.forward()
:PainlessInferenceAcceleration/pia/lookahead/models/llama/modeling_llama_batch.py
Lines 372 to 375 in 6280cb2
PainlessInferenceAcceleration/pia/lookahead/models/llama/modeling_llama_batch.py
Lines 404 to 407 in 6280cb2
where
_sdp_attn
is defined as:PainlessInferenceAcceleration/pia/lookahead/models/llama/modeling_llama_batch.py
Lines 327 to 329 in 6280cb2
However the model generates wrong result. The original
llama_batch_example.py
gives:The modified model gives:
So
LlamaAttention._attn()
is doing something extra other than just standard attention?The text was updated successfully, but these errors were encountered: