Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
797808a
rename QuantizeAxis to QuantizeLayout
phu0ngng Apr 1, 2025
3656756
rename get_layout to get_data_layout
phu0ngng Apr 1, 2025
a859f4b
rename q_axis to q_layout
phu0ngng Apr 1, 2025
ae80e5e
rename layout to data_layout
phu0ngng Apr 1, 2025
88c67cd
add q_axis to quantize/
phu0ngng Apr 1, 2025
dc06cb8
format
phu0ngng Apr 1, 2025
76af138
q_axis to quantization.py
phu0ngng Apr 1, 2025
885599d
fixes for quantize/.*py
phu0ngng Apr 1, 2025
72b6149
TestQuantize passed
phu0ngng Apr 1, 2025
fd8cec0
TestActivation passed
phu0ngng Apr 2, 2025
3428fe0
TestFusedQuantize passed
phu0ngng Apr 2, 2025
3e3c51d
rename quantize_axis to flatten_axis
phu0ngng Apr 2, 2025
ac11aea
TestFusedDense passed
phu0ngng Apr 2, 2025
cf70d7b
rework flax layer
phu0ngng Apr 2, 2025
5d491e3
fix for axes_len>1 and most test passed
phu0ngng Apr 3, 2025
fc6be78
enabled 4 gpus
phu0ngng Apr 3, 2025
b7feb1c
added gated act to test encoder
phu0ngng Apr 3, 2025
fb7d001
use dact_lu
phu0ngng Apr 3, 2025
a591065
format
phu0ngng Apr 3, 2025
ce893b9
fix transpose constraint
phu0ngng Apr 3, 2025
a8def8e
fix gemm output shardings
phu0ngng Apr 3, 2025
32a6073
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 3, 2025
eb6749e
fix padding when flattening first dim needs to be padded
phu0ngng Apr 4, 2025
4fcc5ec
clean and minor fixes
phu0ngng Apr 4, 2025
16a54de
update test sizes so that padding is tested
phu0ngng Apr 4, 2025
b14b958
fix sharding constraint for dense
phu0ngng Apr 4, 2025
cc8b9de
add docstring
phu0ngng Apr 4, 2025
6cbd913
merge with main
phu0ngng Apr 4, 2025
9990fb6
rm output sharding as it can be done in flax module
phu0ngng Apr 4, 2025
ac221c5
sharding scale_inv for mxfp8
phu0ngng Apr 4, 2025
148a3a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2025
f90f28b
rm duplications
phu0ngng Apr 4, 2025
3e65ab5
cleanup and format
phu0ngng Apr 4, 2025
ae938a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ def __call__(self, x, mask, disable_dropout=False):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

x = x.reshape(x.shape[0], -1)

if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
# Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
Expand Down Expand Up @@ -459,30 +460,30 @@ def setUpClass(cls):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
Expand All @@ -491,7 +492,7 @@ def test_te_delayed_scaling_fp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
Expand All @@ -500,7 +501,7 @@ def test_te_mxfp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785


if __name__ == "__main__":
Expand Down
Loading