-
Notifications
You must be signed in to change notification settings - Fork 108
add partial conv tests to esm2_accelerate recipe #1122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add partial conv tests to esm2_accelerate recipe #1122
Conversation
|
Caution Review failedAn error occurred during the review process. Please try again later. ✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
recipes/esm2_accelerate/test_train.py (1)
260-286: Remove straybreakpoint()that will hang CI.This will stall pytest in non-interactive runs.
- final_train_loss = extract_final_train_loss(combined_output) - breakpoint() + final_train_loss = extract_final_train_loss(combined_output)Also consider the same
stop_after_n_steps=50override here:f"trainer.output_dir={tmp_path}", "trainer.do_eval=False", + "stop_after_n_steps=50",Also applies to: 293-299
🧹 Nitpick comments (7)
recipes/esm2_accelerate/train.py (2)
67-78: Resume logic + metrics saving looks solid; also save trainer state.Add
trainer.save_state()so RNG/optimizer/scheduler states are captured alongsidecheckpoint-lastfor more robust resumes.train_result = trainer.train(resume_from_checkpoint=last_checkpoint) logger.info("Training complete. Metrics: %s", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_model(str(Path(training_args.output_dir) / "checkpoint-last")) + trainer.save_state()
79-83: Evaluation gate may skip eval unintentionally.Keying off
do_evalalone can suppress eval even whenevaluation_strategy != "no". Consider checking the strategy or presence ofeval_dataset.- if training_args.do_eval: + # Run eval when explicitly requested or when the strategy isn't "no" + from transformers.training_args import IntervalStrategy + if training_args.do_eval or getattr(training_args, "evaluation_strategy", IntervalStrategy.NO) != IntervalStrategy.NO: eval_result = trainer.evaluate() logger.info("Evaluation complete. Metrics: %s", eval_result) trainer.save_metrics("eval", eval_result)recipes/esm2_accelerate/test_train.py (5)
44-74: Hardentrain_lossparsing (support scientific notation and varied formats).Broaden the regex to avoid false negatives.
- pattern = r'\{[^{}]*[\'"]train_loss[\'"]:\s*([0-9.]+)[^{}]*\}' + # number: optional sign, decimal, optional exponent + number = r'[-+]?(?:\d+(?:\.\d*)?|\.\d+)(?:[eE][-+]?\d+)?' + pattern = rf'\{{[^{{}}]*[\'"]train_loss[\'"]:\s*({number})[^{{}}]*\}}' ... - simple_pattern = r'[\'"]train_loss[\'"]:\s*([0-9.]+)' + simple_pattern = rf'[\'"]train_loss[\'"]\s*[:=]\s*({number})'
16-23: Avoid port collisions: pick an actually free port instead of random range.Bind to port 0 to get a free port and reuse it. Apply in env and CLI.
import os -import random +import random +import socket ... - monkeypatch.setenv("MASTER_PORT", f"{random.randint(20000, 40000)}") + def _get_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + monkeypatch.setenv("MASTER_PORT", str(_get_free_port())) ... - "--main_process_port", - f"{random.randint(20000, 40000)}", + "--main_process_port", + str(_get_free_port()), ... - "--main_process_port", - f"{random.randint(20000, 40000)}", + "--main_process_port", + str(_get_free_port()),Also applies to: 87-87, 199-201, 269-271
188-189: Fix assertion message to reference Accelerate config.Minor clarity nit.
- assert accelerate_config_path.exists(), f"deepspeed_config.yaml not found at {accelerate_config_path}" + assert accelerate_config_path.exists(), f"accelerate config not found at {accelerate_config_path}"
191-207: Accelerate test may exceed 240s with 250 steps — consider overriding or bumping timeout.Add a smaller step cap to keep tests snappy and deterministic.
str(train_py), "--config-name", "L0_sanity.yaml", f"model_tag={model_tag}", f"trainer.output_dir={tmp_path}", "trainer.do_eval=False", + "stop_after_n_steps=50",Alternatively, increase
timeoutbelow to e.g. 600.
131-131: Update stale comment to match current step count.Reflects
stop_after_n_steps=4.- # Remove the checkpoint-10 and checkpoint-last directories + # Remove the checkpoint-4 and checkpoint-last directories
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_accelerate/test_train.py(7 hunks)recipes/esm2_accelerate/train.py(3 hunks)
🔇 Additional comments (4)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (2)
3-3: Hydra: Good addition of_self_to defaults.Ensures local config precedence and avoids surprising overrides.
6-6: Sanity config now runs 250 steps — verify CI runtime or add test overrides.At 250 steps with eval/save at 1000, default runs won’t checkpoint/eval and may push past the 240s subprocess timeout used in tests. Either:
- bump the test timeout, or
- pass a smaller stop_after_n_steps override in accelerate tests (see suggested diffs in test file comments).
Also applies to: 12-17
recipes/esm2_accelerate/train.py (1)
42-46: Verifyfrom_configkwargs support for your Transformers version.Some versions don’t accept
trust_remote_code/torch_dtypeinAutoModelForMaskedLM.from_config(...)and will raiseTypeError. If unsupported, either drop the extra kwargs or switch tofrom_pretrained(...).Option A (keep random init, remove unsupported kwargs):
- model = AutoModelForMaskedLM.from_config( - config, - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ) + model = AutoModelForMaskedLM.from_config(config) + # Optionally cast after init if needed: + # model = model.to(torch.bfloat16)Option B (load pretrained weights):
- model = AutoModelForMaskedLM.from_config( - config, - trust_remote_code=True, - torch_dtype=torch.bfloat16, - ) + model = AutoModelForMaskedLM.from_pretrained( + args.model_tag, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + )recipes/esm2_accelerate/test_train.py (1)
223-233: Good check on parsed train_loss.Nice visibility and clear failure output.
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
f48cd3f to
7bea7f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (1)
6-6: 250 steps may be heavy for a “sanity” preset.If the intent is a long sanity run, fine; otherwise consider keeping this small (e.g., 4–20) and rely on overrides for longer tests, to avoid accidental long CI/dev runs.
Can you confirm CI/test jobs always override
stop_after_n_stepsback to a low value as shown in test_train.py?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml(1 hunks)recipes/esm2_accelerate/test_train.py(7 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- recipes/esm2_accelerate/test_train.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: unit-tests (recipes/esm2_accelerate, svcbionemo023/bionemo-framework:pytorch25.06-py3-squashed-zstd)
🔇 Additional comments (3)
recipes/esm2_accelerate/hydra_config/L0_sanity.yaml (3)
12-14: High save/eval intervals effectively disable them in a 250-step run.Likely intentional to speed L0. Just ensure any resume/checkpoint tests explicitly override
save_steps(they do in tests) so artifacts exist.Double-check no downstream tooling assumes at least one eval/checkpoint occurs when using this preset.
3-3: Confirm hydra-core version for_self_semantics
Nohydra-coreoromegaconfentry was found inrequirements*.txtorpyproject.toml; ensure the repo is targeting Hydra 1.2+ so_self_behaves as intended.
17-17: warmup_steps override safe—no warmup_ratio conflict
No warmup_ratio found in any scheduler config; overriding warmup_steps to 0 safely disables warmup.
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Adds partial conv tests to the esm2_accelerate recipe similar to those used in the mfsdp recipe
Summary by CodeRabbit