Skip to content

generate_wan: profiling pass deadlocks on multi-process TPU (profiler gated to process 0 only) #435

Description

@edgexyz

Summary

Running generate_wan.py with enable_profiler=True on a multi-process TPU pod (≥2 JAX processes) hangs indefinitely during the profiling generation pass and eventually aborts with a barrier DEADLINE_EXCEEDED. The warm-up (non-profiled) pass completes normally on all processes; only the profiled pass deadlocks. The bug is invisible on single-process / single-device runs, which is why it wasn't caught earlier.

Filing per the request in #426 to track the root cause with reproduction details.

Environment

  • MaxDiffusion: main at commit 9616d1c (bug present; introduced in 589c3d5).
  • Workload: WAN2.2-I2V, 1920×1088×81, num_inference_steps=8, mesh DP1·FSDP1·CP4·TP4, attention=ulysses_custom.
  • Hardware: multi-host TPU pod (4 JAX processes / 16 chips). Originally observed on a v4-16 pod; also reproducible on v6e-16. Any ≥2-process pod triggers it.
  • JAX: jax[tpu] (multi-host, one process per host).

Steps to reproduce

On every host of the pod (e.g. gcloud compute tpus tpu-vm ssh --worker=all --command=…), run the same invocation so all processes join the mesh:

python src/maxdiffusion/generate_wan.py \
  src/maxdiffusion/configs/base_wan_i2v_27b.yml \
  enable_profiler=True \
  num_inference_steps=8 \
  height=1088 width=1920 \
  per_device_batch_size=0.0625 \
  ici_data_parallelism=1 ici_fsdp_parallelism=1 \
  ici_context_parallelism=4 ici_tensor_parallelism=4 \
  attention=ulysses_custom \
  run_name=wan22_profiler_repro \
  output_dir=/tmp/wan_output base_output_directory=/tmp/wan_output

Only two ingredients are essential to trigger it: more than one JAX process and enable_profiler=True. Resolution, step count, and mesh are incidental (the values above are just what we ran).

Observed behavior

Warm-up finishes on all processes, then the profiling pass deadlocks and aborts (log from the reproduction on a 4-process / 16-device pod):

generation_time: 522.09s        # warm-up ok on all 4 processes
# hung — generation_time_with_profiler never logged

DEADLINE_EXCEEDED: Barrier timed out.
# of tasks that reached the barrier: 3/4.
The first task at the barrier: 0.
Aborted (core dumped)

Process 0 blocks inside call_pipeline; processes 1–3 return early, reach the shutdown barrier first (3/4), and the job times out.

Root cause

In generate_wan.run(), the profiling pass is gated on max_utils.profiler_enabled(config). That helper carries an implicit jax.process_index() == 0 guard (via _jax_profiler_enabled / _ml_diagnostics_profiler_enabled), so it returns True only on the coordinator. As a result only process 0 enters the profiling block and calls call_pipeline, which issues collective ops (AllReduce/AllGather) that require every process to participate. Processes 1–N skip the block and return, so the collective never forms a quorum → indefinite hang → barrier timeout.

Expected behavior

With enable_profiler=True, all processes should enter the profiling pass together so the collectives complete, matching the warm-up pass behavior.

Fix

Proposed in #426: gate the profiling pass on the raw original_enable_profiler boolean (saved from config.enable_profiler before warm-up temporarily disables it) instead of profiler_enabled(config). That value is True on all processes, so every process runs the profiling pass. The jax.process_index() == 0 guard inside profiler_enabled is correct for profiler context management and is left untouched. (1-line change; verified end-to-end — profiled pass then completes on all processes and emits the XLA trace.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Fields

    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions