Skip to content

Add ESM-2 model gradient tests#1077

Merged
pstjohn merged 5 commits into
NVIDIA-BioNeMo:mainfrom
pstjohn:pstjohn/add-esm2-gradient-tests
Sep 2, 2025
Merged

Add ESM-2 model gradient tests#1077
pstjohn merged 5 commits into
NVIDIA-BioNeMo:mainfrom
pstjohn:pstjohn/add-esm2-gradient-tests

Conversation

@pstjohn
Copy link
Copy Markdown
Collaborator

@pstjohn pstjohn commented Aug 26, 2025

Adds gradient tests to ensure the gradients we get from ddp, fsdp2, and nvfsdp are consistent

Summary by CodeRabbit

  • Chores
    • Simplified image builds by switching to standard pip for editable installs.
  • Dependencies
    • Updated core ML dependencies and enabled/unpinned select components for better compatibility.
  • Tests
    • Added comprehensive distributed-training validation tests comparing strategies across single- and multi-GPU setups with automatic GPU gating and detailed checks.
  • Refactor
    • Extracted test data generation into a reusable helper to streamline test setup.

Comment thread models/esm2/tests/test_distributed_strategies.py Outdated
@pstjohn pstjohn requested a review from ohadmo as a code owner August 27, 2025 22:46
Comment thread models/esm2/tests/test_distributed_strategies.py
Comment thread models/esm2/tests/test_distributed_strategies.py
Comment thread models/esm2/tests/test_distributed_strategies.py
Comment thread models/esm2/tests/test_distributed_strategies.py Outdated
Comment thread models/esm2/tests/test_distributed_strategies.py
Copy link
Copy Markdown
Collaborator

@jomitchellnv jomitchellnv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some nits but overall looks good to me.

@pstjohn pstjohn force-pushed the pstjohn/add-esm2-gradient-tests branch from c822a99 to 96db758 Compare September 2, 2025 19:45
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Sep 2, 2025

Walkthrough

Replaces uv-based editable installs with plain pip in two Dockerfiles, updates esm2 dependencies (megatron moved to a git source, adds megatron-fsdp, enables transformer_engine, unpins nemo), extracts test fixture logic into get_input_data, and adds a new distributed strategies test module comparing DDP and FSDP via torchrun.

Changes

Cohort / File(s) Summary of Changes
Docker install command changes
models/amplify/Dockerfile, models/esm2/Dockerfile
Replace uv pip install --system --break-system-packages -e . with PIP_CONSTRAINT= pip install -e .; remove uv wrapper and --system/--break-system-packages flags while preserving cache mount usage.
ESM2 dependency updates
models/esm2/pyproject.toml
Update dependencies: use git-based megatron-core source, add megatron-fsdp, unpin nemo_toolkit[lightning], and enable transformer_engine[pytorch].
Test fixture refactor
models/esm2/tests/conftest.py
Extract data-generation logic into new get_input_data(tokenizer) helper; input_data pytest fixture now delegates to that helper; behavior unchanged.
Distributed strategies tests
models/esm2/tests/test_distributed_strategies.py
Add new test module that runs torchrun to compare DDP vs FSDP (single- and multi-GPU), parameterizes over strategies/backends, gates multi-GPU tests, prints stdout/stderr on failures, and includes a main workflow for end-to-end gradient/output comparisons.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Pytest
  participant TestModule
  participant Torchrun
  participant Worker0 as Worker 0
  participant Worker1 as Worker 1

  Pytest->>TestModule: invoke distributed test
  TestModule->>Torchrun: launch torchrun (nproc, strategy, backend)
  rect rgba(200,230,255,0.20)
    note right of Torchrun: spawn worker processes
    Torchrun->>Worker0: start process
    Torchrun->>Worker1: start process
  end

  rect rgba(220,255,220,0.18)
    note over Worker0,Worker1: init model, data, strategy (DDP/FSDP)
    Worker0->>Worker0: forward + backward, capture grads
    Worker1->>Worker1: forward + backward, capture grads
    Worker0-->>Worker1: sync per strategy
  end

  Worker0-->>Torchrun: emit logs
  Worker1-->>Torchrun: emit logs

  rect rgba(255,245,200,0.15)
    Torchrun->>TestModule: return exit code + output
    TestModule->>TestModule: compare outputs & grads
    TestModule-->>Pytest: report pass/fail (include stdout/stderr on failure)
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

A rabbit swaps uv for pip with a nimble little bound,
Megatron now comes from git, new engines all around.
I pulled the fixture out, then launched torchrun's parade,
DDP and FSDP, in gradients they waltz and trade. 🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f5a5c88 and 4209f20.

📒 Files selected for processing (5)
  • models/amplify/Dockerfile (1 hunks)
  • models/esm2/Dockerfile (1 hunks)
  • models/esm2/pyproject.toml (1 hunks)
  • models/esm2/tests/conftest.py (2 hunks)
  • models/esm2/tests/test_distributed_strategies.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • models/esm2/pyproject.toml
  • models/amplify/Dockerfile
  • models/esm2/tests/test_distributed_strategies.py
  • models/esm2/Dockerfile
  • models/esm2/tests/conftest.py
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
models/esm2/tests/test_distributed_strategies.py (1)

29-33: Consider a decorator that encodes the exact GPU count needed.

Matches prior feedback; improves clarity vs. a generic “requires at least 2 GPUs”.

Example:

def requires_ngpus(n: int):
    return pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < n,
                              reason=f"Test requires {n} GPUs")
🧹 Nitpick comments (9)
models/amplify/Dockerfile (1)

7-11: Remove unused uv and fix cache mount to pip cache.

You no longer use uv, but still copy the binary and mount its cache path. This bloats the image and misses pip’s cache.

Apply:

-COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
 ...
-RUN --mount=type=cache,target=/root/.cache/uv \
-    PIP_CONSTRAINT= pip install -e .
+RUN --mount=type=cache,target=/root/.cache/pip \
+    PIP_CONSTRAINT= pip install -e .
models/esm2/Dockerfile (1)

2-6: Drop uv and switch cache mount to pip cache.

uv is no longer invoked; keep the image lean and benefit from pip’s cache.

Apply:

-COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
 ...
-RUN --mount=type=cache,target=/root/.cache/uv \
-    PIP_CONTRAINT= pip install -e .
+RUN --mount=type=cache,target=/root/.cache/pip \
+    PIP_CONSTRAINT= pip install -e .
models/esm2/pyproject.toml (1)

21-21: Avoid declaring bare 'torch' in a CUDA container unless constrained.

Unpinned 'torch' can trigger CPU wheel resolution in some contexts. Rely on the base image or add an upper/lower bound.

If you keep it, constrain to the container’s Torch (e.g., >=2.4,<2.6) or drop it and document the requirement in the Dockerfile.

models/esm2/tests/conftest.py (1)

91-93: Minor: add a docstring and type hint to the fixture for clarity.

Helps IDEs and future readers; no behavior change.

-@pytest.fixture
-def input_data(tokenizer):
-    return get_input_data(tokenizer)
+@pytest.fixture
+def input_data(tokenizer) -> dict:
+    """Single batch of tokenized ESM2 inputs for tests."""
+    return get_input_data(tokenizer)
models/esm2/tests/test_distributed_strategies.py (5)

235-247: Use tolerances in assert_close to reduce flakiness across backends.

Small numeric drift across DDP/FSDP/TE is common. Set rtol/atol explicitly.

-    torch.testing.assert_close(fsdp.loss, ddp.loss, msg=lambda x: f"Loss mismatch: {x}")
-    torch.testing.assert_close(fsdp.logits, ddp.logits, msg=lambda x: f"Logits mismatch: {x}")
+    torch.testing.assert_close(fsdp.loss, ddp.loss, rtol=1e-4, atol=1e-6, msg=lambda x: f"Loss mismatch: {x}")
+    torch.testing.assert_close(fsdp.logits, ddp.logits, rtol=1e-4, atol=1e-6, msg=lambda x: f"Logits mismatch: {x}")
 ...
-        torch.testing.assert_close(ddp_grad, fsdp_grad, msg=lambda x: f"Gradient mismatch for {name}: {x}")
+        torch.testing.assert_close(ddp_grad, fsdp_grad, rtol=1e-4, atol=1e-6, msg=lambda x: f"Gradient mismatch for {name}: {x}")

249-250: Guard the negative test with equal shape and finite checks.

If a param is all zeros on both paths or contains NaNs, this can give false signals.

-        assert not torch.allclose(ddp_grad, torch.roll(fsdp_grad, -1, -1))
+        assert ddp_grad.shape == fsdp_grad.shape
+        assert torch.isfinite(ddp_grad).all() and torch.isfinite(fsdp_grad).all()
+        assert not torch.allclose(ddp_grad, torch.roll(fsdp_grad, -1, -1))

54-61: Increase subprocess timeout to reduce CI flakiness.

240s is tight on shared runners for first-time HF/TE model loads. Suggest 600s.

-        timeout=240,
+        timeout=600,

Apply similarly in the multi-GPU test (Lines 88-95).


102-115: Remove duplicate import of argparse.

argparse is imported at Line 16 already.

-    import argparse

221-233: Ensure process group teardown on exceptions.

If an assertion fails before destroy_process_group(), ranks can hang. Use try/finally.

-    ddp, ddp_grads = run_forward_backward(...)
-    fsdp, fsdp_grads = run_forward_backward(...)
-    ...
-    dist.destroy_process_group()
+    try:
+        ddp, ddp_grads = run_forward_backward(...)
+        fsdp, fsdp_grads = run_forward_backward(...)
+        ...
+    finally:
+        dist.destroy_process_group()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between a29272f and 96db758.

📒 Files selected for processing (5)
  • models/amplify/Dockerfile (1 hunks)
  • models/esm2/Dockerfile (1 hunks)
  • models/esm2/pyproject.toml (1 hunks)
  • models/esm2/tests/conftest.py (2 hunks)
  • models/esm2/tests/test_distributed_strategies.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
models/esm2/tests/test_distributed_strategies.py (1)
models/esm2/tests/conftest.py (2)
  • get_input_data (37-88)
  • input_data (92-93)
🔇 Additional comments (3)
models/esm2/tests/conftest.py (1)

37-89: Good extraction of reusable data prep into a helper.

Seeded tokenization, batching, and collator setup are clean and deterministic.

models/esm2/tests/test_distributed_strategies.py (2)

151-165: Tokenizer/model pairing: verify TE and HF variants use identical tokenization.

When --test_te is set, the model comes from "nvidia/esm2..." but inputs are tokenized with "facebook/esm2...". If vocab/tokenization differ, comparisons will be invalid.

If needed, select tokenizer per backend:

tok_id = "nvidia/esm2_t6_8M_UR50D" if use_te else "facebook/esm2_t6_8M_UR50D"
input_data = get_input_data(AutoTokenizer.from_pretrained(tok_id))

172-179: DDP does not accept device_mesh; this likely raises TypeError.

Standard torch.nn.parallel.DistributedDataParallel has no device_mesh kwarg. Pass process_group if needed; keep device_ids/output_device only.

-            model = torch.nn.parallel.DistributedDataParallel(
-                model,
-                device_ids=[dist_config.local_rank],
-                output_device=dist_config.local_rank,
-                device_mesh=device_mesh["dp"],
-            )
+            model = torch.nn.parallel.DistributedDataParallel(
+                model,
+                device_ids=[dist_config.local_rank],
+                output_device=dist_config.local_rank,
+            )
⛔ Skipped due to learnings
Learnt from: pstjohn
PR: NVIDIA/bionemo-framework#1078
File: recipes/esm2_native_te_mfsdp/train_ddp.py:103-108
Timestamp: 2025-08-28T16:40:04.315Z
Learning: PyTorch DistributedDataParallel constructor accepts a device_mesh parameter in recent versions, which supports advanced distributed training scenarios and nvFSDP configurations.

Comment thread models/esm2/Dockerfile Outdated
Comment thread models/esm2/pyproject.toml Outdated
Comment thread models/esm2/tests/test_distributed_strategies.py
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/add-esm2-gradient-tests branch from f5a5c88 to 4209f20 Compare September 2, 2025 20:02
@pstjohn pstjohn enabled auto-merge September 2, 2025 20:03
@pstjohn pstjohn added this pull request to the merge queue Sep 2, 2025
Merged via the queue into NVIDIA-BioNeMo:main with commit 2d3d3c8 Sep 2, 2025
16 checks passed
@pstjohn pstjohn deleted the pstjohn/add-esm2-gradient-tests branch September 2, 2025 21:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants