Skip to content

Fix decoupled test failures on TPU7x#3980

Merged
copybara-service[bot] merged 1 commit into
mainfrom
decoupled-test-fixes
May 27, 2026
Merged

Fix decoupled test failures on TPU7x#3980
copybara-service[bot] merged 1 commit into
mainfrom
decoupled-test-fixes

Conversation

@darisoy
Copy link
Copy Markdown
Collaborator

@darisoy darisoy commented May 26, 2026

Description

This PR consolidates a series of critical bug fixes, test collection refactorings, and environment adaptations to ensure the MaxText offline/decoupled test suite runs without any failures or errors on TPU7x while ensuring that GPU integration tests run successfully in the CI.

Key Changes

1. Decoupled Mode & TPU7x Test Skips

Configured precise skipping rules for tests that are constrained by the offline/decoupled environment or specific hardware topologies:

  • Bypassed GCS/Internet Constrained Tests: Marked the entire generate_param_only_checkpoint_test.py flow to skip in decoupled mode (using is_decoupled()) since it requires external GCS/internet access.
  • Bypassed HuggingFace Dependency: Marked TrainTests::test_tpu_hf_input_pipeline to skip under decoupled mode because it has an external dependency on the HuggingFace internet registry.
  • Topology Mismatch Skips: Added checks to skip test_fsdp_sharding (in sharding_test.py) and test_moe_fsdp_two_stage_parallelism_tpu_only (in moe_test.py) if the available accelerator count is not exactly 4 (preventing failures on standard single-host 1-device runners).
  • TPU7x Compilation Skips: Annotated test_ragged_sort_loss_and_grad tests (both Ring of Experts and standard) in moe_test.py with @pytest.mark.skip_on_tpu7x to skip them on TPU7x platforms due to a known lack of MLIR legalization support on simulated TPU7x.

2. Skipping Logic Architecture Fixes

Fixed regressions in how the skipping logic was executed to prevent PyTest collection failures:

  • Dynamic Skips for JAX safety: Converted the static @pytest.mark.skipif and @unittest.skipIf decorators that queried jax.device_count() to imperative, runtime-evaluated skips inside the test functions. This prevents JAX from initializing prematurely during PyTest's collection phase, which previously locked the TPU or conflicted with CUDA.
  • Standardized Decoupled Checks: Unified all decoupled skips to use the case-insensitive is_decoupled() helper instead of fragile, case-sensitive os.getenv("DECOUPLE_GCLOUD") == "TRUE" checks.

3. GPU CI Integration Fixes

Resolved low-level NCCL communicator failures (corrupted comm object detected / ncclInvalidArgument) on the GPU CI runner:

  • Dynamic Pinned CUDA/NCCL Paths: Configured .github/workflows/run_tests_against_package.yml to dynamically discover the virtual environment's nvidia package folder and prepend all its sub-library directories (including nccl/lib, cublas/lib, cudnn/lib) to LD_LIBRARY_PATH, preventing JAX from loading incompatible system-level libraries.
  • CUDA Context Conflict Resolution: Implemented a targeted early JAX initialization inside tests/conftest.py that triggers strictly on GPU runners. By forcing JAX to secure the CUDA context first during collection, we prevent other libraries (like TensorFlow/PyTorch imported early during collection) from corrupting the CUDA context before JAX NCCL initialization.

BUGS: b/484106617

Tests

Verified the changes inside a real TPU7x-8 Cloud VM in decoupled mode and ensured all tests are passing.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@darisoy darisoy force-pushed the decoupled-test-fixes branch 5 times, most recently from 34dec80 to 1d663c8 Compare May 26, 2026 18:24
@codecov
Copy link
Copy Markdown

codecov Bot commented May 26, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@darisoy darisoy force-pushed the decoupled-test-fixes branch 24 times, most recently from 32483fa to 65a98f2 Compare May 27, 2026 21:11
This squashed commit consolidates all fixes to make the MaxText offline/decoupled
test suite and GPU integration tests run successfully in the CI:

- Fix incorrect imports (maxtext.src.maxtext and maxtext.tests) in unit and integration tests.
- Fix Google3 import-error in gather_reduce_sc_test.py by using a runtime JAX device platform check instead of importing tests.conftest.
- Refactor static skips using jax.device_count() (decorators) to dynamic, runtime-evaluated skips inside test functions to prevent early JAX/PJRT initialization during PyTest collection.
- Fix GHA GPU runner NCCL failures:
  - Dynamically discover pip-installed 'nvidia' packages and prepend all 'nvidia/*/lib' paths to LD_LIBRARY_PATH to avoid conflicts with incompatible system-level CUDA/NCCL libraries.
  - Force early JAX GPU initialization in tests/conftest.py (triggered by GHA-specific markers) to prevent CUDA context corruption caused by TensorFlow/PyTorch importing early during collection.
- Clean up all temporary debug prints and enable/disable verbose logs.
- Add comprehensive in-line comments explaining the rationale for skips and CUDA context workarounds.

TAG=agy
@darisoy darisoy force-pushed the decoupled-test-fixes branch from 100e933 to 4106020 Compare May 27, 2026 22:55
@copybara-service copybara-service Bot merged commit 157baad into main May 27, 2026
30 checks passed
@copybara-service copybara-service Bot deleted the decoupled-test-fixes branch May 27, 2026 23:40
darisoy added a commit that referenced this pull request May 28, 2026
PR #3980 refactored TPU/GPU/CPU hardware platform checks to avoid early JAX initialization. Under Pathways, JAX_PLATFORMS is set to "proxy", which doesn"t contain "tpu", causing all tpu_only integration tests to get skipped. This fix adds "proxy" to the check to correctly detect the TPU platform.
hsuan-lun-chiang pushed a commit to CIeNET-International/maxtext that referenced this pull request May 29, 2026
PR AI-Hypercomputer#3980 refactored TPU/GPU/CPU hardware platform checks to avoid early JAX initialization. Under Pathways, JAX_PLATFORMS is set to "proxy", which doesn"t contain "tpu", causing all tpu_only integration tests to get skipped. This fix adds "proxy" to the check to correctly detect the TPU platform.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants