Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for packed layout with paged attention #132

Open
Corendos opened this issue Mar 4, 2025 · 10 comments
Open

Support for packed layout with paged attention #132

Corendos opened this issue Mar 4, 2025 · 10 comments

Comments

@Corendos
Copy link

Corendos commented Mar 4, 2025

Hi all,

I've been playing with cudnn-frontend to test the Flash Attention kernel. Overall, it's easy to use and fast but I've come across a limitation that I don't really understand.

It seems that the kernel can't be used in paged mode with packed tensors. This is something that other paged attention kernels support (and it makes a big difference in terms of performance as well as tokens can be batched per sequence).

So two questions about that:

  1. Is it a limitation only in cudnn-frontend ? Because I couldn't find in the backend doc such a limitation
  2. Are there plans to add that feature in the future ?
@nvmbreughe
Copy link

Hi @Corendos ,

Thanks for your question. We are actually enabling it in the upcoming cuDNN frontend release. It's roughly a week out, so if you'd like to enable it earlier, you can just comment out L336 in scaled_dot_prodcut_flash_attention.h, and manually verify that you have cuDNN v9.7 as the backend cuDNN version.

Please note that in this case only Q will have packed support. In theory packing could be combined with paged K and V caches: it would be the page tables themselves that would be packed then. However the amount of compression you get from this is minimal. But please let us know if you disagree or if you see other good use cases for that!

@Corendos
Copy link
Author

Corendos commented Mar 4, 2025

Hey @nvmbreughe, thanks for the quick answer !

Funny that you suggest that, because that's exactly what I tried ! 😁
I was just not sure about the correctness of the output.

To give more context, we are currently working on a ML framework that uses XLA under the hood (through the PJRT abstraction). XLA has some support for cudnn flash attention, just not the paged version so we hacked support. It was working great for decode but was quite slow for prefill due to the naive approach, so I was wondering if cudnn supported packed + paged as other kernels supported it. So overall that's great news !

There is still a small issue on XLA side, I'm faced with a CUDA error about misalignment.
I tweaked the cudnn frontend samples in order to reproduce but the error doesn't trigger. The log output of cudnn is the same in both cases, so I guess that's an issue in XLA, I just wanted to see if that rang a bell ?

As for :

In theory packing could be combined with paged K and V caches: it would be the page tables themselves that would be packed then. However the amount of compression you get from this is minimal. But please let us know if you disagree or if you see other good use cases for that!

I was mainly looking for Q packed support so I don't think that's really required to compress the page tables.

Cheers 😁

@nvmbreughe
Copy link

Happy to help, @Corendos.

Funny that you suggest that, because that's exactly what I tried !

Ha, great! That will work as long as cuDNN backend is at least v9.7.

There is still a small issue on XLA side, I'm faced with a CUDA error about misalignment.
I tweaked the cudnn frontend samples in order to reproduce but the error doesn't trigger. The log output of cudnn is the same in both cases, so I guess that's an issue in XLA, I just wanted to see if that rang a bell ?

Not sure. compute-sanitizer "may" tell you more, so maybe try running it through there?
Do you have a way to get the starting addresses of each tensor XLA allocated? Does it happen with paged caches only?

@Corendos
Copy link
Author

Corendos commented Mar 6, 2025

Not sure. compute-sanitizer "may" tell you more, so maybe try running it through there?
Do you have a way to get the starting addresses of each tensor XLA allocated? Does it happen with paged caches only?

Ok, so this was a mistake on my side. Due to the way XLA does its stuff the UID used when building the graph and executing it can be different. I introduced a mismatch so the wrong tensors were given.

Ha, great! That will work as long as cuDNN backend is at least v9.7.

About that, I think I discovered a bug. In the documentation, I understand that in the case of packed layout, Q first dimension (let's call it T) can be different than the batch size B.

However, when I try to build a graph with such a difference, I get an error. Here is the CUDNN logs:

[cudnn_frontend] 
{"context":{"compute_data_type":"FLOAT","intermediate_data_type":"FLOAT","io_data_type":"HALF","name":"","sm_count":-1},"cudnn_backend_version":"9.7.1","cudnn_frontend_version":11000,"json_version":"1.0","nodes":[{"alibi_mask":false,"attn_scale_value":"3DB504F3","diagonal_alignment":"TOP_LEFT","dropout_probability":null,"inputs":{"K":"container_K","Page_table_K":"page_table_k","Page_table_V":"page_table_v","Q":"Q","SEQ_LEN_KV":"seq_kv","SEQ_LEN_Q":"seq_q","V":"container_V"},"is_inference":true,"left_bound":null,"max_seq_len_kv":4096,"name":"flash_attention","outputs":{"O":"flash_attention::O"},"padding_mask":true,"right_bound":null,"tag":"SDPA_FWD"}],"tensors":{"Q":{"data_type":null,"dim":[17,32,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"Q","pass_by_value":null,"reordering_type":"NONE","stride":[4096,128,128,1],"uid":1,"uid_assigned":true},"container_K":{"data_type":null,"dim":[32768,8,16,128],"is_pass_by_value":false,"is_virtual":false,"name":"container_K","pass_by_value":null,"reordering_type":"NONE","stride":[16384,128,1024,1],"uid":2,"uid_assigned":true},"container_V":{"data_type":null,"dim":[32768,8,16,128],"is_pass_by_value":false,"is_virtual":false,"name":"container_V","pass_by_value":null,"reordering_type":"NONE","stride":[16384,128,1024,1],"uid":3,"uid_assigned":true},"flash_attention::O":{"data_type":null,"dim":[17,32,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"flash_attention::O","pass_by_value":null,"reordering_type":"NONE","stride":[4096,128,128,1],"uid":4,"uid_assigned":true},"page_table_k":{"data_type":"INT32","dim":[16,1,256,1],"is_pass_by_value":false,"is_virtual":false,"name":"page_table_k","pass_by_value":null,"reordering_type":"NONE","stride":[256,256,1,1],"uid":9,"uid_assigned":true},"page_table_v":{"data_type":"INT32","dim":[16,1,256,1],"is_pass_by_value":false,"is_virtual":false,"name":"page_table_v","pass_by_value":null,"reordering_type":"NONE","stride":[256,256,1,1],"uid":10,"uid_assigned":true},"seq_kv":{"data_type":"INT32","dim":[17,1,1,1],"is_pass_by_value":false,"is_virtual":false,"name":"seq_kv","pass_by_value":null,"reordering_type":"NONE","stride":[1,1,1,1],"uid":8,"uid_assigned":true},"seq_q":{"data_type":"INT32","dim":[17,1,1,1],"is_pass_by_value":false,"is_virtual":false,"name":"seq_q","pass_by_value":null,"reordering_type":"NONE","stride":[1,1,1,1],"uid":7,"uid_assigned":true}}}
[cudnn_frontend] INFO: Validating SDPANode flash_attention...
[cudnn_frontend] INFO: Validating SDPANode flash_attention...
[cudnn_frontend] INFO: Inferrencing properties for Scaled_dot_product_flash_attention node  flash_attention...
[cudnn_frontend] INFO: Validating PagedCacheLoadNode paged_k_cache_operation...
[cudnn_frontend] INFO: Inferrencing properties for matmul node bmm1...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node attn_scale...
[cudnn_frontend] INFO:attn_scale::OUT_0 stride computed from bmm1::C
[cudnn_frontend] INFO: Inferrencing properties for pointwise node gen_row_idx_padding...
[cudnn_frontend] INFO:gen_row_idx_padding::OUT_0 stride computed from attn_scale::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node gen_col_idx_padding...
[cudnn_frontend] INFO:gen_col_idx_padding::OUT_0 stride computed from attn_scale::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node lt_row_sq_padding...
[cudnn_frontend] INFO:lt_row_sq_padding::OUT_0 stride computed from gen_row_idx_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node lt_col_skv_padding...
[cudnn_frontend] INFO:lt_col_skv_padding::OUT_0 stride computed from gen_col_idx_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node and_row_col_padding...
[cudnn_frontend] INFO:and_row_col_padding::OUT_0 stride computed from lt_col_skv_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node select_padding...
[cudnn_frontend] INFO:select_padding::OUT_0 stride computed from and_row_col_padding::OUT_0
[cudnn_frontend] INFO: Validating SoftmaxNode softmax...
[cudnn_frontend] INFO: Inferrencing properties for Softmax node softmax.
[cudnn_frontend] INFO: Inferrencing properties for reduction node M...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node sub...
[cudnn_frontend] INFO:sub_M stride computed from select_padding::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node exp...
[cudnn_frontend] INFO:exp_sub_M stride computed from sub_M
[cudnn_frontend] INFO: Inferrencing properties for reduction node sum...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node log...
[cudnn_frontend] INFO: Inferrencing properties for pointwise node add...
[cudnn_frontend] INFO: stride computed from log::OUT_0
[cudnn_frontend] INFO: Inferrencing properties for pointwise node div...
[cudnn_frontend] INFO: stride computed from exp_sub_M
[cudnn_frontend] INFO: Validating PagedCacheLoadNode paged_v_cache_operation...
[cudnn_frontend] INFO: Inferrencing properties for matmul node bmm2...
[cudnn_frontend] INFO: Creating cudnn tensors for node named 'flash_attention':
[cudnn_frontend] INFO: Creating Backend Tensor named 'attn_scale::IN_1' with UID 5
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["FLOAT"] Id: 5 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 1,1,1,1 ] Str [ 1,1,1,1 ] isVirtual: 0 isByValue: 1 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] INFO: Creating Backend Tensor named 'container_V' with UID 3
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["HALF"] Id: 3 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 32768,8,16,128 ] Str [ 16384,128,1024,1 ] isVirtual: 0 isByValue: 0 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] INFO: Creating Backend Tensor named 'container_K' with UID 2
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["HALF"] Id: 2 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 32768,8,16,128 ] Str [ 16384,128,1024,1 ] isVirtual: 0 isByValue: 0 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] INFO: Creating Backend Tensor named 'Q' with UID 1
[cudnn_frontend] INFO: Creating Backend Tensor named 'ragged_offset_q' with UID 12
[cudnn_frontend] CUDNN_BACKEND_TENSOR_DESCRIPTOR : Datatype: ["INT32"] Id: 12 nDims 4 VectorCount: 1 vectorDimension -1 Dim [ 17,1,1,1 ] Str [ 1,1,1,1 ] isVirtual: 0 isByValue: 0 Alignment: 16 reorder_type: ["NONE"]
[cudnn_frontend] ERROR: CUDNN_BACKEND_TENSOR_DESCRIPTOR cudnnFinalize failedptrDesc->finalize() cudnn_status: CUDNN_STATUS_BAD_PARAM. ["CUDNN_BACKEND_API_FAILED"] because (e.getCudnnStatus() != CUDNN_STATUS_SUCCESS) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/cudnn_interface.h:86
[cudnn_frontend] ERROR: detail::create_cudnn_tensor(tensor, tensors, potential_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/node_interface.h:395
[cudnn_frontend] ERROR: create_cudnn_tensors_node(uid_to_backend_tensors, potential_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/node_interface.h:242
[cudnn_frontend] ERROR: sub_node->create_cudnn_tensors_subtree(uid_to_backend_tensors, potential_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/node_interface.h:244
[cudnn_frontend] ERROR: create_cudnn_tensors_subtree(uid_to_tensors, start_uid, used_uids) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/graph_interface.h:566
[cudnn_frontend] ERROR: this->build_operation_graph(handle) at /mnt/hugo/cudnn-frontend/include/cudnn_frontend/graph_interface.h:1502

If you want to reproduce, here is a gist containing the modified sample I used: https://gist.github.com/Corendos/ab4712e1c53b72ff114b108635bc5c1f

I saw that there was a recent release of CuDNN backend (9.8.0), do you know by any chance if it was fixed in this version ?

@Corendos
Copy link
Author

Corendos commented Mar 6, 2025

After a bit more investigation, it seems that the error originates when the Q Tensor is built here:

status = detail::finalize(m_tensor.pointer->get_backend_descriptor());

I also tried with the 9.8.0 release of CuDNN backend but the error is still triggering.

I'll try to see if the error also happens when the kernel is used in a non-paged way and keep you posted.

@Corendos
Copy link
Author

Corendos commented Mar 7, 2025

Just found out that you can enable log in CuDNN backend with CUDNN_LOGDEST_DBG=stdout CUDNN_LOGLEVEL_DBG=3 and here is the output:

cudnn_debug_log.txt

The interesting part being:

I! CuDNN (v90800 87) function cudnnBackendFinalize() called:
i!     descriptor: type=CUDNN_BACKEND_TENSOR_DESCRIPTOR:
i!         type: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=4;
i!         dimA: type=int; val=[256,32,1,128];
i!         strideA: type=int; val=[4096,128,128,1];
i!         uid: type=int64_t; val=1;
i!         alignmentInBytes: type=int64_t; val=16;
i!         isVirtual: type=bool; val=false;
i!         isByVal: type=bool; val=false;
i! Time: 2025-03-07T09:16:49.176468 (0d+0h+0m+1s since start)
i! Process=354365; Thread=354365; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v90800 87) function cudnnBackendFinalize() called:
i!     status: type=cudnnStatus_t; val=CUDNN_STATUS_BAD_PARAM (2000);
i! Time: 2025-03-07T09:16:49.176476 (0d+0h+0m+1s since start)
i! Process=354365; Thread=354365; GPU=NULL; Handle=NULL; StreamId=NULL.


E! CuDNN (v90800 87) function cudnnBackendFinalize() called:
e!     Info: Traceback contains 3 message(s)
e!         Error: CUDNN_STATUS_BAD_PARAM; Reason: CUDNN_ATTR_TENSOR_RAGGED_OFFSET_DESC ragged dim should match dim value + 1 of original tensor. All other offset dim values should be singleton. at: offset_dimA[dim] != this->_dimA[dim] + 1 && offset_dimA[dim] != 1
e!         Error: CUDNN_STATUS_BAD_PARAM; Reason: finalize_internal()
e!         Error: CUDNN_STATUS_BAD_PARAM; Reason: ptrDesc->finalize()
e! Time: 2025-03-07T09:16:49.176482 (0d+0h+0m+1s since start)
e! Process=354365; Thread=354365; GPU=NULL; Handle=NULL; StreamId=NULL.

So this seems like a potential issue, as the purpose of packed tensors is to allow a Q tensor with more than one token per batch dimension.

Is there a way to report this directly to the CuDNN backend team? I’d love to help if needed, please let me know how I can contribute! 😁

@mnicely
Copy link
Collaborator

mnicely commented Mar 8, 2025

@Corendos thanks for reporting the bug. @nvmbreughe can you create a NVBug next week?

@Corendos would you be interested in connecting to discuss your use cases?

@steeve
Copy link

steeve commented Mar 8, 2025

Also please note that we also tried noping out the assertion inside cudnn_graph.so, and unfortunately it fails later.

@Anerudhan
Copy link
Collaborator

Hi @Corendos / @steeve ,

From the sample, looks like there is a mismatch in us documenting the ragged offset and Q tensor.

Looking at the multiple tensors,

"Q" -> "dim":[17,32,1,128]
"ragged_offset_q" -> dim [17,1,1,1]
"page_table_k":{"data_type":"INT32","dim":[16,1,256,1]

The expectation is

Q is B,H,S,D
Ragged offset is B+1,1,1,1
Page_table is B, ...

The reason the Ragged offset is B+1, is because the first offset starts at 0.

Hope that makes sense

Regards,
Anerudhan

@Corendos
Copy link
Author

Hi all !

@Corendos would you be interested in connecting to discuss your use cases?

@mnicely I would love to, how do you want to proceed ?

The expectation is

Ragged offset is B+1,1,1,1
Page_table is B, ...

The reason the Ragged offset is B+1, is because the first offset starts at 0.

It's true that the documentation of the kernel says that, but there is also the part about Supported Tensor Layout that introduces a new name for a dimension. It says that in the case of packed layout, Q has a shape called THD, with T = sum(seq_len) and this allows the batch size and this T to be different.

Also, in my understanding, forcing the ragged offset to be of size B + 1 and Q to be of size B is not different than the non-packed layout. In that case, you have a 1-to-1 mapping between Q "slots" and offsets and it's equivalent to non-packed. The usecase I see for this kernel (and also how other popular paged attention kernels work) is to allow prefilling, where you treat multiple input tokens per batch size. In other words, you have T >> B and it seems to be currently impossible.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants