diff --git a/README.md b/README.md index 5ddcc323e..48785fdc5 100755 --- a/README.md +++ b/README.md @@ -210,6 +210,7 @@ After installation completes, run the training script. --xla_enable_async_all_gather=true \ --xla_tpu_scoped_vmem_limit_kib=65536 \ --xla_tpu_enable_async_all_to_all=true \ + --xla_tpu_enable_latency_hiding_scheduler=true \ --xla_tpu_enable_all_experimental_scheduler_features=true \ --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ --xla_tpu_host_transfer_overlap_limit=24 \ @@ -329,6 +330,7 @@ After installation completes, run the training script. --xla_enable_async_all_gather=true \ --xla_tpu_scoped_vmem_limit_kib=65536 \ --xla_tpu_enable_async_all_to_all=true \ + --xla_tpu_enable_latency_hiding_scheduler=true \ --xla_tpu_enable_all_experimental_scheduler_features=true \ --xla_tpu_enable_scheduler_memory_pressure_tracking=true \ --xla_tpu_host_transfer_overlap_limit=24 \ @@ -597,6 +599,23 @@ To generate images, run the following command: Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning. + For TPU multihost 2D context parallelism, use `attention="ulysses_ring"`. + This shards self-attention sequence over `context` x `tensor`, runs the Ulysses all-to-all over the `tensor` + mesh axis, and reuses Tokamax ring attention over the `context` mesh axis. The number of attention heads must + be divisible by `ici_tensor_parallelism`; a typical multihost setup uses DCN context for the ring axis and ICI + tensor for the Ulysses axis. On TPU7x, keep `dcn_tensor_parallelism=1` and set `ici_tensor_parallelism >= 2` + so the dual chiplets exposed as two JAX devices are grouped by Ulysses rather than ring. + + ```bash + python src/maxdiffusion/generate_wan.py \ + src/maxdiffusion/configs/base_wan_i2v_27b.yml \ + attention="ulysses_ring" \ + dcn_context_parallelism= \ + ici_context_parallelism=1 \ + ici_tensor_parallelism= \ + ... + ``` + In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup. ### Caching Mechanisms diff --git a/bench_attn.sh b/bench_attn.sh new file mode 100644 index 000000000..b3ae8ae39 --- /dev/null +++ b/bench_attn.sh @@ -0,0 +1,169 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ── config ──────────────────────────────────────────────────────────────────── +REPO_DIR=/dev/shm/maxdiffusion +VENV=$REPO_DIR/.venv-tpu +CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml +RESULTS_ROOT=$REPO_DIR/bench_results +OUTPUT_ROOT=$REPO_DIR/bench_outputs + +WORKER1_USER=sa_112155357684894056033 +WORKER1_IP=10.154.0.59 +WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB +SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine +SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts +SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \ + -o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \ + -o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS" + +export LIBTPU_INIT_ARGS=\ +'--xla_tpu_dvfs_p_state=7 '\ +'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\ +'--xla_tpu_megacore_fusion_allow_ags=false '\ +'--xla_enable_async_collective_permute=true '\ +'--xla_tpu_enable_ag_backward_pipelining=true '\ +'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\ +'--xla_tpu_data_parallel_opt_different_sized_ops=true '\ +'--xla_tpu_enable_async_collective_fusion=true '\ +'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\ +'--xla_tpu_overlap_compute_collective_tc=true '\ +'--xla_enable_async_all_gather=true '\ +'--xla_tpu_scoped_vmem_limit_kib=65536 '\ +'--xla_tpu_enable_async_all_to_all=true '\ +'--xla_tpu_enable_latency_hiding_scheduler=true '\ +'--xla_tpu_enable_all_experimental_scheduler_features=true '\ +'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\ +'--xla_tpu_host_transfer_overlap_limit=24 '\ +'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\ +'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\ +'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\ +'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\ +'--xla_max_concurrent_host_send_recv=100 '\ +'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\ +'--xla_latency_hiding_scheduler_rerun=5 '\ +'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\ +'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\ +'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\ +'--xla_tpu_assign_all_reduce_scatter_layout=true '\ +'--xla_max_concurrent_async_collective_permutes=16 '\ +'--xla_tpu_enable_ici_ag_pipelining=true' + +source "$VENV/bin/activate" +export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-} +export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface +export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub +export HF_HUB_ENABLE_HF_TRANSFER=1 +export HF_HUB_OFFLINE=1 +export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax +export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla +export TMPDIR=/dev/shm/maxdiffusion_cache/tmp + +# ── helper ──────────────────────────────────────────────────────────────────── +# run_case +# pure (flash/tokamax_ring/ulysses): ici_a=cp, ici_b unused +# hybrid (ulysses_ring): ici_a=ring, ici_b=ulysses +run_case() { + local run_name="$1" + local attention="$2" + local ici_dp="$3" + local ici_a="$4" + local ici_b="$5" + local pdb="$6" # per_device_batch_size + + local results_dir="$RESULTS_ROOT/$run_name" + rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name" + mkdir -p "$results_dir" + rm -f /tmp/libtpu_lockfile + mkdir -p "$TMPDIR" + + local ici_cp ici_tp ici_ring ici_ulysses + if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then + ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b" + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses pdb=$pdb) ──" + else + ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1 + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──" + fi + + # Only capture profiler for bs=2 (per_device_batch_size=0.125) + local profiler_args="enable_profiler=False" + if [[ "$pdb" == "0.125" ]]; then + profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10" + fi + + local common_args="run_name=$run_name \ + attention=$attention \ + ici_data_parallelism=$ici_dp \ + ici_fsdp_parallelism=1 \ + ici_context_parallelism=$ici_cp \ + ici_tensor_parallelism=$ici_tp \ + ici_ring_parallelism=$ici_ring \ + ici_ulysses_parallelism=$ici_ulysses \ + dcn_data_parallelism=1 \ + dcn_fsdp_parallelism=1 \ + dcn_context_parallelism=1 \ + dcn_tensor_parallelism=1 \ + dcn_ring_parallelism=1 \ + dcn_ulysses_parallelism=1 \ + height=720 width=1280 num_frames=81 num_inference_steps=40 \ + per_device_batch_size=$pdb \ + output_dir=$OUTPUT_ROOT \ + scan_layers=True \ + write_metrics=False \ + write_timing_metrics=False \ + $profiler_args" + + local remote_cmd + remote_cmd="$(printf '%q' "set -euo pipefail +export LIBTPU_INIT_ARGS='$LIBTPU_INIT_ARGS' +source $VENV/bin/activate +export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH:-} +export HF_HOME=$HF_HOME +export HF_HUB_CACHE=$HF_HUB_CACHE +export HF_HUB_ENABLE_HF_TRANSFER=1 +export HF_HUB_OFFLINE=1 +export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR +export XLA_CACHE_DIR=$XLA_CACHE_DIR +export TMPDIR=$TMPDIR +rm -f /tmp/libtpu_lockfile +mkdir -p $TMPDIR $results_dir $OUTPUT_ROOT +cd $results_dir +python -u $REPO_DIR/src/maxdiffusion/generate_wan.py $CONFIG $common_args 2>&1 | tee $results_dir/worker1.log")" + + /usr/bin/ssh $SSH_OPTS "$WORKER1_USER@$WORKER1_IP" "bash -lc $remote_cmd" & + local remote_pid=$! + + cd "$results_dir" + set +e + python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \ + "$CONFIG" \ + $common_args \ + 2>&1 | tee "$results_dir/worker0.log" + local local_status=$? + + wait "$remote_pid" + local remote_status=$? + set -e + + echo "[$(date -u +%T)] $run_name done — local=$local_status remote=$remote_status" + echo "────────────────────────────────────────────────────────────────────────" +} + +# ── run matrix ──────────────────────────────────────────────────────────────── +# Columns: run_name attention dp a b per_device_bs +# +# All runs: 16 total devices (2 hosts × 8), dp=2, cp=8 (or ring×ulysses=8) +# Batch sizes: 0.0625=bs1, 0.125=bs2, 0.25=bs4 + +for pdb in 0.0625 0.125 0.25 0.5; do + bs=$(python3 -c "print(int($pdb * 16))") + run_case "flash_dp2_cp8_bs${bs}" flash 2 8 1 $pdb + run_case "tokamax_ring_dp2_cp8_bs${bs}" tokamax_ring 2 8 1 $pdb + run_case "ulysses_dp2_cp8_bs${bs}" ulysses 2 8 1 $pdb + run_case "ulysses_ring_dp2_r2u4_bs${bs}" ulysses_ring 2 2 4 $pdb # ring=2, ulysses=4, dp=2 + run_case "ulysses_ring_dp2_r4u2_bs${bs}" ulysses_ring 2 4 2 $pdb # ring=4, ulysses=2, dp=2 + run_case "ulysses_ring_dp1_r4u4_bs${bs}" ulysses_ring 1 4 4 $pdb # ring=4, ulysses=4, dp=1 +done + +echo "[$(date -u +%T)] All benchmark runs complete." diff --git a/bench_attn_v78.sh b/bench_attn_v78.sh new file mode 100755 index 000000000..47bc1f77b --- /dev/null +++ b/bench_attn_v78.sh @@ -0,0 +1,143 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ── config (TPU v7-8, single host, 8 chips) ────────────────────────────────── +REPO_DIR=/mnt/data/sagarchapara/workspace/maxdiffusion +VENV=/mnt/data/sagarchapara/workspace/venv +CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml +RESULTS_ROOT=$REPO_DIR/bench_results +OUTPUT_ROOT=$REPO_DIR/bench_outputs + +PRETRAINED_ORBAX_DIR=/mnt/data/sagarchapara/workspace/wan22_orbax_cache +mkdir -p "$PRETRAINED_ORBAX_DIR" + +export LIBTPU_INIT_ARGS=\ +'--xla_tpu_dvfs_p_state=7 '\ +'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\ +'--xla_tpu_megacore_fusion_allow_ags=false '\ +'--xla_enable_async_collective_permute=true '\ +'--xla_tpu_enable_ag_backward_pipelining=true '\ +'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\ +'--xla_tpu_data_parallel_opt_different_sized_ops=true '\ +'--xla_tpu_enable_async_collective_fusion=true '\ +'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\ +'--xla_tpu_overlap_compute_collective_tc=true '\ +'--xla_enable_async_all_gather=true '\ +'--xla_tpu_scoped_vmem_limit_kib=65536 '\ +'--xla_tpu_enable_async_all_to_all=true '\ +'--xla_tpu_enable_latency_hiding_scheduler=true '\ +'--xla_tpu_enable_all_experimental_scheduler_features=true '\ +'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\ +'--xla_tpu_host_transfer_overlap_limit=24 '\ +'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\ +'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\ +'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\ +'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\ +'--xla_max_concurrent_host_send_recv=100 '\ +'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\ +'--xla_latency_hiding_scheduler_rerun=5 '\ +'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\ +'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\ +'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\ +'--xla_tpu_assign_all_reduce_scatter_layout=true '\ +'--xla_max_concurrent_async_collective_permutes=16 '\ +'--xla_tpu_enable_ici_ag_pipelining=true' + +source "$VENV/bin/activate" +export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-} +export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface +export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub +export HF_HUB_ENABLE_HF_TRANSFER=1 +export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax +export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla +export TMPDIR=/dev/shm/maxdiffusion_cache/tmp +mkdir -p "$TMPDIR" "$HF_HOME" "$HF_HUB_CACHE" "$JAX_COMPILATION_CACHE_DIR" "$XLA_CACHE_DIR" + +# ── helper (single host - no SSH) ──────────────────────────────────────────── +# run_case +# pure (flash/tokamax_ring/ulysses): ici_a=cp, ici_b unused +# hybrid (ulysses_ring): ici_a=ring, ici_b=ulysses +run_case() { + local run_name="$1" + local attention="$2" + local ici_dp="$3" + local ici_a="$4" + local ici_b="$5" + local pdb="$6" + + local results_dir="$RESULTS_ROOT/$run_name" + rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name" + mkdir -p "$results_dir" + rm -f /tmp/libtpu_lockfile + mkdir -p "$TMPDIR" + + local ici_cp ici_tp ici_ring ici_ulysses + if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then + ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b" + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses pdb=$pdb) ──" + else + ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1 + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──" + fi + + # Profiler only for bs=2 (pdb=0.25 with 8 devices) + local profiler_args="enable_profiler=False" + if [[ "$pdb" == "0.25" ]]; then + profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10" + fi + + local common_args="run_name=$run_name \ + attention=$attention \ + ici_data_parallelism=$ici_dp \ + ici_fsdp_parallelism=1 \ + ici_context_parallelism=$ici_cp \ + ici_tensor_parallelism=$ici_tp \ + ici_ring_parallelism=$ici_ring \ + ici_ulysses_parallelism=$ici_ulysses \ + dcn_data_parallelism=1 \ + dcn_fsdp_parallelism=1 \ + dcn_context_parallelism=1 \ + dcn_tensor_parallelism=1 \ + dcn_ring_parallelism=1 \ + dcn_ulysses_parallelism=1 \ + pretrained_orbax_dir=$PRETRAINED_ORBAX_DIR \ + height=720 width=1280 num_frames=81 num_inference_steps=40 \ + per_device_batch_size=$pdb \ + output_dir=$OUTPUT_ROOT \ + scan_layers=True \ + write_metrics=False \ + write_timing_metrics=False \ + $profiler_args" + + cd "$results_dir" + set +e + python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \ + "$CONFIG" \ + $common_args \ + 2>&1 | tee "$results_dir/worker0.log" + local status=$? + set -e + + echo "[$(date -u +%T)] $run_name done — status=$status" + echo "────────────────────────────────────────────────────────────────────────" +} + +# ── run matrix (TPU v7-8: 1 host × 8 chips = 8 devices) ────────────────────── +# Parallelism rule: dp × fsdp × cp × tp × ring × ulysses = 8 +# +# Pure modes: dp=2, cp=4 (2×4=8) +# 2D (2×2): dp=2, ring=2, ulysses=2 (2×2×2=8) +# +# Batch sizes: pdb × 8 devices = total_bs +# 0.125 → bs1, 0.25 → bs2, 0.5 → bs4 + +for pdb in 0.125 0.25 0.5; do + bs=$(python3 -c "print(int($pdb * 8))") + + run_case "flash_dp2_cp4_bs${bs}" flash 2 4 1 $pdb + run_case "tokamax_ring_dp2_cp4_bs${bs}" tokamax_ring 2 4 1 $pdb + run_case "ulysses_dp2_cp4_bs${bs}" ulysses 2 4 1 $pdb + run_case "ulysses_ring_dp2_r2u2_bs${bs}" ulysses_ring 2 2 2 $pdb +done + +echo "[$(date -u +%T)] All benchmark runs complete." diff --git a/bench_remaining.sh b/bench_remaining.sh new file mode 100755 index 000000000..cb206b70f --- /dev/null +++ b/bench_remaining.sh @@ -0,0 +1,159 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ── config ──────────────────────────────────────────────────────────────────── +REPO_DIR=/dev/shm/maxdiffusion +VENV=$REPO_DIR/.venv-tpu +CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml +RESULTS_ROOT=$REPO_DIR/bench_results +OUTPUT_ROOT=$REPO_DIR/bench_outputs + +WORKER1_USER=sa_112155357684894056033 +WORKER1_IP=10.154.0.59 +WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB +SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine +SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts +SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \ + -o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \ + -o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS" + +export LIBTPU_INIT_ARGS=\ +'--xla_tpu_dvfs_p_state=7 '\ +'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\ +'--xla_tpu_megacore_fusion_allow_ags=false '\ +'--xla_enable_async_collective_permute=true '\ +'--xla_tpu_enable_ag_backward_pipelining=true '\ +'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\ +'--xla_tpu_data_parallel_opt_different_sized_ops=true '\ +'--xla_tpu_enable_async_collective_fusion=true '\ +'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\ +'--xla_tpu_overlap_compute_collective_tc=true '\ +'--xla_enable_async_all_gather=true '\ +'--xla_tpu_scoped_vmem_limit_kib=65536 '\ +'--xla_tpu_enable_async_all_to_all=true '\ +'--xla_tpu_enable_latency_hiding_scheduler=true '\ +'--xla_tpu_enable_all_experimental_scheduler_features=true '\ +'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\ +'--xla_tpu_host_transfer_overlap_limit=24 '\ +'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\ +'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\ +'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\ +'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\ +'--xla_max_concurrent_host_send_recv=100 '\ +'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\ +'--xla_latency_hiding_scheduler_rerun=5 '\ +'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\ +'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\ +'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\ +'--xla_tpu_assign_all_reduce_scatter_layout=true '\ +'--xla_max_concurrent_async_collective_permutes=16' + +source "$VENV/bin/activate" +export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-} +export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface +export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub +export HF_HUB_ENABLE_HF_TRANSFER=1 +export HF_HUB_OFFLINE=1 +export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax +export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla +export TMPDIR=/dev/shm/maxdiffusion_cache/tmp + +# ── helper ──────────────────────────────────────────────────────────────────── +# For hybrid `ulysses_ring`, the 4th and 5th args are the ring and ulysses sizes +# (which use dedicated mesh axes); context and tensor are pinned to 1. +# For pure `ulysses` and `ring`, the 4th arg is context size and 5th is unused. +run_case() { + local run_name="$1" + local attention="$2" + local ici_dp="$3" + local ici_a="$4" # context (pure) | ring (hybrid) + local ici_b="$5" # unused (pure) | ulysses (hybrid) + + local results_dir="$RESULTS_ROOT/$run_name" + rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name" + mkdir -p "$results_dir" + rm -f /tmp/libtpu_lockfile + mkdir -p "$TMPDIR" + + local ici_cp ici_tp ici_ring ici_ulysses + if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then + ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b" + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses) ──" + else + ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1 + echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp) ──" + fi + + local common_args="run_name=$run_name \ + attention=$attention \ + ici_data_parallelism=$ici_dp \ + ici_fsdp_parallelism=1 \ + ici_context_parallelism=$ici_cp \ + ici_tensor_parallelism=$ici_tp \ + ici_ring_parallelism=$ici_ring \ + ici_ulysses_parallelism=$ici_ulysses \ + dcn_data_parallelism=1 \ + dcn_fsdp_parallelism=1 \ + dcn_context_parallelism=1 \ + dcn_tensor_parallelism=1 \ + dcn_ring_parallelism=1 \ + dcn_ulysses_parallelism=1 \ + height=720 width=1280 num_frames=81 num_inference_steps=40 \ + per_device_batch_size=0.125 \ + output_dir=$OUTPUT_ROOT \ + scan_layers=True \ + write_metrics=False \ + write_timing_metrics=False \ + enable_profiler=True \ + skip_first_n_steps_for_profiler=5 \ + profiler_steps=10" + + # launch worker 1 in background via SSH + local remote_cmd + remote_cmd="$(printf '%q' "set -euo pipefail +export LIBTPU_INIT_ARGS='$LIBTPU_INIT_ARGS' +source $VENV/bin/activate +export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH:-} +export HF_HOME=$HF_HOME +export HF_HUB_CACHE=$HF_HUB_CACHE +export HF_HUB_ENABLE_HF_TRANSFER=1 +export HF_HUB_OFFLINE=1 +export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR +export XLA_CACHE_DIR=$XLA_CACHE_DIR +export TMPDIR=$TMPDIR +rm -f /tmp/libtpu_lockfile +mkdir -p $TMPDIR $results_dir $OUTPUT_ROOT +cd $results_dir +python -u $REPO_DIR/src/maxdiffusion/generate_wan.py $CONFIG $common_args 2>&1 | tee $results_dir/worker1.log")" + + /usr/bin/ssh $SSH_OPTS "$WORKER1_USER@$WORKER1_IP" "bash -lc $remote_cmd" & + local remote_pid=$! + + cd "$results_dir" + # Disable -e for the python call so a single run failing doesn't kill the whole sequence. + set +e + python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \ + "$CONFIG" \ + $common_args \ + 2>&1 | tee "$results_dir/worker0.log" + local local_status=$? + + wait "$remote_pid" + local remote_status=$? + set -e + + echo "[$(date -u +%T)] $run_name done — local=$local_status remote=$remote_status" + echo "────────────────────────────────────────────────────────────────────────" +} + +# ── run matrix ──────────────────────────────────────────────────────────────── +# Pure runs: args = dp cp _ (cp on context axis) +# Hybrid runs: args = dp ring ulysses (dedicated ring + ulysses axes) +# run_name attention dp a b +run_case ulysses_dp2_cp8 ulysses 2 8 1 +run_case ring_dp2_cp8 ring 2 8 1 +run_case ulysses_ring_dp2_cp8_4x2 ulysses_ring 2 4 2 +run_case ulysses_ring_dp2_cp8_2x4 ulysses_ring 2 2 4 +run_case ulysses_ring_dp1_cp16_4x4 ulysses_ring 1 4 4 + +echo "[$(date -u +%T)] All benchmark runs complete." diff --git a/docs/tpu_multihost_wan_bench.md b/docs/tpu_multihost_wan_bench.md new file mode 100644 index 000000000..9a7436be9 --- /dev/null +++ b/docs/tpu_multihost_wan_bench.md @@ -0,0 +1,215 @@ +# TPU Multihost WAN Benchmarks + +This note shows how to connect to a TPU v7x-16 multihost VM and run WAN attention comparisons on both workers. + +The examples below assume: + +- TPU name: `rish-tpu-7x16` +- Zone: `europe-west2-a` +- Project: `tpu-prod-env-one-vm` +- Repo path on both workers: `/dev/shm/maxdiffusion` +- Venv path on both workers: `/dev/shm/maxdiffusion/.venv-tpu` +- HF cache path on both workers: `/dev/shm/maxdiffusion_cache/huggingface` + +For the current `rish-tpu-7x16` allocation, the worker endpoints resolved to: + +- worker 0: `sa_112155357684894056033@10.154.0.62` +- worker 1: `sa_112155357684894056033@10.154.0.59` +- worker 0 host key alias: `tpu.2884514015978940116-0-IE5JIi` +- worker 1 host key alias: `tpu.2884514015978940116-1-jgVFrB` + +## Connect + +Set the project and zone: + +```bash +gcloud config set project tpu-prod-env-one-vm +gcloud config set compute/zone europe-west2-a +``` + +Check that both workers are reachable: + +```bash +gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \ + --worker=all \ + --internal-ip \ + --zone=europe-west2-a \ + --quiet \ + --command='hostname' +``` + +Print the raw ssh command for a single worker when you need it: + +```bash +gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \ + --worker=0 \ + --internal-ip \ + --zone=europe-west2-a \ + --dry-run +``` + +## Stage Code On Worker 1 + +If worker 0 already has the repo, venv, and WAN checkpoint cache, mirror them to worker 1 from worker 0: + +```bash +rsync -a --delete \ + --exclude='.git' \ + --exclude='.venv' \ + --exclude='.venv-tpu' \ + --exclude='__pycache__' \ + --exclude='*.pyc' \ + -e 'ssh -T -i ~/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=tpu.2884514015978940116-1-jgVFrB -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=~/.ssh/google_compute_known_hosts' \ + /dev/shm/maxdiffusion/ \ + sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion/ +``` + +```bash +rsync -a --delete \ + -e 'ssh -T -i ~/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=tpu.2884514015978940116-1-jgVFrB -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=~/.ssh/google_compute_known_hosts' \ + /dev/shm/maxdiffusion/.venv-tpu/ \ + sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion/.venv-tpu/ +``` + +```bash +rsync -a --partial --delete \ + -e 'ssh -T -i ~/.ssh/google_compute_engine -o CheckHostIP=no -o HashKnownHosts=no -o HostKeyAlias=tpu.2884514015978940116-1-jgVFrB -o IdentitiesOnly=yes -o StrictHostKeyChecking=no -o UserKnownHostsFile=~/.ssh/google_compute_known_hosts' \ + /dev/shm/maxdiffusion_cache/huggingface/hub/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/ \ + sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion_cache/huggingface/hub/models--Wan-AI--Wan2.2-T2V-A14B-Diffusers/ +``` + +## Smoke Test + +Run a multihost JAX initialization smoke test on both workers: + +```bash +gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \ + --worker=all \ + --internal-ip \ + --zone=europe-west2-a \ + --quiet \ + --command=' + set -e + source /dev/shm/maxdiffusion/.venv-tpu/bin/activate + export PYTHONPATH=/dev/shm/maxdiffusion/src:${PYTHONPATH} + python - <<'"'"'PY'"'"' +import socket +import jax + +jax.distributed.initialize() +print( + f"host={socket.gethostname()} " + f"process_index={jax.process_index()} " + f"process_count={jax.process_count()} " + f"local_device_count={jax.local_device_count()} " + f"device_count={jax.device_count()}" +) +PY' +``` + +## Run WAN Comparison Jobs + +All commands below use: + +- model config: `src/maxdiffusion/configs/base_wan_27b.yml` +- checkpoint: `Wan-AI/Wan2.2-T2V-A14B-Diffusers` +- global batch size: `2` +- per-device batch size: `0.125` +- total devices: `16` + +Common environment: + +```bash +export REPO_DIR=/dev/shm/maxdiffusion +export VENV=/dev/shm/maxdiffusion/.venv-tpu +export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface +export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub +export HF_HUB_ENABLE_HF_TRANSFER=1 +export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax +export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla +export TMPDIR=/dev/shm/maxdiffusion_cache/tmp +export RESULTS_ROOT=/dev/shm/maxdiffusion/bench_results +export CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml +export COMMON_ARGS='height=720 width=1280 num_frames=81 num_inference_steps=40 per_device_batch_size=0.125 output_dir=/dev/shm/maxdiffusion/bench_outputs scan_layers=True enable_profiler=False' +``` + +Helper to run one job on both workers: + +```bash +run_case() { + local run_name="$1" + local attention="$2" + local ici_dp="$3" + local ici_cp="$4" + local ici_tp="$5" + + gcloud alpha compute tpus tpu-vm ssh rish-tpu-7x16 \ + --worker=all \ + --internal-ip \ + --zone=europe-west2-a \ + --quiet \ + --command=" + set -e + source $VENV/bin/activate + export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH} + export HF_HOME=$HF_HOME + export HF_HUB_CACHE=$HF_HUB_CACHE + export HF_HUB_ENABLE_HF_TRANSFER=$HF_HUB_ENABLE_HF_TRANSFER + export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR + export XLA_CACHE_DIR=$XLA_CACHE_DIR + export TMPDIR=$TMPDIR + mkdir -p $RESULTS_ROOT /dev/shm/maxdiffusion/bench_outputs + cd $RESULTS_ROOT + python $REPO_DIR/src/maxdiffusion/generate_wan.py \ + $CONFIG \ + run_name=$run_name \ + attention=$attention \ + ici_data_parallelism=$ici_dp \ + ici_fsdp_parallelism=1 \ + ici_context_parallelism=$ici_cp \ + ici_tensor_parallelism=$ici_tp \ + dcn_data_parallelism=1 \ + dcn_fsdp_parallelism=1 \ + dcn_context_parallelism=1 \ + dcn_tensor_parallelism=1 \ + $COMMON_ARGS \ + 2>&1 | tee $RESULTS_ROOT/$run_name.\$(hostname).log + " +} +``` + +Run matrix: + +```bash +run_case ulysses_dp2_cp8 ulysses 2 8 1 +run_case ring_dp2_cp8 ring 2 8 1 +run_case ulysses_ring_dp2_cp8_2x4 ulysses_ring 2 2 4 +run_case ulysses_ring_dp2_cp8_4x2 ulysses_ring 2 4 2 +run_case ulysses_ring_dp1_cp16_4x4 ulysses_ring 1 4 4 +``` + +## Topology Note + +TPU v7x exposes dual chiplets as two JAX devices. For `ulysses_ring`, keep the dual-chip pairing inside the Ulysses group by setting `ici_tensor_parallelism >= 2`. + +That means: + +- `2x4` uses tensor `4`, so the dual-chip pairing is inside the Ulysses side. +- `4x2` uses tensor `2`, so the dual-chip pairing is still inside the Ulysses side. +- `4x4` uses tensor `4`, so the dual-chip pairing is still inside the Ulysses side. + +The plain `ring` baseline has no Ulysses group, so it cannot preserve that property by construction. + +## Read Results + +Pull the timing summary from the per-host logs: + +```bash +rg -n "compile_time:|generation_time:|generation time per video:|TIMING SUMMARY" /dev/shm/maxdiffusion/bench_results/*.log +``` + +If you want a single run's logs from both workers: + +```bash +ls -1 /dev/shm/maxdiffusion/bench_results/ulysses_ring_dp2_cp8_2x4.* +``` diff --git a/docs/tpu_wan_bench_guide.md b/docs/tpu_wan_bench_guide.md new file mode 100644 index 000000000..99124f891 --- /dev/null +++ b/docs/tpu_wan_bench_guide.md @@ -0,0 +1,179 @@ +# WAN 2.2 Attention Benchmarking Guide (TPU v7x-16) + +## Setup + +Two-host TPU v7x slice: 2 × 8 chips = 16 devices total. + +- **Host 0** (local): coordinator, runs `generate_wan.py` + SSH to host 1 +- **Host 1** (worker): `sa_112155357684894056033@10.154.0.59` + +## Quick Start + +### Run the full attention benchmark + +```bash +nohup bash /dev/shm/maxdiffusion/bench_attn.sh \ + > /dev/shm/maxdiffusion/bench_results/bench_attn.log 2>&1 & + +# Monitor +tail -f /dev/shm/maxdiffusion/bench_results/bench_attn.log + +# Check results as they come in +grep -E "compile_time|Inference:|done —" /dev/shm/maxdiffusion/bench_results/bench_attn.log +``` + +### Run a single attention mode + +Use the one-shot helper scripts in `/tmp/`: + +```bash +# Ulysses dp=2 cp=8 +nohup bash /tmp/run_ulysses_v3.sh > /dev/shm/maxdiffusion/bench_results/myrun.log 2>&1 & + +# Monitor +tail -f /dev/shm/maxdiffusion/bench_results/myrun.log +``` + +Or call `generate_wan.py` directly on both hosts (see [Multihost section](#multihost-runs)). + +--- + +## Benchmark Matrix (`bench_attn.sh`) + +6 attention strategies × 4 batch sizes = 24 runs. +Profiler captured only for bs=2 to avoid disk overhead. + +| Attention | dp | cp/ring/ulysses | Notes | +|-----------|-----|-----------------|-------| +| `flash` | 2 | cp=8 | local flash per shard, no KV rotation | +| `tokamax_ring` | 2 | cp=8 | Tokamax ring kernel, KV rotated across context axis | +| `ulysses` | 2 | cp=8 | Ulysses all-to-all in BSHD layout | +| `ulysses_ring` | 2 | ring=2, ulysses=4 | Hybrid 2D: ulysses intra-chip, ring cross-chip | +| `ulysses_ring` | 2 | ring=4, ulysses=2 | Hybrid 2D: alternative split | +| `ulysses_ring` | 1 | ring=4, ulysses=4 | Hybrid 2D: full 16-chip seq sharding | + +Batch sizes: `per_device_batch_size` ∈ {0.0625, 0.125, 0.25, 0.5} +→ total videos = per_device_batch_size × 16 devices ∈ {1, 2, 4, 8} + +Results land in: +``` +/dev/shm/maxdiffusion/bench_results//worker0.log +/dev/shm/maxdiffusion/bench_results//worker1.log +``` + +--- + +## View Profiler Traces (xprof) + +```bash +source /dev/shm/maxdiffusion/.venv-tpu/bin/activate + +python -c " +from xprof.server import main +import sys +sys.argv=['xprof', + '--logdir', '/dev/shm/maxdiffusion/bench_outputs///tensorboard', + '--port', '6006'] +main() +" +``` + +Then open **http://localhost:6006/** + +Pick the **latest timestamp** under `plugins/profile/` — it's from the warm (second) inference round with XLA cache hot. + +To switch runs, kill the server (`pkill -f "xprof"`) and relaunch with the new logdir. + +--- + +## Multihost Runs + +Both hosts must launch `generate_wan.py` simultaneously — JAX distributed requires all workers to connect within the timeout window (~5 min). + +The benchmark scripts handle this automatically: the local host runs the job directly while SSH-ing the same command to worker 1 in the background, then `wait`s for both. + +### Sync code changes to worker 1 + +```bash +SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine +SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts +WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB +SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \ + -o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \ + -o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS" + +rsync -a --exclude='__pycache__' --exclude='*.pyc' \ + -e "/usr/bin/ssh $SSH_OPTS" \ + /dev/shm/maxdiffusion/src/ sa_112155357684894056033@10.154.0.59:/dev/shm/maxdiffusion/src/ +``` + +**Always rsync before launching a run after code changes.** + +### Clear TPU locks after a crash + +If a run crashes mid-flight, the TPU vfio devices may stay locked: + +```bash +# Local +ps -ef | grep generate_wan | grep -v grep | awk '{print $2}' | xargs -r kill -9 +rm -f /tmp/libtpu_lockfile + +# Remote +/usr/bin/ssh $SSH_OPTS sa_112155357684894056033@10.154.0.59 \ + 'ps -ef | grep generate_wan | grep -v grep | awk '"'"'{print $2}'"'"' | xargs -r kill -9; rm -f /tmp/libtpu_lockfile' +``` + +Then wait ~5 seconds before relaunching. + +--- + +## Key Config Parameters + +Set in `src/maxdiffusion/configs/base_wan_27b.yml` or overridden on the command line: + +| Parameter | Description | +|-----------|-------------| +| `attention` | `flash`, `tokamax_ring`, `ulysses`, `ulysses_ring` | +| `ici_data_parallelism` | Data parallel replicas within a host | +| `ici_context_parallelism` | Sequence shards (flash/ring/ulysses) | +| `ici_ring_parallelism` | Ring axis size (ulysses_ring only) | +| `ici_ulysses_parallelism` | Ulysses axis size (ulysses_ring only) | +| `per_device_batch_size` | Videos per device; total = × 16 | +| `num_inference_steps` | Denoising steps (40 for full quality) | +| `enable_profiler` | Capture xprof trace | +| `skip_first_n_steps_for_profiler` | Warmup steps before profiling (5) | +| `profiler_steps` | Steps to profile (10) | + +**Parallelism rule**: product of all ICI axes must equal 8 (chips per host): +- `ici_dp × ici_fsdp × ici_cp × ici_tp × ici_ring × ici_ulysses = 8` + +For `ulysses_ring`, set `ici_context_parallelism=1` and use `ici_ring` + `ici_ulysses` instead. + +--- + +## XLA Flags + +All performance-critical flags are set in `LIBTPU_INIT_ARGS` in the benchmark scripts. Notable ones: + +| Flag | Value | Purpose | +|------|-------|---------| +| `xla_tpu_enable_latency_hiding_scheduler` | true | Overlap compute and collectives | +| `xla_latency_hiding_scheduler_rerun` | 5 | LHS scheduling passes | +| `xla_enable_async_collective_permute` | true | Async KV rotation for ring attention | +| `xla_max_concurrent_async_collective_permutes` | 16 | Max in-flight ring permutes | +| `xla_tpu_enable_async_all_to_all` | true | Async Ulysses all-to-all | +| `xla_tpu_enable_ici_ag_pipelining` | true | Pipeline ICI all-gathers | +| `xla_tpu_scoped_vmem_limit_kib` | 65536 | VMEM budget per op | + +--- + +## Result Parsing + +```bash +# Summary table across all runs +grep -E "compile_time|Inference:|generation_time per video" \ + /dev/shm/maxdiffusion/bench_results/*/worker0.log + +# Just inference time +grep "Inference:" /dev/shm/maxdiffusion/bench_results/*/worker0.log | sort -t: -k3 -n +``` diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py index 6b1e0754e..e7d1bd863 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_2_2.py @@ -20,7 +20,13 @@ from ..pipelines.wan.wan_pipeline_2_2 import WanPipeline2_2 from .. import max_logging import orbax.checkpoint as ocp -from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import ( + add_sharding_to_struct, + get_cpu_mesh_and_sharding, + create_orbax_checkpoint_manager, + WAN_CHECKPOINT, +) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer @@ -83,20 +89,100 @@ def load_diffusers_checkpoint(self): pipeline = WanPipeline2_2.from_pretrained(self.config) return pipeline + def _get_pretrained_orbax_dir(self) -> str: + return getattr(self.config, "pretrained_orbax_dir", "") + + def save_pretrained_checkpoint(self, pretrained_dir: str, pipeline: WanPipeline2_2): + """Save pretrained weights (no optimizer state) to orbax for fast subsequent loads.""" + max_logging.log(f"Saving pretrained WAN 2.2 weights to orbax at {pretrained_dir}") + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + _, low_state, _ = nnx.split(pipeline.low_noise_transformer, nnx.Param, ...) + _, high_state, _ = nnx.split(pipeline.high_noise_transformer, nnx.Param, ...) + low_params = low_state.to_pure_dict() + high_params = high_state.to_pure_dict() + wan_config = json.loads(pipeline.low_noise_transformer.to_json_string()) + pretrained_mgr.save( + 0, + args=ocp.args.Composite( + wan_config=ocp.args.JsonSave(wan_config), + low_noise_transformer_state=ocp.args.StandardSave(low_params), + high_noise_transformer_state=ocp.args.StandardSave(high_params), + ), + ) + pretrained_mgr.wait_until_finished() + max_logging.log(f"Pretrained weights saved to {pretrained_dir}") + + def load_pretrained_from_orbax(self, pretrained_dir: str) -> Tuple[Optional[object], Optional[int]]: + """Load pretrained weights from orbax cache if available.""" + try: + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + step = pretrained_mgr.latest_step() + if step is None: + max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}") + return None, None + max_logging.log(f"Found pretrained orbax checkpoint (step {step}) in {pretrained_dir}") + mesh, replicated_sharding = get_cpu_mesh_and_sharding() + metadatas = pretrained_mgr.item_metadata(step) + low_meta = metadatas.low_noise_transformer_state + high_meta = metadatas.high_noise_transformer_state + target_shardings_low = jax.tree_util.tree_map(lambda x: replicated_sharding, low_meta) + target_shardings_high = jax.tree_util.tree_map(lambda x: replicated_sharding, high_meta) + with mesh: + abstract_low = jax.tree_util.tree_map(add_sharding_to_struct, low_meta, target_shardings_low) + abstract_high = jax.tree_util.tree_map(add_sharding_to_struct, high_meta, target_shardings_high) + max_logging.log("Restoring pretrained WAN 2.2 weights from orbax") + restored = pretrained_mgr.restore( + step, + args=ocp.args.Composite( + wan_config=ocp.args.JsonRestore(), + low_noise_transformer_state=ocp.args.StandardRestore(abstract_low), + high_noise_transformer_state=ocp.args.StandardRestore(abstract_high), + ), + ) + return restored, step + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}") + return None, None + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: + pretrained_dir = self._get_pretrained_orbax_dir() + + # 1. Fast path: load from pretrained orbax cache (skips diffusers entirely). + if pretrained_dir: + restored, loaded_step = self.load_pretrained_from_orbax(pretrained_dir) + if restored is not None: + max_logging.log("Loading WAN 2.2 pipeline from pretrained orbax checkpoint") + pipeline = WanPipeline2_2.from_checkpoint(self.config, restored) + return pipeline, None, loaded_step + + # 2. Try training checkpoint from checkpoint_dir. restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") + max_logging.log("Loading WAN pipeline from training checkpoint") pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] else: - max_logging.log("No checkpoint found, loading default pipeline.") + # 3. Slow path: load from diffusers, then cache to orbax for next time. + max_logging.log("No checkpoint found, loading pipeline from diffusers.") pipeline = self.load_diffusers_checkpoint() + if pretrained_dir: + self.save_pretrained_checkpoint(pretrained_dir, pipeline) return pipeline, opt_state, step diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py index ce3cc7bb1..1845f5e8b 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer_i2v_2p2.py @@ -20,7 +20,13 @@ from ..pipelines.wan.wan_pipeline_i2v_2p2 import WanPipelineI2V_2_2 from .. import max_logging import orbax.checkpoint as ocp -from maxdiffusion.checkpointing.checkpointing_utils import add_sharding_to_struct, get_cpu_mesh_and_sharding +from flax import nnx +from maxdiffusion.checkpointing.checkpointing_utils import ( + add_sharding_to_struct, + get_cpu_mesh_and_sharding, + create_orbax_checkpoint_manager, + WAN_CHECKPOINT, +) from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer @@ -83,20 +89,100 @@ def load_diffusers_checkpoint(self): pipeline = WanPipelineI2V_2_2.from_pretrained(self.config) return pipeline + def _get_pretrained_orbax_dir(self) -> str: + return getattr(self.config, "pretrained_orbax_dir", "") + + def save_pretrained_checkpoint(self, pretrained_dir: str, pipeline: WanPipelineI2V_2_2): + """Save pretrained weights (no optimizer state) to orbax for fast subsequent loads.""" + max_logging.log(f"Saving pretrained WAN 2.2 I2V weights to orbax at {pretrained_dir}") + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + _, low_state, _ = nnx.split(pipeline.low_noise_transformer, nnx.Param, ...) + _, high_state, _ = nnx.split(pipeline.high_noise_transformer, nnx.Param, ...) + low_params = low_state.to_pure_dict() + high_params = high_state.to_pure_dict() + wan_config = json.loads(pipeline.low_noise_transformer.to_json_string()) + pretrained_mgr.save( + 0, + args=ocp.args.Composite( + wan_config=ocp.args.JsonSave(wan_config), + low_noise_transformer_state=ocp.args.StandardSave(low_params), + high_noise_transformer_state=ocp.args.StandardSave(high_params), + ), + ) + pretrained_mgr.wait_until_finished() + max_logging.log(f"Pretrained weights saved to {pretrained_dir}") + + def load_pretrained_from_orbax(self, pretrained_dir: str) -> Tuple[Optional[object], Optional[int]]: + """Load pretrained weights from orbax cache if available.""" + try: + pretrained_mgr = create_orbax_checkpoint_manager( + pretrained_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=WAN_CHECKPOINT, + use_async=False, + ) + step = pretrained_mgr.latest_step() + if step is None: + max_logging.log(f"No pretrained orbax checkpoint found in {pretrained_dir}") + return None, None + max_logging.log(f"Found pretrained orbax checkpoint (step {step}) in {pretrained_dir}") + mesh, replicated_sharding = get_cpu_mesh_and_sharding() + metadatas = pretrained_mgr.item_metadata(step) + low_meta = metadatas.low_noise_transformer_state + high_meta = metadatas.high_noise_transformer_state + target_shardings_low = jax.tree_util.tree_map(lambda x: replicated_sharding, low_meta) + target_shardings_high = jax.tree_util.tree_map(lambda x: replicated_sharding, high_meta) + with mesh: + abstract_low = jax.tree_util.tree_map(add_sharding_to_struct, low_meta, target_shardings_low) + abstract_high = jax.tree_util.tree_map(add_sharding_to_struct, high_meta, target_shardings_high) + max_logging.log("Restoring pretrained WAN 2.2 I2V weights from orbax") + restored = pretrained_mgr.restore( + step, + args=ocp.args.Composite( + wan_config=ocp.args.JsonRestore(), + low_noise_transformer_state=ocp.args.StandardRestore(abstract_low), + high_noise_transformer_state=ocp.args.StandardRestore(abstract_high), + ), + ) + return restored, step + except Exception as e: # pylint: disable=broad-except + max_logging.log(f"Failed to load pretrained orbax checkpoint from {pretrained_dir}: {e}") + return None, None + def load_checkpoint(self, step=None) -> Tuple[WanPipelineI2V_2_2, Optional[dict], Optional[int]]: + pretrained_dir = self._get_pretrained_orbax_dir() + + # 1. Fast path: load from pretrained orbax cache (skips diffusers entirely). + if pretrained_dir: + restored, loaded_step = self.load_pretrained_from_orbax(pretrained_dir) + if restored is not None: + max_logging.log("Loading WAN 2.2 I2V pipeline from pretrained orbax checkpoint") + pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored) + return pipeline, None, loaded_step + + # 2. Try training checkpoint from checkpoint_dir. restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") + max_logging.log("Loading WAN pipeline from training checkpoint") pipeline = WanPipelineI2V_2_2.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] else: - max_logging.log("No checkpoint found, loading default pipeline.") + # 3. Slow path: load from diffusers, then cache to orbax for next time. + max_logging.log("No checkpoint found, loading pipeline from diffusers.") pipeline = self.load_diffusers_checkpoint() + if pretrained_dir: + self.save_pretrained_checkpoint(pretrained_dir, pipeline) return pipeline, opt_state, step diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index feae6e933..c5205ac53 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -38,6 +38,8 @@ FSDP = "fsdp" CONTEXT = "context" TENSOR = "tensor" +RING = "ring" +ULYSSES = "ulysses" # Logical axis names for model parameters and activations. BATCH = "activation_batch" LENGTH = "activation_length" @@ -94,3 +96,15 @@ [CROSS_ATTN_Q_LENGTH, CONTEXT], [CROSS_ATTN_KV_LENGTH, CONTEXT], ] + +### Common axis rules for 2D Ulysses + ring attention ### +# Sequence is sharded across both `ring` (cross-chip rotation) and `ulysses` +# (intra-chip all-to-all) axes. Weight TP is intentionally not used here. +ULYSSES_RING_ATTENTION_AXIS_RULES = [ + [SELF_ATTN_HEAD, None], + [SELF_ATTN_Q_LENGTH, [RING, ULYSSES]], + [SELF_ATTN_KV_LENGTH, [RING, ULYSSES]], + [CROSS_ATTN_HEAD, None], + [CROSS_ATTN_Q_LENGTH, [RING, ULYSSES]], + [CROSS_ATTN_KV_LENGTH, None], +] diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 7ffb659c8..720d59576 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -61,7 +61,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses, ulysses_ring use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 0 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 9e59ba9ce..ebf7350a7 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses, ulysses_ring use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 0 diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index f80c15515..323e127a6 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -61,7 +61,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses, ulysses_ring use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 4096 @@ -77,14 +77,14 @@ attention_sharding_uniform: True dropout: 0.0 flash_block_sizes: { - "block_q" : 512, - "block_kv_compute" : 512, - "block_kv" : 512, - "block_q_dkv" : 512, - "block_kv_dkv" : 512, - "block_kv_dkv_compute" : 512, - "block_q_dq" : 512, - "block_kv_dq" : 512, + "block_q" : 2048, + "block_kv_compute" : 1024, + "block_kv" : 2048, + "block_q_dkv" : 2048, + "block_kv_dkv" : 2048, + "block_kv_dkv_compute" : 1024, + "block_q_dq" : 2048, + "block_kv_dq" : 2048, "use_fused_bwd_kernel": False, } # Use on v6e @@ -142,7 +142,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu' skip_jax_distributed_system: False # Parallelism -mesh_axes: ['data', 'fsdp', 'context', 'tensor'] +mesh_axes: ['data', 'fsdp', 'context', 'tensor', 'ring', 'ulysses'] # batch : batch dimension of data and activations # hidden : @@ -159,17 +159,17 @@ mesh_axes: ['data', 'fsdp', 'context', 'tensor'] logical_axis_rules: [ ['batch', ['data', 'fsdp']], ['activation_batch', ['data', 'fsdp']], - ['activation_self_attn_heads', ['context', 'tensor']], - ['activation_cross_attn_q_length', ['context', 'tensor']], - ['activation_length', 'context'], + ['activation_self_attn_heads', ['context', 'ring', 'ulysses']], + ['activation_cross_attn_q_length', ['context', 'ring', 'ulysses']], + ['activation_length', ['context', 'ring', 'ulysses']], ['activation_heads', 'tensor'], ['mlp','tensor'], - ['embed', ['context', 'fsdp']], + ['embed', ['context', 'fsdp', 'ring', 'ulysses']], ['heads', 'tensor'], ['norm', 'tensor'], - ['conv_batch', ['data', 'context', 'fsdp']], + ['conv_batch', ['data', 'context', 'fsdp', 'ring', 'ulysses']], ['out_channels', 'tensor'], - ['conv_out', 'context'], + ['conv_out', ['context', 'ring', 'ulysses']], ] vae_logical_axis_rules: [ ['activation_batch', 'redundant'], @@ -184,7 +184,7 @@ vae_logical_axis_rules: [ ['conv_out', 'vae_spatial'], ['conv_in', 'vae_spatial'], ] -data_sharding: [['data', 'fsdp', 'context', 'tensor']] +data_sharding: [['data', 'fsdp', 'context', 'tensor', 'ring', 'ulysses']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. @@ -194,10 +194,14 @@ dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_context_parallelism: -1 dcn_tensor_parallelism: 1 +dcn_ring_parallelism: 1 +dcn_ulysses_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: 1 ici_context_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +ici_ring_parallelism: 1 +ici_ulysses_parallelism: 1 allow_split_physical_axes: False @@ -249,6 +253,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index b136c7a9e..8b2535bde 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses, ulysses_ring use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 4096 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index 4af011879..4cd2ec49d 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -60,7 +60,7 @@ jit_initializers: True # Set true to load weights from pytorch from_pt: True split_head_dim: True -attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses, ulysses_ring use_base2_exp: True use_experimental_scheduler: True flash_min_seq_length: 4096 @@ -244,6 +244,10 @@ names_which_can_be_offloaded: [] # checkpoint every number of samples, -1 means don't checkpoint. checkpoint_every: -1 checkpoint_dir: "" +# Directory to cache pretrained weights as an orbax checkpoint for fast inference loads. +# On first run (slow, diffusers load), weights are saved here automatically. +# On subsequent runs, weights are loaded from here instead (~10x faster). +pretrained_orbax_dir: "" # enables one replica to read the ckpt then broadcast to the rest enable_single_replica_ckpt_restoring: False diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 3a885fba0..59fc092ff 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -375,6 +375,13 @@ def create_device_mesh(config, devices=None, logging=True): config.ici_context_parallelism, config.ici_tensor_parallelism, ] + # Optional `ring` and `ulysses` axes for hybrid Ulysses+Ring attention. + # `ulysses` is placed last so it is the fastest-varying axis: on TPU v7x + # adjacent device IDs are intra-chip, keeping the Ulysses all-to-all on + # the high-bandwidth chip-pair link. + if "ici_ring_parallelism" in config.get_keys() and "ici_ulysses_parallelism" in config.get_keys(): + dcn_parallelism += [config.dcn_ring_parallelism, config.dcn_ulysses_parallelism] + ici_parallelism += [config.ici_ring_parallelism, config.ici_ulysses_parallelism] else: dcn_parallelism = [ config.dcn_data_parallelism, diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 53596eeaf..2613e388b 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -57,6 +57,8 @@ CROSS_ATTN_Q_LENGTH = common_types.CROSS_ATTN_Q_LENGTH CROSS_ATTN_KV_LENGTH = common_types.CROSS_ATTN_KV_LENGTH +ULYSSES_RING_ATTENTION_KERNELS = {"ulysses_ring", "ulysses_tokamax_ring"} + def _coerce_tokamax_block_sizes(block_sizes): # Tokamax requires fused bwd; convert if needed. @@ -155,6 +157,29 @@ def _unflatten_heads(tensor, heads): return tensor +def _unflatten_heads_bshd(tensor, heads): + # reshapes from [b, s, h * d] to [b, s, h, d] + batch, seq, heads_and_dim_head = tensor.shape + return tensor.reshape(batch, seq, heads, heads_and_dim_head // heads) + + +def _bhsd_to_bshd(tensor): + return jnp.transpose(tensor, (0, 2, 1, 3)) + + +def _bshd_to_bhsd(tensor): + return jnp.transpose(tensor, (0, 2, 1, 3)) + + +def _bshd_axis_names(axis_names: AxisNames) -> AxisNames: + batch_axis, head_axis, seq_axis, d_axis = axis_names + return (batch_axis, seq_axis, head_axis, d_axis) + + +def _bshd_as_bhsd_shape(tensor): + return jax.ShapeDtypeStruct((tensor.shape[0], tensor.shape[2], tensor.shape[1], tensor.shape[3]), tensor.dtype) + + def _reshape_data_for_flash(tensor, heads, num_context_shards=1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. @@ -178,6 +203,29 @@ def _reshape_data_for_flash(tensor, heads, num_context_shards=1): return jnp.pad(tensor, pad_width), org_seq_len +def _reshape_data_for_ulysses(tensor, heads, num_context_shards=1): + """Reshape attention inputs to BSHD for Ulysses all-to-all.""" + if tensor.ndim == 3: + tensor = _unflatten_heads_bshd(tensor, heads) + elif tensor.ndim == 4: + # Older call sites may still pass BHSD. Keep accepting that format, but + # prefer BSHD so the Ulysses all-to-all runs before the expensive transpose. + if tensor.shape[1] == heads and tensor.shape[2] != heads: + tensor = _bhsd_to_bshd(tensor) + else: + raise ValueError(f"Ulysses attention expects rank-3 or rank-4 inputs, got rank {tensor.ndim}.") + + org_seq_len = tensor.shape[1] + if num_context_shards <= 1: + return tensor, org_seq_len + rem = org_seq_len % num_context_shards + if rem == 0: + return tensor, org_seq_len + pad_width = [(0, 0)] * tensor.ndim + pad_width[1] = (0, num_context_shards - rem) + return jnp.pad(tensor, pad_width), org_seq_len + + def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. @@ -525,18 +573,19 @@ def _ulysses_attention( """Ulysses sequence-parallel attention. Tensors arrive sequence-sharded on the context axis. Inside a shard_map the - all-to-all collectives trade sequence shards for head shards, run local - splash attention on the full sequence with a subset of heads, then all-to-all - back. + all-to-all collectives trade sequence shards for head shards in BSHD layout, + run local splash attention on the full sequence with a subset of heads, then + all-to-all back. """ axis_name = "context" num_shards = mesh.shape[axis_name] - # Reshape to [b, h, s, d] and pad sequence for even context-axis splitting. - query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_shards) - key, _ = _reshape_data_for_flash(key, heads, num_shards) - value, _ = _reshape_data_for_flash(value, heads, num_shards) - num_heads = query.shape[1] + # Keep tensors in [b, s, h, d] through the Ulysses all-to-all. This avoids + # materializing a full-size BHSD layout before the context redistribution. + query, orig_q_seq_len = _reshape_data_for_ulysses(query, heads, num_shards) + key, _ = _reshape_data_for_ulysses(key, heads, num_shards) + value, _ = _reshape_data_for_ulysses(value, heads, num_shards) + num_heads = query.shape[2] # Ulysses only redistributes existing heads across the context mesh; unlike # the earlier draft, we fail fast instead of padding synthetic heads. if num_heads % num_shards != 0: @@ -544,10 +593,12 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) - block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") + block_sizes = _select_flash_block_sizes( + _bshd_as_bhsd_shape(query), _bshd_as_bhsd_shape(key), flash_block_sizes, dtype, "flash" + ) - q_axis_names = nn.logical_to_mesh_axes(axis_names_q) - kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + q_axis_names = nn.logical_to_mesh_axes(_bshd_axis_names(axis_names_q)) + kv_axis_names = nn.logical_to_mesh_axes(_bshd_axis_names(axis_names_kv)) @functools.partial( jax.shard_map, @@ -559,25 +610,21 @@ def _ulysses_attention( def wrap_ulysses_attention(query, key, value): # Swap sharding modes: each device gives up a slice of sequence and gathers # a slice of heads, so the local splash kernel sees the full sequence. - query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) - key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) - value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) + query = jax.lax.all_to_all(query, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + + # Splash expects [b, h, s, d]. Transpose only after the all-to-all, when + # the local head count is smaller by the context shard count. + query = _bshd_to_bhsd(query) + key = _bshd_to_bhsd(key) + value = _bshd_to_bhsd(value) # Run the same local splash kernel as standard TPU flash attention, but now # on full-sequence / fewer-heads tensors produced by the all-to-all above. - uses_fused_kernel = block_sizes.use_fused_bwd_kernel - block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) - block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) - if uses_fused_kernel: - block_q_sizes += (block_sizes.block_q_dkv,) - block_kv_sizes += (block_sizes.block_kv_dkv,) - else: - block_q_sizes += (block_sizes.block_q_dq,) - block_kv_sizes += (block_sizes.block_kv_dq,) - - block_q = max(*block_q_sizes) + block_q = max(block_sizes.block_q, block_sizes.block_q_dkv) query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) - block_kv = max(*block_kv_sizes) + block_kv = max(block_sizes.block_kv, block_sizes.block_kv_dkv) key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) value, _, _ = _pad_data_for_flash(value, heads, block_kv) @@ -625,7 +672,8 @@ def wrap_ulysses_attention(query, key, value): # Restore the original layout expected by the rest of the model: # head-sharded / full-sequence -> sequence-sharded / full-heads. - attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + attention_output = _bhsd_to_bshd(attention_output) + attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) return attention_output devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) @@ -635,8 +683,156 @@ def wrap_ulysses_attention(query, key, value): f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) x = wrap_ulysses_attention(query, key, value) - x = x[:, :, :orig_q_seq_len, :] - x = _reshape_heads_to_head_dim(x) + x = x[:, :orig_q_seq_len, :, :] + x = x.reshape(x.shape[0], x.shape[1], -1) + + return x + + +def _ulysses_ring_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + heads: int, + mesh: Mesh, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, + flash_block_sizes: BlockSizes, + dtype: jnp.dtype = jnp.float32, + mask_padding_tokens: bool = True, + residual_checkpoint_name: str | None = None, + attention_mask: jax.Array = None, + ulysses_axis: str = "ulysses", + ring_axis: str = "ring", + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, +) -> jax.Array: + """2D context-parallel attention using Ulysses all-to-all plus Tokamax ring. + + Inputs are sequence-sharded over both the ring and Ulysses mesh axes. The + Ulysses all-to-all trades the Ulysses-axis sequence shard for a head shard, + leaving sequence sharded only over the ring axis. Tokamax ring attention then + rotates K/V across the ring axis. + """ + if ulysses_axis not in mesh.shape: + raise ValueError(f"Ulysses ring attention requires mesh axis {ulysses_axis!r}, got mesh axes {mesh.shape}.") + if ring_axis not in mesh.shape: + raise ValueError(f"Ulysses ring attention requires mesh axis {ring_axis!r}, got mesh axes {mesh.shape}.") + + num_ulysses_shards = mesh.shape[ulysses_axis] + num_ring_shards = mesh.shape[ring_axis] + num_sequence_shards = num_ulysses_shards * num_ring_shards + + query, orig_q_seq_len = _reshape_data_for_ulysses(query, heads, num_sequence_shards) + key, _ = _reshape_data_for_ulysses(key, heads, num_sequence_shards) + value, _ = _reshape_data_for_ulysses(value, heads, num_sequence_shards) + + num_heads = query.shape[2] + if num_heads % num_ulysses_shards != 0: + raise ValueError( + "Ulysses ring attention requires the number of heads to be divisible by the Ulysses shard count, " + f"got heads={num_heads} and ulysses_shards={num_ulysses_shards}." + ) + block_sizes = _select_flash_block_sizes( + _bshd_as_bhsd_shape(query), _bshd_as_bhsd_shape(key), flash_block_sizes, dtype, "tokamax_ring" + ) + + q_axis_names = nn.logical_to_mesh_axes(_bshd_axis_names(axis_names_q)) + kv_axis_names = nn.logical_to_mesh_axes(_bshd_axis_names(axis_names_kv)) + + @functools.partial( + jax.shard_map, + mesh=mesh, + in_specs=(q_axis_names, kv_axis_names, kv_axis_names), + out_specs=q_axis_names, + check_vma=False, + ) + def wrap_ulysses_ring_attention(query, key, value): + query = jax.lax.all_to_all(query, axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True) + key = jax.lax.all_to_all(key, axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True) + value = jax.lax.all_to_all(value, axis_name=ulysses_axis, split_axis=2, concat_axis=1, tiled=True) + + query = _bshd_to_bhsd(query) + key = _bshd_to_bhsd(key) + value = _bshd_to_bhsd(value) + + uses_fused_kernel = block_sizes.use_fused_bwd_kernel + block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) + block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) + if uses_fused_kernel: + block_q_sizes += (block_sizes.block_q_dkv,) + block_kv_sizes += (block_sizes.block_kv_dkv,) + else: + block_q_sizes += (block_sizes.block_q_dq,) + block_kv_sizes += (block_sizes.block_kv_dq,) + + block_q = max(*block_q_sizes) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) + block_kv = max(*block_kv_sizes) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_kv) + + q_padded_len = query.shape[2] + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + + kv_padded_len = key.shape[2] + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + + if attention_mask is not None: + mask_len = min(key_seq_len, attention_mask.shape[1]) + kv_mask_for_batch = attention_mask[0, :mask_len] + if key_seq_len > mask_len: + extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) + if kv_padded_len > key_seq_len: + padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) + kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) + else: + kv_mask_padded = kv_mask_for_batch + kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) + + segment_ids = tokamax_splash_base.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + if not mask_padding_tokens: + segment_ids = None + + mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) + splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( + mask=mask, + is_mqa=False, + config=convert_to_tokamax_splash_config( + block_sizes, + residual_checkpoint_name=residual_checkpoint_name, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ), + save_residuals=False, + ring_axis=ring_axis, + rotate_segment_ids=False, + ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + attention_output = vmapped_splash(query, key, value, segment_ids) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + + attention_output = _bhsd_to_bshd(attention_output) + return jax.lax.all_to_all( + attention_output, + axis_name=ulysses_axis, + split_axis=1, + concat_axis=2, + tiled=True, + ) + + devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) + if not (query.shape[0] / devices_in_batch_sharding).is_integer(): + max_logging.log( + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" + ) + x = wrap_ulysses_ring_attention(query, key, value) + x = x[:, :orig_q_seq_len, :, :] + x = x.reshape(x.shape[0], x.shape[1], -1) return x @@ -761,9 +957,9 @@ def _apply_attention( """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) seq_len_idx = 1 - if query.ndim == 4: + if query.ndim == 4 and attention_kernel != "ulysses" and attention_kernel not in ULYSSES_RING_ATTENTION_KERNELS: seq_len_idx = 2 - if attention_kernel in ["flash", "tokamax_flash", "ulysses"]: + if attention_kernel in ["flash", "tokamax_flash", "ulysses"] or attention_kernel in ULYSSES_RING_ATTENTION_KERNELS: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length @@ -790,6 +986,23 @@ def _apply_attention( residual_checkpoint_name=residual_checkpoint_name, attention_mask=attention_mask, ) + elif attention_kernel in ULYSSES_RING_ATTENTION_KERNELS: + return _ulysses_ring_attention( + query, + key * scale, + value, + heads, + mesh, + axis_names_q, + axis_names_kv, + flash_block_sizes, + dtype, + mask_padding_tokens=mask_padding_tokens, + residual_checkpoint_name=residual_checkpoint_name, + attention_mask=attention_mask, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ) elif attention_kernel in ["flash", "tokamax_flash"]: return _tpu_flash_attention( query, @@ -1196,6 +1409,8 @@ def __init__( axis_names_kv = (BATCH, CROSS_ATTN_HEAD, CROSS_ATTN_KV_LENGTH, D_KV) if attention_kernel == "tokamax_ring" and not is_self_attention: attention_kernel = "tokamax_flash" # do not use ring attention for cross attention + if attention_kernel in ULYSSES_RING_ATTENTION_KERNELS and not is_self_attention: + attention_kernel = "tokamax_flash" # Ulysses+ring is only valid for sequence-parallel self-attention. self.added_kv_proj_dim = added_kv_proj_dim # New for I2V self.image_seq_len = image_seq_len # New for I2V @@ -1420,11 +1635,17 @@ def __call__( if rotary_emb is not None: with self.conditional_named_scope("attn_rope"): - query_proj = _unflatten_heads(query_proj, self.heads) - key_proj = _unflatten_heads(key_proj, self.heads) - value_proj = _unflatten_heads(value_proj, self.heads) - # output of _unflatten_heads Batch, heads, seq_len, head_dim - query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) + if self.attention_op.attention_kernel == "ulysses" or self.attention_op.attention_kernel in ULYSSES_RING_ATTENTION_KERNELS: + query_proj = _unflatten_heads_bshd(query_proj, self.heads) + key_proj = _unflatten_heads_bshd(key_proj, self.heads) + value_proj = _unflatten_heads_bshd(value_proj, self.heads) + query_proj, key_proj = self._apply_rope(query_proj, key_proj, _bhsd_to_bshd(rotary_emb)) + else: + query_proj = _unflatten_heads(query_proj, self.heads) + key_proj = _unflatten_heads(key_proj, self.heads) + value_proj = _unflatten_heads(value_proj, self.heads) + # output of _unflatten_heads Batch, heads, seq_len, head_dim + query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) query_proj = checkpoint_name(query_proj, "query_proj") key_proj = checkpoint_name(key_proj, "key_proj") diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 031fe2fe0..50a82607b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -630,7 +630,6 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): vae_devices_array = flat_devices.reshape(total_devices // vae_spatial, vae_spatial) vae_mesh = Mesh(vae_devices_array, ("redundant", "vae_spatial")) - vae_mesh.vae_spatial_axis_name = "vae_spatial" max_logging.log( f"Created VAE specific mesh with axes ('redundant', 'vae_spatial') to support spatial sharding of {vae_spatial}." ) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 9488f8946..b21181351 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -46,12 +46,27 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t common_components = cls._create_common_components(config, vae_only) low_noise_transformer, high_noise_transformer = None, None if not vae_only and load_transformer: + # Restructure the combined checkpoint into per-transformer checkpoints. + # create_sharded_logical_transformer expects {"wan_config": ..., "wan_state": ...}. + if restored_checkpoint is not None: + low_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["low_noise_transformer_state"], + } + high_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["high_noise_transformer_state"], + } + else: + low_noise_ckpt = None + high_noise_ckpt = None + low_noise_transformer = super().load_transformer( devices_array=common_components["devices_array"], mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=low_noise_ckpt, subfolder="transformer_2", ) high_noise_transformer = super().load_transformer( @@ -59,7 +74,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=high_noise_ckpt, subfolder="transformer", ) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index d8398f58f..e5feba3f9 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -50,12 +50,27 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t low_noise_transformer, high_noise_transformer = None, None if not vae_only: if load_transformer: + # Restructure the combined checkpoint into per-transformer checkpoints. + # create_sharded_logical_transformer expects {"wan_config": ..., "wan_state": ...}. + if restored_checkpoint is not None: + high_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["high_noise_transformer_state"], + } + low_noise_ckpt = { + "wan_config": restored_checkpoint["wan_config"], + "wan_state": restored_checkpoint["low_noise_transformer_state"], + } + else: + high_noise_ckpt = None + low_noise_ckpt = None + high_noise_transformer = super().load_transformer( devices_array=common_components["devices_array"], mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=high_noise_ckpt, subfolder="transformer", ) low_noise_transformer = super().load_transformer( @@ -63,7 +78,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t mesh=common_components["mesh"], rngs=common_components["rngs"], config=config, - restored_checkpoint=restored_checkpoint, + restored_checkpoint=low_noise_ckpt, subfolder="transformer_2", ) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 08de21c41..cfbdd8407 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -37,6 +37,7 @@ RING_ATTENTION_AXIS_RULES, SEQUENCE_PARALLEL_AXIS_RULES, ULYSSES_ATTENTION_AXIS_RULES, + ULYSSES_RING_ATTENTION_AXIS_RULES, ) _ALLOWED_MODEL_NAMES = {WAN2_1, WAN2_2, LTX2_VIDEO} @@ -213,10 +214,11 @@ def user_init(raw_keys): raw_keys["vae_logical_axis_rules"] = _lists_to_tuples(raw_keys["vae_logical_axis_rules"]) # Verify qkv is sharded across sequence. attention = raw_keys["attention"] - uses_ring_attention = "ring" in attention + uses_ulysses_ring_attention = attention in ("ulysses_ring", "ulysses_tokamax_ring") + uses_ring_attention = "ring" in attention and not uses_ulysses_ring_attention uses_ulysses_attention = attention == "ulysses" uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"] - if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding: + if uses_ring_attention or uses_ulysses_attention or uses_ulysses_ring_attention or uses_uniform_sequence_sharding: max_logging.log( "Adding sequence sharding to q and kv if not already present because " f"{attention=} requires it or attention_sharding_uniform={uses_uniform_sequence_sharding} is set." @@ -232,7 +234,12 @@ def user_init(raw_keys): if kv_seq_sharding not in logical_axis_rules: logical_axis_rules.append(kv_seq_sharding) max_logging.log(f"Adding key/value sequence axis rule {kv_seq_sharding}") - if uses_ring_attention: + if uses_ulysses_ring_attention: + for ulysses_ring_attention_axis_rule in ULYSSES_RING_ATTENTION_AXIS_RULES: + if ulysses_ring_attention_axis_rule not in logical_axis_rules: + max_logging.log(f"Adding ulysses ring attention axis rule {ulysses_ring_attention_axis_rule}") + new_rules.append(ulysses_ring_attention_axis_rule) + elif uses_ring_attention: for ring_attention_axis_rule in RING_ATTENTION_AXIS_RULES: if ring_attention_axis_rule not in logical_axis_rules: max_logging.log(f"Adding ring attention axis rule {ring_attention_axis_rule}") diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 5c95dff8b..0242d31f7 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -43,6 +43,10 @@ def _ulysses_mesh(self): devices = np.array(jax.devices()[:2]).reshape(1, 1, 2, 1) return Mesh(devices, ("data", "fsdp", "context", "tensor")) + def _ulysses_ring_mesh(self): + devices = np.array(jax.devices()[:4]).reshape(1, 1, 2, 2) + return Mesh(devices, ("data", "fsdp", "ring", "ulysses")) + def _ulysses_axis_rules(self): return ( (attention_flax.BATCH, "data"), @@ -52,6 +56,15 @@ def _ulysses_axis_rules(self): (attention_flax.D_KV, None), ) + def _ulysses_ring_axis_rules(self): + return ( + (attention_flax.BATCH, "data"), + (attention_flax.SELF_ATTN_HEAD, None), + (attention_flax.SELF_ATTN_Q_LENGTH, ("ring", "ulysses")), + (attention_flax.SELF_ATTN_KV_LENGTH, ("ring", "ulysses")), + (attention_flax.D_KV, None), + ) + def _flash_axis_rules(self): return ( (attention_flax.BATCH, "data"), @@ -228,7 +241,9 @@ def test_ulysses_attention_round_trips_query_when_heads_are_divisible(self): length = 5 heads = 4 head_depth = 4 - query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape( + batch, length, heads * head_depth + ) key = query + 1000.0 value = query + 2000.0 mesh = self._ulysses_mesh() @@ -280,7 +295,9 @@ def test_ulysses_attention_raises_when_heads_are_not_divisible_by_context_shards length = 5 heads = 3 head_depth = 4 - query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape( + batch, length, heads * head_depth + ) key = query + 1000.0 value = query + 2000.0 mesh = self._ulysses_mesh() @@ -318,7 +335,9 @@ def test_ulysses_attention_matches_flash_attention_with_same_local_kernel(self): length = 6 heads = 4 head_depth = 3 - query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape( + batch, length, heads * head_depth + ) key = query + 100.0 value = query + 200.0 mesh = self._ulysses_mesh() @@ -391,7 +410,9 @@ def test_ulysses_attention_uses_attention_mask_for_segment_ids(self): length = 5 heads = 4 head_depth = 3 - query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape(batch, length, heads * head_depth) + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape( + batch, length, heads * head_depth + ) key = query + 100.0 value = query + 200.0 attention_mask = jnp.array([[1, 0, 1, 0, 1]], dtype=jnp.int32) @@ -441,6 +462,103 @@ def fake_kernel(q, k, v, segment_ids): self.assertEqual(output.shape, query.shape) self.assertTrue(jnp.array_equal(output, expected)) + @unittest.skipIf(len(jax.devices()) < 4, "Ulysses ring attention layout test requires at least 4 devices.") + def test_ulysses_ring_attention_round_trips_query_when_heads_are_divisible(self): + """2D Ulysses+ring attention should preserve layout across its collectives.""" + batch = 2 + length = 8 + heads = 4 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape( + batch, length, heads * head_depth + ) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + def fake_make_ring_attention(**unused_kwargs): + def fake_kernel(q, k, v, segment_ids): + del k, v, segment_ids + return q + + return fake_kernel + + with ( + mesh, + nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()), + mock.patch.object( + attention_flax.tokamax_ring_attention_kernel, + "make_ring_attention", + side_effect=fake_make_ring_attention, + ), + ): + output = attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + + self.assertEqual(output.shape, query.shape) + self.assertTrue(jnp.array_equal(output, query)) + + def test_ulysses_ring_attention_raises_when_heads_are_not_divisible_by_ulysses_shards(self): + """The all-to-all head split requires heads to divide the Ulysses axis.""" + if len(jax.devices()) < 4: + self.skipTest("Ulysses ring attention validation test requires at least 4 devices.") + batch = 2 + length = 8 + heads = 3 + head_depth = 4 + query = jnp.arange(batch * length * heads * head_depth, dtype=jnp.float32).reshape( + batch, length, heads * head_depth + ) + key = query + 1000.0 + value = query + 2000.0 + mesh = self._ulysses_ring_mesh() + + with mesh, nn_partitioning.axis_rules(self._ulysses_ring_axis_rules()): + with self.assertRaisesRegex( + ValueError, + r"heads=3 and ulysses_shards=2", + ): + attention_flax._ulysses_ring_attention( + query, + key, + value, + heads=heads, + mesh=mesh, + axis_names_q=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_Q_LENGTH, + attention_flax.D_KV, + ), + axis_names_kv=( + attention_flax.BATCH, + attention_flax.SELF_ATTN_HEAD, + attention_flax.SELF_ATTN_KV_LENGTH, + attention_flax.D_KV, + ), + flash_block_sizes=self._ulysses_block_sizes(), + dtype=jnp.float32, + ) + if __name__ == "__main__": absltest.main()