Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/jax/encoder/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import transformer_engine
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.common import recipe


@lru_cache
Expand All @@ -20,3 +21,21 @@ def is_fp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 90


@lru_cache
def is_mxfp8_supported():
"""Return if FP8 has hardware supported"""
gpu_arch = get_device_compute_capability(0)
return gpu_arch >= 100


def get_fp8_recipe_from_name_string(name: str):
"""Query recipe from a given name string"""
match name:
case "DelayedScaling":
return recipe.DelayedScaling()
case "MXFP8BlockScaling":
return recipe.MXFP8BlockScaling()
case _:
raise ValueError(f"Invalid fp8_recipe, got {name}")
8 changes: 7 additions & 1 deletion examples/jax/encoder/run_test_multiprocessing_encoder.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ wait

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_fp8 --num-process=$NUM_GPUS --process-id=$i &
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait

for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
85 changes: 63 additions & 22 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding

from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
DEVICE_TP_AXIS = "model"
Expand Down Expand Up @@ -217,9 +218,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr


def get_params_sharding(sharding_rules, abs_var_collect, mesh):
Expand Down Expand Up @@ -272,6 +272,19 @@ def train_and_evaluate(args):
args.test_batch_size % num_gpu_dp == 0
), f"Test batch size needs to be multiple of {num_gpu_dp}"

if args.fp8_recipe == "MXFP8BlockScaling":
assert (
args.batch_size / num_gpu_dp % 32 == 0
), "Batch size needs to be multiple of 32 for MXFP8"
assert (
args.test_batch_size / num_gpu_dp % 32 == 0
), "Test batch size needs to be multiple of 32 for MXFP8"

if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None

device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(
devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
Expand All @@ -287,7 +300,9 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]

with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None),
):
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
Expand Down Expand Up @@ -371,21 +386,21 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for training (default: 64)",
help="input batch size for training (default: 128)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=64,
default=128,
metavar="N",
help="input batch size for testing (default: 64)",
help="input batch size for testing (default: 128)",
)
parser.add_argument(
"--max-seq-len",
type=int,
default=32,
default=64,
metavar="N",
help="maximum sequence length (default: 32)",
)
Expand Down Expand Up @@ -416,6 +431,12 @@ def encoder_parser(args):
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
Expand All @@ -426,7 +447,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

gpu_has_fp8, reason = te.fp8.is_fp8_available()
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand All @@ -437,29 +459,48 @@ def setUpClass(cls):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.50 and actual[1] > 0.76

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76

@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.50 and actual[1] > 0.76

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_sp(self):
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.45 and actual[1] > 0.79
assert actual[0] < 0.50 and actual[1] > 0.76

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76

@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8_sp(self):
"""Test Transformer Engine with FP8 + SP"""
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
"""Test Transformer Engine with MXFP8 + SP"""
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
assert actual[0] < 0.50 and actual[1] > 0.76


if __name__ == "__main__":
Expand Down
60 changes: 45 additions & 15 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from jax.experimental import mesh_utils
from jax.sharding import PartitionSpec, NamedSharding

from common import is_bf16_supported, get_fp8_recipe_from_name_string
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
from transformer_engine.jax.quantize import is_fp8_available, ScalingMode

from common import is_bf16_supported

DEVICE_DP_AXIS = "data"
PARAMS_KEY = "params"
Expand Down Expand Up @@ -198,9 +199,8 @@ def get_datasets(max_seq_len):
def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
)
func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr


def get_params_sharding(sharding_rules, abs_var_collect, mesh):
Expand Down Expand Up @@ -243,6 +243,18 @@ def train_and_evaluate(args):
num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
if args.fp8_recipe == "MXFP8BlockScaling":
assert (
args.batch_size / num_gpu % 32 == 0
), "Batch size needs to be multiple of 32 for MXFP8"
assert (
args.test_batch_size / num_gpu % 32 == 0
), "Test batch size needs to be multiple of 32 for MXFP8"

if args.use_fp8:
fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe)
else:
fp8_recipe = None

device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh:
Expand All @@ -257,7 +269,9 @@ def train_and_evaluate(args):
label_shape = [args.batch_size]

with te.fp8_autocast(
args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
enabled=args.use_fp8,
fp8_recipe=fp8_recipe,
mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None),
):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
Expand Down Expand Up @@ -344,16 +358,16 @@ def encoder_parser(args):
parser.add_argument(
"--batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for training (default: 128)",
help="input batch size for training (default: 256)",
)
parser.add_argument(
"--test-batch-size",
type=int,
default=128,
default=256,
metavar="N",
help="input batch size for testing (default: 128)",
help="input batch size for testing (default: 256)",
)
parser.add_argument(
"--max-seq-len",
Expand Down Expand Up @@ -389,14 +403,21 @@ def encoder_parser(args):
default=False,
help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--fp8-recipe",
action="store_true",
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)

return parser.parse_args(args)


class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

gpu_has_fp8, reason = te.fp8.is_fp8_available()
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand All @@ -407,14 +428,23 @@ def setUpClass(cls):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.535 and actual[1] > 0.73

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73

@unittest.skipIf(not gpu_has_fp8, reason)
def test_te_fp8(self):
"""Test Transformer Engine with FP8"""
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.535 and actual[1] > 0.73


if __name__ == "__main__":
Expand Down
Loading