Description
Describe the bug
I'm experimenting with attention masking in Stable Diffusion (so that padding tokens aren't considered for cross attention), and I found that UNet2DConditionModel doesn't work when given an attention_mask
.
For the attn1 blocks (self-attention), the target sequence length is different from the current length (target 4096, but it's only 77 for a typical CLIP output). The padding routine pads by adding target_length
zeros to the end of the last dimension, which results in a sequence length of 4096 + 77, rather than the desired 4096. I think it should be:
- attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = F.pad(attention_mask, (0, target_length - current_length), value=0.0)
encoder_attention_mask
works fine - it's passed to the attn2 block and no padding ends up being necessary.
It seems that this would additionally fail if current_length were greater than target_length, since you can't pad by a negative amount, but I don't know that that's a practical concern.
(I know that particular masking isn't even semantically valid, but that's orthogonal to this issue!)
Reproduction
# given a Stable Diffusion pipeline
# given te_mask = tokenizer_output.attention_mask
pipeline.unet(latent_input, timestep, text_encoder_output, attention_mask=te_mask).sample
Logs
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-6.8.0-55-generic-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.10.11
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.28.1
- Transformers version: 4.48.3
- Accelerate version: 1.3.0
- PEFT version: not installed
- Bitsandbytes version: 0.45.2
- Safetensors version: 0.5.2
- xFormers version: 0.0.29.post2
- Accelerator: NVIDIA GeForce RTX 3060, 12288 MiB
NVIDIA GeForce RTX 4060 Ti, 16380 MiB - Using GPU in script?:
- Using distributed or parallel set-up in script?: No
Who can help?
No response