diff --git a/tests/attention_test.py b/tests/attention_test.py index 3ce54e8251..c5e7c1bbab 100644 --- a/tests/attention_test.py +++ b/tests/attention_test.py @@ -23,27 +23,22 @@ from absl.testing import parameterized from flax import nnx -from flax.linen import partitioning as nn_partitioning import jax import jax.numpy as jnp -from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P -from MaxText import max_utils +from jax.sharding import AxisType, Mesh from MaxText import maxtext_utils from MaxText import pyconfig from MaxText.common_types import ( AttentionType, DECODING_ACTIVE_SEQUENCE_INDICATOR, - EP_AS_CONTEXT, MODEL_MODE_AUTOREGRESSIVE, MODEL_MODE_PREFILL, MODEL_MODE_TRAIN, - ShardMode, ) from MaxText.globals import MAXTEXT_PKG_DIR from MaxText.layers.attention_mla import MLA from MaxText.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask from MaxText.layers.attentions import Attention -from MaxText.sharding import maybe_shard_with_name import numpy as np import pytest @@ -693,15 +688,13 @@ def test_tpu_flash_attention_context_parallel( ) nnx.update(attention_as_mha_flash_cp, generic_state) - mha_generic_flash_cp_output = ( - attention_test_util.forward_with_context_expert_parallelism( - cfg_cp, - mesh_cp, - attention_as_mha_flash_cp, - lnx, - decoder_segment_ids, - decoder_positions, - ) + mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mha_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, ) # This removes all sharding information and makes them standard NumPy arrays. @@ -1479,15 +1472,13 @@ def test_tpu_flash_attention_context_parallel( rngs=self.nnx_rng, ) nnx.update(attention_as_mla_flash_cp, generic_state) - mla_generic_flash_cp_output = ( - attention_test_util.forward_with_context_expert_parallelism( - cfg_cp, - mesh_cp, - attention_as_mla_flash_cp, - lnx, - decoder_segment_ids, - decoder_positions, - ) + mla_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism( + cfg_cp, + mesh_cp, + attention_as_mla_flash_cp, + lnx, + decoder_segment_ids, + decoder_positions, ) # This removes all sharding information and makes them standard NumPy arrays. diff --git a/tests/integration_tests/train_tests.py b/tests/integration_tests/train_tests.py index d29338e502..cc81a8e063 100644 --- a/tests/integration_tests/train_tests.py +++ b/tests/integration_tests/train_tests.py @@ -417,6 +417,7 @@ def test_gpu_synthetic_model_ag_once(self): @pytest.mark.integration_test @pytest.mark.gpu_only + @pytest.mark.scheduled_only def test_gpu_zero1_gradient_accumulation(self): os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention zero1_ga = [ # tests Zero-1 optimizer sharding with gradient accumulation