Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 15 additions & 24 deletions tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,22 @@

from absl.testing import parameterized
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
import jax.numpy as jnp
from jax.sharding import AxisType, Mesh, NamedSharding, PartitionSpec as P
from MaxText import max_utils
from jax.sharding import AxisType, Mesh
from MaxText import maxtext_utils
from MaxText import pyconfig
from MaxText.common_types import (
AttentionType,
DECODING_ACTIVE_SEQUENCE_INDICATOR,
EP_AS_CONTEXT,
MODEL_MODE_AUTOREGRESSIVE,
MODEL_MODE_PREFILL,
MODEL_MODE_TRAIN,
ShardMode,
)
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.layers.attention_mla import MLA
from MaxText.layers.attention_op import ChunkedCausalMask, _generate_chunk_attention_mask, _make_bidirectional_block_mask
from MaxText.layers.attentions import Attention
from MaxText.sharding import maybe_shard_with_name
import numpy as np
import pytest

Expand Down Expand Up @@ -693,15 +688,13 @@ def test_tpu_flash_attention_context_parallel(
)
nnx.update(attention_as_mha_flash_cp, generic_state)

mha_generic_flash_cp_output = (
attention_test_util.forward_with_context_expert_parallelism(
cfg_cp,
mesh_cp,
attention_as_mha_flash_cp,
lnx,
decoder_segment_ids,
decoder_positions,
)
mha_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism(
cfg_cp,
mesh_cp,
attention_as_mha_flash_cp,
lnx,
decoder_segment_ids,
decoder_positions,
)

# This removes all sharding information and makes them standard NumPy arrays.
Expand Down Expand Up @@ -1479,15 +1472,13 @@ def test_tpu_flash_attention_context_parallel(
rngs=self.nnx_rng,
)
nnx.update(attention_as_mla_flash_cp, generic_state)
mla_generic_flash_cp_output = (
attention_test_util.forward_with_context_expert_parallelism(
cfg_cp,
mesh_cp,
attention_as_mla_flash_cp,
lnx,
decoder_segment_ids,
decoder_positions,
)
mla_generic_flash_cp_output = attention_test_util.forward_with_context_expert_parallelism(
cfg_cp,
mesh_cp,
attention_as_mla_flash_cp,
lnx,
decoder_segment_ids,
decoder_positions,
)

# This removes all sharding information and makes them standard NumPy arrays.
Expand Down
1 change: 1 addition & 0 deletions tests/integration_tests/train_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def test_gpu_synthetic_model_ag_once(self):

@pytest.mark.integration_test
@pytest.mark.gpu_only
@pytest.mark.scheduled_only
def test_gpu_zero1_gradient_accumulation(self):
os.environ["NVTE_FUSED_ATTN"] = "1" # Enable fused attention
zero1_ga = [ # tests Zero-1 optimizer sharding with gradient accumulation
Expand Down
Loading