fix: Add check for world size and parallelism enabled#1190
Conversation
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
📝 WalkthroughWalkthroughAdds runtime validation in Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor U as Caller
participant P as Policy.__init__
participant C as Cluster
participant WB as WorkerBuilder/RayWorkerGroup
participant SL as Sharding Layout
U->>P: create Policy(config)
P->>C: world_size = cluster.world_size()
P->>P: compute pp, cp, tp, model_parallel_size = pp*cp*tp
alt world_size < model_parallel_size
P-->>U: raise ValueError (insufficient world size)
note right of P: No worker group constructed
else world_size % model_parallel_size != 0
P-->>U: raise ValueError (non-integer DP)
note right of P: No worker group constructed
else
P->>P: DP = world_size / model_parallel_size
P->>WB: select worker_builder_cls and env
P->>SL: construct sharding layout
P-->>U: initialized Policy
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks and finishing touches✅ Passed checks (6 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
nemo_rl/models/policy/lm_policy.py (1)
705-713: Make destructor resilient when init fails before setting worker_group.If validation raises early, del can run on a partially initialized object and raise AttributeError.
Apply this diff:
def __del__(self) -> None: """Shuts down the worker groups when the object is deleted or is garbage collected. This is an extra safety net in case the user forgets to call worker_group.shutdown() and the pointer to the object is lost due to leaving a function scope. It's always recommended that the user calls worker_group.shutdown(). """ - self.worker_group.shutdown() + try: + wg = getattr(self, "worker_group", None) + if wg is not None: + wg.shutdown() + except Exception: + # Best-effort cleanup; avoid raising from __del__ + pass
🧹 Nitpick comments (6)
nemo_rl/models/policy/lm_policy.py (1)
122-127: Slim error messages or move formatting into a helper to satisfy TRY003.Ruff TRY003 flags long f-strings in raises. Either shorten messages or centralize formatting (e.g., a small
_format_world_size_error(...)) and keep the exception text concise.Example (minimal change):
- raise ValueError( - f"World size ({actual_world_size}) is insufficient for the parallelism configuration. " - f"Required minimum world size: PP({pp_size}) * CP({cp_size}) * TP({tp_size}) = {model_parallel_size}. " - f"This would result in DP = {actual_world_size}/{model_parallel_size} = {actual_world_size / model_parallel_size:.3f}, but DP must be ≥ 1. " - f"Please either increase the number of GPUs/nodes or reduce the parallelism parameters." - ) + dp = actual_world_size / model_parallel_size + raise ValueError( + f"Insufficient world size ({actual_world_size}); need at least PP({pp_size})*CP({cp_size})*TP({tp_size})={model_parallel_size}. " + f"Computed DP={dp:.3f} < 1." + )Also applies to: 131-136
tests/unit/models/policy/test_policy_validation.py (5)
57-59: Remove unused parameter ‘pp’ or make intent explicit.
ppisn’t used in DTensor config. Rename to_ppto silence ARG001 and document PP=1 for DTensor.-def create_dtensor_config( - model_name: str, tp: int, pp: int = 1, cp: int = 1 -) -> PolicyConfig: +def create_dtensor_config( + model_name: str, tp: int, _pp: int = 1, cp: int = 1 +) -> PolicyConfig:Call sites can remain as
pp=1.
210-221: Drop try/except on success path; let pytest surface unexpected exceptions.Catching broad Exception (BLE001) hides useful tracebacks. Just construct Policy and assert the mock call.
- if should_pass: - # Should succeed without raising an exception - try: - policy = Policy(cluster=cluster, config=config, tokenizer=tokenizer) - # Verify the calculated DP makes sense - expected_dp = world_size // (1 * cp * tp) # PP=1 for DTensor - assert expected_dp >= 1, f"Expected DP should be >= 1, got {expected_dp}" - # Verify that worker group was created (validation passed) - mock_ray_worker_group.assert_called_once() - except Exception as e: - pytest.fail(f"Expected success for {description}, but got error: {e}") + if should_pass: + Policy(cluster=cluster, config=config, tokenizer=tokenizer) + expected_dp = world_size // (1 * cp * tp) # PP=1 for DTensor + assert expected_dp >= 1 + mock_ray_worker_group.assert_called_once()
302-313: Same: remove broad try/except in Megatron success path.- if should_pass: - # Should succeed without raising an exception - try: - policy = Policy(cluster=cluster, config=config, tokenizer=tokenizer) - # Verify the calculated DP makes sense - expected_dp = world_size // (pp * cp * tp) - assert expected_dp >= 1, f"Expected DP should be >= 1, got {expected_dp}" - # Verify that worker group was created (validation passed) - mock_ray_worker_group.assert_called_once() - except Exception as e: - pytest.fail(f"Expected success for {description}, but got error: {e}") + if should_pass: + Policy(cluster=cluster, config=config, tokenizer=tokenizer) + expected_dp = world_size // (pp * cp * tp) + assert expected_dp >= 1 + mock_ray_worker_group.assert_called_once()
165-183: Add tests for invalid zero/negative parallel sizes to lock in the new guard.Parametrize a few cases like (world_size=8, tp=0) and (tp=-1) for DTensor to assert ValueError.
Example additions:
# Invalid: non-positive TP/CP (8, 0, 1, False, "invalid", "Invalid: TP=0"), (8, -1, 1, False, "invalid", "Invalid: TP<0"), (8, 4, 0, False, "invalid", "Invalid: CP=0"),And assert on "must be a positive integer" in the error message.
245-275: Mirror zero/negative checks for Megatron (PP/TP/CP).Add cases like PP=0, TP=0 to ensure the same failure mode in Megatron.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
nemo_rl/models/policy/lm_policy.py(1 hunks)tests/unit/models/policy/test_policy_validation.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Follow the Google Python Style Guide for all Python code
Target Python 3.12+ for all Python code in NeMo-RL
Indent Python code with 4 spaces; do not use tabs
Python filenames should be snake_case (e.g., some_file.py)
Class names should be PascalCase
Function and method names should be snake_case
Local variable names should be snake_case; if starting with a number, prefix with k (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE and prefixed with G_ (e.g., G_MY_GLOBAL)
Constants should be UPPER_SNAKE_CASE
Avoid shadowing variables declared in an outer scope
Initialize all externally visible members of a class in the constructor
For public interfaces used outside a file, prefer docstrings over comments
Use comments mainly for code within a function or interfaces local to a file
Commented-out code must include a nearby comment explaining usage and why it is commented out; otherwise remove before merging
Use Google-style docstrings for classes and functions (Sphinx-parseable)
Avoid using reflection when functionality can be easily achieved without it
Limit except clauses to the smallest specific set of exceptions possible
For duck-typing via try/except, keep the try body minimal and use else for main logic
Add the NVIDIA copyright header (with current year) at the top of all Python files, excluding tests/ and test-only scripts
Files:
nemo_rl/models/policy/lm_policy.pytests/unit/models/policy/test_policy_validation.py
nemo_rl/**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
nemo_rl/**/*.py: Do not set non-None configuration defaults in code; YAML is the single source of truth for defaults
Access required config attributes directly (e.g., policy_cfg["precision"]) and assume presence; do not introduce hidden defaults
Express configuration optionality via TypedDict using typing.NotRequired
When adding a new config key to a TypedDict subclass, document the key’s purpose, valid values/types, and recommended default in code
For any class or function decorated with @ray.remote, add '# pragma: no cover' on the class/def line (and on remote functions)
Files:
nemo_rl/models/policy/lm_policy.py
🧬 Code graph analysis (2)
nemo_rl/models/policy/lm_policy.py (3)
tests/unit/models/generation/test_vllm_generation.py (1)
cluster(221-232)tests/unit/utils/test_native_checkpoint.py (1)
cluster(96-109)nemo_rl/distributed/virtual_cluster.py (1)
world_size(357-358)
tests/unit/models/policy/test_policy_validation.py (4)
nemo_rl/models/policy/__init__.py (1)
PolicyConfig(141-163)nemo_rl/models/policy/lm_policy.py (1)
Policy(56-722)nemo_rl/distributed/virtual_cluster.py (3)
world_size(357-358)get_placement_groups(347-355)get_available_address_and_port(363-397)tests/unit/conftest.py (1)
tiny_llama_model_path(456-480)
🪛 Ruff (0.13.1)
nemo_rl/models/policy/lm_policy.py
122-127: Avoid specifying long messages outside the exception class
(TRY003)
131-136: Avoid specifying long messages outside the exception class
(TRY003)
tests/unit/models/policy/test_policy_validation.py
58-58: Unused function argument: pp
(ARG001)
213-213: Local variable policy is assigned to but never used
Remove assignment to unused variable policy
(F841)
219-219: Do not catch blind exception: Exception
(BLE001)
305-305: Local variable policy is assigned to but never used
Remove assignment to unused variable policy
(F841)
311-311: Do not catch blind exception: Exception
(BLE001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Post automodel integration comment / Comment on PR
- GitHub Check: Post submodule check comment / Comment on PR
Signed-off-by: Parth Chadha <pchadha@nvidia.com>
Signed-off-by: Parth Chadha <pchadha@nvidia.com> Signed-off-by: yuanhangs <yuanhangs@nvidia.com>
What does this PR do ?
Fixes #1182
Add a one line overview of what this PR aims to accomplish.
Issues
List issues that this PR closes (syntax):
Usage
# Add a code snippet demonstrating how to use thisBefore your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
Bug Fixes
Tests