Skip to content

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Sep 5, 2025

Adds partial conv tests to the esm2_accelerate recipe similar to those used in the mfsdp recipe

Summary by CodeRabbit

  • New Features
    • Added configurable warmup steps (default 0) to training.
  • Chores
    • Increased default training duration (more steps).
    • Reduced frequency of saving, evaluation, and logging to lower overhead.
  • Tests
    • Improved distributed run stability by using dynamic, collision-free ports.
    • Added parsing of final training loss from output and assertions to ensure expected convergence.
    • Streamlined test overrides for faster, deterministic sanity runs.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Caution

Review failed

An error occurred during the review process. Please try again later.

✨ Finishing Touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
recipes/esm2_accelerate/test_train.py (1)

260-286: Remove stray breakpoint() that will hang CI.

This will stall pytest in non-interactive runs.

-        final_train_loss = extract_final_train_loss(combined_output)
-        breakpoint()
+        final_train_loss = extract_final_train_loss(combined_output)

Also consider the same stop_after_n_steps=50 override here:

         f"trainer.output_dir={tmp_path}",
         "trainer.do_eval=False",
+        "stop_after_n_steps=50",

Also applies to: 293-299

🧹 Nitpick comments (7)
recipes/esm2_accelerate/train.py (2)

67-78: Resume logic + metrics saving looks solid; also save trainer state.

Add trainer.save_state() so RNG/optimizer/scheduler states are captured alongside checkpoint-last for more robust resumes.

         train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
         logger.info("Training complete. Metrics: %s", train_result.metrics)
         trainer.save_metrics("train", train_result.metrics)
         trainer.save_model(str(Path(training_args.output_dir) / "checkpoint-last"))
+        trainer.save_state()

79-83: Evaluation gate may skip eval unintentionally.

Keying off do_eval alone can suppress eval even when evaluation_strategy != "no". Consider checking the strategy or presence of eval_dataset.

-    if training_args.do_eval:
+    # Run eval when explicitly requested or when the strategy isn't "no"
+    from transformers.training_args import IntervalStrategy
+    if training_args.do_eval or getattr(training_args, "evaluation_strategy", IntervalStrategy.NO) != IntervalStrategy.NO:
         eval_result = trainer.evaluate()
         logger.info("Evaluation complete. Metrics: %s", eval_result)
         trainer.save_metrics("eval", eval_result)
recipes/esm2_accelerate/test_train.py (5)

44-74: Harden train_loss parsing (support scientific notation and varied formats).

Broaden the regex to avoid false negatives.

-    pattern = r'\{[^{}]*[\'"]train_loss[\'"]:\s*([0-9.]+)[^{}]*\}'
+    # number: optional sign, decimal, optional exponent
+    number = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?'
+    pattern = rf'\{{[^{{}}]*[\'"]train_loss[\'"]:\s*({number})[^{{}}]*\}}'
...
-        simple_pattern = r'[\'"]train_loss[\'"]:\s*([0-9.]+)'
+        simple_pattern = rf'[\'"]train_loss[\'"]\s*[:=]\s*({number})'

16-23: Avoid port collisions: pick an actually free port instead of random range.

Bind to port 0 to get a free port and reuse it. Apply in env and CLI.

 import os
-import random
+import random
+import socket
...
-    monkeypatch.setenv("MASTER_PORT", f"{random.randint(20000, 40000)}")
+    def _get_free_port():
+        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+            s.bind(("127.0.0.1", 0))
+            return s.getsockname()[1]
+    monkeypatch.setenv("MASTER_PORT", str(_get_free_port()))
...
-        "--main_process_port",
-        f"{random.randint(20000, 40000)}",
+        "--main_process_port",
+        str(_get_free_port()),
...
-        "--main_process_port",
-        f"{random.randint(20000, 40000)}",
+        "--main_process_port",
+        str(_get_free_port()),

Also applies to: 87-87, 199-201, 269-271


188-189: Fix assertion message to reference Accelerate config.

Minor clarity nit.

-    assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}"
+    assert accelerate_config_path.exists(), f"accelerate config not found at {accelerate_config_path}"

191-207: Accelerate test may exceed 240s with 250 steps — consider overriding or bumping timeout.

Add a smaller step cap to keep tests snappy and deterministic.

         str(train_py),
         "--config-name",
         "L0_sanity.yaml",
         f"model_tag={model_tag}",
         f"trainer.output_dir={tmp_path}",
         "trainer.do_eval=False",
+        "stop_after_n_steps=50",

Alternatively, increase timeout below to e.g. 600.


131-131: Update stale comment to match current step count.

Reflects stop_after_n_steps=4.

-    # Remove the checkpoint-10 and checkpoint-last directories
+    # Remove the checkpoint-4 and checkpoint-last directories
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d1801d6 and 32fda43.

📒 Files selected for processing (3)
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1 hunks)
  • recipes/esm2_accelerate/test_train.py (7 hunks)
  • recipes/esm2_accelerate/train.py (3 hunks)
🔇 Additional comments (4)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (2)

3-3: Hydra: Good addition of _self_ to defaults.

Ensures local config precedence and avoids surprising overrides.


6-6: Sanity config now runs 250 steps — verify CI runtime or add test overrides.

At 250 steps with eval/save at 1000, default runs won’t checkpoint/eval and may push past the 240s subprocess timeout used in tests. Either:

  • bump the test timeout, or
  • pass a smaller stop_after_n_steps override in accelerate tests (see suggested diffs in test file comments).

Also applies to: 12-17

recipes/esm2_accelerate/train.py (1)

42-46: Verify from_config kwargs support for your Transformers version.

Some versions don’t accept trust_remote_code/torch_dtype in AutoModelForMaskedLM.from_config(...) and will raise TypeError. If unsupported, either drop the extra kwargs or switch to from_pretrained(...).

Option A (keep random init, remove unsupported kwargs):

-    model = AutoModelForMaskedLM.from_config(
-        config,
-        trust_remote_code=True,
-        torch_dtype=torch.bfloat16,
-    )
+    model = AutoModelForMaskedLM.from_config(config)
+    # Optionally cast after init if needed:
+    # model = model.to(torch.bfloat16)

Option B (load pretrained weights):

-    model = AutoModelForMaskedLM.from_config(
-        config,
-        trust_remote_code=True,
-        torch_dtype=torch.bfloat16,
-    )
+    model = AutoModelForMaskedLM.from_pretrained(
+        args.model_tag,
+        trust_remote_code=True,
+        torch_dtype=torch.bfloat16,
+    )
recipes/esm2_accelerate/test_train.py (1)

223-233: Good check on parsed train_loss.

Nice visibility and clear failure output.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/esm2-accelerate-partial-conv branch from f48cd3f to 7bea7f2 Compare September 8, 2025 13:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1)

6-6: 250 steps may be heavy for a “sanity” preset.

If the intent is a long sanity run, fine; otherwise consider keeping this small (e.g., 4–20) and rely on overrides for longer tests, to avoid accidental long CI/dev runs.

Can you confirm CI/test jobs always override stop_after_n_steps back to a low value as shown in test_train.py?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32fda43 and 7bea7f2.

📒 Files selected for processing (2)
  • recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1 hunks)
  • recipes/esm2_accelerate/test_train.py (7 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • recipes/esm2_accelerate/test_train.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: unit-tests (recipes/esm2_accelerate, svcbionemo023/bionemo-framework:pytorch25.06-py3-squashed-zstd)
🔇 Additional comments (3)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (3)

12-14: High save/eval intervals effectively disable them in a 250-step run.

Likely intentional to speed L0. Just ensure any resume/checkpoint tests explicitly override save_steps (they do in tests) so artifacts exist.

Double-check no downstream tooling assumes at least one eval/checkpoint occurs when using this preset.


3-3: Confirm hydra-core version for _self_ semantics
No hydra-core or omegaconf entry was found in requirements*.txt or pyproject.toml; ensure the repo is targeting Hydra 1.2+ so _self_ behaves as intended.


17-17: warmup_steps override safe—no warmup_ratio conflict
No warmup_ratio found in any scheduler config; overriding warmup_steps to 0 safely disables warmup.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn added this pull request to the merge queue Sep 8, 2025
Merged via the queue into NVIDIA:main with commit 0d30652 Sep 8, 2025
15 checks passed
@pstjohn pstjohn deleted the pstjohn/esm2-accelerate-partial-conv branch September 8, 2025 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants