Add ESM-2 model gradient tests#1077
Conversation
jomitchellnv
left a comment
There was a problem hiding this comment.
some nits but overall looks good to me.
c822a99 to
96db758
Compare
WalkthroughReplaces uv-based editable installs with plain pip in two Dockerfiles, updates esm2 dependencies (megatron moved to a git source, adds megatron-fsdp, enables transformer_engine, unpins nemo), extracts test fixture logic into Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Pytest
participant TestModule
participant Torchrun
participant Worker0 as Worker 0
participant Worker1 as Worker 1
Pytest->>TestModule: invoke distributed test
TestModule->>Torchrun: launch torchrun (nproc, strategy, backend)
rect rgba(200,230,255,0.20)
note right of Torchrun: spawn worker processes
Torchrun->>Worker0: start process
Torchrun->>Worker1: start process
end
rect rgba(220,255,220,0.18)
note over Worker0,Worker1: init model, data, strategy (DDP/FSDP)
Worker0->>Worker0: forward + backward, capture grads
Worker1->>Worker1: forward + backward, capture grads
Worker0-->>Worker1: sync per strategy
end
Worker0-->>Torchrun: emit logs
Worker1-->>Torchrun: emit logs
rect rgba(255,245,200,0.15)
Torchrun->>TestModule: return exit code + output
TestModule->>TestModule: compare outputs & grads
TestModule-->>Pytest: report pass/fail (include stdout/stderr on failure)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (5)
🚧 Files skipped from review as they are similar to previous changes (5)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (1)
models/esm2/tests/test_distributed_strategies.py (1)
29-33: Consider a decorator that encodes the exact GPU count needed.Matches prior feedback; improves clarity vs. a generic “requires at least 2 GPUs”.
Example:
def requires_ngpus(n: int): return pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < n, reason=f"Test requires {n} GPUs")
🧹 Nitpick comments (9)
models/amplify/Dockerfile (1)
7-11: Remove unused uv and fix cache mount to pip cache.You no longer use uv, but still copy the binary and mount its cache path. This bloats the image and misses pip’s cache.
Apply:
-COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ ... -RUN --mount=type=cache,target=/root/.cache/uv \ - PIP_CONSTRAINT= pip install -e . +RUN --mount=type=cache,target=/root/.cache/pip \ + PIP_CONSTRAINT= pip install -e .models/esm2/Dockerfile (1)
2-6: Drop uv and switch cache mount to pip cache.uv is no longer invoked; keep the image lean and benefit from pip’s cache.
Apply:
-COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ ... -RUN --mount=type=cache,target=/root/.cache/uv \ - PIP_CONTRAINT= pip install -e . +RUN --mount=type=cache,target=/root/.cache/pip \ + PIP_CONSTRAINT= pip install -e .models/esm2/pyproject.toml (1)
21-21: Avoid declaring bare 'torch' in a CUDA container unless constrained.Unpinned 'torch' can trigger CPU wheel resolution in some contexts. Rely on the base image or add an upper/lower bound.
If you keep it, constrain to the container’s Torch (e.g., >=2.4,<2.6) or drop it and document the requirement in the Dockerfile.
models/esm2/tests/conftest.py (1)
91-93: Minor: add a docstring and type hint to the fixture for clarity.Helps IDEs and future readers; no behavior change.
-@pytest.fixture -def input_data(tokenizer): - return get_input_data(tokenizer) +@pytest.fixture +def input_data(tokenizer) -> dict: + """Single batch of tokenized ESM2 inputs for tests.""" + return get_input_data(tokenizer)models/esm2/tests/test_distributed_strategies.py (5)
235-247: Use tolerances in assert_close to reduce flakiness across backends.Small numeric drift across DDP/FSDP/TE is common. Set rtol/atol explicitly.
- torch.testing.assert_close(fsdp.loss, ddp.loss, msg=lambda x: f"Loss mismatch: {x}") - torch.testing.assert_close(fsdp.logits, ddp.logits, msg=lambda x: f"Logits mismatch: {x}") + torch.testing.assert_close(fsdp.loss, ddp.loss, rtol=1e-4, atol=1e-6, msg=lambda x: f"Loss mismatch: {x}") + torch.testing.assert_close(fsdp.logits, ddp.logits, rtol=1e-4, atol=1e-6, msg=lambda x: f"Logits mismatch: {x}") ... - torch.testing.assert_close(ddp_grad, fsdp_grad, msg=lambda x: f"Gradient mismatch for {name}: {x}") + torch.testing.assert_close(ddp_grad, fsdp_grad, rtol=1e-4, atol=1e-6, msg=lambda x: f"Gradient mismatch for {name}: {x}")
249-250: Guard the negative test with equal shape and finite checks.If a param is all zeros on both paths or contains NaNs, this can give false signals.
- assert not torch.allclose(ddp_grad, torch.roll(fsdp_grad, -1, -1)) + assert ddp_grad.shape == fsdp_grad.shape + assert torch.isfinite(ddp_grad).all() and torch.isfinite(fsdp_grad).all() + assert not torch.allclose(ddp_grad, torch.roll(fsdp_grad, -1, -1))
54-61: Increase subprocess timeout to reduce CI flakiness.240s is tight on shared runners for first-time HF/TE model loads. Suggest 600s.
- timeout=240, + timeout=600,Apply similarly in the multi-GPU test (Lines 88-95).
102-115: Remove duplicate import of argparse.argparse is imported at Line 16 already.
- import argparse
221-233: Ensure process group teardown on exceptions.If an assertion fails before destroy_process_group(), ranks can hang. Use try/finally.
- ddp, ddp_grads = run_forward_backward(...) - fsdp, fsdp_grads = run_forward_backward(...) - ... - dist.destroy_process_group() + try: + ddp, ddp_grads = run_forward_backward(...) + fsdp, fsdp_grads = run_forward_backward(...) + ... + finally: + dist.destroy_process_group()
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
models/amplify/Dockerfile(1 hunks)models/esm2/Dockerfile(1 hunks)models/esm2/pyproject.toml(1 hunks)models/esm2/tests/conftest.py(2 hunks)models/esm2/tests/test_distributed_strategies.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
models/esm2/tests/test_distributed_strategies.py (1)
models/esm2/tests/conftest.py (2)
get_input_data(37-88)input_data(92-93)
🔇 Additional comments (3)
models/esm2/tests/conftest.py (1)
37-89: Good extraction of reusable data prep into a helper.Seeded tokenization, batching, and collator setup are clean and deterministic.
models/esm2/tests/test_distributed_strategies.py (2)
151-165: Tokenizer/model pairing: verify TE and HF variants use identical tokenization.When --test_te is set, the model comes from "nvidia/esm2..." but inputs are tokenized with "facebook/esm2...". If vocab/tokenization differ, comparisons will be invalid.
If needed, select tokenizer per backend:
tok_id = "nvidia/esm2_t6_8M_UR50D" if use_te else "facebook/esm2_t6_8M_UR50D" input_data = get_input_data(AutoTokenizer.from_pretrained(tok_id))
172-179: DDP does not accept device_mesh; this likely raises TypeError.Standard torch.nn.parallel.DistributedDataParallel has no device_mesh kwarg. Pass process_group if needed; keep device_ids/output_device only.
- model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[dist_config.local_rank], - output_device=dist_config.local_rank, - device_mesh=device_mesh["dp"], - ) + model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[dist_config.local_rank], + output_device=dist_config.local_rank, + )⛔ Skipped due to learnings
Learnt from: pstjohn PR: NVIDIA/bionemo-framework#1078 File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108 Timestamp: 2025-08-28T16:40:04.315Z Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
f5a5c88 to
4209f20
Compare
Adds gradient tests to ensure the gradients we get from ddp, fsdp2, and nvfsdp are consistent
Summary by CodeRabbit