Skip to content

Commit 2b23ec8

Browse files
yiyixuxuyiyixuxupatrickvonplatensayakpaulDN6
authored
add callbacks to denoising step (huggingface#5427)
* draft1 * update * style * move to the end of loop * update * update callbak_on_step_end_inputs * Revert "update" This reverts commit 5f9b153. * Revert "update callbak_on_step_end_inputs" This reverts commit 44889f4. * update * update test required_optional_params * remove self.lora_scale * img2img * inpaint * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * fix * apply feedbacks on img2img + inpaint: keep only important pipeline attributes * depth * pix2pix * make _callback_tensor_inputs an class variable so that we can use it for testing * add a basic tst for callback * add a read-only tensor input timesteps + fix tests * add second test for callback cfg * sdxl * sdxl img2img * sdxl inpaint * kandinsky prior * kandinsky decoder * kandinsky img2img + combined * kandinsky inpaint * fix copies * fix * consistent default inputs * fix copies * wuerstchen_prior prior * test_wuerstchen_decoder + fix test for prior * wuerstchen_combined pipeline + skip tests * skip test for kandinsky combined * lcm * remove timesteps etc * add doc string * copies * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * make style and improve tests * up * up * fix more * fix cfg test * tests for callbacks * fix for real * update * lcm img2img * add doc * add doc page to index --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
1 parent 080081b commit 2b23ec8

File tree

62 files changed

+2514
-582
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+2514
-582
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
title: Kandinsky
7777
- local: using-diffusers/controlnet
7878
title: ControlNet
79+
- local: using-diffusers/callback
80+
title: Callback
7981
- local: using-diffusers/shap-e
8082
title: Shap-E
8183
- local: using-diffusers/diffedit
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Using callback
14+
15+
[[open-in-colab]]
16+
17+
Most 🤗 Diffusers pipeline now accept a `callback_on_step_end` argument that allows you to change the default behavior of denoising loop with custom defined functions. Here is an example of a callback function we can write to disable classifier free guidance after 40% of inference steps to save compute with minimum tradeoff in performance.
18+
19+
```python
20+
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
21+
# adjust the batch_size of prompt_embeds according to guidance_scale
22+
if step_index == int(pipe.num_timestep * 0.4):
23+
prompt_embeds = callback_kwargs["prompt_embeds"]
24+
prompt_embeds =prompt_embeds.chunk(2)[-1]
25+
26+
# update guidance_scale and prompt_embeds
27+
pipe._guidance_scale = 0.0
28+
callback_kwargs["prompt_embeds"] = prompt_embeds
29+
return callback_kwargs
30+
```
31+
32+
Your callback function has below arguments:
33+
* `pipe` is the pipeline instance, which provides access to useful properties such as `num_timestep` and `guidance_scale`. You can modify these properties by updating the underlying attributes. In this example, we disable CFG by setting `pipe._guidance_scale` to be `0`.
34+
* `step_index` and `timestep` tell you where you are in the denoising loop. In our example, we use `step_index` to decide when to turn off CFG.
35+
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables so please check the pipeline class's `_callback_tensor_inputs` attribute for the list of variables that you can modify. Common variables include `latents` and `prompt_embeds`. In our example, we need to adjust the batch size of `prompt_embeds` after setting `guidance_scale` to be `0` in order for it to work properly.
36+
37+
You can pass the callback function as `callback_on_step_end` argument to the pipeline along with `callback_on_step_end_tensor_inputs`.
38+
39+
```
40+
import torch
41+
from diffusers import StableDiffusionPipeline
42+
43+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
44+
pipe = pipe.to("cuda")
45+
46+
prompt = "a photo of an astronaut riding a horse on mars"
47+
48+
generator = torch.Generator(device="cuda").manual_seed(1)
49+
out= pipe(prompt, generator=generator, callback_on_step_end = callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
50+
51+
out.images[0].save("out_custom_cfg.png")
52+
```
53+
54+
Your callback function will be executed at the end of each denoising step and modify pipeline attributes and tensor variables for the next denoising step. We successfully added the "dynamic CFG" feature to the stable diffusion pipeline without having to modify the code at all.
55+
56+
<Tip>
57+
58+
Currently we only support `callback_on_step_end`. If you have a solid use case and require a callback function with a different execution point, please open an [feature request](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
59+
60+
</Tip>

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 104 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
110110
model_cpu_offload_seq = "text_encoder->unet->vae"
111111
_optional_components = ["safety_checker", "feature_extractor"]
112112
_exclude_from_cpu_offload = ["safety_checker"]
113+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
113114

114115
def __init__(
115116
self,
@@ -500,17 +501,23 @@ def check_inputs(
500501
negative_prompt=None,
501502
prompt_embeds=None,
502503
negative_prompt_embeds=None,
504+
callback_on_step_end_tensor_inputs=None,
503505
):
504506
if height % 8 != 0 or width % 8 != 0:
505507
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
506508

507-
if (callback_steps is None) or (
508-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
509-
):
509+
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
510510
raise ValueError(
511511
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
512512
f" {type(callback_steps)}."
513513
)
514+
if callback_on_step_end_tensor_inputs is not None and not all(
515+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
516+
):
517+
raise ValueError(
518+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
519+
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
520+
)
514521

515522
if prompt is not None and prompt_embeds is not None:
516523
raise ValueError(
@@ -581,6 +588,33 @@ def disable_freeu(self):
581588
"""Disables the FreeU mechanism if enabled."""
582589
self.unet.disable_freeu()
583590

591+
@property
592+
def guidance_scale(self):
593+
return self._guidance_scale
594+
595+
@property
596+
def guidance_rescale(self):
597+
return self._guidance_rescale
598+
599+
@property
600+
def clip_skip(self):
601+
return self._clip_skip
602+
603+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
604+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
605+
# corresponds to doing no classifier free guidance.
606+
@property
607+
def do_classifier_free_guidance(self):
608+
return self._guidance_scale > 1
609+
610+
@property
611+
def cross_attention_kwargs(self):
612+
return self._cross_attention_kwargs
613+
614+
@property
615+
def num_timesteps(self):
616+
return self._num_timesteps
617+
584618
@torch.no_grad()
585619
@replace_example_docstring(EXAMPLE_DOC_STRING)
586620
def __call__(
@@ -599,11 +633,12 @@ def __call__(
599633
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
600634
output_type: Optional[str] = "pil",
601635
return_dict: bool = True,
602-
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
603-
callback_steps: int = 1,
604636
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
605637
guidance_rescale: float = 0.0,
606638
clip_skip: Optional[int] = None,
639+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
640+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
641+
**kwargs,
607642
):
608643
r"""
609644
The call function to the pipeline for generation.
@@ -647,12 +682,6 @@ def __call__(
647682
return_dict (`bool`, *optional*, defaults to `True`):
648683
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
649684
plain tuple.
650-
callback (`Callable`, *optional*):
651-
A function that calls every `callback_steps` steps during inference. The function is called with the
652-
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
653-
callback_steps (`int`, *optional*, defaults to 1):
654-
The frequency at which the `callback` function is called. If not specified, the callback is called at
655-
every step.
656685
cross_attention_kwargs (`dict`, *optional*):
657686
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
658687
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
@@ -663,6 +692,15 @@ def __call__(
663692
clip_skip (`int`, *optional*):
664693
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
665694
the output of the pre-final layer will be used for computing the prompt embeddings.
695+
callback_on_step_end (`Callable`, *optional*):
696+
A function that calls at the end of each denoising steps during the inference. The function is called
697+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
698+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
699+
`callback_on_step_end_tensor_inputs`.
700+
callback_on_step_end_tensor_inputs (`List`, *optional*):
701+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
702+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
703+
`._callback_tensor_inputs` attribute of your pipeine class.
666704
667705
Examples:
668706
@@ -673,16 +711,47 @@ def __call__(
673711
second element is a list of `bool`s indicating whether the corresponding generated image contains
674712
"not-safe-for-work" (nsfw) content.
675713
"""
714+
715+
callback = kwargs.pop("callback", None)
716+
callback_steps = kwargs.pop("callback_steps", None)
717+
718+
if callback is not None:
719+
deprecate(
720+
"callback",
721+
"1.0.0",
722+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using"
723+
" `callback_on_step_end`",
724+
)
725+
if callback_steps is not None:
726+
deprecate(
727+
"callback_steps",
728+
"1.0.0",
729+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using"
730+
" `callback_on_step_end`",
731+
)
732+
676733
# 0. Default height and width to unet
677734
height = height or self.unet.config.sample_size * self.vae_scale_factor
678735
width = width or self.unet.config.sample_size * self.vae_scale_factor
679736
# to deal with lora scaling and other possible forward hooks
680737

681738
# 1. Check inputs. Raise error if not correct
682739
self.check_inputs(
683-
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
740+
prompt,
741+
height,
742+
width,
743+
callback_steps,
744+
negative_prompt,
745+
prompt_embeds,
746+
negative_prompt_embeds,
747+
callback_on_step_end_tensor_inputs,
684748
)
685749

750+
self._guidance_scale = guidance_scale
751+
self._guidance_rescale = guidance_rescale
752+
self._clip_skip = clip_skip
753+
self._cross_attention_kwargs = cross_attention_kwargs
754+
686755
# 2. Define call parameters
687756
if prompt is not None and isinstance(prompt, str):
688757
batch_size = 1
@@ -692,29 +761,27 @@ def __call__(
692761
batch_size = prompt_embeds.shape[0]
693762

694763
device = self._execution_device
695-
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
696-
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
697-
# corresponds to doing no classifier free guidance.
698-
do_classifier_free_guidance = guidance_scale > 1.0
699764

700765
# 3. Encode input prompt
701-
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
766+
lora_scale = (
767+
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
768+
)
702769

703770
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
704771
prompt,
705772
device,
706773
num_images_per_prompt,
707-
do_classifier_free_guidance,
774+
self.do_classifier_free_guidance,
708775
negative_prompt,
709776
prompt_embeds=prompt_embeds,
710777
negative_prompt_embeds=negative_prompt_embeds,
711778
lora_scale=lora_scale,
712-
clip_skip=clip_skip,
779+
clip_skip=self.clip_skip,
713780
)
714781
# For classifier free guidance, we need to do two forward passes.
715782
# Here we concatenate the unconditional and text embeddings into a single batch
716783
# to avoid doing two forward passes
717-
if do_classifier_free_guidance:
784+
if self.do_classifier_free_guidance:
718785
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
719786

720787
# 4. Prepare timesteps
@@ -739,33 +806,44 @@ def __call__(
739806

740807
# 7. Denoising loop
741808
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
809+
self._num_timesteps = len(timesteps)
742810
with self.progress_bar(total=num_inference_steps) as progress_bar:
743811
for i, t in enumerate(timesteps):
744812
# expand the latents if we are doing classifier free guidance
745-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
813+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
746814
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
747815

748816
# predict the noise residual
749817
noise_pred = self.unet(
750818
latent_model_input,
751819
t,
752820
encoder_hidden_states=prompt_embeds,
753-
cross_attention_kwargs=cross_attention_kwargs,
821+
cross_attention_kwargs=self.cross_attention_kwargs,
754822
return_dict=False,
755823
)[0]
756824

757825
# perform guidance
758-
if do_classifier_free_guidance:
826+
if self.do_classifier_free_guidance:
759827
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
760-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
828+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
761829

762-
if do_classifier_free_guidance and guidance_rescale > 0.0:
830+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
763831
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
764-
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
832+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
765833

766834
# compute the previous noisy sample x_t -> x_t-1
767835
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
768836

837+
if callback_on_step_end is not None:
838+
callback_kwargs = {}
839+
for k in callback_on_step_end_tensor_inputs:
840+
callback_kwargs[k] = locals()[k]
841+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
842+
843+
latents = callback_outputs.pop("latents", latents)
844+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
845+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
846+
769847
# call the callback, if provided
770848
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
771849
progress_bar.update()

0 commit comments

Comments
 (0)