diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 583695fa..7b707bff 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -287,11 +287,11 @@ def _tpu_flash_attention( ) -> jax.Array: """TPU Flash Attention""" - block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) num_context_shards = mesh.shape["context"] query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards) key, _ = _reshape_data_for_flash(key, heads, num_context_shards) value, _ = _reshape_data_for_flash(value, heads, num_context_shards) + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, attention_kernel) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) diff --git a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py index 485d500e..569f194a 100644 --- a/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py +++ b/src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py @@ -127,6 +127,8 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict): ltx2_config["dtype"] = config.activations_dtype ltx2_config["weights_dtype"] = config.weights_dtype ltx2_config["attention_kernel"] = config.attention + ltx2_config["a2v_attention_kernel"] = getattr(config, "a2v_attention_kernel", "flash") + ltx2_config["v2a_attention_kernel"] = getattr(config, "v2a_attention_kernel", "dot_product") ltx2_config["precision"] = get_precision(config) ltx2_config["flash_block_sizes"] = get_flash_block_sizes(config) ltx2_config["flash_min_seq_length"] = getattr(config, "flash_min_seq_length", 4096)