diff --git a/.github/workflows/run_tests_internal.yml b/.github/workflows/run_tests_internal.yml index 7fa9a7286..8132a3a13 100644 --- a/.github/workflows/run_tests_internal.yml +++ b/.github/workflows/run_tests_internal.yml @@ -67,4 +67,4 @@ jobs: FINAL_PYTEST_MARKER="${{ inputs.pytest_marker }} and not scheduled_only" fi python3 -m pip install -e . --no-dependencies && - python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0 + LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536' python3 -m pytest -v -m "${FINAL_PYTEST_MARKER}" --durations=0 diff --git a/tests/train_compile_test.py b/tests/train_compile_test.py index 06498e868..90dc92a83 100644 --- a/tests/train_compile_test.py +++ b/tests/train_compile_test.py @@ -493,7 +493,7 @@ def test_moe_deepseek_scanned_bf16(self): "megablox=False", "per_device_batch_size=2", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "scan_layers=True", @@ -518,7 +518,7 @@ def test_moe_deepseek_unscanned_bf16(self): "megablox=False", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "scan_layers=False", @@ -541,7 +541,7 @@ def test_moe_deepseek_with_device_limit(self): "megablox=False", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "n_routing_groups=8", @@ -565,7 +565,7 @@ def test_moe_deepseek_without_device_limit(self): "megablox=False", "per_device_batch_size=1", "max_target_length=1024", - "attention=dot_product", # Change to flash attention once it works for MLA + "attention=flash", "dtype=bfloat16", "weight_dtype=bfloat16", "n_routing_groups=-1", @@ -585,7 +585,7 @@ def test_moe_deepseek_pipeline_subset(self): "compile_topology_num_slices=8", "use_iota_embed=true", "model_name=deepseek3-671b", - "megablox=False", # dropless not yet supported (b/418313093) + "megablox=True", "sparse_matmul=False", "capacity_factor=1", "per_device_batch_size=1",