Skip to content

fix(tdt): preserve disable_cuda_graphs state across change_decoding_strategy() calls#15457

Closed
CodersAcademy006 wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
CodersAcademy006:fix/tdt-cuda-graphs-decoding-strategy-15423
Closed

fix(tdt): preserve disable_cuda_graphs state across change_decoding_strategy() calls#15457
CodersAcademy006 wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
CodersAcademy006:fix/tdt-cuda-graphs-decoding-strategy-15423

Conversation

@CodersAcademy006
Copy link
Copy Markdown

Problem

When disable_cuda_graphs() is called on the decoding computer, the disabled
state is lost whenever change_decoding_strategy() is invoked — because it
creates a brand new decoding_computer object with CUDA graphs re-enabled
by default (cuda_graphs_mode='full_graph').

This silently re-enables CUDA graphs even when the user explicitly disabled
them, re-introducing the torch.load() corruption bug reported in #15423.

Concretely, calling model.transcribe(timestamps=True) triggers
change_decoding_strategy() internally, which replaces the entire
decoding_computer object — resetting CUDA graph state in the process.

Root Cause

  • disable_cuda_graphs() state was only stored inside _decoding_computer
  • change_decoding_strategy() creates a completely new _decoding_computer
  • The new object always has CUDA graphs enabled by default
  • No state was preserved or restored across this replacement

Fix

Two minimal changes across two files:

  1. tdt_decoding.pyBeamBatchedTDTInfer.disable_cuda_graphs():
    Track the disabled state via a _cuda_graphs_disabled = True flag on the
    outer object so it survives _decoding_computer replacement.

  2. rnnt_models.pyEncDecRNNTModel.change_decoding_strategy():

    • Read _cuda_graphs_disabled from the current decoding_computer BEFORE
      replacing self.decoding
    • After the new self.decoding is created, restore the disabled state on
      the new decoding_computer
    • Both read and write are wrapped in try/except AttributeError to remain
      safe across all model types that call this method

Testing

The reproducer from #15423 no longer corrupts torch.load() after
transcribe(timestamps=True) with this fix applied:

model = nemo_asr.models.ASRModel.from_pretrained('nvidia/parakeet-tdt-0.6b-v3')
model = model.cuda().eval()
model.decoding.decoding.decoding_computer.disable_cuda_graphs()
model.transcribe(['/tmp/test.wav'], timestamps=True, return_hypotheses=True)

# torch.load now works correctly after this — previously failed with:
# TypeError: 'str' object is not callable
sd = OrderedDict((f'layer.{i}', torch.randn(32, 32, device='cuda')) for i in range(10))
buf = io.BytesIO()
torch.save(sd, buf)
buf.seek(0)
torch.load(buf, weights_only=False)  # ✅ Works

Fixes #15423

Notes

  • Zero breaking changes — all existing behaviour is preserved
  • The try/except AttributeError guards make this safe for all RNNT model
    variants, not just TDT
  • The deeper PyTorch-side corruption of torch.load() dispatch table entries
    during CUDA graph capture is a separate upstream issue and should be tracked
    in pytorch/pytorch

chtruong814 and others added 3 commits March 2, 2026 18:11
…Mo#15432)

* Add initial uv lock

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Add build docs

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix docs build

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix docstring formatting in magpietts.py

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix docs build

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Rename broken links files

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Add release docs jobs

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix docs

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Dry-run of docs publishing

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Dry-run of docs publishing"

This reverts commit 1c3aa19.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Revert "Dry-run of docs publishing""

This reverts commit 43c19ae.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix dry run

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix broken links

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Add retries for linke check

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix broken link

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix broken link

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Revert "Revert "Dry-run of docs publishing"""

This reverts commit a28f306.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Revert "Revert "Revert "Dry-run of docs publishing""""

This reverts commit 9353dcb.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Test nightly publish

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Increase docs broken link retry and timeout

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Test nightly publish"

This reverts commit 9c8e7a4.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Revert "Revert "Revert "Revert "Dry-run of docs publishing"""""

This reverts commit 850207e.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Fix docs footer

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Test nightly push

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

* Revert "Test nightly push"

This reverts commit 05894f2.

Signed-off-by: Charlie Truong <chtruong@nvidia.com>

---------

Signed-off-by: Charlie Truong <chtruong@nvidia.com>
Signed-off-by: Srijan Upadhyay <104912634+CodersAcademy006@users.noreply.github.com>
…trategy()

Signed-off-by: Srijan Upadhyay <104912634+CodersAcademy006@users.noreply.github.com>
Signed-off-by: Srijan Upadhyay <104912634+CodersAcademy006@users.noreply.github.com>
@CodersAcademy006 CodersAcademy006 force-pushed the fix/tdt-cuda-graphs-decoding-strategy-15423 branch from f06c118 to 78474bd Compare March 2, 2026 18:11
CodersAcademy006 and others added 2 commits March 2, 2026 18:11
Signed-off-by: CodersAcademy006 <CodersAcademy006@users.noreply.github.com>
@CodersAcademy006
Copy link
Copy Markdown
Author

@rsclafani Please look into this as well, Thank You. And please provide me what else i can change in this?

@artbataev
Copy link
Copy Markdown
Collaborator

Thanks for the contribution.
Unfortunately, I need to close this PR: the correct way to disable CUDA graphs is to change the values in the config.

from omegaconf import open_dict

# greedy decoding
with open_dict(model.cfg.decoding.greedy):
    model.cfg.decoding.greedy.use_cuda_graph_decoder = False

# beam decoding
with open_dict(model.cfg.decoding.beam):
    model.cfg.decoding.beam.allow_cuda_graphs = False

model.change_decoding_strategy(model.cfg.decoding)

We should not introduce such stateful changes - change_decoding_strategy expected to fully reinstantiate the decoder using the provided decoding parameters.

@artbataev artbataev closed this Mar 2, 2026
@CodersAcademy006
Copy link
Copy Markdown
Author

@artbataev I sincerely apologize if I moved forward without formal confirmation. Please guide me on the necessary steps to rectify this and ensure the issue is completely resolved.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TDT decoder CUDA graph capture corrupts torch.load() — process-wide, irreversible after first model.transcribe()

3 participants