Skip to content

[JAX] Fix imports in test for deprecated jax.experimental.pjit#2274

Merged
KshitijLakhani merged 4 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/fix/deprecated-pjit-failure
Oct 17, 2025
Merged

[JAX] Fix imports in test for deprecated jax.experimental.pjit#2274
KshitijLakhani merged 4 commits into
NVIDIA:mainfrom
KshitijLakhani:klakhani/fix/deprecated-pjit-failure

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Oct 15, 2025

Description

Deprecated jax.experimental.pjit causes import errors in TE L1 JAX tests
This PR fixes these errors by correcting the imports

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Import UNSPECIFIED from a different module and use jit instead of pjit as done here

Pass NamedSharding instead of PartitionSpec in the distributed softmax and layernorm tests so that the jitted function being created in compare_ops() receives mesh info as part of the in and out shardings

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

KshitijLakhani and others added 3 commits October 16, 2025 11:19
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…that when the in and out sharding is used to create a jitted function, it has the mesh info

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/deprecated-pjit-failure branch from 65ad0c0 to 7c3cc09 Compare October 16, 2025 18:27
@KshitijLakhani KshitijLakhani marked this pull request as ready for review October 16, 2025 22:18
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

The UNSPECIFIED and pjit changes look good to me.

Can you clarify this portion?

Pass NamedSharding instead of PartitionSpec in the distributed softmax and layernorm tests so that the jitted function being created in compare_ops() receives mesh info as part of the in and out shardings

Is that a difference in jit's signature that requires NamedSharding instead of the PartitionSpec? We always use NamedSharding to specifying sharding in other places, so the changes look good, just want to make sure I'm following the reasoning for the change. Thanks!

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

KshitijLakhani commented Oct 17, 2025

The UNSPECIFIED and pjit changes look good to me.

Can you clarify this portion?

Pass NamedSharding instead of PartitionSpec in the distributed softmax and layernorm tests so that the jitted function being created in compare_ops() receives mesh info as part of the in and out shardings

Is that a difference in jit's signature that requires NamedSharding instead of the PartitionSpec? We always use NamedSharding to specifying sharding in other places, so the changes look good, just want to make sure I'm following the reasoning for the change. Thanks!

Sure @jberchtold-nvidia !
So seems like pjit() accepted PartitionSpec as args, however, when I made the move to jit from pjit to accommodate for the pjit deprecation, I saw failures because the in_shardings and out_shardings being passed to compare_ops were of type PartitionSpec.
So, I could either set the mesh in compare_ops() or just ensure that the distributed layer norm and distributed softmax tests pass NamedSharding objects to compare_ops() as args - I chose the latter.

@KshitijLakhani KshitijLakhani merged commit 9dd6192 into NVIDIA:main Oct 17, 2025
10 of 12 checks passed
@KshitijLakhani KshitijLakhani deleted the klakhani/fix/deprecated-pjit-failure branch October 17, 2025 03:45
KshitijLakhani added a commit that referenced this pull request Oct 17, 2025
* Fix imports in test for deprecated jax.experimental.pjit

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix: Pass NamedSharding instead of PartitionSpec to compare_ops() so that when the in and out sharding is used to create a jitted function, it has the mesh info

Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kshitij  Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants