Skip to content

[JAX] Improve JAX tutorial documentation#2976

Merged
jberchtold-nvidia merged 18 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial
May 21, 2026
Merged

[JAX] Improve JAX tutorial documentation#2976
jberchtold-nvidia merged 18 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/improve-jax-tutorial

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented May 11, 2026

Description

Reworks tutorial to focus on individual operations and their usage+performance. This will make it clearer to users the impact of each operation and they can focus on trying them out one-at-a-time depending on which are bottlenecks in their models.

Additionally, this switches from notebook .ipynb files to .rst and separate .py files for easier testing in CI to ensure our docs do not become stale and always work with the latest TE version.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Rework existing tutorial and replace with new Dense-specific tutorial
  • Placeholders for Attention and MoE
  • Refactor .ipynb notebooks to .rst and .py files for similar appearance in docs but better testability in CI by running .py files

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR replaces the monolithic JAX integration notebook (te_jax_integration.ipynb) with a structured set of .rst + .py files, starting with a Dense GEMM tutorial. The refactor improves CI testability by running .py files directly with pytest instead of executing notebooks.

  • New Dense tutorial (dense.py + dense.rst): walks through swapping nn.Dense's GEMM for TE's quantized GEMM using make_dot_general_cls, covering single-GPU and 4-GPU DP+TP benchmarks.
  • Test coverage (test_dense.py): pytest tests for forward shape, numerical parity, and benchmarks; imports from dense are deferred into test bodies behind @requires_mxfp8 marks to prevent collection failures on non-Blackwell hardware.
  • CI integration: both L0_jax_unittest/test.sh and L0_jax_distributed_unittest/test.sh are updated to run pytest on docs/examples/jax/; placeholder .rst files added for Attention, Collective GEMM, and MoE tutorials.

Confidence Score: 5/5

Documentation-only refactor that is safe to merge; no production code paths are touched.

The change is entirely documentation and CI scripts. The pytest wiring is correct, the deferred-import pattern in test_dense.py properly handles non-Blackwell hardware, and cross-references in the RST files are valid. The two non-blocking notes are editorial nits that do not affect CI or runtime behavior.

No files require special attention for merge safety.

Important Files Changed

Filename Overview
docs/examples/jax/dense.py New tutorial script demonstrating quantized Dense GEMMs via TE's make_dot_general_cls; well-structured with marker comments for RST literalinclude.
docs/examples/jax/test_dense.py Pytest tests for dense.py; imports from dense are correctly deferred into test bodies behind @requires_mxfp8 marks so collection does not fail on non-Blackwell hardware.
docs/examples/jax/quickstart_jax_utils.py Moved/expanded benchmark utility; speedometer lacks jax.block_until_ready() after the timing loop, which can cause JAX async dispatch to produce underestimated latency numbers.
docs/examples/jax/dense.out Captured benchmark output with RST literalinclude markers; the inline regeneration instruction would silently strip those markers and break the Sphinx build.
docs/examples/jax/dense.rst New RST tutorial page; literalinclude markers and cross-references are correct.
docs/examples/te_jax_integration.rst New landing-page RST replacing the removed notebook; contains a truncated sentence in the Conventions section.
qa/L0_jax_unittest/test.sh CI script correctly updated to point pytest at docs/examples/jax/ for the new tutorial tests.
qa/L0_jax_distributed_unittest/test.sh Distributed CI script correctly targets docs/examples/jax/ with -k multi_gpu filter; auto-skips when fewer than 4 GPUs are available.

Reviews (12): Last reviewed commit: "Merge branch 'main' into jberchtold/impr..." | Re-trigger Greptile

Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/attention.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Comment thread docs/examples/jax_examples/moe.ipynb Outdated
Comment thread docs/examples/jax_examples/dense.ipynb Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment thread docs/examples/jax/dense.py
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 L0

Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this skeleton.
I like the modular approach, concise explanation and benchmarking.

In general it looks good there might be some working around needed on item placements but I think that's going to be an evolving process.

Comment on lines +1 to +11
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

JAX: Attention with TransformerEngine
=====================================

**TODO — Coming soon.**

`← Back to the JAX integration overview <../te_jax_integration.html>`_
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to attention but looks like you are renaming the dir to examples/jax_examples whereas I think the pytorch side is examples/pytorch ?
I think we could stick with examples/jax - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated to examples/jax

`Haiku/Flax interop
<https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html>`_ if you're on
a different stack.)
* **Baseline dtype.** bf16 for inputs and parameters.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add GB200 (arch) details here rather than adding it in the example module or is that by choice ?
I think there's value in having all examples run on the same arch for consistency.

Comment thread docs/examples/jax/attention.rst
Comment thread docs/examples/jax/conftest.py Outdated
#
# See LICENSE for license information.

"""Pytest conftest for docs/examples/jax_examples.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the usage of pytest in general, however I think currently the examples/mnist uses the in built Python UT module for the test example.
@phu0ngng and @tdophung it might be good to standardize and use pytest in there too - thoughts ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, in our main tests in tests/jax/ we use pytest. In our examples/jax we do use unittest instead, but then run those tests in CI with pytest examples/jax/.... because pytest can also run unittest tests.

I'm ok with standardizing and using pytest everywhere. We already have requirements.txt files for running the examples/jax/mnist or encoder tests, so we could add the pytest dependency there too.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with standardizing the use of pytest

Comment thread docs/examples/jax_examples/dense.out Outdated
Comment thread docs/examples/jax_examples/dense.rst Outdated
and your performance comparison will not be accurate.


6. Multi-GPU: DP=2 / TP=2 on a single Dense
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Single GPU performane
    4,5 ?
  2. Multi-GPU: DP=2 / TP=2 on a single Dense

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I had it broken into more sections and forgot to update the latest section numbers. Fixed now

Comment thread qa/L0_jax_unittest/test.sh Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 7464325 to 5432ec6 Compare May 15, 2026 16:43
Comment thread qa/L0_jax_unittest/test.sh Outdated
Comment thread qa/L1_jax_distributed_unittest/test.sh Outdated
jberchtold-nvidia and others added 2 commits May 15, 2026 09:48
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Comment thread docs/examples/jax/dense.rst Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Comment thread docs/examples/jax/test_dense.py Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 48884cd to 168cc63 Compare May 15, 2026 18:36
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 54b1a9c to 4c1fec9 Compare May 15, 2026 19:02
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

Comment thread docs/examples/jax/test_dense.py
Comment thread docs/examples/jax/dense.rst Outdated
Comment thread qa/L0_jax_unittest/test.sh
Comment thread qa/L1_jax_distributed_unittest/test.sh Outdated
Comment thread docs/examples/jax/dense.out
Comment thread docs/examples/jax/dense.rst Outdated
Comment thread docs/examples/te_jax_integration.rst Outdated
Comment thread docs/examples/jax/dense.rst Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/improve-jax-tutorial branch from 92d9e29 to c2c8444 Compare May 19, 2026 22:27
Dense — and the only code change was passing ``dot_general=te_dot_general_cls()``
into ``nn.Dense``.

The speedup depends on shape: large GEMMs benefit most. Very small GEMMs may
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we know the threshold to this. At which size do we start getting benefit

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on the GPU type, if it's square or narrow, and probably the cuBLAS version. But I'll search through cuBLAS docs just in case they have anything I can link to that will stay up-to-date with their latest version's perf improvements

Comment thread docs/examples/jax/dense.rst Outdated
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@tdophung
Copy link
Copy Markdown
Collaborator

LGTM pending CI

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

1 similar comment
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci

@jberchtold-nvidia jberchtold-nvidia merged commit 8c0f1d2 into NVIDIA:main May 21, 2026
11 of 13 checks passed
@jberchtold-nvidia jberchtold-nvidia deleted the jberchtold/improve-jax-tutorial branch May 21, 2026 17:06
KshitijLakhani pushed a commit that referenced this pull request May 21, 2026
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Teddy Do <tdophung@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants