-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
v1.11.0 release #136
v1.11.0 release #136
Conversation
d97adcd
to
3ef994b
Compare
## cudnn frontend v1.11 release notes cudnn frontend v1.11 is the preferred cudnn frontend version for cudnn version 9.8.0 and above. With cuDNN frontend v1.11, the minimum supported cudnn version is 9.0.0. ## New API - cudnn frontend v1.11 release flexible score modifier to the python SDPA API. Samples showcasing soft cap of the attention scores, arrow mask are available in the [cudnn_frontend/test/python/test_flexible_sdpa.py](https://github.com/NVIDIA/cuDNN-frontend/blob/main/cudnn_frontend/test/python/test_flexible_sdpa.py) file. A sample usage of score modifier is shown below: ``` score_mod=partial( custom_mask, mod_tensor=mod_tensor, neg_inf=neg_inf_tensor, seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, ) ``` - The Concatenate operation merges two or more tensors into one, along the specified axis. The user may also specify an in-place merge. ``` std::shared_ptr<Tensor_attributes> concatenate(std::vector<std::shared_ptr<Tensor_attributes>>, Concatenate_attributes); ``` - pip wheels compatible with windows x86_64 architecture are now available on [pypi](https://pypi.org/project/nvidia-cudnn-frontend/). - sdpa paged attention API now supports Q tensor to be ragged when used with cudnn version 9.7.0 and above. ## Improvements - Users can now pass the CMake flag `-DCMAKE_CXX_FLAGS="-DNV_CUDNN_FRONTEND_DISABLE_LOGGING"` to disable logging in the cuDNN frontend. - Added a new sample to showcase native cudagraph creation from cudnn for sdpa bprop operation. Fixed a bug when using the update_cuda_graph API to update cuda graph for sdpa bprop operation. ## Bug Fixes - Fixed memory leak in the test harness for some legacy tests that use ragged tensors. - Fixed a bug introduced in the benchmarking script that prevented the sdpa cudnn operation from being executed. This was because the `use_padding_mask` attribute was made mandatory for the sdpa operation. This has been fixed as well. - Updated the paged attention sample to not cause illegal memory access when changing the dimensions of the tensors in the sample. - Updated the DgradDReluBNBwdWeight sample to perform the right operation for the dgrad + drelu fusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoops forgot to actually finish off my review
INode::concatenate(std::vector<std::shared_ptr<Tensor_attributes>> x, | ||
Concatenate_attributes attributes, | ||
std::shared_ptr<Tensor_attributes> y) { | ||
for (auto& element : x) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit you could add a reserve here to pre allocate space for the concatenation
auto [O, Stats] = graph.sdpa(q, k, v, attributes); | ||
if (fn.has_value()) { | ||
attributes.set_score_mod(wrapper_function); | ||
callback_fn = fn; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can fn be moved into callback_fn here?
auto [O, Stats] = graph.sdpa(q, k, v, attributes); | ||
if (fn.has_value()) { | ||
attributes.set_score_mod(wrapper_function); | ||
callback_fn = fn; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can fn be moved into callback_fn here?
cudnn frontend v1.11 release notes
cudnn frontend v1.11 is the preferred cudnn frontend version for cudnn version 9.8.0 and above. With cuDNN frontend v1.11, the minimum supported cudnn version is 9.0.0.
New API
pip wheels compatible with windows x86_64 architecture are now available on pypi.
sdpa paged attention API now supports Q tensor to be ragged when used with cudnn version 9.7.0 and above.
Improvements
Users can now pass the CMake flag
-DCMAKE_CXX_FLAGS="-DNV_CUDNN_FRONTEND_DISABLE_LOGGING"
to disable logging in the cuDNN frontend.Adds a new sample to showcase native cudagraph creation from cudnn for sdpa bprop operation. Fixed a bug when using the update_cuda_graph API to update cuda graph for sdpa bprop operation.
Updates the create_container_and_page_table example function to use the layout that's desired for the more performant kernel."
Bug Fixes
Fixes memory leak in the test harness for some legacy tests that use ragged tensors.
Fixes a bug introduced in the benchmarking script that prevented the sdpa cudnn operation from being executed. This was because the
use_padding_mask
attribute was made mandatory for the sdpa operation. This has been fixed as well.Updates the paged attention sample to not cause illegal memory access when changing the dimensions of the tensors in the sample.
Updates the DgradDReluBNBwdWeight sample to perform the right operation for the dgrad + drelu fusion.