Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions tests/jax/distributed_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED
from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED

from transformer_engine.jax.sharding import MeshResource

Expand Down Expand Up @@ -154,13 +154,15 @@ def compare_ops(
grad_args = tuple(range(len(inputs)))

target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
target_fwd, target_grads = target_pjitter(*inputs, **kwargs)
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text()
target_jitter = jax.jit(
target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings
)
target_fwd, target_grads = target_jitter(*inputs, **kwargs)
target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text()

ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs)
ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs)

assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)

Expand Down
26 changes: 17 additions & 9 deletions tests/jax/test_distributed_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ def ref_func(x, gamma, beta):
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
b_named_sharding = NamedSharding(mesh, b_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
beta_ = jax.device_put(beta, b_named_sharding)

with warnings.catch_warnings(record=True) as warns:
try:
Expand All @@ -148,8 +151,11 @@ def ref_func(x, gamma, beta):
grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding),
out_shardings=(
None,
(x_named_sharding, g_named_sharding, b_named_sharding),
),
)
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
Expand Down Expand Up @@ -210,8 +216,10 @@ def ref_func(x, gamma):
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)

with warnings.catch_warnings(record=True) as warns:
try:
Expand All @@ -223,8 +231,8 @@ def ref_func(x, gamma):
grad_args=(0, 1),
metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec),
out_shardings=(None, (x_pspec, g_pspec)),
in_shardings=(x_named_sharding, g_named_sharding),
out_shardings=(None, (x_named_sharding, g_named_sharding)),
)
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
Expand Down
10 changes: 6 additions & 4 deletions tests/jax/test_distributed_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,10 @@ def impl_test_softmax(
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))
x_named_sharding = NamedSharding(mesh, x_pspec)
mask_named_sharding = NamedSharding(mesh, mask_pspec)
x_ = jax.device_put(x, x_named_sharding)
mask_ = jax.device_put(mask, mask_named_sharding)

with warnings.catch_warnings(record=True) as warns:
try:
Expand All @@ -116,8 +118,8 @@ def impl_test_softmax(
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)),
in_shardings=(x_named_sharding, mask_named_sharding),
out_shardings=(None, x_named_sharding),
)
except AssertionError as err:
# Softmax should still produce the correct numerical result with
Expand Down
Loading