Skip to content

make sure we initialize accelerator before model#1132

Merged
pstjohn merged 2 commits into
mainfrom
pstjohn/fix-esm2-accelerate-init
Sep 8, 2025
Merged

make sure we initialize accelerator before model#1132
pstjohn merged 2 commits into
mainfrom
pstjohn/fix-esm2-accelerate-init

Conversation

@pstjohn
Copy link
Copy Markdown
Collaborator

@pstjohn pstjohn commented Sep 8, 2025

We need to initialize the Accelerator object before creating TE layers or they all end up on a single device

Summary by CodeRabbit

  • New Features

    • Added a ready-to-run performance test preset for the esm2 t48 15B model with sensible defaults: model tag, step cap, batch sizes, learning rate, weight decay, warmup steps, and Weights & Biases logging.
  • Bug Fixes

    • Improved multi-GPU initialization by starting distributed state earlier, reducing setup issues and OOM risk without changing training behavior.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Sep 8, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Sep 8, 2025

Walkthrough

Adds a new Hydra YAML config for an esm2_t48_15B_UR50D performance test and moves early Accelerate initialization into train.py main() before model construction; removes the post-creation accelerator state log. No other training, dataset, or evaluation logic changed.

Changes

Cohort / File(s) Summary
Hydra perf test config
recipes/esm2_accelerate/hydra_config/L1_15B_perf_test.yaml
New config: adds defaults (items defaults, _self_), model_tag: nvidia/esm2_t48_15B_UR50D, stop_after_n_steps: 500, and trainer mapping (run_name, per-device batch sizes 12/12, report_to: wandb, learning_rate: 1.6e-4, weight_decay: 0.1, warmup_steps: 20_000).
Accelerate init adjustment
recipes/esm2_accelerate/train.py
Instantiate Accelerate (PartialState/Accelerator) early in main() before model/config creation and log local_process_index, num_processes, and device; remove the later accelerator state print. No other functional changes.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant CLI as CLI/Launcher
  participant Train as train.py
  participant Accel as Accelerate
  participant Model as Model/Config
  participant Trainer as Trainer/Loop
  participant Logger as Logger

  CLI->>Train: start main()
  Train->>Accel: create PartialState()/Accelerator (early)
  Accel-->>Train: provides device, process info
  Train->>Logger: log local_process_index, num_processes, device
  Train->>Model: build model/config on assigned device
  Train->>Trainer: configure dataloaders, optimizer, etc.
  Trainer->>Accel: prepare components (wrap with accelerator)
  Trainer->>Trainer: run training/eval (stop_after_n_steps)
  Note over Logger,Trainer: removed post-creation accelerator state print
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Suggested reviewers

  • jstjohn
  • jwilber
  • cspades
  • trvachov

A hop, a whisper, circuits bright,
I start the accel before midnight.
Fifteen billion fur and speed,
Steps capped quick for performance need.
Logs tidy, batches set just right—🥕

Pre-merge checks (2 passed, 1 warning)

❌ Failed Checks (1 warning)
Check Name Status Explanation Resolution
Description Check ⚠️ Warning The current description consists of a single sentence and does not follow the repository’s required template, missing sections such as Type of changes, CI Pipeline Configuration, Usage, and Pre-submit Checklist. Update the description to include the full template sections by adding detailed change descriptions, selecting the appropriate change type, specifying CI labels, demonstrating usage with a code snippet, and completing the pre-submit checklist.
✅ Passed Checks (2 passed)
Check Name Status Explanation
Title Check ✅ Passed The pull request title clearly and concisely summarizes the primary functional change, namely ensuring the accelerator is initialized before model creation, and avoids extraneous detail, making it easy for reviewers to understand the core update at a glance.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  - Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.
  - Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between febd9f5 and 710e405.

📒 Files selected for processing (1)
  • recipes/esm2_accelerate/train.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • recipes/esm2_accelerate/train.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch pstjohn/fix-esm2-accelerate-init

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (4)
recipes/esm2_accelerate/train.py (2)

39-42: Early Accelerate init fixes device placement; consider PartialState instead.

Creating an unused Accelerator just to set the current device works but is heavier than needed. PartialState initializes the distributed state and sets torch device without constructing a full Accelerator instance.

Apply this diff:

-    # We need to initialize the Accelerator manually prior to creating our model, otherwise we won't end up setting the
-    # current torch device and the model creation will all happen on a single GPU, typically leading to an OOM.
-    _ = Accelerator()
+    # Initialize Accelerate's distributed state early so torch device is set per process
+    state = PartialState()
+    logger.info(
+        "Accelerate initialized (local_process_index=%s, num_processes=%s, device=%s)",
+        state.local_process_index,
+        state.num_processes,
+        state.device,
+    )

Note: See the import change on Line 22.


22-22: Import PartialState (lighter) instead of Accelerator (unused).

Avoid constructing an unused Accelerator instance; use PartialState for early device init.

Apply this diff:

-from accelerate import Accelerator
+from accelerate import PartialState
recipes/esm2_accelerate/hydra_config/L1_15B_perf_test.yaml (2)

12-14: Use an int for warmup_steps (drop underscore).

Some YAML parsers/OmegaConf setups won’t accept numeric separators; “20_000” can become a string and break TrainingArguments’ int validation.

Apply this diff:

-  warmup_steps: 20_000
+  warmup_steps: 20000

7-14: Optional: align Trainer precision with model dtype.

Model is created in bfloat16; consider setting bf16: true (and tf32: true on Ampere/Hopper) for better perf and consistent autocast during training.

Example:

 trainer:
   run_name: "esm2_t48_15B_UR50D_perf"
   per_device_train_batch_size: 12
   per_device_eval_batch_size: 12
   report_to: "wandb"
   learning_rate: 1.6e-4
   weight_decay: 0.1
-  warmup_steps: 20000
+  warmup_steps: 20000
+  bf16: true
+  tf32: true
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0d30652 and febd9f5.

📒 Files selected for processing (2)
  • recipes/esm2_accelerate/hydra_config/L1_15B_perf_test.yaml (1 hunks)
  • recipes/esm2_accelerate/train.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Analyze (rust)
🔇 Additional comments (2)
recipes/esm2_accelerate/train.py (1)

22-22: Dependency confirmed: accelerate declared in recipes/esm2_accelerate/requirements.txt No further action needed.

recipes/esm2_accelerate/hydra_config/L1_15B_perf_test.yaml (1)

1-3: defaults.yaml present Verified that recipes/esm2_accelerate/hydra_config/defaults.yaml exists and resolves correctly.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn
Copy link
Copy Markdown
Collaborator Author

pstjohn commented Sep 8, 2025

/ok to test 710e405

@pstjohn pstjohn enabled auto-merge September 8, 2025 21:37
@pstjohn pstjohn added this pull request to the merge queue Sep 8, 2025
Merged via the queue into main with commit fca6ead Sep 8, 2025
19 checks passed
@pstjohn pstjohn deleted the pstjohn/fix-esm2-accelerate-init branch September 8, 2025 22:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants