Skip to content

Add MXFP8 and NVFP4 quantization support to LLaMA3#1500

Open
jomitchellnv wants to merge 6 commits intomainfrom
jm/mxfp8-nvfp4-llama3
Open

Add MXFP8 and NVFP4 quantization support to LLaMA3#1500
jomitchellnv wants to merge 6 commits intomainfrom
jm/mxfp8-nvfp4-llama3

Conversation

@jomitchellnv
Copy link
Copy Markdown
Collaborator

@jomitchellnv jomitchellnv commented Mar 6, 2026

Summary

 - Add per-layer MXFP8 and NVFP4 quantization support to the LLaMA3 model and recipe, matching the existing ESM2 implementation
 - Extend `NVLlamaModel` with `set_recipes()` / `get_layer_autocast()` for mixed-precision per-layer control (FP8, FP4, BF16)
 - Add `quantization.py` module to the llama3_native_te recipe with layer precision resolution, regex generation for debug stats, and stats logging initialization
 - Integrate quantization into all three training scripts (`train_ddp.py`, `train_fsdp2.py`, `train_fsdp2_cp.py`)
 - Add comprehensive test coverage:
   - Recipe-level: 21 unit tests for `resolve_layer_precision`, `generate_layer_regex`, `update_quant_stats_config`
   - Training integration: FP8 training tests (FSDP2 BSHD/THD), FP8 stats logging tests (DDP/FSDP2, full/partial layers)
   - Model-level: 14 unit tests for `set_recipes` / `get_layer_autocast` (all precision combos, mixed, pickling)
   - Distributed: FP8 recipe attachment and state synchronization tests (single/multi-process, DDP/FSDP2)

Test plan

 - [ ] `pytest bionemo-recipes/models/llama3/tests/test_layer_quantization.py -v`
 - [ ] `pytest bionemo-recipes/models/llama3/tests/test_distributed_fp8.py -v` (requires FP8-capable GPU)
 - [ ] `pytest bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py -v`
 - [ ] `pytest bionemo-recipes/recipes/llama3_native_te/tests/test_train.py -v` (requires GPU)
 - [ ] `pre-commit run --all-files` passes

Usage

TODO: Add code snippet

Type of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Refactor
  • Documentation update
  • Other (please describe):

CI Pipeline Configuration

Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run.

  • ciflow:skip - Skip all CI tests for this PR
  • ciflow:notebooks - Run Jupyter notebooks execution tests for bionemo2
  • ciflow:slow - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2
  • ciflow:all - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2.
  • ciflow:all-recipes - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes.

Unit tests marked as @pytest.mark.multi_gpu or @pytest.mark.distributed are not run in the PR pipeline.

For more details, see CONTRIBUTING

Note

By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage.

Authorizing CI Runs

We use copy-pr-bot to manage authorization of CI
runs on NVIDIA's compute resources.

  • If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will
    automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123)
  • If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an
    /ok to test comment on the pull request to trigger CI. This will need to be done for each new commit.

Triggering Code Rabbit AI Review

To trigger a code review from code rabbit, comment on a pull request with one of these commands:

See https://docs.coderabbit.ai/reference/review-commands for a full list of commands.

Pre-submit Checklist

  • I have tested these changes locally
  • I have updated the documentation accordingly
  • I have added/updated tests as needed
  • All existing tests pass successfully

Summary by CodeRabbit

  • New Features

    • Per-layer quantization: assign FP8/FP4/BF16 per transformer layer, attach runtime recipes, and nested autocast handling.
    • New quantization utilities for resolving layer precision, generating layer regexes, and initializing quant stats logging.
  • Documentation

    • Updated README, examples, and recipe docs for FP8/MXFP8/NVFP4, FP4 usage, and quantization stats configuration.
  • Tests

    • Added unit and distributed tests for layer quantization, recipe attachment/sync, regex/config updates, and stats logging.
  • Chores

    • Renamed fp8_stats_config → quant_stats_config; removed legacy FP8-only debugging module; updated training/perf logger flows.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 6, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c2b2fc4b-16bd-4942-b09d-d9be1b4a2281

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds layer-wise quantization support (FP8/FP4/BF16): new config field for per-layer precision, recipe attachment APIs on TE models, per-layer autocast orchestration in forward, quantization utilities and logging init, training-script wiring, tests, docs/config updates, and removal of legacy FP8 debugging initialization.

Changes

Cohort / File(s) Summary
Core Model Changes
bionemo-recipes/models/llama3/modeling_llama_te.py, bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
Add layer_precision config field; add _fp8_recipe/_fp4_recipe attributes; new set_recipes() and get_layer_autocast(); change forward to use outer FP8 autocast with nested per-layer autocast contexts.
Quantization Utilities
bionemo-recipes/recipes/llama3_native_te/quantization.py
New helpers: generate_layer_regex(), update_quant_stats_config(), initialize_quant_stats_logging(), resolve_layer_precision() for per-layer assignment and stats config updates.
Training Pipelines
bionemo-recipes/recipes/llama3_native_te/train_ddp.py, .../train_fsdp2.py, .../train_fsdp2_cp.py
Resolve and store layer_precision, conditionally create FP8/FP4 recipes, wrap model init with TE quantized init when used, attach recipes via set_recipes(), initialize quant stats logging, remove per-call external TE autocast.
Config & Docs
bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml, .../README.md
Rename fp8_stats_configquant_stats_config; add fp8_layers, fp4_layers, use_fp32_master_weights; doc updates for layer-wise FP8/FP4 and stats.
Stats & Debugging
bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml, .../fp8_debugging_stats.yaml, perf_logger.py, .../fp8_debugging.py
Add FP4/FP8 tensor-stats examples and expanded layer targets; remove legacy fp8_debugging.py; PerfLogger now reads quant_stats_config and gates debug API calls accordingly.
Tests — Model & Distributed
bionemo-recipes/models/llama3/tests/test_layer_quantization.py, bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
Add unit tests for set_recipes()/get_layer_autocast() and distributed FP8 recipe sync tests (torchrun orchestration).
Tests — Quantization & Train
bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py, .../tests/test_train.py, .../tests/test_perf_logger.py, .../tests/conftest.py
Add tests for resolve_layer_precision, regex generation, config mutation; update tests/configs to use quant_stats_config; add partial-layer stats logging sanity tests.
Removed
bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
Deleted legacy FP8 debugging initializer and its filesystem/logging/debug API wiring.

Sequence Diagram(s)

sequenceDiagram
    participant Script as Training Script
    participant Config as Resolver
    participant Recipes as Recipe Factory
    participant Model as NVLlama Model
    participant Decoder as Decoder Layers
    participant TE as TransformerEngine Autocast

    Script->>Config: resolve_layer_precision(fp8_enabled, fp4_enabled, fp8_layers, fp4_layers)
    Config-->>Script: layer_precision list
    Script->>Recipes: create fp8_recipe / fp4_recipe (conditional)
    Recipes-->>Script: recipe objects or None
    Script->>Model: set_recipes(fp8_recipe, fp4_recipe)
    Model->>Model: store _fp8_recipe, _fp4_recipe and config.layer_precision
    Script->>Model: forward(input_ids,...)
    Model->>TE: enter outer FP8 autocast if _fp8_recipe
    TE->>Decoder: iterate layers
    Decoder->>Model: get_layer_autocast(layer_number)
    Model-->>Decoder: return per-layer context (nullcontext / FP4 autocast / BF16-disabled)
    Decoder->>TE: execute layer under returned autocast
    Decoder-->>Model: layer outputs
    Model-->>Script: logits / outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

Possibly related PRs

Suggested reviewers

  • jstjohn
  • pstjohn
  • yzhang123
  • jwilber
  • cspades
  • dorotat-nv
  • trvachov
  • tshimko-nv
  • broland-hat
  • polinabinder1
  • savitha-eng

Poem

🐰 I stitched recipes thread by thread,

FP8 hops where frogs would dread,
FP4 whispers, BF16 stays calm,
Layers blink like a carrot charm,
Ranks align — a quantized drumbeat, tadah!

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Description check ❓ Inconclusive The PR description includes a clear Summary section outlining all major changes, an explicit Test plan with pytest commands, CI guidance, and references to comprehensive test coverage. However, the Usage section contains only 'TODO: Add code snippet' and the Type of changes checkboxes remain unchecked, leaving required template elements incomplete. Complete the Usage section with a concrete code snippet demonstrating per-layer quantization setup, and check the appropriate 'Type of changes' checkbox (likely 'New feature') to fully satisfy the template requirements.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding MXFP8 and NVFP4 quantization support to the LLaMA3 model, which aligns with the primary objective described in the PR summary and objectives.
Docstring Coverage ✅ Passed Docstring coverage is 80.52% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch jm/mxfp8-nvfp4-llama3

Comment @coderabbitai help to get the list of available commands and usage tips.

@jomitchellnv jomitchellnv force-pushed the jm/mxfp8-nvfp4-llama3 branch from cb05a45 to 4067915 Compare March 6, 2026 18:00
@jomitchellnv
Copy link
Copy Markdown
Collaborator Author

PERF results for llama3-1B
image

@jomitchellnv jomitchellnv changed the title [draft] Adds MXFP8 and NVFP4 to LLAMA3 Adds MXFP8 and NVFP4 to LLAMA3 Mar 6, 2026
@jomitchellnv jomitchellnv changed the title Adds MXFP8 and NVFP4 to LLAMA3 Add MXFP8 and NVFP4 quantization support to LLaMA3 Mar 6, 2026
@jomitchellnv
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 6, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (5)
bionemo-recipes/recipes/llama3_native_te/tests/conftest.py (1)

54-64: Use item.originalname for more robust test name matching.

The stats tests are not currently parametrized, but the reordering logic would silently break if they become parametrized in the future (pytest appends [param] suffixes to parametrized test names). Using item.originalname or stripping parameter suffixes makes this logic resilient to future changes.

Suggested change
 def pytest_collection_modifyitems(items):
     """Run FP8 stats logging tests first to avoid late debug initialization."""
     stats_test_names = {
         "test_sanity_ddp_fp8_stats_logging",
         "test_sanity_fsdp2_fp8_stats_logging",
         "test_sanity_ddp_fp8_partial_layers_stats_logging",
         "test_sanity_fsdp2_fp8_partial_layers_stats_logging",
     }
-    stats_tests = [item for item in items if item.name in stats_test_names]
-    other_tests = [item for item in items if item.name not in stats_test_names]
+    def _base_name(item):
+        return getattr(item, "originalname", item.name.split("[", 1)[0])
+
+    stats_tests = [item for item in items if _base_name(item) in stats_test_names]
+    other_tests = [item for item in items if _base_name(item) not in stats_test_names]
     items[:] = stats_tests + other_tests
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/tests/conftest.py` around lines 54 -
64, The reordering currently matches on item.name which breaks for parametrized
tests; change the comparisons in pytest_collection_modifyitems to use the test's
original name (e.g., use item.originalname or getattr(item, "originalname",
item.name)) when building stats_tests and other_tests so parameter suffixes like
"[param]" are ignored; update the two list comprehensions that reference
item.name and keep the rest of the function unchanged (symbols:
pytest_collection_modifyitems, stats_test_names, stats_tests, other_tests,
item.originalname).
bionemo-recipes/recipes/llama3_native_te/quantization.py (1)

86-94: Redundant YAML serialization and temp file cleanup consideration.

The config is serialized twice: once to the temp file (line 87) and again for logging (line 90). Consider reusing the serialized string. Also, the temp file with delete=False is never cleaned up - this is acceptable since it's a short-lived training run, but consider documenting this behavior or adding cleanup in the caller.

♻️ Optional: reuse serialized config
-    temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
-    yaml.dump(config, temp_file, default_flow_style=False)
-    temp_file.close()
-
-    config_str = yaml.dump(config, default_flow_style=False)
+    config_str = yaml.dump(config, default_flow_style=False)
+    temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
+    temp_file.write(config_str)
+    temp_file.close()
+
     logger.info(f"Created updated quant stats config at: {temp_file.name}")
     logger.info(f"Updated quant stats config contents:\n{config_str}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/quantization.py` around lines 86 -
94, The code serializes `config` twice (once writing to `temp_file` via
`yaml.dump(config, temp_file, ...)` and again into `config_str`) and the
temporary file created with `tempfile.NamedTemporaryFile(..., delete=False)` is
not documented or cleaned up; update the function that creates `temp_file` so
you serialize `config` once into a string (e.g., call `yaml.dump(config,
default_flow_style=False)` once), write that string to `temp_file` (instead of
calling `yaml.dump` twice) and reuse that string for the log message, and either
document that the returned filename must be removed by the caller or add
optional cleanup logic in the caller that removes the temp file when no longer
needed (refer to symbols `temp_file`, `config_str`, `yaml.dump`, and the
function that returns `temp_file.name`).
bionemo-recipes/recipes/llama3_native_te/perf_logger.py (1)

94-94: Consider renaming to clarify this is a boolean flag.

The attribute quant_stats_config stores a boolean (args.quant_stats_config.enabled), but the name suggests it holds a config object. Consider quant_stats_enabled for clarity and consistency with the previous fp8_stats_enabled naming convention.

Suggested rename for clarity
-        self.quant_stats_config = args.quant_stats_config.enabled
+        self.quant_stats_enabled = args.quant_stats_config.enabled

And update usages at lines 153 and 204 accordingly.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/perf_logger.py` at line 94, Rename
the boolean attribute self.quant_stats_config to self.quant_stats_enabled to
match the fp8_stats_enabled naming and avoid implying a config object; update
the constructor assignment (currently assigning args.quant_stats_config.enabled)
and replace every usage of self.quant_stats_config in the class (and any methods
that read it, e.g., where flags are checked or conditional logic runs) to
self.quant_stats_enabled, and adjust any related variable names, docs, and tests
that reference the old name so references remain consistent.
bionemo-recipes/models/llama3/modeling_llama_te.py (1)

191-211: Consider adding bounds check in get_layer_autocast.

If layer_number exceeds len(self.config.layer_precision), line 205 will raise an IndexError. Consider adding a bounds check or documenting this precondition.

Optional bounds check
     def get_layer_autocast(self, layer_number: int):
         """Return the appropriate TE autocast context manager for a given layer.
         ...
         """
-        precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None
+        if self.config.layer_precision is None or layer_number >= len(self.config.layer_precision):
+            precision = None
+        else:
+            precision = self.config.layer_precision[layer_number]
         if precision == "fp8":
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/llama3/modeling_llama_te.py` around lines 191 - 211,
get_layer_autocast currently indexes self.config.layer_precision without
validating layer_number; add a bounds check at the start of get_layer_autocast
to verify self.config.layer_precision is not None and 0 <= layer_number <
len(self.config.layer_precision), and if the check fails raise a clear exception
(e.g., IndexError or ValueError) with an explanatory message; keep the existing
behavior for valid indices (use precision =
self.config.layer_precision[layer_number] and return nullcontext() /
transformer_engine.pytorch.autocast(...) as before).
bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py (1)

24-26: sys.path manipulation for imports.

The sys.path.append pattern works but is fragile. Consider if there's a package structure that could avoid this, or add a comment explaining why this approach is necessary for the recipe test layout.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py` around
lines 24 - 26, The test currently mutates sys.path using
sys.path.append(Path(__file__).parent.parent.as_posix()) to import quantization,
which is fragile; either make the test an actual package (add an __init__.py in
the recipes/llama3_native_te/tests or recipes/llama3_native_te parent and import
via package path) or load the module via a test-time PYTHONPATH/pytest
configuration (e.g., set pythonpath in pytest.ini) so you don't need sys.path
manipulation, and if keeping the append must remain, replace it with a brief
comment above the line explaining why direct import is necessary for the recipe
layout and noting the intended workaround (package or pytest config) for future
maintainers; locate the sys.path.append line and the import of
generate_layer_regex/resolve_layer_precision/update_quant_stats_config in the
test file to apply the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/recipes/llama3_native_te/quantization.py`:
- Around line 140-175: The function resolve_layer_precision lacks validation
that entries in fp8_layers and fp4_layers fall within 1..num_layers; add checks
after all_layers is defined to validate any provided fp8_layers and fp4_layers
(and similarly handle duplicates/overlap) by iterating the lists and raising
ValueError listing invalid layer numbers (or a clear message) if any number <1
or >num_layers or not an int; reference fp8_layers, fp4_layers, num_layers,
all_layers and raise the error before proceeding with the existing
overlapping/none checks so invalid indices are caught early.

In `@bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py`:
- Line 53: The file currently imports only resolve_layer_precision from
quantization but misses initialize_quant_stats_logging and does not initialize
quant stats; add initialize_quant_stats_logging to the imports and, where
argument parsing/setup occurs (same area as resolve_layer_precision usage), call
initialize_quant_stats_logging(process_logger, args.quant_stats_config) when
args.quant_stats_config.enabled is true so quantization stats are initialized
the same way as in train_fsdp2.py/train_ddp.py.

---

Nitpick comments:
In `@bionemo-recipes/models/llama3/modeling_llama_te.py`:
- Around line 191-211: get_layer_autocast currently indexes
self.config.layer_precision without validating layer_number; add a bounds check
at the start of get_layer_autocast to verify self.config.layer_precision is not
None and 0 <= layer_number < len(self.config.layer_precision), and if the check
fails raise a clear exception (e.g., IndexError or ValueError) with an
explanatory message; keep the existing behavior for valid indices (use precision
= self.config.layer_precision[layer_number] and return nullcontext() /
transformer_engine.pytorch.autocast(...) as before).

In `@bionemo-recipes/recipes/llama3_native_te/perf_logger.py`:
- Line 94: Rename the boolean attribute self.quant_stats_config to
self.quant_stats_enabled to match the fp8_stats_enabled naming and avoid
implying a config object; update the constructor assignment (currently assigning
args.quant_stats_config.enabled) and replace every usage of
self.quant_stats_config in the class (and any methods that read it, e.g., where
flags are checked or conditional logic runs) to self.quant_stats_enabled, and
adjust any related variable names, docs, and tests that reference the old name
so references remain consistent.

In `@bionemo-recipes/recipes/llama3_native_te/quantization.py`:
- Around line 86-94: The code serializes `config` twice (once writing to
`temp_file` via `yaml.dump(config, temp_file, ...)` and again into `config_str`)
and the temporary file created with `tempfile.NamedTemporaryFile(...,
delete=False)` is not documented or cleaned up; update the function that creates
`temp_file` so you serialize `config` once into a string (e.g., call
`yaml.dump(config, default_flow_style=False)` once), write that string to
`temp_file` (instead of calling `yaml.dump` twice) and reuse that string for the
log message, and either document that the returned filename must be removed by
the caller or add optional cleanup logic in the caller that removes the temp
file when no longer needed (refer to symbols `temp_file`, `config_str`,
`yaml.dump`, and the function that returns `temp_file.name`).

In `@bionemo-recipes/recipes/llama3_native_te/tests/conftest.py`:
- Around line 54-64: The reordering currently matches on item.name which breaks
for parametrized tests; change the comparisons in pytest_collection_modifyitems
to use the test's original name (e.g., use item.originalname or getattr(item,
"originalname", item.name)) when building stats_tests and other_tests so
parameter suffixes like "[param]" are ignored; update the two list
comprehensions that reference item.name and keep the rest of the function
unchanged (symbols: pytest_collection_modifyitems, stats_test_names,
stats_tests, other_tests, item.originalname).

In `@bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py`:
- Around line 24-26: The test currently mutates sys.path using
sys.path.append(Path(__file__).parent.parent.as_posix()) to import quantization,
which is fragile; either make the test an actual package (add an __init__.py in
the recipes/llama3_native_te/tests or recipes/llama3_native_te parent and import
via package path) or load the module via a test-time PYTHONPATH/pytest
configuration (e.g., set pythonpath in pytest.ini) so you don't need sys.path
manipulation, and if keeping the append must remain, replace it with a brief
comment above the line explaining why direct import is necessary for the recipe
layout and noting the intended workaround (package or pytest config) for future
maintainers; locate the sys.path.append line and the import of
generate_layer_regex/resolve_layer_precision/update_quant_stats_config in the
test file to apply the change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 39eb17c1-c1ff-4756-90d6-a1eae8aa0d07

📥 Commits

Reviewing files that changed from the base of the PR and between 470e10d and 049e821.

📒 Files selected for processing (18)
  • bionemo-recipes/models/llama3/modeling_llama_te.py
  • bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
  • bionemo-recipes/models/llama3/tests/test_layer_quantization.py
  • bionemo-recipes/recipes/llama3_native_te/README.md
  • bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml
  • bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
  • bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
  • bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
  • bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
  • bionemo-recipes/recipes/llama3_native_te/perf_logger.py
  • bionemo-recipes/recipes/llama3_native_te/quantization.py
  • bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
  • bionemo-recipes/recipes/llama3_native_te/train_ddp.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
💤 Files with no reviewable changes (1)
  • bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py

@jomitchellnv
Copy link
Copy Markdown
Collaborator Author

@coderabbitai resolve

@jomitchellnv
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 6, 2026

✅ Actions performed

Comments resolved. Auto-approval is disabled; enable reviews.request_changes_workflow to approve automatically.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 6, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@jomitchellnv
Copy link
Copy Markdown
Collaborator Author

Remaining

  • Awaiting convergence results and conv updates inside README.md
  • Can do LLAMA3-8B PERF benchmarks (but not required for this MR)

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py (1)

92-100: Consider guarding layer_precision assignment for TE models only.

The layer_precision attribute is assigned to the config regardless of whether it's NVLlamaConfig or LlamaConfig. While this won't cause errors (Python allows dynamic attribute assignment), the attribute is only meaningful for TE models that use get_layer_autocast(). For clarity and consistency, consider wrapping this in a check.

♻️ Optional: Guard layer_precision assignment
 layer_precision = resolve_layer_precision(
     num_layers=config.num_hidden_layers,
     fp8_enabled=args.fp8_config.enabled,
     fp4_enabled=args.fp4_config.enabled,
     fp8_layers=OmegaConf.to_container(args.fp8_layers, resolve=True) if args.fp8_layers is not None else None,
     fp4_layers=OmegaConf.to_container(args.fp4_layers, resolve=True) if args.fp4_layers is not None else None,
 )
-config.layer_precision = layer_precision
+if args.use_te:
+    config.layer_precision = layer_precision
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py` around lines 92 -
100, The code always assigns config.layer_precision after calling
resolve_layer_precision, but layer_precision is only meaningful for TE models
(e.g., NVLlamaConfig) that use get_layer_autocast; update the block that calls
resolve_layer_precision and sets config.layer_precision to only set this
attribute when config is an instance of the TE-specific config class (e.g.,
NVLlamaConfig) — keep calling resolve_layer_precision as needed but guard the
assignment with an isinstance(config, NVLlamaConfig) (or equivalent TE config
check) so LlamaConfig instances are not mutated with a TE-only attribute.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/recipes/llama3_native_te/README.md`:
- Around line 68-77: The convergence benchmarks image paths in the README.md
reference a non-existent recipes/llama3 subdirectory; update the src attributes
in the <img> tags for the two convergence images (lingua-1b-loss-curve.png and
lingua-1b-step-time.png) to point to ../../../docs/docs/assets/images/llama3/…
(remove the recipes/ component) so they match the actual location; ensure you
modify the two src strings currently containing
.../assets/images/recipes/llama3/... to .../assets/images/llama3/... in the
README.md convergence section.

---

Nitpick comments:
In `@bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py`:
- Around line 92-100: The code always assigns config.layer_precision after
calling resolve_layer_precision, but layer_precision is only meaningful for TE
models (e.g., NVLlamaConfig) that use get_layer_autocast; update the block that
calls resolve_layer_precision and sets config.layer_precision to only set this
attribute when config is an instance of the TE-specific config class (e.g.,
NVLlamaConfig) — keep calling resolve_layer_precision as needed but guard the
assignment with an isinstance(config, NVLlamaConfig) (or equivalent TE config
check) so LlamaConfig instances are not mutated with a TE-only attribute.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6703b6b6-e632-489e-a1b5-96f595dfafc4

📥 Commits

Reviewing files that changed from the base of the PR and between 049e821 and 80e4897.

⛔ Files ignored due to path filters (3)
  • docs/docs/assets/images/llama3/lingua-1b-loss-curve.png is excluded by !**/*.png
  • docs/docs/assets/images/llama3/lingua-1b-step-time.png is excluded by !**/*.png
  • docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png is excluded by !**/*.png
📒 Files selected for processing (3)
  • bionemo-recipes/recipes/llama3_native_te/README.md
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 6, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.45%. Comparing base (470e10d) to head (5645c23).
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1500   +/-   ##
=======================================
  Coverage   76.45%   76.45%           
=======================================
  Files         102      102           
  Lines        7952     7952           
=======================================
  Hits         6080     6080           
  Misses       1872     1872           

@jomitchellnv jomitchellnv force-pushed the jm/mxfp8-nvfp4-llama3 branch 2 times, most recently from f82b896 to 8dda89b Compare March 17, 2026 03:43
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
@jomitchellnv jomitchellnv force-pushed the jm/mxfp8-nvfp4-llama3 branch from 8dda89b to 6079013 Compare March 17, 2026 05:02
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
@jomitchellnv
Copy link
Copy Markdown
Collaborator Author

@coderabbitai re-review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

@jomitchellnv Sure, I'll kick off a full re-review of the PR now!

✅ Actions performed

Full review triggered.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (5)
bionemo-recipes/recipes/llama3_native_te/quantization.py (1)

86-94: Consider simplifying the redundant YAML dump.

The config is serialized twice: once to write to the temp file (line 87) and again for logging (line 90). This is minor but could be optimized.

♻️ Suggested optimization
-    temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
-    yaml.dump(config, temp_file, default_flow_style=False)
-    temp_file.close()
-
-    config_str = yaml.dump(config, default_flow_style=False)
+    config_str = yaml.dump(config, default_flow_style=False)
+    temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
+    temp_file.write(config_str)
+    temp_file.close()
+
     logger.info(f"Created updated quant stats config at: {temp_file.name}")
     logger.info(f"Updated quant stats config contents:\n{config_str}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/quantization.py` around lines 86 -
94, The code currently calls yaml.dump twice; instead, call yaml.dump(config,
default_flow_style=False) once into a string (e.g., config_str), write that
string to the NamedTemporaryFile (temp_file.write(config_str) or
temp_file.write(config_str.encode() if opened in binary), log config_str with
logger.info, close the temp_file and return temp_file.name; update the logic
around temp_file, config_str, yaml.dump, logger.info, and return to use the
single serialized string.
bionemo-recipes/recipes/llama3_native_te/tests/test_train.py (1)

510-538: Consider adding log file assertions for consistency.

The DDP partial layers test only asserts directory existence, while the FSDP2 test also asserts that specific log files exist. For consistency and more thorough validation, consider adding the file existence checks here as well.

♻️ Suggested enhancement
     assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs").exists(), (
         "nvdlfw_inspect_statistics_logs directory was not created"
     )
+
+    # Verify log files exist (consistent with FSDP2 test)
+    assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
+    assert (quant_log_dir / "rank_0" / "nvdlfw_inspect_statistics_logs" / "nvdlfw_inspect_globalrank-0.log").exists()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/tests/test_train.py` around lines
510 - 538, Add assertions in test_sanity_ddp_fp8_partial_layers_stats_logging to
verify that the expected log files were created inside the directories (not just
the directories themselves): after the existing directory asserts for
quant_log_dir / "rank_0" / "nvdlfw_inspect_logs" and quant_log_dir / "rank_0" /
"nvdlfw_inspect_statistics_logs", assert that each directory contains at least
one file (e.g., using any(path.iterdir()) or
list(path.glob("*.log"))/len(list(path.iterdir())) > 0). Reference the test
function name test_sanity_ddp_fp8_partial_layers_stats_logging and the Path
variables quant_log_dir, "rank_0", nvdlfw_inspect_logs, and
nvdlfw_inspect_statistics_logs to locate where to insert these checks.
bionemo-recipes/recipes/llama3_native_te/README.md (1)

97-99: Minor grammar nit: Use hyphen in compound adjective.

The phrase "Low Precision convergence benchmarks" should be "Low-Precision convergence benchmarks" for grammatical consistency with line 68.

📝 Suggested fix
-### Low Precision convergence benchmarks
+### Low-Precision convergence benchmarks
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/README.md` around lines 97 - 99,
Update the header text "Low Precision convergence benchmarks" to use a
hyphenated compound adjective: change the string "Low Precision convergence
benchmarks" to "Low-Precision convergence benchmarks" (look for that exact
header text in README.md, around the section titled "Low Precision convergence
benchmarks").
bionemo-recipes/models/llama3/modeling_llama_te.py (1)

158-159: Clarify the intentional recipe reset pattern.

The _fp8_recipe and _fp4_recipe attributes are set at lines 158-159 for constructor validation, then explicitly reset to None at lines 221-222 before post_init(). This appears intentional to ensure recipes (which are not serializable) are attached via set_recipes() after model sharding.

Consider adding a brief comment at lines 221-222 explaining this design decision:

📝 Suggested clarification
+        # Reset recipes to None - they must be re-attached via set_recipes() after model
+        # creation/sharding since Recipe objects are not serializable.
         self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None
         self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None

Also applies to: 221-222

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/llama3/modeling_llama_te.py` around lines 158 - 159,
Add a short inline comment where the code resets self._fp8_recipe and
self._fp4_recipe to None (before calling post_init()) explaining the intentional
pattern: these Recipe objects are set in the constructor only for validation but
are not serializable and must be attached later via set_recipes() after model
sharding; this clarifies why we clear them here and expect set_recipes() to
provide the runtime recipes. Ensure the comment references the attributes
_fp8_recipe and _fp4_recipe and the set_recipes() and post_init() flow so future
readers understand the design.
bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py (1)

248-314: Consider cleaning up temp files created by update_quant_stats_config.

The tests properly use tmp_path for input fixtures, but update_quant_stats_config creates output files in the system temp directory with delete=False. These files accumulate across test runs. Consider adding cleanup or using tmp_path for output verification.

♻️ Optional: Add temp file cleanup in tests
 def test_fp8_layers_updates_regex(fp8_only_config):
     """FP8 layer list should update the regex in the output config."""
     output_path = update_quant_stats_config(config_file=fp8_only_config, fp4_layers=None, fp8_layers=[1, 2, 3])
-    with open(output_path) as f:
-        result = yaml.safe_load(f)
-    regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
-    assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
-    assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2")
-    assert not re.search(regex, "model.model.layers.4.self_attention.proj")
+    try:
+        with open(output_path) as f:
+            result = yaml.safe_load(f)
+        regex = result["example_fp8_tensor_stat_collection"]["layers"]["layer_name_regex_pattern"]
+        assert re.search(regex, "model.model.layers.1.self_attention.layernorm_qkv")
+        assert re.search(regex, "model.model.layers.3.layernorm_mlp.fc2")
+        assert not re.search(regex, "model.model.layers.4.self_attention.proj")
+    finally:
+        Path(output_path).unlink(missing_ok=True)

Alternatively, consider creating a pytest fixture that wraps update_quant_stats_config with automatic cleanup.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py` around
lines 248 - 314, Tests call update_quant_stats_config which creates temp files
with delete=False and never cleans them up; update each test that calls
update_quant_stats_config (e.g., test_fp8_layers_updates_regex,
test_none_layers_disables_matching, test_fp4_section_disabled_fp8_still_updated,
test_original_file_not_modified, test_preserves_other_config_fields,
test_missing_section_is_skipped) to remove the produced file after assertions
(os.remove(output_path)) or, better, pass a tmp_path-based output location if
update_quant_stats_config supports an output path parameter; alternatively add a
small pytest fixture that wraps update_quant_stats_config and ensures the
returned temp file is unlinked in teardown to avoid accumulating temp files.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml`:
- Around line 44-45: Add a runtime validation that enforces the constraint
between quantized_model_init_kwargs.enabled and fp8_config.enabled: when
quantized_model_init_kwargs.enabled is true but fp8_config.enabled is false,
raise a clear error (ValueError/RuntimeError) with a message instructing the
user to enable fp8_config. Implement this check the same way as the existing
pattern in fp8_debugging.py (lines 49-53) — add it to the configuration/recipe
initialization path so it runs at startup and prevents passing recipe=None with
enabled=true to the quantized_model_init context manager.

---

Nitpick comments:
In `@bionemo-recipes/models/llama3/modeling_llama_te.py`:
- Around line 158-159: Add a short inline comment where the code resets
self._fp8_recipe and self._fp4_recipe to None (before calling post_init())
explaining the intentional pattern: these Recipe objects are set in the
constructor only for validation but are not serializable and must be attached
later via set_recipes() after model sharding; this clarifies why we clear them
here and expect set_recipes() to provide the runtime recipes. Ensure the comment
references the attributes _fp8_recipe and _fp4_recipe and the set_recipes() and
post_init() flow so future readers understand the design.

In `@bionemo-recipes/recipes/llama3_native_te/quantization.py`:
- Around line 86-94: The code currently calls yaml.dump twice; instead, call
yaml.dump(config, default_flow_style=False) once into a string (e.g.,
config_str), write that string to the NamedTemporaryFile
(temp_file.write(config_str) or temp_file.write(config_str.encode() if opened in
binary), log config_str with logger.info, close the temp_file and return
temp_file.name; update the logic around temp_file, config_str, yaml.dump,
logger.info, and return to use the single serialized string.

In `@bionemo-recipes/recipes/llama3_native_te/README.md`:
- Around line 97-99: Update the header text "Low Precision convergence
benchmarks" to use a hyphenated compound adjective: change the string "Low
Precision convergence benchmarks" to "Low-Precision convergence benchmarks"
(look for that exact header text in README.md, around the section titled "Low
Precision convergence benchmarks").

In `@bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py`:
- Around line 248-314: Tests call update_quant_stats_config which creates temp
files with delete=False and never cleans them up; update each test that calls
update_quant_stats_config (e.g., test_fp8_layers_updates_regex,
test_none_layers_disables_matching, test_fp4_section_disabled_fp8_still_updated,
test_original_file_not_modified, test_preserves_other_config_fields,
test_missing_section_is_skipped) to remove the produced file after assertions
(os.remove(output_path)) or, better, pass a tmp_path-based output location if
update_quant_stats_config supports an output path parameter; alternatively add a
small pytest fixture that wraps update_quant_stats_config and ensures the
returned temp file is unlinked in teardown to avoid accumulating temp files.

In `@bionemo-recipes/recipes/llama3_native_te/tests/test_train.py`:
- Around line 510-538: Add assertions in
test_sanity_ddp_fp8_partial_layers_stats_logging to verify that the expected log
files were created inside the directories (not just the directories themselves):
after the existing directory asserts for quant_log_dir / "rank_0" /
"nvdlfw_inspect_logs" and quant_log_dir / "rank_0" /
"nvdlfw_inspect_statistics_logs", assert that each directory contains at least
one file (e.g., using any(path.iterdir()) or
list(path.glob("*.log"))/len(list(path.iterdir())) > 0). Reference the test
function name test_sanity_ddp_fp8_partial_layers_stats_logging and the Path
variables quant_log_dir, "rank_0", nvdlfw_inspect_logs, and
nvdlfw_inspect_statistics_logs to locate where to insert these checks.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7e63c8a0-7ffd-4ff6-a4e8-711e7a796b54

📥 Commits

Reviewing files that changed from the base of the PR and between 46112e7 and a616dd8.

⛔ Files ignored due to path filters (4)
  • docs/docs/assets/images/llama3/lingua-1b-loss-curve.png is excluded by !**/*.png
  • docs/docs/assets/images/llama3/lingua-1b-step-time.png is excluded by !**/*.png
  • docs/docs/assets/images/llama3/llama3_1b_fsdp2_tflops.png is excluded by !**/*.png
  • docs/docs/assets/images/llama3/llama3_8gpu_tflops.png is excluded by !**/*.png
📒 Files selected for processing (18)
  • bionemo-recipes/models/llama3/modeling_llama_te.py
  • bionemo-recipes/models/llama3/tests/test_distributed_fp8.py
  • bionemo-recipes/models/llama3/tests/test_layer_quantization.py
  • bionemo-recipes/recipes/llama3_native_te/README.md
  • bionemo-recipes/recipes/llama3_native_te/fp4_debugging_stats.yaml
  • bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py
  • bionemo-recipes/recipes/llama3_native_te/fp8_debugging_stats.yaml
  • bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml
  • bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py
  • bionemo-recipes/recipes/llama3_native_te/perf_logger.py
  • bionemo-recipes/recipes/llama3_native_te/quantization.py
  • bionemo-recipes/recipes/llama3_native_te/tests/conftest.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_perf_logger.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_quantization.py
  • bionemo-recipes/recipes/llama3_native_te/tests/test_train.py
  • bionemo-recipes/recipes/llama3_native_te/train_ddp.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py
💤 Files with no reviewable changes (1)
  • bionemo-recipes/recipes/llama3_native_te/fp8_debugging.py

Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
x
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
x
Signed-off-by: Jonathan Mitchell <jomitchell@nvidia.com>
savitha-eng added a commit that referenced this pull request Mar 25, 2026
Replace broken fp8_first_last_bf16 mechanism with resolve_layer_precision()
from PR #1500. The old approach set config attributes that were never read by
the forward pass, causing all layers to default to FP8 regardless of setting.

Key changes:
- Delete fp8_debugging.py, add quantization.py with resolve_layer_precision()
  and initialize_quant_stats_logging()
- Add set_recipes()/get_layer_autocast() to OG2 model (from lepton branch),
  model now handles per-layer autocast internally
- Model constructor accepts fp8_recipe/fp4_recipe, set_recipes() called after
  FSDP wrapping since recipes aren't serializable
- Remove outer te.autocast() from training loop (model handles it)
- Rename fp8_stats_config -> quant_stats_config throughout
- Add _parse_layers_cfg() for CLI string support
- Add og2_7b_fp8_fl1_pq2.yaml with explicit fp8_layers=[2..31]
- Expand fp8_debugging_stats.yaml with all layer types + LogTensorStats

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's going on here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants