Skip to content

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Aug 27, 2025

Initial work on adding an accelerate recipe for ESM-2. This will let us more easily compare TE to FA2 performance with accelerate and their built-in distributed settings (FSDP2, etc.)

I noticed a bug with the ESM-2 model implementation that we'll need to upstream to the hf model repo

Summary by CodeRabbit

  • New Features

    • Added an ESM-2 Accelerate recipe with Hydra configs, dataset/metrics utilities, callback to stop after N steps, and SLURM script for multi-node runs.
    • Provided FSDP-ready accelerate configs (single-process and single-node variants).
  • Bug Fixes

    • Corrected parameter name to pretrained_model and improved distributed teardown safety.
    • Improved rotary embedding handling for stability across devices and sequence lengths.
  • Documentation

    • Added READMEs for the new recipe and Hydra configs with run examples.
  • Tests

    • Introduced end-to-end and multi-GPU launch tests with improved logging and timeouts.
  • Chores

    • Updated dependencies, added lint/config files, dockerignore, and adjusted default accelerate settings.

Copy link
Collaborator

@jomitchellnv jomitchellnv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM a few nits

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 28, 2025

Walkthrough

Per-call ROPE embedding handling was refactored in the ESM encoder. A new training recipe (esm2_accelerate) was added with Hydra configs, datasets, metrics, callbacks, Docker/SLURM assets, requirements, and tests. The amplify_accelerate_te_fp8 recipe gained a standalone callback, config tweak, test updates including multi-GPU, and a dataset API rename.

Changes

Cohort / File(s) Summary
ESM2 encoder ROPE handling
models/esm2/src/esm/modeling_esm_te.py
Registers te_rope_emb as a non-persistent CPU pinned buffer; constructs/moves/slices ROPE per forward call; passes per-call rotary_pos_emb; handles absence of rotary.
Amplify TE FP8 training plumbing
recipes/amplify_accelerate_te_fp8/callbacks.py, recipes/amplify_accelerate_te_fp8/train.py, recipes/amplify_accelerate_te_fp8/dataset.py
Extracts StopAfterNStepsCallback to module and imports it; fixes pretrained_model arg in dataset API and usage; guards process group destroy; minor Hydra config name fix; tokenization transform simplified.
Amplify TE FP8 configs/tests
recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml, recipes/amplify_accelerate_te_fp8/test_train.py
Sets dynamo_backend: "NO"; refactors accelerate launch test; adds multi-GPU test with skip guard; standardizes trainer flags and timeout/error reporting.
ESM2 accelerate core code
recipes/esm2_accelerate/train.py, recipes/esm2_accelerate/dataset.py, recipes/esm2_accelerate/metrics.py, recipes/esm2_accelerate/callbacks.py
Adds Hydra-driven training entrypoint; local dataset pipeline and collator; perplexity metric utilities; StopAfterNStepsCallback.
ESM2 accelerate configs
recipes/esm2_accelerate/accelerate_config/*
Adds default and FSDP v1/v2 accelerate configs; sets bf16, local machine, wrap policies, and dynamo_backend: "NO".
ESM2 accelerate Hydra configs/docs
recipes/esm2_accelerate/hydra_config/defaults.yaml, recipes/esm2_accelerate/hydra_config/L0_sanity.yaml, recipes/esm2_accelerate/hydra_config/README.md, recipes/esm2_accelerate/README.md
Introduces base and sanity Hydra configs; documents usage and notes on bf16 vs fp8; example run instructions.
ESM2 accelerate infra
recipes/esm2_accelerate/Dockerfile, recipes/esm2_accelerate/.dockerignore, recipes/esm2_accelerate/.ruff.toml, recipes/esm2_accelerate/requirements.txt, recipes/esm2_accelerate/slurm.sh
Removes Dockerfile syntax directive; adds dockerignore and per-recipe Ruff config; adds training dependencies; adds SLURM script with containerized accelerate launch.
ESM2 accelerate tests
recipes/esm2_accelerate/test_train.py
Adds end-to-end training, resume-from-checkpoint, and accelerate single/multi-GPU launch tests with logging and assertions.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Hydra
  participant Trainer
  participant Model
  participant Dataset
  participant Callback as StopAfterNStepsCallback
  participant Metrics

  User->>Hydra: Run train.py (config L0_sanity)
  Hydra->>Trainer: Build TrainingArguments
  Hydra->>Model: AutoConfig -> AutoModelForMaskedLM (bf16)
  Hydra->>Dataset: create_datasets_and_collator(tokenizer, max_len)
  Trainer->>Trainer: Initialize with model, datasets, data_collator, compute_metrics
  Trainer->>Callback: Register StopAfterNStepsCallback(n)
  Trainer->>Trainer: Train (resume if checkpoint found)
  loop Each step
    Trainer->>Callback: on_step_end(state.global_step)
    Callback-->>Trainer: control.should_training_stop? (if >= n)
  end
  Trainer->>Metrics: compute_metrics(eval_pred) [if eval]
  Trainer->>Trainer: Save metrics and checkpoint-last
Loading
sequenceDiagram
  autonumber
  participant Encoder as NVEsmEncoder
  participant Buffer as te_rope_emb (CPU pinned buffer)
  participant Hidden as hidden_states
  participant Layer as TransformerLayer

  Hidden->>Encoder: forward(hidden_states, seq_len)
  alt Rotary enabled
    Encoder->>Buffer: Access registered buffer
    Encoder->>Encoder: Move/slice to device/dtype (non-blocking), length check
    Encoder->>Layer: call(..., rotary_pos_emb=sliced_emb)
  else Rotary disabled
    Encoder->>Layer: call(..., rotary_pos_emb=None)
  end
  Layer-->>Encoder: outputs
  Encoder-->>Hidden: outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • yzhang123
  • ohadmo
  • jwilber
  • cspades
  • dorotat-nv
  • trvachov
  • skothenhill-nv
  • sichu2023
  • tshimko-nv
  • broland-hat
  • polinabinder1

Poem

Hop hop, I ship with glee,
ROPE now moves per-call, set free.
New recipes sprout, configs bloom,
Trainers march—stop-steps go boom.
Checkpoints nest like carrots stacked,
Multi-GPU? I’ve got your back.
Bun-approved builds: thump-thump, intact! 🥕🐇


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d68627a and e29a5d3.

📒 Files selected for processing (1)
  • recipes/amplify_accelerate_te_fp8/test_train.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • recipes/amplify_accelerate_te_fp8/test_train.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbit in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbit in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbit gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbit read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbit help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbit ignore or @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbit summary or @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbit or @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
recipes/amplify_accelerate_te_fp8/train.py (1)

82-82: Guard destroy_process_group to prevent errors when DDP was not initialized.

Unconditional teardown can raise on single-process or CPU runs.

Apply:

-    torch.distributed.destroy_process_group()
+    if torch.distributed.is_available() and torch.distributed.is_initialized():
+        torch.distributed.destroy_process_group()
♻️ Duplicate comments (3)
models/esm2/src/esm/modeling_esm_te.py (1)

173-174: Confirm TE accepts rotary_pos_emb=None for non-ROPE paths.

Upstream comment asked “is none acceptable?”. This path now passes None when position_embedding_type != "rotary". Please confirm TransformerLayer(..., rotary_pos_emb=None) is a supported code path in the TE version you target.

If TE is available, run a quick import-time check:

#!/bin/bash
python - <<'PY'
import transformer_engine.pytorch as te, torch
try:
    lyr = te.TransformerLayer(hidden_size=64, ffn_hidden_size=128, num_attention_heads=4,
                              layernorm_epsilon=1e-5, hidden_dropout=0.0, attention_dropout=0.0,
                              qkv_weight_interleaved=True, layer_number=1, layer_type="encoder",
                              self_attn_mask_type="padding", activation="gelu", attn_input_format="bshd",
                              seq_length=128, micro_batch_size=1, num_gqa_groups=4, fuse_qkv_params=True,
                              params_dtype=torch.bfloat16, window_size=(-1, -1))
    x = torch.zeros(1, 8, 64, dtype=torch.bfloat16, device="cuda" if torch.cuda.is_available() else "cpu")
    m = torch.zeros(1,1,1,8, dtype=torch.bool, device=x.device)
    lyr(x, m, rotary_pos_emb=None)
    print("OK: rotary_pos_emb=None accepted")
except Exception as e:
    print("FAIL:", e)
PY
recipes/esm2_accelerate/hydra_config/defaults.yaml (1)

20-21: Scheduler differs from original ESM-2 (WarmupAnnealDecay).

Not blocking, but note divergence from upstream training schedule.

If you want parity, I can draft a Hydra-configurable Warmup+Anneal+MinLR scheduler adapter compatible with HF Trainer.

recipes/esm2_accelerate/test_train.py (1)

142-145: Nit: variable name in assertion message is about accelerate, not deepspeed.

Covered by a later diff; ensure both single- and multi-GPU messages are consistent.

🧹 Nitpick comments (29)
recipes/esm2_accelerate/.dockerignore (1)

1-9: Tighten Docker context excludes; add common heavy dirs and patterns.

Recommend ignoring VCS, venvs, eggs, caches, and Python artifacts to shrink build context and speed builds.

 Dockerfile
 README.md
 checkpoint_export/
 outputs/
 .ruff_cache
 __pycache__
 .pytest_cache
 .ruff.toml
 .dockerignore
+.git
+.gitignore
+.venv/
+**/__pycache__/
+*.py[cod]
+*.egg-info/
+.mypy_cache/
+.cache/
+dist/
+build/
+data/
recipes/esm2_accelerate/accelerate_config/default.yaml (1)

3-11: Check distributed_type: MULTI_GPU with num_processes: 1.

If you expect 1 proc/GPU, this should be overridden by launcher or set accordingly; otherwise you'll run single-process on multi-GPU.

How are tests invoking accelerate (CLI flags vs config override)? If config-driven, consider documenting that num_processes must match GPU count for multi-GPU runs.

recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml (1)

1-23: Add FSDP2-friendly flags

  • Under fsdp_config, add
    fsdp_sync_module_states: true
    fsdp_limit_all_gathers: true
    (verify FSDP v2 supports these keys)
  • Ensure num_processes: 1 is intended only when always overridden via CLI
recipes/amplify_accelerate_te_fp8/callbacks.py (3)

31-34: Return the updated control object from callback.

Transformers callbacks typically return control to propagate changes. Make the return explicit.

     def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
         """Interrupt training after a specified number of steps."""
         if state.global_step >= self.max_steps:
             control.should_training_stop = True
+        return control

27-29: Validate max_steps early to avoid accidental no-ops.

Guard against max_steps <= 0.

     def __init__(self, max_steps: int):
         """Initialize the callback."""
-        self.max_steps = max_steps
+        if max_steps <= 0:
+            raise ValueError(f"max_steps must be > 0, got {max_steps}")
+        self.max_steps = max_steps

20-34: Reduce duplication across recipes.

This class duplicates recipes/esm2_accelerate/callbacks.py. Consider a shared utility (e.g., recipes/common/callbacks.py) to keep a single source.

recipes/esm2_accelerate/metrics.py (1)

26-33: Handle nested NumPy arrays as well.

nested_cpu skips NumPy arrays. If you keep this helper, extend it to convert NumPy arrays to tensors.

-    return tensors.detach().cpu() if isinstance(tensors, torch.Tensor) else tensors
+    if isinstance(tensors, torch.Tensor):
+        return tensors.detach().cpu()
+    try:
+        import numpy as np  # local import to avoid hard dep if unused
+        if isinstance(tensors, np.ndarray):
+            return torch.from_numpy(tensors).detach().cpu()
+    except Exception:
+        pass
+    return tensors
recipes/esm2_accelerate/Dockerfile (1)

8-8: Optional: minimize context churn.

If the image only needs training scripts, consider copying just the recipe subtree to speed builds and reduce cache invalidation.

recipes/esm2_accelerate/hydra_config/README.md (2)

3-4: Fix broken sentence wrap.

Merge the split phrase.

-We could expand this to include configs for full-scale convergence testing, partial conv
-experiments, etc.
+We could expand this to include configs for full-scale convergence testing, partial convergence experiments, etc.

10-12: Grammar and clarity nits.

Tighten spacing and casing.

-- Specifying `bf16` in the `TrainingArguments` class will override fp8 settings given to accelerate.
-  This causes issues with the `deepspeed` backend, since HF will check to make sure the bf16
+- Specifying `bf16` in the `TrainingArguments` class will override FP8 settings given to Accelerate.
+  This can cause issues with the `deepspeed` backend, since HF will check to make sure the bf16
   settings are the same between HF and Deepspeed settings.
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1)

6-14: Make sanity runs deterministic; ensure eval triggers.

  • Add a fixed seed for reproducibility.
  • If not set in defaults.yaml, specify evaluation_strategy: "steps" so eval_steps is honored.
 trainer:
   run_name: "esm2_t6_8M_UR50D_sanity"
+  seed: 42
   per_device_train_batch_size: 2
   per_device_eval_batch_size: 2
   save_steps: 2
   eval_steps: 2
+  evaluation_strategy: "steps"
   logging_steps: 1
   report_to: "none"
   dataloader_num_workers: 0
recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1)

3-5: Sanity check: single-process FSDP.

distributed_type=FSDP with num_processes: 1 is fine for dev sanity, but won’t shard across GPUs. If the intent is single-GPU testing, consider distributed_type: NO here and keep FSDP in a separate multi-GPU config.

recipes/esm2_accelerate/slurm.sh (2)

3-3: Directive/comment mismatch.

--ntasks-per-node=1 conflicts with the “one task per gpu” comment. Either:

  • keep 1 task per node (accelerate spawns per-GPU workers), and fix the comment; or
  • set --ntasks-per-node=8 and add --gpus-per-task=1.
-#SBATCH --ntasks-per-node=1              # n tasks per machine (one task per gpu) <required>
+#SBATCH --ntasks-per-node=1              # one task per node; accelerate spawns per-GPU workers

12-21: Consistent deferred expansion for SLURM vars inside $CMD.

You escaped $SLURM_NODEID and $SLURM_SRUN_COMM_HOST but not $SLURM_NNODES. For consistency and to avoid host-side expansion, escape it too.

-    --num_machines "$SLURM_NNODES" \
+    --num_machines "\$SLURM_NNODES" \
recipes/amplify_accelerate_te_fp8/test_train.py (1)

213-249: Add existence check for accelerate config used in multi-GPU test.

Prevents opaque failures if bf16_config.yaml is missing.

     recipe_dir = Path(__file__).parent
     train_py = recipe_dir / "train.py"
 
-    cmd = [
+    accelerate_config_path = recipe_dir / "accelerate_config" / "bf16_config.yaml"
+    assert accelerate_config_path.exists(), f"bf16_config.yaml not found at {accelerate_config_path}"
+
+    cmd = [
         sys.executable,
         "-m",
         "accelerate.commands.launch",
         "--config_file",
-        str(recipe_dir / "accelerate_config" / "bf16_config.yaml"),
+        str(accelerate_config_path),
recipes/esm2_accelerate/train.py (1)

41-41: Optional: load from_pretrained if weights are desired.

If you intend to compare TE vs HF with pretrained ESM-2 weights, consider from_pretrained(args.model_tag) rather than from_config.

-    model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
+    model = AutoModelForMaskedLM.from_pretrained(args.model_tag, config=config, trust_remote_code=True, torch_dtype=torch.bfloat16)
recipes/esm2_accelerate/callbacks.py (2)

31-34: Return control from the callback for consistency with HF CallbackHandler.

Mutating control in-place usually works, but returning it is the recommended pattern and avoids surprises with multiple callbacks.

Apply:

     def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
         """Interrupt training after a specified number of steps."""
         if state.global_step >= self.max_steps:
             control.should_training_stop = True
+        return control

27-29: Validate max_steps > 0 to avoid immediate stop on step 0.

Defensive check improves ergonomics and error messages.

Apply:

     def __init__(self, max_steps: int):
         """Initialize the callback."""
-        self.max_steps = max_steps
+        if not isinstance(max_steps, int) or max_steps <= 0:
+            raise ValueError(f"max_steps must be a positive int, got {max_steps!r}")
+        self.max_steps = int(max_steps)
recipes/esm2_accelerate/dataset.py (2)

70-74: Use a sane pad_to_multiple_of (e.g., 8) instead of max_length.

Padding to max_length is redundant with padding="max_length" and can be misleading.

Apply:

-    data_collator = DataCollatorForLanguageModeling(
+    data_collator = DataCollatorForLanguageModeling(
         tokenizer=tokenizer,
         mlm_probability=0.15,
-        pad_to_multiple_of=max_length,
+        pad_to_multiple_of=8,
     )

28-36: infinite_dataloader assumes sampler.set_epoch exists. Guard for non-distributed samplers or drop this helper if unused.

SequentialSampler does not implement set_epoch.

Apply:

 def infinite_dataloader(dataloader, sampler):
     """Create an infinite iterator that automatically restarts at the end of each epoch."""
     epoch = 0
     while True:
-        sampler.set_epoch(epoch)  # Update epoch for proper shuffling
+        if hasattr(sampler, "set_epoch"):
+            sampler.set_epoch(epoch)  # Update epoch for proper shuffling
         for batch in dataloader:
             yield batch
         epoch += 1  # Increment epoch counter after completing one full pass
recipes/esm2_accelerate/test_train.py (9)

24-26: Only set TRITON_LIBCUDA_PATH if unset; avoid overriding env in CI.

Apply:

-# Set TRITON_LIBCUDA_PATH before check_fp8_support import that triggers Triton initialization
-os.environ["TRITON_LIBCUDA_PATH"] = "/usr/local/cuda/lib64"
+# Set TRITON_LIBCUDA_PATH before any TE import that may trigger Triton initialization
+os.environ.setdefault("TRITON_LIBCUDA_PATH", "/usr/local/cuda/lib64")

27-36: Remove unused FP8 support check and import to reduce side effects.

These values are unused and can fail on CPU-only CI.

Apply:

-import pytest
-import torch
-from hydra import compose, initialize_config_dir
-from transformer_engine.pytorch.fp8 import check_fp8_support
+import pytest
+import torch
+from hydra import compose, initialize_config_dir
@@
-_fp8_available, _fp8_reason = check_fp8_support()
+# (Intentionally not importing FP8/TE here; tests don't require it.)

67-71: Make checkpoint assertions robust to config changes.

Asserting exactly 2 checkpoints is brittle. Check presence and minimum expected count.

Apply:

-    assert len(checkpoint_dirs) == 2, (
-        f"Expected 2 checkpoint directories, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}"
-    )
+    assert len(checkpoint_dirs) >= 1, (
+        f"Expected at least 1 checkpoint directory, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}"
+    )

Optional: also assert specific ones if your save_steps are stable:

expected = {tmp_path / "checkpoint-2", tmp_path / "checkpoint-4"}
assert expected & set(checkpoint_dirs), f"Missing expected checkpoints: {expected}"

90-96: Fix misleading comment (mentions checkpoint-10).

Apply:

-    # Remove the checkpoint-10 and checkpoint-last directories
+    # Remove the checkpoint-4 and checkpoint-last directories

146-148: Correct error message to reference accelerate config.

Apply:

-    assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}"
+    assert accelerate_config_path.exists(), f"accelerate config not found at {accelerate_config_path}"

171-172: Increase timeout for slow CI runners.

Four minutes can be tight; 10 minutes is safer for cold environments.

Apply:

-        timeout=240,
+        timeout=600,

202-204: Correct error message to reference accelerate config (multi-GPU).

Apply:

-    assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}"
+    assert accelerate_config_path.exists(), f"accelerate config not found at {accelerate_config_path}"

227-228: Increase timeout for multi-GPU run.

Apply:

-        timeout=240,
+        timeout=600,

151-163: Hydra config naming: consider dropping .yaml extension for --config-name.

Hydra typically expects the name without extension; leaving as-is works if your decorator uses the same string, but normalizing reduces surprises.

Apply (optional):

-        "--config-name",
-        "L0_sanity.yaml",
+        "--config-name",
+        "L0_sanity",
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6d79c72 and c0a5486.

⛔ Files ignored due to path filters (1)
  • recipes/esm2_accelerate/train.parquet is excluded by !**/*.parquet
📒 Files selected for processing (24)
  • models/esm2/src/esm/modeling_esm_te.py (2 hunks)
  • recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (1 hunks)
  • recipes/amplify_accelerate_te_fp8/callbacks.py (1 hunks)
  • recipes/amplify_accelerate_te_fp8/test_train.py (2 hunks)
  • recipes/amplify_accelerate_te_fp8/train.py (1 hunks)
  • recipes/esm2_accelerate/.dockerignore (1 hunks)
  • recipes/esm2_accelerate/.ruff.toml (1 hunks)
  • recipes/esm2_accelerate/Dockerfile (1 hunks)
  • recipes/esm2_accelerate/README.md (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/default.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml (1 hunks)
  • recipes/esm2_accelerate/callbacks.py (1 hunks)
  • recipes/esm2_accelerate/dataset.py (1 hunks)
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1 hunks)
  • recipes/esm2_accelerate/hydra_config/README.md (1 hunks)
  • recipes/esm2_accelerate/hydra_config/defaults.yaml (1 hunks)
  • recipes/esm2_accelerate/metrics.py (1 hunks)
  • recipes/esm2_accelerate/requirements.txt (1 hunks)
  • recipes/esm2_accelerate/slurm.sh (1 hunks)
  • recipes/esm2_accelerate/test_train.py (1 hunks)
  • recipes/esm2_accelerate/train.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (7)
recipes/esm2_accelerate/dataset.py (1)
sub-packages/bionemo-size-aware-batching/tests/bionemo/size_aware_batching/conftest.py (1)
  • sampler (92-93)
recipes/esm2_accelerate/callbacks.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (2)
  • StopAfterNStepsCallback (20-34)
  • on_step_end (31-34)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
recipes/esm2_accelerate/dataset.py (1)
  • create_datasets_and_collator (38-76)
recipes/esm2_accelerate/metrics.py (1)
  • compute_metrics (36-54)
recipes/amplify_accelerate_te_fp8/train.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
recipes/esm2_accelerate/callbacks.py (2)
  • StopAfterNStepsCallback (20-34)
  • on_step_end (31-34)
recipes/amplify_accelerate_te_fp8/test_train.py (1)
recipes/esm2_accelerate/test_train.py (1)
  • test_accelerate_launch_multi_gpu (194-233)
recipes/esm2_accelerate/test_train.py (2)
sub-packages/bionemo-testing/src/bionemo/testing/torch.py (1)
  • check_fp8_support (21-33)
recipes/esm2_accelerate/train.py (1)
  • main (36-72)
🪛 LanguageTool
recipes/esm2_accelerate/hydra_config/README.md

[grammar] ~10-~10: There might be a mistake here.
Context: ...erride fp8 settings given to accelerate. This causes issues with the `deepspeed...

(QB_NEW_EN)

🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh

[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.

(SC1073)


[error] 24-24: Fix any mentioned problems and try again.

(SC1072)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (9)
recipes/esm2_accelerate/.ruff.toml (1)

1-1: Leave extend path unchanged. ../.ruff.toml correctly points to recipes/.ruff.toml (the common config for all recipes); there is no top-level .ruff.toml to reference.

Likely an incorrect or invalid review comment.

recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml (1)

1-22: Confirm Accelerate FSDP settings compatibility and optionally enable sync/init knobs

  • Ensure your installed accelerate version supports fsdp_version: 1 and the fsdp_config keys in fsdp1_te.yaml (e.g. by running a quick local test).
  • To improve initialization determinism and communication behavior, consider adding:
    fsdp_sync_module_states: true
    fsdp_limit_all_gathers: true
  • Note that num_processes: 1 is hardcoded here; override via --num_processes (or other CLI args) for multi-GPU runs.
recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml (1)

1-23: Add optional FSDP2 sync/all-gathers flags and verify Accelerate support

  • Under fsdp_config, add:
    fsdp_sync_module_states: true
    fsdp_limit_all_gathers: true
  • Confirm your installed accelerate release recognizes fsdp_version: 2 (e.g. run
    python -c "import accelerate; print(accelerate.__version__)"
    and check it meets the minimum version that added FSDP2 support).
recipes/esm2_accelerate/metrics.py (2)

35-55: Fix compute_metrics contract and tensor handling.

  • Signature must match HF Trainer: it should accept only EvalPrediction and return a dict.
  • Current code may receive NumPy arrays; torchmetrics.Perplexity expects tensors.
  • Returning None when compute_result=False will break logging.

[sraise_critical_issue -> raise_critical_issue]

-@torch.no_grad()
-def compute_metrics(eval_pred: transformers.EvalPrediction, compute_result: bool):
+@torch.no_grad()
+def compute_metrics(eval_pred: transformers.EvalPrediction):
@@
-    logits, labels = eval_pred
-    logits = nested_cpu(logits)
-    labels = nested_cpu(labels)
-    perplexity(logits, labels)
-
-    if compute_result:
-        loss = perplexity.compute()
-        perplexity.reset()
-        return {"perplexity": loss}
+    logits, labels = eval_pred
+    # Convert possible NumPy inputs to torch tensors on CPU
+    if not isinstance(logits, torch.Tensor):
+        logits = torch.from_numpy(logits)
+    if not isinstance(labels, torch.Tensor):
+        labels = torch.from_numpy(labels)
+    logits = logits.detach().cpu()
+    labels = labels.detach().cpu()
+
+    perplexity(logits, labels)
+    value = perplexity.compute()
+    perplexity.reset()
+    return {"perplexity": float(value) if isinstance(value, torch.Tensor) else value}

23-24: Confirm Perplexity input expectations (logits vs probabilities).

Ensure Perplexity is configured for raw logits; otherwise apply log_softmax or adjust the metric’s normalize/log_prob settings.

Would you like me to check the exact TorchMetrics version semantics and align the code accordingly?

recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1)

6-10: Verify correct FSDP transformer layer class
‘EsmLayer’ isn’t found in the repo and likely isn’t the HF ESM layer name. Confirm the actual HuggingFace ESM transformer layer class (for example, EsmEncoderLayer) and update fsdp_transformer_layer_cls_to_wrap accordingly to ensure TRANSFORMER_BASED_WRAP functions as intended.

recipes/amplify_accelerate_te_fp8/test_train.py (2)

37-41: Good multi-GPU gating.

Skip mark correctly guards for >=2 GPUs.


183-197: Command construction and added trainer flags look good.

List-form subprocess with explicit trainer overrides is solid.

recipes/amplify_accelerate_te_fp8/train.py (1)

47-51: Ignore typo warning The create_datasets_and_collator signature in recipes/amplify_accelerate_te_fp8/dataset.py (line 29) defines pretained_model, matching its invocation in train.py (line 48), so this won’t crash. Rename both occurrences only if you want to correct the spelling.

Likely an incorrect or invalid review comment.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (1)

17-18: Align dynamo_config usage across configs and verify casing.

Other configs in this PR reportedly use a scalar (e.g., dynamo_config: "NO"). Standardize on the mapping form repo-wide; also verify whether your Accelerate version prefers no vs NO for the backend and unify casing.

Run:

#!/bin/bash
# Find scalar (incorrect) usages and mapping (correct) usages
rg -nP -C2 '^\s*dynamo_config:\s*"(?:NO|INDUCTOR|no|inductor)"\s*$' --glob '!**/site-packages/**'
rg -nP -C2 '^\s*dynamo_config:\s*$\n^\s{2,}dynamo_backend:\s*(?:NO|INDUCTOR|no|inductor)\b' -U --glob '!**/site-packages/**'
recipes/esm2_accelerate/train.py (1)

49-57: Trainer will raise TypeError: compute_metrics expects 2 args, Trainer passes 1. Wrap it.

This is not just a typing nuisance; with the current signature it will fail at runtime. Wrap with a lambda (or functools.partial).

Apply this diff:

-        compute_metrics=compute_metrics,
+        compute_metrics=lambda pred: compute_metrics(pred, compute_result=True),
🧹 Nitpick comments (2)
models/esm2/src/esm/modeling_esm_te.py (1)

164-176: Optional: cache per-device ROPE to avoid repeated H2D copies.

Current code copies CPU-pinned ROPE to device every forward. For long seqs this can be a measurable overhead. Cache the full-length per (device, dtype) and slice each step.

Apply within forward:

-        if self.te_rope_emb is not None:
-            te_rope_emb = self.te_rope_emb.to(device=hidden_states.device,
-                                              dtype=hidden_states.dtype,
-                                              non_blocking=True)
+        if self.te_rope_emb is not None:
+            cache_key = (hidden_states.device, hidden_states.dtype)
+            if not hasattr(self, "_te_rope_cache"):
+                self._te_rope_cache = {}
+            te_rope_emb = self._te_rope_cache.get(cache_key)
+            if te_rope_emb is None:
+                te_rope_emb = self.te_rope_emb.to(device=hidden_states.device,
+                                                  dtype=hidden_states.dtype,
+                                                  non_blocking=True)
+                self._te_rope_cache[cache_key] = te_rope_emb
             seq_len = hidden_states.shape[1]
             if te_rope_emb.size(0) < seq_len:
                 raise RuntimeError(
                     f"ROPE length {te_rope_emb.size(0)} < input seq length {seq_len}. "
                     f"Increase max_position_embeddings."
                 )
             te_rope_emb = te_rope_emb[:seq_len]

Note: _te_rope_cache is a plain attribute (not a buffer), so it won’t bloat checkpoints.

recipes/esm2_accelerate/train.py (1)

39-41: Align config.torch_dtype with model dtype to keep TE params consistent.

TE layers read params_dtype from config.torch_dtype. Set it explicitly to match model construction.

Apply this diff:

-    model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
+    config.torch_dtype = torch.bfloat16
+    model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between c0a5486 and fccdc23.

📒 Files selected for processing (3)
  • models/esm2/src/esm/modeling_esm_te.py (2 hunks)
  • recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (1 hunks)
  • recipes/esm2_accelerate/train.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T15:58:24.658Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1080
File: recipes/esm2_accelerate/train.py:47-55
Timestamp: 2025-08-28T15:58:24.658Z
Learning: HuggingFace Transformers has a known typing issue where Trainer expects compute_metrics to have signature Callable[[EvalPrediction], Dict], but functions with additional parameters cause type checker warnings. This is a long-standing typing limitation in the framework, not a runtime error. Workarounds include wrapping functions with functools.partial, lambdas, or type casting.

Applied to files:

  • recipes/esm2_accelerate/train.py
🧬 Code graph analysis (1)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
recipes/esm2_accelerate/dataset.py (1)
  • create_datasets_and_collator (38-76)
recipes/esm2_accelerate/metrics.py (1)
  • compute_metrics (36-54)
🪛 GitHub Actions: BioNeMo Framework CI
models/esm2/src/esm/modeling_esm_te.py

[error] 142-149: Ruff format error: Invalid syntax due to indentation in modeling_esm_te.py (lines 142-149).

recipes/esm2_accelerate/train.py

[error] 44-45: Ruff format error: Invalid syntax/indentation around patch hunks in train.py (lines 44-45).

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (5)
recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (2)

17-18: Correct Accelerate schema for dynamo_config. LGTM.

Using a mapping with dynamo_backend: "NO" is valid and fixes the earlier schema issue.


3-3: distributed_type: MULTI_GPU with num_processes: 1—confirm intent.

Accelerate typically runs one process per GPU for MULTI_GPU. If you intend single-GPU, set distributed_type: "SINGLE_GPU"; if multi-GPU, set num_processes to the GPU count (or override via CLI/launcher).

Also applies to: 10-10

models/esm2/src/esm/modeling_esm_te.py (1)

164-176: Good: dtype/device match + explicit length guard on ROPE.

Casting to hidden_states.dtype, non_blocking move, and the clear error on insufficient ROPE length are all correct and prevent mixed-precision/shape bugs.

recipes/esm2_accelerate/train.py (2)

35-35: Hydra config_name fixed (no .yaml).

Good correction; this avoids config resolution issues.


59-59: LGTM: accelerator state logging.

Safe and useful for debugging distributed setups.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
recipes/amplify_accelerate_te_fp8/train.py (1)

47-51: Rename pretained_model to pretrained_model across dataset and train files

  • In recipes/amplify_accelerate_te_fp8/dataset.py, update the function signature (line 29), docstring (line 36) and internal use (line 46) from pretained_model to pretrained_model.
  • In recipes/amplify_accelerate_te_fp8/train.py (line 48), change the keyword argument to pretrained_model=args.model_tag.
♻️ Duplicate comments (5)
recipes/amplify_accelerate_te_fp8/train.py (1)

27-29: Use package-relative imports for local modules (callbacks, dataset, metrics).

Prevents accidental shadowing on varied PYTHONPATH and ensures intra-package resolution.

-from callbacks import StopAfterNStepsCallback
-from dataset import create_datasets_and_collator
-from metrics import compute_metrics
+from .callbacks import StopAfterNStepsCallback
+from .dataset import create_datasets_and_collator
+from .metrics import compute_metrics

Ensure this directory is a package (has init.py).

recipes/esm2_accelerate/README.md (1)

3-9: README example looks consistent with repo (default.yaml).

The earlier nonexistent config references are gone; this is good.

recipes/esm2_accelerate/slurm.sh (2)

24-25: Invalid placeholder breaks shell; provide safe default and create directory.

Use parameter expansion and mkdir -p.

Apply:

-# Mount a persistent cache directory to cache dataset downloads and transformations.
-export CACHE_DIR=<cache_dir>
+# Mount a persistent cache directory to cache dataset downloads and transformations.
+export CACHE_DIR="${CACHE_DIR:-$PWD/.cache/huggingface}"
+mkdir -p "$CACHE_DIR"

10-15: Pre-flight check: ensure accelerate config exists before srun.

Fail fast with a clear message if missing.

Apply:

 ulimit -c 0
 
+if [ ! -f accelerate_config/fsdp2_te.yaml ]; then
+  echo "Missing accelerate_config/fsdp2_te.yaml"
+  exit 1
+fi
+
 export GPUS_PER_NODE=8
recipes/esm2_accelerate/train.py (1)

27-30: Make imports Hydra-proof (cwd changes).

Use package-qualified imports with fallback.

Apply:

-from callbacks import StopAfterNStepsCallback
-from dataset import create_datasets_and_collator
-from metrics import compute_metrics
+try:
+    from recipes.esm2_accelerate.callbacks import StopAfterNStepsCallback
+    from recipes.esm2_accelerate.dataset import create_datasets_and_collator
+    from recipes.esm2_accelerate.metrics import compute_metrics
+except ImportError:
+    from callbacks import StopAfterNStepsCallback
+    from dataset import create_datasets_and_collator
+    from metrics import compute_metrics
🧹 Nitpick comments (8)
recipes/amplify_accelerate_te_fp8/train.py (1)

55-63: Guard callback construction when stop_after_n_steps is unset/0.

Avoids passing None/0 to the callback and keeps Trainer config cleaner.

-    trainer = Trainer(
+    callbacks = []
+    if getattr(args, "stop_after_n_steps", None):
+        callbacks.append(StopAfterNStepsCallback(args.stop_after_n_steps))
+
+    trainer = Trainer(
         model=model,
         args=training_args,
         train_dataset=train_dataset,
         eval_dataset=eval_dataset,
         compute_metrics=compute_metrics,
         data_collator=data_collator,
-        callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)],
+        callbacks=callbacks,
     )
recipes/esm2_accelerate/dataset.py (3)

66-68: Be explicit: set_transform should not return original columns.

Explicitly set output_all_columns=False to avoid accidental passthrough if defaults change.

Apply:

-    for dataset in [train_dataset, eval_dataset]:
-        dataset.set_transform(tokenize_function)
+    for dataset in (train_dataset, eval_dataset):
+        dataset.set_transform(tokenize_function, output_all_columns=False)

69-73: Redundant collator padding arg (already padded to max_length).

Since tokenization uses padding="max_length", pad_to_multiple_of=max_length is unnecessary.

Apply:

-    data_collator = DataCollatorForLanguageModeling(
-        tokenizer=tokenizer,
-        mlm_probability=0.15,
-        pad_to_multiple_of=max_length,
-    )
+    data_collator = DataCollatorForLanguageModeling(
+        tokenizer=tokenizer,
+        mlm_probability=0.15,
+    )

50-52: Ensure train.parquet is present at runtime or document its requirement
Add a runtime existence check (e.g. if not data_path.exists(): raise FileNotFoundError(...)) or clearly document that recipes/esm2_accelerate/train.parquet must be packaged alongside the code.

recipes/esm2_accelerate/slurm.sh (3)

8-9: Harden shell options.

Add pipefail for safer error propagation.

Apply:

-set -x -e
+set -euo pipefail
+set -x

11-21: Allow configurable main port; avoid hardcoding.

Support MAIN_PORT override.

Apply:

 export GPUS_PER_NODE=8
+export MAIN_PORT="${MAIN_PORT:-12340}"
 export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \
     accelerate launch \
     --config_file accelerate_config/fsdp2_te.yaml \
     --machine_rank \"\$SLURM_NODEID\" \
     --num_machines \"$SLURM_NNODES\" \
     --main_process_ip \"\$SLURM_SRUN_COMM_HOST\" \
-    --main_process_port 12340 \
+    --main_process_port \"$MAIN_PORT\" \
     --num_processes \"$(( $SLURM_NNODES * $GPUS_PER_NODE ))\" \
     train.py
 "

26-29: .netrc mount may fail if file absent; create a zero-byte file with safe perms.

Avoid srun mount errors.

Apply:

-srun \
+touch "$HOME/.netrc" && chmod 600 "$HOME/.netrc"
+srun \
recipes/esm2_accelerate/train.py (1)

38-45: Optional: Pin model revision when trusting remote code.

Passing trust_remote_code=True without a revision reduces reproducibility.

Would you like me to add a model_revision field to hydra_config and thread it through AutoConfig.from_pretrained(..., revision=args.model_revision)?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between fccdc23 and cc60c39.

📒 Files selected for processing (6)
  • recipes/amplify_accelerate_te_fp8/train.py (2 hunks)
  • recipes/esm2_accelerate/README.md (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/default.yaml (1 hunks)
  • recipes/esm2_accelerate/dataset.py (1 hunks)
  • recipes/esm2_accelerate/slurm.sh (1 hunks)
  • recipes/esm2_accelerate/train.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • recipes/esm2_accelerate/accelerate_config/default.yaml
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T15:58:24.658Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1080
File: recipes/esm2_accelerate/train.py:47-55
Timestamp: 2025-08-28T15:58:24.658Z
Learning: HuggingFace Transformers has a known typing issue where Trainer expects compute_metrics to have signature Callable[[EvalPrediction], Dict], but functions with additional parameters cause type checker warnings. This is a long-standing typing limitation in the framework, not a runtime error. Workarounds include wrapping functions with functools.partial, lambdas, or type casting.

Applied to files:

  • recipes/esm2_accelerate/train.py
🧬 Code graph analysis (3)
recipes/esm2_accelerate/dataset.py (1)
sub-packages/bionemo-size-aware-batching/tests/bionemo/size_aware_batching/conftest.py (1)
  • sampler (92-93)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
recipes/esm2_accelerate/dataset.py (1)
  • create_datasets_and_collator (36-75)
recipes/esm2_accelerate/metrics.py (1)
  • compute_metrics (36-54)
recipes/amplify_accelerate_te_fp8/train.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh

[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.

(SC1073)


[error] 24-24: Fix any mentioned problems and try again.

(SC1072)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (1)
recipes/amplify_accelerate_te_fp8/train.py (1)

82-83: Good: safely destroy the process group only when initialized.

Prevents errors on non-distributed runs and aligns with best practices.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/esm2-accelerate-tests branch from 738690a to 31e9681 Compare August 28, 2025 18:10
@pstjohn pstjohn enabled auto-merge August 28, 2025 18:16
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
recipes/amplify_accelerate_te_fp8/dataset.py (1)

35-43: Docstring mismatch for data_size parameter
Update the docstring to replace "small" with "parquet" and to note that "sanity" selects 10 examples, or adjust the code if you intend 100:

-        data_size: The size of the dataset to load. If "full", use and pre-process the full UR100P
-            CSV dataset. This takes a long time without a cached dataset. If "small", use and
-            pre-process the parquet version of the dataset, which is much faster than "full". If
-            "sanity", truncates the evaluation dataset to 100 examples.
+        data_size: The size of the dataset to load. If "full", use and pre-process the full UR100P
+            CSV dataset. This takes a long time without a cached dataset. If "parquet", use and
+            pre-process the parquet version of the dataset, which is much faster than "full". If
+            "sanity", truncates the evaluation dataset to 10 examples.

Alternatively, if you want "sanity" to yield 100 examples, change .select(range(10)) to .select(range(100)).

recipes/esm2_accelerate/metrics.py (1)

35-55: Fix Trainer incompatibility: wrong signature and numpy inputs cause runtime error

Hugging Face’s Trainer calls compute_metrics(eval_pred) with numpy arrays. Your current signature requires a second arg and passes numpy arrays to a TorchMetrics module — this will fail. Make it single‑arg, convert numpy→torch, and compute perplexity directly.

Apply:

-@torch.no_grad()
-def compute_metrics(eval_pred: transformers.EvalPrediction, compute_result: bool):
-    """Compute perplexity metrics for the evaluation set.
- 
-    Args:
-        eval_pred: A tuple containing the logits and labels for the evaluation set.
-        compute_result: A boolean indicating whether to compute the perplexity metrics.
-
-    Returns:
-        A dictionary containing the perplexity metrics.
-    """
-    logits, labels = eval_pred
-    logits = nested_cpu(logits)
-    labels = nested_cpu(labels)
-    perplexity(logits, labels)
-
-    if compute_result:
-        loss = perplexity.compute()
-        perplexity.reset()
-        return {"perplexity": loss}
+@torch.no_grad()
+def compute_metrics(eval_pred: transformers.EvalPrediction):
+    """Return {"perplexity": float} computed from logits/labels provided by Trainer."""
+    # EvalPrediction exposes .predictions and .label_ids; Trainer provides numpy arrays
+    logits = getattr(eval_pred, "predictions", eval_pred[0])
+    labels = getattr(eval_pred, "label_ids", eval_pred[1])
+    # numpy -> torch on CPU
+    if not isinstance(logits, torch.Tensor):
+        logits = torch.from_numpy(logits)
+    if not isinstance(labels, torch.Tensor):
+        labels = torch.from_numpy(labels)
+    # Flatten to (N*L, V) vs (N*L,) and ignore pad label -100
+    vocab = logits.shape[-1]
+    loss = torch.nn.functional.cross_entropy(
+        logits.view(-1, vocab), labels.view(-1), ignore_index=-100
+    )
+    ppl = torch.exp(loss).item()
+    return {"perplexity": ppl}
♻️ Duplicate comments (6)
recipes/esm2_accelerate/accelerate_config/default.yaml (1)

17-18: Resolved: dynamo_config is now a mapping; backend value is valid.

Switching to:

dynamo_config:
  dynamo_backend: "NO"

matches accelerate’s schema and accepted enum values. ✔️ (huggingface.co)

recipes/amplify_accelerate_te_fp8/train.py (1)

27-27: Use a package-relative import for callbacks (duplicate).

Avoid shadowing top-level modules; import from the recipe package.

Apply:

-from callbacks import StopAfterNStepsCallback
+from .callbacks import StopAfterNStepsCallback
recipes/esm2_accelerate/train.py (3)

27-30: Make imports Hydra-proof (CWD changes)

Use package-qualified imports with a fallback so imports survive Hydra’s chdir.

Apply:

-from callbacks import StopAfterNStepsCallback
-from dataset import create_datasets_and_collator
-from metrics import compute_metrics
+try:
+    from recipes.esm2_accelerate.callbacks import StopAfterNStepsCallback
+    from recipes.esm2_accelerate.dataset import create_datasets_and_collator
+    from recipes.esm2_accelerate.metrics import compute_metrics
+except ImportError:
+    from callbacks import StopAfterNStepsCallback  # type: ignore
+    from dataset import create_datasets_and_collator  # type: ignore
+    from metrics import compute_metrics  # type: ignore

41-41: torch_dtype is ignored by from_config

Remove torch_dtype here and cast after creating TrainingArguments (or rely on training_args.bf16).

Apply:

-model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True, torch_dtype=torch.bfloat16)
+model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)

Then after TrainingArguments:

 training_args = TrainingArguments(**args.trainer)
+
+if getattr(training_args, "bf16", False):
+    model.to(torch.bfloat16)

50-58: Trainer will crash: compute_metrics signature mismatch

You pass compute_metrics as-is, but the function currently expects two args. Bind the extra param or (preferred) adopt the single-arg version suggested in metrics.py.

If you keep the two-arg function, bind it:

-    trainer = Trainer(
+    compute_metrics_fn = functools.partial(compute_metrics)  # bind any extras if present
+    trainer = Trainer(
         model=model,
         args=training_args,
         train_dataset=train_dataset,
         eval_dataset=eval_dataset,
-        compute_metrics=compute_metrics,
+        compute_metrics=compute_metrics_fn,
         data_collator=data_collator,
         callbacks=[StopAfterNStepsCallback(args.stop_after_n_steps)],
     )

If you apply the metrics.py refactor (single-arg), simply keep compute_metrics=compute_metrics.

recipes/esm2_accelerate/slurm.sh (1)

24-24: Invalid placeholder breaks shell

export CACHE_DIR=<cache_dir> is invalid and fails ShellCheck. Provide a sane default and ensure it exists.

Apply:

-export CACHE_DIR=<cache_dir>
+export CACHE_DIR="${CACHE_DIR:-$PWD/.cache/huggingface}"
+mkdir -p "$CACHE_DIR"
🧹 Nitpick comments (14)
recipes/esm2_accelerate/hydra_config/README.md (3)

1-1: Clarify the title for context.

Make it explicit that these configs are for the esm2_accelerate recipe.

-# Hydra configs
+# ESM-2 Accelerate Hydra configs

6-6: Grammar/clarity.

Tighten wording; “maps” reads better than “gets mapped.”

-The `trainer` section gets mapped directly to the `TrainingArguments` class.
+The `trainer` section maps directly to the `TrainingArguments` class.

12-12: Add “Tested versions” to de-risk config drift.

Document the exact package versions this note applies to (transformers, accelerate, deepspeed) so future upgrades don’t invalidate the guidance.

Proposed addition at end of file:

## Tested versions
This behavior was validated with:
- transformers: X.Y.Z
- accelerate: A.B.C
- deepspeed: D.E.F

Update this section when bumping versions.
recipes/esm2_accelerate/accelerate_config/default.yaml (1)

5-7: Make GPU selection explicit for reproducibility.

Add gpu_ids: all so launches deterministically use all local GPUs unless overridden (common in accelerate-generated configs). (github.com, cnblogs.com)

 enable_cpu_affinity: false
+gpu_ids: all
 machine_rank: 0
recipes/amplify_accelerate_te_fp8/dataset.py (2)

86-88: Dropping raw columns via set_transform: confirm downstream expectations.

set_transform(tokenize) returns only tokenized fields; raw "sequence"/"name" are no longer yielded. Ensure nothing downstream (metrics, debugging) expects those columns; otherwise, keep them out intentionally to avoid string tensors in batches.


62-66: Ensure local train.parquet is packaged or add a fallback
The train.parquet file is present under recipes/amplify_accelerate_te_fp8—confirm it’s included in CI and package builds. Optionally implement a Hub-based fallback for clean environments or document that this local file is required.

recipes/esm2_accelerate/metrics.py (3)

18-21: Add numpy import for safe conversion path

Needed by the revised compute_metrics if you prefer np checks; otherwise you can omit.

Apply:

-import torch
+import torch
+import numpy as np  # optional; keep if you add np-specific logic

23-23: Avoid global mutable metric state

The global Perplexity instance isn’t needed with the stateless CE→exp computation and can introduce state bleed. Remove it or ensure exclusive use.

Apply:

-perplexity = Perplexity(ignore_index=-100, sync_on_compute=False)
+# No global metric needed; compute perplexity from CE in compute_metrics.

26-33: nested_cpu doesn’t handle numpy arrays

If you keep nested_cpu, consider converting np.ndarray to torch tensors here to be robust when used elsewhere.

Apply:

 def nested_cpu(tensors):
     """Move `tensors` to CPU (even if it's a nested list/tuple/dict of tensors)."""
     if isinstance(tensors, (list, tuple)):
         return type(tensors)(nested_cpu(t) for t in tensors)
     elif isinstance(tensors, Mapping):
         return type(tensors)({k: nested_cpu(t) for k, t in tensors.items()})
-    return tensors.detach().cpu() if isinstance(tensors, torch.Tensor) else tensors
+    if isinstance(tensors, torch.Tensor):
+        return tensors.detach().cpu()
+    try:
+        import numpy as np  # local import to avoid hard dep if unused
+        if isinstance(tensors, np.ndarray):
+            return torch.from_numpy(tensors).cpu()
+    except Exception:
+        pass
+    return tensors
recipes/esm2_accelerate/train.py (1)

19-25: Import functools for wrapping compute_metrics

Needed for the binding below.

Apply:

 import hydra
+import functools
 import torch
recipes/esm2_accelerate/slurm.sh (2)

3-3: Misleading comment for ntasks-per-node

You set 1 task per node, but the comment says one per GPU. Clarify to avoid ops confusion.

Apply:

-#SBATCH --ntasks-per-node=1    	          # n tasks per machine (one task per gpu) <required>
+#SBATCH --ntasks-per-node=1    	          # one task per node (accelerate spawns per-GPU processes inside)

11-21: Harden distributed args and config path

  • Validate accelerate_config exists before srun.
  • Allow overriding the main port to avoid collisions on shared clusters.

Apply:

 export GPUS_PER_NODE=8
-export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \
+MAIN_PORT="\${MAIN_PORT:-12340}"
+[ -f accelerate_config/fsdp2_te.yaml ] || { echo 'Missing accelerate_config/fsdp2_te.yaml' >&2; exit 1; }
+export CMD="TRITON_CACHE_DIR=/tmp/triton_cache \
     accelerate launch \
     --config_file accelerate_config/fsdp2_te.yaml \
-    --machine_rank \"\$SLURM_NODEID\" \
+    --machine_rank \"\$SLURM_NODEID\" \
     --num_machines \"$SLURM_NNODES\" \
-    --main_process_ip \"\$SLURM_SRUN_COMM_HOST\" \
-    --main_process_port 12340 \
+    --main_process_ip \"\$SLURM_SRUN_COMM_HOST\" \
+    --main_process_port $MAIN_PORT \
     --num_processes \"$(( $SLURM_NNODES * $GPUS_PER_NODE ))\" \
     train.py
 "
recipes/esm2_accelerate/test_train.py (2)

146-148: Fix assertion message

Message references deepspeed_config.yaml but you’re checking accelerate_config_path.

Apply:

-    assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}"
+    assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"

171-172: CI flakiness: bump timeout

240s is tight when kernels/JIT kick in. Consider 600s to reduce flakes.

Apply:

-        timeout=240,
+        timeout=600,

Also applies to: 227-228

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 738690a and 31e9681.

⛔ Files ignored due to path filters (1)
  • recipes/esm2_accelerate/train.parquet is excluded by !**/*.parquet
📒 Files selected for processing (25)
  • models/esm2/src/esm/modeling_esm_te.py (2 hunks)
  • recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (1 hunks)
  • recipes/amplify_accelerate_te_fp8/callbacks.py (1 hunks)
  • recipes/amplify_accelerate_te_fp8/dataset.py (3 hunks)
  • recipes/amplify_accelerate_te_fp8/test_train.py (2 hunks)
  • recipes/amplify_accelerate_te_fp8/train.py (3 hunks)
  • recipes/esm2_accelerate/.dockerignore (1 hunks)
  • recipes/esm2_accelerate/.ruff.toml (1 hunks)
  • recipes/esm2_accelerate/Dockerfile (1 hunks)
  • recipes/esm2_accelerate/README.md (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/default.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml (1 hunks)
  • recipes/esm2_accelerate/callbacks.py (1 hunks)
  • recipes/esm2_accelerate/dataset.py (1 hunks)
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1 hunks)
  • recipes/esm2_accelerate/hydra_config/README.md (1 hunks)
  • recipes/esm2_accelerate/hydra_config/defaults.yaml (1 hunks)
  • recipes/esm2_accelerate/metrics.py (1 hunks)
  • recipes/esm2_accelerate/requirements.txt (1 hunks)
  • recipes/esm2_accelerate/slurm.sh (1 hunks)
  • recipes/esm2_accelerate/test_train.py (1 hunks)
  • recipes/esm2_accelerate/train.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • recipes/esm2_accelerate/.dockerignore
🚧 Files skipped from review as they are similar to previous changes (16)
  • recipes/esm2_accelerate/requirements.txt
  • recipes/esm2_accelerate/README.md
  • models/esm2/src/esm/modeling_esm_te.py
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml
  • recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml
  • recipes/esm2_accelerate/dataset.py
  • recipes/esm2_accelerate/Dockerfile
  • recipes/esm2_accelerate/.ruff.toml
  • recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml
  • recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml
  • recipes/esm2_accelerate/hydra_config/defaults.yaml
  • recipes/amplify_accelerate_te_fp8/test_train.py
  • recipes/esm2_accelerate/callbacks.py
  • recipes/amplify_accelerate_te_fp8/callbacks.py
  • recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml
  • recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-28T15:58:24.658Z
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1080
File: recipes/esm2_accelerate/train.py:47-55
Timestamp: 2025-08-28T15:58:24.658Z
Learning: HuggingFace Transformers has a known typing issue where Trainer expects compute_metrics to have signature Callable[[EvalPrediction], Dict], but functions with additional parameters cause type checker warnings. This is a long-standing typing limitation in the framework, not a runtime error. Workarounds include wrapping functions with functools.partial, lambdas, or type casting.

Applied to files:

  • recipes/esm2_accelerate/train.py
🧬 Code graph analysis (3)
recipes/esm2_accelerate/train.py (3)
recipes/esm2_accelerate/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
recipes/esm2_accelerate/dataset.py (1)
  • create_datasets_and_collator (36-74)
recipes/esm2_accelerate/metrics.py (1)
  • compute_metrics (36-54)
recipes/esm2_accelerate/test_train.py (2)
recipes/esm2_accelerate/train.py (1)
  • main (36-75)
recipes/amplify_accelerate_te_fp8/test_train.py (1)
  • test_accelerate_launch (169-210)
recipes/amplify_accelerate_te_fp8/train.py (1)
recipes/amplify_accelerate_te_fp8/callbacks.py (1)
  • StopAfterNStepsCallback (20-34)
🪛 LanguageTool
recipes/esm2_accelerate/hydra_config/README.md

[grammar] ~10-~10: There might be a mistake here.
Context: ...erride fp8 settings given to accelerate. This causes issues with the `deepspeed...

(QB_NEW_EN)

🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh

[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.

(SC1073)


[error] 24-24: Fix any mentioned problems and try again.

(SC1072)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: unit-tests (models/esm2, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (recipes/esm2_accelerate, nvcr.io/nvidia/pytorch:25.06-py3)
  • GitHub Check: unit-tests (recipes/amplify_accelerate_te_fp8, nvcr.io/nvidia/pytorch:25.06-py3)
🔇 Additional comments (6)
recipes/esm2_accelerate/accelerate_config/default.yaml (3)

4-4: downcast_bf16: 'no' looks correct.

String-quoting avoids YAML bool coercion and matches accelerate examples. (github.com)


13-16: TPU fields are consistent with accelerate defaults.

tpu_env: [], tpu_use_cluster: false, and tpu_use_sudo: false match typical defaults; no change needed. (github.com)


3-10: Align MULTI_GPU mode with num_processes
distributed_type: MULTI_GPU with num_processes: 1 launches only one process (one GPU) by default, which defeats the multi-GPU intent. Either:

  • Switch to a single-GPU default (e.g. distributed_type: NO or SINGLE_GPU, num_processes: 1)
  • Or set num_processes to your per-node GPU count (e.g. 8) for a true multi-GPU default
  • Or split into dedicated configs (e.g. default_single.yaml, default_multi8.yaml) and reference them in your CI/SLURM scripts
    Please confirm which variant this recipe should use in CI/SLURM.
recipes/amplify_accelerate_te_fp8/dataset.py (1)

28-33: Param rename LGTM.

The pretained_modelpretrained_model fix removes a distracting typo and aligns call sites.

recipes/amplify_accelerate_te_fp8/train.py (2)

47-51: Arg rename usage LGTM.

Call site correctly updated to pretrained_model=args.model_tag.


82-83: Guarded PG destroy LGTM.

Safer in single-process and CPU-only runs; avoids errors if dist not initialized.

@pstjohn
Copy link
Collaborator Author

pstjohn commented Aug 28, 2025

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 28, 2025

✅ Actions performed

Comments resolved.

@pstjohn pstjohn added this pull request to the merge queue Aug 28, 2025
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to a conflict with the base branch Aug 28, 2025
@pstjohn pstjohn added this pull request to the merge queue Aug 28, 2025
pstjohn and others added 8 commits August 28, 2025 13:44
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Signed-off-by: Peter St. John <peterc.stjohn@gmail.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/esm2-accelerate-tests branch from 31e9681 to d68627a Compare August 28, 2025 20:47
@pstjohn pstjohn removed this pull request from the merge queue due to a manual request Aug 28, 2025
@pstjohn pstjohn enabled auto-merge August 28, 2025 20:47
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 23

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
recipes/esm2_native_te_nvfsdp_thd/train.py (1)

149-154: DDP does not accept device_mesh kwarg — remove to avoid TypeError.
torch.nn.parallel.DistributedDataParallel has no device_mesh parameter.

Apply this diff:

         model = torch.nn.parallel.DistributedDataParallel(
             model,
             device_ids=[dist_config.local_rank],
             output_device=dist_config.local_rank,
-            device_mesh=device_mesh["fsdp"],
         )
recipes/esm2_native_te_mfsdp/scheduler.py (1)

19-46: Input validation and small-step robustness

Guard invalid configs and ensure decay window is ≥1 step after warmup.

 def get_linear_schedule_with_warmup(
@@
 ):
@@
-    decay_steps = int(num_training_steps * 0.9)
+    if num_training_steps <= 0:
+        raise ValueError("num_training_steps must be > 0")
+    if num_warmup_steps < 0 or num_warmup_steps >= num_training_steps:
+        raise ValueError("Require 0 <= num_warmup_steps < num_training_steps")
+    decay_steps = max(num_warmup_steps + 1, int(num_training_steps * 0.9))
@@
-        elif current_step > decay_steps:
+        elif current_step > decay_steps:
             return 0.1  # one tenth of peak learning rate after decay period
recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)

157-158: FSDP-aware grad clipping needed.

clip_grad_norm_(model.parameters()) computes a local (shard) norm under FSDP and misreports/clips incorrectly across ranks.

Option A (preferred): use the FSDP/megatron wrapper’s global-clip utility if available.

Option B: reduce the squared-norm across the fsdp mesh group:

fsdp_pg = device_mesh["fsdp"].get_group()
local_sq = torch.zeros(1, device=device)
for p in model.parameters():
    if p.grad is not None:
        local_sq += p.grad.detach().float().norm(2).pow(2)
dist.all_reduce(local_sq, group=fsdp_pg)
total_norm = local_sq.sqrt().item()
if total_norm > 1.0:
    scale = 1.0 / (total_norm + 1e-6)
    for p in model.parameters():
        if p.grad is not None:
            p.grad.mul_(scale)
recipes/esm2_native_te_mfsdp/train_fsdp2.py (1)

114-118: Finalize initialization: tie LM head weights when configured.

After to_empty + manual resets, HF’s weight tying isn’t guaranteed to run.

     for module in model.modules():
         if hasattr(module, "reset_parameters"):
             module.reset_parameters()
+    if getattr(model.config, "tie_word_embeddings", False) and hasattr(model, "tie_weights"):
+        model.tie_weights()
♻️ Duplicate comments (8)
recipes/esm2_accelerate/hydra_config/README.md (2)

3-4: Fix broken phrase and reference provided configs.

Correct the split “partial conv … experiments” and point readers to the starter configs.

Apply this diff:

-We could expand this to include configs for full-scale convergence testing, partial conv
-experiments, etc.
+We can expand this to include configs for full-scale convergence testing and partial convergence experiments.
+See defaults.yaml and L0_sanity.yaml in this directory for starting points.

8-12: BF16 vs FP8 note: fix capitalization and add concrete guidance to avoid Accelerate/DeepSpeed conflicts.

Apply this diff:

-Notes:
-
-- Specifying `bf16` in the `TrainingArguments` class will override fp8 settings given to accelerate.
-  This causes issues with the `deepspeed` backend, since HF will check to make sure the bf16
-  settings are the same between HF and Deepspeed settings.
+Notes:
+
+- Specifying `bf16` in `TrainingArguments` will override FP8 settings provided to Accelerate.
+  This can cause issues with the DeepSpeed backend because Transformers validates that BF16
+  settings are consistent between Transformers and DeepSpeed.
+
+Recommended:
+- For FP8 runs via Accelerate/DeepSpeed, do not set `trainer.bf16`; keep it `false` and ensure DeepSpeed BF16 is disabled.
+- For BF16 runs, set BF16 consistently in both `trainer` and the DeepSpeed config.

Example Hydra snippets:

# FP8 run (disable BF16 everywhere)
trainer:
  bf16: false
# deepspeed_config:
#   bf16:
#     enabled: false

# BF16 run (enable in both places)
trainer:
  bf16: true
# deepspeed_config:
#   bf16:
#     enabled: true
recipes/esm2_accelerate/slurm.sh (3)

24-25: Fix invalid placeholder that breaks parsing; set safe default and create the dir.

The literal <cache_dir> causes SC1073/SC1072 and makes the script unusable.

Apply:

-# Mount a persistent cache directory to cache dataset downloads and transformations.
-export CACHE_DIR=<cache_dir>
+# Mount a persistent cache directory to cache dataset downloads and transformations.
+export CACHE_DIR="${CACHE_DIR:-$PWD/.cache/huggingface}"
+mkdir -p "$CACHE_DIR"

14-15: Pre-flight check: ensure accelerate config exists before srun.

Fail fast if the referenced config file is missing.

Apply near the top (after ulimit):

+test -f accelerate_config/fsdp2_te.yaml || {
+  echo "Missing accelerate_config/fsdp2_te.yaml"; exit 1;
+}

26-29: cd into repo inside container; quote mounts; kill on bad exit.

Without cd, train.py may not resolve; unquoted mounts can break on spaces; add kill-on-bad-exit for hygiene.

Apply:

+WORKDIR="/workspace/bionemo/recipes/esm2_accelerate"
+# Build mounts safely; .netrc may not exist
+MOUNTS="${PWD}:/workspace/bionemo,${CACHE_DIR}:/root/.cache/huggingface"
+[ -f "$HOME/.netrc" ] && MOUNTS="$MOUNTS,$HOME/.netrc:/root/.netrc"
+
 srun \
-  --container-image=<image_name> \
-  --container-mounts=${PWD}:/workspace/bionemo,$HOME/.netrc:/root/.netrc,$CACHE_DIR:/root/.cache/huggingface \
-  bash -c "$CMD"
+  --container-image="$IMAGE" \
+  --container-mounts="$MOUNTS" \
+  --kill-on-bad-exit=1 \
+  bash -lc "cd \"$WORKDIR\" && $CMD"
recipes/esm2_accelerate/test_train.py (3)

32-32: Call the undecorated runner, not a Hydra-decorated main.

Prevents decorator side-effects (cwd changes, Hydra re-inits) when passing a composed config.

-from train import main
+from train import run
-    main(sanity_config)
+    run(sanity_config)
-    main(sanity_config)
+    run(sanity_config)

Also applies to: 60-60, 101-101


159-161: Hydra CLI: drop “.yaml” in --config-name.

-        "--config-name",
-        "L0_sanity.yaml",
+        "--config-name",
+        "L0_sanity",

215-217: Hydra CLI: drop “.yaml” in --config-name.

-        "--config-name",
-        "L0_sanity.yaml",
+        "--config-name",
+        "L0_sanity",
🧹 Nitpick comments (24)
recipes/esm2_accelerate/hydra_config/README.md (1)

6-6: Clarify mapping to Transformers’ TrainingArguments.

Apply this diff:

-The `trainer` section gets mapped directly to the `TrainingArguments` class.
+The `trainer` section maps 1:1 to the Hugging Face Transformers `TrainingArguments` class.
recipes/esm2_accelerate/slurm.sh (2)

3-3: Fix misleading comment for ntasks-per-node.

The script runs one srun task per node; Accelerate then spawns per-GPU workers. Update the comment to avoid confusion.

Apply:

-#SBATCH --ntasks-per-node=1    	          # n tasks per machine (one task per gpu) <required>
+#SBATCH --ntasks-per-node=1                  # one task per node; Accelerate spawns $GPUS_PER_NODE procs

20-21: Guard against missing train.py in working dir.

Catch mistakes early if the repo layout changes.

Apply (place before srun):

+[ -f "train.py" ] || { echo "train.py not found in \$PWD; ensure WORKDIR is correct."; }

If you adopt WORKDIR from the other suggestion, use: [ -f "$WORKDIR/train.py" ] || { ...; exit 1; }.

recipes/esm2_accelerate/test_train.py (5)

146-148: Fix assertion message (not a deepspeed config).

-    assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}"
+    assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"

171-172: Make timeout configurable and slightly higher to reduce flakiness.

-        timeout=240,
+        timeout=int(os.getenv("BIONEMO_TEST_TIMEOUT", "300")),

Place near the imports:

# near other imports
# Optional: allow CI to extend timeouts
# TIMEOUT = int(os.getenv("BIONEMO_TEST_TIMEOUT", "300"))

202-204: Fix assertion message (not a deepspeed config).

-    assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}"
+    assert accelerate_config_path.exists(), f"{accelerate_config} not found at {accelerate_config_path}"

54-55: Avoid fixed MASTER_PORT to reduce EADDRINUSE in parallel CI.

Example:

-    monkeypatch.setenv("MASTER_PORT", "29500")
+    import socket
+    with socket.socket() as s:
+        s.bind(("", 0))
+        monkeypatch.setenv("MASTER_PORT", str(s.getsockname()[1]))

57-59: Optional: lock down steps to make checkpoint count deterministic.

-        sanity_config = compose(config_name="L0_sanity", overrides=[f"trainer.output_dir={tmp_path}"])
+        sanity_config = compose(
+            config_name="L0_sanity",
+            overrides=[
+                f"trainer.output_dir={tmp_path}",
+                "trainer.save_steps=2",
+                "stop_after_n_steps=5",
+            ],
+        )
recipes/esm2_native_te_mfsdp/.dockerignore (2)

1-9: Trim build context: add common heavy paths and artifacts to .dockerignore

Add typical large/secret directories and model artifacts to cut build time and reduce risk.

 Dockerfile
 README.md
 checkpoint_export/
 outputs/
 .ruff_cache
 __pycache__
 .pytest_cache
 .ruff.toml
 .dockerignore
+ .git
+ .venv/
+ env/
+ .DS_Store
+ notebooks/
+ data/
+ datasets/
+ logs/
+ wandb/
+ *.pt
+ *.ckpt
+ *.bin
+ *.safetensors
+ .cache/huggingface
+ .cache/torch
+ **/.ipynb_checkpoints

2-2: Reconsider excluding README.md

If you want docs inside the image (e.g., for debugging/help), keep README.md in context.

recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml (1)

18-19: Prefer fp32 gradient reduce for bf16 stability

Unless you have a measured reason, reduce grads in fp32.

-  grad_reduce_in_fp32: false
+  grad_reduce_in_fp32: true
recipes/esm2_native_te_mfsdp/Dockerfile (2)

4-7: Harden pip install step (cache concurrency, smaller layers)

Lock the cache and avoid persisting pip wheels into the image layer.

-RUN --mount=type=secret,id=netrc,target=/root/.netrc \
-    --mount=type=cache,target=/root/.cache/pip \
+RUN --mount=type=secret,id=netrc,target=/root/.netrc \
+    --mount=type=cache,target=/root/.cache/pip,sharing=locked \
     --mount=type=bind,source=requirements.txt,target=/requirements.txt \
-    PIP_CONSTRAINT= pip install -r /requirements.txt
+    PIP_CONSTRAINT= pip install -r /requirements.txt --progress-bar off --no-cache-dir

9-10: Non-root and cache dirs for HF/torch downloads

Running as root in containers is discouraged; also set cache envs for reuse across runs.

 WORKDIR /workspace/bionemo
-COPY . .
+COPY --link . .

Additional lines to add (outside this hunk):

ENV HF_HOME=/workspace/.cache/huggingface \
    TRANSFORMERS_CACHE=/workspace/.cache/huggingface \
    TORCH_HOME=/workspace/.cache/torch
RUN mkdir -p $HF_HOME $TORCH_HOME && chmod -R a+rwx /workspace/.cache
# Optional: switch to non-root user if your runtime permits
# RUN useradd -m -u 1000 app && chown -R 1000:1000 /workspace
# USER 1000:1000
recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml (1)

16-16: Add a tiny warmup to reduce early-step instability

For 250 steps, 10% warmup is reasonable; even 25 steps helps.

-  num_warmup_steps: 0
+  num_warmup_steps: 25
recipes/esm2_native_te_mfsdp/scheduler.py (1)

19-25: Name collision with transformers.get_linear_schedule_with_warmup

Same name, different behavior (plateau at 0.1). Consider renaming or clarifying in docs to avoid confusion during imports.

recipes/esm2_native_te_mfsdp/train_mfsdp.py (3)

162-162: Use set_to_none=True for cheaper zeroing.

Reduces memory writes and can speed up training.

-        optimizer.zero_grad()
+        optimizer.zero_grad(set_to_none=True)

64-66: Docstring return type should include None.

Non-main ranks return None.

-        float: The loss value for the final batch.
+        float | None: The loss on global rank 0; None on other ranks.

191-192: Use the already-detached loss_value for the postfix.

Avoids extra device sync.

-            progress_bar.set_postfix({"loss": loss.item()})
+            progress_bar.set_postfix({"loss": loss_value})
recipes/esm2_native_te_mfsdp/train_ddp.py (2)

144-144: Cheaper grad zeroing.

-        optimizer.zero_grad()
+        optimizer.zero_grad(set_to_none=True)

172-173: Use detached value for tqdm postfix.

-            progress_bar.set_postfix({"loss": loss.item()})
+            progress_bar.set_postfix({"loss": loss_value})
recipes/esm2_native_te_mfsdp/train_fsdp2.py (1)

65-66: Docstring return type should include None.

-        float: The loss value for the final batch.
+        float | None: The loss on global rank 0; None on other ranks.
recipes/esm2_native_te_mfsdp/test_train.py (3)

81-86: Harden cleanup of private mesh resources.

These are private internals; guard with hasattr to avoid AttributeErrors across versions.

-    _mesh_resources.mesh_stack.clear()
-    _mesh_resources.child_to_root_mapping.clear()
-    _mesh_resources.root_to_flatten_mapping.clear()
-    _mesh_resources.flatten_name_to_root_dims.clear()
-    _mesh_resources.mesh_dim_group_options.clear()
+    for attr in (
+        "mesh_stack",
+        "child_to_root_mapping",
+        "root_to_flatten_mapping",
+        "flatten_name_to_root_dims",
+        "mesh_dim_group_options",
+    ):
+        if hasattr(_mesh_resources, attr):
+            getattr(_mesh_resources, attr).clear()

183-183: Fix misleading docstrings (use torchrun, not accelerate).

-    # Run 'accelerate launch train.py' as a subprocess
+    # Run with torchrun as a subprocess

Also applies to: 200-200, 220-220, 238-238


49-56: Consider a slightly higher timeout for slower CI runners.

240s can be tight under GPU resource contention; 360–600s is safer.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 31e9681 and d68627a.

⛔ Files ignored due to path filters (2)
  • recipes/esm2_accelerate/train.parquet is excluded by !**/*.parquet
  • recipes/esm2_native_te_mfsdp/train.parquet is excluded by !**/*.parquet
📒 Files selected for processing (42)
  • .devcontainer/recipes/requirements.txt (1 hunks)
  • models/esm2/src/esm/modeling_esm_te.py (2 hunks)
  • recipes/README.md (2 hunks)
  • recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml (1 hunks)
  • recipes/amplify_accelerate_te_fp8/callbacks.py (1 hunks)
  • recipes/amplify_accelerate_te_fp8/dataset.py (3 hunks)
  • recipes/amplify_accelerate_te_fp8/test_train.py (2 hunks)
  • recipes/amplify_accelerate_te_fp8/train.py (3 hunks)
  • recipes/esm2_accelerate/Dockerfile (0 hunks)
  • recipes/esm2_accelerate/README.md (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/default.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml (1 hunks)
  • recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml (1 hunks)
  • recipes/esm2_accelerate/callbacks.py (1 hunks)
  • recipes/esm2_accelerate/dataset.py (1 hunks)
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1 hunks)
  • recipes/esm2_accelerate/hydra_config/README.md (1 hunks)
  • recipes/esm2_accelerate/hydra_config/defaults.yaml (1 hunks)
  • recipes/esm2_accelerate/metrics.py (1 hunks)
  • recipes/esm2_accelerate/requirements.txt (1 hunks)
  • recipes/esm2_accelerate/slurm.sh (1 hunks)
  • recipes/esm2_accelerate/test_train.py (1 hunks)
  • recipes/esm2_accelerate/train.py (1 hunks)
  • recipes/esm2_native_te_mfsdp/.dockerignore (1 hunks)
  • recipes/esm2_native_te_mfsdp/.ruff.toml (1 hunks)
  • recipes/esm2_native_te_mfsdp/Dockerfile (1 hunks)
  • recipes/esm2_native_te_mfsdp/README.md (1 hunks)
  • recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml (2 hunks)
  • recipes/esm2_native_te_mfsdp/hydra_config/L1_15B_perf_test.yaml (1 hunks)
  • recipes/esm2_native_te_mfsdp/hydra_config/L1_650M.yaml (0 hunks)
  • recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml (1 hunks)
  • recipes/esm2_native_te_mfsdp/requirements.txt (1 hunks)
  • recipes/esm2_native_te_mfsdp/scheduler.py (1 hunks)
  • recipes/esm2_native_te_mfsdp/test_train.py (1 hunks)
  • recipes/esm2_native_te_mfsdp/train_ddp.py (1 hunks)
  • recipes/esm2_native_te_mfsdp/train_fsdp2.py (8 hunks)
  • recipes/esm2_native_te_mfsdp/train_mfsdp.py (8 hunks)
  • recipes/esm2_native_te_nvfsdp/README.md (0 hunks)
  • recipes/esm2_native_te_nvfsdp/test_train.py (0 hunks)
  • recipes/esm2_native_te_nvfsdp_thd/train.py (1 hunks)
💤 Files with no reviewable changes (4)
  • recipes/esm2_native_te_nvfsdp/README.md
  • recipes/esm2_native_te_mfsdp/hydra_config/L1_650M.yaml
  • recipes/esm2_native_te_nvfsdp/test_train.py
  • recipes/esm2_accelerate/Dockerfile
✅ Files skipped from review due to trivial changes (2)
  • recipes/esm2_native_te_mfsdp/README.md
  • recipes/esm2_accelerate/accelerate_config/default.yaml
🚧 Files skipped from review as they are similar to previous changes (17)
  • recipes/esm2_accelerate/hydra_config/defaults.yaml
  • recipes/esm2_accelerate/accelerate_config/fsdp1_hf.yaml
  • recipes/amplify_accelerate_te_fp8/dataset.py
  • recipes/esm2_accelerate/callbacks.py
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml
  • recipes/esm2_accelerate/README.md
  • recipes/amplify_accelerate_te_fp8/callbacks.py
  • recipes/esm2_accelerate/accelerate_config/fsdp1_te.yaml
  • recipes/esm2_accelerate/dataset.py
  • recipes/esm2_accelerate/requirements.txt
  • recipes/esm2_accelerate/train.py
  • recipes/esm2_accelerate/metrics.py
  • recipes/amplify_accelerate_te_fp8/accelerate_config/bf16_config.yaml
  • recipes/esm2_accelerate/accelerate_config/fsdp2_te.yaml
  • models/esm2/src/esm/modeling_esm_te.py
  • recipes/amplify_accelerate_te_fp8/train.py
  • recipes/esm2_accelerate/accelerate_config/fsdp2_hf.yaml
🧰 Additional context used
🧬 Code graph analysis (6)
recipes/esm2_native_te_mfsdp/train_ddp.py (2)
recipes/esm2_native_te_mfsdp/dataset.py (1)
  • create_dataloader (40-95)
recipes/esm2_native_te_mfsdp/scheduler.py (1)
  • get_linear_schedule_with_warmup (19-45)
recipes/esm2_native_te_mfsdp/test_train.py (3)
recipes/esm2_native_te_mfsdp/train_ddp.py (1)
  • main (53-181)
recipes/esm2_native_te_mfsdp/train_fsdp2.py (1)
  • main (57-190)
recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)
  • main (56-199)
recipes/esm2_accelerate/test_train.py (3)
sub-packages/bionemo-testing/src/bionemo/testing/torch.py (1)
  • check_fp8_support (21-33)
recipes/esm2_accelerate/train.py (1)
  • main (36-75)
recipes/amplify_accelerate_te_fp8/test_train.py (2)
  • test_accelerate_launch (169-210)
  • test_accelerate_launch_multi_gpu (214-249)
recipes/esm2_native_te_mfsdp/train_fsdp2.py (3)
recipes/esm2_native_te_mfsdp/dataset.py (1)
  • create_dataloader (40-95)
recipes/esm2_native_te_mfsdp/scheduler.py (1)
  • get_linear_schedule_with_warmup (19-45)
recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)
  • main (56-199)
recipes/amplify_accelerate_te_fp8/test_train.py (1)
recipes/esm2_accelerate/test_train.py (1)
  • test_accelerate_launch_multi_gpu (194-233)
recipes/esm2_native_te_mfsdp/train_mfsdp.py (3)
recipes/esm2_native_te_mfsdp/dataset.py (1)
  • create_dataloader (40-95)
recipes/esm2_native_te_mfsdp/scheduler.py (1)
  • get_linear_schedule_with_warmup (19-45)
recipes/esm2_native_te_mfsdp/train_fsdp2.py (2)
  • main (57-190)
  • DistributedConfig (44-53)
🪛 LanguageTool
recipes/esm2_accelerate/hydra_config/README.md

[grammar] ~10-~10: There might be a mistake here.
Context: ...erride fp8 settings given to accelerate. This causes issues with the `deepspeed...

(QB_NEW_EN)

🪛 Shellcheck (0.10.0)
recipes/esm2_accelerate/slurm.sh

[error] 24-24: Couldn't parse this variable assignment. Fix to allow more checks.

(SC1073)


[error] 24-24: Fix any mentioned problems and try again.

(SC1072)

🔇 Additional comments (14)
recipes/esm2_native_te_mfsdp/.ruff.toml (2)

1-1: Per-recipe Ruff inheritance looks good.

Extending the parent config keeps rules consistent across recipes. No issues.


1-1: Confirmed base config path. extend = "../.ruff.toml" correctly resolves to recipes/.ruff.toml and that file exists. No change required—only switch to extend = "../../.ruff.toml" if you intended to reference the repo-root config.

recipes/esm2_native_te_mfsdp/requirements.txt (1)

7-7: Pin the transformers VCS dependency to a commit SHA for reproducibility.

Installing from HEAD is non-deterministic in air-gapped/CI environments; pin to a known-good commit that contains the ESM-2 fix:

-transformers @ git+https://github.com/huggingface/transformers
+transformers @ git+https://github.com/huggingface/transformers@<commit_sha>

Run in your environment to confirm the pinned ref is used:

pip install -r recipes/esm2_native_te_mfsdp/requirements.txt
python - <<'PY'
import pkg_resources
dist = pkg_resources.get_distribution("transformers")
print("Direct URL:", dist._provider.egg_info)
PY
recipes/esm2_native_te_nvfsdp_thd/train.py (1)

84-84: Hydra config_name without extension — LGTM.
Consistent with repo-wide change.

recipes/esm2_native_te_mfsdp/hydra_config/L1_15B_perf_test.yaml (1)

5-5: Verify training scripts use the full model_name – check that in train_mfsdp.py, train_fsdp2.py and train_ddp.py all calls to from_pretrained use args.model_name directly (e.g. AutoModel.from_pretrained(args.model_name)) and do not prepend “facebook/”, otherwise “nvidia/esm2_t48_15B_UR50D” will resolve to “facebook/nvidia/…”.

recipes/README.md (1)

138-139: Doc sample uses extensionless config_name — good.
Matches updated decorator usage.

recipes/amplify_accelerate_te_fp8/test_train.py (1)

37-40: Good multi-GPU skip marker.
Prevents spurious CI failures on 1-GPU runners.

recipes/esm2_accelerate/test_train.py (2)

32-32: Importing main is correct: train.py defines a Hydra-decorated main entrypoint and does not implement a run() function.


68-71: Make checkpoint count assertion robust.

Exact “== 2” is brittle across config changes. Assert presence instead.

-    assert len(checkpoint_dirs) == 2, (
-        f"Expected 2 checkpoint directories, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}"
-    )
+    assert len(checkpoint_dirs) >= 1, (
+        f"Expected at least 1 checkpoint directory, found {len(checkpoint_dirs)}: {[d.name for d in checkpoint_dirs]}"
+    )

Likely an incorrect or invalid review comment.

recipes/esm2_native_te_mfsdp/hydra_config/defaults.yaml (1)

12-24: FSDP enabled solely by fully_shard_kwargs
train_mfsdp.py unconditionally calls fully_shard(..., **args.fully_shard_kwargs) without any use_fsdp guard (lines 119–122).

recipes/esm2_native_te_mfsdp/hydra_config/L0_sanity.yaml (2)

18-19: LR 1e-2 is likely to diverge even for sanity runs

Recommend ≤1e-3 (ESM-2 pretrain baselines use ~4e-4 peak).

 adamw_kwargs:
-  lr: 1e-2
+  lr: 4e-4

5-5: Model path is valid
The Hugging Face repo nvidia/esm2_t6_8M_UR50D exists; the canonical upstream fallback is facebook/esm2_t6_8M_UR50D.

recipes/esm2_native_te_mfsdp/train_mfsdp.py (1)

113-117: Verify EsmLayer import compatibility. Ensure that in your deployed Transformers version, the module transformers.models.esm.modeling_esm actually defines EsmLayer by running, in the same environment where this recipe will execute:

from transformers.models.esm.modeling_esm import EsmLayer

and confirming no ImportError.

recipes/esm2_native_te_mfsdp/train_ddp.py (1)

70-74: DDP doesn’t accept device_mesh; remove arg (and the unused mesh).

Passing device_mesh to DistributedDataParallel will error on most PyTorch versions.

-    device_mesh = init_device_mesh(
-        "cuda",
-        mesh_shape=(dist_config.world_size,),
-        mesh_dim_names=("ddp",),
-    )
@@
     model = torch.nn.parallel.DistributedDataParallel(
         model,
         device_ids=[dist_config.local_rank],
         output_device=dist_config.local_rank,
-        device_mesh=device_mesh["ddp"],
     )

Also applies to: 103-108

⛔ Skipped due to learnings
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1078
File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108
Timestamp: 2025-08-28T16:40:04.274Z
Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn
Copy link
Collaborator Author

pstjohn commented Aug 28, 2025

@coderabbitai resolve

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 28, 2025

✅ Actions performed

Comments resolved.

@pstjohn pstjohn added this pull request to the merge queue Aug 28, 2025
Merged via the queue into NVIDIA:main with commit 720a732 Aug 29, 2025
17 checks passed
@pstjohn pstjohn deleted the pstjohn/esm2-accelerate-tests branch August 29, 2025 00:36
@coderabbitai coderabbitai bot mentioned this pull request Sep 8, 2025
9 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants