-
Notifications
You must be signed in to change notification settings - Fork 108
ESM-2 Accelerate Recipes #1080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ESM-2 Accelerate Recipes #1080
Conversation
c9c09ae to
bd16031
Compare
jomitchellnv
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM a few nits
WalkthroughPer-call ROPE embedding handling was refactored in the ESM encoder. A new training recipe (esm2_accelerate) was added with Hydra configs, datasets, metrics, callbacks, Docker/SLURM assets, requirements, and tests. The amplify_accelerate_te_fp8 recipe gained a standalone callback, config tweak, test updates including multi-GPU, and a dataset API rename. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant Hydra
participant Trainer
participant Model
participant Dataset
participant Callback as StopAfterNStepsCallback
participant Metrics
User->>Hydra: Run train.py (config L0_sanity)
Hydra->>Trainer: Build TrainingArguments
Hydra->>Model: AutoConfig -> AutoModelForMaskedLM (bf16)
Hydra->>Dataset: create_datasets_and_collator(tokenizer, max_len)
Trainer->>Trainer: Initialize with model, datasets, data_collator, compute_metrics
Trainer->>Callback: Register StopAfterNStepsCallback(n)
Trainer->>Trainer: Train (resume if checkpoint found)
loop Each step
Trainer->>Callback: on_step_end(state.global_step)
Callback-->>Trainer: control.should_training_stop? (if >= n)
end
Trainer->>Metrics: compute_metrics(eval_pred) [if eval]
Trainer->>Trainer: Save metrics and checkpoint-last
sequenceDiagram
autonumber
participant Encoder as NVEsmEncoder
participant Buffer as te_rope_emb (CPU pinned buffer)
participant Hidden as hidden_states
participant Layer as TransformerLayer
Hidden->>Encoder: forward(hidden_states, seq_len)
alt Rotary enabled
Encoder->>Buffer: Access registered buffer
Encoder->>Encoder: Move/slice to device/dtype (non-blocking), length check
Encoder->>Layer: call(..., rotary_pos_emb=sliced_emb)
else Rotary disabled
Encoder->>Layer: call(..., rotary_pos_emb=None)
end
Layer-->>Encoder: outputs
Encoder-->>Hidden: outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 18
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
recipes/amplify_accelerate_te_fp8/train.py (1)
82-82: Guard destroy_process_group to prevent errors when DDP was not initialized.Unconditional teardown can raise on single-process or CPU runs.
Apply:
- torch.distributed.destroy_process_group() + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.destroy_process_group()
♻️ Duplicate comments (3)
models/esm2/src/esm/modeling_esm_te.py (1)
173-174: Confirm TE acceptsrotary_pos_emb=Nonefor non-ROPE paths.Upstream comment asked “is none acceptable?”. This path now passes
Nonewhenposition_embedding_type != "rotary". Please confirmTransformerLayer(..., rotary_pos_emb=None)is a supported code path in the TE version you target.If TE is available, run a quick import-time check:
#!/bin/bash python - <<'PY' import transformer_engine.pytorch as te, torch try: lyr = te.TransformerLayer(hidden_size=64, ffn_hidden_size=128, num_attention_heads=4, layernorm_epsilon=1e-5, hidden_dropout=0.0, attention_dropout=0.0, qkv_weight_interleaved=True, layer_number=1, layer_type="encoder", self_attn_mask_type="padding", activation="gelu", attn_input_format="bshd", seq_length=128, micro_batch_size=1, num_gqa_groups=4, fuse_qkv_params=True, params_dtype=torch.bfloat16, window_size=(-1, -1)) x = torch.zeros(1, 8, 64, dtype=torch.bfloat16, device="cuda" if torch.cuda.is_available() else "cpu") m = torch.zeros(1,1,1,8, dtype=torch.bool, device=x.device) lyr(x, m, rotary_pos_emb=None) print("OK: rotary_pos_emb=None accepted") except Exception as e: print("FAIL:", e) PYrecipes/esm2_accelerate/hydra_config/defaults.yaml (1)
20-21: Scheduler differs from original ESM-2 (WarmupAnnealDecay).Not blocking, but note divergence from upstream training schedule.
If you want parity, I can draft a Hydra-configurable Warmup+Anneal+MinLR scheduler adapter compatible with HF Trainer.
recipes/esm2_accelerate/test_train.py (1)
142-145: Nit: variable name in assertion message is about accelerate, not deepspeed.Covered by a later diff; ensure both single- and multi-GPU messages are consistent.
🧹 Nitpick comments (29)
recipes/esm2_accelerate/.dockerignore (1)
1-9: Tighten Docker context excludes; add common heavy dirs and patterns.Recommend ignoring VCS, venvs, eggs, caches, and Python artifacts to shrink build context and speed builds.
Dockerfile README.md checkpoint_export/ outputs/ .ruff_cache __pycache__ .pytest_cache .ruff.toml .dockerignore +.git +.gitignore +.venv/ +**/__pycache__/ +*.py[cod] +*.egg-info/ +.mypy_cache/ +.cache/ +dist/ +build/ +data/recipes/esm2_accelerate/accelerate_config/default.yaml (1)
3-11: Checkdistributed_type: MULTI_GPUwithnum_processes: 1.If you expect 1 proc/GPU, this should be overridden by launcher or set accordingly; otherwise you'll run single-process on multi-GPU.
How are tests invoking accelerate (CLI flags vs config override)? If config-driven, consider documenting that
num_processesmust match GPU count for multi-GPU runs.recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml (1)
1-23: Add FSDP2-friendly flags
- Under
fsdp_config, add(verify FSDP v2 supports these keys)fsdp_sync_module_states: true fsdp_limit_all_gathers: true- Ensure
num_processes: 1is intended only when always overridden via CLIrecipes/amplify_accelerate_te_fp8/callbacks.py (3)
31-34: Return the updated control object from callback.Transformers callbacks typically return
controlto propagate changes. Make the return explicit.def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """Interrupt training after a specified number of steps.""" if state.global_step >= self.max_steps: control.should_training_stop = True + return control
27-29: Validatemax_stepsearly to avoid accidental no-ops.Guard against
max_steps <= 0.def __init__(self, max_steps: int): """Initialize the callback.""" - self.max_steps = max_steps + if max_steps <= 0: + raise ValueError(f"max_steps must be > 0, got {max_steps}") + self.max_steps = max_steps
20-34: Reduce duplication across recipes.This class duplicates
recipes/esm2_accelerate/callbacks.py. Consider a shared utility (e.g.,recipes/common/callbacks.py) to keep a single source.recipes/esm2_accelerate/metrics.py (1)
26-33: Handle nested NumPy arrays as well.
nested_cpuskips NumPy arrays. If you keep this helper, extend it to convert NumPy arrays to tensors.- return tensors.detach().cpu() if isinstance(tensors, torch.Tensor) else tensors + if isinstance(tensors, torch.Tensor): + return tensors.detach().cpu() + try: + import numpy as np # local import to avoid hard dep if unused + if isinstance(tensors, np.ndarray): + return torch.from_numpy(tensors).detach().cpu() + except Exception: + pass + return tensorsrecipes/esm2_accelerate/Dockerfile (1)
8-8: Optional: minimize context churn.If the image only needs training scripts, consider copying just the recipe subtree to speed builds and reduce cache invalidation.
recipes/esm2_accelerate/hydra_config/README.md (2)
3-4: Fix broken sentence wrap.Merge the split phrase.
-We could expand this to include configs for full-scale convergence testing, partial conv -experiments, etc. +We could expand this to include configs for full-scale convergence testing, partial convergence experiments, etc.
10-12: Grammar and clarity nits.Tighten spacing and casing.
-- Specifying `bf16` in the `TrainingArguments` class will override fp8 settings given to accelerate. - This causes issues with the `deepspeed` backend, since HF will check to make sure the bf16 +- Specifying `bf16` in the `TrainingArguments` class will override FP8 settings given to Accelerate. + This can cause issues with the `deepspeed` backend, since HF will check to make sure the bf16 settings are the same between HF and Deepspeed settings.recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1)
6-14: Make sanity runs deterministic; ensure eval triggers.
- Add a fixed
seedfor reproducibility.- If not set in
defaults.yaml, specifyevaluation_strategy: "steps"soeval_stepsis honored.trainer: run_name: "esm2_t6_8M_UR50D_sanity" + seed: 42 per_device_train_batch_size: 2 per_device_eval_batch_size: 2 save_steps: 2 eval_steps: 2 + evaluation_strategy: "steps" logging_steps: 1 report_to: "none" dataloader_num_workers: 0recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1)
3-5: Sanity check: single-process FSDP.distributed_type=FSDP with num_processes: 1 is fine for dev sanity, but won’t shard across GPUs. If the intent is single-GPU testing, consider distributed_type: NO here and keep FSDP in a separate multi-GPU config.
recipes/esm2_accelerate/slurm.sh (2)
3-3: Directive/comment mismatch.--ntasks-per-node=1 conflicts with the “one task per gpu” comment. Either:
- keep 1 task per node (accelerate spawns per-GPU workers), and fix the comment; or
- set --ntasks-per-node=8 and add --gpus-per-task=1.
-#SBATCH --ntasks-per-node=1 # n tasks per machine (one task per gpu) <required> +#SBATCH --ntasks-per-node=1 # one task per node; accelerate spawns per-GPU workers
12-21: Consistent deferred expansion for SLURM vars inside $CMD.You escaped $SLURM_NODEID and $SLURM_SRUN_COMM_HOST but not $SLURM_NNODES. For consistency and to avoid host-side expansion, escape it too.
- --num_machines "$SLURM_NNODES" \ + --num_machines "\$SLURM_NNODES" \recipes/amplify_accelerate_te_fp8/test_train.py (1)
213-249: Add existence check for accelerate config used in multi-GPU test.Prevents opaque failures if bf16_config.yaml is missing.
recipe_dir = Path(__file__).parent train_py = recipe_dir / "train.py" - cmd = [ + accelerate_config_path = recipe_dir / "accelerate_config" / "bf16_config.yaml" + assert accelerate_config_path.exists(), f"bf16_config.yaml not found at {accelerate_config_path}" + + cmd = [ sys.executable, "-m", "accelerate.commands.launch", "--config_file", - str(recipe_dir / "accelerate_config" / "bf16_config.yaml"), + str(accelerate_config_path),recipes/esm2_accelerate/train.py (1)
41-41: Optional: load from_pretrained if weights are desired.If you intend to compare TE vs HF with pretrained ESM-2 weights, consider from_pretrained(args.model_tag) rather than from_config.
- model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16) + model = AutoModelForMaskedLM.from_pretrained(args.model_tag, config=config, trust_remote_code=True, torch_dtype=torch.bfloat16)recipes/esm2_accelerate/callbacks.py (2)
31-34: Returncontrolfrom the callback for consistency with HFCallbackHandler.Mutating
controlin-place usually works, but returning it is the recommended pattern and avoids surprises with multiple callbacks.Apply:
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): """Interrupt training after a specified number of steps.""" if state.global_step >= self.max_steps: control.should_training_stop = True + return control
27-29: Validatemax_steps> 0 to avoid immediate stop on step 0.Defensive check improves ergonomics and error messages.
Apply:
def __init__(self, max_steps: int): """Initialize the callback.""" - self.max_steps = max_steps + if not isinstance(max_steps, int) or max_steps <= 0: + raise ValueError(f"max_steps must be a positive int, got {max_steps!r}") + self.max_steps = int(max_steps)recipes/esm2_accelerate/dataset.py (2)
70-74: Use a sanepad_to_multiple_of(e.g., 8) instead ofmax_length.Padding to
max_lengthis redundant withpadding="max_length"and can be misleading.Apply:
- data_collator = DataCollatorForLanguageModeling( + data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm_probability=0.15, - pad_to_multiple_of=max_length, + pad_to_multiple_of=8, )
28-36:infinite_dataloaderassumessampler.set_epochexists. Guard for non-distributed samplers or drop this helper if unused.SequentialSampler does not implement
set_epoch.Apply:
def infinite_dataloader(dataloader, sampler): """Create an infinite iterator that automatically restarts at the end of each epoch.""" epoch = 0 while True: - sampler.set_epoch(epoch) # Update epoch for proper shuffling + if hasattr(sampler, "set_epoch"): + sampler.set_epoch(epoch) # Update epoch for proper shuffling for batch in dataloader: yield batch epoch += 1 # Increment epoch counter after completing one full passrecipes/esm2_accelerate/test_train.py (9)
24-26: Only set TRITON_LIBCUDA_PATH if unset; avoid overriding env in CI.Apply:
-# Set TRITON_LIBCUDA_PATH before check_fp8_support import that triggers Triton initialization -os.environ["TRITON_LIBCUDA_PATH"] = "/usr/local/cuda/lib64" +# Set TRITON_LIBCUDA_PATH before any TE import that may trigger Triton initialization +os.environ.setdefault("TRITON_LIBCUDA_PATH", "/usr/local/cuda/lib64")
27-36: Remove unused FP8 support check and import to reduce side effects.These values are unused and can fail on CPU-only CI.
Apply:
-import pytest -import torch -from hydra import compose, initialize_config_dir -from transformer_engine.pytorch.fp8 import check_fp8_support +import pytest +import torch +from hydra import compose, initialize_config_dir @@ -_fp8_available, _fp8_reason = check_fp8_support() +# (Intentionally not importing FP8/TE here; tests don't require it.)
67-71: Make checkpoint assertions robust to config changes.Asserting exactly 2 checkpoints is brittle. Check presence and minimum expected count.
Apply:
- assert len(checkpoint_dirs) == 2, ( - f"Expected 2 checkpoint directories, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}" - ) + assert len(checkpoint_dirs) >= 1, ( + f"Expected at least 1 checkpoint directory, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}" + )Optional: also assert specific ones if your save_steps are stable:
expected = {tmp_path / "checkpoint-2", tmp_path / "checkpoint-4"} assert expected & set(checkpoint_dirs), f"Missing expected checkpoints: {expected}"
90-96: Fix misleading comment (mentions checkpoint-10).Apply:
- # Remove the checkpoint-10 and checkpoint-last directories + # Remove the checkpoint-4 and checkpoint-last directories
146-148: Correct error message to reference accelerate config.Apply:
- assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}" + assert accelerate_config_path.exists(), f"accelerate config not found at {accelerate_config_path}"
171-172: Increase timeout for slow CI runners.Four minutes can be tight; 10 minutes is safer for cold environments.
Apply:
- timeout=240, + timeout=600,
202-204: Correct error message to reference accelerate config (multi-GPU).Apply:
- assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}" + assert accelerate_config_path.exists(), f"accelerate config not found at {accelerate_config_path}"
227-228: Increase timeout for multi-GPU run.Apply:
- timeout=240, + timeout=600,
151-163: Hydra config naming: consider dropping.yamlextension for--config-name.Hydra typically expects the name without extension; leaving as-is works if your decorator uses the same string, but normalizing reduces surprises.
Apply (optional):
- "--config-name", - "L0_sanity.yaml", + "--config-name", + "L0_sanity",
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
recipes/esm2_accelerate/train.parquetis excluded by!**/*.parquet
📒 Files selected for processing (24)
models/esm2/src/esm/modeling_esm_te.py(2 hunks)recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml(1 hunks)recipes/amplify_accelerate_te_fp8/callbacks.py(1 hunks)recipes/amplify_accelerate_te_fp8/test_train.py(2 hunks)recipes/amplify_accelerate_te_fp8/train.py(1 hunks)recipes/esm2_accelerate/.dockerignore(1 hunks)recipes/esm2_accelerate/.ruff.toml(1 hunks)recipes/esm2_accelerate/Dockerfile(1 hunks)recipes/esm2_accelerate/README.md(1 hunks)recipes/esm2_accelerate/accelerate_config/default.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml(1 hunks)recipes/esm2_accelerate/callbacks.py(1 hunks)recipes/esm2_accelerate/dataset.py(1 hunks)recipes/esm2_accelerate/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_accelerate/hydra_config/README.md(1 hunks)recipes/esm2_accelerate/hydra_config/defaults.yaml(1 hunks)recipes/esm2_accelerate/metrics.py(1 hunks)recipes/esm2_accelerate/requirements.txt(1 hunks)recipes/esm2_accelerate/slurm.sh(1 hunks)recipes/esm2_accelerate/test_train.py(1 hunks)recipes/esm2_accelerate/train.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
recipes/esm2_accelerate/dataset.py (1)
sub-packages/bionemo-size-aware-batching/tests/bionemo/size_aware_batching/conftest.py (1)
sampler(92-93)
recipes/esm2_accelerate/callbacks.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (2)
StopAfterNStepsCallback(20-34)on_step_end(31-34)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
StopAfterNStepsCallback(20-34)recipes/esm2_accelerate/dataset.py (1)
create_datasets_and_collator(38-76)recipes/esm2_accelerate/metrics.py (1)
compute_metrics(36-54)
recipes/amplify_accelerate_te_fp8/train.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
StopAfterNStepsCallback(20-34)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
recipes/esm2_accelerate/callbacks.py (2)
StopAfterNStepsCallback(20-34)on_step_end(31-34)
recipes/amplify_accelerate_te_fp8/test_train.py (1)
recipes/esm2_accelerate/test_train.py (1)
test_accelerate_launch_multi_gpu(194-233)
recipes/esm2_accelerate/test_train.py (2)
sub-packages/bionemo-testing/src/bionemo/testing/torch.py (1)
check_fp8_support(21-33)recipes/esm2_accelerate/train.py (1)
main(36-72)
🪛 LanguageTool
recipes/esm2_accelerate/hydra_config/README.md
[grammar] ~10-~10: There might be a mistake here.
Context: ...erride fp8 settings given to accelerate. This causes issues with the `deepspeed...
(QB_NEW_EN)
🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh
[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.
(SC1073)
[error] 24-24: Fix any mentioned problems and try again.
(SC1072)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (9)
recipes/esm2_accelerate/.ruff.toml (1)
1-1: Leave extend path unchanged.../.ruff.tomlcorrectly points torecipes/.ruff.toml(the common config for all recipes); there is no top-level.ruff.tomlto reference.Likely an incorrect or invalid review comment.
recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml (1)
1-22: Confirm Accelerate FSDP settings compatibility and optionally enable sync/init knobs
- Ensure your installed
accelerateversion supportsfsdp_version: 1and thefsdp_configkeys infsdp1_te.yaml(e.g. by running a quick local test).- To improve initialization determinism and communication behavior, consider adding:
fsdp_sync_module_states: true fsdp_limit_all_gathers: true- Note that
num_processes: 1is hardcoded here; override via--num_processes(or other CLI args) for multi-GPU runs.recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml (1)
1-23: Add optional FSDP2 sync/all-gathers flags and verify Accelerate support
- Under
fsdp_config, add:fsdp_sync_module_states: true fsdp_limit_all_gathers: true- Confirm your installed
acceleraterelease recognizesfsdp_version: 2(e.g. runand check it meets the minimum version that added FSDP2 support).python -c "import accelerate; print(accelerate.__version__)"recipes/esm2_accelerate/metrics.py (2)
35-55: Fixcompute_metricscontract and tensor handling.
- Signature must match HF Trainer: it should accept only
EvalPredictionand return a dict.- Current code may receive NumPy arrays;
torchmetrics.Perplexityexpects tensors.- Returning
Nonewhencompute_result=Falsewill break logging.[sraise_critical_issue -> raise_critical_issue]
-@torch.no_grad() -def compute_metrics(eval_pred: transformers.EvalPrediction, compute_result: bool): +@torch.no_grad() +def compute_metrics(eval_pred: transformers.EvalPrediction): @@ - logits, labels = eval_pred - logits = nested_cpu(logits) - labels = nested_cpu(labels) - perplexity(logits, labels) - - if compute_result: - loss = perplexity.compute() - perplexity.reset() - return {"perplexity": loss} + logits, labels = eval_pred + # Convert possible NumPy inputs to torch tensors on CPU + if not isinstance(logits, torch.Tensor): + logits = torch.from_numpy(logits) + if not isinstance(labels, torch.Tensor): + labels = torch.from_numpy(labels) + logits = logits.detach().cpu() + labels = labels.detach().cpu() + + perplexity(logits, labels) + value = perplexity.compute() + perplexity.reset() + return {"perplexity": float(value) if isinstance(value, torch.Tensor) else value}
23-24: Confirm Perplexity input expectations (logits vs probabilities).Ensure
Perplexityis configured for raw logits; otherwise applylog_softmaxor adjust the metric’snormalize/log_probsettings.Would you like me to check the exact TorchMetrics version semantics and align the code accordingly?
recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1)
6-10: Verify correct FSDP transformer layer class
‘EsmLayer’ isn’t found in the repo and likely isn’t the HF ESM layer name. Confirm the actual HuggingFace ESM transformer layer class (for example,EsmEncoderLayer) and updatefsdp_transformer_layer_cls_to_wrapaccordingly to ensure TRANSFORMER_BASED_WRAP functions as intended.recipes/amplify_accelerate_te_fp8/test_train.py (2)
37-41: Good multi-GPU gating.Skip mark correctly guards for >=2 GPUs.
183-197: Command construction and added trainer flags look good.List-form subprocess with explicit trainer overrides is solid.
recipes/amplify_accelerate_te_fp8/train.py (1)
47-51: Ignore typo warning Thecreate_datasets_and_collatorsignature inrecipes/amplify_accelerate_te_fp8/dataset.py(line 29) definespretained_model, matching its invocation intrain.py(line 48), so this won’t crash. Rename both occurrences only if you want to correct the spelling.Likely an incorrect or invalid review comment.
recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (2)
recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (1)
17-18: Aligndynamo_configusage across configs and verify casing.Other configs in this PR reportedly use a scalar (e.g.,
dynamo_config: "NO"). Standardize on the mapping form repo-wide; also verify whether your Accelerate version prefersnovsNOfor the backend and unify casing.Run:
#!/bin/bash # Find scalar (incorrect) usages and mapping (correct) usages rg -nP -C2 '^\s*dynamo_config:\s*"(?:NO|INDUCTOR|no|inductor)"\s*$' --glob '!**/site-packages/**' rg -nP -C2 '^\s*dynamo_config:\s*$\n^\s{2,}dynamo_backend:\s*(?:NO|INDUCTOR|no|inductor)\b' -U --glob '!**/site-packages/**'recipes/esm2_accelerate/train.py (1)
49-57: Trainer will raise TypeError: compute_metrics expects 2 args, Trainer passes 1. Wrap it.This is not just a typing nuisance; with the current signature it will fail at runtime. Wrap with a lambda (or functools.partial).
Apply this diff:
- compute_metrics=compute_metrics, + compute_metrics=lambda pred: compute_metrics(pred, compute_result=True),
🧹 Nitpick comments (2)
models/esm2/src/esm/modeling_esm_te.py (1)
164-176: Optional: cache per-device ROPE to avoid repeated H2D copies.Current code copies CPU-pinned ROPE to device every forward. For long seqs this can be a measurable overhead. Cache the full-length per (device, dtype) and slice each step.
Apply within forward:
- if self.te_rope_emb is not None: - te_rope_emb = self.te_rope_emb.to(device=hidden_states.device, - dtype=hidden_states.dtype, - non_blocking=True) + if self.te_rope_emb is not None: + cache_key = (hidden_states.device, hidden_states.dtype) + if not hasattr(self, "_te_rope_cache"): + self._te_rope_cache = {} + te_rope_emb = self._te_rope_cache.get(cache_key) + if te_rope_emb is None: + te_rope_emb = self.te_rope_emb.to(device=hidden_states.device, + dtype=hidden_states.dtype, + non_blocking=True) + self._te_rope_cache[cache_key] = te_rope_emb seq_len = hidden_states.shape[1] if te_rope_emb.size(0) < seq_len: raise RuntimeError( f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. " f"Increase max_position_embeddings." ) te_rope_emb = te_rope_emb[:seq_len]Note: _te_rope_cache is a plain attribute (not a buffer), so it won’t bloat checkpoints.
recipes/esm2_accelerate/train.py (1)
39-41: Align config.torch_dtype with model dtype to keep TE params consistent.TE layers read params_dtype from config.torch_dtype. Set it explicitly to match model construction.
Apply this diff:
- model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16) + config.torch_dtype = torch.bfloat16 + model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
models/esm2/src/esm/modeling_esm_te.py(2 hunks)recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml(1 hunks)recipes/esm2_accelerate/train.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T15:58:24.658Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1080
File: recipes/esm2_accelerate/train.py:47-55
Timestamp: 2025-08-28T15:58:24.658Z
Learning: HuggingFace Transformers has a known typing issue where Trainer expects compute_metrics to have signature Callable[[EvalPrediction], Dict], but functions with additional parameters cause type checker warnings. This is a long-standing typing limitation in the framework, not a runtime error. Workarounds include wrapping functions with functools.partial, lambdas, or type casting.
Applied to files:
recipes/esm2_accelerate/train.py
🧬 Code graph analysis (1)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
StopAfterNStepsCallback(20-34)recipes/esm2_accelerate/dataset.py (1)
create_datasets_and_collator(38-76)recipes/esm2_accelerate/metrics.py (1)
compute_metrics(36-54)
🪛 GitHub Actions: BioNeMo Framework CI
models/esm2/src/esm/modeling_esm_te.py
[error] 142-149: Ruff format error: Invalid syntax due to indentation in modeling_esm_te.py (lines 142-149).
recipes/esm2_accelerate/train.py
[error] 44-45: Ruff format error: Invalid syntax/indentation around patch hunks in train.py (lines 44-45).
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (5)
recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (2)
17-18: Correct Accelerate schema fordynamo_config. LGTM.Using a mapping with
dynamo_backend: "NO"is valid and fixes the earlier schema issue.
3-3:distributed_type: MULTI_GPUwithnum_processes: 1—confirm intent.Accelerate typically runs one process per GPU for MULTI_GPU. If you intend single-GPU, set
distributed_type: "SINGLE_GPU"; if multi-GPU, setnum_processesto the GPU count (or override via CLI/launcher).Also applies to: 10-10
models/esm2/src/esm/modeling_esm_te.py (1)
164-176: Good: dtype/device match + explicit length guard on ROPE.Casting to hidden_states.dtype, non_blocking move, and the clear error on insufficient ROPE length are all correct and prevent mixed-precision/shape bugs.
recipes/esm2_accelerate/train.py (2)
35-35: Hydra config_name fixed (no .yaml).Good correction; this avoids config resolution issues.
59-59: LGTM: accelerator state logging.Safe and useful for debugging distributed setups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
recipes/amplify_accelerate_te_fp8/train.py (1)
47-51: Renamepretained_modeltopretrained_modelacross dataset and train files
- In recipes/amplify_accelerate_te_fp8/dataset.py, update the function signature (line 29), docstring (line 36) and internal use (line 46) from
pretained_modeltopretrained_model.- In recipes/amplify_accelerate_te_fp8/train.py (line 48), change the keyword argument to
pretrained_model=args.model_tag.
♻️ Duplicate comments (5)
recipes/amplify_accelerate_te_fp8/train.py (1)
27-29: Use package-relative imports for local modules (callbacks, dataset, metrics).Prevents accidental shadowing on varied PYTHONPATH and ensures intra-package resolution.
-from callbacks import StopAfterNStepsCallback -from dataset import create_datasets_and_collator -from metrics import compute_metrics +from .callbacks import StopAfterNStepsCallback +from .dataset import create_datasets_and_collator +from .metrics import compute_metricsEnsure this directory is a package (has init.py).
recipes/esm2_accelerate/README.md (1)
3-9: README example looks consistent with repo (default.yaml).The earlier nonexistent config references are gone; this is good.
recipes/esm2_accelerate/slurm.sh (2)
24-25: Invalid placeholder breaks shell; provide safe default and create directory.Use parameter expansion and mkdir -p.
Apply:
-# Mount a persistent cache directory to cache dataset downloads and transformations. -export CACHE_DIR=<cache_dir> +# Mount a persistent cache directory to cache dataset downloads and transformations. +export CACHE_DIR="${CACHE_DIR:-$PWD/.cache/huggingface}" +mkdir -p "$CACHE_DIR"
10-15: Pre-flight check: ensure accelerate config exists before srun.Fail fast with a clear message if missing.
Apply:
ulimit -c 0 +if [ ! -f accelerate_config/fsdp2_te.yaml ]; then + echo "Missing accelerate_config/fsdp2_te.yaml" + exit 1 +fi + export GPUS_PER_NODE=8recipes/esm2_accelerate/train.py (1)
27-30: Make imports Hydra-proof (cwd changes).Use package-qualified imports with fallback.
Apply:
-from callbacks import StopAfterNStepsCallback -from dataset import create_datasets_and_collator -from metrics import compute_metrics +try: + from recipes.esm2_accelerate.callbacks import StopAfterNStepsCallback + from recipes.esm2_accelerate.dataset import create_datasets_and_collator + from recipes.esm2_accelerate.metrics import compute_metrics +except ImportError: + from callbacks import StopAfterNStepsCallback + from dataset import create_datasets_and_collator + from metrics import compute_metrics
🧹 Nitpick comments (8)
recipes/amplify_accelerate_te_fp8/train.py (1)
55-63: Guard callback construction when stop_after_n_steps is unset/0.Avoids passing None/0 to the callback and keeps Trainer config cleaner.
- trainer = Trainer( + callbacks = [] + if getattr(args, "stop_after_n_steps", None): + callbacks.append(StopAfterNStepsCallback(args.stop_after_n_steps)) + + trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics, data_collator=data_collator, - callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)], + callbacks=callbacks, )recipes/esm2_accelerate/dataset.py (3)
66-68: Be explicit: set_transform should not return original columns.Explicitly set output_all_columns=False to avoid accidental passthrough if defaults change.
Apply:
- for dataset in [train_dataset, eval_dataset]: - dataset.set_transform(tokenize_function) + for dataset in (train_dataset, eval_dataset): + dataset.set_transform(tokenize_function, output_all_columns=False)
69-73: Redundant collator padding arg (already padded to max_length).Since tokenization uses padding="max_length", pad_to_multiple_of=max_length is unnecessary.
Apply:
- data_collator = DataCollatorForLanguageModeling( - tokenizer=tokenizer, - mlm_probability=0.15, - pad_to_multiple_of=max_length, - ) + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm_probability=0.15, + )
50-52: Ensuretrain.parquetis present at runtime or document its requirement
Add a runtime existence check (e.g.if not data_path.exists(): raise FileNotFoundError(...)) or clearly document thatrecipes/esm2_accelerate/train.parquetmust be packaged alongside the code.recipes/esm2_accelerate/slurm.sh (3)
8-9: Harden shell options.Add pipefail for safer error propagation.
Apply:
-set -x -e +set -euo pipefail +set -x
11-21: Allow configurable main port; avoid hardcoding.Support MAIN_PORT override.
Apply:
export GPUS_PER_NODE=8 +export MAIN_PORT="${MAIN_PORT:-12340}" export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \ accelerate launch \ --config_file accelerate_config/fsdp2_te.yaml \ --machine_rank \"\$SLURM_NODEID\" \ --num_machines \"$SLURM_NNODES\" \ --main_process_ip \"\$SLURM_SRUN_COMM_HOST\" \ - --main_process_port 12340 \ + --main_process_port \"$MAIN_PORT\" \ --num_processes \"$(( $SLURM_NNODES * $GPUS_PER_NODE ))\" \ train.py "
26-29: .netrc mount may fail if file absent; create a zero-byte file with safe perms.Avoid srun mount errors.
Apply:
-srun \ +touch "$HOME/.netrc" && chmod 600 "$HOME/.netrc" +srun \recipes/esm2_accelerate/train.py (1)
38-45: Optional: Pin model revision when trusting remote code.Passing trust_remote_code=True without a revision reduces reproducibility.
Would you like me to add a model_revision field to hydra_config and thread it through AutoConfig.from_pretrained(..., revision=args.model_revision)?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
recipes/amplify_accelerate_te_fp8/train.py(2 hunks)recipes/esm2_accelerate/README.md(1 hunks)recipes/esm2_accelerate/accelerate_config/default.yaml(1 hunks)recipes/esm2_accelerate/dataset.py(1 hunks)recipes/esm2_accelerate/slurm.sh(1 hunks)recipes/esm2_accelerate/train.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- recipes/esm2_accelerate/accelerate_config/default.yaml
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T15:58:24.658Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1080
File: recipes/esm2_accelerate/train.py:47-55
Timestamp: 2025-08-28T15:58:24.658Z
Learning: HuggingFace Transformers has a known typing issue where Trainer expects compute_metrics to have signature Callable[[EvalPrediction], Dict], but functions with additional parameters cause type checker warnings. This is a long-standing typing limitation in the framework, not a runtime error. Workarounds include wrapping functions with functools.partial, lambdas, or type casting.
Applied to files:
recipes/esm2_accelerate/train.py
🧬 Code graph analysis (3)
recipes/esm2_accelerate/dataset.py (1)
sub-packages/bionemo-size-aware-batching/tests/bionemo/size_aware_batching/conftest.py (1)
sampler(92-93)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
StopAfterNStepsCallback(20-34)recipes/esm2_accelerate/dataset.py (1)
create_datasets_and_collator(36-75)recipes/esm2_accelerate/metrics.py (1)
compute_metrics(36-54)
recipes/amplify_accelerate_te_fp8/train.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
StopAfterNStepsCallback(20-34)
🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh
[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.
(SC1073)
[error] 24-24: Fix any mentioned problems and try again.
(SC1072)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (1)
recipes/amplify_accelerate_te_fp8/train.py (1)
82-83: Good: safely destroy the process group only when initialized.Prevents errors on non-distributed runs and aligns with best practices.
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
738690a to
31e9681
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
recipes/amplify_accelerate_te_fp8/dataset.py (1)
35-43: Docstring mismatch fordata_sizeparameter
Update the docstring to replace"small"with"parquet"and to note that"sanity"selects 10 examples, or adjust the code if you intend 100:- data_size: The size of the dataset to load. If "full", use and pre-process the full UR100P - CSV dataset. This takes a long time without a cached dataset. If "small", use and - pre-process the parquet version of the dataset, which is much faster than "full". If - "sanity", truncates the evaluation dataset to 100 examples. + data_size: The size of the dataset to load. If "full", use and pre-process the full UR100P + CSV dataset. This takes a long time without a cached dataset. If "parquet", use and + pre-process the parquet version of the dataset, which is much faster than "full". If + "sanity", truncates the evaluation dataset to 10 examples.Alternatively, if you want
"sanity"to yield 100 examples, change.select(range(10))to.select(range(100)).recipes/esm2_accelerate/metrics.py (1)
35-55: Fix Trainer incompatibility: wrong signature and numpy inputs cause runtime errorHugging Face’s Trainer calls compute_metrics(eval_pred) with numpy arrays. Your current signature requires a second arg and passes numpy arrays to a TorchMetrics module — this will fail. Make it single‑arg, convert numpy→torch, and compute perplexity directly.
Apply:
-@torch.no_grad() -def compute_metrics(eval_pred: transformers.EvalPrediction, compute_result: bool): - """Compute perplexity metrics for the evaluation set. - - Args: - eval_pred: A tuple containing the logits and labels for the evaluation set. - compute_result: A boolean indicating whether to compute the perplexity metrics. - - Returns: - A dictionary containing the perplexity metrics. - """ - logits, labels = eval_pred - logits = nested_cpu(logits) - labels = nested_cpu(labels) - perplexity(logits, labels) - - if compute_result: - loss = perplexity.compute() - perplexity.reset() - return {"perplexity": loss} +@torch.no_grad() +def compute_metrics(eval_pred: transformers.EvalPrediction): + """Return {"perplexity": float} computed from logits/labels provided by Trainer.""" + # EvalPrediction exposes .predictions and .label_ids; Trainer provides numpy arrays + logits = getattr(eval_pred, "predictions", eval_pred[0]) + labels = getattr(eval_pred, "label_ids", eval_pred[1]) + # numpy -> torch on CPU + if not isinstance(logits, torch.Tensor): + logits = torch.from_numpy(logits) + if not isinstance(labels, torch.Tensor): + labels = torch.from_numpy(labels) + # Flatten to (N*L, V) vs (N*L,) and ignore pad label -100 + vocab = logits.shape[-1] + loss = torch.nn.functional.cross_entropy( + logits.view(-1, vocab), labels.view(-1), ignore_index=-100 + ) + ppl = torch.exp(loss).item() + return {"perplexity": ppl}
♻️ Duplicate comments (6)
recipes/esm2_accelerate/accelerate_config/default.yaml (1)
17-18: Resolved: dynamo_config is now a mapping; backend value is valid.Switching to:
dynamo_config: dynamo_backend: "NO"matches accelerate’s schema and accepted enum values. ✔️ (huggingface.co)
recipes/amplify_accelerate_te_fp8/train.py (1)
27-27: Use a package-relative import for callbacks (duplicate).Avoid shadowing top-level modules; import from the recipe package.
Apply:
-from callbacks import StopAfterNStepsCallback +from .callbacks import StopAfterNStepsCallbackrecipes/esm2_accelerate/train.py (3)
27-30: Make imports Hydra-proof (CWD changes)Use package-qualified imports with a fallback so imports survive Hydra’s chdir.
Apply:
-from callbacks import StopAfterNStepsCallback -from dataset import create_datasets_and_collator -from metrics import compute_metrics +try: + from recipes.esm2_accelerate.callbacks import StopAfterNStepsCallback + from recipes.esm2_accelerate.dataset import create_datasets_and_collator + from recipes.esm2_accelerate.metrics import compute_metrics +except ImportError: + from callbacks import StopAfterNStepsCallback # type: ignore + from dataset import create_datasets_and_collator # type: ignore + from metrics import compute_metrics # type: ignore
41-41: torch_dtype is ignored by from_configRemove torch_dtype here and cast after creating TrainingArguments (or rely on training_args.bf16).
Apply:
-model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16) +model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)Then after TrainingArguments:
training_args = TrainingArguments(**args.trainer) + +if getattr(training_args, "bf16", False): + model.to(torch.bfloat16)
50-58: Trainer will crash: compute_metrics signature mismatchYou pass compute_metrics as-is, but the function currently expects two args. Bind the extra param or (preferred) adopt the single-arg version suggested in metrics.py.
If you keep the two-arg function, bind it:
- trainer = Trainer( + compute_metrics_fn = functools.partial(compute_metrics) # bind any extras if present + trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, - compute_metrics=compute_metrics, + compute_metrics=compute_metrics_fn, data_collator=data_collator, callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)], )If you apply the metrics.py refactor (single-arg), simply keep compute_metrics=compute_metrics.
recipes/esm2_accelerate/slurm.sh (1)
24-24: Invalid placeholder breaks shellexport CACHE_DIR=<cache_dir> is invalid and fails ShellCheck. Provide a sane default and ensure it exists.
Apply:
-export CACHE_DIR=<cache_dir> +export CACHE_DIR="${CACHE_DIR:-$PWD/.cache/huggingface}" +mkdir -p "$CACHE_DIR"
🧹 Nitpick comments (14)
recipes/esm2_accelerate/hydra_config/README.md (3)
1-1: Clarify the title for context.Make it explicit that these configs are for the esm2_accelerate recipe.
-# Hydra configs +# ESM-2 Accelerate Hydra configs
6-6: Grammar/clarity.Tighten wording; “maps” reads better than “gets mapped.”
-The `trainer` section gets mapped directly to the `TrainingArguments` class. +The `trainer` section maps directly to the `TrainingArguments` class.
12-12: Add “Tested versions” to de-risk config drift.Document the exact package versions this note applies to (transformers, accelerate, deepspeed) so future upgrades don’t invalidate the guidance.
Proposed addition at end of file:
## Tested versions This behavior was validated with: - transformers: X.Y.Z - accelerate: A.B.C - deepspeed: D.E.F Update this section when bumping versions.recipes/esm2_accelerate/accelerate_config/default.yaml (1)
5-7: Make GPU selection explicit for reproducibility.Add
gpu_ids: allso launches deterministically use all local GPUs unless overridden (common in accelerate-generated configs). (github.com, cnblogs.com)enable_cpu_affinity: false +gpu_ids: all machine_rank: 0recipes/amplify_accelerate_te_fp8/dataset.py (2)
86-88: Dropping raw columns via set_transform: confirm downstream expectations.
set_transform(tokenize)returns only tokenized fields; raw "sequence"/"name" are no longer yielded. Ensure nothing downstream (metrics, debugging) expects those columns; otherwise, keep them out intentionally to avoid string tensors in batches.
62-66: Ensure local train.parquet is packaged or add a fallback
Thetrain.parquetfile is present underrecipes/amplify_accelerate_te_fp8—confirm it’s included in CI and package builds. Optionally implement a Hub-based fallback for clean environments or document that this local file is required.recipes/esm2_accelerate/metrics.py (3)
18-21: Add numpy import for safe conversion pathNeeded by the revised compute_metrics if you prefer np checks; otherwise you can omit.
Apply:
-import torch +import torch +import numpy as np # optional; keep if you add np-specific logic
23-23: Avoid global mutable metric stateThe global Perplexity instance isn’t needed with the stateless CE→exp computation and can introduce state bleed. Remove it or ensure exclusive use.
Apply:
-perplexity = Perplexity(ignore_index=-100, sync_on_compute=False) +# No global metric needed; compute perplexity from CE in compute_metrics.
26-33: nested_cpu doesn’t handle numpy arraysIf you keep nested_cpu, consider converting np.ndarray to torch tensors here to be robust when used elsewhere.
Apply:
def nested_cpu(tensors): """Move `tensors` to CPU (even if it's a nested list/tuple/dict of tensors).""" if isinstance(tensors, (list, tuple)): return type(tensors)(nested_cpu(t) for t in tensors) elif isinstance(tensors, Mapping): return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()}) - return tensors.detach().cpu() if isinstance(tensors, torch.Tensor) else tensors + if isinstance(tensors, torch.Tensor): + return tensors.detach().cpu() + try: + import numpy as np # local import to avoid hard dep if unused + if isinstance(tensors, np.ndarray): + return torch.from_numpy(tensors).cpu() + except Exception: + pass + return tensorsrecipes/esm2_accelerate/train.py (1)
19-25: Import functools for wrapping compute_metricsNeeded for the binding below.
Apply:
import hydra +import functools import torchrecipes/esm2_accelerate/slurm.sh (2)
3-3: Misleading comment for ntasks-per-nodeYou set 1 task per node, but the comment says one per GPU. Clarify to avoid ops confusion.
Apply:
-#SBATCH --ntasks-per-node=1 # n tasks per machine (one task per gpu) <required> +#SBATCH --ntasks-per-node=1 # one task per node (accelerate spawns per-GPU processes inside)
11-21: Harden distributed args and config path
- Validate accelerate_config exists before srun.
- Allow overriding the main port to avoid collisions on shared clusters.
Apply:
export GPUS_PER_NODE=8 -export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \ +MAIN_PORT="\${MAIN_PORT:-12340}" +[ -f accelerate_config/fsdp2_te.yaml ] || { echo 'Missing accelerate_config/fsdp2_te.yaml' >&2; exit 1; } +export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \ accelerate launch \ --config_file accelerate_config/fsdp2_te.yaml \ - --machine_rank \"\$SLURM_NODEID\" \ + --machine_rank \"\$SLURM_NODEID\" \ --num_machines \"$SLURM_NNODES\" \ - --main_process_ip \"\$SLURM_SRUN_COMM_HOST\" \ - --main_process_port 12340 \ + --main_process_ip \"\$SLURM_SRUN_COMM_HOST\" \ + --main_process_port $MAIN_PORT \ --num_processes \"$(( $SLURM_NNODES * $GPUS_PER_NODE ))\" \ train.py "recipes/esm2_accelerate/test_train.py (2)
146-148: Fix assertion messageMessage references deepspeed_config.yaml but you’re checking accelerate_config_path.
Apply:
- assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}" + assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"
171-172: CI flakiness: bump timeout240s is tight when kernels/JIT kick in. Consider 600s to reduce flakes.
Apply:
- timeout=240, + timeout=600,Also applies to: 227-228
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (1)
recipes/esm2_accelerate/train.parquetis excluded by!**/*.parquet
📒 Files selected for processing (25)
models/esm2/src/esm/modeling_esm_te.py(2 hunks)recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml(1 hunks)recipes/amplify_accelerate_te_fp8/callbacks.py(1 hunks)recipes/amplify_accelerate_te_fp8/dataset.py(3 hunks)recipes/amplify_accelerate_te_fp8/test_train.py(2 hunks)recipes/amplify_accelerate_te_fp8/train.py(3 hunks)recipes/esm2_accelerate/.dockerignore(1 hunks)recipes/esm2_accelerate/.ruff.toml(1 hunks)recipes/esm2_accelerate/Dockerfile(1 hunks)recipes/esm2_accelerate/README.md(1 hunks)recipes/esm2_accelerate/accelerate_config/default.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml(1 hunks)recipes/esm2_accelerate/callbacks.py(1 hunks)recipes/esm2_accelerate/dataset.py(1 hunks)recipes/esm2_accelerate/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_accelerate/hydra_config/README.md(1 hunks)recipes/esm2_accelerate/hydra_config/defaults.yaml(1 hunks)recipes/esm2_accelerate/metrics.py(1 hunks)recipes/esm2_accelerate/requirements.txt(1 hunks)recipes/esm2_accelerate/slurm.sh(1 hunks)recipes/esm2_accelerate/test_train.py(1 hunks)recipes/esm2_accelerate/train.py(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- recipes/esm2_accelerate/.dockerignore
🚧 Files skipped from review as they are similar to previous changes (16)
- recipes/esm2_accelerate/requirements.txt
- recipes/esm2_accelerate/README.md
- models/esm2/src/esm/modeling_esm_te.py
- recipes/esm2_accelerate/hydra_config/L0_sanity.yaml
- recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml
- recipes/esm2_accelerate/dataset.py
- recipes/esm2_accelerate/Dockerfile
- recipes/esm2_accelerate/.ruff.toml
- recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml
- recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml
- recipes/esm2_accelerate/hydra_config/defaults.yaml
- recipes/amplify_accelerate_te_fp8/test_train.py
- recipes/esm2_accelerate/callbacks.py
- recipes/amplify_accelerate_te_fp8/callbacks.py
- recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml
- recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T15:58:24.658Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1080
File: recipes/esm2_accelerate/train.py:47-55
Timestamp: 2025-08-28T15:58:24.658Z
Learning: HuggingFace Transformers has a known typing issue where Trainer expects compute_metrics to have signature Callable[[EvalPrediction], Dict], but functions with additional parameters cause type checker warnings. This is a long-standing typing limitation in the framework, not a runtime error. Workarounds include wrapping functions with functools.partial, lambdas, or type casting.
Applied to files:
recipes/esm2_accelerate/train.py
🧬 Code graph analysis (3)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
StopAfterNStepsCallback(20-34)recipes/esm2_accelerate/dataset.py (1)
create_datasets_and_collator(36-74)recipes/esm2_accelerate/metrics.py (1)
compute_metrics(36-54)
recipes/esm2_accelerate/test_train.py (2)
recipes/esm2_accelerate/train.py (1)
main(36-75)recipes/amplify_accelerate_te_fp8/test_train.py (1)
test_accelerate_launch(169-210)
recipes/amplify_accelerate_te_fp8/train.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
StopAfterNStepsCallback(20-34)
🪛 LanguageTool
recipes/esm2_accelerate/hydra_config/README.md
[grammar] ~10-~10: There might be a mistake here.
Context: ...erride fp8 settings given to accelerate. This causes issues with the `deepspeed...
(QB_NEW_EN)
🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh
[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.
(SC1073)
[error] 24-24: Fix any mentioned problems and try again.
(SC1072)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
- GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (6)
recipes/esm2_accelerate/accelerate_config/default.yaml (3)
4-4: downcast_bf16: 'no' looks correct.String-quoting avoids YAML bool coercion and matches accelerate examples. (github.com)
13-16: TPU fields are consistent with accelerate defaults.
tpu_env: [],tpu_use_cluster: false, andtpu_use_sudo: falsematch typical defaults; no change needed. (github.com)
3-10: Align MULTI_GPU mode with num_processes
distributed_type: MULTI_GPUwithnum_processes: 1launches only one process (one GPU) by default, which defeats the multi-GPU intent. Either:
- Switch to a single-GPU default (e.g.
distributed_type: NOorSINGLE_GPU,num_processes: 1)- Or set
num_processesto your per-node GPU count (e.g. 8) for a true multi-GPU default- Or split into dedicated configs (e.g. default_single.yaml, default_multi8.yaml) and reference them in your CI/SLURM scripts
Please confirm which variant this recipe should use in CI/SLURM.recipes/amplify_accelerate_te_fp8/dataset.py (1)
28-33: Param rename LGTM.The
pretained_model→pretrained_modelfix removes a distracting typo and aligns call sites.recipes/amplify_accelerate_te_fp8/train.py (2)
47-51: Arg rename usage LGTM.Call site correctly updated to
pretrained_model=args.model_tag.
82-83: Guarded PG destroy LGTM.Safer in single-process and CPU-only runs; avoids errors if dist not initialized.
|
@coderabbitai resolve |
✅ Actions performedComments resolved. |
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Peter St. John <peterc.stjohn@gmail.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
31e9681 to
d68627a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 23
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
recipes/esm2_native_te_nvfsdp_thd/train.py (1)
149-154: DDP does not accept device_mesh kwarg — remove to avoid TypeError.
torch.nn.parallel.DistributedDataParallel has no device_mesh parameter.Apply this diff:
model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[dist_config.local_rank], output_device=dist_config.local_rank, - device_mesh=device_mesh["fsdp"], )recipes/esm2_native_te_mfsdp/scheduler.py (1)
19-46: Input validation and small-step robustnessGuard invalid configs and ensure decay window is ≥1 step after warmup.
def get_linear_schedule_with_warmup( @@ ): @@ - decay_steps = int(num_training_steps * 0.9) + if num_training_steps <= 0: + raise ValueError("num_training_steps must be > 0") + if num_warmup_steps < 0 or num_warmup_steps >= num_training_steps: + raise ValueError("Require 0 <= num_warmup_steps < num_training_steps") + decay_steps = max(num_warmup_steps + 1, int(num_training_steps * 0.9)) @@ - elif current_step > decay_steps: + elif current_step > decay_steps: return 0.1 # one tenth of peak learning rate after decay periodrecipes/esm2_native_te_mfsdp/train_mfsdp.py (1)
157-158: FSDP-aware grad clipping needed.
clip_grad_norm_(model.parameters())computes a local (shard) norm under FSDP and misreports/clips incorrectly across ranks.Option A (preferred): use the FSDP/megatron wrapper’s global-clip utility if available.
Option B: reduce the squared-norm across the fsdp mesh group:
fsdp_pg = device_mesh["fsdp"].get_group() local_sq = torch.zeros(1, device=device) for p in model.parameters(): if p.grad is not None: local_sq += p.grad.detach().float().norm(2).pow(2) dist.all_reduce(local_sq, group=fsdp_pg) total_norm = local_sq.sqrt().item() if total_norm > 1.0: scale = 1.0 / (total_norm + 1e-6) for p in model.parameters(): if p.grad is not None: p.grad.mul_(scale)recipes/esm2_native_te_mfsdp/train_fsdp2.py (1)
114-118: Finalize initialization: tie LM head weights when configured.After
to_empty+ manual resets, HF’s weight tying isn’t guaranteed to run.for module in model.modules(): if hasattr(module, "reset_parameters"): module.reset_parameters() + if getattr(model.config, "tie_word_embeddings", False) and hasattr(model, "tie_weights"): + model.tie_weights()
♻️ Duplicate comments (8)
recipes/esm2_accelerate/hydra_config/README.md (2)
3-4: Fix broken phrase and reference provided configs.Correct the split “partial conv … experiments” and point readers to the starter configs.
Apply this diff:
-We could expand this to include configs for full-scale convergence testing, partial conv -experiments, etc. +We can expand this to include configs for full-scale convergence testing and partial convergence experiments. +See defaults.yaml and L0_sanity.yaml in this directory for starting points.
8-12: BF16 vs FP8 note: fix capitalization and add concrete guidance to avoid Accelerate/DeepSpeed conflicts.Apply this diff:
-Notes: - -- Specifying `bf16` in the `TrainingArguments` class will override fp8 settings given to accelerate. - This causes issues with the `deepspeed` backend, since HF will check to make sure the bf16 - settings are the same between HF and Deepspeed settings. +Notes: + +- Specifying `bf16` in `TrainingArguments` will override FP8 settings provided to Accelerate. + This can cause issues with the DeepSpeed backend because Transformers validates that BF16 + settings are consistent between Transformers and DeepSpeed. + +Recommended: +- For FP8 runs via Accelerate/DeepSpeed, do not set `trainer.bf16`; keep it `false` and ensure DeepSpeed BF16 is disabled. +- For BF16 runs, set BF16 consistently in both `trainer` and the DeepSpeed config.Example Hydra snippets:
# FP8 run (disable BF16 everywhere) trainer: bf16: false # deepspeed_config: # bf16: # enabled: false # BF16 run (enable in both places) trainer: bf16: true # deepspeed_config: # bf16: # enabled: truerecipes/esm2_accelerate/slurm.sh (3)
24-25: Fix invalid placeholder that breaks parsing; set safe default and create the dir.The literal
<cache_dir>causes SC1073/SC1072 and makes the script unusable.Apply:
-# Mount a persistent cache directory to cache dataset downloads and transformations. -export CACHE_DIR=<cache_dir> +# Mount a persistent cache directory to cache dataset downloads and transformations. +export CACHE_DIR="${CACHE_DIR:-$PWD/.cache/huggingface}" +mkdir -p "$CACHE_DIR"
14-15: Pre-flight check: ensure accelerate config exists before srun.Fail fast if the referenced config file is missing.
Apply near the top (after ulimit):
+test -f accelerate_config/fsdp2_te.yaml || { + echo "Missing accelerate_config/fsdp2_te.yaml"; exit 1; +}
26-29: cd into repo inside container; quote mounts; kill on bad exit.Without cd, train.py may not resolve; unquoted mounts can break on spaces; add kill-on-bad-exit for hygiene.
Apply:
+WORKDIR="/workspace/bionemo/recipes/esm2_accelerate" +# Build mounts safely; .netrc may not exist +MOUNTS="${PWD}:/workspace/bionemo,${CACHE_DIR}:/root/.cache/huggingface" +[ -f "$HOME/.netrc" ] && MOUNTS="$MOUNTS,$HOME/.netrc:/root/.netrc" + srun \ - --container-image=<image_name> \ - --container-mounts=${PWD}:/workspace/bionemo,$HOME/.netrc:/root/.netrc,$CACHE_DIR:/root/.cache/huggingface \ - bash -c "$CMD" + --container-image="$IMAGE" \ + --container-mounts="$MOUNTS" \ + --kill-on-bad-exit=1 \ + bash -lc "cd \"$WORKDIR\" && $CMD"recipes/esm2_accelerate/test_train.py (3)
32-32: Call the undecorated runner, not a Hydra-decorated main.Prevents decorator side-effects (cwd changes, Hydra re-inits) when passing a composed config.
-from train import main +from train import run- main(sanity_config) + run(sanity_config)- main(sanity_config) + run(sanity_config)Also applies to: 60-60, 101-101
159-161: Hydra CLI: drop “.yaml” in --config-name.- "--config-name", - "L0_sanity.yaml", + "--config-name", + "L0_sanity",
215-217: Hydra CLI: drop “.yaml” in --config-name.- "--config-name", - "L0_sanity.yaml", + "--config-name", + "L0_sanity",
🧹 Nitpick comments (24)
recipes/esm2_accelerate/hydra_config/README.md (1)
6-6: Clarify mapping to Transformers’ TrainingArguments.Apply this diff:
-The `trainer` section gets mapped directly to the `TrainingArguments` class. +The `trainer` section maps 1:1 to the Hugging Face Transformers `TrainingArguments` class.recipes/esm2_accelerate/slurm.sh (2)
3-3: Fix misleading comment for ntasks-per-node.The script runs one srun task per node; Accelerate then spawns per-GPU workers. Update the comment to avoid confusion.
Apply:
-#SBATCH --ntasks-per-node=1 # n tasks per machine (one task per gpu) <required> +#SBATCH --ntasks-per-node=1 # one task per node; Accelerate spawns $GPUS_PER_NODE procs
20-21: Guard against missing train.py in working dir.Catch mistakes early if the repo layout changes.
Apply (place before srun):
+[ -f "train.py" ] || { echo "train.py not found in \$PWD; ensure WORKDIR is correct."; }If you adopt WORKDIR from the other suggestion, use:
[ -f "$WORKDIR/train.py" ] || { ...; exit 1; }.recipes/esm2_accelerate/test_train.py (5)
146-148: Fix assertion message (not a deepspeed config).- assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}" + assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"
171-172: Make timeout configurable and slightly higher to reduce flakiness.- timeout=240, + timeout=int(os.getenv("BIONEMO_TEST_TIMEOUT", "300")),Place near the imports:
# near other imports # Optional: allow CI to extend timeouts # TIMEOUT = int(os.getenv("BIONEMO_TEST_TIMEOUT", "300"))
202-204: Fix assertion message (not a deepspeed config).- assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}" + assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"
54-55: Avoid fixed MASTER_PORT to reduce EADDRINUSE in parallel CI.Example:
- monkeypatch.setenv("MASTER_PORT", "29500") + import socket + with socket.socket() as s: + s.bind(("", 0)) + monkeypatch.setenv("MASTER_PORT", str(s.getsockname()[1]))
57-59: Optional: lock down steps to make checkpoint count deterministic.- sanity_config = compose(config_name="L0_sanity", overrides=[f"trainer.output_dir={tmp_path}"]) + sanity_config = compose( + config_name="L0_sanity", + overrides=[ + f"trainer.output_dir={tmp_path}", + "trainer.save_steps=2", + "stop_after_n_steps=5", + ], + )recipes/esm2_native_te_mfsdp/.dockerignore (2)
1-9: Trim build context: add common heavy paths and artifacts to .dockerignoreAdd typical large/secret directories and model artifacts to cut build time and reduce risk.
Dockerfile README.md checkpoint_export/ outputs/ .ruff_cache __pycache__ .pytest_cache .ruff.toml .dockerignore + .git + .venv/ + env/ + .DS_Store + notebooks/ + data/ + datasets/ + logs/ + wandb/ + *.pt + *.ckpt + *.bin + *.safetensors + .cache/huggingface + .cache/torch + **/.ipynb_checkpoints
2-2: Reconsider excluding README.mdIf you want docs inside the image (e.g., for debugging/help), keep README.md in context.
recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml (1)
18-19: Prefer fp32 gradient reduce for bf16 stabilityUnless you have a measured reason, reduce grads in fp32.
- grad_reduce_in_fp32: false + grad_reduce_in_fp32: truerecipes/esm2_native_te_mfsdp/Dockerfile (2)
4-7: Harden pip install step (cache concurrency, smaller layers)Lock the cache and avoid persisting pip wheels into the image layer.
-RUN --mount=type=secret,id=netrc,target=/root/.netrc \ - --mount=type=cache,target=/root/.cache/pip \ +RUN --mount=type=secret,id=netrc,target=/root/.netrc \ + --mount=type=cache,target=/root/.cache/pip,sharing=locked \ --mount=type=bind,source=requirements.txt,target=/requirements.txt \ - PIP_CONSTRAINT= pip install -r /requirements.txt + PIP_CONSTRAINT= pip install -r /requirements.txt --progress-bar off --no-cache-dir
9-10: Non-root and cache dirs for HF/torch downloadsRunning as root in containers is discouraged; also set cache envs for reuse across runs.
WORKDIR /workspace/bionemo -COPY . . +COPY --link . .Additional lines to add (outside this hunk):
ENV HF_HOME=/workspace/.cache/huggingface \ TRANSFORMERS_CACHE=/workspace/.cache/huggingface \ TORCH_HOME=/workspace/.cache/torch RUN mkdir -p $HF_HOME $TORCH_HOME && chmod -R a+rwx /workspace/.cache # Optional: switch to non-root user if your runtime permits # RUN useradd -m -u 1000 app && chown -R 1000:1000 /workspace # USER 1000:1000recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml (1)
16-16: Add a tiny warmup to reduce early-step instabilityFor 250 steps, 10% warmup is reasonable; even 25 steps helps.
- num_warmup_steps: 0 + num_warmup_steps: 25recipes/esm2_native_te_mfsdp/scheduler.py (1)
19-25: Name collision with transformers.get_linear_schedule_with_warmupSame name, different behavior (plateau at 0.1). Consider renaming or clarifying in docs to avoid confusion during imports.
recipes/esm2_native_te_mfsdp/train_mfsdp.py (3)
162-162: Useset_to_none=Truefor cheaper zeroing.Reduces memory writes and can speed up training.
- optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True)
64-66: Docstring return type should includeNone.Non-main ranks return
None.- float: The loss value for the final batch. + float | None: The loss on global rank 0; None on other ranks.
191-192: Use the already-detachedloss_valuefor the postfix.Avoids extra device sync.
- progress_bar.set_postfix({"loss": loss.item()}) + progress_bar.set_postfix({"loss": loss_value})recipes/esm2_native_te_mfsdp/train_ddp.py (2)
144-144: Cheaper grad zeroing.- optimizer.zero_grad() + optimizer.zero_grad(set_to_none=True)
172-173: Use detached value for tqdm postfix.- progress_bar.set_postfix({"loss": loss.item()}) + progress_bar.set_postfix({"loss": loss_value})recipes/esm2_native_te_mfsdp/train_fsdp2.py (1)
65-66: Docstring return type should includeNone.- float: The loss value for the final batch. + float | None: The loss on global rank 0; None on other ranks.recipes/esm2_native_te_mfsdp/test_train.py (3)
81-86: Harden cleanup of private mesh resources.These are private internals; guard with
hasattrto avoid AttributeErrors across versions.- _mesh_resources.mesh_stack.clear() - _mesh_resources.child_to_root_mapping.clear() - _mesh_resources.root_to_flatten_mapping.clear() - _mesh_resources.flatten_name_to_root_dims.clear() - _mesh_resources.mesh_dim_group_options.clear() + for attr in ( + "mesh_stack", + "child_to_root_mapping", + "root_to_flatten_mapping", + "flatten_name_to_root_dims", + "mesh_dim_group_options", + ): + if hasattr(_mesh_resources, attr): + getattr(_mesh_resources, attr).clear()
183-183: Fix misleading docstrings (usetorchrun, notaccelerate).- # Run 'accelerate launch train.py' as a subprocess + # Run with torchrun as a subprocessAlso applies to: 200-200, 220-220, 238-238
49-56: Consider a slightly higher timeout for slower CI runners.240s can be tight under GPU resource contention; 360–600s is safer.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
⛔ Files ignored due to path filters (2)
recipes/esm2_accelerate/train.parquetis excluded by!**/*.parquetrecipes/esm2_native_te_mfsdp/train.parquetis excluded by!**/*.parquet
📒 Files selected for processing (42)
.devcontainer/recipes/requirements.txt(1 hunks)models/esm2/src/esm/modeling_esm_te.py(2 hunks)recipes/README.md(2 hunks)recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml(1 hunks)recipes/amplify_accelerate_te_fp8/callbacks.py(1 hunks)recipes/amplify_accelerate_te_fp8/dataset.py(3 hunks)recipes/amplify_accelerate_te_fp8/test_train.py(2 hunks)recipes/amplify_accelerate_te_fp8/train.py(3 hunks)recipes/esm2_accelerate/Dockerfile(0 hunks)recipes/esm2_accelerate/README.md(1 hunks)recipes/esm2_accelerate/accelerate_config/default.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml(1 hunks)recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml(1 hunks)recipes/esm2_accelerate/callbacks.py(1 hunks)recipes/esm2_accelerate/dataset.py(1 hunks)recipes/esm2_accelerate/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_accelerate/hydra_config/README.md(1 hunks)recipes/esm2_accelerate/hydra_config/defaults.yaml(1 hunks)recipes/esm2_accelerate/metrics.py(1 hunks)recipes/esm2_accelerate/requirements.txt(1 hunks)recipes/esm2_accelerate/slurm.sh(1 hunks)recipes/esm2_accelerate/test_train.py(1 hunks)recipes/esm2_accelerate/train.py(1 hunks)recipes/esm2_native_te_mfsdp/.dockerignore(1 hunks)recipes/esm2_native_te_mfsdp/.ruff.toml(1 hunks)recipes/esm2_native_te_mfsdp/Dockerfile(1 hunks)recipes/esm2_native_te_mfsdp/README.md(1 hunks)recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml(2 hunks)recipes/esm2_native_te_mfsdp/hydra_config/L1_15B_perf_test.yaml(1 hunks)recipes/esm2_native_te_mfsdp/hydra_config/L1_650M.yaml(0 hunks)recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml(1 hunks)recipes/esm2_native_te_mfsdp/requirements.txt(1 hunks)recipes/esm2_native_te_mfsdp/scheduler.py(1 hunks)recipes/esm2_native_te_mfsdp/test_train.py(1 hunks)recipes/esm2_native_te_mfsdp/train_ddp.py(1 hunks)recipes/esm2_native_te_mfsdp/train_fsdp2.py(8 hunks)recipes/esm2_native_te_mfsdp/train_mfsdp.py(8 hunks)recipes/esm2_native_te_nvfsdp/README.md(0 hunks)recipes/esm2_native_te_nvfsdp/test_train.py(0 hunks)recipes/esm2_native_te_nvfsdp_thd/train.py(1 hunks)
💤 Files with no reviewable changes (4)
- recipes/esm2_native_te_nvfsdp/README.md
- recipes/esm2_native_te_mfsdp/hydra_config/L1_650M.yaml
- recipes/esm2_native_te_nvfsdp/test_train.py
- recipes/esm2_accelerate/Dockerfile
✅ Files skipped from review due to trivial changes (2)
- recipes/esm2_native_te_mfsdp/README.md
- recipes/esm2_accelerate/accelerate_config/default.yaml
🚧 Files skipped from review as they are similar to previous changes (17)
- recipes/esm2_accelerate/hydra_config/defaults.yaml
- recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml
- recipes/amplify_accelerate_te_fp8/dataset.py
- recipes/esm2_accelerate/callbacks.py
- recipes/esm2_accelerate/hydra_config/L0_sanity.yaml
- recipes/esm2_accelerate/README.md
- recipes/amplify_accelerate_te_fp8/callbacks.py
- recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml
- recipes/esm2_accelerate/dataset.py
- recipes/esm2_accelerate/requirements.txt
- recipes/esm2_accelerate/train.py
- recipes/esm2_accelerate/metrics.py
- recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml
- recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml
- models/esm2/src/esm/modeling_esm_te.py
- recipes/amplify_accelerate_te_fp8/train.py
- recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml
🧰 Additional context used
🧬 Code graph analysis (6)
recipes/esm2_native_te_mfsdp/train_ddp.py (2)
recipes/esm2_native_te_mfsdp/dataset.py (1)
create_dataloader(40-95)recipes/esm2_native_te_mfsdp/scheduler.py (1)
get_linear_schedule_with_warmup(19-45)
recipes/esm2_native_te_mfsdp/test_train.py (3)
recipes/esm2_native_te_mfsdp/train_ddp.py (1)
main(53-181)recipes/esm2_native_te_mfsdp/train_fsdp2.py (1)
main(57-190)recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)
main(56-199)
recipes/esm2_accelerate/test_train.py (3)
sub-packages/bionemo-testing/src/bionemo/testing/torch.py (1)
check_fp8_support(21-33)recipes/esm2_accelerate/train.py (1)
main(36-75)recipes/amplify_accelerate_te_fp8/test_train.py (2)
test_accelerate_launch(169-210)test_accelerate_launch_multi_gpu(214-249)
recipes/esm2_native_te_mfsdp/train_fsdp2.py (3)
recipes/esm2_native_te_mfsdp/dataset.py (1)
create_dataloader(40-95)recipes/esm2_native_te_mfsdp/scheduler.py (1)
get_linear_schedule_with_warmup(19-45)recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)
main(56-199)
recipes/amplify_accelerate_te_fp8/test_train.py (1)
recipes/esm2_accelerate/test_train.py (1)
test_accelerate_launch_multi_gpu(194-233)
recipes/esm2_native_te_mfsdp/train_mfsdp.py (3)
recipes/esm2_native_te_mfsdp/dataset.py (1)
create_dataloader(40-95)recipes/esm2_native_te_mfsdp/scheduler.py (1)
get_linear_schedule_with_warmup(19-45)recipes/esm2_native_te_mfsdp/train_fsdp2.py (2)
main(57-190)DistributedConfig(44-53)
🪛 LanguageTool
recipes/esm2_accelerate/hydra_config/README.md
[grammar] ~10-~10: There might be a mistake here.
Context: ...erride fp8 settings given to accelerate. This causes issues with the `deepspeed...
(QB_NEW_EN)
🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh
[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.
(SC1073)
[error] 24-24: Fix any mentioned problems and try again.
(SC1072)
🔇 Additional comments (14)
recipes/esm2_native_te_mfsdp/.ruff.toml (2)
1-1: Per-recipe Ruff inheritance looks good.Extending the parent config keeps rules consistent across recipes. No issues.
1-1: Confirmed base config path.extend = "../.ruff.toml"correctly resolves torecipes/.ruff.tomland that file exists. No change required—only switch toextend = "../../.ruff.toml"if you intended to reference the repo-root config.recipes/esm2_native_te_mfsdp/requirements.txt (1)
7-7: Pin the transformers VCS dependency to a commit SHA for reproducibility.Installing from HEAD is non-deterministic in air-gapped/CI environments; pin to a known-good commit that contains the ESM-2 fix:
-transformers @ git+https://github.com/huggingface/transformers +transformers @ git+https://github.com/huggingface/transformers@<commit_sha>Run in your environment to confirm the pinned ref is used:
pip install -r recipes/esm2_native_te_mfsdp/requirements.txt python - <<'PY' import pkg_resources dist = pkg_resources.get_distribution("transformers") print("Direct URL:", dist._provider.egg_info) PYrecipes/esm2_native_te_nvfsdp_thd/train.py (1)
84-84: Hydra config_name without extension — LGTM.
Consistent with repo-wide change.recipes/esm2_native_te_mfsdp/hydra_config/L1_15B_perf_test.yaml (1)
5-5: Verify training scripts use the full model_name – check that in train_mfsdp.py, train_fsdp2.py and train_ddp.py all calls to from_pretrained use args.model_name directly (e.g. AutoModel.from_pretrained(args.model_name)) and do not prepend “facebook/”, otherwise “nvidia/esm2_t48_15B_UR50D” will resolve to “facebook/nvidia/…”.recipes/README.md (1)
138-139: Doc sample uses extensionless config_name — good.
Matches updated decorator usage.recipes/amplify_accelerate_te_fp8/test_train.py (1)
37-40: Good multi-GPU skip marker.
Prevents spurious CI failures on 1-GPU runners.recipes/esm2_accelerate/test_train.py (2)
32-32: Importingmainis correct: train.py defines a Hydra-decoratedmainentrypoint and does not implement arun()function.
68-71: Make checkpoint count assertion robust.Exact “== 2” is brittle across config changes. Assert presence instead.
- assert len(checkpoint_dirs) == 2, ( - f"Expected 2 checkpoint directories, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}" - ) + assert len(checkpoint_dirs) >= 1, ( + f"Expected at least 1 checkpoint directory, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}" + )Likely an incorrect or invalid review comment.
recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml (1)
12-24: FSDP enabled solely byfully_shard_kwargs
train_mfsdp.py unconditionally callsfully_shard(..., **args.fully_shard_kwargs)without anyuse_fsdpguard (lines 119–122).recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml (2)
18-19: LR 1e-2 is likely to diverge even for sanity runsRecommend ≤1e-3 (ESM-2 pretrain baselines use ~4e-4 peak).
adamw_kwargs: - lr: 1e-2 + lr: 4e-4
5-5: Model path is valid
The Hugging Face reponvidia/esm2_t6_8M_UR50Dexists; the canonical upstream fallback isfacebook/esm2_t6_8M_UR50D.recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)
113-117: VerifyEsmLayerimport compatibility. Ensure that in your deployed Transformers version, the moduletransformers.models.esm.modeling_esmactually definesEsmLayerby running, in the same environment where this recipe will execute:from transformers.models.esm.modeling_esm import EsmLayerand confirming no ImportError.
recipes/esm2_native_te_mfsdp/train_ddp.py (1)
70-74: DDP doesn’t acceptdevice_mesh; remove arg (and the unused mesh).Passing
device_meshtoDistributedDataParallelwill error on most PyTorch versions.- device_mesh = init_device_mesh( - "cuda", - mesh_shape=(dist_config.world_size,), - mesh_dim_names=("ddp",), - ) @@ model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[dist_config.local_rank], output_device=dist_config.local_rank, - device_mesh=device_mesh["ddp"], )Also applies to: 103-108
⛔ Skipped due to learnings
Learnt from: pstjohn PR: NVIDIA/bionemo-framework#1078 File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108 Timestamp: 2025-08-28T16:40:04.274Z Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
|
@coderabbitai resolve |
✅ Actions performedComments resolved. |
Initial work on adding an accelerate recipe for ESM-2. This will let us more easily compare TE to FA2 performance with accelerate and their built-in distributed settings (FSDP2, etc.)
I noticed a bug with the ESM-2 model implementation that we'll need to upstream to the hf model repo
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Tests
Chores