Support Stable Audio 3 model.#14010
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughThis PR adds complete Stable Audio 3 model support to ComfyUI across the inference pipeline. It introduces a new exponential Fourier feature embedder alongside learnable variants, enhances the transformer architecture with RMSNorm, differential attention, and feat_scale modulation, adds memory tokens and shared global conditioning to ContinuousTransformer, implements a StableAudio3 model wrapper with local/global conditioning logic, detects SA3 configurations from checkpoint state_dicts, and integrates the SA3AudioVAE and T5‑Gemma text encoder for end-to-end support. 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
comfy/ldm/audio/dit.py (2)
428-438:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winRepeat differential K on the head axis.
When
differential=Trueanddim_context != dim,kis shaped(B, 2, kv_h, M, D). Therepeat_interleave(..., dim=1)at Line 431 duplicates the differential branch axis, so Line 435 no longer unpacks into(k, k_diff)and cross-attention blows up as soon ash != kv_h.💡 Suggested fix
if h != kv_h: # Repeat interleave kv_heads to match q_heads heads_per_kv_head = h // kv_h - k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) + if self.differential: + k = k.repeat_interleave(heads_per_kv_head, dim=2) + else: + k = k.repeat_interleave(heads_per_kv_head, dim=1) + v = v.repeat_interleave(heads_per_kv_head, dim=1)
798-809:⚠️ Potential issue | 🟠 Major | ⚡ Quick winForward
local_add_condthrough the patched-block path too.The normal path at Lines 807-809 threads
local_add_condinto eachTransformerBlock, but thepatches_replace["dit"]wrapper drops it entirely. That makes SA3 local conditioning silently disappear whenever a custom node patches these blocks.As per coding guidelines, "Core ML/diffusion engine. Focus on: Backward compatibility (breaking changes affect all custom nodes)".💡 Suggested fix
def block_wrap(args): out = {} - out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"]) + out["img"] = layer( + args["img"], + rotary_pos_emb=args["pe"], + global_cond=args["vec"], + local_add_cond=args.get("local_add_cond"), + context=args["txt"], + transformer_options=args["transformer_options"], + ) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]( + { + "img": x, + "txt": context, + "vec": global_cond, + "pe": rotary_pos_emb, + "local_add_cond": local_add_cond, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) x = out["img"]🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@comfy/ldm/audio/dit.py` around lines 798 - 809, The patched-block path drops local_add_cond causing SA3 local conditioning to be lost; update the wrapper used for ("double_block", i) in blocks_replace so block_wrap and the call to blocks_replace[("double_block", i)] forward the local_add_cond parameter into layer (same as the else path): accept local_add_cond in block_wrap's args and pass it into layer, and when invoking blocks_replace[("double_block", i)] include the current local_add_cond in the context dict so x = out["img"] receives the conditioned output.
🧹 Nitpick comments (3)
comfy/model_base.py (2)
860-860: 💤 Low valueConsider documenting the magic constant.
The constant
10.7666appears to be a sample-rate or hop-length conversion factor specific to Stable Audio 3 (similar to21.53used in StableAudio1). Adding a comment explaining what this represents would improve maintainability.# Conversion factor from latent temporal dimension to seconds (model-specific) seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 10.7666))🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@comfy/model_base.py` at line 860, The line computing seconds_total uses a magic constant 10.7666 without explanation; update the code around seconds_total (referencing seconds_total, noise.shape[-1], and the literal 10.7666) to document or replace the magic number: add a brief comment stating that 10.7666 is the model-specific conversion factor from latent temporal dimension to seconds (Stable Audio 3), or extract it into a well-named constant (e.g., STABLE_AUDIO3_LATENT_TO_SECONDS) and use that constant in the calculation so the intent is clear and maintainable.
826-848: 💤 Low valueConsider addressing or documenting the TODO comments.
The TODO comments on lines 835 and 846 about scaling when shapes don't match suggest incomplete error handling. While the code works for matching shapes, consider either:
- Implementing the scaling logic (similar to line 262 in base class
concat_cond)- Documenting why scaling is not needed for this model
- Converting to a more specific comment if this is intentional
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@comfy/model_base.py` around lines 826 - 848, The concat_cond method currently has TODOs about scaling when image/mask spatial shapes don't match noise; update concat_cond (and its use of process_latent_in and utils.resize_to_batch_size) to either implement explicit scaling logic mirroring the base class concat_cond behavior (ensure image and mask are resized to noise.shape using the same resizing/resampling rules as base class) or replace the TODOs with a precise comment/docstring explaining why scaling is unnecessary for this model (including expected tensor shapes and invariants). Ensure the mask handling (mean reduction and invert) remains consistent and mention utils.resize_to_batch_size as the resizing call to use so future readers know where resizing happens.comfy/model_detection.py (1)
119-157: 💤 Low valueConsider adding a comment explaining the attention detection logic.
The logic for detecting attention configuration (lines 131-137) checks for different normalization layer parameters:
q_norm.weightindicates LayerNormq_norm.gammaindicates RMSNormWhile the logic is correct (these are mutually exclusive keys), adding a brief comment would improve maintainability and clarify why these checks are separate.
# Detect attention normalization type: LayerNorm uses .weight, RMSNorm uses .gamma if '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) in state_dict: unet_config["attn_kwargs"] = {"qk_norm": "ln", "feat_scale": True} rms_norm = state_dict.get('{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix), None) if rms_norm is not None: unet_config["attn_kwargs"] = {"qk_norm": "rms", "differential": differential}🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@comfy/model_detection.py` around lines 119 - 157, Add a brief clarifying comment above the attention normalization detection block in model_detection.py explaining that presence of '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) indicates LayerNorm (uses .weight) while presence of '{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix) indicates RMSNorm (uses .gamma), and that these are mutually exclusive; place this comment just before the if checking q_norm.weight and the rms_norm = state_dict.get(...) line so readers of the unet_config population (e.g., where attn_kwargs, norm_type, and num_heads are set) understand why the separate branches exist.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@comfy/ldm/audio/dit.py`:
- Around line 779-782: When prepending memory tokens in the block using
self.num_memory_tokens and memory_tokens (from
comfy.ops.cast_to_input(self.memory_tokens, x)), also extend the attention mask
to account for those new tokens: if mask is not None, create a memory_mask of
valid bits (ones/True) shaped (batch, self.num_memory_tokens) on the same
device/dtype as mask and concat it to the front of mask along dim=1; if mask is
None, create a new mask of ones with shape (batch, seq_len +
self.num_memory_tokens) or at least prepend a memory_mask and use the existing
mask afterwards. Ensure the batch sizing uses the same batch variable used for x
so the attention masking path receives matching shapes after x =
torch.cat((memory_tokens, x), dim=1).
In `@comfy/ldm/audio/embedders.py`:
- Around line 42-51: The forward method in the ExpoFourier embedder currently
casts the sinusoidal features back to the original input dtype (in_dtype), which
quantizes integer timesteps and destroys the signal; modify the return so the
embedding stays in a floating dtype instead of converting to in_dtype —
specifically, in the forward function (variables: in_dtype, t, args) remove the
final .to(in_dtype) cast and return the concatenated cos/sin tensor as a
floating type (e.g., keep t.float() / torch.float32) so callers receive
floating-point features.
---
Outside diff comments:
In `@comfy/ldm/audio/dit.py`:
- Around line 798-809: The patched-block path drops local_add_cond causing SA3
local conditioning to be lost; update the wrapper used for ("double_block", i)
in blocks_replace so block_wrap and the call to blocks_replace[("double_block",
i)] forward the local_add_cond parameter into layer (same as the else path):
accept local_add_cond in block_wrap's args and pass it into layer, and when
invoking blocks_replace[("double_block", i)] include the current local_add_cond
in the context dict so x = out["img"] receives the conditioned output.
---
Nitpick comments:
In `@comfy/model_base.py`:
- Line 860: The line computing seconds_total uses a magic constant 10.7666
without explanation; update the code around seconds_total (referencing
seconds_total, noise.shape[-1], and the literal 10.7666) to document or replace
the magic number: add a brief comment stating that 10.7666 is the model-specific
conversion factor from latent temporal dimension to seconds (Stable Audio 3), or
extract it into a well-named constant (e.g., STABLE_AUDIO3_LATENT_TO_SECONDS)
and use that constant in the calculation so the intent is clear and
maintainable.
- Around line 826-848: The concat_cond method currently has TODOs about scaling
when image/mask spatial shapes don't match noise; update concat_cond (and its
use of process_latent_in and utils.resize_to_batch_size) to either implement
explicit scaling logic mirroring the base class concat_cond behavior (ensure
image and mask are resized to noise.shape using the same resizing/resampling
rules as base class) or replace the TODOs with a precise comment/docstring
explaining why scaling is unnecessary for this model (including expected tensor
shapes and invariants). Ensure the mask handling (mean reduction and invert)
remains consistent and mention utils.resize_to_batch_size as the resizing call
to use so future readers know where resizing happens.
In `@comfy/model_detection.py`:
- Around line 119-157: Add a brief clarifying comment above the attention
normalization detection block in model_detection.py explaining that presence of
'{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) indicates
LayerNorm (uses .weight) while presence of
'{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix) indicates
RMSNorm (uses .gamma), and that these are mutually exclusive; place this comment
just before the if checking q_norm.weight and the rms_norm = state_dict.get(...)
line so readers of the unet_config population (e.g., where attn_kwargs,
norm_type, and num_heads are set) understand why the separate branches exist.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 8ef9c209-172e-4b4e-bb8d-9ae9b9ef78b3
📒 Files selected for processing (7)
comfy/latent_formats.pycomfy/ldm/audio/dit.pycomfy/ldm/audio/embedders.pycomfy/model_base.pycomfy/model_detection.pycomfy/sd.pycomfy/supported_models.py
| if self.num_memory_tokens > 0: | ||
| memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1) | ||
| x = torch.cat((memory_tokens, x), dim=1) | ||
|
|
There was a problem hiding this comment.
Extend mask when memory tokens are prepended.
This branch lengthens x by num_memory_tokens, but mask keeps the pre-memory length. Any masked call will then hit a shape mismatch in the attention output masking path, and the memory tokens never get valid mask bits.
💡 Suggested fix
if self.num_memory_tokens > 0:
memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1)
x = torch.cat((memory_tokens, x), dim=1)
+ if mask is not None:
+ memory_mask = torch.ones((batch, self.num_memory_tokens), device=device, dtype=torch.bool)
+ mask = torch.cat((memory_mask, mask), dim=-1)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@comfy/ldm/audio/dit.py` around lines 779 - 782, When prepending memory tokens
in the block using self.num_memory_tokens and memory_tokens (from
comfy.ops.cast_to_input(self.memory_tokens, x)), also extend the attention mask
to account for those new tokens: if mask is not None, create a memory_mask of
valid bits (ones/True) shaped (batch, self.num_memory_tokens) on the same
device/dtype as mask and concat it to the front of mask along dim=1; if mask is
None, create a new mask of ones with shape (batch, seq_len +
self.num_memory_tokens) or at least prepend a memory_mask and use the existing
mask afterwards. Ensure the batch sizing uses the same batch variable used for x
so the attention masking path receives matching shapes after x =
torch.cat((memory_tokens, x), dim=1).
| def forward(self, t): | ||
| in_dtype = t.dtype | ||
| t = t.float() | ||
| if t.dim() == 1: | ||
| t = t.unsqueeze(-1) | ||
| half_dim = self.dim // 2 | ||
| ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32) | ||
| freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq)) | ||
| args = t * freqs * 2 * math.pi | ||
| return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype) |
There was a problem hiding this comment.
Keep expo Fourier features in floating point.
Line 51 casts the sinusoidal features back to t.dtype. If the caller passes integer timesteps, this quantizes the embedding before to_timestep_embed sees it and collapses most of the signal. Keep the output in a floating dtype instead of round-tripping through the input dtype.
💡 Suggested fix
def forward(self, t):
- in_dtype = t.dtype
+ out_dtype = t.dtype if t.is_floating_point() else torch.float32
t = t.float()
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = self.dim // 2
ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
args = t * freqs * 2 * math.pi
- return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype)
+ return torch.cat([args.cos(), args.sin()], dim=-1).to(out_dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def forward(self, t): | |
| in_dtype = t.dtype | |
| t = t.float() | |
| if t.dim() == 1: | |
| t = t.unsqueeze(-1) | |
| half_dim = self.dim // 2 | |
| ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32) | |
| freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq)) | |
| args = t * freqs * 2 * math.pi | |
| return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype) | |
| def forward(self, t): | |
| out_dtype = t.dtype if t.is_floating_point() else torch.float32 | |
| t = t.float() | |
| if t.dim() == 1: | |
| t = t.unsqueeze(-1) | |
| half_dim = self.dim // 2 | |
| ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32) | |
| freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq)) | |
| args = t * freqs * 2 * math.pi | |
| return torch.cat([args.cos(), args.sin()], dim=-1).to(out_dtype) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@comfy/ldm/audio/embedders.py` around lines 42 - 51, The forward method in the
ExpoFourier embedder currently casts the sinusoidal features back to the
original input dtype (in_dtype), which quantizes integer timesteps and destroys
the signal; modify the return so the embedding stays in a floating dtype instead
of converting to in_dtype — specifically, in the forward function (variables:
in_dtype, t, args) remove the final .to(in_dtype) cast and return the
concatenated cos/sin tensor as a floating type (e.g., keep t.float() /
torch.float32) so callers receive floating-point features.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@comfy/ldm/audio/vae_sa3.py`:
- Around line 521-540: The encode/decode pair (_pad, encode, decode) currently
pads inputs but doesn't preserve the original length, so decode(encode(x))
returns the padded length; fix by returning/propagating the original length or
pad count from encode (e.g., have encode return (latent, orig_len) or attach
pad_len to the returned tuple) and modify decode to accept that metadata and
crop the reconstructed waveform to orig_len (use patch_size to compute pad_len
if needed) instead of storing any state on the module; update all call sites to
pass through the length/pad metadata accordingly (also apply same change to the
similar methods around lines 590-600).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 9a1802d4-8600-4050-94b3-3cafda4589f1
📒 Files selected for processing (2)
comfy/ldm/audio/vae_sa3.pycomfy/text_encoders/sa3.py
| def _pad(self, x): | ||
| pad_len = (self.patch_size - x.shape[-1] % self.patch_size) % self.patch_size | ||
| if pad_len > 0: | ||
| x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) | ||
| return x | ||
|
|
||
| def encode(self, x): | ||
| x = self._pad(x) | ||
| B, C, T = x.shape | ||
| h = self.patch_size | ||
| L = T // h | ||
| # b c (l h) -> b (c h) l | ||
| return x.reshape(B, C, L, h).permute(0, 1, 3, 2).reshape(B, C * h, L) | ||
|
|
||
| def decode(self, x): | ||
| B, Ch, L = x.shape | ||
| h = self.patch_size | ||
| C = Ch // h | ||
| # b (c h) l -> b c (l h) | ||
| return x.reshape(B, C, h, L).permute(0, 1, 3, 2).reshape(B, C, L * h) |
There was a problem hiding this comment.
Preserve the original sample length across pretransform round-trips.
_pad() extends any non-256-aligned waveform, but the encode/decode path never carries that pre-pad length forward, so decode(encode(x)) returns ceil(T / 256) * 256 samples instead of T. That will shift timing for most real inputs. Please thread the original length or pad count alongside the latent and crop in decode() rather than storing it on module state.
Also applies to: 590-600
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@comfy/ldm/audio/vae_sa3.py` around lines 521 - 540, The encode/decode pair
(_pad, encode, decode) currently pads inputs but doesn't preserve the original
length, so decode(encode(x)) returns the padded length; fix by
returning/propagating the original length or pad count from encode (e.g., have
encode return (latent, orig_len) or attach pad_len to the returned tuple) and
modify decode to accept that metadata and crop the reconstructed waveform to
orig_len (use patch_size to compute pad_len if needed) instead of storing any
state on the module; update all call sites to pass through the length/pad
metadata accordingly (also apply same change to the similar methods around lines
590-600).
| attn = torch.matmul(xq * self.scale, xk.transpose(-2, -1)) | ||
| attn = torch.tanh(attn / self.softcap) * self.softcap | ||
| if attention_mask is not None: | ||
| attn = attn + attention_mask | ||
| attn = torch.nn.functional.softmax(attn.float(), dim=-1).to(xq.dtype) | ||
| out = torch.matmul(attn, xv).transpose(1, 2).reshape(B, S, self.inner_size) |
There was a problem hiding this comment.
Sliding-window attention is still O(S²) and the tokenizer makes it easy to hit.
This path still materializes both attn and sw_mask as dense S x S tensors, so long prompts will run out of memory before the window helps. Line 190 then makes that user-reachable by leaving token length effectively unbounded. Please clamp accepted sequence length until there is a true windowed/chunked attention path here.
As per coding guidelines, comfy/**: Core ML/diffusion engine. Focus on: Memory management and GPU resource handling.
Also applies to: 77-84, 184-192
No description provided.