Summary
Running generate_wan.py with enable_profiler=True on a multi-process TPU pod (≥2 JAX processes) hangs indefinitely during the profiling generation pass and eventually aborts with a barrier DEADLINE_EXCEEDED. The warm-up (non-profiled) pass completes normally on all processes; only the profiled pass deadlocks. The bug is invisible on single-process / single-device runs, which is why it wasn't caught earlier.
Filing per the request in #426 to track the root cause with reproduction details.
Environment
- MaxDiffusion: main at commit 9616d1c (bug present; introduced in 589c3d5).
- Workload: WAN2.2-I2V, 1920×1088×81, num_inference_steps=8, mesh DP1·FSDP1·CP4·TP4, attention=ulysses_custom.
- Hardware: multi-host TPU pod (4 JAX processes / 16 chips). Originally observed on a v4-16 pod; also reproducible on v6e-16. Any ≥2-process pod triggers it.
- JAX: jax[tpu] (multi-host, one process per host).
Steps to reproduce
On every host of the pod (e.g. gcloud compute tpus tpu-vm ssh --worker=all --command=…), run the same invocation so all processes join the mesh:
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
enable_profiler=True \
num_inference_steps=8 \
height=1088 width=1920 \
per_device_batch_size=0.0625 \
ici_data_parallelism=1 ici_fsdp_parallelism=1 \
ici_context_parallelism=4 ici_tensor_parallelism=4 \
attention=ulysses_custom \
run_name=wan22_profiler_repro \
output_dir=/tmp/wan_output base_output_directory=/tmp/wan_output
Only two ingredients are essential to trigger it: more than one JAX process and enable_profiler=True. Resolution, step count, and mesh are incidental (the values above are just what we ran).
Observed behavior
Warm-up finishes on all processes, then the profiling pass deadlocks and aborts (log from the reproduction on a 4-process / 16-device pod):
generation_time: 522.09s # warm-up ok on all 4 processes
# hung — generation_time_with_profiler never logged
DEADLINE_EXCEEDED: Barrier timed out.
# of tasks that reached the barrier: 3/4.
The first task at the barrier: 0.
Aborted (core dumped)
Process 0 blocks inside call_pipeline; processes 1–3 return early, reach the shutdown barrier first (3/4), and the job times out.
Root cause
In generate_wan.run(), the profiling pass is gated on max_utils.profiler_enabled(config). That helper carries an implicit jax.process_index() == 0 guard (via _jax_profiler_enabled / _ml_diagnostics_profiler_enabled), so it returns True only on the coordinator. As a result only process 0 enters the profiling block and calls call_pipeline, which issues collective ops (AllReduce/AllGather) that require every process to participate. Processes 1–N skip the block and return, so the collective never forms a quorum → indefinite hang → barrier timeout.
Expected behavior
With enable_profiler=True, all processes should enter the profiling pass together so the collectives complete, matching the warm-up pass behavior.
Fix
Proposed in #426: gate the profiling pass on the raw original_enable_profiler boolean (saved from config.enable_profiler before warm-up temporarily disables it) instead of profiler_enabled(config). That value is True on all processes, so every process runs the profiling pass. The jax.process_index() == 0 guard inside profiler_enabled is correct for profiler context management and is left untouched. (1-line change; verified end-to-end — profiled pass then completes on all processes and emits the XLA trace.)
Summary
Running generate_wan.py with enable_profiler=True on a multi-process TPU pod (≥2 JAX processes) hangs indefinitely during the profiling generation pass and eventually aborts with a barrier DEADLINE_EXCEEDED. The warm-up (non-profiled) pass completes normally on all processes; only the profiled pass deadlocks. The bug is invisible on single-process / single-device runs, which is why it wasn't caught earlier.
Filing per the request in #426 to track the root cause with reproduction details.
Environment
Steps to reproduce
On every host of the pod (e.g. gcloud compute tpus tpu-vm ssh --worker=all --command=…), run the same invocation so all processes join the mesh:
Only two ingredients are essential to trigger it: more than one JAX process and enable_profiler=True. Resolution, step count, and mesh are incidental (the values above are just what we ran).
Observed behavior
Warm-up finishes on all processes, then the profiling pass deadlocks and aborts (log from the reproduction on a 4-process / 16-device pod):
Process 0 blocks inside call_pipeline; processes 1–3 return early, reach the shutdown barrier first (3/4), and the job times out.
Root cause
In generate_wan.run(), the profiling pass is gated on max_utils.profiler_enabled(config). That helper carries an implicit jax.process_index() == 0 guard (via _jax_profiler_enabled / _ml_diagnostics_profiler_enabled), so it returns True only on the coordinator. As a result only process 0 enters the profiling block and calls call_pipeline, which issues collective ops (AllReduce/AllGather) that require every process to participate. Processes 1–N skip the block and return, so the collective never forms a quorum → indefinite hang → barrier timeout.
Expected behavior
With enable_profiler=True, all processes should enter the profiling pass together so the collectives complete, matching the warm-up pass behavior.
Fix
Proposed in #426: gate the profiling pass on the raw original_enable_profiler boolean (saved from config.enable_profiler before warm-up temporarily disables it) instead of profiler_enabled(config). That value is True on all processes, so every process runs the profiling pass. The jax.process_index() == 0 guard inside profiler_enabled is correct for profiler context management and is left untouched. (1-line change; verified end-to-end — profiled pass then completes on all processes and emits the XLA trace.)