Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing naive attention to SDPA gives wrong result for batched llama example #22

Open
learning-chip opened this issue Mar 5, 2024 · 3 comments

Comments

@learning-chip
Copy link

learning-chip commented Mar 5, 2024

I attempted to swap-in FlashAttention for batched llama, by simply changing self._attn() to self._sdp_attn() inside LlamaAttention.forward():

attn_output, attn_weights = self._attn(query_states,
key_states,
value_states,
attention_mask=attention_mask)

attn_output, attn_weights = self._attn(query_states,
past_key[:, :, :max_len],
past_value[:, :, :max_len],
attention_mask=attention_mask)

where _sdp_attn is defined as:

def _sdp_attn(self, query, key, value, attention_mask=None, head_mask=None):
with torch.backends.cuda.sdp_kernel(enable_math=False):
return F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask), None

However the model generates wrong result. The original llama_batch_example.py gives:

lookahead:False time:3.326s speed:35.5token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]
...

The modified model gives:

lookahead:False time:3.271s speed:39.1token/s response:['the    “ nobody nobody     “ nobody   “ nobody  “ nobody   “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody  “ nobody ', 'nobody “ nobody “ nobody “ nobody to nobody. Unterscheidung the. Unterscheidung nobody. Unterscheidung ( ,  “ nobody, MS nobody, MS nobodyMS nobodyMS nobodyMS nobodyMS nobodyMS nobody. Unterscheidung,MS nobodyMS nobody. Unterscheidung,MS nobodyMS nobodyMS nobodyMS nobody. UnterscheidungMS nobodyMS']

So LlamaAttention._attn() is doing something extra other than just standard attention?

@learning-chip
Copy link
Author

learning-chip commented Mar 5, 2024

For the LlamaAttention._attn() implementation:

def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# attns = {}
# attn_weights = torch.matmul(self.norm_coef*query, key.transpose(-1, -2))
# attn_weights = attn_weights.masked_fill_(attention_mask==0.0, -65504.0)
# print(attention_mask.shape, query.shape, key.shape)
# coef is divide by weight and bias after loading
if query.size(0) == 1:
if self.normed:
attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
key.squeeze(0).transpose(-1, -2))
else:
attn_weights = torch.baddbmm(attention_mask.squeeze(0), self.norm_coef * query.squeeze(0),
key.squeeze(0).transpose(-1, -2))
else:
if self.normed:
attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
else:
attn_weights = torch.matmul(key, self.norm_coef * query.transpose(-1, -2)).transpose(-1, -2)
attn_weights = attn_weights.add_(attention_mask)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights

The self.norm_coef is never defined, and its else branch is never entered. So the code is equivalent to (I checked that code below gives identical result):

    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        if query.size(0) == 1:
            attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
                                         key.squeeze(0).transpose(-1, -2))
        else:
            attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
            attn_weights = attn_weights.add_(attention_mask)
            
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, value)
        return attn_output, attn_weights

which is a very standard masked SDPA... Why is it not equivalent to using scaled_dot_product_attention ?

@learning-chip
Copy link
Author

OK I see, the original _attn() function is missing the scaling factor. Setting scale=1.0 for scaled_dot_product_attention fixes the problem.

Here's a simple test:

import torch
import torch.nn.functional as F

def attn(query, key, value, attention_mask=None):
        if query.size(0) == 1:
            attn_weights = torch.baddbmm(attention_mask.squeeze(0), query.squeeze(0),
                                         key.squeeze(0).transpose(-1, -2))
        else:
            attn_weights = torch.matmul(key, query.transpose(-1, -2)).transpose(-1, -2)
            attn_weights = attn_weights.add_(attention_mask)
            
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, value)
        return attn_output


def sdp_attn(query, key, value, attention_mask=None):
    return F.scaled_dot_product_attention(
        query, key, value, attn_mask=attention_mask, scale=1.0)


@torch.inference_mode()
def main():
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)

    batch_size = 2
    q_len, kv_len = 9, 9
    head_num = 32
    hidden_size = 128
    
    q_shape = [batch_size, head_num, q_len, hidden_size]
    kv_shape = [batch_size, head_num, kv_len, hidden_size]
    mask_shape = [batch_size, 1, q_len, kv_len]

    dtype = torch.float16
    device = "cuda"
    query = torch.randn(q_shape, dtype=dtype).to(device)
    key = torch.randn(kv_shape, dtype=dtype).to(device)
    value = torch.randn(kv_shape, dtype=dtype).to(device)
    attention_mask = torch.zeros(mask_shape, dtype=dtype).to(device)
    
    out = attn(query, key, value, attention_mask)
    sdp_out = sdp_attn(query, key, value, attention_mask)

    print(torch.allclose(out, sdp_out)) # True

if __name__ == "__main__":
    main()

@learning-chip
Copy link
Author

Now with

    def _sdp_attn(self, query, key, value, attention_mask=None, head_mask=None):
        with torch.backends.cuda.sdp_kernel(enable_math=False): 
            return F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, scale=1.0), None

I can get correct results:

lookahead:False time:3.198s speed:36.9token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]



lookahead:False time:1.500s speed:78.7token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]



lookahead:True time:1.305s speed:90.4token/s response:["I'm here to help you.\nI'm here to help you with any questions or problems you might have. I'm a highly advanced AI language model, so I can provide information, answer questions, and even help you with your daily tasks.\n\nIs there something specific you would like to", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]



lookahead:True time:0.976s speed:120.9token/s response:["I'm here to help you.\nI'm just an AI, I don't have personal experiences or emotions like humans do, but I'm here to assist you in any way I can. Is there something specific you would like to know or discuss?\n\nPlease let me know if", "about a chicken crossing the road\n\nSure! Here's a classic one:\n\nWhy did the chicken cross the road?\n\nTo get to the other side... of the bar!\n\nI hope you found that one amusing!</s></s></s></s></s></s></s></s></s></s>"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant