You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When using StableDiffusionXLImg2ImgPipeline with the SDXL-Turbo model, setting num_inference_steps=1 and certain strength values (typically > 0.5), the pipeline crashes with a tensor error during generation. The issue occurs specifically in the get_timesteps method, where an empty tensor is created due to how the timestep calculation is performed.
The root cause is in the get_timesteps method in pipeline_stable_diffusion_xl_img2img.py. For certain combinations of strength and num_inference_steps, the method returns an empty list of timesteps which causes the latent tensor to become empty later in the pipeline.
Current Behavior
The get_timesteps method calculates:
init_timestep = min(int(1 * 0.8), 1) = 0
t_start = max(1 - 0, 0) = 1
Resulting in timesteps = scheduler.timesteps[1 * scheduler.order:] which is an empty slice
This empty tensor later causes a failure in the add_noise method when trying to apply noise to the latents.
Expected Behavior
The pipeline should ensure that at least one timestep is available regardless of the strength and inference step settings. A patch could include a check to ensure the timestep slice is never empty, falling back to a single timestep if necessary.
Workaround
A temporary workaround is to monkey-patch the get_timesteps method:
import types
# Fix the get_timesteps method
original_get_timesteps = pipe.get_timesteps
def fixed_get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
if denoising_start is None:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
# Check if we'd get an empty slice and adjust if needed
if t_start * self.scheduler.order >= len(self.scheduler.timesteps):
t_start = 0 # Fallback to using all timesteps
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
# Extra safety check
if len(timesteps) == 0:
timesteps = self.scheduler.timesteps[0:1] # Always use at least one timestep
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
else:
return original_get_timesteps(num_inference_steps, strength, device, denoising_start)
# Apply the patch
pipe.get_timesteps = types.MethodType(fixed_get_timesteps, pipe)
This issue is particularly significant for users of SDXL-Turbo, which is designed for ultra-fast generation with just 1-4 inference steps.
Reproduction
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
# Load model
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16
).to("mps") # This can be any device: cpu, cuda, mps
# Load IP-Adapter (this is not required to trigger the bug, but shows a common use case)
pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="sdxl_models",
weight_name="ip-adapter_sdxl.safetensors"
)
# Load input image
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png").resize((512, 512))
# This will trigger the bug
try:
output = pipe(
prompt="high quality, best quality",
image=image,
strength=0.8, # Values > 0.5 with 1 step often trigger the issue
num_inference_steps=1,
guidance_scale=1.0
)
except Exception as e:
print(f"Error: {e}")
# The error occurs during the pipeline run, either as an empty tensor error
# or down the pipeline when it tries to reshape a tensor with 0 elements
Logs
.../.venv/bin/python .../ipadapter/tester.py
Loading pipeline components...: 100%|██████████| 7/7 [00:14<00:00, 2.04s/it]
0it [00:00, ?it/s]
Traceback (most recent call last):
File ".../ipadapter/tester.py", line 129, in<module>
new_image = pipeline(
^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py", line 1478, in __call__
image = self.vae.decode(latents, return_dict=False)[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 327, in decode
decoded = self._decode(z).sample
^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 298, in _decode
dec = self.decoder(z)
^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/models/autoencoders/vae.py", line 345, in forward
sample = self.mid_block(sample, latent_embeds)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 761, in forward
hidden_states = attn(hidden_states, temb=temb)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 588, in forward
return self.processor(
^^^^^^^^^^^^^^^
File ".../.venv/lib/python3.12/site-packages/diffusers/models/attention_processor.py", line 3274, in __call__
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cannot reshape tensor of 0 elements into shape [0, -1, 1, 512] because the unspecified dimension size -1 can be any value and is ambiguous
Process finished with exit code 1
System Info
diffusers version: 0.24.0
Platform: macOS
Python version: 3.12
PyTorch version (GPU?): 2.1.0 (MPS)
Using GPU in script?: Yes (Apple MPS)
Using distributed or parallel set-up in script?: No
The text was updated successfully, but these errors were encountered:
Describe the bug
When using StableDiffusionXLImg2ImgPipeline with the SDXL-Turbo model, setting num_inference_steps=1 and certain strength values (typically > 0.5), the pipeline crashes with a tensor error during generation. The issue occurs specifically in the get_timesteps method, where an empty tensor is created due to how the timestep calculation is performed.
The root cause is in the get_timesteps method in pipeline_stable_diffusion_xl_img2img.py. For certain combinations of strength and num_inference_steps, the method returns an empty list of timesteps which causes the latent tensor to become empty later in the pipeline.
Current Behavior
The get_timesteps method calculates:
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
Then retrieves timesteps: timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
When num_inference_steps=1 and strength=0.8:
init_timestep = min(int(1 * 0.8), 1) = 0
t_start = max(1 - 0, 0) = 1
Resulting in timesteps = scheduler.timesteps[1 * scheduler.order:] which is an empty slice
This empty tensor later causes a failure in the add_noise method when trying to apply noise to the latents.
Expected Behavior
The pipeline should ensure that at least one timestep is available regardless of the strength and inference step settings. A patch could include a check to ensure the timestep slice is never empty, falling back to a single timestep if necessary.
Workaround
A temporary workaround is to monkey-patch the get_timesteps method:
Reproduction
Logs
System Info
diffusers version: 0.24.0
Platform: macOS
Python version: 3.12
PyTorch version (GPU?): 2.1.0 (MPS)
Using GPU in script?: Yes (Apple MPS)
Using distributed or parallel set-up in script?: No
The text was updated successfully, but these errors were encountered: