diff --git a/docs/attention_blocks_flowchart.md b/docs/attention_blocks_flowchart.md new file mode 100644 index 00000000..69816ac7 --- /dev/null +++ b/docs/attention_blocks_flowchart.md @@ -0,0 +1,30 @@ +# Attention block sizes + +## Description +- "block_q": Block sizes (HBM TO VMEM and VREG) to tile along Q sequence in forward pass +- "block_kv_compute" : Sub Block size (VMEM to VREG) of "block_kv" where compute is performed in forward pass. It must be factor or same as "block_kv" +- "block_kv" : Block sizes (HBM TO VMEM) to tile along KV sequence in forward pass +- "block_q_dkv" : Block sizes along Q sequence in backward pass with fused kernel to compute gradient of q, k , v. It must be factor or same as block_q +- "block_kv_dkv" : Block sizes along KV sequence in backward pass. It must be factor or same as block_kv +- "block_kv_dkv_compute" : Sub Block Sizes of block_kv_dkv, must be factor or same as "block_kv_dkv" +- "block_q_dq" : Block sizes along Q sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_q" +- "block_kv_dq" : Block sizes along KV to tiline on KV sequence in backward pass with unfused kernel to compute gradient of just q. it must be factor or same as "block_kv" +- "use_fused_bwd_kernel" : This means fused bwd kernel is used where DQ, DK, DV are computed in single kernel. It usually more perfomant but comes with slight HBM memory overhead. + +## Flowchart + +Maxdiffusion automatically adheres to this flowchart to ensure working, and there is a log that will inform you on the modifications that maxdiffusion makes to the specified block sizes. + +![alt text](attention_blocks_flowchart.png) + +> "tokamax_flash" uses the splash attention implementation in [tokamax-repo](https://github.com/openxla/tokamax/blob/main/tokamax/_src/ops/experimental/tpu/splash_attention/splash_attention_kernel.py) This kernel only supports fused backward pass where gradients for q,k,v are computed in a single kernel so "block_q_dq" and "block_kv_dq" are not used + +## How block sizes matter for perfomance and accuracy + +Block sizes key to saturating HBM bandwidth and ensuring maximum possible overlap of computation on cores with HBM use and VMEM to VREG. It is highly recommended to tune them. + +Block sizes also have an effect on the sequence length. Sequence length is multiple of resolution and number of frames (video), along with VAE scale down factors and patchifying ratios. This sequence length or shard of this sequence length needs to be multiple of the block sizes specified. Therefore maxdiffusion pads the sequence lengths to the nearest multiple of the block sizes. It is advisable to choose block sizes which are factor of sequence length, atleast for the Q block sizes. + +> In cross attention Image or Video tokens are attending to text tokens sequence length of text tokens is really small and potentially smaller than specified block size so KV block sizes are overwritten to safe values. + +> KV block sizes must be multiple of 128 since the size of register is 8x128 and in attention KV sequence dim lies on 128 for the multiplications as K is transposed. \ No newline at end of file diff --git a/docs/attention_blocks_flowchart.png b/docs/attention_blocks_flowchart.png new file mode 100644 index 00000000..bed28e63 Binary files /dev/null and b/docs/attention_blocks_flowchart.png differ diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index 07e9bd15..48c6ca44 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -501,17 +501,26 @@ def get_flash_block_sizes(config): """Create custom flash attention BlockSizes.""" flash_block_sizes = None if len(config.flash_block_sizes.keys()) > 0: - use_fused_bwd_kernel = config.flash_block_sizes.get("use_fused_bwd_kernel", False) + attention_is_tokamax = "tokamax" in config.attention + user_block_sizes:Dict[str, int] = config.flash_block_sizes + if attention_is_tokamax: + max_logging.log("Tokamax kernel specified, Note: Tokamax only supports fused backward kernel." + "Hence following flash block properties specified will be ignored:" + f"block_q: {user_block_sizes['block_q']}," + f"block_q_dq: {user_block_sizes.get('block_q_dq')}," + f"block_kv_dq: {user_block_sizes.get('block_kv_dq')}," + f"use_fused_bwd_kernel: {user_block_sizes.get('use_fused_bwd_kernel')}" + ) flash_block_sizes = splash_attention_kernel.BlockSizes( - block_q=config.flash_block_sizes["block_q"], - block_kv_compute=config.flash_block_sizes["block_kv_compute"], - block_kv=config.flash_block_sizes["block_kv"], - block_q_dkv=config.flash_block_sizes["block_q_dkv"], - block_kv_dkv=config.flash_block_sizes["block_kv_dkv"], - block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"], - block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"), - block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"), - use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"), + block_q=user_block_sizes.get("block_q_dkv", user_block_sizes["block_kv"]) if attention_is_tokamax else user_block_sizes["block_q"], + block_kv_compute=user_block_sizes["block_kv_compute"], + block_kv=user_block_sizes["block_kv"], + block_q_dkv=user_block_sizes["block_q_dkv"], + block_kv_dkv=user_block_sizes["block_kv_dkv"], + block_kv_dkv_compute=user_block_sizes["block_kv_dkv_compute"], + block_q_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_q_dq"), + block_kv_dq=None if attention_is_tokamax else value_or_none(user_block_sizes, "block_kv_dq"), + use_fused_bwd_kernel=True if attention_is_tokamax else value_or_none(user_block_sizes, "use_fused_bwd_kernel"), ) return flash_block_sizes diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 0dc4a9bf..6a578899 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -233,14 +233,15 @@ def _tpu_flash_attention( if flash_block_sizes and key.shape[1] == query.shape[1]: block_sizes = flash_block_sizes else: + block_size_q = flash_block_sizes.block_q if flash_block_sizes else q_max_block_size block_sizes = splash_attention_kernel.BlockSizes( - block_q=min(q_max_block_size, query.shape[2]), + block_q=block_size_q, block_kv_compute=min(kv_max_block_size, key.shape[2]), block_kv=min(kv_max_block_size, key.shape[2]), - block_q_dkv=min(q_max_block_size, query.shape[2]), + block_q_dkv=block_size_q, block_kv_dkv=min(kv_max_block_size, key.shape[2]), block_kv_dkv_compute=min(kv_max_block_size, query.shape[2]), - block_q_dq=None if attention_kernel == "tokamax_flash" else block_sizes.block_q_dq, + block_q_dq=None if attention_kernel == "tokamax_flash" else block_size_q, block_kv_dq=None if attention_kernel == "tokamax_flash" else min(kv_max_block_size, query.shape[2]), use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False, ) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index d40edfad..47a41234 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -179,69 +179,69 @@ def test_wan_block(self): dummy_encoder_hidden_states = jnp.ones((batch_size, 512, dim)) dummy_temb = jnp.ones((batch_size, 6, dim)) - with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - wan_block = WanTransformerBlock( - rngs=rngs, - dim=dim, - ffn_dim=ffn_dim, - num_heads=num_heads, - qk_norm=qk_norm, - cross_attn_norm=cross_attn_norm, - eps=eps, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) + + wan_block = WanTransformerBlock( + rngs=rngs, + dim=dim, + ffn_dim=ffn_dim, + num_heads=num_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + attention="flash", + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + with mesh: dummy_output = wan_block(dummy_hidden_states, dummy_encoder_hidden_states, dummy_temb, dummy_rotary_emb) assert dummy_output.shape == dummy_hidden_states.shape def test_wan_attention(self): - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), - ], - unittest=True, - ) - config = pyconfig.config - - batch_size = 1 - channels = 16 - frames = 21 - height = 90 - width = 160 - hidden_states_shape = (batch_size, frames, height, width, channels) - dummy_hidden_states = jnp.ones(hidden_states_shape) - wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) - dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) - - key = jax.random.key(0) - rngs = nnx.Rngs(key) - devices_array = create_device_mesh(config) - - flash_block_sizes = get_flash_block_sizes(config) - - mesh = Mesh(devices_array, config.mesh_axes) - batch_size = 1 - query_dim = 5120 - with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - attention = FlaxWanAttention( - rngs=rngs, - query_dim=query_dim, - heads=40, - dim_head=128, - attention_kernel="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, + for attention_kernel in ["flash", "tokamax_flash"]: + pyconfig.initialize( + [ + None, + os.path.join(THIS_DIR, "..", "configs", "base_wan_14b.yml"), + f"attention={attention_kernel}" + ], + unittest=True ) - dummy_hidden_states_shape = (batch_size, 75600, query_dim) + config = pyconfig.config + batch_size = 1 + channels = 16 + frames = 21 + height = 90 + width = 160 + hidden_states_shape = (batch_size, frames, height, width, channels) + dummy_hidden_states = jnp.ones(hidden_states_shape) + wan_rot_embed = WanRotaryPosEmbed(attention_head_dim=128, patch_size=[1, 2, 2], max_seq_len=1024) + dummy_rotary_emb = wan_rot_embed(dummy_hidden_states) + + key = jax.random.key(0) + rngs = nnx.Rngs(key) + devices_array = create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + batch_size = 1 + query_dim = 5120 + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + flash_block_sizes = get_flash_block_sizes(config) + attention = FlaxWanAttention( + rngs=rngs, + query_dim=query_dim, + heads=40, + dim_head=128, + attention_kernel=attention_kernel, + mesh=mesh, + flash_block_sizes=flash_block_sizes, + ) + dummy_hidden_states_shape = (batch_size, 75600, query_dim) - dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) - dummy_output = attention( - hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb - ) - assert dummy_output.shape == dummy_hidden_states_shape + dummy_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_encoder_hidden_states = jnp.ones(dummy_hidden_states_shape) + dummy_output = attention( + hidden_states=dummy_hidden_states, encoder_hidden_states=dummy_encoder_hidden_states, rotary_emb=dummy_rotary_emb + ) + assert dummy_output.shape == dummy_hidden_states_shape # dot product try: