Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dgalvez/cuda graphs greedy rnnt inference squash #8191

Merged

Conversation

galv
Copy link
Collaborator

@galv galv commented Jan 17, 2024

What does this PR do ?

Speeds up RNN-T greedy decoding greatly by eliminating the 90% of the time that the GPU is idle, waiting on the CPU, via cuda graphs with conditional nodes.

Here are some results for transcribing librispeech test other, a 5.4 hour dataset, on an A100, with bfloat16, at batch size 16:

image

You can see that we get a 3.125x speed up with a 600 million parameter model, and a 2.65x speedup with a 1.1 billion parameter model.

This benchmark comes from running the following. Note that I exclude the time required to create the cuda graph from the timing measurement. This fits an inference use case where that is a one time task.

for model_name in stt_en_fastconformer_transducer_xlarge nvidia/parakeet-rnnt-1.1b; do
for use_cuda_graph_decoder in false true; do
for amp in true; do
for batch_size in 16; do
echo "GALVEZ: ${model_name} fast:${use_cuda_graph_decoder}"
$(which python) examples/asr/speech_to_text_eval.py  pretrained_name=$model_name dataset_manifest=/home/dgalvez/scratch/data/test_other.json  batch_size=$batch_size  output_filename=test_other_decoded.js\
onl  amp=$amp  amp_dtype=bfloat16  rnnt_decoding.greedy.use_cuda_graph_decoder=$use_cuda_graph_decoder  use_cer=false num_workers=1
done
done
done
done

This is a squashing of #7976 . I wanted to squash that, but unfortunately I reference a few commits in various bugs I filed, and didn't want the links to break.

Collection: ASR. Adds some utils for cuda-python to common.

Changelog

  • Adds a new class RNNTGreedyDecodeCudaGraph, which uses cuda graphs with conditional nodes to remove the CPU overhead.

Usage

  • You can potentially add a usage example below
python examples/asr/speech_to_text_eval.py \
  pretrained_name=stt_en_fastconformer_transducer_xlarge \
  dataset_manifest=/home/dgalvez/scratch/data/test_other.json \
  batch_size=16 \
  output_filename=test_other_decoded.jsonl \
  amp=true  amp_dtype=bfloat16 \
  rnnt_decoding.greedy.loop_labels=false \
  rnnt_decoding.greedy.use_cuda_graph_decoder=true \
  use_cer=false num_workers=1

Jenkins CI

To run Jenkins, a NeMo User with write access must comment jenkins on the PR.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

@github-actions github-actions bot added core Changes to NeMo Core ASR labels Jan 17, 2024
@nithinraok
Copy link
Collaborator

jenkins

@nithinraok
Copy link
Collaborator

@titu1994 for RNNT review.
@galv instead of go_very_fast can you change to enable_cuda_graphs=true something like that?

@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from 3a79df5 to d5d0a15 Compare January 26, 2024 22:24
@galv
Copy link
Collaborator Author

galv commented Jan 30, 2024

This PR is ready for review. @artbataev @titu1994 would you be willing?

@titu1994
Copy link
Collaborator

I can review this tomorrow @artbataev could you review it too ?

@artbataev
Copy link
Collaborator

Cool, I will review the PR today or tomorrow.

@galv Please fix DCO (anyway, you will need to fix it for merging).

Copy link

@github-advanced-security github-advanced-security bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CodeQL found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks decent but needs some minor fixes. Needs removal of nvtx, docstring/comments for complex functions, and dynamic import guard instead of directly importing cuda python without check.

Apart from that, just wanted to note that is incredibly complex work, serious kudos to developing this @galv

examples/asr/transcribe_speech.py Outdated Show resolved Hide resolved
nemo/collections/asr/models/rnnt_models.py Outdated Show resolved Hide resolved
nemo/collections/asr/modules/rnnt.py Outdated Show resolved Hide resolved
torch.cuda.nvtx.range_pop()

torch.cuda.nvtx.range_push("Convert to Hypotheses")
hypotheses = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the section you mentioned you wanted to speedup with numba CPU ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up doing it with pytorch, which was definitely fast enough. One tricky bit about numba was the lack of bfloat16 support. And we can end up with scores in bfloat16 if the model is running in bfloat16.



def ASSERT_DRV(err):
if isinstance(err, cuda.CUresult):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inside any function that uses cuda python, check if HAVE_CUDA_PYTHON, and if not call check_cuda_python_cuda_graphs_conditional_nodes_supported() to give users a meaningful error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not that simple, though. It's not just about whether we can import cuda. We also need to check that the appropriate version of cuda-python itself and the cuda driver are installed. https://github.com/NVIDIA/NeMo/pull/8191/files#diff-acab1a9f3d702862ddbe5720bfa6c7fd0a57f7c3dc0b59eb9878ed5cd1e3513aR28-R45

Maybe what you want in this case is something like https://github.com/NVIDIA/NeMo/blob/main/nemo/core/utils/k2_guard.py. We will simply expect those developing with cuda-python to do from cuda_python_guard.cuda import ... instead from cuda import ...

@titu1994
Copy link
Collaborator

Jenkins

conf["decoding"]["greedy"]["max_symbols"] = 5
conf["decoding"]["greedy"]["loop_labels"] = False

with tempfile.NamedTemporaryFile() as fp:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change_decoding_strategy should be enough. Saving/restoring models takes a lot of time, it's better to avoid serialization if possible in unit tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the tip. It sped up test execution almost 3 times.

Copy link
Collaborator

@artbataev artbataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me. Agree with comments from @titu1994
@galv Please, clean up prints, unused imports, and make tests using data on our CI system (not local).

@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from 9328d1f to e5827b1 Compare February 9, 2024 00:49
@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch 2 times, most recently from 7b3aaf7 to e2d4174 Compare February 14, 2024 23:39
This uses CUDA 12.3's conditional node support.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from e2d4174 to 5b31417 Compare February 14, 2024 23:48
titu1994
titu1994 previously approved these changes Feb 15, 2024
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR looks great! Thanks! @artbataev please go through it once more and merge when ready

@artbataev
Copy link
Collaborator

jenkins

Copy link
Collaborator

@artbataev artbataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@galv, the issue still exists.

  • See details about decoding with cuda=1 below
  • Also, please change the PR name (e.g., "Cuda graphs rnnt inference") and description to match the current code (e.g., go_very_fast is no more valid).
  • If possible, also please move the with_conditional_node code to a separate file to make it reusable (but this is not a blocker)

Issue with cuda != 0

I still see an error, but the error is different now. Did you test it? Do you need a multi-gpu machine?

I do not want to block the PR, but I think that it is a bug and should be fixed if possible.

Cuda 1 + cuda graphs: Fail

python examples/asr/speech_to_text_eval.py  pretrained_name=stt_en_fastconformer_transducer_large  dataset_manifest=test_other.json  batch_size=16  output_filename=test_other_decoded.jsonl  amp=false amp_dtype=bfloat16  rnnt_decoding.greedy.use_cuda_graph_decoder=true rnnt_decoding.greedy.loop_labels=false cuda=1

...
in _reinitialize
logp = self.caller._joint_step(self.f, g, log_normalize=None)[:, 0, 0, :]
...
RuntimeError: CUDA error: operation not permitted when stream is capturing
Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.
During handling of the above exception, another exception occurred:
...
RuntimeError: Capture must end on the same stream it began on.

Cuda 0 + cuda graphs: OK

python examples/asr/speech_to_text_eval.py  pretrained_name=stt_en_fastconformer_transducer_large  dataset_manifest=test_other.json  batch_size=16  output_filename=test_other_decoded.jsonl  amp=false amp_dtype=bfloat16  rnnt_decoding.greedy.use_cuda_graph_decoder=true rnnt_decoding.greedy.loop_labels=false cuda=0

Cuda 1 + no cuda graphs (both loop frames/loop labels): OK

@artbataev
Copy link
Collaborator

artbataev commented Feb 21, 2024

@galv as I see, the issue can be fixed when passing appropriate device to cuda streams initializers and getters:

def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device):
    ...
    body_stream = torch.cuda.Stream(device=device)
    previous_stream = torch.cuda.current_stream(device=device)

def _reinitialize(...):
    ...
    with torch.cuda.stream(torch.cuda.Stream(device=device)), torch.inference_mode(), torch.cuda.graph(self.graph):
    ... 
    capture_status, _, graph, _, _ = cu_call(
                cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream)
    )

@galv
Copy link
Collaborator Author

galv commented Feb 22, 2024

@artbataev thank you for the initial suggestion. It works when the decoder has not been run yet. However, it doesn't work if the decoder has already been run. You can see my failing test here: 36b3273

Clearly something obscure is happening here. The commit message provides more details. I've spent a few hours trying to debug this so I need to stop for the day.

@artbataev
Copy link
Collaborator

artbataev commented Feb 22, 2024

@Galvi tried some changes, and it seems I can get it to work.
But I'm wondering why these changes are required and why everything works when creating a graph for the first time on any device. There may be some bugs in PyTorch.

  1. Pass stream explicitly to torch.cuda.graph. After this change, I'm able to run the test, but the final comparison of the results fails (the results seem to be incorrect for the second run)
  2. Pass the device explicitly to all calls torch.cuda.current_stream. After this change, the test is passed.
# Always create a new stream, because the per-thread default stream disallows stream capture to a graph.
stream_for_graph = torch.cuda.Stream(self.device)
with torch.cuda.stream(stream_for_graph), torch.inference_mode(), torch.cuda.graph(self.graph, stream=stream_for_graph):
    ... # capture graph
   # pass device explicitly
   capture_status, _, graph, _, _ = cu_call(
       cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=self.device).cuda_stream)
   )
   ...

@contextlib.contextmanager
def with_conditional_node(while_loop_kernel, while_loop_args, while_loop_conditional_handle, device):
    ...
    # pass device explicitly here and in other calls
    capture_status, _, graph, _, _ = cu_call(cudart.cudaStreamGetCaptureInfo(torch.cuda.current_stream(device=device).cuda_stream))
    ... 

You can see the full commit here: artbataev@77fc36e

Thank you, Vladimir.

Signed-off-by: Daniel Galvez <dgalvez@computelab-frontend-3.nvidia.com>
@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from 0c509e7 to 7bbbe3d Compare February 22, 2024 19:41
@galv
Copy link
Collaborator Author

galv commented Feb 22, 2024

jenkins

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from 7117260 to 118c01a Compare February 22, 2024 20:24
It will crash in cuda-python 12.4.0.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from dc3d1ff to fb2bd7a Compare February 22, 2024 20:57
@galv
Copy link
Collaborator Author

galv commented Feb 22, 2024

jenkins

@galv
Copy link
Collaborator Author

galv commented Feb 23, 2024

jenkins

Previous failure seems to be a spurious failure caused by git clone failing.

@artbataev I incorporated your change after verifying it on a multi-GPU machine. Thank you again. I made one more commit fb2bd7a as well which makes this work with cuda-python version 12.4.0 and greater. It turns out that the bug fix in that version makes the phGraph_out variable not writable. So I must not use my workaround when cuda-python > 12.3.0. Things are well tested at this point.

artbataev
artbataev previously approved these changes Feb 23, 2024
Copy link
Collaborator

@artbataev artbataev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the really speed-of-light decoding with Cuda graphs!

@artbataev
Copy link
Collaborator

@galv I manually restarted Jenkins, but it is still waiting for an executor

@artbataev
Copy link
Collaborator

@galv please fix the test failing on Jenkins (the guard is needed)

FAILED tests/collections/asr/decoding/test_cuda_graph_rnnt_greedy_decoding.py::test_change_devices - ImportError: Found cuda-python 12.3.0rc4+8.gcb4e395, but at least version 12.3.0 is needed.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv force-pushed the dgalvez/cuda-graphs-greedy-rnnt-inference-squash branch from 1738de1 to e099042 Compare February 23, 2024 18:30
@galv
Copy link
Collaborator Author

galv commented Feb 23, 2024

jenkins

Sorry for missing the guard in that test. Hopefully things go through now.

@artbataev
Copy link
Collaborator

jenkins

@galv galv merged commit 96878f1 into NVIDIA:main Feb 26, 2024
8 checks passed
yaoyu-33 pushed a commit that referenced this pull request Feb 26, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
zpx01 pushed a commit to zpx01/NeMo that referenced this pull request Mar 8, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
NVIDIA#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

NVIDIA#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Zeeshan Patel <zeeshanp@berkeley.edu>
JRD971000 pushed a commit that referenced this pull request Mar 15, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>
pablo-garay pushed a commit that referenced this pull request Mar 19, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Pablo Garay <pagaray@nvidia.com>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
* Speed up RNN-T greedy decoding with cuda graphs

This uses CUDA 12.3's conditional node support.

Initialize cuda tensors lazily on first call of __call__ instead of __init__.

We don't know what device is going to be used at construction time,
and we can't rely on torch.nn.Module.to() to work here. See here:
NVIDIA#8436

This fixes an error "Expected all tensors to be on the same device,
but found at least two devices" that happens when you call to() on your
torch.nn.Module after constructing it.

NVIDIA#8191 (comment)

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ASR core Changes to NeMo Core
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants