Skip to content

Megatron refactor POC#1592

Draft
ashors1 wants to merge 19 commits intomainfrom
ashors/policy-refactor
Draft

Megatron refactor POC#1592
ashors1 wants to merge 19 commits intomainfrom
ashors/policy-refactor

Conversation

@ashors1
Copy link
Contributor

@ashors1 ashors1 commented Dec 2, 2025

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
@ashors1 ashors1 mentioned this pull request Dec 2, 2025
3 tasks
)


def broadcast_object_across_pp_ranks(obj: Any) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to move utils to bridge for common testing. I think some utils here are also used in bridge. but can change later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I removed broadcast_object_across_pp_ranks in favor of mbridge's broadcast_obj_from_pp_rank

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, didn't look at the mbridge code closely enough and I missed that broadcast_obj_from_pp_rank is a class method. I added broadcast_obj_from_pp_rank back to nemo-rl for now, but eventually we should plan to merge the two

)


def validate_model_paths(config: PolicyConfig) -> tuple[str, str, bool]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be a general util

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should move it elsewhere? I think the function itself is pretty specific to megatron and is only used during init

model_cfg = cfg_from_pretrained.model
cfg_from_pretrained.logger = LoggerConfig()

# Apply parallelism settings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this overriding part needs some work, right now seems hacky and hard to maintain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think the refactor should be improved, or the actual logic for overriding hyperparameters? If the latter, do you think we can address in a separate PR since this PR just focuses on refactoring the existing logic?




def setup_reference_model_state(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is reference model setup diff from actor? can they share some components? right now it seems very different flow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the flow for the two is quite different. Personally I feel like attempting to share components between the two would be somewhat contrived, because right now there isn't much obvious code duplication, but happy to discuss this further if you'd like


return output_tensor, processor_fn_wrapped

def forward_maybe_backward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need some docstring, looks confusing to have different levels of fwd stuff.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some docstrings and renamed forward_maybe_backward to megatron_forward_backward for clarity. Please take another look and let me know if things are more clear.

if hasattr(module, "_inference_key_value_memory"):
module._inference_key_value_memory = None

if gbs is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fallback seems unsafe. any cases gbs here will be different from the cfg value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we call train when running validation for DPO, for example, and we want to allow for the possibility of using a val batch size that's different from the training batch size. Similarly for microbatch size. We could make things explicit and require that users pass gbs and mbs to train. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that sounds better.

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: Anna Shors <ashors@nvidia.com>
@terrykong terrykong linked an issue Dec 4, 2025 that may be closed by this pull request
3 tasks
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
@github-actions
Copy link

github-actions bot commented Dec 5, 2025

⚠️ File Consistency Check

Check based on commit: 71dea6b (PR #1592 from ashors/policy-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: ashors1 <ashors@nvidia.com>
@github-actions
Copy link

github-actions bot commented Dec 5, 2025

⚠️ File Consistency Check

Check based on commit: fea335e (PR #1592 from ashors/policy-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: ashors1 <ashors@nvidia.com>
@github-actions
Copy link

⚠️ File Consistency Check

Check based on commit: 1f07875 (PR #1592 from ashors/policy-refactor)

⚠️ DTensor Policy Worker Synchronization Warning

The file nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py was modified in this PR, but nemo_rl/models/policy/workers/dtensor_policy_worker.py was not updated.

Why this matters:
These files contain related DTensor policy worker implementations that should be kept synchronized to ensure consistency across different versions.

Action required:

  • Please review if the changes in nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py should also be applied to nemo_rl/models/policy/workers/dtensor_policy_worker.py
  • Update nemo_rl/models/policy/workers/dtensor_policy_worker.py if necessary to maintain consistency
  • If the files are intentionally different, please add a comment in the PR explaining why

Files to check:

  • Modified: nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py
  • Not modified: nemo_rl/models/policy/workers/dtensor_policy_worker.py

This check ensures that related file implementations remain synchronized across the codebase. If you believe this warning is incorrect or the files should intentionally differ, please add a comment explaining the reasoning.

Signed-off-by: ashors1 <ashors@nvidia.com>
Signed-off-by: ashors1 <ashors@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MegatronPolicyWorker refactor

2 participants