[JAX] Fix imports in test for deprecated jax.experimental.pjit#2274
Conversation
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
…that when the in and out sharding is used to create a jitted function, it has the mesh info Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
65ad0c0 to
7c3cc09
Compare
for more information, see https://pre-commit.ci
|
The UNSPECIFIED and pjit changes look good to me. Can you clarify this portion?
Is that a difference in jit's signature that requires NamedSharding instead of the PartitionSpec? We always use NamedSharding to specifying sharding in other places, so the changes look good, just want to make sure I'm following the reasoning for the change. Thanks! |
Sure @jberchtold-nvidia ! |
* Fix imports in test for deprecated jax.experimental.pjit Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix: Pass NamedSharding instead of PartitionSpec to compare_ops() so that when the in and out sharding is used to create a jitted function, it has the mesh info Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Signed-off-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kshitij Janardan Lakhani <klakhani@login-eos01.eos.clusters.nvidia.com>
Description
Deprecated jax.experimental.pjit causes import errors in TE L1 JAX tests
This PR fixes these errors by correcting the imports
Type of change
Changes
Import UNSPECIFIED from a different module and use jit instead of pjit as done here
Pass NamedSharding instead of PartitionSpec in the distributed softmax and layernorm tests so that the jitted function being created in compare_ops() receives mesh info as part of the in and out shardings
Checklist: