Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import quantizations
from .modeling_flax_utils import get_activation

LOG2E = math.log2(math.e)

Array = common_types.Array
Mesh = common_types.Mesh
Expand Down Expand Up @@ -591,9 +592,7 @@ def wrap_ulysses_attention(query, key, value):
heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile)

if use_base2_exp:
query_scaled = query * 1.44269504
else:
query_scaled = query
query = query * LOG2E

query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq)
key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv)
Expand All @@ -612,7 +611,7 @@ def wrap_ulysses_attention(query, key, value):
)

vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0))
attention_output = vmapped_splash(query_scaled, key, value)
attention_output = vmapped_splash(query, key, value)
attention_output = jnp.swapaxes(attention_output, 2, 3)
attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
else:
Expand Down
2 changes: 1 addition & 1 deletion src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def user_init(raw_keys):
# Verify qkv is sharded across sequence.
attention = raw_keys["attention"]
uses_ring_attention = "ring" in attention
uses_ulysses_attention = attention == "ulysses"
uses_ulysses_attention = "ulysses" in attention
uses_uniform_sequence_sharding = raw_keys["attention_sharding_uniform"]
if uses_ring_attention or uses_ulysses_attention or uses_uniform_sequence_sharding:
max_logging.log(
Expand Down
Loading