Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch sdpa implementation in ASR mha #9590

Merged
merged 21 commits into from
Oct 16, 2024
Merged

Conversation

WoodieDudy
Copy link
Contributor

@WoodieDudy WoodieDudy commented Jul 2, 2024

Hola. I changed the mha implementation for the ASR modules so that it uses torch.nn.functional.scaled_dot_product_attention.
This accelerated forward in the mha by 27% and backward by 17% on the A100.
Pytorch sdpa is continuously being optimized, ensuring that we benefit from the latest performance improvements.
My code uses memory efficient backend in sdpa because flash attention doesn't support custom attention bias. There is ongoing work to contribute custom bias support in the flash-attention repository. PR.

What else do I need to do to merge this pr?

Usage

There is also my benchmark:

import torch
import torch.nn as nn
import torch.utils.benchmark as benchmark
from nemo.collections.asr.parts.submodules.multi_head_attention import RelPositionMultiHeadAttention

torch.manual_seed(123)

device = "cuda"
batch_size = 32
seq_len = 1024
d_model = 512
n_head = 8

query = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
key = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
value = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)
mask = torch.ones(batch_size, seq_len, seq_len, device=device, requires_grad=False)
mask = torch.triu(mask, diagonal=1).bool() # mask: True - make zero, False - leave unchanged 
mask = None
pos_emb = torch.rand(batch_size, seq_len, d_model, device=device, requires_grad=True)

attention_sdpa = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=True).to(device)
attention_original = RelPositionMultiHeadAttention(n_head, d_model, 0.0, None, None, use_pytorch_sdpa=False).to(device)
for original_param, sdpa_param in zip(attention_original.parameters(), attention_sdpa.parameters()):
    original_param.data.copy_(sdpa_param.data)

# attention_sdpa = torch.compile(attention_sdpa)
# attention_original = torch.compile(attention_original)


def measure_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='attention(query, key, value, mask, pos_emb);torch.cuda.synchronize()',
        setup='torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )

    with torch.no_grad():
        torch.cuda.synchronize()
        results = timer.blocked_autorange(min_run_time=10)
        forward_time = results.mean
        output = attention(query, key, value, mask, pos_emb)
    return forward_time, output


def measure_fwd_bwd_time(attention, query, key, value, mask, pos_emb):
    timer = benchmark.Timer(
        stmt='loss=attention(query, key, value, mask, pos_emb).sum();torch.cuda.synchronize();loss.backward();torch.cuda.synchronize()',
        globals={'attention': attention, 'query': query, 'key': key, 'value': value, 'mask': mask, 'pos_emb': pos_emb}
    )
    torch.cuda.synchronize()
    results = timer.blocked_autorange(min_run_time=10)
    fwd_bwd_time = results.mean
    return fwd_bwd_time


time_fwd_original, output_original = measure_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_sdpa, output_sdpa = measure_time(attention_sdpa, query, key, value, mask, pos_emb)

print(f"Original implementation time: {time_fwd_original:.6f} seconds")
print(f"SDPA implementation time: {time_fwd_sdpa:.6f} seconds")
print(f"SDPA boost {(time_fwd_original - time_fwd_sdpa) / time_fwd_original * 100:.2f}%")

time_fwd_bwd_original = measure_fwd_bwd_time(attention_original, query, key, value, mask, pos_emb)
time_fwd_bwd_sdpa = measure_fwd_bwd_time(attention_sdpa, query, key, value, mask, pos_emb)
time_bwd_original = time_fwd_bwd_original - time_fwd_original
time_bwd_sdpa = time_fwd_bwd_sdpa - time_fwd_sdpa

print(f"Original implementation backward time: {time_bwd_original:.6f} seconds")
print(f"SDPA implementation backward time: {time_bwd_sdpa:.6f} seconds")
print(f"SDPA backward boost {(time_bwd_original - time_bwd_sdpa) / time_bwd_original * 100:.2f}%")

print(f"Outputs are {'the same' if torch.allclose(output_original, output_sdpa, atol=1e-5) else 'different'}")

# Original implementation time: 0.049075 seconds
# SDPA implementation time: 0.035598 seconds
# SDPA boost 27.46%
# Original implementation backward time: 0.127004 seconds
# SDPA implementation backward time: 0.104986 seconds
# SDPA backward boost 17.34%
# Outputs are the same

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

cc @titu1994 @SeanNaren

Additional Information

@github-actions github-actions bot added the ASR label Jul 2, 2024
@WoodieDudy
Copy link
Contributor Author

I also attempted to run the tests in the repository but encountered an issue. NaNs appear when a mask with a fully False row is passed to MHA. Because of such mask, filling the matrix_bd with -inf values using matrix_bd.masked_fill_(mask.logical_not(), float("-inf")) results in a row of only -inf, and after the softmax, this entire row becomes NaNs. I am unsure how to resolve this since the softmax and multiplication by value occur within torch.nn.functional.scaled_dot_product_attention, and I cannot intervene. In your implementation, this is handled by manually filling with zeros attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0).

What can be done about this? Should we write a separate test that doesn't use such masks as input?

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 9, 2024

@SeanNaren @titu1994 An option of using SDPA is a good thing also because a Triton-based version of FAv2 with custom attn_bias support (FlexAttention) is being added into PyTorch core: pytorch/pytorch#130250 (comment), so Conformer attention can benefit in the future from the speed-ups and proper compilation of SDPA in core PyTorch developments

@vadimkantorov
Copy link
Contributor

@SeanNaren @titu1994 haha, and now that FAv3 is out, probably PyTorch would integrate it as well in some near term - for maximum brr on H100 :) so having Nemo's Conformer auto-benefitting from this new work would be awesome

@WoodieDudy
Copy link
Contributor Author

cc @redoctopus @jbalam-nv @okuchaiev

Copy link
Contributor

github-actions bot commented Aug 2, 2024

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Aug 2, 2024
@vadimkantorov
Copy link
Contributor

stale bump

@VahidooX
Copy link
Collaborator

VahidooX commented Aug 7, 2024

Thanks for the contribution!
@titu1994 please take a look at this PR as it looks like an interesting addition to speedup conformer models.

Just some notes:

  • Please use -10000 instead of -inf if it is possible as -inf may cause NAN with some data types.

  • Please add it as a config to the config files somewhere like here to be able to control it from configs:


    Name is "use_pytorch_sdpa"?

  • I suggest to make it True as default if we can make sure it works in all cases? @titu1994 what do you think?

  • Please evaluate one of the pretrained models on NGC on a test-other LS to make sure that it produces the same exact output and accuracy.

  • You need to set the dropout to zero manually in non-training model as sdpa does not respect that and it always uses the dropout.

  • Have you used matrix_ac in your code/calculations?

@WoodieDudy
Copy link
Contributor Author

WoodieDudy commented Aug 12, 2024

@VahidooX Big thanks for review!

  • I replaced -inf with -10000
  • Added use_pytorch_sdpa to config
  • Fixed dropout for torch sdpa
  • I dont calculate matrix_ac manually but it is calculated under the hood of torch.nn.functional.scaled_dot_product_attention (look on implementation example).
    attn_weight = q_with_bias_u @ key.transpose(-2, -1) * scale_factor
    # so matrix_ac would be equivalent to attn_weight

Do you have any script for calculating metrics on LS? And reference metrics?

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks great! Sorry for the delay in review.

I would be ok with making the default use torch sdpa as True only if we can add a test that runs this function two times - setting the flag to true and false and comparing the output to have 1e-4 or lower mse difference.

Could you add your example code as a unit test somewhere as a check ?

@WoodieDudy
Copy link
Contributor Author

Thanks @titu1994
Okey, I'll try to add tests.

nithinraok
nithinraok previously approved these changes Aug 12, 2024
Copy link
Collaborator

@nithinraok nithinraok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks much for the PR!

LGTM, minor question.

I evaluated the PR on HF-Leaderboard datasets, observed no difference in WER. And for LS test set on A6000, improved RTFx by 5%

nemo/collections/asr/modules/conformer_encoder.py Outdated Show resolved Hide resolved
@WoodieDudy
Copy link
Contributor Author

I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward.
But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right?

random_length = torch.tensor([2, 2], dtype=torch.int64)

@VahidooX
Copy link
Collaborator

I'm working with tests and faced a problem with tests/collections/asr/test_conformer_encoder.py:test_stochastic_depth_forward. But I think that in this test, the data in random_length is incorrect, because random_length must one dim with shape batch, not two dims. Am I right?

random_length = torch.tensor([2, 2], dtype=torch.int64)

Yes, that looks like to be incorrect.

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
WoodieDudy and others added 2 commits October 8, 2024 18:15
Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
@@ -32,12 +32,14 @@
Part of this code is adopted from https://github.com/espnet/espnet
"""

import contextlib

Check notice

Code scanning / CodeQL

Unused import

Import of 'contextlib' is not used.
Signed-off-by: WoodieDudy <goshagks@gmail.com>
@WoodieDudy
Copy link
Contributor Author

Let's merge

@nithinraok
Copy link
Collaborator

Rerunning CI as lot of ASR tests failing

@WoodieDudy
Copy link
Contributor Author

Are the failed tests unrelated to the changes in this PR?

@ko3n1g
Copy link
Collaborator

ko3n1g commented Oct 10, 2024

The relevant test CICD NeMo / Nemo_CICD_Test (pull_request) has passed so once this has approvals it can get merged

@ko3n1g ko3n1g enabled auto-merge (squash) October 10, 2024 11:16
@WoodieDudy
Copy link
Contributor Author

Merge?

@nithinraok nithinraok enabled auto-merge (squash) October 15, 2024 18:09
@ko3n1g ko3n1g disabled auto-merge October 15, 2024 20:17
@ko3n1g ko3n1g enabled auto-merge (squash) October 15, 2024 20:18
@ko3n1g ko3n1g merged commit a3b835a into NVIDIA:main Oct 16, 2024
447 of 467 checks passed
yashaswikarnati pushed a commit that referenced this pull request Oct 20, 2024
* use pytorch sdpa

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* sdpa work

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* sdpa flag to false & sdpa_backend arg

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* change arg name

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* fix config args

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* add condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* update condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* remove condition on torch version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* move code to init

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

---------

Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: titu1994 <titu1994@users.noreply.github.com>
Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
artbataev pushed a commit to artbataev/NeMo that referenced this pull request Oct 22, 2024
* use pytorch sdpa

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* sdpa work

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* sdpa flag to false & sdpa_backend arg

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* change arg name

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* fix config args

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* add condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* update condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* remove condition on torch version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* move code to init

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

---------

Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: titu1994 <titu1994@users.noreply.github.com>
Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
akoumpa pushed a commit that referenced this pull request Oct 24, 2024
* use pytorch sdpa

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* sdpa work

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* sdpa flag to false & sdpa_backend arg

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* change arg name

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* fix config args

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* add condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* update condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* remove condition on torch version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* move code to init

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

---------

Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: titu1994 <titu1994@users.noreply.github.com>
Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
hainan-xv pushed a commit to hainan-xv/NeMo that referenced this pull request Nov 5, 2024
* use pytorch sdpa

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* sdpa work

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: titu1994 <titu1994@users.noreply.github.com>

* sdpa flag to false & sdpa_backend arg

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* change arg name

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* fix config args

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* add condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* update condition on version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* remove condition on torch version

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* move code to init

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

* Apply isort and black reformatting

Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>

* refactor

Signed-off-by: WoodieDudy <goshagks@gmail.com>

---------

Signed-off-by: WoodieDudy <goshagks@gmail.com>
Signed-off-by: titu1994 <titu1994@users.noreply.github.com>
Signed-off-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: titu1994 <titu1994@users.noreply.github.com>
Co-authored-by: WoodieDudy <WoodieDudy@users.noreply.github.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants