Skip to content

[JAX] Handle meshs set with jax.set_mesh#2532

Merged
jberchtold-nvidia merged 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-set-mesh
Dec 19, 2025
Merged

[JAX] Handle meshs set with jax.set_mesh#2532
jberchtold-nvidia merged 2 commits intoNVIDIA:mainfrom
jberchtold-nvidia:jberchtold/fix-set-mesh

Conversation

@jberchtold-nvidia
Copy link
Collaborator

Description

Fixes issue that we cannot query jax Mesh info from Meshs set via jax.set_mesh

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Check get_abstract_mesh() in addition to pxla thread resources when looking for mesh context

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Collaborator Author

/te-ci L1 L2 jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 18, 2025

Greptile Summary

Fixed a bug where TransformerEngine/JAX couldn't query mesh information from meshes set via jax.set_mesh(). Previously, the code only checked _PXLA_THREAD_RESOURCES.env.physical_mesh, which is populated when using with mesh: context managers but not when using jax.set_mesh().

Key Changes:

  • Added new _get_mesh() helper function that checks both thread resources and jax.sharding.get_abstract_mesh()
  • Replaced 6 direct accesses to _PXLA_THREAD_RESOURCES.env.physical_mesh with calls to _get_mesh()
  • Ensures backward compatibility by checking thread resources first, then falling back to abstract mesh

Impact:

  • Enables TE/JAX to work correctly with both mesh initialization patterns
  • No breaking changes to existing code using with mesh: context

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk
  • The fix is well-structured with a clear fallback pattern. The change is backwards compatible (checking thread resources first maintains existing behavior), localized to one file, and addresses a specific bug. Confidence reduced by 1 point due to one minor consideration around null safety.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/sharding.py Added _get_mesh() helper to check both thread resources and jax.set_mesh(), replaced 6 direct _PXLA_THREAD_RESOURCES.env.physical_mesh accesses with the new helper

Sequence Diagram

sequenceDiagram
    participant User
    participant JAX
    participant TE as TransformerEngine
    participant PXLA as PXLA Thread Resources
    participant GetAbstractMesh as jax.sharding.get_abstract_mesh()

    Note over User,GetAbstractMesh: Scenario 1: Mesh set via 'with mesh:' context
    User->>JAX: with mesh:
    JAX->>PXLA: Set physical_mesh
    User->>TE: Call TE function (e.g., is_mesh_available())
    TE->>TE: _get_mesh()
    TE->>PXLA: Check physical_mesh
    PXLA-->>TE: Returns mesh (not empty)
    TE-->>User: Returns mesh
    
    Note over User,GetAbstractMesh: Scenario 2: Mesh set via jax.set_mesh()
    User->>JAX: jax.set_mesh(mesh)
    JAX->>GetAbstractMesh: Store mesh in abstract context
    Note over PXLA: physical_mesh remains None or empty
    User->>TE: Call TE function (e.g., is_mesh_available())
    TE->>TE: _get_mesh()
    TE->>PXLA: Check physical_mesh
    PXLA-->>TE: Returns None or empty mesh
    TE->>GetAbstractMesh: Fallback to get_abstract_mesh()
    GetAbstractMesh-->>TE: Returns mesh from jax.set_mesh()
    TE-->>User: Returns mesh
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/jax/sharding.py, line 40-46 (link)

    style: check if the fallback to get_abstract_mesh() should happen when physical_mesh is explicitly set to an empty (but non-None) mesh. Currently, if someone uses with empty_mesh:, it will fallback to get_abstract_mesh() which might return a different mesh set via jax.set_mesh(). Consider whether the condition should be if mesh is None: instead of if mesh is not None and not mesh.empty:. If physical_mesh is explicitly set to an empty mesh object (not None), should the code still fallback to get_abstract_mesh()?

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Collaborator

@tdophung tdophung left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jberchtold-nvidia jberchtold-nvidia merged commit d46d5db into NVIDIA:main Dec 19, 2025
28 of 32 checks passed
@jberchtold-nvidia jberchtold-nvidia deleted the jberchtold/fix-set-mesh branch December 19, 2025 19:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants