Skip to content

Support Stable Audio 3 model.#14010

Merged
comfyanonymous merged 3 commits into
masterfrom
temp_pr1
May 20, 2026
Merged

Support Stable Audio 3 model.#14010
comfyanonymous merged 3 commits into
masterfrom
temp_pr1

Conversation

@comfyanonymous
Copy link
Copy Markdown
Member

No description provided.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 20, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cdbb6142-5e15-4091-8e6d-c97e94b5ab36

📥 Commits

Reviewing files that changed from the base of the PR and between 1ba7da0 and c70d4a3.

📒 Files selected for processing (1)
  • comfy/ldm/audio/vae_sa3.py

📝 Walkthrough

Walkthrough

This PR adds complete Stable Audio 3 model support to ComfyUI across the inference pipeline. It introduces a new exponential Fourier feature embedder alongside learnable variants, enhances the transformer architecture with RMSNorm, differential attention, and feat_scale modulation, adds memory tokens and shared global conditioning to ContinuousTransformer, implements a StableAudio3 model wrapper with local/global conditioning logic, detects SA3 configurations from checkpoint state_dicts, and integrates the SA3AudioVAE and T5‑Gemma text encoder for end-to-end support.

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 2.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive No pull request description was provided by the author, making it impossible to assess whether a description exists or relate it to the changeset. Consider adding a description explaining the Stable Audio 3 support implementation, key changes, and any relevant context for reviewers.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Support Stable Audio 3 model' clearly and concisely summarizes the main objective of the changeset, which adds comprehensive support for the Stable Audio 3 model across multiple files and components.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
comfy/ldm/audio/dit.py (2)

428-438: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Repeat differential K on the head axis.

When differential=True and dim_context != dim, k is shaped (B, 2, kv_h, M, D). The repeat_interleave(..., dim=1) at Line 431 duplicates the differential branch axis, so Line 435 no longer unpacks into (k, k_diff) and cross-attention blows up as soon as h != kv_h.

💡 Suggested fix
         if h != kv_h:
             # Repeat interleave kv_heads to match q_heads
             heads_per_kv_head = h // kv_h
-            k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
+            if self.differential:
+                k = k.repeat_interleave(heads_per_kv_head, dim=2)
+            else:
+                k = k.repeat_interleave(heads_per_kv_head, dim=1)
+            v = v.repeat_interleave(heads_per_kv_head, dim=1)

798-809: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Forward local_add_cond through the patched-block path too.

The normal path at Lines 807-809 threads local_add_cond into each TransformerBlock, but the patches_replace["dit"] wrapper drops it entirely. That makes SA3 local conditioning silently disappear whenever a custom node patches these blocks.

💡 Suggested fix
                 def block_wrap(args):
                     out = {}
-                    out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"], transformer_options=args["transformer_options"])
+                    out["img"] = layer(
+                        args["img"],
+                        rotary_pos_emb=args["pe"],
+                        global_cond=args["vec"],
+                        local_add_cond=args.get("local_add_cond"),
+                        context=args["txt"],
+                        transformer_options=args["transformer_options"],
+                    )
                     return out

-                out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
+                out = blocks_replace[("double_block", i)](
+                    {
+                        "img": x,
+                        "txt": context,
+                        "vec": global_cond,
+                        "pe": rotary_pos_emb,
+                        "local_add_cond": local_add_cond,
+                        "transformer_options": transformer_options,
+                    },
+                    {"original_block": block_wrap},
+                )
                 x = out["img"]
As per coding guidelines, "Core ML/diffusion engine. Focus on: Backward compatibility (breaking changes affect all custom nodes)".
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/ldm/audio/dit.py` around lines 798 - 809, The patched-block path drops
local_add_cond causing SA3 local conditioning to be lost; update the wrapper
used for ("double_block", i) in blocks_replace so block_wrap and the call to
blocks_replace[("double_block", i)] forward the local_add_cond parameter into
layer (same as the else path): accept local_add_cond in block_wrap's args and
pass it into layer, and when invoking blocks_replace[("double_block", i)]
include the current local_add_cond in the context dict so x = out["img"]
receives the conditioned output.
🧹 Nitpick comments (3)
comfy/model_base.py (2)

860-860: 💤 Low value

Consider documenting the magic constant.

The constant 10.7666 appears to be a sample-rate or hop-length conversion factor specific to Stable Audio 3 (similar to 21.53 used in StableAudio1). Adding a comment explaining what this represents would improve maintainability.

# Conversion factor from latent temporal dimension to seconds (model-specific)
seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 10.7666))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/model_base.py` at line 860, The line computing seconds_total uses a
magic constant 10.7666 without explanation; update the code around seconds_total
(referencing seconds_total, noise.shape[-1], and the literal 10.7666) to
document or replace the magic number: add a brief comment stating that 10.7666
is the model-specific conversion factor from latent temporal dimension to
seconds (Stable Audio 3), or extract it into a well-named constant (e.g.,
STABLE_AUDIO3_LATENT_TO_SECONDS) and use that constant in the calculation so the
intent is clear and maintainable.

826-848: 💤 Low value

Consider addressing or documenting the TODO comments.

The TODO comments on lines 835 and 846 about scaling when shapes don't match suggest incomplete error handling. While the code works for matching shapes, consider either:

  • Implementing the scaling logic (similar to line 262 in base class concat_cond)
  • Documenting why scaling is not needed for this model
  • Converting to a more specific comment if this is intentional
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/model_base.py` around lines 826 - 848, The concat_cond method currently
has TODOs about scaling when image/mask spatial shapes don't match noise; update
concat_cond (and its use of process_latent_in and utils.resize_to_batch_size) to
either implement explicit scaling logic mirroring the base class concat_cond
behavior (ensure image and mask are resized to noise.shape using the same
resizing/resampling rules as base class) or replace the TODOs with a precise
comment/docstring explaining why scaling is unnecessary for this model
(including expected tensor shapes and invariants). Ensure the mask handling
(mean reduction and invert) remains consistent and mention
utils.resize_to_batch_size as the resizing call to use so future readers know
where resizing happens.
comfy/model_detection.py (1)

119-157: 💤 Low value

Consider adding a comment explaining the attention detection logic.

The logic for detecting attention configuration (lines 131-137) checks for different normalization layer parameters:

  • q_norm.weight indicates LayerNorm
  • q_norm.gamma indicates RMSNorm

While the logic is correct (these are mutually exclusive keys), adding a brief comment would improve maintainability and clarify why these checks are separate.

# Detect attention normalization type: LayerNorm uses .weight, RMSNorm uses .gamma
if '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) in state_dict:
    unet_config["attn_kwargs"] = {"qk_norm": "ln", "feat_scale": True}
rms_norm = state_dict.get('{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix), None)
if rms_norm is not None:
    unet_config["attn_kwargs"] = {"qk_norm": "rms", "differential": differential}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/model_detection.py` around lines 119 - 157, Add a brief clarifying
comment above the attention normalization detection block in model_detection.py
explaining that presence of
'{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) indicates
LayerNorm (uses .weight) while presence of
'{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix) indicates
RMSNorm (uses .gamma), and that these are mutually exclusive; place this comment
just before the if checking q_norm.weight and the rms_norm = state_dict.get(...)
line so readers of the unet_config population (e.g., where attn_kwargs,
norm_type, and num_heads are set) understand why the separate branches exist.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@comfy/ldm/audio/dit.py`:
- Around line 779-782: When prepending memory tokens in the block using
self.num_memory_tokens and memory_tokens (from
comfy.ops.cast_to_input(self.memory_tokens, x)), also extend the attention mask
to account for those new tokens: if mask is not None, create a memory_mask of
valid bits (ones/True) shaped (batch, self.num_memory_tokens) on the same
device/dtype as mask and concat it to the front of mask along dim=1; if mask is
None, create a new mask of ones with shape (batch, seq_len +
self.num_memory_tokens) or at least prepend a memory_mask and use the existing
mask afterwards. Ensure the batch sizing uses the same batch variable used for x
so the attention masking path receives matching shapes after x =
torch.cat((memory_tokens, x), dim=1).

In `@comfy/ldm/audio/embedders.py`:
- Around line 42-51: The forward method in the ExpoFourier embedder currently
casts the sinusoidal features back to the original input dtype (in_dtype), which
quantizes integer timesteps and destroys the signal; modify the return so the
embedding stays in a floating dtype instead of converting to in_dtype —
specifically, in the forward function (variables: in_dtype, t, args) remove the
final .to(in_dtype) cast and return the concatenated cos/sin tensor as a
floating type (e.g., keep t.float() / torch.float32) so callers receive
floating-point features.

---

Outside diff comments:
In `@comfy/ldm/audio/dit.py`:
- Around line 798-809: The patched-block path drops local_add_cond causing SA3
local conditioning to be lost; update the wrapper used for ("double_block", i)
in blocks_replace so block_wrap and the call to blocks_replace[("double_block",
i)] forward the local_add_cond parameter into layer (same as the else path):
accept local_add_cond in block_wrap's args and pass it into layer, and when
invoking blocks_replace[("double_block", i)] include the current local_add_cond
in the context dict so x = out["img"] receives the conditioned output.

---

Nitpick comments:
In `@comfy/model_base.py`:
- Line 860: The line computing seconds_total uses a magic constant 10.7666
without explanation; update the code around seconds_total (referencing
seconds_total, noise.shape[-1], and the literal 10.7666) to document or replace
the magic number: add a brief comment stating that 10.7666 is the model-specific
conversion factor from latent temporal dimension to seconds (Stable Audio 3), or
extract it into a well-named constant (e.g., STABLE_AUDIO3_LATENT_TO_SECONDS)
and use that constant in the calculation so the intent is clear and
maintainable.
- Around line 826-848: The concat_cond method currently has TODOs about scaling
when image/mask spatial shapes don't match noise; update concat_cond (and its
use of process_latent_in and utils.resize_to_batch_size) to either implement
explicit scaling logic mirroring the base class concat_cond behavior (ensure
image and mask are resized to noise.shape using the same resizing/resampling
rules as base class) or replace the TODOs with a precise comment/docstring
explaining why scaling is unnecessary for this model (including expected tensor
shapes and invariants). Ensure the mask handling (mean reduction and invert)
remains consistent and mention utils.resize_to_batch_size as the resizing call
to use so future readers know where resizing happens.

In `@comfy/model_detection.py`:
- Around line 119-157: Add a brief clarifying comment above the attention
normalization detection block in model_detection.py explaining that presence of
'{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) indicates
LayerNorm (uses .weight) while presence of
'{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix) indicates
RMSNorm (uses .gamma), and that these are mutually exclusive; place this comment
just before the if checking q_norm.weight and the rms_norm = state_dict.get(...)
line so readers of the unet_config population (e.g., where attn_kwargs,
norm_type, and num_heads are set) understand why the separate branches exist.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 8ef9c209-172e-4b4e-bb8d-9ae9b9ef78b3

📥 Commits

Reviewing files that changed from the base of the PR and between 78b5dec and 83ec7ec.

📒 Files selected for processing (7)
  • comfy/latent_formats.py
  • comfy/ldm/audio/dit.py
  • comfy/ldm/audio/embedders.py
  • comfy/model_base.py
  • comfy/model_detection.py
  • comfy/sd.py
  • comfy/supported_models.py

Comment thread comfy/ldm/audio/dit.py
Comment on lines +779 to 782
if self.num_memory_tokens > 0:
memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1)
x = torch.cat((memory_tokens, x), dim=1)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Extend mask when memory tokens are prepended.

This branch lengthens x by num_memory_tokens, but mask keeps the pre-memory length. Any masked call will then hit a shape mismatch in the attention output masking path, and the memory tokens never get valid mask bits.

💡 Suggested fix
         if self.num_memory_tokens > 0:
             memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1)
             x = torch.cat((memory_tokens, x), dim=1)
+            if mask is not None:
+                memory_mask = torch.ones((batch, self.num_memory_tokens), device=device, dtype=torch.bool)
+                mask = torch.cat((memory_mask, mask), dim=-1)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/ldm/audio/dit.py` around lines 779 - 782, When prepending memory tokens
in the block using self.num_memory_tokens and memory_tokens (from
comfy.ops.cast_to_input(self.memory_tokens, x)), also extend the attention mask
to account for those new tokens: if mask is not None, create a memory_mask of
valid bits (ones/True) shaped (batch, self.num_memory_tokens) on the same
device/dtype as mask and concat it to the front of mask along dim=1; if mask is
None, create a new mask of ones with shape (batch, seq_len +
self.num_memory_tokens) or at least prepend a memory_mask and use the existing
mask afterwards. Ensure the batch sizing uses the same batch variable used for x
so the attention masking path receives matching shapes after x =
torch.cat((memory_tokens, x), dim=1).

Comment on lines +42 to +51
def forward(self, t):
in_dtype = t.dtype
t = t.float()
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = self.dim // 2
ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
args = t * freqs * 2 * math.pi
return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep expo Fourier features in floating point.

Line 51 casts the sinusoidal features back to t.dtype. If the caller passes integer timesteps, this quantizes the embedding before to_timestep_embed sees it and collapses most of the signal. Keep the output in a floating dtype instead of round-tripping through the input dtype.

💡 Suggested fix
     def forward(self, t):
-        in_dtype = t.dtype
+        out_dtype = t.dtype if t.is_floating_point() else torch.float32
         t = t.float()
         if t.dim() == 1:
             t = t.unsqueeze(-1)
         half_dim = self.dim // 2
         ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
         freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
         args = t * freqs * 2 * math.pi
-        return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype)
+        return torch.cat([args.cos(), args.sin()], dim=-1).to(out_dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(self, t):
in_dtype = t.dtype
t = t.float()
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = self.dim // 2
ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
args = t * freqs * 2 * math.pi
return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype)
def forward(self, t):
out_dtype = t.dtype if t.is_floating_point() else torch.float32
t = t.float()
if t.dim() == 1:
t = t.unsqueeze(-1)
half_dim = self.dim // 2
ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32)
freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq))
args = t * freqs * 2 * math.pi
return torch.cat([args.cos(), args.sin()], dim=-1).to(out_dtype)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/ldm/audio/embedders.py` around lines 42 - 51, The forward method in the
ExpoFourier embedder currently casts the sinusoidal features back to the
original input dtype (in_dtype), which quantizes integer timesteps and destroys
the signal; modify the return so the embedding stays in a floating dtype instead
of converting to in_dtype — specifically, in the forward function (variables:
in_dtype, t, args) remove the final .to(in_dtype) cast and return the
concatenated cos/sin tensor as a floating type (e.g., keep t.float() /
torch.float32) so callers receive floating-point features.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@comfy/ldm/audio/vae_sa3.py`:
- Around line 521-540: The encode/decode pair (_pad, encode, decode) currently
pads inputs but doesn't preserve the original length, so decode(encode(x))
returns the padded length; fix by returning/propagating the original length or
pad count from encode (e.g., have encode return (latent, orig_len) or attach
pad_len to the returned tuple) and modify decode to accept that metadata and
crop the reconstructed waveform to orig_len (use patch_size to compute pad_len
if needed) instead of storing any state on the module; update all call sites to
pass through the length/pad metadata accordingly (also apply same change to the
similar methods around lines 590-600).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9a1802d4-8600-4050-94b3-3cafda4589f1

📥 Commits

Reviewing files that changed from the base of the PR and between 83ec7ec and 1ba7da0.

📒 Files selected for processing (2)
  • comfy/ldm/audio/vae_sa3.py
  • comfy/text_encoders/sa3.py

Comment on lines +521 to +540
def _pad(self, x):
pad_len = (self.patch_size - x.shape[-1] % self.patch_size) % self.patch_size
if pad_len > 0:
x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1)
return x

def encode(self, x):
x = self._pad(x)
B, C, T = x.shape
h = self.patch_size
L = T // h
# b c (l h) -> b (c h) l
return x.reshape(B, C, L, h).permute(0, 1, 3, 2).reshape(B, C * h, L)

def decode(self, x):
B, Ch, L = x.shape
h = self.patch_size
C = Ch // h
# b (c h) l -> b c (l h)
return x.reshape(B, C, h, L).permute(0, 1, 3, 2).reshape(B, C, L * h)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Preserve the original sample length across pretransform round-trips.

_pad() extends any non-256-aligned waveform, but the encode/decode path never carries that pre-pad length forward, so decode(encode(x)) returns ceil(T / 256) * 256 samples instead of T. That will shift timing for most real inputs. Please thread the original length or pad count alongside the latent and crop in decode() rather than storing it on module state.

Also applies to: 590-600

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy/ldm/audio/vae_sa3.py` around lines 521 - 540, The encode/decode pair
(_pad, encode, decode) currently pads inputs but doesn't preserve the original
length, so decode(encode(x)) returns the padded length; fix by
returning/propagating the original length or pad count from encode (e.g., have
encode return (latent, orig_len) or attach pad_len to the returned tuple) and
modify decode to accept that metadata and crop the reconstructed waveform to
orig_len (use patch_size to compute pad_len if needed) instead of storing any
state on the module; update all call sites to pass through the length/pad
metadata accordingly (also apply same change to the similar methods around lines
590-600).

Comment on lines +53 to +58
attn = torch.matmul(xq * self.scale, xk.transpose(-2, -1))
attn = torch.tanh(attn / self.softcap) * self.softcap
if attention_mask is not None:
attn = attn + attention_mask
attn = torch.nn.functional.softmax(attn.float(), dim=-1).to(xq.dtype)
out = torch.matmul(attn, xv).transpose(1, 2).reshape(B, S, self.inner_size)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Sliding-window attention is still O(S²) and the tokenizer makes it easy to hit.

This path still materializes both attn and sw_mask as dense S x S tensors, so long prompts will run out of memory before the window helps. Line 190 then makes that user-reachable by leaving token length effectively unbounded. Please clamp accepted sequence length until there is a true windowed/chunked attention path here.

As per coding guidelines, comfy/**: Core ML/diffusion engine. Focus on: Memory management and GPU resource handling.

Also applies to: 77-84, 184-192

@comfyanonymous comfyanonymous merged commit f9c84c9 into master May 20, 2026
16 checks passed
@comfyanonymous comfyanonymous deleted the temp_pr1 branch May 20, 2026 15:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant