Fix decoupled test failures on TPU7x#3980
Merged
Merged
Conversation
34dec80 to
1d663c8
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
32483fa to
65a98f2
Compare
khatwanimohit
approved these changes
May 27, 2026
igorts-git
approved these changes
May 27, 2026
This squashed commit consolidates all fixes to make the MaxText offline/decoupled test suite and GPU integration tests run successfully in the CI: - Fix incorrect imports (maxtext.src.maxtext and maxtext.tests) in unit and integration tests. - Fix Google3 import-error in gather_reduce_sc_test.py by using a runtime JAX device platform check instead of importing tests.conftest. - Refactor static skips using jax.device_count() (decorators) to dynamic, runtime-evaluated skips inside test functions to prevent early JAX/PJRT initialization during PyTest collection. - Fix GHA GPU runner NCCL failures: - Dynamically discover pip-installed 'nvidia' packages and prepend all 'nvidia/*/lib' paths to LD_LIBRARY_PATH to avoid conflicts with incompatible system-level CUDA/NCCL libraries. - Force early JAX GPU initialization in tests/conftest.py (triggered by GHA-specific markers) to prevent CUDA context corruption caused by TensorFlow/PyTorch importing early during collection. - Clean up all temporary debug prints and enable/disable verbose logs. - Add comprehensive in-line comments explaining the rationale for skips and CUDA context workarounds. TAG=agy
100e933 to
4106020
Compare
4 tasks
darisoy
added a commit
that referenced
this pull request
May 28, 2026
PR #3980 refactored TPU/GPU/CPU hardware platform checks to avoid early JAX initialization. Under Pathways, JAX_PLATFORMS is set to "proxy", which doesn"t contain "tpu", causing all tpu_only integration tests to get skipped. This fix adds "proxy" to the check to correctly detect the TPU platform.
hsuan-lun-chiang
pushed a commit
to CIeNET-International/maxtext
that referenced
this pull request
May 29, 2026
PR AI-Hypercomputer#3980 refactored TPU/GPU/CPU hardware platform checks to avoid early JAX initialization. Under Pathways, JAX_PLATFORMS is set to "proxy", which doesn"t contain "tpu", causing all tpu_only integration tests to get skipped. This fix adds "proxy" to the check to correctly detect the TPU platform.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR consolidates a series of critical bug fixes, test collection refactorings, and environment adaptations to ensure the MaxText offline/decoupled test suite runs without any failures or errors on TPU7x while ensuring that GPU integration tests run successfully in the CI.
Key Changes
1. Decoupled Mode & TPU7x Test Skips
Configured precise skipping rules for tests that are constrained by the offline/decoupled environment or specific hardware topologies:
generate_param_only_checkpoint_test.pyflow to skip in decoupled mode (usingis_decoupled()) since it requires external GCS/internet access.TrainTests::test_tpu_hf_input_pipelineto skip under decoupled mode because it has an external dependency on the HuggingFace internet registry.test_fsdp_sharding(insharding_test.py) andtest_moe_fsdp_two_stage_parallelism_tpu_only(inmoe_test.py) if the available accelerator count is not exactly 4 (preventing failures on standard single-host 1-device runners).test_ragged_sort_loss_and_gradtests (both Ring of Experts and standard) inmoe_test.pywith@pytest.mark.skip_on_tpu7xto skip them on TPU7x platforms due to a known lack of MLIR legalization support on simulated TPU7x.2. Skipping Logic Architecture Fixes
Fixed regressions in how the skipping logic was executed to prevent PyTest collection failures:
@pytest.mark.skipifand@unittest.skipIfdecorators that queriedjax.device_count()to imperative, runtime-evaluated skips inside the test functions. This prevents JAX from initializing prematurely during PyTest's collection phase, which previously locked the TPU or conflicted with CUDA.is_decoupled()helper instead of fragile, case-sensitiveos.getenv("DECOUPLE_GCLOUD") == "TRUE"checks.3. GPU CI Integration Fixes
Resolved low-level NCCL communicator failures (
corrupted comm object detected/ncclInvalidArgument) on the GPU CI runner:.github/workflows/run_tests_against_package.ymlto dynamically discover the virtual environment'snvidiapackage folder and prepend all its sub-library directories (includingnccl/lib,cublas/lib,cudnn/lib) toLD_LIBRARY_PATH, preventing JAX from loading incompatible system-level libraries.tests/conftest.pythat triggers strictly on GPU runners. By forcing JAX to secure the CUDA context first during collection, we prevent other libraries (like TensorFlow/PyTorch imported early during collection) from corrupting the CUDA context before JAX NCCL initialization.BUGS: b/484106617
Tests
Verified the changes inside a real TPU7x-8 Cloud VM in decoupled mode and ensured all tests are passing.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.