Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,13 +895,12 @@ def __call__(
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

with self.conditional_named_scope("attn_qkv_proj"):
with self.conditional_named_scope("proj_query"):
query_proj = self.query(hidden_states)
with self.conditional_named_scope("proj_key"):
key_proj = self.key(encoder_hidden_states)
with self.conditional_named_scope("proj_value"):
value_proj = self.value(encoder_hidden_states)
with jax.named_scope("query_proj"):
query_proj = self.query(hidden_states)
with jax.named_scope("key_proj"):
key_proj = self.key(encoder_hidden_states)
with jax.named_scope("value_proj"):
value_proj = self.value(encoder_hidden_states)

if self.qk_norm:
with self.conditional_named_scope("attn_q_norm"):
Expand All @@ -921,13 +920,13 @@ def __call__(
key_proj = checkpoint_name(key_proj, "key_proj")
value_proj = checkpoint_name(value_proj, "value_proj")

with self.conditional_named_scope("attn_compute"):
with jax.named_scope("apply_attention"):
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)

attn_output = attn_output.astype(dtype=dtype)
attn_output = checkpoint_name(attn_output, "attn_output")

with self.conditional_named_scope("attn_out_proj"):
with jax.named_scope("proj_attn"):
hidden_states = self.proj_attn(attn_output)
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
return hidden_states
Expand Down
50 changes: 22 additions & 28 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def __call__(
):
timestep = self.timesteps_proj(timestep)
temb = self.time_embedder(timestep)

timestep_proj = self.time_proj(self.act_fn(temb))
with jax.named_scope("time_proj"):
timestep_proj = self.time_proj(self.act_fn(temb))

encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
Expand Down Expand Up @@ -186,7 +186,8 @@ def __init__(
)

def __call__(self, x: jax.Array) -> jax.Array:
x = self.proj(x)
with jax.named_scope("gelu"):
x = self.proj(x)
return nnx.gelu(x)


Expand Down Expand Up @@ -244,12 +245,11 @@ def conditional_named_scope(self, name: str):
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_down_proj"):
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
with jax.named_scope("proj_out"):
return self.proj_out(hidden_states) # output is (4, 75600, 5120)


class WanTransformerBlock(nnx.Module):
Expand Down Expand Up @@ -354,48 +354,42 @@ def __call__(
rngs: nnx.Rngs = None,
):
with self.conditional_named_scope("transformer_block"):
with self.conditional_named_scope("adaln"):
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))

# 1. Self-attention
with self.conditional_named_scope("self_attn"):
with self.conditional_named_scope("self_attn_norm"):
with jax.named_scope("attn1"):
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
hidden_states.dtype
)
with self.conditional_named_scope("self_attn_attn"):
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_hidden_states,
rotary_emb=rotary_emb,
deterministic=deterministic,
rngs=rngs,
)
with self.conditional_named_scope("self_attn_residual"):
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)

# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
attn_output = self.attn2(
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
)
hidden_states = hidden_states + attn_output
with jax.named_scope('attn2'):
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
attn_output = self.attn2(
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
)
hidden_states = hidden_states + attn_output

# 3. Feed-forward
with self.conditional_named_scope("mlp"):
with self.conditional_named_scope("mlp_norm"):
with jax.named_scope("ffn"):
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
hidden_states.dtype
)
with self.conditional_named_scope("mlp_ffn"):
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
with self.conditional_named_scope("mlp_residual"):
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
hidden_states.dtype
)
return hidden_states
Expand Down Expand Up @@ -543,6 +537,7 @@ def conditional_named_scope(self, name: str):
"""Return a JAX named scope if enabled, otherwise a null context."""
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()

@jax.named_scope('WanModel')
def __call__(
self,
hidden_states: jax.Array,
Expand Down Expand Up @@ -609,9 +604,8 @@ def layer_forward(hidden_states):
hidden_states = rematted_layer_forward(hidden_states)

shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
with self.conditional_named_scope("output_norm"):
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
with self.conditional_named_scope("output_proj"):
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
with jax.named_scope("proj_out"):
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(
Expand Down