Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #8891 by supported GPU-side batched CTC Greedy Decoding #9100

Merged
merged 13 commits into from
May 9, 2024

Conversation

galv
Copy link
Collaborator

@galv galv commented May 2, 2024

What does this PR do ?

I add a new flag "batched_inference" to the CTC greedy decoder.

Collection: ASR

Changelog

  • Support batched inference of greedy CTC decoding.
  • Add support for batched inference for label inputs as well.

Usage

Here is an example turning this on for transcribe_speech.py

batch_size=32
amp=true

python examples/asr/speech_to_text_eval.py  pretrained_name=nvidia/parakeet-ctc-1.1b \ 
dataset_manifest=/home/dgalvez/scratch/data/test_other_sorted_downward.json  \
batch_size=$batch_size  output_filename=test_clean_decoded.jsonl  \
amp=$amp  amp_dtype=bfloat16  use_cer=false num_workers=1 \
return_hypotheses=false ctc_decoding.greedy.batched_inference=true

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

Additional Information

Here are the speedups in terms of throughput. First, I applied this diff to transcribe_speech.py in order to run multiple times:

modified   examples/asr/transcribe_speech.py                                                                                                                                                                
@@ -407,7 +407,33 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis                                                                                                       
                 override_cfg.augmentor = augmentor                                                                                                                                                         
                 override_cfg.text_field = cfg.gt_text_attr_name                                                                                                                                            
                 override_cfg.lang_field = cfg.gt_lang_attr_name                                                                                                                                            
-                transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)                                                                                                      
+                                                                                                                                                                                                           
+                transcriptions = None                                                                                                                                                                      
+                                                                                                                                                                                                           
+                torch.cuda.cudart().cudaProfilerStart()                                                                                                                                                    
+                torch.cuda.nvtx.range_push("GALVEZ_START")                                                                                                                                                 
+                for i in range(5):                                                                                                                                                                         
+                    if i == 1:                                                                                                                                                                             
+                        # import nvtx                                                                                                                                                                      
+                        # pr = nvtx.Profile()                                                                                                                                                              
+                        # pr.enable()  # begin annotating function calls                                                                                                                                   
+                        # ctx = torch.autograd.profiler.emit_nvtx()                                                                                                                                        
+                        # ctx.__enter__()                                                                                                                                                                  
+                        # torch.cuda.cudart().cudaProfilerStart()                                                                                                                                          
+                        pass                                                                                                                                                                               
+                    import time                                                                                                                                                                            
+                    del transcriptions                                                                                                                                                                     
+                    start_time = time.time()                                                                                                                                                               
+                    transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)                                                                                                  
+                    end_time = time.time()                                                                                                                                                                 
+                    print("RTFx=", 5.1 * 60 * 60 / (end_time - start_time))                                                                                                                                
+                    if i == 1:                                                                                                                                                                             
+                        # pr.disable()                                                                                                                                                                     
+                        # ctx.__exit__(None, None, None)                                                                                                                                                   
+                        # torch.cuda.cudart().cudaProfilerStop()                                                                                                                                           
+                        pass                                                                                                                                                                               
+                torch.cuda.nvtx.range_pop() # "GALVEZ_START"                                                                                                                                               
+                torch.cuda.cudart().cudaProfilerStop()                                                                                                                                                     
                                                                                                                                                                                                            
     if cfg.dataset_manifest is not None:                                                                                                                                                                   
         logging.info(f"Finished transcribing from manifest file: {cfg.dataset_manifest}")                                                                                                                  

Then I ran with and without this option turned on:

Without batched inference:

set -euo pipefail

batch_size=32
amp=true

echo "GALVEZ: Conformer CTC"
python examples/asr/speech_to_text_eval.py  pretrained_name=nvidia/parakeet-ctc-1.1b dataset_manifest=/home/dgalvez/scratch/data/test_other_sorted_downward.json  batch_size=$batch_size  output_filename\
=test_clean_decoded.jsonl  amp=$amp  amp_dtype=bfloat16  use_cer=false num_workers=1 return_hypotheses=false ctc_decoding.greedy.batched_inference=false
RTFx= 1128.651983273993
RTFx= 1248.7464311388146
RTFx= 1327.39890551856
RTFx= 1344.5921792715505
RTFx= 1346.41703283495
[NeMo I 2024-05-02 15:07:14 speech_to_text_eval:210] Dataset WER/CER 3.76%/1.41%

With batched inference:

set -euo pipefail

batch_size=32
amp=true

echo "GALVEZ: Conformer CTC"
python examples/asr/speech_to_text_eval.py  pretrained_name=nvidia/parakeet-ctc-1.1b dataset_manifest=/home/dgalvez/scratch/data/test_other_sorted_downward.json  batch_size=$batch_size  output_filename=test_clean_decoded.jsonl  amp=$amp  amp_dtype=bfloat16  use_cer=false num_workers=1 return_hypotheses=false ctc_decoding.greedy.batched_inference=true
RTFx= 1157.725053779039
RTFx= 1265.9191489003545
RTFx= 1361.67639098725
RTFx= 1354.9981484942944
RTFx= 1350.2222749634843
[NeMo I 2024-05-02 15:25:17 speech_to_text_eval:210] Dataset WER/CER 3.76%/1.41%

You can see that RTFx throughput improvements are only modest. However, in terms of inference latency, there is a large speed up at the P99 latency levels, because the sometimes long calls to max() on the CPU no longer happen. We don't have a great way to measure this right now with NeMo's current tooling, so you will just have to believe me for now. See #8891 for details about this problem.

@galv galv requested a review from titu1994 May 2, 2024 22:45
@github-actions github-actions bot added the ASR label May 2, 2024
@galv galv requested review from pzelasko and tbartley94 May 2, 2024 22:45
@galv
Copy link
Collaborator Author

galv commented May 2, 2024

To verify that latency does indeed fall, I ran the above prefixed by nsys profile --env-var=NSYS_NVTX_PROFILER_REGISTER_ONLY=0 -t nvtx -c nvtx -p GALVEZ_START with batch_inference=false and batch_inference=true. I added an nvtx range just for the decoder.

batch_inference=true case:

Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Style	Range
93.9%	71.801 s	1	71.801 s	71.801 s	71.801 s	71.801 s	0 ns	PushPop	GALVEZ_START
6.1%	4.683 s	460	10.180 ms	1.173 ms	796.420 μs	117.883 ms	19.312 ms	PushPop	decoder

batched_inference=false case:

Time	Total Time	Instances	Avg	Med	Min	Max	StdDev	Style	Range
86.7%	79.410 s	1	79.410 s	79.410 s	79.410 s	79.410 s	0 ns	PushPop	GALVEZ_START
13.3%	12.217 s	460	26.558 ms	3.800 ms	2.125 ms	171.248 ms	38.294 ms	PushPop	decoder

You can see that the decoder takes half the time it used to, and that its maximum runtime is 118ms, down from 171 ms.

@galv galv added the Run CICD label May 2, 2024
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new algorithm part looks great, but shouldn't be fused into CTC decoding but he treated as a strategy similar to how RNNT does it.

@nithinraok upto you if you want to fuse the two together like this with a flag or use rnnt strategy pattern here in CTC

galv and others added 8 commits May 6, 2024 14:14
Fixes NVIDIA#8891

Basically, doing max() on CPU one at a time is very very slow. It is
better to do that all on the GPU before we do the copy over to CPU.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
for more information, see https://pre-commit.ci

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv force-pushed the dgalvez/batched-ctc-inference branch from e193293 to 3ebd1f1 Compare May 6, 2024 21:14
@galv galv added Run CICD and removed Run CICD labels May 6, 2024
@galv galv requested a review from titu1994 May 6, 2024 23:30
titu1994
titu1994 previously approved these changes May 7, 2024
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks far cleaner than before, minor comments but overall good enough to merge.

  1. Can we make this the default algo ? If users want inference features that don't work in this mode they can switch to slower code path. NeMo 2.0 is the time for switching defaults if we want to.

  2. can we name this greedy_batch similar to RNNT

@@ -213,7 +213,7 @@ def __init__(self, decoding_cfg, blank_id: int):
self.batch_dim_index = self.cfg.get('batch_dim_index', 0)
self.word_seperator = self.cfg.get('word_seperator', ' ')

possible_strategies = ['greedy', 'beam', 'pyctcdecode', 'flashlight']
possible_strategies = ['greedy', 'greedy_vectorized', 'beam', 'pyctcdecode', 'flashlight']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm can't we call it greedy_batched similar to RNNT? I understand both are technically batched, and vectorized is more appropriate, however conformity is nice for both us and users

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good observation. I changed to "greedy_batched", and generally speaking changed "vectorized" to "batched" based on your observation about conformity.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv force-pushed the dgalvez/batched-ctc-inference branch from a4e373e to edc2b13 Compare May 7, 2024 00:50
@galv
Copy link
Collaborator Author

galv commented May 7, 2024

Will need to get back to you about what it takes to make this the default algo. But the main concern is that many existing ".nemo" files will have the "strategy" field of the relevant config set to "greedy". It's not really clear to me how we can convert those users using pretrained models to use the "greedy_batched" strategy. For what it's worth, my tests check that the interfaces are effectively the same, so I could just use GreedyBatchedCTCInfer whenever "greedy" strategy is specified, unless that distresses you.

@galv galv added Run CICD and removed Run CICD labels May 7, 2024
@pablo-garay
Copy link
Collaborator

I see CICD pipeline passed: https://github.com/NVIDIA/NeMo/actions/runs/8976148083
Let me know if you need/ready to merge

@titu1994
Copy link
Collaborator

titu1994 commented May 7, 2024

Hmm I see your point. Lets do this -

  1. Set default in config to greedy_batch
  2. Add logging message to greedy (onetime log) that informs users that greedy_batch strategy is faster.
  3. All newer models will automatically be using the greedy batch path then

I'm not average to replacing "greedy" with greedy_batch, but then we need a different flag name to allow users to fall back to older codepath. Maybe we can call it "greedy_legacy" then?

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
galv and others added 2 commits May 8, 2024 11:21
Warn when using greedy rather than greedy_batched strategy.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv
Copy link
Collaborator Author

galv commented May 8, 2024

@titu1994 done. To clarify, I went with the warn-once path rather than making "greedy" pointed to the new implementation.

@galv galv added Run CICD and removed Run CICD labels May 8, 2024
Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
@galv galv added Run CICD and removed Run CICD labels May 8, 2024
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ! Thanks for the refactor !

@galv galv merged commit 363f9ec into NVIDIA:main May 9, 2024
129 checks passed
BoxiangW pushed a commit to BoxiangW/NeMo that referenced this pull request Jun 5, 2024
…IDIA#9100)

* Support batched inference of greedy CTC decoding.

Fixes NVIDIA#8891

Basically, doing max() on CPU one at a time is very very slow. It is
better to do that all on the GPU before we do the copy over to CPU.

This new algorithm has the same interface as the old one and can be accessed by setting strategy to "greedy_batched" rather than "greedy".

Warn when using greedy rather than greedy_batched strategy.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
rohitrango pushed a commit to rohitrango/NeMo that referenced this pull request Jun 25, 2024
…IDIA#9100)

* Support batched inference of greedy CTC decoding.

Fixes NVIDIA#8891

Basically, doing max() on CPU one at a time is very very slow. It is
better to do that all on the GPU before we do the copy over to CPU.

This new algorithm has the same interface as the old one and can be accessed by setting strategy to "greedy_batched" rather than "greedy".

Warn when using greedy rather than greedy_batched strategy.

Signed-off-by: Daniel Galvez <dgalvez@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants