Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prepare_attention_mask - incorrect padding? #11063

Open
cheald opened this issue Mar 14, 2025 · 1 comment
Open

prepare_attention_mask - incorrect padding? #11063

cheald opened this issue Mar 14, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@cheald
Copy link

cheald commented Mar 14, 2025

Describe the bug

I'm experimenting with attention masking in Stable Diffusion (so that padding tokens aren't considered for cross attention), and I found that UNet2DConditionModel doesn't work when given an attention_mask.

attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)

For the attn1 blocks (self-attention), the target sequence length is different from the current length (target 4096, but it's only 77 for a typical CLIP output). The padding routine pads by adding target_length zeros to the end of the last dimension, which results in a sequence length of 4096 + 77, rather than the desired 4096. I think it should be:

- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = F.pad(attention_mask, (0, target_length - current_length), value=0.0)

encoder_attention_mask works fine - it's passed to the attn2 block and no padding ends up being necessary.

It seems that this would additionally fail if current_length were greater than target_length, since you can't pad by a negative amount, but I don't know that that's a practical concern.

(I know that particular masking isn't even semantically valid, but that's orthogonal to this issue!)

Reproduction

# given a Stable Diffusion pipeline
# given te_mask = tokenizer_output.attention_mask
pipeline.unet(latent_input, timestep, text_encoder_output, attention_mask=te_mask).sample

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.10.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.28.1
  • Transformers version: 4.48.3
  • Accelerate version: 1.3.0
  • PEFT version: not installed
  • Bitsandbytes version: 0.45.2
  • Safetensors version: 0.5.2
  • xFormers version: 0.0.29.post2
  • Accelerator: NVIDIA GeForce RTX 3060, 12288 MiB
    NVIDIA GeForce RTX 4060 Ti, 16380 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?: No

Who can help?

No response

@cheald cheald added the bug Something isn't working label Mar 14, 2025
@cheald
Copy link
Author

cheald commented Mar 15, 2025

Additionally, are there any examples of attention_mask being used with UnetCondition2d? Since it's a unet, the blocks are of descending sizes, but there's just the single attention mask accepted. It seems like you would need to accept a mask per block level (or to downsample the mask to match the block size at each depth). I've got my local UnetCondition2d hacked up to do just that, and it seems to work, but I'd like to understand what the intended usage principle behind attention_mask is as-is.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant