Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ After installation completes, run the training script.
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_latency_hiding_scheduler=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
Expand Down Expand Up @@ -329,6 +330,7 @@ After installation completes, run the training script.
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_latency_hiding_scheduler=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
Expand Down Expand Up @@ -597,6 +599,23 @@ To generate images, run the following command:

Ulysses requires `ici_context_parallelism` greater than 1, and the number of attention heads must be divisible by the context shard count. `flash_block_sizes` tuning is optional and can still be used for hardware-specific tuning.

For TPU multihost 2D context parallelism, use `attention="ulysses_ring"`.
This shards self-attention sequence over `context` x `tensor`, runs the Ulysses all-to-all over the `tensor`
mesh axis, and reuses Tokamax ring attention over the `context` mesh axis. The number of attention heads must
be divisible by `ici_tensor_parallelism`; a typical multihost setup uses DCN context for the ring axis and ICI
tensor for the Ulysses axis. On TPU7x, keep `dcn_tensor_parallelism=1` and set `ici_tensor_parallelism >= 2`
so the dual chiplets exposed as two JAX devices are grouped by Ulysses rather than ring.

```bash
python src/maxdiffusion/generate_wan.py \
src/maxdiffusion/configs/base_wan_i2v_27b.yml \
attention="ulysses_ring" \
dcn_context_parallelism=<num_slices> \
ici_context_parallelism=1 \
ici_tensor_parallelism=<ulysses_shards_per_slice> \
...
```

In our Wan2.2 I2V benchmarks at 40 inference steps, 81 frames, and `720x1280` resolution, Ulysses improved inference time by roughly `~10%` compared with flash attention, with about `~20s` lower latency on the v6e-8 and v7x-8 TPU setup.

### Caching Mechanisms
Expand Down
169 changes: 169 additions & 0 deletions bench_attn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#!/usr/bin/env bash
set -euo pipefail

# ── config ────────────────────────────────────────────────────────────────────
REPO_DIR=/dev/shm/maxdiffusion
VENV=$REPO_DIR/.venv-tpu
CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml
RESULTS_ROOT=$REPO_DIR/bench_results
OUTPUT_ROOT=$REPO_DIR/bench_outputs

WORKER1_USER=sa_112155357684894056033
WORKER1_IP=10.154.0.59
WORKER1_HOST_ALIAS=tpu.2884514015978940116-1-jgVFrB
SSH_KEY=/home/sagarchapara_google_com/.ssh/google_compute_engine
SSH_KNOWN_HOSTS=/home/sagarchapara_google_com/.ssh/google_compute_known_hosts
SSH_OPTS="-T -i $SSH_KEY -o CheckHostIP=no -o HashKnownHosts=no \
-o HostKeyAlias=$WORKER1_HOST_ALIAS -o IdentitiesOnly=yes \
-o StrictHostKeyChecking=no -o UserKnownHostsFile=$SSH_KNOWN_HOSTS"

export LIBTPU_INIT_ARGS=\
'--xla_tpu_dvfs_p_state=7 '\
'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\
'--xla_tpu_megacore_fusion_allow_ags=false '\
'--xla_enable_async_collective_permute=true '\
'--xla_tpu_enable_ag_backward_pipelining=true '\
'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\
'--xla_tpu_data_parallel_opt_different_sized_ops=true '\
'--xla_tpu_enable_async_collective_fusion=true '\
'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\
'--xla_tpu_overlap_compute_collective_tc=true '\
'--xla_enable_async_all_gather=true '\
'--xla_tpu_scoped_vmem_limit_kib=65536 '\
'--xla_tpu_enable_async_all_to_all=true '\
'--xla_tpu_enable_latency_hiding_scheduler=true '\
'--xla_tpu_enable_all_experimental_scheduler_features=true '\
'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\
'--xla_tpu_host_transfer_overlap_limit=24 '\
'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\
'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\
'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\
'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\
'--xla_max_concurrent_host_send_recv=100 '\
'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\
'--xla_latency_hiding_scheduler_rerun=5 '\
'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\
'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\
'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\
'--xla_tpu_assign_all_reduce_scatter_layout=true '\
'--xla_max_concurrent_async_collective_permutes=16 '\
'--xla_tpu_enable_ici_ag_pipelining=true'

source "$VENV/bin/activate"
export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-}
export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface
export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub
export HF_HUB_ENABLE_HF_TRANSFER=1
export HF_HUB_OFFLINE=1
export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax
export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla
export TMPDIR=/dev/shm/maxdiffusion_cache/tmp

# ── helper ────────────────────────────────────────────────────────────────────
# run_case <run_name> <attention> <ici_dp> <ici_a> <ici_b> <per_device_batch_size>
# pure (flash/tokamax_ring/ulysses): ici_a=cp, ici_b unused
# hybrid (ulysses_ring): ici_a=ring, ici_b=ulysses
run_case() {
local run_name="$1"
local attention="$2"
local ici_dp="$3"
local ici_a="$4"
local ici_b="$5"
local pdb="$6" # per_device_batch_size

local results_dir="$RESULTS_ROOT/$run_name"
rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name"
mkdir -p "$results_dir"
rm -f /tmp/libtpu_lockfile
mkdir -p "$TMPDIR"

local ici_cp ici_tp ici_ring ici_ulysses
if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then
ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b"
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses pdb=$pdb) ──"
else
ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──"
fi

# Only capture profiler for bs=2 (per_device_batch_size=0.125)
local profiler_args="enable_profiler=False"
if [[ "$pdb" == "0.125" ]]; then
profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10"
fi

local common_args="run_name=$run_name \
attention=$attention \
ici_data_parallelism=$ici_dp \
ici_fsdp_parallelism=1 \
ici_context_parallelism=$ici_cp \
ici_tensor_parallelism=$ici_tp \
ici_ring_parallelism=$ici_ring \
ici_ulysses_parallelism=$ici_ulysses \
dcn_data_parallelism=1 \
dcn_fsdp_parallelism=1 \
dcn_context_parallelism=1 \
dcn_tensor_parallelism=1 \
dcn_ring_parallelism=1 \
dcn_ulysses_parallelism=1 \
height=720 width=1280 num_frames=81 num_inference_steps=40 \
per_device_batch_size=$pdb \
output_dir=$OUTPUT_ROOT \
scan_layers=True \
write_metrics=False \
write_timing_metrics=False \
$profiler_args"

local remote_cmd
remote_cmd="$(printf '%q' "set -euo pipefail
export LIBTPU_INIT_ARGS='$LIBTPU_INIT_ARGS'
source $VENV/bin/activate
export PYTHONPATH=$REPO_DIR/src:\${PYTHONPATH:-}
export HF_HOME=$HF_HOME
export HF_HUB_CACHE=$HF_HUB_CACHE
export HF_HUB_ENABLE_HF_TRANSFER=1
export HF_HUB_OFFLINE=1
export JAX_COMPILATION_CACHE_DIR=$JAX_COMPILATION_CACHE_DIR
export XLA_CACHE_DIR=$XLA_CACHE_DIR
export TMPDIR=$TMPDIR
rm -f /tmp/libtpu_lockfile
mkdir -p $TMPDIR $results_dir $OUTPUT_ROOT
cd $results_dir
python -u $REPO_DIR/src/maxdiffusion/generate_wan.py $CONFIG $common_args 2>&1 | tee $results_dir/worker1.log")"

/usr/bin/ssh $SSH_OPTS "$WORKER1_USER@$WORKER1_IP" "bash -lc $remote_cmd" &
local remote_pid=$!

cd "$results_dir"
set +e
python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \
"$CONFIG" \
$common_args \
2>&1 | tee "$results_dir/worker0.log"
local local_status=$?

wait "$remote_pid"
local remote_status=$?
set -e

echo "[$(date -u +%T)] $run_name done — local=$local_status remote=$remote_status"
echo "────────────────────────────────────────────────────────────────────────"
}

# ── run matrix ────────────────────────────────────────────────────────────────
# Columns: run_name attention dp a b per_device_bs
#
# All runs: 16 total devices (2 hosts × 8), dp=2, cp=8 (or ring×ulysses=8)
# Batch sizes: 0.0625=bs1, 0.125=bs2, 0.25=bs4

for pdb in 0.0625 0.125 0.25 0.5; do
bs=$(python3 -c "print(int($pdb * 16))")
run_case "flash_dp2_cp8_bs${bs}" flash 2 8 1 $pdb
run_case "tokamax_ring_dp2_cp8_bs${bs}" tokamax_ring 2 8 1 $pdb
run_case "ulysses_dp2_cp8_bs${bs}" ulysses 2 8 1 $pdb
run_case "ulysses_ring_dp2_r2u4_bs${bs}" ulysses_ring 2 2 4 $pdb # ring=2, ulysses=4, dp=2
run_case "ulysses_ring_dp2_r4u2_bs${bs}" ulysses_ring 2 4 2 $pdb # ring=4, ulysses=2, dp=2
run_case "ulysses_ring_dp1_r4u4_bs${bs}" ulysses_ring 1 4 4 $pdb # ring=4, ulysses=4, dp=1
done

echo "[$(date -u +%T)] All benchmark runs complete."
143 changes: 143 additions & 0 deletions bench_attn_v78.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
#!/usr/bin/env bash
set -euo pipefail

# ── config (TPU v7-8, single host, 8 chips) ──────────────────────────────────
REPO_DIR=/mnt/data/sagarchapara/workspace/maxdiffusion
VENV=/mnt/data/sagarchapara/workspace/venv
CONFIG=$REPO_DIR/src/maxdiffusion/configs/base_wan_27b.yml
RESULTS_ROOT=$REPO_DIR/bench_results
OUTPUT_ROOT=$REPO_DIR/bench_outputs

PRETRAINED_ORBAX_DIR=/mnt/data/sagarchapara/workspace/wan22_orbax_cache
mkdir -p "$PRETRAINED_ORBAX_DIR"

export LIBTPU_INIT_ARGS=\
'--xla_tpu_dvfs_p_state=7 '\
'--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true '\
'--xla_tpu_megacore_fusion_allow_ags=false '\
'--xla_enable_async_collective_permute=true '\
'--xla_tpu_enable_ag_backward_pipelining=true '\
'--xla_tpu_enable_data_parallel_all_reduce_opt=true '\
'--xla_tpu_data_parallel_opt_different_sized_ops=true '\
'--xla_tpu_enable_async_collective_fusion=true '\
'--xla_tpu_enable_async_collective_fusion_multiple_steps=true '\
'--xla_tpu_overlap_compute_collective_tc=true '\
'--xla_enable_async_all_gather=true '\
'--xla_tpu_scoped_vmem_limit_kib=65536 '\
'--xla_tpu_enable_async_all_to_all=true '\
'--xla_tpu_enable_latency_hiding_scheduler=true '\
'--xla_tpu_enable_all_experimental_scheduler_features=true '\
'--xla_tpu_enable_scheduler_memory_pressure_tracking=true '\
'--xla_tpu_host_transfer_overlap_limit=24 '\
'--xla_tpu_aggressive_opt_barrier_removal=ENABLED '\
'--xla_lhs_prioritize_async_depth_over_stall=ENABLED '\
'--xla_should_allow_loop_variant_parameter_in_chain=ENABLED '\
'--xla_should_add_loop_invariant_op_in_chain=ENABLED '\
'--xla_max_concurrent_host_send_recv=100 '\
'--xla_tpu_scheduler_percent_shared_memory_limit=100 '\
'--xla_latency_hiding_scheduler_rerun=5 '\
'--xla_tpu_use_minor_sharding_for_major_trivial_input=true '\
'--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 '\
'--xla_tpu_spmd_rng_bit_generator_unsafe=true '\
'--xla_tpu_assign_all_reduce_scatter_layout=true '\
'--xla_max_concurrent_async_collective_permutes=16 '\
'--xla_tpu_enable_ici_ag_pipelining=true'

source "$VENV/bin/activate"
export PYTHONPATH=$REPO_DIR/src:${PYTHONPATH:-}
export HF_HOME=/dev/shm/maxdiffusion_cache/huggingface
export HF_HUB_CACHE=/dev/shm/maxdiffusion_cache/huggingface/hub
export HF_HUB_ENABLE_HF_TRANSFER=1
export JAX_COMPILATION_CACHE_DIR=/dev/shm/maxdiffusion_cache/jax
export XLA_CACHE_DIR=/dev/shm/maxdiffusion_cache/xla
export TMPDIR=/dev/shm/maxdiffusion_cache/tmp
mkdir -p "$TMPDIR" "$HF_HOME" "$HF_HUB_CACHE" "$JAX_COMPILATION_CACHE_DIR" "$XLA_CACHE_DIR"

# ── helper (single host - no SSH) ────────────────────────────────────────────
# run_case <run_name> <attention> <ici_dp> <ici_a> <ici_b> <per_device_batch_size>
# pure (flash/tokamax_ring/ulysses): ici_a=cp, ici_b unused
# hybrid (ulysses_ring): ici_a=ring, ici_b=ulysses
run_case() {
local run_name="$1"
local attention="$2"
local ici_dp="$3"
local ici_a="$4"
local ici_b="$5"
local pdb="$6"

local results_dir="$RESULTS_ROOT/$run_name"
rm -rf "$results_dir" "$OUTPUT_ROOT/$run_name"
mkdir -p "$results_dir"
rm -f /tmp/libtpu_lockfile
mkdir -p "$TMPDIR"

local ici_cp ici_tp ici_ring ici_ulysses
if [[ "$attention" == "ulysses_ring" || "$attention" == "ulysses_tokamax_ring" ]]; then
ici_cp=1; ici_tp=1; ici_ring="$ici_a"; ici_ulysses="$ici_b"
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp ring=$ici_ring ulysses=$ici_ulysses pdb=$pdb) ──"
else
ici_cp="$ici_a"; ici_tp=1; ici_ring=1; ici_ulysses=1
echo "[$(date -u +%T)] ── Starting $run_name (attention=$attention dp=$ici_dp cp=$ici_cp pdb=$pdb) ──"
fi

# Profiler only for bs=2 (pdb=0.25 with 8 devices)
local profiler_args="enable_profiler=False"
if [[ "$pdb" == "0.25" ]]; then
profiler_args="enable_profiler=True skip_first_n_steps_for_profiler=5 profiler_steps=10"
fi

local common_args="run_name=$run_name \
attention=$attention \
ici_data_parallelism=$ici_dp \
ici_fsdp_parallelism=1 \
ici_context_parallelism=$ici_cp \
ici_tensor_parallelism=$ici_tp \
ici_ring_parallelism=$ici_ring \
ici_ulysses_parallelism=$ici_ulysses \
dcn_data_parallelism=1 \
dcn_fsdp_parallelism=1 \
dcn_context_parallelism=1 \
dcn_tensor_parallelism=1 \
dcn_ring_parallelism=1 \
dcn_ulysses_parallelism=1 \
pretrained_orbax_dir=$PRETRAINED_ORBAX_DIR \
height=720 width=1280 num_frames=81 num_inference_steps=40 \
per_device_batch_size=$pdb \
output_dir=$OUTPUT_ROOT \
scan_layers=True \
write_metrics=False \
write_timing_metrics=False \
$profiler_args"

cd "$results_dir"
set +e
python -u "$REPO_DIR/src/maxdiffusion/generate_wan.py" \
"$CONFIG" \
$common_args \
2>&1 | tee "$results_dir/worker0.log"
local status=$?
set -e

echo "[$(date -u +%T)] $run_name done — status=$status"
echo "────────────────────────────────────────────────────────────────────────"
}

# ── run matrix (TPU v7-8: 1 host × 8 chips = 8 devices) ──────────────────────
# Parallelism rule: dp × fsdp × cp × tp × ring × ulysses = 8
#
# Pure modes: dp=2, cp=4 (2×4=8)
# 2D (2×2): dp=2, ring=2, ulysses=2 (2×2×2=8)
#
# Batch sizes: pdb × 8 devices = total_bs
# 0.125 → bs1, 0.25 → bs2, 0.5 → bs4

for pdb in 0.125 0.25 0.5; do
bs=$(python3 -c "print(int($pdb * 8))")

run_case "flash_dp2_cp4_bs${bs}" flash 2 4 1 $pdb
run_case "tokamax_ring_dp2_cp4_bs${bs}" tokamax_ring 2 4 1 $pdb
run_case "ulysses_dp2_cp4_bs${bs}" ulysses 2 4 1 $pdb
run_case "ulysses_ring_dp2_r2u2_bs${bs}" ulysses_ring 2 2 2 $pdb
done

echo "[$(date -u +%T)] All benchmark runs complete."
Loading
Loading