Skip to content

feat: add TE FusedAdam QuantizedTensor compatibility patch#1417

Merged
hemildesai merged 6 commits intomainfrom
worktree-hemil/te-fused-adam
Mar 2, 2026
Merged

feat: add TE FusedAdam QuantizedTensor compatibility patch#1417
hemildesai merged 6 commits intomainfrom
worktree-hemil/te-fused-adam

Conversation

@hemildesai
Copy link
Copy Markdown
Contributor

@hemildesai hemildesai commented Mar 1, 2026

Summary

  • Add nemo_automodel/shared/te_patches.py with a runtime monkey-patch for Transformer Engine's FusedAdam._initialize_state to handle QuantizedTensor parameters
  • Apply the patches early in TrainFinetuneRecipeForNextTokenPrediction.setup() (in train_ft.py), right after apply_cache_compatibility_patches()
  • Add unit tests covering idempotency, skip-when-TE-missing, patch-application, and skip-when-upstream-already-handles-it

Example config:

optimizer:
  _target_: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
  betas: [0.9, 0.95]
  eps: 1e-8
  lr: 1.0e-5
  weight_decay: 0
  master_weights: true
  store_param_remainders: true
  exp_avg_dtype: torch.bfloat16
  exp_avg_sq_dtype: torch.bfloat16

Context

This is a workaround for NVIDIA/TransformerEngine#2535, specifically

item 2: Out of Memory Error with Fused Optimizer and DTensor**.

PyTorch introduced JAX-like DTensor, and some workloads use TE's fused optimizer with this tensor type. The >previous TE implementation used .empty_like, which works correctly for standard tensors but does not respect >sharding for DTensor—resulting in full tensors being created on each device. The upstream fix switches to .empty >with explicit shape specification.

The patch is idempotent and auto-skips if TE is not installed or if the upstream fix is already present.

Test plan

  • uv run pytest tests/unit_tests/shared/test_te_patches.py -vs — 6/6 passing
  • Verify with a TE FusedAdam + QuantizedTensor workload on GPU

🤖 Generated with Claude Code

Add runtime monkey-patch for Transformer Engine's FusedAdam optimizer
to handle QuantizedTensor parameters whose .shape does not carry
correct metadata for tensor allocation. The patch dequantizes params
before creating optimizer state buffers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 1, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

hemildesai and others added 2 commits March 1, 2026 08:32
Update the skip-guard to check for all three specific lines from
NVIDIA/TransformerEngine#2535 (dequantize, zeros_like, empty_like)
instead of just the string "QuantizedTensor". Also update copyright
year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Use the existing is_te_min_version utility to skip the monkey-patch
entirely when TE >= 2.12, where the fix from
NVIDIA/TransformerEngine#2535 is already included. Also update
copyright year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test e04fdff

hemildesai and others added 3 commits March 1, 2026 11:09
Exercise all code paths in the patched function body: regular params,
QuantizedTensor dequantization, zero_buffer, store_param_remainders,
scale creation for non-float32, and uint8/FP8 quantizer branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
TE FusedAdam accepts torch.dtype kwargs (master_weight_dtype,
exp_avg_dtype, exp_avg_sq_dtype) but YAML configs produce strings.
Resolve them via dtype_from_str before instantiation, following the
same pattern used in _dist_setup.py for mp_policy dtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
Cover all branches of the dtype resolution logic: resolving all three
TE FusedAdam dtype kwargs from strings, without torch prefix, preserving
existing torch.dtype objects, missing attrs, and partial attrs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>
@hemildesai
Copy link
Copy Markdown
Contributor Author

/ok to test 81ff313

Copy link
Copy Markdown
Collaborator

@adil-a adil-a left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I want to ask that we update the pinned version for TE in the next release.

@thomasdhc the patch here is already fixed for TE top-of-tree, so we should fix the pinned version appropriately as well. Thank you!

@hemildesai hemildesai merged commit 40fa3f8 into main Mar 2, 2026
53 checks passed
@hemildesai hemildesai deleted the worktree-hemil/te-fused-adam branch March 2, 2026 18:06
hemildesai added a commit that referenced this pull request Mar 4, 2026
* feat: add TE FusedAdam QuantizedTensor compatibility patch

Add runtime monkey-patch for Transformer Engine's FusedAdam optimizer
to handle QuantizedTensor parameters whose .shape does not carry
correct metadata for tensor allocation. The patch dequantizes params
before creating optimizer state buffers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: check full upstream fix lines in TE FusedAdam guard

Update the skip-guard to check for all three specific lines from
NVIDIA/TransformerEngine#2535 (dequantize, zeros_like, empty_like)
instead of just the string "QuantizedTensor". Also update copyright
year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: skip FusedAdam patch when TE >= 2.12 using is_te_min_version

Use the existing is_te_min_version utility to skip the monkey-patch
entirely when TE >= 2.12, where the fix from
NVIDIA/TransformerEngine#2535 is already included. Also update
copyright year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add coverage for _patched_initialize_state behavior

Exercise all code paths in the patched function body: regular params,
QuantizedTensor dequantization, zero_buffer, store_param_remainders,
scale creation for non-float32, and uint8/FP8 quantizer branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* feat: resolve dtype strings in optimizer config for TE FusedAdam

TE FusedAdam accepts torch.dtype kwargs (master_weight_dtype,
exp_avg_dtype, exp_avg_sq_dtype) but YAML configs produce strings.
Resolve them via dtype_from_str before instantiation, following the
same pattern used in _dist_setup.py for mp_policy dtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add tests for optimizer dtype string resolution in build_optimizer

Cover all branches of the dtype resolution logic: resolving all three
TE FusedAdam dtype kwargs from strings, without torch prefix, preserving
existing torch.dtype objects, missing attrs, and partial attrs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
SwekeR-463 pushed a commit to SwekeR-463/Automodel that referenced this pull request Mar 11, 2026
…Mo#1417)

* feat: add TE FusedAdam QuantizedTensor compatibility patch

Add runtime monkey-patch for Transformer Engine's FusedAdam optimizer
to handle QuantizedTensor parameters whose .shape does not carry
correct metadata for tensor allocation. The patch dequantizes params
before creating optimizer state buffers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: check full upstream fix lines in TE FusedAdam guard

Update the skip-guard to check for all three specific lines from
NVIDIA/TransformerEngine#2535 (dequantize, zeros_like, empty_like)
instead of just the string "QuantizedTensor". Also update copyright
year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: skip FusedAdam patch when TE >= 2.12 using is_te_min_version

Use the existing is_te_min_version utility to skip the monkey-patch
entirely when TE >= 2.12, where the fix from
NVIDIA/TransformerEngine#2535 is already included. Also update
copyright year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add coverage for _patched_initialize_state behavior

Exercise all code paths in the patched function body: regular params,
QuantizedTensor dequantization, zero_buffer, store_param_remainders,
scale creation for non-float32, and uint8/FP8 quantizer branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* feat: resolve dtype strings in optimizer config for TE FusedAdam

TE FusedAdam accepts torch.dtype kwargs (master_weight_dtype,
exp_avg_dtype, exp_avg_sq_dtype) but YAML configs produce strings.
Resolve them via dtype_from_str before instantiation, following the
same pattern used in _dist_setup.py for mp_policy dtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add tests for optimizer dtype string resolution in build_optimizer

Cover all branches of the dtype resolution logic: resolving all three
TE FusedAdam dtype kwargs from strings, without torch prefix, preserving
existing torch.dtype objects, missing attrs, and partial attrs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: SwekeR-463 <swekerswasti@gmail.com>
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
* feat: add TE FusedAdam QuantizedTensor compatibility patch

Add runtime monkey-patch for Transformer Engine's FusedAdam optimizer
to handle QuantizedTensor parameters whose .shape does not carry
correct metadata for tensor allocation. The patch dequantizes params
before creating optimizer state buffers.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: check full upstream fix lines in TE FusedAdam guard

Update the skip-guard to check for all three specific lines from
NVIDIA/TransformerEngine#2535 (dequantize, zeros_like, empty_like)
instead of just the string "QuantizedTensor". Also update copyright
year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* fix: skip FusedAdam patch when TE >= 2.12 using is_te_min_version

Use the existing is_te_min_version utility to skip the monkey-patch
entirely when TE >= 2.12, where the fix from
NVIDIA/TransformerEngine#2535 is already included. Also update
copyright year to 2026 on new files.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add coverage for _patched_initialize_state behavior

Exercise all code paths in the patched function body: regular params,
QuantizedTensor dequantization, zero_buffer, store_param_remainders,
scale creation for non-float32, and uint8/FP8 quantizer branch.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* feat: resolve dtype strings in optimizer config for TE FusedAdam

TE FusedAdam accepts torch.dtype kwargs (master_weight_dtype,
exp_avg_dtype, exp_avg_sq_dtype) but YAML configs produce strings.
Resolve them via dtype_from_str before instantiation, following the
same pattern used in _dist_setup.py for mp_policy dtypes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

* test: add tests for optimizer dtype string resolution in build_optimizer

Cover all branches of the dtype resolution logic: resolving all three
TE FusedAdam dtype kwargs from strings, without torch prefix, preserving
existing torch.dtype objects, missing attrs, and partial attrs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Hemil Desai <hemild@nvidia.com>

---------

Signed-off-by: Hemil Desai <hemild@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants