Skip to content

fix: optimized TP plan lookup in NeMo-RL by qualname#1547

Merged
hemildesai merged 2 commits intomainfrom
zhiyul/fix_tp_plan_lookup
Mar 15, 2026
Merged

fix: optimized TP plan lookup in NeMo-RL by qualname#1547
hemildesai merged 2 commits intomainfrom
zhiyul/fix_tp_plan_lookup

Conversation

@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia commented Mar 15, 2026

Problem

When NeMo-RL uses Automodel as a submodule, the optimized TP plan (e.g. for LlamaForCausalLM) was silently skipped,
causing ~16 GB extra peak GPU memory during SFT with FSDP2+TP.

Root Cause

NeMo-RL auto-sets force_hf=True for models whose adapter does not implement convert_single_tensor_to_hf. This triggers _get_mixin_wrapped_class() in model_init.py, which creates a new class object via type(name, (HFCheckpointingMixin, model_class), ...). The wrapper copies __module__ and __qualname__ from the original but is a different Python object, so type(model) in PARALLELIZE_FUNCTIONS (identity check) returns False and the
default plan is used.

Fix

Key PARALLELIZE_FUNCTIONS by f"{cls.__module__}.{cls.__qualname__}" string instead of class object. Look up via _get_class_qualname(model_cls).

This replaces the two-strategy lookup that existed before:

  • model_cls in PARALLELIZE_FUNCTIONS — exact class identity (failed for mixin-wrapped classes)
  • model_cls.__name__ in {k.__name__ for k in PARALLELIZE_FUNCTIONS} — bare-name fallback (also failed, and could false-positive on unrelated classes from different packages with the same name)

_get_class_qualname covers both cases correctly: exact identity matches by definition, and mixin-wrapped classes preserve __module__/__qualname__ so the string key still matches. Using the full module.qualname also eliminates the false-positive risk of the old __name__-only fallback.

Validation

Llama 3.1 8B SFT, FSDP2+TP=4, 1N8G:

  • [MEM_TRAIN][after_fwd_bwd] peak: 51.93 GB → 35.66 GB (−16 GB)
  • lm_head output: DTensor=False (all-gathered) → DTensor=True, Shard(dim=2) (sharded, no extra all-gather)

Made with Cursor

…ookup in NeMo-RL

~16 GB extra peak GPU memory when using NeMo-RL with torch 2.10 /
transformers v5 for Llama-style models with FSDP2+TP.

NeMo-RL auto-sets force_hf=True for LlamaForCausalLM because its adapter
does not implement convert_single_tensor_to_hf (required for weight syncing).

force_hf=True triggers _get_mixin_wrapped_class() in model_init.py, which
creates a new class via type(...) wrapping the original with HFCheckpointingMixin.
The wrapper preserves __module__ and __qualname__ from the original class but
is a different Python object.

PARALLELIZE_FUNCTIONS was keyed by class objects, so the lookup:

    elif model_cls in PARALLELIZE_FUNCTIONS:

uses Python identity (is). This silently fails for the mixin-wrapped class,
falls through to the default plan which uses ColwiseParallel(output_layouts=Replicate())
for lm_head — triggering an all-gather that adds ~8 GB per forward and backward pass.

Note: this does not affect automodel standalone or older NeMo-RL (fp8/nemo-rl-tot)
because those do not auto-set force_hf=True, so type(model) is the exact
CustomLlamaForCausalLM object already stored in the dict.

- optimized_tp_plans.py: add _get_class_qualname() helper returning
  f"{cls.__module__}.{cls.__qualname__}" and key PARALLELIZE_FUNCTIONS
  by that string instead of the class object.
- parallelizer.py: replace identity lookup with
  PARALLELIZE_FUNCTIONS.get(_get_class_qualname(model_cls)) using a
  walrus operator to keep the if/elif/else chain flat.

Llama 3.1 8B SFT, FSDP2+TP=4, 1N8G:

  [LM_HEAD_FWD] peak: 33.579 GB → 29.665 GB  (-3.9 GB)
  [MEM_TRAIN][after_fwd_bwd] peak: 51.93 GB → 35.66 GB  (-16.3 GB)
  lm_head output: DTensor=False (gathered) → DTensor=True placements=(Shard(dim=2),)

Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 15, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ZhiyuLi-Nvidia ZhiyuLi-Nvidia changed the title fix: key PARALLELIZE_FUNCTIONS by qualname to fix optimized TP plan lookup in NeMo-RL fix: optimized TP plan lookup in NeMo-RL by qualname Mar 15, 2026
@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 8830bd5

Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
@ZhiyuLi-Nvidia
Copy link
Copy Markdown
Contributor Author

/ok to test 7962235

@hemildesai hemildesai merged commit 385073c into main Mar 15, 2026
52 checks passed
@hemildesai hemildesai deleted the zhiyul/fix_tp_plan_lookup branch March 15, 2026 23:37
@yuki-97 yuki-97 mentioned this pull request Mar 24, 2026
linnanwang pushed a commit that referenced this pull request Apr 24, 2026
* fix: key PARALLELIZE_FUNCTIONS by qualname to fix optimized TP plan lookup in NeMo-RL

~16 GB extra peak GPU memory when using NeMo-RL with torch 2.10 /
transformers v5 for Llama-style models with FSDP2+TP.

NeMo-RL auto-sets force_hf=True for LlamaForCausalLM because its adapter
does not implement convert_single_tensor_to_hf (required for weight syncing).

force_hf=True triggers _get_mixin_wrapped_class() in model_init.py, which
creates a new class via type(...) wrapping the original with HFCheckpointingMixin.
The wrapper preserves __module__ and __qualname__ from the original class but
is a different Python object.

PARALLELIZE_FUNCTIONS was keyed by class objects, so the lookup:

    elif model_cls in PARALLELIZE_FUNCTIONS:

uses Python identity (is). This silently fails for the mixin-wrapped class,
falls through to the default plan which uses ColwiseParallel(output_layouts=Replicate())
for lm_head — triggering an all-gather that adds ~8 GB per forward and backward pass.

Note: this does not affect automodel standalone or older NeMo-RL (fp8/nemo-rl-tot)
because those do not auto-set force_hf=True, so type(model) is the exact
CustomLlamaForCausalLM object already stored in the dict.

- optimized_tp_plans.py: add _get_class_qualname() helper returning
  f"{cls.__module__}.{cls.__qualname__}" and key PARALLELIZE_FUNCTIONS
  by that string instead of the class object.
- parallelizer.py: replace identity lookup with
  PARALLELIZE_FUNCTIONS.get(_get_class_qualname(model_cls)) using a
  walrus operator to keep the if/elif/else chain flat.

Llama 3.1 8B SFT, FSDP2+TP=4, 1N8G:

  [LM_HEAD_FWD] peak: 33.579 GB → 29.665 GB  (-3.9 GB)
  [MEM_TRAIN][after_fwd_bwd] peak: 51.93 GB → 35.66 GB  (-16.3 GB)
  lm_head output: DTensor=False (gathered) → DTensor=True placements=(Shard(dim=2),)

Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix tests

Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>

---------

Signed-off-by: Zhiyu Li <zhiyul@NVIDIA.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants