-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch sdpa implementation in ASR mha #9590
Conversation
I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the What can be done about this? Should we write a separate test that doesn't use such masks as input? |
@SeanNaren @titu1994 An option of using SDPA is a good thing also because a Triton-based version of FAv2 with custom |
@SeanNaren @titu1994 haha, and now that FAv3 is out, probably PyTorch would integrate it as well in some near term - for maximum brr on H100 :) so having Nemo's Conformer auto-benefitting from this new work would be awesome |
This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days. |
stale bump |
Thanks for the contribution! Just some notes:
|
29ac6c0
to
741be10
Compare
@VahidooX Big thanks for review!
Do you have any script for calculating metrics on LS? And reference metrics? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR looks great! Sorry for the delay in review.
I would be ok with making the default use torch sdpa as True only if we can add a test that runs this function two times - setting the flag to true and false and comparing the output to have 1e-4 or lower mse difference.
Could you add your example code as a unit test somewhere as a check ?
Thanks @titu1994 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks much for the PR!
LGTM, minor question.
I evaluated the PR on HF-Leaderboard datasets, observed no difference in WER. And for LS test set on A6000, improved RTFx by 5%
examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml
Outdated
Show resolved
Hide resolved
I'm working with tests and faced a problem with
|
Yes, that looks like to be incorrect. |
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Let's merge |
Rerunning CI as lot of ASR tests failing |
Are the failed tests unrelated to the changes in this PR? |
The relevant test |
Merge? |
* use pytorch sdpa Signed-off-by: WoodieDudy <goshagks@gmail.com> * sdpa work Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: titu1994 <titu1994@users.noreply.github.com> * sdpa flag to false & sdpa_backend arg Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * change arg name Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * fix config args Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * add condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * update condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * remove condition on torch version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * move code to init Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> --------- Signed-off-by: WoodieDudy <goshagks@gmail.com> Signed-off-by: titu1994 <titu1994@users.noreply.github.com> Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Somshubra Majumdar <titu1994@gmail.com> Co-authored-by: titu1994 <titu1994@users.noreply.github.com> Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
* use pytorch sdpa Signed-off-by: WoodieDudy <goshagks@gmail.com> * sdpa work Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: titu1994 <titu1994@users.noreply.github.com> * sdpa flag to false & sdpa_backend arg Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * change arg name Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * fix config args Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * add condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * update condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * remove condition on torch version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * move code to init Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> --------- Signed-off-by: WoodieDudy <goshagks@gmail.com> Signed-off-by: titu1994 <titu1994@users.noreply.github.com> Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Somshubra Majumdar <titu1994@gmail.com> Co-authored-by: titu1994 <titu1994@users.noreply.github.com> Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
* use pytorch sdpa Signed-off-by: WoodieDudy <goshagks@gmail.com> * sdpa work Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: titu1994 <titu1994@users.noreply.github.com> * sdpa flag to false & sdpa_backend arg Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * change arg name Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * fix config args Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * add condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * update condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * remove condition on torch version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * move code to init Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> --------- Signed-off-by: WoodieDudy <goshagks@gmail.com> Signed-off-by: titu1994 <titu1994@users.noreply.github.com> Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Somshubra Majumdar <titu1994@gmail.com> Co-authored-by: titu1994 <titu1994@users.noreply.github.com> Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com> Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
* use pytorch sdpa Signed-off-by: WoodieDudy <goshagks@gmail.com> * sdpa work Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: titu1994 <titu1994@users.noreply.github.com> * sdpa flag to false & sdpa_backend arg Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * change arg name Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * fix config args Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * add condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * update condition on version Signed-off-by: WoodieDudy <goshagks@gmail.com> * remove condition on torch version Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * move code to init Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> * Apply isort and black reformatting Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> * refactor Signed-off-by: WoodieDudy <goshagks@gmail.com> --------- Signed-off-by: WoodieDudy <goshagks@gmail.com> Signed-off-by: titu1994 <titu1994@users.noreply.github.com> Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Somshubra Majumdar <titu1994@gmail.com> Co-authored-by: titu1994 <titu1994@users.noreply.github.com> Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com> Signed-off-by: Hainan Xu <hainanx@nvidia.com>
Hola. I changed the mha implementation for the ASR modules so that it uses
torch.nn.functional.scaled_dot_product_attention
.This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.
What else do I need to do to merge this pr?
Usage
There is also my benchmark:
PR Type:
Who can review?
cc @titu1994 @SeanNaren
Additional Information