-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for packed layout with paged attention #132
Comments
Hi @Corendos , Thanks for your question. We are actually enabling it in the upcoming cuDNN frontend release. It's roughly a week out, so if you'd like to enable it earlier, you can just comment out L336 in scaled_dot_prodcut_flash_attention.h, and manually verify that you have cuDNN v9.7 as the backend cuDNN version. Please note that in this case only Q will have packed support. In theory packing could be combined with paged K and V caches: it would be the page tables themselves that would be packed then. However the amount of compression you get from this is minimal. But please let us know if you disagree or if you see other good use cases for that! |
Hey @nvmbreughe, thanks for the quick answer ! Funny that you suggest that, because that's exactly what I tried ! 😁 To give more context, we are currently working on a ML framework that uses XLA under the hood (through the PJRT abstraction). XLA has some support for cudnn flash attention, just not the paged version so we hacked support. It was working great for decode but was quite slow for prefill due to the naive approach, so I was wondering if cudnn supported packed + paged as other kernels supported it. So overall that's great news ! There is still a small issue on XLA side, I'm faced with a CUDA error about misalignment. As for :
I was mainly looking for Q packed support so I don't think that's really required to compress the page tables. Cheers 😁 |
Happy to help, @Corendos.
Ha, great! That will work as long as cuDNN backend is at least v9.7.
Not sure. compute-sanitizer "may" tell you more, so maybe try running it through there? |
Ok, so this was a mistake on my side. Due to the way XLA does its stuff the UID used when building the graph and executing it can be different. I introduced a mismatch so the wrong tensors were given.
About that, I think I discovered a bug. In the documentation, I understand that in the case of packed layout, However, when I try to build a graph with such a difference, I get an error. Here is the CUDNN logs:
If you want to reproduce, here is a gist containing the modified sample I used: https://gist.github.com/Corendos/ab4712e1c53b72ff114b108635bc5c1f I saw that there was a recent release of CuDNN backend (9.8.0), do you know by any chance if it was fixed in this version ? |
After a bit more investigation, it seems that the error originates when the
I also tried with the 9.8.0 release of CuDNN backend but the error is still triggering. I'll try to see if the error also happens when the kernel is used in a non-paged way and keep you posted. |
Just found out that you can enable log in CuDNN backend with The interesting part being:
So this seems like a potential issue, as the purpose of packed tensors is to allow a Is there a way to report this directly to the CuDNN backend team? I’d love to help if needed, please let me know how I can contribute! 😁 |
@Corendos thanks for reporting the bug. @nvmbreughe can you create a NVBug next week? @Corendos would you be interested in connecting to discuss your use cases? |
Also please note that we also tried noping out the assertion inside cudnn_graph.so, and unfortunately it fails later. |
From the sample, looks like there is a mismatch in us documenting the ragged offset and Q tensor. Looking at the multiple tensors,
The expectation is
The reason the Ragged offset is B+1, is because the first offset starts at 0. Hope that makes sense Regards, |
Hi all !
@mnicely I would love to, how do you want to proceed ?
It's true that the documentation of the kernel says that, but there is also the part about Supported Tensor Layout that introduces a new name for a dimension. It says that in the case of packed layout, Q has a shape called Also, in my understanding, forcing the ragged offset to be of size |
Hi all,
I've been playing with
cudnn-frontend
to test the Flash Attention kernel. Overall, it's easy to use and fast but I've come across a limitation that I don't really understand.It seems that the kernel can't be used in paged mode with packed tensors. This is something that other paged attention kernels support (and it makes a big difference in terms of performance as well as tokens can be batched per sequence).
So two questions about that:
cudnn-frontend
? Because I couldn't find in the backend doc such a limitationThe text was updated successfully, but these errors were encountered: