-
Notifications
You must be signed in to change notification settings - Fork 108
rename nvfsdp to mfsdp globally #1137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
WalkthroughReferences to nvFSDP were renamed to megatron-fsdp / mfsdp across documentation, configs, scripts, and tests. Geneformer checkpointing and training APIs were updated to use Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Hydra as Hydra Config
participant Trainer as Train Script
participant Model
participant Dist as Dist Backend
participant Ckpt as Checkpoint IO
User->>Hydra: Provide config (training.use_mfsdp)
Hydra-->>Trainer: Resolved params
alt training.use_mfsdp = true
Trainer->>Dist: Initialize process group + enable fully_shard (mfsdp)
Trainer->>Model: Construct and wrap (mfsdp)
opt Load checkpoint
Trainer->>Ckpt: load_checkpoint(..., use_mfsdp=true)
Ckpt-->>Trainer: Restore model/optimizer via distributed checkpoint
end
else DDP path
Trainer->>Dist: Initialize process group (DDP)
Trainer->>Model: Construct and wrap (DDP)
opt Load checkpoint
Trainer->>Ckpt: load_checkpoint(..., use_mfsdp=false)
Ckpt-->>Trainer: Restore model/optimizer (file-based)
end
end
loop Train steps
Trainer->>Model: forward/backward/step
opt Periodic save
alt mfsdp
Trainer->>Ckpt: save_checkpoint(..., use_mfsdp=true)
Ckpt-->>Trainer: Saved via distributed checkpoint
else DDP
Trainer->>Ckpt: save_checkpoint(..., use_mfsdp=false)
Ckpt-->>Trainer: Saved to files
end
end
end
opt Final save
alt mfsdp
Trainer->>Ckpt: save_final_model(..., use_mfsdp=true)
Ckpt->>Model: Gather parameters to rank 0 and unwrap
Ckpt-->>Trainer: Final model saved (mfsdp)
else DDP
Trainer->>Ckpt: save_final_model(..., use_mfsdp=false)
Ckpt-->>Trainer: Final model saved (DDP)
end
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
bionemo-recipes.md (3)
197-205: Docs still show distributed.nvfsdp config block — update to mfsdp.
This contradicts the rename and can confuse users.Apply:
-# Distributed training -distributed: - backend: nccl - nvfsdp: - enable: true - sharding_strategy: zero3 +# Distributed training +distributed: + backend: nccl + mfsdp: + enable: true + sharding_strategy: zero3
245-247: Rename example WandB run to match mfsdp.-wandb: - name: "esm2_nvfsdp_benchmark" +wandb: + name: "esm2_mfsdp_benchmark"
414-417: Bottom “Examples” still references esm2_native_te_nvfsdp/.
Update folder name and wording.-- **`esm2_native_te_nvfsdp/`**: Comprehensive example showing vanilla PyTorch with TE and nvFSDP +- **`esm2_native_te_mfsdp/`**: Comprehensive example showing vanilla PyTorch with TE and megatron-fsdprecipes/README.md (2)
197-205: Config snippet still uses nvfsdp — switch to mfsdp.
Prevents new users from copying stale keys.-# Distributed training -distributed: - backend: nccl - nvfsdp: - enable: true - sharding_strategy: zero3 +# Distributed training +distributed: + backend: nccl + mfsdp: + enable: true + sharding_strategy: zero3
245-247: Benchmark example name still uses nvfsdp.-wandb: - name: "esm2_nvfsdp_benchmark" +wandb: + name: "esm2_mfsdp_benchmark"recipes/geneformer_native_te_mfsdp_fp8/test_monotonic_decreasing_loss.py (1)
115-143: Fix train.py path resolution and make overrides handling robustRunning via a bare "train.py" relies on cwd; in CI this can 404. Also, the parameter named kwargs is a list, and None would crash with star-expansion. Use an absolute path based on file and guard None.
Apply:
- def _run_training_with_config(self, config_name, kwargs=None): + def _run_training_with_config(self, config_name, overrides=None): @@ - # Run the training script with the specified config and override num_train_steps - cmd = [ - sys.executable, - "train.py", + # Resolve train.py relative to this test file + this_dir = os.path.dirname(__file__) + train_script = os.path.join(this_dir, "train.py") + arg_overrides = overrides or [] + # Run the training script with the specified config and override num_train_steps + cmd = [ + sys.executable, + train_script, "--config-name", config_name, - *kwargs, + *arg_overrides, "training.num_train_steps=50", # Override to 50 steps regardless of config "training.resume_from_checkpoint=false", ]recipes/geneformer_native_te_mfsdp_fp8/test_distributed_checkpointing.py (1)
92-119: Add timeouts to subprocess.run to prevent CI hangsSeveral torchrun invocations lack timeouts; a stalled NCCL job can hang the suite indefinitely. Mirror the 10–15 min guard you used elsewhere.
Example changes (replicate across calls):
- result1 = subprocess.run(cmd_phase1, check=False, capture_output=True, text=True, env=env) + result1 = subprocess.run( + cmd_phase1, check=False, capture_output=True, text=True, env=env, timeout=900 + ) @@ - result2 = subprocess.run(cmd_phase2, check=False, capture_output=True, text=True, env=env) + result2 = subprocess.run( + cmd_phase2, check=False, capture_output=True, text=True, env=env, timeout=900 + )Also applies to: 181-206, 449-451, 608-610, 727-729, 864-866, 1005-1007, 1124-1126, 1261-1263
♻️ Duplicate comments (1)
recipes/geneformer_native_te_mfsdp_fp8/hydra_config/4b.yaml (1)
16-16: Consistent flag rename.
Same verification as above to catch any remaining use_nvfsdp references.
🧹 Nitpick comments (20)
recipes/esm2_native_te_mfsdp_thd/hydra_config/L1_650M.yaml (1)
12-12: W&B run name rename looks good.Name updated to mfsdp; consistent with the PR goal.
If you want uniform naming across all ESM2 configs, consider a consistent suffix pattern like "mfsdp".
recipes/geneformer_native_te_mfsdp_fp8/hydra_config/defaults.yaml (1)
22-22: Defaults switched to use_mfsdp—ensure backward-compatibility shim if users have old configs.Optional: tolerate either key, preferring use_mfsdp, to reduce breakage.
Apply a small adapter before reading flags:
-# read flag directly -use_mfsdp = cfg.training.use_mfsdp +# tolerate legacy key +use_mfsdp = cfg.training.get("use_mfsdp") +if use_mfsdp is None: + use_mfsdp = cfg.training.get("use_nvfsdp", False)recipes/geneformer_native_te_mfsdp_fp8/AGENT_DOCUMENTATION.md (1)
185-185: Unreleased pre-release version pinned—loosen or remove version constraint
There is no stable PyPI release ofmegatron-fsdp(latest is 0.1.0rc0 as of Aug 22 2025)cite12; pinning0.1.0rc1may fail. In docs, prefer omitting the version or using a range (e.g.,>=0.1.0rc0) instead.recipes/geneformer_native_te_mfsdp_fp8/hydra_config/l0_sanity.yaml (1)
16-16: Switched to use_mfsdp—LGTM.Mirror the defaults.yaml shim suggestion if you expect older runs.
recipes/esm2_native_te_mfsdp/hydra_config/L1_15B_perf_test.yaml (1)
11-11: Perf test W&B name migrated to mfsdp—LGTM.Consider standardizing project naming across configs to ease filtering in W&B dashboards.
recipes/geneformer_native_te_mfsdp_fp8/checkpoint.py (7)
44-51: Rename parameter looks consistent; tightendist_configtyping
dist_configis typed asDict[str, Any]but used as an object with.local_rankand.is_main_process(). Consider a lightweight Protocol to make intent clear to type checkers.-from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, Protocol + +class _DistConfig(Protocol): + local_rank: int + def is_main_process(self) -> bool: ... @@ -def load_checkpoint( - use_mfsdp: bool, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - ckpt_dir: str, - dist_config: Dict[str, Any], +def load_checkpoint( + use_mfsdp: bool, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + ckpt_dir: str, + dist_config: _DistConfig,
100-116: Validation logic OK; add an explicit rank-0 check to reduce log spamThe
.metadatadirectory check is appropriate. To avoid N identical warnings per rank, gate warnings/info on main process for MFSDP too.-def _validate_checkpoint(checkpoint_path: str, use_mfsdp: bool, logger: logging.Logger) -> bool: +def _validate_checkpoint(checkpoint_path: str, use_mfsdp: bool, logger: logging.Logger, *, is_main: bool = True) -> bool: @@ - if not os.path.isdir(checkpoint_path): - logger.warning(f"mfsdp checkpoint should be a directory, but {checkpoint_path} is not") + if not os.path.isdir(checkpoint_path): + if is_main: + logger.warning(f"mfsdp checkpoint should be a directory, but {checkpoint_path} is not") return False @@ - if not os.path.isfile(metadata_path): - logger.warning(f"mfsdp checkpoint missing .metadata file at {metadata_path}") + if not os.path.isfile(metadata_path): + if is_main: + logger.warning(f"mfsdp checkpoint missing .metadata file at {metadata_path}") return FalseAnd pass
is_main=dist_config.is_main_process()at call sites.
138-160: Corrupted checkpoint cleanup is good; consider guarding by main process + barrierFor MFSDP/DDP, multiple ranks may concurrently remove the same path. Prefer cleanup on rank 0 plus a barrier to avoid races.
+import torch.distributed as dist @@ - if os.path.exists(checkpoint_path): + if dist.is_initialized() and dist.get_rank() != 0: + pass # only rank0 performs cleanup + elif os.path.exists(checkpoint_path): @@ - except (OSError, PermissionError) as e: + except (OSError, PermissionError) as e: logger.error(f"Could not clean up existing checkpoint {checkpoint_path}: {e}") return + if dist.is_initialized(): + dist.barrier()
166-172: MFSDP save path: ensure all ranks participate; consider post-save barrier before validationDCP save must run on all ranks (looks correct). To prevent validation races on networked filesystems, add a barrier before validating.
if use_mfsdp: torch.distributed.checkpoint.save( {"model": model.state_dict(), "optimizer": optimizer.state_dict()}, checkpoint_id=checkpoint_path, ) logger.info(f"Successfully saved mfsdp checkpoint to {checkpoint_path}") + if torch.distributed.is_initialized(): + torch.distributed.barrier()
182-185: Validate only on main process to reduce FS churnAfter saving, validate once on rank 0 (both MFSDP and DDP).
- if not _validate_checkpoint(checkpoint_path, use_mfsdp, logger): + is_main = (not torch.distributed.is_initialized()) or torch.distributed.get_rank() == 0 + if is_main and not _validate_checkpoint(checkpoint_path, use_mfsdp, logger, is_main=True): logger.error(f"Saved checkpoint {checkpoint_path} failed validation!")
199-212: Gather uses private MFSDP internals; add guards and a barrierThese private methods may drift. Add attribute checks and a barrier to fail fast with a clear message and to synchronize before returning.
def _gather_mfsdp_parameters(model, logger): """Helper function to gather mfsdp parameters across all processes.""" logger.info("Starting mfsdp parameter gathering...") - # These collective operations must run on ALL processes - model._replace_param_with_raw_if_needed() - model.all_gather_pipeline.all_gather_params(list(model.module.parameters())) + # These collective operations must run on ALL processes + missing = [name for name in ["_replace_param_with_raw_if_needed", "all_gather_pipeline", "param_and_grad_buffer"] if not hasattr(model, name)] + if missing: + raise RuntimeError(f"Model does not expose expected MFSDP internals: {missing}") + model._replace_param_with_raw_if_needed() + model.all_gather_pipeline.all_gather_params(list(model.module.parameters())) @@ - logger.info("mfsdp parameter gathering completed") + if torch.distributed.is_initialized(): + torch.distributed.barrier() + logger.info("mfsdp parameter gathering completed")
221-247: Final save flow is sound; minor: prefer using_get_underlying_modelfor symmetryUsing
_get_underlying_model(model)in both branches keeps behavior consistent if wrapping changes.- underlying_model = model.module + underlying_model = _get_underlying_model(model)recipes/geneformer_native_te_mfsdp_fp8/train.py (3)
166-176: DDP fallback OK; considerstatic_graph=Trueif applicableIf the model graph is static, enabling
static_graph=Truecan improve performance in recent PyTorch versions.- model = torch.nn.parallel.DistributedDataParallel( + model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[dist_config.local_rank], output_device=dist_config.local_rank, find_unused_parameters=False, # More efficient for static graphs broadcast_buffers=True, # Important for normalization layers + static_graph=True, )
243-252: Periodic checkpointing passesuse_mfsdp; consider guarding by main for DDP pathLooks good. Optional: call
save_checkpointonly on rank 0 when not using MFSDP to avoid redundant file checks.- if step % cfg.training.save_every_n_steps == 0 and step > 0: # Skip step 0 + if step % cfg.training.save_every_n_steps == 0 and step > 0: # Skip step 0 # For mfsdp, always use distributed checkpointing - save_checkpoint( + if cfg.training.use_mfsdp or dist_config.is_main_process(): + save_checkpoint( use_mfsdp=cfg.training.use_mfsdp, model=model, optimizer=optimizer, ckpt_dir=ckpt_dir, dist_config=dist_config, logger=logger, step=step, - ) + )
273-281: Final save flow is correct; add a barrier before and after for MFSDPTo ensure all ranks finish parameter gathering and to avoid early teardown, add barriers around
save_final_model.- final_model_dir = os.path.join(ckpt_dir, "final_model") - save_final_model( + final_model_dir = os.path.join(ckpt_dir, "final_model") + if dist.is_initialized(): + dist.barrier() + save_final_model( model=model, use_mfsdp=cfg.training.use_mfsdp, save_directory=final_model_dir, logger=logger, is_main_process=dist_config.is_main_process(), ) + if dist.is_initialized(): + dist.barrier()recipes/geneformer_native_te_mfsdp_fp8/test_train.py (1)
365-397: Renamed MFSDP (non-TE) entrypoint: OK; 60s timeout may be tightOn slower CI with cold caches, 60s can flake. Consider 120s to reduce noise.
- timeout=60, # 1 minute timeout + timeout=120, # allow slower CI nodesrecipes/geneformer_native_te_mfsdp_fp8/test_monotonic_decreasing_loss.py (2)
180-183: Avoid duplicate Hydra overridestraining.resume_from_checkpoint=false is already appended in _run_training_with_config; drop the duplicate here to prevent confusion if defaults change later.
- loss_values, error = self._run_training_with_config( - "l0_sanity", - ["model.use_te_layers=false", "training.use_mfsdp=false", "training.resume_from_checkpoint=false"], - ) + loss_values, error = self._run_training_with_config( + "l0_sanity", + ["model.use_te_layers=false", "training.use_mfsdp=false"], + )
48-55: Make loss parsing robust to scientific notation and negativesCurrent regex misses values like 1.2e-3 or -0.1. Minor but cheap to harden.
- pattern = r"Step (\d+) loss: ([\d.]+)" + pattern = r"Step\s+(\d+)\s+loss:\s+([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)"recipes/geneformer_native_te_mfsdp_fp8/test_distributed_checkpointing.py (2)
169-179: Docstring says sanity_te_mfsdp but TE layers not enabled; align overridesThe two-process mfsdp checkpoint test docs claim “sanity_te_mfsdp,” but the command lacks model.use_te_layers=true. Align intent vs. config.
cmd_phase1 = [ "torchrun", "--nproc_per_node=2", train_script, "--config-name", "l0_sanity", - "training.use_mfsdp=true", + "training.use_mfsdp=true", + "model.use_te_layers=true", @@ cmd_phase2 = [ "torchrun", "--nproc_per_node=2", train_script, "--config-name", "l0_sanity", - "training.use_mfsdp=true", + "training.use_mfsdp=true", + "model.use_te_layers=true",Also applies to: 191-203
36-46: Set NCCL safety envs to reduce deadlock risk on multi-GPUHelps tests fail fast instead of hanging on fabric/network issues.
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_MODE"] = "disabled" +os.environ.setdefault("NCCL_ASYNC_ERROR_HANDLING", "1") +os.environ.setdefault("TORCH_NCCL_BLOCKING_WAIT", "1")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
⛔ Files ignored due to path filters (2)
recipes/esm2_native_te_mfsdp_thd/train.parquetis excluded by!**/*.parquetrecipes/geneformer_native_te_mfsdp_fp8/genecorpus_500_samples.parquetis excluded by!**/*.parquet
📒 Files selected for processing (23)
bionemo-recipes.md(4 hunks)models/.ruff.toml(0 hunks)recipes/README.md(3 hunks)recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_native_te_mfsdp/hydra_config/L1_15B_perf_test.yaml(1 hunks)recipes/esm2_native_te_mfsdp/train_ddp.py(2 hunks)recipes/esm2_native_te_mfsdp/train_mfsdp.py(2 hunks)recipes/esm2_native_te_mfsdp_thd/README.md(0 hunks)recipes/esm2_native_te_mfsdp_thd/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_native_te_mfsdp_thd/hydra_config/L1_15B_perf_test.yaml(1 hunks)recipes/esm2_native_te_mfsdp_thd/hydra_config/L1_650M.yaml(1 hunks)recipes/geneformer_native_te_mfsdp_fp8/AGENT_DOCUMENTATION.md(4 hunks)recipes/geneformer_native_te_mfsdp_fp8/README.md(5 hunks)recipes/geneformer_native_te_mfsdp_fp8/checkpoint.py(10 hunks)recipes/geneformer_native_te_mfsdp_fp8/hydra_config/106m.yaml(1 hunks)recipes/geneformer_native_te_mfsdp_fp8/hydra_config/10m.yaml(1 hunks)recipes/geneformer_native_te_mfsdp_fp8/hydra_config/4b.yaml(1 hunks)recipes/geneformer_native_te_mfsdp_fp8/hydra_config/defaults.yaml(1 hunks)recipes/geneformer_native_te_mfsdp_fp8/hydra_config/l0_sanity.yaml(1 hunks)recipes/geneformer_native_te_mfsdp_fp8/test_distributed_checkpointing.py(33 hunks)recipes/geneformer_native_te_mfsdp_fp8/test_monotonic_decreasing_loss.py(5 hunks)recipes/geneformer_native_te_mfsdp_fp8/test_train.py(5 hunks)recipes/geneformer_native_te_mfsdp_fp8/train.py(6 hunks)
💤 Files with no reviewable changes (2)
- models/.ruff.toml
- recipes/esm2_native_te_mfsdp_thd/README.md
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T16:40:04.315Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1078
File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108
Timestamp: 2025-08-28T16:40:04.315Z
Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.
Applied to files:
recipes/esm2_native_te_mfsdp/train_mfsdp.pyrecipes/esm2_native_te_mfsdp/train_ddp.pyrecipes/geneformer_native_te_mfsdp_fp8/train.py
🧬 Code graph analysis (2)
recipes/geneformer_native_te_mfsdp_fp8/checkpoint.py (1)
recipes/geneformer_native_te_mfsdp_fp8/train.py (1)
is_main_process(81-83)
recipes/geneformer_native_te_mfsdp_fp8/train.py (1)
recipes/geneformer_native_te_mfsdp_fp8/checkpoint.py (2)
save_checkpoint(138-196)save_final_model(221-259)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: unit-tests (recipes/esm2_native_te_mfsdp, svcbionemo023/bionemo-framework:pytorch25.06-py3-squash...
- GitHub Check: unit-tests (recipes/geneformer_native_te_mfsdp_fp8, svcbionemo023/bionemo-framework:pytorch25.06-...
- GitHub Check: unit-tests (recipes/esm2_native_te_mfsdp_thd, svcbionemo023/bionemo-framework:pytorch25.06-py3-sq...
🔇 Additional comments (16)
recipes/geneformer_native_te_mfsdp_fp8/hydra_config/106m.yaml (1)
16-16: No stale use_nvfsdp references detected
Runningrg -nP -C2 '\buse_nvfsdp\b|["'"'']use_nvfsdp["'"'']' --hiddenreturned no matches; the config rename touse_mfsdpis complete.recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml (1)
11-11: W&B run name updated to mfsdp_sanity—LGTM.recipes/geneformer_native_te_mfsdp_fp8/AGENT_DOCUMENTATION.md (1)
58-58: Doc text aligned to mFSDP—looks consistent.Also applies to: 77-77
recipes/esm2_native_te_mfsdp/train_mfsdp.py (2)
58-58: Docstring rename to mfsdp—LGTM.
70-70: Dependency naming verified
Requirements list “megatron-fsdp” and code imports “megatron_fsdp” as expected under Python’s naming conventions.recipes/esm2_native_te_mfsdp_thd/hydra_config/L0_sanity.yaml (1)
11-11: WandB run name rename looks good.
Name now reflects mfsdp; consistent with the PR intent.recipes/esm2_native_te_mfsdp_thd/hydra_config/L1_15B_perf_test.yaml (1)
11-11: Perf run name updated correctly.
Matches the new naming convention.recipes/esm2_native_te_mfsdp/train_ddp.py (1)
54-67: DDP device_mesh supported by current Docker base images: NVIDIA PyTorch containers 25.06 and 25.01 correspond to PyTorch 2.8.0a0 and 2.6.0a0 respectively (both ≥ 2.2) (docs.nvidia.com);torch.nn.parallel.DistributedDataParallelincludes adevice_meshparameter since PyTorch 2.7+ (docs.pytorch.org). No explicittorch==pin inrequirements*.txt; versions are managed via the base image.recipes/geneformer_native_te_mfsdp_fp8/hydra_config/10m.yaml (1)
16-16: Ensure removal of legacy nvfsdp keys
Theuse_mfsdpflag has been migrated totraining.use_mfsdp; please manually verify that nonvfsdporuse_nvfsdpentries remain in any Hydra configs (notably the198-distributedchoice in thedistributedgroup) and update them to use the new key.recipes/geneformer_native_te_mfsdp_fp8/README.md (1)
8-10: README rename is thorough and coherent.
mfsdp terminology, flags, and examples are consistent.Also applies to: 37-38, 48-50, 125-130, 141-144
recipes/README.md (1)
118-130: Docstring header updated correctly.
Reflects megatron-fsdp in the embedded example.recipes/geneformer_native_te_mfsdp_fp8/checkpoint.py (1)
78-84: Cannot verify MFSDP DCP load/save – please run tests with PyTorchTest collection failed due to missing
torchmodule; run the MFSDP distributed checkpoint tests in an environment with PyTorch installed and confirm that the optimizer state is correctly restored across ranks.recipes/geneformer_native_te_mfsdp_fp8/train.py (2)
151-165: MFSDP enablement path looks correctUsing
fully_shardwith explicit unit modules and device mesh matches the rename intent. No blockers spotted.
213-220: Resume path passesuse_mfsdpcorrectlyThe handoff to
load_checkpointaligns with the new flag. Good.recipes/geneformer_native_te_mfsdp_fp8/test_train.py (2)
482-486: Test runner order includes both MFSDP variants: goodNice coverage to run TE and non-TE MFSDP smoke tests when CUDA is available.
311-343: Hydra override confirmed –training.use_mfsdpis defined inl0_sanity.yaml, so the override will apply.
recipes/geneformer_native_te_mfsdp_fp8/test_distributed_checkpointing.py
Show resolved
Hide resolved
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
recipes/README.md (2)
52-52: Examples entry now consistent with mfsdp.This addresses the earlier inconsistency noted in prior reviews.
414-414: Bottom “Examples” list fixed to mfsdp.The earlier mismatch is resolved.
🧹 Nitpick comments (2)
recipes/README.md (1)
118-130: Minor wording tweak for accuracy around “ZeRO-3”.To avoid mixing DeepSpeed terminology with FSDP, suggest clarifying as “full sharding (ZeRO-3 equivalent)”.
- - We use megatron-fsdp ZeRO-3 for maximum memory efficiency + - We use megatron-fsdp full sharding (ZeRO-3 equivalent) for maximum memory efficiencyrecipes/esm2_native_te_mfsdp_thd/test_thd_format.py (1)
122-129: Re-enable a non-flaky MLM masking check gated by sample sizeInstead of commenting out, keep the ratio assertion only when the token count is large enough; this preserves coverage without flakiness on tiny batches.
Apply this diff:
- # masked_positions = (labels != -100).sum() - # total_positions = labels.numel() - # masking_ratio = masked_positions.float() / total_positions + masked_positions = (labels != -100).sum().item() + total_positions = labels.numel() + masking_ratio = masked_positions / max(1, total_positions) @@ - # TODO(jomitchell): Add this back if you have a larger dataset and this isn't as flaky. - # assert 0.05 <= masking_ratio <= 0.25, f"MLM masking ratio should be ~15%, got {masking_ratio:.1%}" + # Gate ratio check by sample size to avoid flakiness on tiny batches + if total_positions >= 100: + assert 0.05 <= masking_ratio <= 0.25, ( + f"MLM masking ratio should be ~15%, got {masking_ratio:.1%} (N={total_positions})" + )
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
recipes/README.md(6 hunks)recipes/esm2_native_te_mfsdp_thd/test_thd_format.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T16:40:04.315Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1078
File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108
Timestamp: 2025-08-28T16:40:04.315Z
Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.
Applied to files:
recipes/README.md
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: unit-tests (recipes/esm2_native_te_mfsdp, svcbionemo023/bionemo-framework:pytorch25.06-py3-squash...
- GitHub Check: unit-tests (recipes/geneformer_native_te_mfsdp_fp8, svcbionemo023/bionemo-framework:pytorch25.06-...
- GitHub Check: unit-tests (recipes/esm2_native_te_mfsdp_thd, svcbionemo023/bionemo-framework:pytorch25.06-py3-sq...
🔇 Additional comments (3)
recipes/README.md (2)
5-5: PyPI link looks correct.The switch to “megatron-fsdp” and the PyPI URL is valid (latest 0.1.0rc0 as of Aug 22, 2025). (data.safetycli.com)
245-245: W&B name updated to mfsdp — good.Matches the rename and keeps dashboards consistent.
recipes/esm2_native_te_mfsdp_thd/test_thd_format.py (1)
122-129: Verify RNG determinism for DataCollatorForLanguageModeling
Confirm whether your installed version ofDataCollatorForLanguageModelingsupports aseedkeyword—if it doesn’t, explicitly invoketorch.manual_seed(...)(and any other relevant RNG seeds) before calling the collator so that masking is deterministic.
Rename folders and files to no longer use the deprecated nvFSDP name, now uses 'megatron-fsdp' or 'mfsdp'
Summary by CodeRabbit
New Features/Improvements
Refactor
Documentation
Tests
Chores