fix: optimized TP plan lookup in NeMo-RL by qualname#1547
Merged
hemildesai merged 2 commits intomainfrom Mar 15, 2026
Merged
Conversation
…ookup in NeMo-RL
~16 GB extra peak GPU memory when using NeMo-RL with torch 2.10 /
transformers v5 for Llama-style models with FSDP2+TP.
NeMo-RL auto-sets force_hf=True for LlamaForCausalLM because its adapter
does not implement convert_single_tensor_to_hf (required for weight syncing).
force_hf=True triggers _get_mixin_wrapped_class() in model_init.py, which
creates a new class via type(...) wrapping the original with HFCheckpointingMixin.
The wrapper preserves __module__ and __qualname__ from the original class but
is a different Python object.
PARALLELIZE_FUNCTIONS was keyed by class objects, so the lookup:
elif model_cls in PARALLELIZE_FUNCTIONS:
uses Python identity (is). This silently fails for the mixin-wrapped class,
falls through to the default plan which uses ColwiseParallel(output_layouts=Replicate())
for lm_head — triggering an all-gather that adds ~8 GB per forward and backward pass.
Note: this does not affect automodel standalone or older NeMo-RL (fp8/nemo-rl-tot)
because those do not auto-set force_hf=True, so type(model) is the exact
CustomLlamaForCausalLM object already stored in the dict.
- optimized_tp_plans.py: add _get_class_qualname() helper returning
f"{cls.__module__}.{cls.__qualname__}" and key PARALLELIZE_FUNCTIONS
by that string instead of the class object.
- parallelizer.py: replace identity lookup with
PARALLELIZE_FUNCTIONS.get(_get_class_qualname(model_cls)) using a
walrus operator to keep the if/elif/else chain flat.
Llama 3.1 8B SFT, FSDP2+TP=4, 1N8G:
[LM_HEAD_FWD] peak: 33.579 GB → 29.665 GB (-3.9 GB)
[MEM_TRAIN][after_fwd_bwd] peak: 51.93 GB → 35.66 GB (-16.3 GB)
lm_head output: DTensor=False (gathered) → DTensor=True placements=(Shard(dim=2),)
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Contributor
Author
|
/ok to test 8830bd5 |
Contributor
Author
|
/ok to test 7962235 |
hemildesai
approved these changes
Mar 15, 2026
linnanwang
pushed a commit
that referenced
this pull request
Apr 24, 2026
* fix: key PARALLELIZE_FUNCTIONS by qualname to fix optimized TP plan lookup in NeMo-RL
~16 GB extra peak GPU memory when using NeMo-RL with torch 2.10 /
transformers v5 for Llama-style models with FSDP2+TP.
NeMo-RL auto-sets force_hf=True for LlamaForCausalLM because its adapter
does not implement convert_single_tensor_to_hf (required for weight syncing).
force_hf=True triggers _get_mixin_wrapped_class() in model_init.py, which
creates a new class via type(...) wrapping the original with HFCheckpointingMixin.
The wrapper preserves __module__ and __qualname__ from the original class but
is a different Python object.
PARALLELIZE_FUNCTIONS was keyed by class objects, so the lookup:
elif model_cls in PARALLELIZE_FUNCTIONS:
uses Python identity (is). This silently fails for the mixin-wrapped class,
falls through to the default plan which uses ColwiseParallel(output_layouts=Replicate())
for lm_head — triggering an all-gather that adds ~8 GB per forward and backward pass.
Note: this does not affect automodel standalone or older NeMo-RL (fp8/nemo-rl-tot)
because those do not auto-set force_hf=True, so type(model) is the exact
CustomLlamaForCausalLM object already stored in the dict.
- optimized_tp_plans.py: add _get_class_qualname() helper returning
f"{cls.__module__}.{cls.__qualname__}" and key PARALLELIZE_FUNCTIONS
by that string instead of the class object.
- parallelizer.py: replace identity lookup with
PARALLELIZE_FUNCTIONS.get(_get_class_qualname(model_cls)) using a
walrus operator to keep the if/elif/else chain flat.
Llama 3.1 8B SFT, FSDP2+TP=4, 1N8G:
[LM_HEAD_FWD] peak: 33.579 GB → 29.665 GB (-3.9 GB)
[MEM_TRAIN][after_fwd_bwd] peak: 51.93 GB → 35.66 GB (-16.3 GB)
lm_head output: DTensor=False (gathered) → DTensor=True placements=(Shard(dim=2),)
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
* fix tests
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
---------
Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
When NeMo-RL uses Automodel as a submodule, the optimized TP plan (e.g. for
LlamaForCausalLM) was silently skipped,causing ~16 GB extra peak GPU memory during SFT with FSDP2+TP.
Root Cause
NeMo-RL auto-sets
force_hf=Truefor models whose adapter does not implementconvert_single_tensor_to_hf. This triggers_get_mixin_wrapped_class()inmodel_init.py, which creates a new class object viatype(name, (HFCheckpointingMixin, model_class), ...). The wrapper copies__module__and__qualname__from the original but is a different Python object, sotype(model) in PARALLELIZE_FUNCTIONS(identity check) returnsFalseand thedefault plan is used.
Fix
Key
PARALLELIZE_FUNCTIONSbyf"{cls.__module__}.{cls.__qualname__}"string instead of class object. Look up via_get_class_qualname(model_cls).This replaces the two-strategy lookup that existed before:
model_cls in PARALLELIZE_FUNCTIONS— exact class identity (failed for mixin-wrapped classes)model_cls.__name__ in {k.__name__ for k in PARALLELIZE_FUNCTIONS}— bare-name fallback (also failed, and could false-positive on unrelated classes from different packages with the same name)_get_class_qualnamecovers both cases correctly: exact identity matches by definition, and mixin-wrapped classes preserve__module__/__qualname__so the string key still matches. Using the fullmodule.qualnamealso eliminates the false-positive risk of the old__name__-only fallback.Validation
Llama 3.1 8B SFT, FSDP2+TP=4, 1N8G:
[MEM_TRAIN][after_fwd_bwd] peak: 51.93 GB → 35.66 GB (−16 GB)lm_headoutput:DTensor=False(all-gathered) →DTensor=True, Shard(dim=2)(sharded, no extra all-gather)Made with Cursor