diff --git a/.github/actions/setup-miniconda/action.yml b/.github/actions/setup-miniconda/action.yml index 8a82ae8b17bf..cc755d3aad79 100644 --- a/.github/actions/setup-miniconda/action.yml +++ b/.github/actions/setup-miniconda/action.yml @@ -27,7 +27,7 @@ runs: - name: Get date id: get-date shell: bash - run: echo "::set-output name=today::$(/bin/date -u '+%Y%m%d')d" + run: echo "today=$(/bin/date -u '+%Y%m%d')d" >> $GITHUB_OUTPUT - name: Setup miniconda cache id: miniconda-cache uses: actions/cache@v2 @@ -143,4 +143,4 @@ runs: echo "There is ${AVAIL}KB free space left in $MOUNT, continue" fi fi - done \ No newline at end of file + done diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index c202cc628542..c833bc0319e1 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -5,7 +5,7 @@ on: branches: - main - doc-builder* - - v*-release + - v*-patch jobs: build: diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml index d06b576fa631..162b1ba83d66 100644 --- a/.github/workflows/pr_tests.yml +++ b/.github/workflows/pr_tests.yml @@ -36,11 +36,6 @@ jobs: runner: docker-cpu image: diffusers/diffusers-flax-cpu report: flax_cpu - - name: Fast ONNXRuntime CPU tests - framework: onnxruntime - runner: docker-cpu - image: diffusers/diffusers-onnxruntime-cpu - report: onnx_cpu - name: PyTorch Example CPU tests framework: pytorch_examples runner: docker-cpu @@ -69,8 +64,6 @@ jobs: run: | apt-get update && apt-get install libsndfile1-dev -y python -m pip install -e .[quality,test] - python -m pip install -U git+https://github.com/huggingface/transformers - python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -100,14 +93,6 @@ jobs: --make-reports=tests_${{ matrix.config.report }} \ tests - - name: Run fast ONNXRuntime CPU tests - if: ${{ matrix.config.framework == 'onnxruntime' }} - run: | - python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \ - -s -v -k "Onnx" \ - --make-reports=tests_${{ matrix.config.report }} \ - tests/ - - name: Run example PyTorch CPU tests if: ${{ matrix.config.framework == 'pytorch_examples' }} run: | @@ -125,56 +110,3 @@ jobs: with: name: pr_${{ matrix.config.report }}_test_reports path: reports - - run_fast_tests_apple_m1: - name: Fast PyTorch MPS tests on MacOS - runs-on: [ self-hosted, apple-m1 ] - - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Clean checkout - shell: arch -arch arm64 bash {0} - run: | - git clean -fxd - - - name: Setup miniconda - uses: ./.github/actions/setup-miniconda - with: - python-version: 3.9 - - - name: Install dependencies - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python -m pip install --upgrade pip - ${CONDA_RUN} python -m pip install -e .[quality,test] - ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate - ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - - - name: Environment - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python utils/print_env.py - - - name: Run fast PyTorch tests on M1 (MPS) - shell: arch -arch arm64 bash {0} - env: - HF_HOME: /System/Volumes/Data/mnt/cache - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - run: | - ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: cat reports/tests_torch_mps_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v2 - with: - name: pr_torch_mps_test_reports - path: reports diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 2d4875b80ced..567cd5f5b0d4 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -17,6 +17,7 @@ jobs: run_slow_tests: strategy: fail-fast: false + max-parallel: 1 matrix: config: - name: Slow PyTorch CUDA tests on Ubuntu @@ -61,8 +62,6 @@ jobs: - name: Install dependencies run: | python -m pip install -e .[quality,test] - python -m pip install -U git+https://github.com/huggingface/transformers - python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -72,6 +71,9 @@ jobs: if: ${{ matrix.config.framework == 'pytorch' }} env: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms + CUBLAS_WORKSPACE_CONFIG: :16:8 + run: | python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ -s -v -k "not Flax and not Onnx" \ @@ -131,8 +133,6 @@ jobs: - name: Install dependencies run: | python -m pip install -e .[quality,test,training] - python -m pip install git+https://github.com/huggingface/accelerate - python -m pip install -U git+https://github.com/huggingface/transformers - name: Environment run: | diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml index 525df28cbaa8..adf4fc8a87bc 100644 --- a/.github/workflows/push_tests_fast.yml +++ b/.github/workflows/push_tests_fast.yml @@ -1,4 +1,4 @@ -name: Slow tests on main +name: Fast tests on main on: push: @@ -62,8 +62,6 @@ jobs: run: | apt-get update && apt-get install libsndfile1-dev -y python -m pip install -e .[quality,test] - python -m pip install -U git+https://github.com/huggingface/transformers - python -m pip install git+https://github.com/huggingface/accelerate - name: Environment run: | @@ -110,56 +108,3 @@ jobs: with: name: pr_${{ matrix.config.report }}_test_reports path: reports - - run_fast_tests_apple_m1: - name: Fast PyTorch MPS tests on MacOS - runs-on: [ self-hosted, apple-m1 ] - - steps: - - name: Checkout diffusers - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - name: Clean checkout - shell: arch -arch arm64 bash {0} - run: | - git clean -fxd - - - name: Setup miniconda - uses: ./.github/actions/setup-miniconda - with: - python-version: 3.9 - - - name: Install dependencies - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python -m pip install --upgrade pip - ${CONDA_RUN} python -m pip install -e .[quality,test] - ${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu - ${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate - ${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers - - - name: Environment - shell: arch -arch arm64 bash {0} - run: | - ${CONDA_RUN} python utils/print_env.py - - - name: Run fast PyTorch tests on M1 (MPS) - shell: arch -arch arm64 bash {0} - env: - HF_HOME: /System/Volumes/Data/mnt/cache - HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} - run: | - ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ - - - name: Failure short reports - if: ${{ failure() }} - run: cat reports/tests_torch_mps_failures_short.txt - - - name: Test suite reports artifacts - if: ${{ always() }} - uses: actions/upload-artifact@v2 - with: - name: pr_torch_mps_test_reports - path: reports diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml new file mode 100644 index 000000000000..6b95815f1ea5 --- /dev/null +++ b/.github/workflows/push_tests_mps.yml @@ -0,0 +1,68 @@ +name: Fast mps tests on main + +on: + push: + branches: + - main + +env: + DIFFUSERS_IS_CI: yes + HF_HOME: /mnt/cache + OMP_NUM_THREADS: 8 + MKL_NUM_THREADS: 8 + PYTEST_TIMEOUT: 600 + RUN_SLOW: no + +jobs: + run_fast_tests_apple_m1: + name: Fast PyTorch MPS tests on MacOS + runs-on: [ self-hosted, apple-m1 ] + + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - name: Clean checkout + shell: arch -arch arm64 bash {0} + run: | + git clean -fxd + + - name: Setup miniconda + uses: ./.github/actions/setup-miniconda + with: + python-version: 3.9 + + - name: Install dependencies + shell: arch -arch arm64 bash {0} + run: | + ${CONDA_RUN} python -m pip install --upgrade pip + ${CONDA_RUN} python -m pip install -e .[quality,test] + ${CONDA_RUN} python -m pip install torch torchvision torchaudio + ${CONDA_RUN} python -m pip install accelerate --upgrade + ${CONDA_RUN} python -m pip install transformers --upgrade + + - name: Environment + shell: arch -arch arm64 bash {0} + run: | + ${CONDA_RUN} python utils/print_env.py + + - name: Run fast PyTorch tests on M1 (MPS) + shell: arch -arch arm64 bash {0} + env: + HF_HOME: /System/Volumes/Data/mnt/cache + HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} + run: | + ${CONDA_RUN} python -m pytest -n 0 -s -v --make-reports=tests_torch_mps tests/ + + - name: Failure short reports + if: ${{ failure() }} + run: cat reports/tests_torch_mps_failures_short.txt + + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v2 + with: + name: pr_torch_mps_test_reports + path: reports diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5ce48793e9c2..9c5f0a10edd3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -125,14 +125,14 @@ Awesome! Tell us what problem it solved for you. You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=). -#### 2.3 Feedback. +#### 2.3 Feedback. Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed. If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions. You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=). -#### 2.4 Technical questions. +#### 2.4 Technical questions. Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on why this part of the code is difficult to understand. @@ -394,8 +394,8 @@ passes. You should run the tests impacted by your changes like this: ```bash $ pytest tests/.py ``` - -Before you run the tests, please make sure you install the dependencies required for testing. You can do so + +Before you run the tests, please make sure you install the dependencies required for testing. You can do so with this command: ```bash diff --git a/PHILOSOPHY.md b/PHILOSOPHY.md index fbad5948e17e..399cb0bfb47d 100644 --- a/PHILOSOPHY.md +++ b/PHILOSOPHY.md @@ -27,18 +27,18 @@ In a nutshell, Diffusers is built to be a natural extension of PyTorch. Therefor ## Simple over easy -As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library: +As PyTorch states, **explicit is better than implicit** and **simple is better than complex**. This design philosophy is reflected in multiple parts of the library: - We follow PyTorch's API with methods like [`DiffusionPipeline.to`](https://huggingface.co/docs/diffusers/main/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.to) to let the user handle device management. - Raising concise error messages is preferred to silently correct erroneous input. Diffusers aims at teaching the user, rather than making the library as easy to use as possible. - Complex model vs. scheduler logic is exposed instead of magically handled inside. Schedulers/Samplers are separated from diffusion models with minimal dependencies on each other. This forces the user to write the unrolled denoising loop. However, the separation allows for easier debugging and gives the user more control over adapting the denoising process or switching out diffusion models or schedulers. -- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. Dreambooth or textual inversion training +- Separately trained components of the diffusion pipeline, *e.g.* the text encoder, the unet, and the variational autoencoder, each have their own model class. This forces the user to handle the interaction between the different model components, and the serialization format separates the model components into different files. However, this allows for easier debugging and customization. Dreambooth or textual inversion training is very simple thanks to diffusers' ability to separate single components of the diffusion pipeline. ## Tweakable, contributor-friendly over abstraction -For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself). +For large parts of the library, Diffusers adopts an important design principle of the [Transformers library](https://github.com/huggingface/transformers), which is to prefer copy-pasted code over hasty abstractions. This design principle is very opinionated and stands in stark contrast to popular design principles such as [Don't repeat yourself (DRY)](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself). In short, just like Transformers does for modeling files, diffusers prefers to keep an extremely low level of abstraction and very self-contained code for pipelines and schedulers. -Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable. +Functions, long code blocks, and even classes can be copied across multiple files which at first can look like a bad, sloppy design choice that makes the library unmaintainable. **However**, this design has proven to be extremely successful for Transformers and makes a lot of sense for community-driven, open-source machine learning libraries because: - Machine Learning is an extremely fast-moving field in which paradigms, model architectures, and algorithms are changing rapidly, which therefore makes it very difficult to define long-lasting code abstractions. - Machine Learning practitioners like to be able to quickly tweak existing code for ideation and research and therefore prefer self-contained code over one that contains many abstractions. @@ -47,10 +47,10 @@ Functions, long code blocks, and even classes can be copied across multiple file At Hugging Face, we call this design the **single-file policy** which means that almost all of the code of a certain class should be written in a single, self-contained file. To read more about the philosophy, you can have a look at [this blog post](https://huggingface.co/blog/transformers-design-philosophy). -In diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such +In diffusers, we follow this philosophy for both pipelines and schedulers, but only partly for diffusion models. The reason we don't follow this design fully for diffusion models is because almost all diffusion pipelines, such as [DDPM](https://huggingface.co/docs/diffusers/v0.12.0/en/api/pipelines/ddpm), [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.12.0/en/api/pipelines/stable_diffusion/overview#stable-diffusion-pipelines), [UnCLIP (Dalle-2)](https://huggingface.co/docs/diffusers/v0.12.0/en/api/pipelines/unclip#overview) and [Imagen](https://imagen.research.google/) all rely on the same diffusion model, the [UNet](https://huggingface.co/docs/diffusers/api/models#diffusers.UNet2DConditionModel). -Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗. +Great, now you should have generally understood why 🧨 Diffusers is designed the way it is 🤗. We try to apply these design principles consistently across the library. Nevertheless, there are some minor exceptions to the philosophy or some unlucky design choices. If you have feedback regarding the design, we would ❤️ to hear it [directly on GitHub](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=). ## Design Philosophy in Details @@ -89,7 +89,7 @@ The following design principles are followed: - Models should by default have the highest precision and lowest performance setting. - To integrate new model checkpoints whose general architecture can be classified as an architecture that already exists in Diffusers, the existing model architecture shall be adapted to make it work with the new checkpoint. One should only create a new file if the model architecture is fundamentally different. - Models should be designed to be easily extendable to future changes. This can be achieved by limiting public function arguments, configuration arguments, and "foreseeing" future changes, *e.g.* it is usually better to add `string` "...type" arguments that can easily be extended to new future types instead of boolean `is_..._type` arguments. Only the minimum amount of changes shall be made to existing architectures to make a new model checkpoint work. -- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and +- The model design is a difficult trade-off between keeping code readable and concise and supporting many model checkpoints. For most parts of the modeling code, classes shall be adapted for new model checkpoints, while there are some exceptions where it is preferred to add new classes to make sure the code is kept concise and readable longterm, such as [UNet blocks](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py) and [Attention processors](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). ### Schedulers @@ -97,9 +97,9 @@ readable longterm, such as [UNet blocks](https://github.com/huggingface/diffuser Schedulers are responsible to guide the denoising process for inference as well as to define a noise schedule for training. They are designed as individual classes with loadable configuration files and strongly follow the **single-file policy**. The following design principles are followed: -- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). -- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained. -- One scheduler python file corresponds to one scheduler algorithm (as might be defined in a paper). +- All schedulers are found in [`src/diffusers/schedulers`](https://github.com/huggingface/diffusers/tree/main/src/diffusers/schedulers). +- Schedulers are **not** allowed to import from large utils files and shall be kept very self-contained. +- One scheduler python file corresponds to one scheduler algorithm (as might be defined in a paper). - If schedulers share similar functionalities, we can make use of the `#Copied from` mechanism. - Schedulers all inherit from `SchedulerMixin` and `ConfigMixin`. - Schedulers can be easily swapped out with the [`ConfigMixin.from_config`](https://huggingface.co/docs/diffusers/main/en/api/configuration#diffusers.ConfigMixin.from_config) method as explained in detail [here](./using-diffusers/schedulers.mdx). diff --git a/README.md b/README.md index 76d7df79c813..c2a3b04b57a8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@


- +

@@ -30,7 +30,7 @@ We recommend installing 🤗 Diffusers in a virtual environment from PyPi or Con ### PyTorch With `pip` (official package): - + ```bash pip install --upgrade diffusers[torch] ``` @@ -59,8 +59,9 @@ Generating outputs is super easy with 🤗 Diffusers. To generate an image from ```python from diffusers import DiffusionPipeline +import torch -pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) pipeline.to("cuda") pipeline("An image of a squirrel in Picasso style").images[0] ``` @@ -99,58 +100,14 @@ Check out the [Quickstart](https://huggingface.co/docs/diffusers/quicktour) to l | **Documentation** | **What can I learn?** | |---------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| Tutorial | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. | -| Loading | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. | -| Pipelines for inference | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. | -| Optimization | Guides for how to optimize your diffusion model to run faster and consume less memory. | +| [Tutorial](https://huggingface.co/docs/diffusers/tutorials/tutorial_overview) | A basic crash course for learning how to use the library's most important features like using models and schedulers to build your own diffusion system, and training your own diffusion model. | +| [Loading](https://huggingface.co/docs/diffusers/using-diffusers/loading_overview) | Guides for how to load and configure all the components (pipelines, models, and schedulers) of the library, as well as how to use different schedulers. | +| [Pipelines for inference](https://huggingface.co/docs/diffusers/using-diffusers/pipeline_overview) | Guides for how to use pipelines for different inference tasks, batched generation, controlling generated outputs and randomness, and how to contribute a pipeline to the library. | +| [Optimization](https://huggingface.co/docs/diffusers/optimization/opt_overview) | Guides for how to optimize your diffusion model to run faster and consume less memory. | | [Training](https://huggingface.co/docs/diffusers/training/overview) | Guides for how to train a diffusion model for different tasks with different training techniques. | - -## Supported pipelines - -| Pipeline | Paper | Tasks | -|---|---|:---:| -| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | -| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | -| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | -| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | -| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | -| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | -| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | -| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | -| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | -| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting | -| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | -| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [**Semantic Guidance**](https://arxiv.org/abs/2301.12247) | Text-Guided Generation | -| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | -| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | -| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | -| [stable_diffusion_panorama](./api/pipelines/stable_diffusion/panorama) | [**MultiDiffusion**](https://multidiffusion.github.io/) | Text-to-Panorama Generation | -| [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [**InstructPix2Pix**](https://github.com/timothybrooks/instruct-pix2pix) | Text-Guided Image Editing| -| [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [**Zero-shot Image-to-Image Translation**](https://pix2pixzero.github.io/) | Text-Guided Image Editing | -| [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [**Attend and Excite for Stable Diffusion**](https://attendandexcite.github.io/Attend-and-Excite/) | Text-to-Image Generation | -| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [**Self-Attention Guidance**](https://ku-cvlab.github.io/Self-Attention-Guidance) | Text-to-Image Generation | -| [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [**Stable Diffusion Image Variations**](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation | -| [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [**Stable Diffusion Latent Upscaler**](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Depth-Conditional Stable Diffusion**](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | -| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | -| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation | -| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation | -| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | -| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation | -| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | -| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | -| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | -| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | - ## Contribution -We ❤️ contributions from the open-source community! +We ❤️ contributions from the open-source community! If you want to contribute to this library, please check out our [Contribution guide](https://github.com/huggingface/diffusers/blob/main/CONTRIBUTING.md). You can look out for [issues](https://github.com/huggingface/diffusers/issues) you'd like to tackle to contribute to the library. - See [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) for general opportunities to contribute @@ -160,6 +117,87 @@ You can look out for [issues](https://github.com/huggingface/diffusers/issues) y Also, say 👋 in our public Discord channel Join us on Discord. We discuss the hottest trends about diffusion models, help each other with contributions, personal projects or just hang out ☕. + +## Popular Tasks & Pipelines + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
TaskPipeline🤗 Hub
Unconditional Image Generation DDPM google/ddpm-ema-church-256
Text-to-ImageStable Diffusion Text-to-Image runwayml/stable-diffusion-v1-5
Text-to-Imageunclip kakaobrain/karlo-v1-alpha
Text-to-Imageif DeepFloyd/IF-I-XL-v1.0
Text-guided Image-to-ImageControlnet lllyasviel/sd-controlnet-canny
Text-guided Image-to-ImageInstruct Pix2Pix timbrooks/instruct-pix2pix
Text-guided Image-to-ImageStable Diffusion Image-to-Image runwayml/stable-diffusion-v1-5
Text-guided Image InpaintingStable Diffusion Inpaint runwayml/stable-diffusion-inpainting
Image VariationStable Diffusion Image Variation lambdalabs/sd-image-variations-diffusers
Super ResolutionStable Diffusion Upscale stabilityai/stable-diffusion-x4-upscaler
Super ResolutionStable Diffusion Latent Upscale stabilityai/sd-x2-latent-upscaler
+ +## Popular libraries using 🧨 Diffusers + +- https://github.com/microsoft/TaskMatrix +- https://github.com/invoke-ai/InvokeAI +- https://github.com/apple/ml-stable-diffusion +- https://github.com/Sanster/lama-cleaner +- https://github.com/IDEA-Research/Grounded-Segment-Anything +- https://github.com/ashawkey/stable-dreamfusion +- https://github.com/deep-floyd/IF +- https://github.com/bentoml/BentoML +- https://github.com/bmaltais/kohya_ss +- +3000 other amazing GitHub repositories 💪 + +Thank you for using us ❤️ + ## Credits This library concretizes previous work by many different authors and would not have been possible without their great research and implementations. We'd like to thank, in particular, the following implementations which have helped us in our development and without which the API could not have been as polished today: diff --git a/docker/diffusers-pytorch-cuda/Dockerfile b/docker/diffusers-pytorch-cuda/Dockerfile index 8087be429996..6b56403a6f94 100644 --- a/docker/diffusers-pytorch-cuda/Dockerfile +++ b/docker/diffusers-pytorch-cuda/Dockerfile @@ -26,7 +26,7 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ python3 -m pip install --no-cache-dir \ torch \ torchvision \ - torchaudio \ + torchaudio && \ python3 -m pip install --no-cache-dir \ accelerate \ datasets \ @@ -37,6 +37,9 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip && \ numpy \ scipy \ tensorboard \ - transformers + transformers \ + omegaconf \ + pytorch-lightning \ + xformers CMD ["/bin/bash"] diff --git a/docs/source/_config.py b/docs/source/_config.py index 9a4818ea8b1e..3d0d73dcb951 100644 --- a/docs/source/_config.py +++ b/docs/source/_config.py @@ -6,4 +6,4 @@ # ! pip install git+https://github.com/huggingface/diffusers.git """ -notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] \ No newline at end of file +notebook_first_cells = [{"type": "code", "content": INSTALL_CONTENT}] diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index ccaaff7ca680..5084299bb0dd 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -26,8 +26,10 @@ title: Load and compare different schedulers - local: using-diffusers/custom_pipeline_overview title: Load community pipelines - - local: using-diffusers/kerascv - title: Load KerasCV Stable Diffusion checkpoints + - local: using-diffusers/using_safetensors + title: Load safetensors + - local: using-diffusers/other-formats + title: Load different Stable Diffusion formats title: Loading & Hub - sections: - local: using-diffusers/pipeline_overview @@ -42,6 +44,10 @@ title: Text-guided image-inpainting - local: using-diffusers/depth2img title: Text-guided depth-to-image + - local: using-diffusers/textual_inversion_inference + title: Textual inversion + - local: training/distributed_inference + title: Distributed inference with multiple GPUs - local: using-diffusers/reusing_seeds title: Improve image quality with deterministic generation - local: using-diffusers/reproducibility @@ -50,8 +56,6 @@ title: Community pipelines - local: using-diffusers/contribute_pipeline title: How to contribute a community pipeline - - local: using-diffusers/using_safetensors - title: Using safetensors - local: using-diffusers/stable_diffusion_jax_how_to title: Stable Diffusion in JAX/Flax - local: using-diffusers/weighted_prompts @@ -60,6 +64,10 @@ - sections: - local: training/overview title: Overview + - local: training/create_dataset + title: Create a dataset for training + - local: training/adapt_a_model + title: Adapt a model to a new task - local: training/unconditional_training title: Unconditional image generation - local: training/text_inversion @@ -124,6 +132,8 @@ - sections: - local: api/models title: Models + - local: api/attnprocessor + title: Attention Processor - local: api/diffusion_pipeline title: Diffusion Pipeline - local: api/logging @@ -134,16 +144,22 @@ title: Outputs - local: api/loaders title: Loaders + - local: api/utilities + title: Utilities title: Main Classes - sections: - local: api/pipelines/overview title: Overview - local: api/pipelines/alt_diffusion title: AltDiffusion + - local: api/pipelines/attend_and_excite + title: Attend and Excite - local: api/pipelines/audio_diffusion title: Audio Diffusion - local: api/pipelines/audioldm title: AudioLDM + - local: api/pipelines/controlnet + title: ControlNet - local: api/pipelines/cycle_diffusion title: Cycle Diffusion - local: api/pipelines/dance_diffusion @@ -152,26 +168,36 @@ title: DDIM - local: api/pipelines/ddpm title: DDPM + - local: api/pipelines/diffedit + title: DiffEdit - local: api/pipelines/dit title: DiT - local: api/pipelines/if title: IF + - local: api/pipelines/pix2pix + title: InstructPix2Pix + - local: api/pipelines/kandinsky + title: Kandinsky - local: api/pipelines/latent_diffusion title: Latent Diffusion + - local: api/pipelines/panorama + title: MultiDiffusion Panorama - local: api/pipelines/paint_by_example title: PaintByExample + - local: api/pipelines/pix2pix_zero + title: Pix2Pix Zero - local: api/pipelines/pndm title: PNDM - local: api/pipelines/repaint title: RePaint - - local: api/pipelines/stable_diffusion_safe - title: Safe Stable Diffusion - local: api/pipelines/score_sde_ve title: Score SDE VE + - local: api/pipelines/self_attention_guidance + title: Self-Attention Guidance - local: api/pipelines/semantic_stable_diffusion title: Semantic Guidance - local: api/pipelines/spectrogram_diffusion - title: "Spectrogram Diffusion" + title: Spectrogram Diffusion - sections: - local: api/pipelines/stable_diffusion/overview title: Overview @@ -185,31 +211,21 @@ title: Depth-to-Image - local: api/pipelines/stable_diffusion/image_variation title: Image-Variation - - local: api/pipelines/stable_diffusion/upscale - title: Super-Resolution + - local: api/pipelines/stable_diffusion/stable_diffusion_safe + title: Safe Stable Diffusion + - local: api/pipelines/stable_diffusion/stable_diffusion_2 + title: Stable Diffusion 2 - local: api/pipelines/stable_diffusion/latent_upscale title: Stable-Diffusion-Latent-Upscaler - - local: api/pipelines/stable_diffusion/pix2pix - title: InstructPix2Pix - - local: api/pipelines/stable_diffusion/attend_and_excite - title: Attend and Excite - - local: api/pipelines/stable_diffusion/pix2pix_zero - title: Pix2Pix Zero - - local: api/pipelines/stable_diffusion/self_attention_guidance - title: Self-Attention Guidance - - local: api/pipelines/stable_diffusion/panorama - title: MultiDiffusion Panorama - - local: api/pipelines/stable_diffusion/controlnet - title: Text-to-Image Generation with ControlNet Conditioning - - local: api/pipelines/stable_diffusion/model_editing - title: Text-to-Image Model Editing + - local: api/pipelines/stable_diffusion/upscale + title: Super-Resolution title: Stable Diffusion - - local: api/pipelines/stable_diffusion_2 - title: Stable Diffusion 2 - local: api/pipelines/stable_unclip title: Stable unCLIP - local: api/pipelines/stochastic_karras_ve title: Stochastic Karras VE + - local: api/pipelines/model_editing + title: Text-to-Image Model Editing - local: api/pipelines/text_to_video title: Text-to-Video - local: api/pipelines/text_to_video_zero @@ -218,6 +234,8 @@ title: UnCLIP - local: api/pipelines/latent_diffusion_uncond title: Unconditional Latent Diffusion + - local: api/pipelines/unidiffuser + title: UniDiffuser - local: api/pipelines/versatile_diffusion title: Versatile Diffusion - local: api/pipelines/vq_diffusion @@ -238,12 +256,16 @@ title: DPM Discrete Scheduler - local: api/schedulers/dpm_discrete_ancestral title: DPM Discrete Scheduler with ancestral sampling + - local: api/schedulers/dpm_sde + title: DPMSolverSDEScheduler - local: api/schedulers/euler_ancestral title: Euler Ancestral Scheduler - local: api/schedulers/euler title: Euler scheduler - local: api/schedulers/heun title: Heun Scheduler + - local: api/schedulers/multistep_dpm_solver_inverse + title: Inverse Multistep DPM-Solver - local: api/schedulers/ipndm title: IPNDM - local: api/schedulers/lms_discrete diff --git a/docs/source/en/api/attnprocessor.mdx b/docs/source/en/api/attnprocessor.mdx new file mode 100644 index 000000000000..7a4812e0961e --- /dev/null +++ b/docs/source/en/api/attnprocessor.mdx @@ -0,0 +1,42 @@ +# Attention Processor + +An attention processor is a class for applying different types of attention mechanisms. + +## AttnProcessor +[[autodoc]] models.attention_processor.AttnProcessor + +## AttnProcessor2_0 +[[autodoc]] models.attention_processor.AttnProcessor2_0 + +## LoRAAttnProcessor +[[autodoc]] models.attention_processor.LoRAAttnProcessor + +## LoRAAttnProcessor2_0 +[[autodoc]] models.attention_processor.LoRAAttnProcessor2_0 + +## CustomDiffusionAttnProcessor +[[autodoc]] models.attention_processor.CustomDiffusionAttnProcessor + +## AttnAddedKVProcessor +[[autodoc]] models.attention_processor.AttnAddedKVProcessor + +## AttnAddedKVProcessor2_0 +[[autodoc]] models.attention_processor.AttnAddedKVProcessor2_0 + +## LoRAAttnAddedKVProcessor +[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor + +## XFormersAttnProcessor +[[autodoc]] models.attention_processor.XFormersAttnProcessor + +## LoRAXFormersAttnProcessor +[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor + +## CustomDiffusionXFormersAttnProcessor +[[autodoc]] models.attention_processor.CustomDiffusionXFormersAttnProcessor + +## SlicedAttnProcessor +[[autodoc]] models.attention_processor.SlicedAttnProcessor + +## SlicedAttnAddedKVProcessor +[[autodoc]] models.attention_processor.SlicedAttnAddedKVProcessor \ No newline at end of file diff --git a/docs/source/en/api/diffusion_pipeline.mdx b/docs/source/en/api/diffusion_pipeline.mdx index 280802d6a89a..a47025a3e94a 100644 --- a/docs/source/en/api/diffusion_pipeline.mdx +++ b/docs/source/en/api/diffusion_pipeline.mdx @@ -12,36 +12,25 @@ specific language governing permissions and limitations under the License. # Pipelines -The [`DiffusionPipeline`] is the easiest way to load any pretrained diffusion pipeline from the [Hub](https://huggingface.co/models?library=diffusers) and to use it in inference. +The [`DiffusionPipeline`] is the easiest way to load any pretrained diffusion pipeline from the [Hub](https://huggingface.co/models?library=diffusers) and use it for inference. - One should not use the Diffusion Pipeline class for training or fine-tuning a diffusion model. Individual - components of diffusion pipelines are usually trained individually, so we suggest to directly work - with [`UNetModel`] and [`UNetConditionModel`]. +You shouldn't use the [`DiffusionPipeline`] class for training or finetuning a diffusion model. Individual +components (for example, [`UNetModel`] and [`UNetConditionModel`]) of diffusion pipelines are usually trained individually, so we suggest directly working with instead. -Any diffusion pipeline that is loaded with [`~DiffusionPipeline.from_pretrained`] will automatically -detect the pipeline type, *e.g.* [`StableDiffusionPipeline`] and consequently load each component of the -pipeline and pass them into the `__init__` function of the pipeline, *e.g.* [`~StableDiffusionPipeline.__init__`]. +The pipeline type (for example [`StableDiffusionPipeline`]) of any diffusion pipeline loaded with [`~DiffusionPipeline.from_pretrained`] is automatically +detected and pipeline components are loaded and passed to the `__init__` function of the pipeline. Any pipeline object can be saved locally with [`~DiffusionPipeline.save_pretrained`]. ## DiffusionPipeline + [[autodoc]] DiffusionPipeline - all - __call__ - device - to - components - -## ImagePipelineOutput -By default diffusion pipelines return an object of class - -[[autodoc]] pipelines.ImagePipelineOutput - -## AudioPipelineOutput -By default diffusion pipelines return an object of class - -[[autodoc]] pipelines.AudioPipelineOutput diff --git a/docs/source/en/api/logging.mdx b/docs/source/en/api/logging.mdx index b52c0434f42d..bb973db781ea 100644 --- a/docs/source/en/api/logging.mdx +++ b/docs/source/en/api/logging.mdx @@ -61,7 +61,7 @@ verbose to the most verbose), those levels (with their corresponding int values critical errors. - `diffusers.logging.ERROR` (int value, 40): only report errors. - `diffusers.logging.WARNING` or `diffusers.logging.WARN` (int value, 30): only reports error and - warnings. This the default level used by the library. + warnings. This is the default level used by the library. - `diffusers.logging.INFO` (int value, 20): reports error, warnings and basic information. - `diffusers.logging.DEBUG` (int value, 10): report all information. diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx index 2361fd4f6597..74291f9173ea 100644 --- a/docs/source/en/api/models.mdx +++ b/docs/source/en/api/models.mdx @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License. # Models Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution $p_\theta(\mathbf{x}_{t-1}|\mathbf{x}_t)$. +The primary function of these models is to denoise an input sample, by modeling the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\). The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. ## ModelMixin diff --git a/docs/source/en/api/outputs.mdx b/docs/source/en/api/outputs.mdx index 9466f354541d..1e9fbedba35b 100644 --- a/docs/source/en/api/outputs.mdx +++ b/docs/source/en/api/outputs.mdx @@ -12,11 +12,11 @@ specific language governing permissions and limitations under the License. # BaseOutputs -All models have outputs that are instances of subclasses of [`~utils.BaseOutput`]. Those are -data structures containing all the information returned by the model, but that can also be used as tuples or +All models have outputs that are subclasses of [`~utils.BaseOutput`]. Those are +data structures containing all the information returned by the model, but they can also be used as tuples or dictionaries. -Let's see how this looks in an example: +For example: ```python from diffusers import DDIMPipeline @@ -25,31 +25,45 @@ pipeline = DDIMPipeline.from_pretrained("google/ddpm-cifar10-32") outputs = pipeline() ``` -The `outputs` object is a [`~pipelines.ImagePipelineOutput`], as we can see in the -documentation of that class below, it means it has an image attribute. +The `outputs` object is a [`~pipelines.ImagePipelineOutput`] which means it has an image attribute. -You can access each attribute as you would usually do, and if that attribute has not been returned by the model, you will get `None`: +You can access each attribute as you normally would or with a keyword lookup, and if that attribute is not returned by the model, you will get `None`: ```python outputs.images -``` - -or via keyword lookup - -```python outputs["images"] ``` -When considering our `outputs` object as tuple, it only considers the attributes that don't have `None` values. -Here for instance, we could retrieve images via indexing: +When considering the `outputs` object as a tuple, it only considers the attributes that don't have `None` values. +For instance, retrieving an image by indexing into it returns the tuple `(outputs.images)`: ```python outputs[:1] ``` -which will return the tuple `(outputs.images)` for instance. + + +To check a specific pipeline or model output, refer to its corresponding API documentation. + + ## BaseOutput [[autodoc]] utils.BaseOutput - to_tuple + +## ImagePipelineOutput + +[[autodoc]] pipelines.ImagePipelineOutput + +## FlaxImagePipelineOutput + +[[autodoc]] pipelines.pipeline_flax_utils.FlaxImagePipelineOutput + +## AudioPipelineOutput + +[[autodoc]] pipelines.AudioPipelineOutput + +## ImageTextPipelineOutput + +[[autodoc]] ImageTextPipelineOutput \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/attend_and_excite.mdx b/docs/source/en/api/pipelines/attend_and_excite.mdx similarity index 100% rename from docs/source/en/api/pipelines/stable_diffusion/attend_and_excite.mdx rename to docs/source/en/api/pipelines/attend_and_excite.mdx diff --git a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx b/docs/source/en/api/pipelines/controlnet.mdx similarity index 67% rename from docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx rename to docs/source/en/api/pipelines/controlnet.mdx index fd5c87821c01..f9e4c3c47e3e 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx +++ b/docs/source/en/api/pipelines/controlnet.mdx @@ -22,7 +22,7 @@ The abstract of the paper is the following: *We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* -This model was contributed by the amazing community contributor [takuma104](https://huggingface.co/takuma104) ❤️ . +This model was contributed by the community contributor [takuma104](https://huggingface.co/takuma104) ❤️ . Resources: @@ -33,7 +33,9 @@ Resources: | Pipeline | Tasks | Demo |---|---|:---:| -| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) +| [StableDiffusionControlNetPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet.py) | *Text-to-Image Generation with ControlNet Conditioning* | [Colab Example](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) +| [StableDiffusionControlNetImg2ImgPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py) | *Image-to-Image Generation with ControlNet Conditioning* | +| [StableDiffusionControlNetInpaintPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_controlnet_inpaint.py) | *Inpainting Generation with ControlNet Conditioning* | ## Usage example @@ -301,21 +303,22 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h ### ControlNet v1.1 -| Model Name | Control Image Overview| Control Image Example | Generated Image Example | -|---|---|---|---| -|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)
*Trained with canny edge detection* | A monochrome image with white edges on a black background.||| -|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)
*Trained with pixel to pixel instruction* | No condition .||| -|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)
Trained with image inpainting | No condition.||| -|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)
Trained with multi-level line segment detection | An image with annotated line segments.||| -|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)
Trained with depth estimation | An image with depth information, usually represented as a grayscale image.||| -|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)
Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.||| -|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)
Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.||| -|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)
Trained with line art generation | An image with line art, usually black lines on a white background.||| -|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
Trained with anime line art generation | An image with anime-style line art.||| -|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.||| -|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)
Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.||| -|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)
Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.||| -|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)
Trained with image shuffling | An image with shuffled patches or regions.||| +| Model Name | Control Image Overview| Condition Image | Control Image Example | Generated Image Example | +|---|---|---|---|---| +|[lllyasviel/control_v11p_sd15_canny](https://huggingface.co/lllyasviel/control_v11p_sd15_canny)
| *Trained with canny edge detection* | A monochrome image with white edges on a black background.||| +|[lllyasviel/control_v11e_sd15_ip2p](https://huggingface.co/lllyasviel/control_v11e_sd15_ip2p)
| *Trained with pixel to pixel instruction* | No condition .||| +|[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint)
| Trained with image inpainting | No condition.||| +|[lllyasviel/control_v11p_sd15_mlsd](https://huggingface.co/lllyasviel/control_v11p_sd15_mlsd)
| Trained with multi-level line segment detection | An image with annotated line segments.||| +|[lllyasviel/control_v11f1p_sd15_depth](https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth)
| Trained with depth estimation | An image with depth information, usually represented as a grayscale image.||| +|[lllyasviel/control_v11p_sd15_normalbae](https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae)
| Trained with surface normal estimation | An image with surface normal information, usually represented as a color-coded image.||| +|[lllyasviel/control_v11p_sd15_seg](https://huggingface.co/lllyasviel/control_v11p_sd15_seg)
| Trained with image segmentation | An image with segmented regions, usually represented as a color-coded image.||| +|[lllyasviel/control_v11p_sd15_lineart](https://huggingface.co/lllyasviel/control_v11p_sd15_lineart)
| Trained with line art generation | An image with line art, usually black lines on a white background.||| +|[lllyasviel/control_v11p_sd15s2_lineart_anime](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
| Trained with anime line art generation | An image with anime-style line art.||| +|[lllyasviel/control_v11p_sd15_openpose](https://huggingface.co/lllyasviel/control_v11p_sd15s2_lineart_anime)
| Trained with human pose estimation | An image with human poses, usually represented as a set of keypoints or skeletons.||| +|[lllyasviel/control_v11p_sd15_scribble](https://huggingface.co/lllyasviel/control_v11p_sd15_scribble)
| Trained with scribble-based image generation | An image with scribbles, usually random or user-drawn strokes.||| +|[lllyasviel/control_v11p_sd15_softedge](https://huggingface.co/lllyasviel/control_v11p_sd15_softedge)
| Trained with soft edge image generation | An image with soft edges, usually to create a more painterly or artistic effect.||| +|[lllyasviel/control_v11e_sd15_shuffle](https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle)
| Trained with image shuffling | An image with shuffled patches or regions.||| +|[lllyasviel/control_v11f1e_sd15_tile](https://huggingface.co/lllyasviel/control_v11f1e_sd15_tile)
| Trained with image tiling | A blurry image or part of an image .||| ## StableDiffusionControlNetPipeline [[autodoc]] StableDiffusionControlNetPipeline @@ -329,6 +332,30 @@ All checkpoints can be found under the authors' namespace [lllyasviel](https://h - disable_xformers_memory_efficient_attention - load_textual_inversion +## StableDiffusionControlNetImg2ImgPipeline +[[autodoc]] StableDiffusionControlNetImg2ImgPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + - load_textual_inversion + +## StableDiffusionControlNetInpaintPipeline +[[autodoc]] StableDiffusionControlNetInpaintPipeline + - all + - __call__ + - enable_attention_slicing + - disable_attention_slicing + - enable_vae_slicing + - disable_vae_slicing + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention + - load_textual_inversion + ## FlaxStableDiffusionControlNetPipeline [[autodoc]] FlaxStableDiffusionControlNetPipeline - all diff --git a/docs/source/en/api/pipelines/diffedit.mdx b/docs/source/en/api/pipelines/diffedit.mdx new file mode 100644 index 000000000000..a7cd906e0e77 --- /dev/null +++ b/docs/source/en/api/pipelines/diffedit.mdx @@ -0,0 +1,360 @@ + + +# Zero-shot Diffusion-based Semantic Image Editing with Mask Guidance + +## Overview + +[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427) by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord. + +The abstract of the paper is the following: + +*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.* + +Resources: + +* [Paper](https://arxiv.org/abs/2210.11427). +* [Blog Post with Demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html). +* [Implementation on Github](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/). + +## Tips + +* The pipeline can generate masks that can be fed into other inpainting pipelines. Check out the code examples below to know more. +* In order to generate an image using this pipeline, both an image mask (manually specified or generated using `generate_mask`) +and a set of partially inverted latents (generated using `invert`) _must_ be provided as arguments when calling the pipeline to generate the final edited image. +Refer to the code examples below for more details. +* The function `generate_mask` exposes two prompt arguments, `source_prompt` and `target_prompt`, +that let you control the locations of the semantic edits in the final image to be generated. Let's say, +you wanted to translate from "cat" to "dog". In this case, the edit direction will be "cat -> dog". To reflect +this in the generated mask, you simply have to set the embeddings related to the phrases including "cat" to +`source_prompt_embeds` and "dog" to `target_prompt_embeds`. Refer to the code example below for more details. +* When generating partially inverted latents using `invert`, assign a caption or text embedding describing the +overall image to the `prompt` argument to help guide the inverse latent sampling process. In most cases, the +source concept is sufficently descriptive to yield good results, but feel free to explore alternatives. +Please refer to [this code example](#generating-image-captions-for-inversion) for more details. +* When calling the pipeline to generate the final edited image, assign the source concept to `negative_prompt` +and the target concept to `prompt`. Taking the above example, you simply have to set the embeddings related to +the phrases including "cat" to `negative_prompt_embeds` and "dog" to `prompt_embeds`. Refer to the code example +below for more details. +* If you wanted to reverse the direction in the example above, i.e., "dog -> cat", then it's recommended to: + * Swap the `source_prompt` and `target_prompt` in the arguments to `generate_mask`. + * Change the input prompt for `invert` to include "dog". + * Swap the `prompt` and `negative_prompt` in the arguments to call the pipeline to generate the final edited image. +* Note that the source and target prompts, or their corresponding embeddings, can also be automatically generated. Please, refer to [this discussion](#generating-source-and-target-embeddings) for more details. + +## Available Pipelines: + +| Pipeline | Tasks +|---|---| +| [StableDiffusionDiffEditPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py) | *Text-Based Image Editing* + + + +## Usage example + +### Based on an input image with a caption + +When the pipeline is conditioned on an input image, we first obtain partially inverted latents from the input image using a +`DDIMInverseScheduler` with the help of a caption. Then we generate an editing mask to identify relevant regions in the image using the source and target prompts. Finally, +the inverted noise and generated mask is used to start the generation process. + +First, let's load our pipeline: + +```py +import torch +from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionPix2PixZeroPipeline + +sd_model_ckpt = "stabilityai/stable-diffusion-2-1" +pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + sd_model_ckpt, + torch_dtype=torch.float16, + safety_checker=None, +) +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) +pipeline.enable_model_cpu_offload() +pipeline.enable_vae_slicing() +generator = torch.manual_seed(0) +``` + +Then, we load an input image to edit using our method: + +```py +from diffusers.utils import load_image + +img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" +raw_image = load_image(img_url).convert("RGB").resize((768, 768)) +``` + +Then, we employ the source and target prompts to generate the editing mask: + +```py +# See the "Generating source and target embeddings" section below to +# automate the generation of these captions with a pre-trained model like Flan-T5 as explained below. + +source_prompt = "a bowl of fruits" +target_prompt = "a basket of fruits" +mask_image = pipeline.generate_mask( + image=raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, + generator=generator, +) +``` + +Then, we employ the caption and the input image to get the inverted latents: + +```py +inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image, generator=generator).latents +``` + +Now, generate the image with the inverted latents and semantically generated mask: + +```py +image = pipeline( + prompt=target_prompt, + mask_image=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, +).images[0] +image.save("edited_image.png") +``` + +## Generating image captions for inversion + +The authors originally used the source concept prompt as the caption for generating the partially inverted latents. However, we can also leverage open source and public image captioning models for the same purpose. +Below, we provide an end-to-end example with the [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) model +for generating captions. + +First, let's load our automatic image captioning model: + +```py +import torch +from transformers import BlipForConditionalGeneration, BlipProcessor + +captioner_id = "Salesforce/blip-image-captioning-base" +processor = BlipProcessor.from_pretrained(captioner_id) +model = BlipForConditionalGeneration.from_pretrained(captioner_id, torch_dtype=torch.float16, low_cpu_mem_usage=True) +``` + +Then, we define a utility to generate captions from an input image using the model: + +```py +@torch.no_grad() +def generate_caption(images, caption_generator, caption_processor): + text = "a photograph of" + + inputs = caption_processor(images, text, return_tensors="pt").to(device="cuda", dtype=caption_generator.dtype) + caption_generator.to("cuda") + outputs = caption_generator.generate(**inputs, max_new_tokens=128) + + # offload caption generator + caption_generator.to("cpu") + + caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0] + return caption +``` + +Then, we load an input image for conditioning and obtain a suitable caption for it: + +```py +from diffusers.utils import load_image + +img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" +raw_image = load_image(img_url).convert("RGB").resize((768, 768)) +caption = generate_caption(raw_image, model, processor) +``` + +Then, we employ the generated caption and the input image to get the inverted latents: + +```py +from diffusers import DDIMInverseScheduler, DDIMScheduler + +pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 +) +pipeline = pipeline.to("cuda") +pipeline.enable_model_cpu_offload() +pipeline.enable_vae_slicing() + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + +generator = torch.manual_seed(0) +inv_latents = pipeline.invert(prompt=caption, image=raw_image, generator=generator).latents +``` + +Now, generate the image with the inverted latents and semantically generated mask from our source and target prompts: + +```py +source_prompt = "a bowl of fruits" +target_prompt = "a basket of fruits" + +mask_image = pipeline.generate_mask( + image=raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, + generator=generator, +) + +image = pipeline( + prompt=target_prompt, + mask_image=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, +).images[0] +image.save("edited_image.png") +``` + +## Generating source and target embeddings + +The authors originally required the user to manually provide the source and target prompts for discovering +edit directions. However, we can also leverage open source and public models for the same purpose. +Below, we provide an end-to-end example with the [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) model +for generating source an target embeddings. + +**1. Load the generation model**: + +```py +import torch +from transformers import AutoTokenizer, T5ForConditionalGeneration + +tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xl") +model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl", device_map="auto", torch_dtype=torch.float16) +``` + +**2. Construct a starting prompt**: + +```py +source_concept = "bowl" +target_concept = "basket" + +source_text = f"Provide a caption for images containing a {source_concept}. " +"The captions should be in English and should be no longer than 150 characters." + +target_text = f"Provide a caption for images containing a {target_concept}. " +"The captions should be in English and should be no longer than 150 characters." +``` + +Here, we're interested in the "bowl -> basket" direction. + +**3. Generate prompts**: + +We can use a utility like so for this purpose. + +```py +@torch.no_grad +def generate_prompts(input_prompt): + input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda") + + outputs = model.generate( + input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10 + ) + return tokenizer.batch_decode(outputs, skip_special_tokens=True) +``` + +And then we just call it to generate our prompts: + +```py +source_prompts = generate_prompts(source_text) +target_prompts = generate_prompts(target_text) +``` + +We encourage you to play around with the different parameters supported by the +`generate()` method ([documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.generation_tf_utils.TFGenerationMixin.generate)) for the generation quality you are looking for. + +**4. Load the embedding model**: + +Here, we need to use the same text encoder model used by the subsequent Stable Diffusion model. + +```py +from diffusers import StableDiffusionDiffEditPipeline + +pipeline = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 +) +pipeline = pipeline.to("cuda") +pipeline.enable_model_cpu_offload() +pipeline.enable_vae_slicing() + +generator = torch.manual_seed(0) +``` + +**5. Compute embeddings**: + +```py +import torch + +@torch.no_grad() +def embed_prompts(sentences, tokenizer, text_encoder, device="cuda"): + embeddings = [] + for sent in sentences: + text_inputs = tokenizer( + sent, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0] + embeddings.append(prompt_embeds) + return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0) + +source_embeddings = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder) +target_embeddings = embed_prompts(target_captions, pipeline.tokenizer, pipeline.text_encoder) +``` + +And you're done! Now, you can use these embeddings directly while calling the pipeline: + +```py +from diffusers import DDIMInverseScheduler, DDIMScheduler +from diffusers.utils import load_image + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + +img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" +raw_image = load_image(img_url).convert("RGB").resize((768, 768)) + + +mask_image = pipeline.generate_mask( + image=raw_image, + source_prompt_embeds=source_embeds, + target_prompt_embeds=target_embeds, + generator=generator, +) + +inv_latents = pipeline.invert( + prompt_embeds=source_embeds, + image=raw_image, + generator=generator, +).latents + +images = pipeline( + mask_image=mask_image, + image_latents=inv_latents, + prompt_embeds=target_embeddings, + negative_prompt_embeds=source_embeddings, + generator=generator, +).images +images[0].save("edited_image.png") +``` + +## StableDiffusionDiffEditPipeline +[[autodoc]] StableDiffusionDiffEditPipeline + - all + - generate_mask + - invert + - __call__ \ No newline at end of file diff --git a/docs/source/en/api/pipelines/kandinsky.mdx b/docs/source/en/api/pipelines/kandinsky.mdx new file mode 100644 index 000000000000..b94937e4af85 --- /dev/null +++ b/docs/source/en/api/pipelines/kandinsky.mdx @@ -0,0 +1,351 @@ + + +# Kandinsky + +## Overview + +Kandinsky 2.1 inherits best practices from [DALL-E 2](https://arxiv.org/abs/2204.06125) and [Latent Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/latent_diffusion), while introducing some new ideas. + +It uses [CLIP](https://huggingface.co/docs/transformers/model_doc/clip) for encoding images and text, and a diffusion image prior (mapping) between latent spaces of CLIP modalities. This approach enhances the visual performance of the model and unveils new horizons in blending images and text-guided image manipulation. + +The Kandinsky model is created by [Arseniy Shakhmatov](https://github.com/cene555), [Anton Razzhigaev](https://github.com/razzant), [Aleksandr Nikolich](https://github.com/AlexWortega), [Igor Pavlov](https://github.com/boomb0om), [Andrey Kuznetsov](https://github.com/kuznetsoffandrey) and [Denis Dimitrov](https://github.com/denndimitrov) and the original codebase can be found [here](https://github.com/ai-forever/Kandinsky-2) + +## Available Pipelines: + +| Pipeline | Tasks | +|---|---| +| [pipeline_kandinsky.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py) | *Text-to-Image Generation* | +| [pipeline_kandinsky_inpaint.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py) | *Image-Guided Image Generation* | +| [pipeline_kandinsky_img2img.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py) | *Image-Guided Image Generation* | + +## Usage example + +In the following, we will walk you through some examples of how to use the Kandinsky pipelines to create some visually aesthetic artwork. + +### Text-to-Image Generation + +For text-to-image generation, we need to use both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. +The first step is to encode text prompts with CLIP and then diffuse the CLIP text embeddings to CLIP image embeddings, +as first proposed in [DALL-E 2](https://cdn.openai.com/papers/dall-e-2.pdf). +Let's throw a fun prompt at Kandinsky to see what it comes up with. + +```py +prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" +``` + +First, let's instantiate the prior pipeline and the text-to-image pipeline. Both +pipelines are diffusion models. + + +```py +from diffusers import DiffusionPipeline +import torch + +pipe_prior = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16) +pipe_prior.to("cuda") + +t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +t2i_pipe.to("cuda") +``` + +Now we pass the prompt through the prior to generate image embeddings. The prior +returns both the image embeddings corresponding to the prompt and negative/unconditional image +embeddings corresponding to an empty string. + +```py +generator = torch.Generator(device="cuda").manual_seed(12) +image_embeds, negative_image_embeds = pipe_prior(prompt, generator=generator).to_tuple() +``` + + + +The text-to-image pipeline expects both `image_embeds`, `negative_image_embeds` and the original +`prompt` as the text-to-image pipeline uses another text encoder to better guide the second diffusion +process of `t2i_pipe`. + +By default, the prior returns unconditioned negative image embeddings corresponding to the negative prompt of `""`. +For better results, you can also pass a `negative_prompt` to the prior. This will increase the effective batch size +of the prior by a factor of 2. + +```py +prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" +negative_prompt = "low quality, bad quality" + +image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple() +``` + + + + +Next, we can pass the embeddings as well as the prompt to the text-to-image pipeline. Remember that +in case you are using a customized negative prompt, that you should pass this one also to the text-to-image pipelines +with `negative_prompt=negative_prompt`: + +```py +image = t2i_pipe(prompt, image_embeds=image_embeds, negative_image_embeds=negative_image_embeds).images[0] +image.save("cheeseburger_monster.png") +``` + +One cheeseburger monster coming up! Enjoy! + +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/cheeseburger.png) + +The Kandinsky model works extremely well with creative prompts. Here is some of the amazing art that can be created using the exact same process but with different prompts. + +```python +prompt = "bird eye view shot of a full body woman with cyan light orange magenta makeup, digital art, long braided hair her face separated by makeup in the style of yin Yang surrealism, symmetrical face, real image, contrasting tone, pastel gradient background" +``` +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/hair.png) + +```python +prompt = "A car exploding into colorful dust" +``` +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/dusts.png) + +```python +prompt = "editorial photography of an organic, almost liquid smoke style armchair" +``` +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/smokechair.png) + +```python +prompt = "birds eye view of a quilted paper style alien planet landscape, vibrant colours, Cinematic lighting" +``` +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/alienplanet.png) + + +### Text Guided Image-to-Image Generation + +The same Kandinsky model weights can be used for text-guided image-to-image translation. In this case, just make sure to load the weights using the [`KandinskyImg2ImgPipeline`] pipeline. + +**Note**: You can also directly move the weights of the text-to-image pipelines to the image-to-image pipelines +without loading them twice by making use of the [`~DiffusionPipeline.components`] function as explained [here](#converting-between-different-pipelines). + +Let's download an image. + +```python +from PIL import Image +import requests +from io import BytesIO + +# download image +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" +response = requests.get(url) +original_image = Image.open(BytesIO(response.content)).convert("RGB") +original_image = original_image.resize((768, 512)) +``` + +![img](https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg) + +```python +import torch +from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline + +# create prior +pipe_prior = KandinskyPriorPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 +) +pipe_prior.to("cuda") + +# create img2img pipeline +pipe = KandinskyImg2ImgPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +pipe.to("cuda") + +prompt = "A fantasy landscape, Cinematic lighting" +negative_prompt = "low quality, bad quality" + +generator = torch.Generator(device="cuda").manual_seed(30) +image_embeds, negative_image_embeds = pipe_prior(prompt, negative_prompt, generator=generator).to_tuple() + +out = pipe( + prompt, + image=original_image, + image_embeds=image_embeds, + negative_image_embeds=negative_image_embeds, + height=768, + width=768, + strength=0.3, +) + +out.images[0].save("fantasy_land.png") +``` + +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/img2img_fantasyland.png) + + +### Text Guided Inpainting Generation + +You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat. + +```py +from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline +from diffusers.utils import load_image +import torch +import numpy as np + +pipe_prior = KandinskyPriorPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 +) +pipe_prior.to("cuda") + +prompt = "a hat" +prior_output = pipe_prior(prompt) + +pipe = KandinskyInpaintPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16) +pipe.to("cuda") + +init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" +) + +mask = np.ones((768, 768), dtype=np.float32) +# Let's mask out an area above the cat's head +mask[:250, 250:-250] = 0 + +out = pipe( + prompt, + image=init_image, + mask_image=mask, + **prior_output, + height=768, + width=768, + num_inference_steps=150, +) + +image = out.images[0] +image.save("cat_with_hat.png") +``` +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/inpaint_cat_hat.png) + +### Interpolate + +The [`KandinskyPriorPipeline`] also comes with a cool utility function that will allow you to interpolate the latent space of different images and texts super easily. Here is an example of how you can create an Impressionist-style portrait for your pet based on "The Starry Night". + +Note that you can interpolate between texts and images - in the below example, we passed a text prompt "a cat" and two images to the `interplate` function, along with a `weights` variable containing the corresponding weights for each condition we interplate. + +```python +from diffusers import KandinskyPriorPipeline, KandinskyPipeline +from diffusers.utils import load_image +import PIL + +import torch + +pipe_prior = KandinskyPriorPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 +) +pipe_prior.to("cuda") + +img1 = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" +) + +img2 = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/starry_night.jpeg" +) + +# add all the conditions we want to interpolate, can be either text or image +images_texts = ["a cat", img1, img2] + +# specify the weights for each condition in images_texts +weights = [0.3, 0.3, 0.4] + +# We can leave the prompt empty +prompt = "" +prior_out = pipe_prior.interpolate(images_texts, weights) + +pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +pipe.to("cuda") + +image = pipe(prompt, **prior_out, height=768, width=768).images[0] + +image.save("starry_cat.png") +``` +![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/starry_cat.png) + + +## Optimization + +Running Kandinsky in inference requires running both a first prior pipeline: [`KandinskyPriorPipeline`] +and a second image decoding pipeline which is one of [`KandinskyPipeline`], [`KandinskyImg2ImgPipeline`], or [`KandinskyInpaintPipeline`]. + +The bulk of the computation time will always be the second image decoding pipeline, so when looking +into optimizing the model, one should look into the second image decoding pipeline. + +When running with PyTorch < 2.0, we strongly recommend making use of [`xformers`](https://github.com/facebookresearch/xformers) +to speed-up the optimization. This can be done by simply running: + +```py +from diffusers import DiffusionPipeline +import torch + +t2i_pipe = DiffusionPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +t2i_pipe.enable_xformers_memory_efficient_attention() +``` + +When running on PyTorch >= 2.0, PyTorch's SDPA attention will automatically be used. For more information on +PyTorch's SDPA, feel free to have a look at [this blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/). + +To have explicit control , you can also manually set the pipeline to use PyTorch's 2.0 efficient attention: + +```py +from diffusers.models.attention_processor import AttnAddedKVProcessor2_0 + +t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor2_0()) +``` + +The slowest and most memory intense attention processor is the default `AttnAddedKVProcessor` processor. +We do **not** recommend using it except for testing purposes or cases where very high determistic behaviour is desired. +You can set it with: + +```py +from diffusers.models.attention_processor import AttnAddedKVProcessor + +t2i_pipe.unet.set_attn_processor(AttnAddedKVProcessor()) +``` + +With PyTorch >= 2.0, you can also use Kandinsky with `torch.compile` which depending +on your hardware can signficantly speed-up your inference time once the model is compiled. +To use Kandinsksy with `torch.compile`, you can do: + +```py +t2i_pipe.unet.to(memory_format=torch.channels_last) +t2i_pipe.unet = torch.compile(t2i_pipe.unet, mode="reduce-overhead", fullgraph=True) +``` + +After compilation you should see a very fast inference time. For more information, +feel free to have a look at [Our PyTorch 2.0 benchmark](https://huggingface.co/docs/diffusers/main/en/optimization/torch2.0). + + + + + +## KandinskyPriorPipeline + +[[autodoc]] KandinskyPriorPipeline + - all + - __call__ + - interpolate + +## KandinskyPipeline + +[[autodoc]] KandinskyPipeline + - all + - __call__ + +## KandinskyImg2ImgPipeline + +[[autodoc]] KandinskyImg2ImgPipeline + - all + - __call__ + +## KandinskyInpaintPipeline + +[[autodoc]] KandinskyInpaintPipeline + - all + - __call__ diff --git a/docs/source/en/api/pipelines/stable_diffusion/model_editing.mdx b/docs/source/en/api/pipelines/model_editing.mdx similarity index 100% rename from docs/source/en/api/pipelines/stable_diffusion/model_editing.mdx rename to docs/source/en/api/pipelines/model_editing.mdx diff --git a/docs/source/en/api/pipelines/overview.mdx b/docs/source/en/api/pipelines/overview.mdx index 91716784f8fe..0ae3d897a3b1 100644 --- a/docs/source/en/api/pipelines/overview.mdx +++ b/docs/source/en/api/pipelines/overview.mdx @@ -46,7 +46,7 @@ available a colab notebook to directly try them out. |---|---|:---:|:---:| | [alt_diffusion](./alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | - | [audio_diffusion](./audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio_diffusion.git) | Unconditional Audio Generation | -| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) +| [controlnet](./api/pipelines/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) | [cycle_diffusion](./cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | @@ -113,105 +113,3 @@ each pipeline, one should look directly into the respective pipeline. **Note**: All pipelines have PyTorch's autograd disabled by decorating the `__call__` method with a [`torch.no_grad`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) decorator because pipelines should not be used for training. If you want to store the gradients during the forward pass, we recommend writing your own pipeline, see also our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). - -## Contribution - -We are more than happy about any contribution to the officially supported pipelines 🤗. We aspire -all of our pipelines to be **self-contained**, **easy-to-tweak**, **beginner-friendly** and for **one-purpose-only**. - -- **Self-contained**: A pipeline shall be as self-contained as possible. More specifically, this means that all functionality should be either directly defined in the pipeline file itself, should be inherited from (and only from) the [`DiffusionPipeline` class](.../diffusion_pipeline) or be directly attached to the model and scheduler components of the pipeline. -- **Easy-to-use**: Pipelines should be extremely easy to use - one should be able to load the pipeline and -use it for its designated task, *e.g.* text-to-image generation, in just a couple of lines of code. Most -logic including pre-processing, an unrolled diffusion loop, and post-processing should all happen inside the `__call__` method. -- **Easy-to-tweak**: Certain pipelines will not be able to handle all use cases and tasks that you might like them to. If you want to use a certain pipeline for a specific use case that is not yet supported, you might have to copy the pipeline file and tweak the code to your needs. We try to make the pipeline code as readable as possible so that each part –from pre-processing to diffusing to post-processing– can easily be adapted. If you would like the community to benefit from your customized pipeline, we would love to see a contribution to our [community-examples](https://github.com/huggingface/diffusers/tree/main/examples/community). If you feel that an important pipeline should be part of the official pipelines but isn't, a contribution to the [official pipelines](./overview) would be even better. -- **One-purpose-only**: Pipelines should be used for one task and one task only. Even if two tasks are very similar from a modeling point of view, *e.g.* image2image translation and in-painting, pipelines shall be used for one task only to keep them *easy-to-tweak* and *readable*. - -## Examples - -### Text-to-Image generation with Stable Diffusion - -```python -# make sure you're logged in with `huggingface-cli login` -from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler - -pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") -pipe = pipe.to("cuda") - -prompt = "a photo of an astronaut riding a horse on mars" -image = pipe(prompt).images[0] - -image.save("astronaut_rides_horse.png") -``` - -### Image-to-Image text-guided generation with Stable Diffusion - -The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images. - -```python -import requests -from PIL import Image -from io import BytesIO - -from diffusers import StableDiffusionImg2ImgPipeline - -# load the pipeline -device = "cuda" -pipe = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to( - device -) - -# let's download an initial image -url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" - -response = requests.get(url) -init_image = Image.open(BytesIO(response.content)).convert("RGB") -init_image = init_image.resize((768, 512)) - -prompt = "A fantasy landscape, trending on artstation" - -images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images - -images[0].save("fantasy_landscape.png") -``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) - -### Tweak prompts reusing seeds and latents - -You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](https://github.com/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb) - - -### In-painting using Stable Diffusion - -The `StableDiffusionInpaintPipeline` lets you edit specific parts of an image by providing a mask and text prompt. - -```python -import PIL -import requests -import torch -from io import BytesIO - -from diffusers import StableDiffusionInpaintPipeline - - -def download_image(url): - response = requests.get(url) - return PIL.Image.open(BytesIO(response.content)).convert("RGB") - - -img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" -mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - -init_image = download_image(img_url).resize((512, 512)) -mask_image = download_image(mask_url).resize((512, 512)) - -pipe = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", - torch_dtype=torch.float16, -) -pipe = pipe.to("cuda") - -prompt = "Face of a yellow cat, high resolution, sitting on a park bench" -image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] -``` - -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) diff --git a/docs/source/en/api/pipelines/stable_diffusion/panorama.mdx b/docs/source/en/api/pipelines/panorama.mdx similarity index 93% rename from docs/source/en/api/pipelines/stable_diffusion/panorama.mdx rename to docs/source/en/api/pipelines/panorama.mdx index e0c7747a0193..044901f24bf3 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/panorama.mdx +++ b/docs/source/en/api/pipelines/panorama.mdx @@ -52,6 +52,14 @@ image = pipe(prompt).images[0] image.save("dolomites.png") ``` + + +While calling this pipeline, it's possible to specify the `view_batch_size` to have a >1 value. +For some GPUs with high performance, higher a `view_batch_size`, can speedup the generation +and increase the VRAM usage. + + + ## StableDiffusionPanoramaPipeline [[autodoc]] StableDiffusionPanoramaPipeline - __call__ diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx b/docs/source/en/api/pipelines/pix2pix.mdx similarity index 100% rename from docs/source/en/api/pipelines/stable_diffusion/pix2pix.mdx rename to docs/source/en/api/pipelines/pix2pix.mdx diff --git a/docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx b/docs/source/en/api/pipelines/pix2pix_zero.mdx similarity index 100% rename from docs/source/en/api/pipelines/stable_diffusion/pix2pix_zero.mdx rename to docs/source/en/api/pipelines/pix2pix_zero.mdx diff --git a/docs/source/en/api/pipelines/repaint.mdx b/docs/source/en/api/pipelines/repaint.mdx index 927398d0bf54..895d3011883c 100644 --- a/docs/source/en/api/pipelines/repaint.mdx +++ b/docs/source/en/api/pipelines/repaint.mdx @@ -60,7 +60,7 @@ pipe = pipe.to("cuda") generator = torch.Generator(device="cuda").manual_seed(0) output = pipe( - original_image=original_image, + image=original_image, mask_image=mask_image, num_inference_steps=250, eta=0.0, diff --git a/docs/source/en/api/pipelines/stable_diffusion/self_attention_guidance.mdx b/docs/source/en/api/pipelines/self_attention_guidance.mdx similarity index 100% rename from docs/source/en/api/pipelines/stable_diffusion/self_attention_guidance.mdx rename to docs/source/en/api/pipelines/self_attention_guidance.mdx diff --git a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx index 70731fd294b9..a163b57f2a84 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/overview.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/overview.mdx @@ -36,6 +36,7 @@ For more details about how Stable Diffusion works and how it differs from the ba | [StableDiffusionAttendAndExcitePipeline](./attend_and_excite) | **Experimental** – *Text-to-Image Generation * | | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://huggingface.co/spaces/AttendAndExcite/Attend-and-Excite) | [StableDiffusionPix2PixZeroPipeline](./pix2pix_zero) | **Experimental** – *Text-Based Image Editing * | | [Zero-shot Image-to-Image Translation](https://arxiv.org/abs/2302.03027) | [StableDiffusionModelEditingPipeline](./model_editing) | **Experimental** – *Text-to-Image Model Editing * | | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://arxiv.org/abs/2303.08084) +| [StableDiffusionDiffEditPipeline](./diffedit) | **Experimental** – *Text-Based Image Editing * | | [DiffEdit: Diffusion-based semantic image editing with mask guidance](https://arxiv.org/abs/2210.11427) diff --git a/docs/source/en/api/pipelines/stable_diffusion_2.mdx b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx similarity index 73% rename from docs/source/en/api/pipelines/stable_diffusion_2.mdx rename to docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx index e922072e4e31..7162626ebbde 100644 --- a/docs/source/en/api/pipelines/stable_diffusion_2.mdx +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_2.mdx @@ -71,6 +71,64 @@ image = pipe(prompt, guidance_scale=9, num_inference_steps=25).images[0] image.save("astronaut.png") ``` +#### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed": + +The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)** +claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. + +The abstract reads as follows: + +*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR), +and some implementations of diffusion samplers do not start from the last timestep. +Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference. +We show that the flawed design causes real problems in existing implementations. +In Stable Diffusion, it severely limits the model to only generate images with medium brightness and +prevents it from generating very bright and dark samples. We propose a few simple fixes: +- (1) rescale the noise schedule to enforce zero terminal SNR; +- (2) train the model with v prediction; +- (3) change the sampler to always start from the last timestep; +- (4) rescale classifier-free guidance to prevent over-exposure. +These simple changes ensure the diffusion process is congruent between training and inference and +allow the model to generate samples more faithful to the original data distribution.* + +You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]: +- (1) rescale the noise schedule to enforce zero terminal SNR; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True) +``` +- (2) train the model with v prediction; +Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) +and `--prediction_type="v_prediction"`. +- (3) change the sampler to always start from the last timestep; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing") +``` +- (4) rescale classifier-free guidance to prevent over-exposure. +```py +pipe(..., guidance_rescale=0.7) +``` + +An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2) +which has been fine-tuned using the `"v_prediction"`. + +The checkpoint can then be run in inference as follows: + +```py +from diffusers import DiffusionPipeline, DDIMScheduler + +pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16) +pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing" +) +pipe.to("cuda") + +prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" +image = pipeline(prompt, guidance_rescale=0.7).images[0] +``` + +## DDIMScheduler +[[autodoc]] DDIMScheduler + ### Image Inpainting - *Image Inpainting (512x512 resolution)*: [stabilityai/stable-diffusion-2-inpainting](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) with [`StableDiffusionInpaintPipeline`] diff --git a/docs/source/en/api/pipelines/stable_diffusion_safe.mdx b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.mdx similarity index 100% rename from docs/source/en/api/pipelines/stable_diffusion_safe.mdx rename to docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.mdx diff --git a/docs/source/en/api/pipelines/unidiffuser.mdx b/docs/source/en/api/pipelines/unidiffuser.mdx new file mode 100644 index 000000000000..10290e263e6d --- /dev/null +++ b/docs/source/en/api/pipelines/unidiffuser.mdx @@ -0,0 +1,204 @@ + + +# UniDiffuser + +The UniDiffuser model was proposed in [One Transformer Fits All Distributions in Multi-Modal Diffusion at Scale](https://arxiv.org/abs/2303.06555) by Fan Bao, Shen Nie, Kaiwen Xue, Chongxuan Li, Shi Pu, Yaole Wang, Gang Yue, Yue Cao, Hang Su, Jun Zhu. + +The abstract of the [paper](https://arxiv.org/abs/2303.06555) is the following: + +*This paper proposes a unified diffusion framework (dubbed UniDiffuser) to fit all distributions relevant to a set of multi-modal data in one model. Our key insight is -- learning diffusion models for marginal, conditional, and joint distributions can be unified as predicting the noise in the perturbed data, where the perturbation levels (i.e. timesteps) can be different for different modalities. Inspired by the unified view, UniDiffuser learns all distributions simultaneously with a minimal modification to the original diffusion model -- perturbs data in all modalities instead of a single modality, inputs individual timesteps in different modalities, and predicts the noise of all modalities instead of a single modality. UniDiffuser is parameterized by a transformer for diffusion models to handle input types of different modalities. Implemented on large-scale paired image-text data, UniDiffuser is able to perform image, text, text-to-image, image-to-text, and image-text pair generation by setting proper timesteps without additional overhead. In particular, UniDiffuser is able to produce perceptually realistic samples in all tasks and its quantitative results (e.g., the FID and CLIP score) are not only superior to existing general-purpose models but also comparable to the bespoken models (e.g., Stable Diffusion and DALL-E 2) in representative tasks (e.g., text-to-image generation).* + +Resources: + +* [Paper](https://arxiv.org/abs/2303.06555). +* [Original Code](https://github.com/thu-ml/unidiffuser). + +Available Checkpoints are: +- *UniDiffuser-v0 (512x512 resolution)* [thu-ml/unidiffuser-v0](https://huggingface.co/thu-ml/unidiffuser-v0) +- *UniDiffuser-v1 (512x512 resolution)* [thu-ml/unidiffuser-v1](https://huggingface.co/thu-ml/unidiffuser-v1) + +This pipeline was contributed by our community member [dg845](https://github.com/dg845). + +## Available Pipelines: + +| Pipeline | Tasks | Demo | Colab | +|:---:|:---:|:---:|:---:| +| [UniDiffuserPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pipeline_unidiffuser.py) | *Joint Image-Text Gen*, *Text-to-Image*, *Image-to-Text*,
*Image Gen*, *Text Gen*, *Image Variation*, *Text Variation* | [🤗 Spaces](https://huggingface.co/spaces/thu-ml/unidiffuser) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/unidiffuser.ipynb) | + +## Usage Examples + +Because the UniDiffuser model is trained to model the joint distribution of (image, text) pairs, it is capable of performing a diverse range of generation tasks. + +### Unconditional Image and Text Generation + +Unconditional generation (where we start from only latents sampled from a standard Gaussian prior) from a [`UniDiffuserPipeline`] will produce a (image, text) pair: + +```python +import torch + +from diffusers import UniDiffuserPipeline + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Unconditional image and text generation. The generation task is automatically inferred. +sample = pipe(num_inference_steps=20, guidance_scale=8.0) +image = sample.images[0] +text = sample.text[0] +image.save("unidiffuser_joint_sample_image.png") +print(text) +``` + +This is also called "joint" generation in the UniDiffusers paper, since we are sampling from the joint image-text distribution. + +Note that the generation task is inferred from the inputs used when calling the pipeline. +It is also possible to manually specify the unconditional generation task ("mode") manually with [`UniDiffuserPipeline.set_joint_mode`]: + +```python +# Equivalent to the above. +pipe.set_joint_mode() +sample = pipe(num_inference_steps=20, guidance_scale=8.0) +``` + +When the mode is set manually, subsequent calls to the pipeline will use the set mode without attempting the infer the mode. +You can reset the mode with [`UniDiffuserPipeline.reset_mode`], after which the pipeline will once again infer the mode. + +You can also generate only an image or only text (which the UniDiffuser paper calls "marginal" generation since we sample from the marginal distribution of images and text, respectively): + +```python +# Unlike other generation tasks, image-only and text-only generation don't use classifier-free guidance +# Image-only generation +pipe.set_image_mode() +sample_image = pipe(num_inference_steps=20).images[0] +# Text-only generation +pipe.set_text_mode() +sample_text = pipe(num_inference_steps=20).text[0] +``` + +### Text-to-Image Generation + +UniDiffuser is also capable of sampling from conditional distributions; that is, the distribution of images conditioned on a text prompt or the distribution of texts conditioned on an image. +Here is an example of sampling from the conditional image distribution (text-to-image generation or text-conditioned image generation): + +```python +import torch + +from diffusers import UniDiffuserPipeline + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Text-to-image generation +prompt = "an elephant under the sea" + +sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0) +t2i_image = sample.images[0] +t2i_image.save("unidiffuser_text2img_sample_image.png") +``` + +The `text2img` mode requires that either an input `prompt` or `prompt_embeds` be supplied. You can set the `text2img` mode manually with [`UniDiffuserPipeline.set_text_to_image_mode`]. + +### Image-to-Text Generation + +Similarly, UniDiffuser can also produce text samples given an image (image-to-text or image-conditioned text generation): + +```python +import torch + +from diffusers import UniDiffuserPipeline +from diffusers.utils import load_image + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Image-to-text generation +image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" +init_image = load_image(image_url).resize((512, 512)) + +sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) +i2t_text = sample.text[0] +print(i2t_text) +``` + +The `img2text` mode requires that an input `image` be supplied. You can set the `img2text` mode manually with [`UniDiffuserPipeline.set_image_to_text_mode`]. + +### Image Variation + +The UniDiffuser authors suggest performing image variation through a "round-trip" generation method, where given an input image, we first perform an image-to-text generation, and the perform a text-to-image generation on the outputs of the first generation. +This produces a new image which is semantically similar to the input image: + +```python +import torch + +from diffusers import UniDiffuserPipeline +from diffusers.utils import load_image + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Image variation can be performed with a image-to-text generation followed by a text-to-image generation: +# 1. Image-to-text generation +image_url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" +init_image = load_image(image_url).resize((512, 512)) + +sample = pipe(image=init_image, num_inference_steps=20, guidance_scale=8.0) +i2t_text = sample.text[0] +print(i2t_text) + +# 2. Text-to-image generation +sample = pipe(prompt=i2t_text, num_inference_steps=20, guidance_scale=8.0) +final_image = sample.images[0] +final_image.save("unidiffuser_image_variation_sample.png") +``` + +### Text Variation + + +Similarly, text variation can be performed on an input prompt with a text-to-image generation followed by a image-to-text generation: + +```python +import torch + +from diffusers import UniDiffuserPipeline + +device = "cuda" +model_id_or_path = "thu-ml/unidiffuser-v1" +pipe = UniDiffuserPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) +pipe.to(device) + +# Text variation can be performed with a text-to-image generation followed by a image-to-text generation: +# 1. Text-to-image generation +prompt = "an elephant under the sea" + +sample = pipe(prompt=prompt, num_inference_steps=20, guidance_scale=8.0) +t2i_image = sample.images[0] +t2i_image.save("unidiffuser_text2img_sample_image.png") + +# 2. Image-to-text generation +sample = pipe(image=t2i_image, num_inference_steps=20, guidance_scale=8.0) +final_prompt = sample.text[0] +print(final_prompt) +``` + +## UniDiffuserPipeline +[[autodoc]] UniDiffuserPipeline + - all + - __call__ diff --git a/docs/source/en/api/schedulers/ddim.mdx b/docs/source/en/api/schedulers/ddim.mdx index 51b0cc3e9a09..0db5e4f4e2b5 100644 --- a/docs/source/en/api/schedulers/ddim.mdx +++ b/docs/source/en/api/schedulers/ddim.mdx @@ -18,10 +18,71 @@ specific language governing permissions and limitations under the License. The abstract of the paper is the following: -Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, yet they require simulating a Markov chain for many steps to produce a sample. To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space. +*Denoising diffusion probabilistic models (DDPMs) have achieved high quality image generation without adversarial training, +yet they require simulating a Markov chain for many steps to produce a sample. +To accelerate sampling, we present denoising diffusion implicit models (DDIMs), a more efficient class of iterative implicit probabilistic models +with the same training procedure as DDPMs. In DDPMs, the generative process is defined as the reverse of a Markovian diffusion process. +We construct a class of non-Markovian diffusion processes that lead to the same training objective, but whose reverse process can be much faster to sample from. +We empirically demonstrate that DDIMs can produce high quality samples 10× to 50× faster in terms of wall-clock time compared to DDPMs, allow us to trade off +computation for sample quality, and can perform semantically meaningful image interpolation directly in the latent space.* The original codebase of this paper can be found here: [ermongroup/ddim](https://github.com/ermongroup/ddim). For questions, feel free to contact the author on [tsong.me](https://tsong.me/). +### Experimental: "Common Diffusion Noise Schedules and Sample Steps are Flawed": + +The paper **[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/abs/2305.08891)** +claims that a mismatch between the training and inference settings leads to suboptimal inference generation results for Stable Diffusion. + +The abstract reads as follows: + +*We discover that common diffusion noise schedules do not enforce the last timestep to have zero signal-to-noise ratio (SNR), +and some implementations of diffusion samplers do not start from the last timestep. +Such designs are flawed and do not reflect the fact that the model is given pure Gaussian noise at inference, creating a discrepancy between training and inference. +We show that the flawed design causes real problems in existing implementations. +In Stable Diffusion, it severely limits the model to only generate images with medium brightness and +prevents it from generating very bright and dark samples. We propose a few simple fixes: +- (1) rescale the noise schedule to enforce zero terminal SNR; +- (2) train the model with v prediction; +- (3) change the sampler to always start from the last timestep; +- (4) rescale classifier-free guidance to prevent over-exposure. +These simple changes ensure the diffusion process is congruent between training and inference and +allow the model to generate samples more faithful to the original data distribution.* + +You can apply all of these changes in `diffusers` when using [`DDIMScheduler`]: +- (1) rescale the noise schedule to enforce zero terminal SNR; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, rescale_betas_zero_snr=True) +``` +- (2) train the model with v prediction; +Continue fine-tuning a checkpoint with [`train_text_to_image.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) or [`train_text_to_image_lora.py`](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py) +and `--prediction_type="v_prediction"`. +- (3) change the sampler to always start from the last timestep; +```py +pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config, timestep_scaling="trailing") +``` +- (4) rescale classifier-free guidance to prevent over-exposure. +```py +pipe(..., guidance_rescale=0.7) +``` + +An example is to use [this checkpoint](https://huggingface.co/ptx0/pseudo-journey-v2) +which has been fine-tuned using the `"v_prediction"`. + +The checkpoint can then be run in inference as follows: + +```py +from diffusers import DiffusionPipeline, DDIMScheduler + +pipe = DiffusionPipeline.from_pretrained("ptx0/pseudo-journey-v2", torch_dtype=torch.float16) +pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, rescale_betas_zero_snr=True, timestep_scaling="trailing" +) +pipe.to("cuda") + +prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" +image = pipeline(prompt, guidance_rescale=0.7).images[0] +``` + ## DDIMScheduler [[autodoc]] DDIMScheduler diff --git a/docs/source/en/api/schedulers/dpm_sde.mdx b/docs/source/en/api/schedulers/dpm_sde.mdx new file mode 100644 index 000000000000..33ec514cef64 --- /dev/null +++ b/docs/source/en/api/schedulers/dpm_sde.mdx @@ -0,0 +1,23 @@ + + +# DPM Stochastic Scheduler inspired by Karras et. al paper + +## Overview + +Inspired by Stochastic Sampler from [Karras et. al](https://arxiv.org/abs/2206.00364). +Scheduler ported from @crowsonkb's https://github.com/crowsonkb/k-diffusion library: + +All credit for making this scheduler work goes to [Katherine Crowson](https://github.com/crowsonkb/) + +## DPMSolverSDEScheduler +[[autodoc]] DPMSolverSDEScheduler \ No newline at end of file diff --git a/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx b/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx new file mode 100644 index 000000000000..1b3348a5a3ea --- /dev/null +++ b/docs/source/en/api/schedulers/multistep_dpm_solver_inverse.mdx @@ -0,0 +1,22 @@ + + +# Inverse Multistep DPM-Solver (DPMSolverMultistepInverse) + +## Overview + +This scheduler is the inverted scheduler of [DPM-Solver: A Fast ODE Solver for Diffusion Probabilistic Model Sampling in Around 10 Steps](https://arxiv.org/abs/2206.00927) and [DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models +](https://arxiv.org/abs/2211.01095) by Cheng Lu, Yuhao Zhou, Fan Bao, Jianfei Chen, Chongxuan Li, and Jun Zhu. +The implementation is mostly based on the DDIM inversion definition of [Null-text Inversion for Editing Real Images using Guided Diffusion Models](https://arxiv.org/pdf/2211.09794.pdf) and the ad-hoc notebook implementation for DiffEdit latent inversion [here](https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/diffedit.ipynb). + +## DPMSolverMultistepInverseScheduler +[[autodoc]] DPMSolverMultistepInverseScheduler diff --git a/docs/source/en/api/utilities.mdx b/docs/source/en/api/utilities.mdx new file mode 100644 index 000000000000..16143a2a66a6 --- /dev/null +++ b/docs/source/en/api/utilities.mdx @@ -0,0 +1,23 @@ +# Utilities + +Utility and helper functions for working with 🤗 Diffusers. + +## randn_tensor + +[[autodoc]] diffusers.utils.randn_tensor + +## numpy_to_pil + +[[autodoc]] utils.pil_utils.numpy_to_pil + +## pt_to_pil + +[[autodoc]] utils.pil_utils.pt_to_pil + +## load_image + +[[autodoc]] utils.testing_utils.load_image + +## export_to_video + +[[autodoc]] utils.testing_utils.export_to_video \ No newline at end of file diff --git a/docs/source/en/conceptual/contribution.mdx b/docs/source/en/conceptual/contribution.mdx index 7b78d318b679..ea1d15f2124c 100644 --- a/docs/source/en/conceptual/contribution.mdx +++ b/docs/source/en/conceptual/contribution.mdx @@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License. We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation –not just code– are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it! -Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. Join us on Discord +Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. Join us on Discord Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility. diff --git a/docs/source/en/conceptual/evaluation.mdx b/docs/source/en/conceptual/evaluation.mdx index 2721adea0c16..6e5c14acad4e 100644 --- a/docs/source/en/conceptual/evaluation.mdx +++ b/docs/source/en/conceptual/evaluation.mdx @@ -37,7 +37,8 @@ We cover Diffusion models with the following pipelines: ## Qualitative Evaluation -Qualitative evaluation typically involves human assessment of generated images. Quality is measured across aspects such as compositionality, image-text alignment, and spatial relations. Common prompts provide a degree of uniformity for subjective metrics. DrawBench and PartiPrompts are prompt datasets used for qualitative benchmarking. DrawBench and PartiPrompts were introduced by [Imagen](https://imagen.research.google/) and [Parti](https://parti.research.google/) respectively. +Qualitative evaluation typically involves human assessment of generated images. Quality is measured across aspects such as compositionality, image-text alignment, and spatial relations. Common prompts provide a degree of uniformity for subjective metrics. +DrawBench and PartiPrompts are prompt datasets used for qualitative benchmarking. DrawBench and PartiPrompts were introduced by [Imagen](https://imagen.research.google/) and [Parti](https://parti.research.google/) respectively. From the [official Parti website](https://parti.research.google/): @@ -51,7 +52,13 @@ PartiPrompts has the following columns: - Category of the prompt (such as “Abstract”, “World Knowledge”, etc.) - Challenge reflecting the difficulty (such as “Basic”, “Complex”, “Writing & Symbols”, etc.) -These benchmarks allow for side-by-side human evaluation of different image generation models. Let’s see how we can use `diffusers` on a couple of PartiPrompts. +These benchmarks allow for side-by-side human evaluation of different image generation models. + +For this, the 🧨 Diffusers team has built **Open Parti Prompts**, which is a community-driven qualitative benchmark based on Parti Prompts to compare state-of-the-art open-source diffusion models: +- [Open Parti Prompts Game](https://huggingface.co/spaces/OpenGenAI/open-parti-prompts): For 10 parti prompts, 4 generated images are shown and the user selects the image that suits the prompt best. +- [Open Parti Prompts Leaderboard](https://huggingface.co/spaces/OpenGenAI/parti-prompts-leaderboard): The leaderboard comparing the currently best open-sourced diffusion models to each other. + +To manually compare images, let’s see how we can use `diffusers` on a couple of PartiPrompts. Below we show some prompts sampled across different challenges: Basic, Complex, Linguistic Structures, Imagination, and Writing & Symbols. Here we are using PartiPrompts as a [dataset](https://huggingface.co/datasets/nateraw/parti-prompts). diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 46a985ac2f8d..66548663827a 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -53,7 +53,7 @@ The library has three main components: |---|---|:---:| | [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | | [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | -| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | +| [controlnet](./api/pipelines/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | | [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | | [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | | [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | diff --git a/docs/source/en/installation.mdx b/docs/source/en/installation.mdx index 8639bcfca95b..218ccd7bc4f6 100644 --- a/docs/source/en/installation.mdx +++ b/docs/source/en/installation.mdx @@ -12,9 +12,9 @@ specific language governing permissions and limitations under the License. # Installation -Install 🤗 Diffusers for whichever deep learning library you’re working with. +Install 🤗 Diffusers for whichever deep learning library you're working with. -🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and flax. Follow the installation instructions below for the deep learning library you are using: +🤗 Diffusers is tested on Python 3.7+, PyTorch 1.7.0+ and Flax. Follow the installation instructions below for the deep learning library you are using: - [PyTorch](https://pytorch.org/get-started/locally/) installation instructions. - [Flax](https://flax.readthedocs.io/en/latest/) installation instructions. @@ -37,27 +37,28 @@ Activate the virtual environment: source .env/bin/activate ``` -Now you're ready to install 🤗 Diffusers with the following command: - -**For PyTorch** +🤗 Diffusers also relies on the 🤗 Transformers library, and you can install both with the following command: + + ```bash -pip install diffusers["torch"] +pip install diffusers["torch"] transformers ``` - -**For Flax** - + + ```bash -pip install diffusers["flax"] +pip install diffusers["flax"] transformers ``` + + ## Install from source -Before intsalling `diffusers` from source, make sure you have `torch` and `accelerate` installed. +Before installing 🤗 Diffusers from source, make sure you have `torch` and 🤗 Accelerate installed. -For `torch` installation refer to the `torch` [docs](https://pytorch.org/get-started/locally/#start-locally). +For `torch` installation, refer to the `torch` [installation](https://pytorch.org/get-started/locally/#start-locally) guide. -To install `accelerate` +To install 🤗 Accelerate: ```bash pip install accelerate @@ -74,7 +75,7 @@ The `main` version is useful for staying up-to-date with the latest developments For instance, if a bug has been fixed since the last official release but a new release hasn't been rolled out yet. However, this means the `main` version may not always be stable. We strive to keep the `main` version operational, and most issues are usually resolved within a few hours or a day. -If you run into a problem, please open an [Issue](https://github.com/huggingface/transformers/issues), so we can fix it even sooner! +If you run into a problem, please open an [Issue](https://github.com/huggingface/diffusers/issues/new/choose), so we can fix it even sooner! ## Editable install @@ -90,21 +91,22 @@ git clone https://github.com/huggingface/diffusers.git cd diffusers ``` -**For PyTorch** - -``` + + +```bash pip install -e ".[torch]" ``` - -**For Flax** - -``` + + +```bash pip install -e ".[flax]" ``` + + These commands will link the folder you cloned the repository to and your Python library paths. Python will now look inside the folder you cloned to in addition to the normal library paths. -For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the folder you cloned to: `~/diffusers/`. +For example, if your Python packages are typically installed in `~/anaconda3/envs/main/lib/python3.7/site-packages/`, Python will also search the `~/diffusers/` folder you cloned to. diff --git a/docs/source/en/optimization/fp16.mdx b/docs/source/en/optimization/fp16.mdx index d05c5aabea2b..8b3a62cba099 100644 --- a/docs/source/en/optimization/fp16.mdx +++ b/docs/source/en/optimization/fp16.mdx @@ -50,7 +50,6 @@ from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", - torch_dtype=torch.float16, ) pipe = pipe.to("cuda") @@ -60,8 +59,10 @@ image = pipe(prompt).images[0] ``` + It is strongly discouraged to make use of [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) in any of the pipelines as it can lead to black images and is always slower than using pure float16 precision. + ## Sliced attention for additional memory savings @@ -83,7 +84,6 @@ from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", - torch_dtype=torch.float16, ) pipe = pipe.to("cuda") @@ -110,7 +110,6 @@ from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", - torch_dtype=torch.float16, ) pipe = pipe.to("cuda") @@ -164,7 +163,6 @@ from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", - torch_dtype=torch.float16, ) @@ -189,7 +187,6 @@ from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", - torch_dtype=torch.float16, ) @@ -202,6 +199,8 @@ image = pipe(prompt).images[0] **Note**: When using `enable_sequential_cpu_offload()`, it is important to **not** move the pipeline to CUDA beforehand or else the gain in memory consumption will only be minimal. See [this issue](https://github.com/huggingface/diffusers/issues/1934) for more information. +**Note**: `enable_sequential_cpu_offload()` is a stateful operation that installs hooks on the models. + ## Model offloading for fast inference and memory savings @@ -251,6 +250,11 @@ image = pipe(prompt).images[0] This feature requires `accelerate` version 0.17.0 or larger. +**Note**: `enable_model_cpu_offload()` is a stateful operation that installs hooks on the models and state on the pipeline. In order to properly offload +models after they are called, it is required that the entire pipeline is run and models are called in the order the pipeline expects them to be. Exercise caution +if models are re-used outside the context of the pipeline after hooks have been installed. See [accelerate](https://huggingface.co/docs/accelerate/v0.18.0/en/package_reference/big_modeling#accelerate.hooks.remove_hook_from_module) +for further docs on removing hooks. + ## Using Channels Last memory format Channels last memory format is an alternative way of ordering NCHW tensors in memory preserving dimensions ordering. Channels last tensors ordered in such a way that channels become the densest dimension (aka storing images pixel-per-pixel). Since not all operators currently support channels last format it may result in a worst performance, so it's better to try it and see if it works for your model. @@ -400,7 +404,14 @@ Here are the speedups we obtain on a few Nvidia GPUs when running the inference | A100-SXM4-40GB | 18.6it/s | 29.it/s | | A100-SXM-80GB | 18.7it/s | 29.5it/s | -To leverage it just make sure you have: +To leverage it just make sure you have: + + + +If you have PyTorch 2.0 installed, you shouldn't use xFormers! + + + - PyTorch > 1.12 - Cuda available - [Installed the xformers library](xformers). diff --git a/docs/source/en/optimization/torch2.0.mdx b/docs/source/en/optimization/torch2.0.mdx index 206ac4e447cc..6e8466fd6ecc 100644 --- a/docs/source/en/optimization/torch2.0.mdx +++ b/docs/source/en/optimization/torch2.0.mdx @@ -12,19 +12,21 @@ specific language governing permissions and limitations under the License. # Accelerated PyTorch 2.0 support in Diffusers -Starting from version `0.13.0`, Diffusers supports the latest optimization from the upcoming [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) release. These include: -1. Support for accelerated transformers implementation with memory-efficient attention – no extra dependencies required. +Starting from version `0.13.0`, Diffusers supports the latest optimization from [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). These include: +1. Support for accelerated transformers implementation with memory-efficient attention – no extra dependencies (such as `xformers`) required. 2. [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) support for extra performance boost when individual models are compiled. ## Installation -To benefit from the accelerated attention implementation and `torch.compile`, you just need to install the latest versions of PyTorch 2.0 from `pip`, and make sure you are on diffusers 0.13.0 or later. As explained below, `diffusers` automatically uses the attention optimizations (but not `torch.compile`) when available. + +To benefit from the accelerated attention implementation and `torch.compile()`, you just need to install the latest versions of PyTorch 2.0 from pip, and make sure you are on diffusers 0.13.0 or later. As explained below, diffusers automatically uses the optimized attention processor ([`AttnProcessor2_0`](https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/src/diffusers/models/attention_processor.py#L798)) (but not `torch.compile()`) +when PyTorch 2.0 is available. ```bash -pip install --upgrade torch torchvision diffusers +pip install --upgrade torch diffusers ``` -## Using accelerated transformers and torch.compile. +## Using accelerated transformers and `torch.compile`. 1. **Accelerated Transformers implementation** @@ -46,13 +48,13 @@ pip install --upgrade torch torchvision diffusers If you want to enable it explicitly (which is not required), you can do so as shown below. - ```Python + ```diff import torch from diffusers import DiffusionPipeline - from diffusers.models.attention_processor import AttnProcessor2_0 + + from diffusers.models.attention_processor import AttnProcessor2_0 pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") - pipe.unet.set_attn_processor(AttnProcessor2_0()) + + pipe.unet.set_attn_processor(AttnProcessor2_0()) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] @@ -60,151 +62,383 @@ pip install --upgrade torch torchvision diffusers This should be as fast and memory efficient as `xFormers`. More details [in our benchmark](#benchmark). + It is possible to revert to the vanilla attention processor ([`AttnProcessor`](https://github.com/huggingface/diffusers/blob/1a5797c6d4491a879ea5285c4efc377664e0332d/src/diffusers/models/attention_processor.py#L402)), which can be helpful to make the pipeline more deterministic, or if you need to convert a fine-tuned model to other formats such as [Core ML](https://huggingface.co/docs/diffusers/v0.16.0/en/optimization/coreml#how-to-run-stable-diffusion-with-core-ml). To use the normal attention processor you can use the [`~diffusers.UNet2DConditionModel.set_default_attn_processor`] function: -2. **torch.compile** - - To get an additional speedup, we can use the new `torch.compile` feature. To do so, we simply wrap our `unet` with `torch.compile`. For more information and different options, refer to the - [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). - - ```python + ```Python import torch from diffusers import DiffusionPipeline + from diffusers.models.attention_processor import AttnProcessor pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") - pipe.unet = torch.compile(pipe.unet) + pipe.unet.set_default_attn_processor() - batch_size = 10 - prompt = "A photo of an astronaut riding a horse on marse." + prompt = "a photo of an astronaut riding a horse on mars" + image = pipe(prompt).images[0] + ``` + +2. **torch.compile** + + To get an additional speedup, we can use the new `torch.compile` feature. Since the UNet of the pipeline is usually the most computationally expensive, we wrap the `unet` with `torch.compile` leaving rest of the sub-models (text encoder and VAE) as is. For more information and different options, refer to the + [torch compile docs](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html). + + ```python + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images ``` - Depending on the type of GPU, `compile()` can yield between 2-9% of _additional speed-up_ over the accelerated transformer optimizations. Note, however, that compilation is able to squeeze more performance improvements in more recent GPU architectures such as Ampere (A100, 3090), Ada (4090) and Hopper (H100). + Depending on the type of GPU, `compile()` can yield between **5% - 300%** of _additional speed-up_ over the accelerated transformer optimizations. Note, however, that compilation is able to squeeze more performance improvements in more recent GPU architectures such as Ampere (A100, 3090), Ada (4090) and Hopper (H100). - Compilation takes some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times. + Compilation takes some time to complete, so it is best suited for situations where you need to prepare your pipeline once and then perform the same type of inference operations multiple times. Calling the compiled pipeline on a different image size will re-trigger compilation which can be expensive. ## Benchmark -We conducted a simple benchmark on different GPUs to compare vanilla attention, xFormers, `torch.nn.functional.scaled_dot_product_attention` and `torch.compile+torch.nn.functional.scaled_dot_product_attention`. -For the benchmark we used the [stable-diffusion-v1-4](https://huggingface.co/CompVis/stable-diffusion-v1-4) model with 50 steps. The `xFormers` benchmark is done using the `torch==1.13.1` version, while the accelerated transformers optimizations are tested using nightly versions of PyTorch 2.0. The tables below summarize the results we got. - -Please refer to [our featured blog post in the PyTorch site](https://pytorch.org/blog/accelerated-diffusers-pt-20/) for more details. - -### FP16 benchmark - -The table below shows the benchmark results for inference using `fp16`. As we can see, `torch.nn.functional.scaled_dot_product_attention` is as fast as `xFormers` (sometimes slightly faster/slower) on all the GPUs we tested. -And using `torch.compile` gives further speed-up of up of 10% over `xFormers`, but it's mostly noticeable on the A100 GPU. - -___The time reported is in seconds.___ - -| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | -| --- | --- | --- | --- | --- | --- | --- | -| A100 | 1 | 2.69 | 2.7 | 1.98 | 2.47 | 8.52 | -| A100 | 2 | 3.21 | 3.04 | 2.38 | 2.78 | 8.55 | -| A100 | 4 | 5.27 | 3.91 | 3.89 | 3.53 | 9.72 | -| A100 | 8 | 9.74 | 7.03 | 7.04 | 6.62 | 5.83 | -| A100 | 10 | 12.02 | 8.7 | 8.67 | 8.45 | 2.87 | -| A100 | 16 | 18.95 | 13.57 | 13.55 | 13.20 | 2.73 | -| A100 | 32 (1) | OOM | 26.56 | 26.68 | 25.85 | 2.67 | -| A100 | 64 | | 52.51 | 53.03 | 50.93 | 3.01 | -| | | | | | | | -| A10 | 4 | 13.94 | 9.81 | 10.01 | 9.35 | 4.69 | -| A10 | 8 | 27.09 | 19 | 19.53 | 18.33 | 3.53 | -| A10 | 10 | 33.69 | 23.53 | 24.19 | 22.52 | 4.29 | -| A10 | 16 | OOM | 37.55 | 38.31 | 36.81 | 1.97 | -| A10 | 32 (1) | | 77.19 | 78.43 | 76.64 | 0.71 | -| A10 | 64 (1) | | 173.59 | 158.99 | 155.14 | 10.63 | -| | | | | | | | -| T4 | 4 | 38.81 | 30.09 | 29.74 | 27.55 | 8.44 | -| T4 | 8 | OOM | 55.71 | 55.99 | 53.85 | 3.34 | -| T4 | 10 | OOM | 68.96 | 69.86 | 65.35 | 5.23 | -| T4 | 16 | OOM | 111.47 | 113.26 | 106.93 | 4.07 | -| | | | | | | | -| V100 | 4 | 9.84 | 8.16 | 8.09 | 7.65 | 6.25 | -| V100 | 8 | OOM | 15.62 | 15.44 | 14.59 | 6.59 | -| V100 | 10 | OOM | 19.52 | 19.28 | 18.18 | 6.86 | -| V100 | 16 | OOM | 30.29 | 29.84 | 28.22 | 6.83 | -| | | | | | | | -| 3090 | 1 | 2.94 | 2.5 | 2.42 | 2.33 | 6.80 | -| 3090 | 4 | 10.04 | 7.82 | 7.72 | 7.38 | 5.63 | -| 3090 | 8 | 19.27 | 14.97 | 14.88 | 14.15 | 5.48 | -| 3090 | 10| 24.08 | 18.7 | 18.62 | 18.12 | 3.10 | -| 3090 | 16 | OOM | 29.06 | 28.88 | 28.2 | 2.96 | -| 3090 | 32 (1) | | 58.05 | 57.42 | 56.28 | 3.05 | -| 3090 | 64 (1) | | 126.54 | 114.27 | 112.21 | 11.32 | -| | | | | | | | -| 3090 Ti | 1 | 2.7 | 2.26 | 2.19 | 2.12 | 6.19 | -| 3090 Ti | 4 | 9.07 | 7.14 | 7.00 | 6.71 | 6.02 | -| 3090 Ti | 8 | 17.51 | 13.65 | 13.53 | 12.94 | 5.20 | -| 3090 Ti | 10 (2) | 21.79 | 16.85 | 16.77 | 16.44 | 2.43 | -| 3090 Ti | 16 | OOM | 26.1 | 26.04 | 25.53 | 2.18 | -| 3090 Ti | 32 (1) | | 51.78 | 51.71 | 50.91 | 1.68 | -| 3090 Ti | 64 (1) | | 112.02 | 102.78 | 100.89 | 9.94 | -| | | | | | | | -| 4090 | 1 | 4.47 | 3.98 | 1.28 | 1.21 | 69.60 | -| 4090 | 4 | 10.48 | 8.37 | 3.76 | 3.56 | 57.47 | -| 4090 | 8 | 14.33 | 10.22 | 7.43 | 6.99 | 31.60 | -| 4090 | 16 | | 17.07 | 14.98 | 14.58 | 14.59 | -| 4090 | 32 (1) | | 39.03 | 30.18 | 29.49 | 24.44 | -| 4090 | 64 (1) | | 77.29 | 61.34 | 59.96 | 22.42 | - - - -### FP32 benchmark - -The table below shows the benchmark results for inference using `fp32`. In this case, `torch.nn.functional.scaled_dot_product_attention` is faster than `xFormers` on all the GPUs we tested. - -Using `torch.compile` in addition to the accelerated transformers implementation can yield up to 19% performance improvement over `xFormers` in Ampere and Ada cards, and up to 20% (Ampere) or 28% (Ada) over vanilla attention. - -| GPU | Batch Size | Vanilla Attention | xFormers | PyTorch2.0 SDPA | SDPA + torch.compile | Speed over xformers (%) | Speed over vanilla (%) | -| --- | --- | --- | --- | --- | --- | --- | --- | -| A100 | 1 | 4.97 | 3.86 | 2.6 | 2.86 | 25.91 | 42.45 | -| A100 | 2 | 9.03 | 6.76 | 4.41 | 4.21 | 37.72 | 53.38 | -| A100 | 4 | 16.70 | 12.42 | 7.94 | 7.54 | 39.29 | 54.85 | -| A100 | 10 | OOM | 29.93 | 18.70 | 18.46 | 38.32 | | -| A100 | 16 | | 47.08 | 29.41 | 29.04 | 38.32 | | -| A100 | 32 | | 92.89 | 57.55 | 56.67 | 38.99 | | -| A100 | 64 | | 185.3 | 114.8 | 112.98 | 39.03 | | -| | | | | | | | -| A10 | 1 | 10.59 | 8.81 | 7.51 | 7.35 | 16.57 | 30.59 | -| A10 | 4 | 34.77 | 27.63 | 22.77 | 22.07 | 20.12 | 36.53 | -| A10 | 8 | | 56.19 | 43.53 | 43.86 | 21.94 | | -| A10 | 16 | | 116.49 | 88.56 | 86.64 | 25.62 | | -| A10 | 32 | | 221.95 | 175.74 | 168.18 | 24.23 | | -| A10 | 48 | | 333.23 | 264.84 | | 20.52 | | -| | | | | | | | -| T4 | 1 | 28.2 | 24.49 | 23.93 | 23.56 | 3.80 | 16.45 | -| T4 | 2 | 52.77 | 45.7 | 45.88 | 45.06 | 1.40 | 14.61 | -| T4 | 4 | OOM | 85.72 | 85.78 | 84.48 | 1.45 | | -| T4 | 8 | | 149.64 | 150.75 | 148.4 | 0.83 | | -| | | | | | | | -| V100 | 1 | 7.4 | 6.84 | 6.8 | 6.66 | 2.63 | 10.00 | -| V100 | 2 | 13.85 | 12.81 | 12.66 | 12.35 | 3.59 | 10.83 | -| V100 | 4 | OOM | 25.73 | 25.31 | 24.78 | 3.69 | | -| V100 | 8 | | 43.95 | 43.37 | 42.25 | 3.87 | | -| V100 | 16 | | 84.99 | 84.73 | 82.55 | 2.87 | | -| | | | | | | | -| 3090 | 1 | 7.09 | 6.78 | 5.34 | 5.35 | 21.09 | 24.54 | -| 3090 | 4 | 22.69 | 21.45 | 18.56 | 18.18 | 15.24 | 19.88 | -| 3090 | 8 | | 42.59 | 36.68 | 35.61 | 16.39 | | -| 3090 | 16 | | 85.35 | 72.93 | 70.18 | 17.77 | | -| 3090 | 32 (1) | | 162.05 | 143.46 | 138.67 | 14.43 | | -| | | | | | | | -| 3090 Ti | 1 | 6.45 | 6.19 | 4.99 | 4.89 | 21.00 | 24.19 | -| 3090 Ti | 4 | 20.32 | 19.31 | 17.02 | 16.48 | 14.66 | 18.90 | -| 3090 Ti | 8 | | 37.93 | 33.21 | 32.24 | 15.00 | | -| 3090 Ti | 16 | | 75.37 | 66.63 | 64.5 | 14.42 | | -| 3090 Ti | 32 (1) | | 142.55 | 128.89 | 124.92 | 12.37 | | -| | | | | | | | -| 4090 | 1 | 5.54 | 4.99 | 2.66 | 2.58 | 48.30 | 53.43 | -| 4090 | 4 | 13.67 | 11.4 | 8.81 | 8.46 | 25.79 | 38.11 | -| 4090 | 8 | | 19.79 | 17.55 | 16.62 | 16.02 | | -| 4090 | 16 | | 38.62 | 35.65 | 34.07 | 11.78 | | -| 4090 | 32 (1) | | 76.57 | 69.48 | 65.35 | 14.65 | | -| 4090 | 48 | | 114.44 | 106.3 | | 7.11 | | - - -(1) Batch Size >= 32 requires enable_vae_slicing() because of https://github.com/pytorch/pytorch/issues/81665. -This is required for PyTorch 1.13.1, and also for PyTorch 2.0 and large batch sizes. - -For more details about how this benchmark was run, please refer to [this PR](https://github.com/huggingface/diffusers/pull/2303) and to [the blog post](https://pytorch.org/blog/accelerated-diffusers-pt-20/). +We conducted a comprehensive benchmark with PyTorch 2.0's efficient attention implementation and `torch.compile` across different GPUs and batch sizes for five of our most used pipelines. We used `diffusers 0.17.0.dev0`, which [makes sure `torch.compile()` is leveraged optimally](https://github.com/huggingface/diffusers/pull/3313). + +### Benchmarking code + +#### Stable Diffusion text-to-image + +```python +from diffusers import DiffusionPipeline +import torch + +path = "runwayml/stable-diffusion-v1-5" + +run_compile = True # Set True / False + +pipe = DiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.unet.to(memory_format=torch.channels_last) + +if run_compile: + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +prompt = "ghibli style, a fantasy landscape with castles" + +for _ in range(3): + images = pipe(prompt=prompt).images +``` + +#### Stable Diffusion image-to-image + +```python +from diffusers import StableDiffusionImg2ImgPipeline +import requests +import torch +from PIL import Image +from io import BytesIO + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) + +path = "runwayml/stable-diffusion-v1-5" + +run_compile = True # Set True / False + +pipe = StableDiffusionImg2ImgPipeline.from_pretrained(path, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.unet.to(memory_format=torch.channels_last) + +if run_compile: + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +prompt = "ghibli style, a fantasy landscape with castles" + +for _ in range(3): + image = pipe(prompt=prompt, image=init_image).images[0] +``` + +#### Stable Diffusion - inpainting + +```python +from diffusers import StableDiffusionInpaintPipeline +import requests +import torch +from PIL import Image +from io import BytesIO + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +def download_image(url): + response = requests.get(url) + return Image.open(BytesIO(response.content)).convert("RGB") + + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) + +path = "runwayml/stable-diffusion-inpainting" + +run_compile = True # Set True / False + +pipe = StableDiffusionInpaintPipeline.from_pretrained(path, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.unet.to(memory_format=torch.channels_last) + +if run_compile: + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + +prompt = "ghibli style, a fantasy landscape with castles" + +for _ in range(3): + image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +``` + +#### ControlNet + +```python +from diffusers import StableDiffusionControlNetPipeline, ControlNetModel +import requests +import torch +from PIL import Image +from io import BytesIO + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +init_image = Image.open(BytesIO(response.content)).convert("RGB") +init_image = init_image.resize((512, 512)) + +path = "runwayml/stable-diffusion-v1-5" + +run_compile = True # Set True / False +controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) +pipe = StableDiffusionControlNetPipeline.from_pretrained( + path, controlnet=controlnet, torch_dtype=torch.float16 +) + +pipe = pipe.to("cuda") +pipe.unet.to(memory_format=torch.channels_last) +pipe.controlnet.to(memory_format=torch.channels_last) + +if run_compile: + print("Run torch compile") + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) + +prompt = "ghibli style, a fantasy landscape with castles" + +for _ in range(3): + image = pipe(prompt=prompt, image=init_image).images[0] +``` + +#### IF text-to-image + upscaling + +```python +from diffusers import DiffusionPipeline +import torch + +run_compile = True # Set True / False + +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16) +pipe.to("cuda") +pipe_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-M-v1.0", variant="fp16", text_encoder=None, torch_dtype=torch.float16) +pipe_2.to("cuda") +pipe_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", torch_dtype=torch.float16) +pipe_3.to("cuda") + + +pipe.unet.to(memory_format=torch.channels_last) +pipe_2.unet.to(memory_format=torch.channels_last) +pipe_3.unet.to(memory_format=torch.channels_last) + +if run_compile: + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + pipe_2.unet = torch.compile(pipe_2.unet, mode="reduce-overhead", fullgraph=True) + pipe_3.unet = torch.compile(pipe_3.unet, mode="reduce-overhead", fullgraph=True) + +prompt = "the blue hulk" + +prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16) +neg_prompt_embeds = torch.randn((1, 2, 4096), dtype=torch.float16) + +for _ in range(3): + image = pipe(prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images + image_2 = pipe_2(image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=neg_prompt_embeds, output_type="pt").images + image_3 = pipe_3(prompt=prompt, image=image, noise_level=100).images +``` + +To give you a pictorial overview of the possible speed-ups that can be obtained with PyTorch 2.0 and `torch.compile()`, +here is a plot that shows relative speed-ups for the [Stable Diffusion text-to-image pipeline](StableDiffusionPipeline) across five +different GPU families (with a batch size of 4): + +![t2i_speedup](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/t2i_speedup.png) + +To give you an even better idea of how this speed-up holds for the other pipelines presented above, consider the following +plot that shows the benchmarking numbers from an A100 across three different batch sizes +(with PyTorch 2.0 nightly and `torch.compile()`): + +![a100_numbers](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/pt2_benchmarks/a100_numbers.png) + +_(Our benchmarking metric for the plots above is **number of iterations/second**)_ + +But we reveal all the benchmarking numbers in the interest of transparency! + +In the following tables, we report our findings in terms of the number of **_iterations processed per second_**. + +### A100 (batch size: 1) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 21.66 | 23.13 | 44.03 | 49.74 | +| SD - img2img | 21.81 | 22.40 | 43.92 | 46.32 | +| SD - inpaint | 22.24 | 23.23 | 43.76 | 49.25 | +| SD - controlnet | 15.02 | 15.82 | 32.13 | 36.08 | +| IF | 20.21 /
13.84 /
24.00 | 20.12 /
13.70 /
24.03 | ❌ | 97.34 /
27.23 /
111.66 | + +### A100 (batch size: 4) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 11.6 | 13.12 | 14.62 | 17.27 | +| SD - img2img | 11.47 | 13.06 | 14.66 | 17.25 | +| SD - inpaint | 11.67 | 13.31 | 14.88 | 17.48 | +| SD - controlnet | 8.28 | 9.38 | 10.51 | 12.41 | +| IF | 25.02 | 18.04 | ❌ | 48.47 | + +### A100 (batch size: 16) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 3.04 | 3.6 | 3.83 | 4.68 | +| SD - img2img | 2.98 | 3.58 | 3.83 | 4.67 | +| SD - inpaint | 3.04 | 3.66 | 3.9 | 4.76 | +| SD - controlnet | 2.15 | 2.58 | 2.74 | 3.35 | +| IF | 8.78 | 9.82 | ❌ | 16.77 | + +### V100 (batch size: 1) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 18.99 | 19.14 | 20.95 | 22.17 | +| SD - img2img | 18.56 | 19.18 | 20.95 | 22.11 | +| SD - inpaint | 19.14 | 19.06 | 21.08 | 22.20 | +| SD - controlnet | 13.48 | 13.93 | 15.18 | 15.88 | +| IF | 20.01 /
9.08 /
23.34 | 19.79 /
8.98 /
24.10 | ❌ | 55.75 /
11.57 /
57.67 | + +### V100 (batch size: 4) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 5.96 | 5.89 | 6.83 | 6.86 | +| SD - img2img | 5.90 | 5.91 | 6.81 | 6.82 | +| SD - inpaint | 5.99 | 6.03 | 6.93 | 6.95 | +| SD - controlnet | 4.26 | 4.29 | 4.92 | 4.93 | +| IF | 15.41 | 14.76 | ❌ | 22.95 | + +### V100 (batch size: 16) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 1.66 | 1.66 | 1.92 | 1.90 | +| SD - img2img | 1.65 | 1.65 | 1.91 | 1.89 | +| SD - inpaint | 1.69 | 1.69 | 1.95 | 1.93 | +| SD - controlnet | 1.19 | 1.19 | OOM after warmup | 1.36 | +| IF | 5.43 | 5.29 | ❌ | 7.06 | + +### T4 (batch size: 1) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 6.9 | 6.95 | 7.3 | 7.56 | +| SD - img2img | 6.84 | 6.99 | 7.04 | 7.55 | +| SD - inpaint | 6.91 | 6.7 | 7.01 | 7.37 | +| SD - controlnet | 4.89 | 4.86 | 5.35 | 5.48 | +| IF | 17.42 /
2.47 /
18.52 | 16.96 /
2.45 /
18.69 | ❌ | 24.63 /
2.47 /
23.39 | + +### T4 (batch size: 4) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 1.79 | 1.79 | 2.03 | 1.99 | +| SD - img2img | 1.77 | 1.77 | 2.05 | 2.04 | +| SD - inpaint | 1.81 | 1.82 | 2.09 | 2.09 | +| SD - controlnet | 1.34 | 1.27 | 1.47 | 1.46 | +| IF | 5.79 | 5.61 | ❌ | 7.39 | + +### T4 (batch size: 16) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 2.34s | 2.30s | OOM after 2nd iteration | 1.99s | +| SD - img2img | 2.35s | 2.31s | OOM after warmup | 2.00s | +| SD - inpaint | 2.30s | 2.26s | OOM after 2nd iteration | 1.95s | +| SD - controlnet | OOM after 2nd iteration | OOM after 2nd iteration | OOM after warmup | OOM after warmup | +| IF * | 1.44 | 1.44 | ❌ | 1.94 | + +### RTX 3090 (batch size: 1) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 22.56 | 22.84 | 23.84 | 25.69 | +| SD - img2img | 22.25 | 22.61 | 24.1 | 25.83 | +| SD - inpaint | 22.22 | 22.54 | 24.26 | 26.02 | +| SD - controlnet | 16.03 | 16.33 | 17.38 | 18.56 | +| IF | 27.08 /
9.07 /
31.23 | 26.75 /
8.92 /
31.47 | ❌ | 68.08 /
11.16 /
65.29 | + +### RTX 3090 (batch size: 4) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 6.46 | 6.35 | 7.29 | 7.3 | +| SD - img2img | 6.33 | 6.27 | 7.31 | 7.26 | +| SD - inpaint | 6.47 | 6.4 | 7.44 | 7.39 | +| SD - controlnet | 4.59 | 4.54 | 5.27 | 5.26 | +| IF | 16.81 | 16.62 | ❌ | 21.57 | + +### RTX 3090 (batch size: 16) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 1.7 | 1.69 | 1.93 | 1.91 | +| SD - img2img | 1.68 | 1.67 | 1.93 | 1.9 | +| SD - inpaint | 1.72 | 1.71 | 1.97 | 1.94 | +| SD - controlnet | 1.23 | 1.22 | 1.4 | 1.38 | +| IF | 5.01 | 5.00 | ❌ | 6.33 | + +### RTX 4090 (batch size: 1) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 40.5 | 41.89 | 44.65 | 49.81 | +| SD - img2img | 40.39 | 41.95 | 44.46 | 49.8 | +| SD - inpaint | 40.51 | 41.88 | 44.58 | 49.72 | +| SD - controlnet | 29.27 | 30.29 | 32.26 | 36.03 | +| IF | 69.71 /
18.78 /
85.49 | 69.13 /
18.80 /
85.56 | ❌ | 124.60 /
26.37 /
138.79 | + +### RTX 4090 (batch size: 4) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 12.62 | 12.84 | 15.32 | 15.59 | +| SD - img2img | 12.61 | 12,.79 | 15.35 | 15.66 | +| SD - inpaint | 12.65 | 12.81 | 15.3 | 15.58 | +| SD - controlnet | 9.1 | 9.25 | 11.03 | 11.22 | +| IF | 31.88 | 31.14 | ❌ | 43.92 | + +### RTX 4090 (batch size: 16) + +| **Pipeline** | **torch 2.0 -
no compile** | **torch nightly -
no compile** | **torch 2.0 -
compile** | **torch nightly -
compile** | +|:---:|:---:|:---:|:---:|:---:| +| SD - txt2img | 3.17 | 3.2 | 3.84 | 3.85 | +| SD - img2img | 3.16 | 3.2 | 3.84 | 3.85 | +| SD - inpaint | 3.17 | 3.2 | 3.85 | 3.85 | +| SD - controlnet | 2.23 | 2.3 | 2.7 | 2.75 | +| IF | 9.26 | 9.2 | ❌ | 13.31 | + +## Notes + +* Follow [this PR](https://github.com/huggingface/diffusers/pull/3313) for more details on the environment used for conducting the benchmarks. +* For the IF pipeline and batch sizes > 1, we only used a batch size of >1 in the first IF pipeline for text-to-image generation and NOT for upscaling. So, that means the two upscaling pipelines received a batch size of 1. + +*Thanks to [Horace He](https://github.com/Chillee) from the PyTorch team for their support in improving our support of `torch.compile()` in Diffusers.* \ No newline at end of file diff --git a/docs/source/en/quicktour.mdx b/docs/source/en/quicktour.mdx index d494b79dccd5..2a2a5a3ad903 100644 --- a/docs/source/en/quicktour.mdx +++ b/docs/source/en/quicktour.mdx @@ -33,7 +33,7 @@ The quicktour is a simplified version of the introductory 🧨 Diffusers [notebo Before you begin, make sure you have all the necessary libraries installed: ```bash -pip install --upgrade diffusers accelerate transformers +!pip install --upgrade diffusers accelerate transformers ``` - [🤗 Accelerate](https://huggingface.co/docs/accelerate/index) speeds up model loading for inference and training. @@ -121,9 +121,9 @@ Save the image by calling `save`: You can also use the pipeline locally. The only difference is you need to download the weights first: -``` -git lfs install -git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 +```bash +!git lfs install +!git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 ``` Then load the saved weights into the pipeline: diff --git a/docs/source/en/stable_diffusion.mdx b/docs/source/en/stable_diffusion.mdx index eebe0ec660f2..78fa848421d8 100644 --- a/docs/source/en/stable_diffusion.mdx +++ b/docs/source/en/stable_diffusion.mdx @@ -153,7 +153,7 @@ def get_inputs(batch_size=1): You'll also need a function that'll display each batch of images: ```python -from PIL import image +from PIL import Image def image_grid(imgs, rows=2, cols=2): @@ -246,7 +246,7 @@ image_grid(images, rows=2, cols=4) Pretty impressive! Let's tweak the second image - corresponding to the `Generator` with a seed of `1` - a bit more by adding some text about the age of the subject: ```python -prommpts = [ +prompts = [ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta", @@ -266,6 +266,6 @@ image_grid(images) In this tutorial, you learned how to optimize a [`DiffusionPipeline`] for computational and memory efficiency as well as improving the quality of generated outputs. If you're interested in making your pipeline even faster, take a look at the following resources: -- Enable [xFormers](./optimization/xformers) memory efficient attention mechanism for faster speed and reduced memory consumption. -- Learn how in [PyTorch 2.0](./optimization/torch2.0), [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 2-9% faster inference speed. -- Many optimization techniques for inference are also included in this memory and speed [guide](./optimization/fp16), such as memory offloading. \ No newline at end of file +- Learn how [PyTorch 2.0](./optimization/torch2.0) and [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) can yield 5 - 300% faster inference speed. On an A100 GPU, inference can be up to 50% faster! +- If you can't use PyTorch 2, we recommend you install [xFormers](./optimization/xformers). Its memory-efficient attention mechanism works great with PyTorch 1.13.1 for faster speed and reduced memory consumption. +- Other optimization techniques, such as model offloading, are covered in [this guide](./optimization/fp16). diff --git a/docs/source/en/training/adapt_a_model.mdx b/docs/source/en/training/adapt_a_model.mdx new file mode 100644 index 000000000000..f1af5fca57a2 --- /dev/null +++ b/docs/source/en/training/adapt_a_model.mdx @@ -0,0 +1,42 @@ +# Adapt a model to a new task + +Many diffusion systems share the same components, allowing you to adapt a pretrained model for one task to an entirely different task. + +This guide will show you how to adapt a pretrained text-to-image model for inpainting by initializing and modifying the architecture of a pretrained [`UNet2DConditionModel`]. + +## Configure UNet2DConditionModel parameters + +A [`UNet2DConditionModel`] by default accepts 4 channels in the [input sample](https://huggingface.co/docs/diffusers/v0.16.0/en/api/models#diffusers.UNet2DConditionModel.in_channels). For example, load a pretrained text-to-image model like [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) and take a look at the number of `in_channels`: + +```py +from diffusers import StableDiffusionPipeline + +pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") +pipeline.unet.config["in_channels"] +4 +``` + +Inpainting requires 9 channels in the input sample. You can check this value in a pretrained inpainting model like [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting): + +```py +from diffusers import StableDiffusionPipeline + +pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting") +pipeline.unet.config["in_channels"] +9 +``` + +To adapt your text-to-image model for inpainting, you'll need to change the number of `in_channels` from 4 to 9. + +Initialize a [`UNet2DConditionModel`] with the pretrained text-to-image model weights, and change `in_channels` to 9. Changing the number of `in_channels` means you need to set `ignore_mismatched_sizes=True` and `low_cpu_mem_usage=False` to avoid a size mismatch error because the shape is different now. + +```py +from diffusers import UNet2DConditionModel + +model_id = "runwayml/stable-diffusion-v1-5" +unet = UNet2DConditionModel.from_pretrained( + model_id, subfolder="unet", in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True +) +``` + +The pretrained weights of the other components from the text-to-image model are initialized from their checkpoints, but the input channel weights (`conv_in.weight`) of the `unet` are randomly initialized. It is important to finetune the model for inpainting because otherwise the model returns noise. diff --git a/docs/source/en/training/controlnet.mdx b/docs/source/en/training/controlnet.mdx index 94e3d969b80a..16a9ba95f057 100644 --- a/docs/source/en/training/controlnet.mdx +++ b/docs/source/en/training/controlnet.mdx @@ -33,7 +33,12 @@ cd diffusers pip install -e . ``` -Then navigate into the example folder and run: +Then navigate into the [example folder](https://github.com/huggingface/diffusers/tree/main/examples/controlnet) +```bash +cd examples/controlnet +``` + +Now run: ```bash pip install -r requirements.txt ``` @@ -64,6 +69,8 @@ The original dataset is hosted in the ControlNet [repo](https://huggingface.co/l Our training examples use [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) because that is what the original set of ControlNet models was trained on. However, ControlNet can be trained to augment any compatible Stable Diffusion model (such as [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4)) or [`stabilityai/stable-diffusion-2-1`](https://huggingface.co/stabilityai/stable-diffusion-2-1). +To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. + ## Training Download the following images to condition our training with: @@ -74,7 +81,9 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png ``` -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. + +The training script creates and saves a `diffusion_pytorch_model.bin` file in your repository. ```bash export MODEL_DIR="runwayml/stable-diffusion-v1-5" @@ -88,7 +97,8 @@ accelerate launch train_controlnet.py \ --learning_rate=1e-5 \ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ - --train_batch_size=4 + --train_batch_size=4 \ + --push_to_hub ``` This default configuration requires ~38GB VRAM. @@ -111,7 +121,8 @@ accelerate launch train_controlnet.py \ --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ --train_batch_size=1 \ - --gradient_accumulation_steps=4 + --gradient_accumulation_steps=4 \ + --push_to_hub ``` ## Training with multiple GPUs @@ -134,7 +145,8 @@ accelerate launch --mixed_precision="fp16" --multi_gpu train_controlnet.py \ --train_batch_size=4 \ --mixed_precision="fp16" \ --tracker_project_name="controlnet-demo" \ - --report_to=wandb + --report_to=wandb \ + --push_to_hub ``` ## Example results @@ -182,7 +194,8 @@ accelerate launch train_controlnet.py \ --train_batch_size=1 \ --gradient_accumulation_steps=4 \ --gradient_checkpointing \ - --use_8bit_adam + --use_8bit_adam \ + --push_to_hub ``` ## Training on a 12 GB GPU @@ -210,7 +223,8 @@ accelerate launch train_controlnet.py \ --gradient_checkpointing \ --use_8bit_adam \ --enable_xformers_memory_efficient_attention \ - --set_grads_to_none + --set_grads_to_none \ + --push_to_hub ``` When using `enable_xformers_memory_efficient_attention`, please make sure to install `xformers` by `pip install xformers`. @@ -274,7 +288,8 @@ accelerate launch train_controlnet.py \ --gradient_checkpointing \ --enable_xformers_memory_efficient_attention \ --set_grads_to_none \ - --mixed_precision fp16 + --mixed_precision fp16 \ + --push_to_hub ``` ## Inference diff --git a/docs/source/en/training/create_dataset.mdx b/docs/source/en/training/create_dataset.mdx new file mode 100644 index 000000000000..9c4f4de53904 --- /dev/null +++ b/docs/source/en/training/create_dataset.mdx @@ -0,0 +1,90 @@ +# Create a dataset for training + +There are many datasets on the [Hub](https://huggingface.co/datasets?task_categories=task_categories:text-to-image&sort=downloads) to train a model on, but if you can't find one you're interested in or want to use your own, you can create a dataset with the 🤗 [Datasets](hf.co/docs/datasets) library. The dataset structure depends on the task you want to train your model on. The most basic dataset structure is a directory of images for tasks like unconditional image generation. Another dataset structure may be a directory of images and a text file containing their corresponding text captions for tasks like text-to-image generation. + +This guide will show you two ways to create a dataset to finetune on: + +- provide a folder of images to the `--train_data_dir` argument +- upload a dataset to the Hub and pass the dataset repository id to the `--dataset_name` argument + + + +💡 Learn more about how to create an image dataset for training in the [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset) guide. + + + +## Provide a dataset as a folder + +For unconditional generation, you can provide your own dataset as a folder of images. The training script uses the [`ImageFolder`](https://huggingface.co/docs/datasets/en/image_dataset#imagefolder) builder from 🤗 Datasets to automatically build a dataset from the folder. Your directory structure should look like: + +```bash +data_dir/xxx.png +data_dir/xxy.png +data_dir/[...]/xxz.png +``` + +Pass the path to the dataset directory to the `--train_data_dir` argument, and then you can start training: + +```bash +accelerate launch train_unconditional.py \ + --train_data_dir \ + +``` + +## Upload your data to the Hub + + + +💡 For more details and context about creating and uploading a dataset to the Hub, take a look at the [Image search with 🤗 Datasets](https://huggingface.co/blog/image-search-datasets) post. + + + +Start by creating a dataset with the [`ImageFolder`](https://huggingface.co/docs/datasets/image_load#imagefolder) feature, which creates an `image` column containing the PIL-encoded images. + +You can use the `data_dir` or `data_files` parameters to specify the location of the dataset. The `data_files` parameter supports mapping specific files to dataset splits like `train` or `test`: + +```python +from datasets import load_dataset + +# example 1: local folder +dataset = load_dataset("imagefolder", data_dir="path_to_your_folder") + +# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset("imagefolder", data_files="path_to_zip_file") + +# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd) +dataset = load_dataset( + "imagefolder", + data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip", +) + +# example 4: providing several splits +dataset = load_dataset( + "imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]} +) +``` + +Then use the [`~datasets.Dataset.push_to_hub`] method to upload the dataset to the Hub: + +```python +# assuming you have ran the huggingface-cli login command in a terminal +dataset.push_to_hub("name_of_your_dataset") + +# if you want to push to a private repo, simply pass private=True: +dataset.push_to_hub("name_of_your_dataset", private=True) +``` + +Now the dataset is available for training by passing the dataset name to the `--dataset_name` argument: + +```bash +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --dataset_name="name_of_your_dataset" \ + +``` + +## Next steps + +Now that you've created a dataset, you can plug it into the `train_data_dir` (if your dataset is local) or `dataset_name` (if your dataset is on the Hub) arguments of a training script. + +For your next steps, feel free to try and use your dataset to train a model for [unconditional generation](uncondtional_training) or [text-to-image generation](text2image)! \ No newline at end of file diff --git a/docs/source/en/training/custom_diffusion.mdx b/docs/source/en/training/custom_diffusion.mdx index 08604f101ea2..ffee456de41f 100644 --- a/docs/source/en/training/custom_diffusion.mdx +++ b/docs/source/en/training/custom_diffusion.mdx @@ -33,7 +33,13 @@ cd diffusers pip install -e . ``` -Then cd in the example folder and run +Then cd into the [example folder](https://github.com/huggingface/diffusers/tree/main/examples/custom_diffusion) + +``` +cd examples/custom_diffusion +``` + +Now run ```bash pip install -r requirements.txt @@ -61,7 +67,7 @@ write_basic_config() ``` ### Cat example 😺 -Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. +Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`. The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training. @@ -73,6 +79,8 @@ python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --nu **___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___** +The script creates and saves model checkpoints and a `pytorch_custom_diffusion_weights.bin` file in your repository. + ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" export OUTPUT_DIR="path-to-save-model" @@ -92,7 +100,8 @@ accelerate launch train_custom_diffusion.py \ --lr_warmup_steps=0 \ --max_train_steps=250 \ --scale_lr --hflip \ - --modifier_token "" + --modifier_token "" \ + --push_to_hub ``` **Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.** @@ -124,7 +133,8 @@ accelerate launch train_custom_diffusion.py \ --scale_lr --hflip \ --modifier_token "" \ --validation_prompt=" cat sitting in a bucket" \ - --report_to="wandb" + --report_to="wandb" \ + --push_to_hub ``` Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details. @@ -160,7 +170,8 @@ accelerate launch train_custom_diffusion.py \ --max_train_steps=500 \ --num_class_images=200 \ --scale_lr --hflip \ - --modifier_token "+" + --modifier_token "+" \ + --push_to_hub ``` Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details. @@ -199,7 +210,8 @@ accelerate launch train_custom_diffusion.py \ --scale_lr --hflip --noaug \ --freeze_model crossattn \ --modifier_token "" \ - --enable_xformers_memory_efficient_attention + --enable_xformers_memory_efficient_attention \ + --push_to_hub ``` ## Inference diff --git a/docs/source/en/training/distributed_inference.mdx b/docs/source/en/training/distributed_inference.mdx new file mode 100644 index 000000000000..e85b3f11e238 --- /dev/null +++ b/docs/source/en/training/distributed_inference.mdx @@ -0,0 +1,91 @@ +# Distributed inference with multiple GPUs + +On distributed setups, you can run inference across multiple GPUs with 🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) or [PyTorch Distributed](https://pytorch.org/tutorials/beginner/dist_overview.html), which is useful for generating with multiple prompts in parallel. + +This guide will show you how to use 🤗 Accelerate and PyTorch Distributed for distributed inference. + +## 🤗 Accelerate + +🤗 [Accelerate](https://huggingface.co/docs/accelerate/index) is a library designed to make it easy to train or run inference across distributed setups. It simplifies the process of setting up the distributed environment, allowing you to focus on your PyTorch code. + +To begin, create a Python file and initialize an [`accelerate.PartialState`] to create a distributed environment; your setup is automatically detected so you don't need to explicitly define the `rank` or `world_size`. Move the [`DiffusionPipeline`] to `distributed_state.device` to assign a GPU to each process. + +Now use the [`~accelerate.PartialState.split_between_processes`] utility as a context manager to automatically distribute the prompts between the number of processes. + +```py +from accelerate import PartialState +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +distributed_state = PartialState() +pipeline.to(distributed_state.device) + +with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: + result = pipeline(prompt).images[0] + result.save(f"result_{distributed_state.process_index}.png") +``` + +Use the `--num_processes` argument to specify the number of GPUs to use, and call `accelerate launch` to run the script: + +```bash +accelerate launch run_distributed.py --num_processes=2 +``` + + + +To learn more, take a look at the [Distributed Inference with 🤗 Accelerate](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) guide. + + + +## PyTorch Distributed + +PyTorch supports [`DistributedDataParallel`](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) which enables data parallelism. + +To start, create a Python file and import `torch.distributed` and `torch.multiprocessing` to set up the distributed process group and to spawn the processes for inference on each GPU. You should also initialize a [`DiffusionPipeline`]: + +```py +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from diffusers import DiffusionPipeline + +sd = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +``` + +You'll want to create a function to run inference; [`init_process_group`](https://pytorch.org/docs/stable/distributed.html?highlight=init_process_group#torch.distributed.init_process_group) handles creating a distributed environment with the type of backend to use, the `rank` of the current process, and the `world_size` or the number of processes participating. If you're running inference in parallel over 2 GPUs, then the `world_size` is 2. + +Move the [`DiffusionPipeline`] to `rank` and use `get_rank` to assign a GPU to each process, where each process handles a different prompt: + +```py +def run_inference(rank, world_size): + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + sd.to(rank) + + if torch.distributed.get_rank() == 0: + prompt = "a dog" + elif torch.distributed.get_rank() == 1: + prompt = "a cat" + + image = sd(prompt).images[0] + image.save(f"./{'_'.join(prompt)}.png") +``` + +To run the distributed inference, call [`mp.spawn`](https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn) to run the `run_inference` function on the number of GPUs defined in `world_size`: + +```py +def main(): + world_size = 2 + mp.spawn(run_inference, args=(world_size,), nprocs=world_size, join=True) + + +if __name__ == "__main__": + main() +``` + +Once you've completed the inference script, use the `--nproc_per_node` argument to specify the number of GPUs to use and call `torchrun` to run the script: + +```bash +torchrun run_distributed.py --nproc_per_node=2 +``` \ No newline at end of file diff --git a/docs/source/en/training/dreambooth.mdx b/docs/source/en/training/dreambooth.mdx index c5a5a047d114..c26762d4a75d 100644 --- a/docs/source/en/training/dreambooth.mdx +++ b/docs/source/en/training/dreambooth.mdx @@ -64,6 +64,8 @@ snapshot_download( ) ``` +To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. + ## Finetuning @@ -76,7 +78,7 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit. Set the `INSTANCE_DIR` environment variable to the path of the directory containing the dog images. -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`] argument. The `instance_prompt` argument is a text prompt that contains a unique identifier, such as `sks`, and the class the image belongs to, which in this example is `a photo of a sks dog`. ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" @@ -98,7 +100,8 @@ accelerate launch train_dreambooth.py \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --max_train_steps=400 + --max_train_steps=400 \ + --push_to_hub ``` @@ -110,7 +113,7 @@ Before running the script, make sure you have the requirements installed: pip install -U -r requirements.txt ``` -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`] argument. The `instance_prompt` argument is a text prompt that contains a unique identifier, such as `sks`, and the class the image belongs to, which in this example is `a photo of a sks dog`. Now you can launch the training script with the following command: @@ -127,7 +130,8 @@ python train_dreambooth_flax.py \ --resolution=512 \ --train_batch_size=1 \ --learning_rate=5e-6 \ - --max_train_steps=400 + --max_train_steps=400 \ + --push_to_hub ``` @@ -161,7 +165,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -183,7 +188,8 @@ python train_dreambooth_flax.py \ --train_batch_size=1 \ --learning_rate=5e-6 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -219,13 +225,14 @@ accelerate launch train_dreambooth.py \ --class_prompt="a photo of dog" \ --resolution=512 \ --train_batch_size=1 \ - --use_8bit_adam + --use_8bit_adam \ --gradient_checkpointing \ --learning_rate=2e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -248,7 +255,8 @@ python train_dreambooth_flax.py \ --train_batch_size=1 \ --learning_rate=2e-6 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -387,7 +395,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### 12GB GPU @@ -418,7 +427,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### 8 GB GPU @@ -464,7 +474,8 @@ accelerate launch train_dreambooth.py \ --lr_warmup_steps=0 \ --num_class_images=200 \ --max_train_steps=800 \ - --mixed_precision=fp16 + --mixed_precision=fp16 \ + --push_to_hub ``` ## Inference @@ -488,3 +499,207 @@ image.save("dog-bucket.png") ``` You may also run inference from any of the [saved training checkpoints](#inference-from-a-saved-checkpoint). + +## IF + +You can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler +[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0). + +Note that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed +variance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you +must also update the pipeline's scheduler config. + +```py +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0") + +pipe.load_lora_weights("") + +# Update scheduler config to fixed variance schedule +pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small") +``` + +Additionally, a few alternative cli flags are needed for IF. + +`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution. + +`--pre_compute_text_embeddings`: IF uses [T5](https://huggingface.co/docs/transformers/model_doc/t5) for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate +T5. + +`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number. + +`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder. + +### Tips and Tricks +We find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless. + +For common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the +upscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is "a sks dog", use "a dog" for your stage II prompt. + +For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than +LoRA finetuning stage II. + +For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best. + +For stage II, we find that lower learning rates are also needed. + +We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler +used in the training scripts. + +### Stage II additional validation images + +The stage II validation requires images to upscale, we can download a downsized version of the training set: + +```py +from huggingface_hub import snapshot_download + +local_dir = "./dog_downsized" +snapshot_download( + "diffusers/dog-example-downsized", + local_dir=local_dir, + repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +### IF stage I LoRA Dreambooth +This training configuration requires ~28 GB VRAM. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_lora" + +accelerate launch train_dreambooth_lora.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=64 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --scale_lr \ + --max_train_steps=1200 \ + --validation_prompt="a sks dog" \ + --validation_epochs=25 \ + --checkpointing_steps=100 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask +``` + +### IF stage II LoRA Dreambooth + +`--validation_images`: These images are upscaled during validation steps. + +`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II. + +`--learning_rate=1e-6`: Lower learning rate than stage I. + +`--resolution=256`: The upscaler expects higher resolution inputs + +```sh +export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_upscale" +export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png" + +python train_dreambooth_lora.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=256 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-6 \ + --max_train_steps=2000 \ + --validation_prompt="a sks dog" \ + --validation_epochs=100 \ + --checkpointing_steps=500 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask \ + --validation_images $VALIDATION_IMAGES \ + --class_labels_conditioning=timesteps +``` + +### IF Stage I Full Dreambooth +`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline +with a T5 loaded from the original model. + +`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. + +`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is +likely the learning rate can be increased with larger batch sizes. + +Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" + +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_if" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=64 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-7 \ + --max_train_steps=150 \ + --validation_prompt "a photo of sks dog" \ + --validation_steps 25 \ + --text_encoder_use_attention_mask \ + --tokenizer_max_length 77 \ + --pre_compute_text_embeddings \ + --use_8bit_adam \ + --set_grads_to_none \ + --skip_save_text_encoder \ + --push_to_hub +``` + +### IF Stage II Full Dreambooth + +`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as +1e-8. + +`--resolution=256`: The upscaler expects higher resolution inputs + +`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with +faces required large effective batch sizes. + +```sh +export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_upscale" +export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png" + +accelerate launch train_dreambooth.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=256 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=6 \ + --learning_rate=5e-6 \ + --max_train_steps=2000 \ + --validation_prompt="a sks dog" \ + --validation_steps=150 \ + --checkpointing_steps=500 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask \ + --validation_images $VALIDATION_IMAGES \ + --class_labels_conditioning timesteps \ + --push_to_hub +``` diff --git a/docs/source/en/training/instructpix2pix.mdx b/docs/source/en/training/instructpix2pix.mdx index ff34ec335656..03ba8f5635d6 100644 --- a/docs/source/en/training/instructpix2pix.mdx +++ b/docs/source/en/training/instructpix2pix.mdx @@ -24,7 +24,7 @@ The output is an "edited" image that reflects the edit instruction applied on th instructpix2pix-output

-The `train_instruct_pix2pix.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. +The `train_instruct_pix2pix.py` script (you can find the it [here](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py)) shows how to implement the training procedure and adapt it for Stable Diffusion. ***Disclaimer: Even though `train_instruct_pix2pix.py` implements the InstructPix2Pix training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.*** @@ -44,7 +44,12 @@ cd diffusers pip install -e . ``` -Then cd in the example folder and run +Then cd in the example folder +```bash +cd examples/instruct_pix2pix +``` + +Now run ```bash pip install -r requirements.txt ``` @@ -72,16 +77,16 @@ write_basic_config() ### Toy example As mentioned before, we'll use a [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) for training. The dataset -is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper. +is a smaller version of the [original dataset](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) used in the InstructPix2Pix paper. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. You'll also need to specify the dataset name in `DATASET_ID`: +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. You'll also need to specify the dataset name in `DATASET_ID`: ```bash export MODEL_NAME="runwayml/stable-diffusion-v1-5" export DATASET_ID="fusing/instructpix2pix-1000-samples" ``` -Now, we can launch training: +Now, we can launch training. The script saves all the components (`feature_extractor`, `scheduler`, `text_encoder`, `unet`, etc) in a subfolder in your repository. ```bash accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ @@ -95,7 +100,8 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ --learning_rate=5e-05 --max_grad_norm=1 --lr_warmup_steps=0 \ --conditioning_dropout_prob=0.05 \ --mixed_precision=fp16 \ - --seed=42 + --seed=42 \ + --push_to_hub ``` Additionally, we support performing validation inference to monitor training progress @@ -116,7 +122,8 @@ accelerate launch --mixed_precision="fp16" train_instruct_pix2pix.py \ --val_image_url="https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" \ --validation_prompt="make the mountains snowy" \ --seed=42 \ - --report_to=wandb + --report_to=wandb \ + --push_to_hub ``` We recommend this type of validation as it can be useful for model debugging. Note that you need `wandb` installed to use this. You can install `wandb` by running `pip install wandb`. @@ -143,7 +150,8 @@ accelerate launch --mixed_precision="fp16" --multi_gpu train_instruct_pix2pix.py --learning_rate=5e-05 --lr_warmup_steps=0 \ --conditioning_dropout_prob=0.05 \ --mixed_precision=fp16 \ - --seed=42 + --seed=42 \ + --push_to_hub ``` ## Inference @@ -199,3 +207,5 @@ speed and quality during performance: Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example). + +If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd). \ No newline at end of file diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 7e3c3c0b2b68..1208178810a5 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -17,8 +17,7 @@ specific language governing permissions and limitations under the License. Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also -support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support -LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918). +support fine-tuning the text encoder for DreamBooth with LoRA in a limited capacity. Fine-tuning the text encoder for DreamBooth generally yields better results, but it can increase compute usage. @@ -52,7 +51,7 @@ Finetuning a model like Stable Diffusion, which has billions of parameters, can Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon. -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. You'll also need to set the `DATASET_NAME` environment variable to the name of the dataset you want to train on. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. You'll also need to set the `DATASET_NAME` environment variable to the name of the dataset you want to train on. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables are optional and specify where to save the model to on the Hub: @@ -69,7 +68,7 @@ There are some flags to be aware of before you start training: * `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)). * `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA. -Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)): +Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)). Training takes about 5 hours on a 2080 Ti GPU with 11GB of RAM, and it'll create and save model checkpoints and the `pytorch_lora_weights` in your repository. ```bash accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \ @@ -115,7 +114,7 @@ Load the LoRA weights from your finetuned model *on top of the base model weight ```py ->>> pipe.unet.load_attn_procs(model_path) +>>> pipe.unet.load_attn_procs(lora_model_path) >>> pipe.to("cuda") # use half the weights from the LoRA finetuned model and half the weights from the base model @@ -128,6 +127,26 @@ Load the LoRA weights from your finetuned model *on top of the base model weight >>> image.save("blue_pokemon.png") ``` + + +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + + + + ## DreamBooth [DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type). @@ -140,9 +159,9 @@ Load the LoRA weights from your finetuned model *on top of the base model weight ### Training[[dreambooth-training]] -Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) with DreamBooth and LoRA with some 🐶 [dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ). Download and save these images to a directory. +Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) with DreamBooth and LoRA with some 🐶 [dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ). Download and save these images to a directory. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. -To start, specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. You'll also need to set `INSTANCE_DIR` to the path of the directory containing the images. +To start, specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. You'll also need to set `INSTANCE_DIR` to the path of the directory containing the images. The `OUTPUT_DIR` variables is optional and specifies where to save the model to on the Hub: @@ -158,7 +177,11 @@ There are some flags to be aware of before you start training: * `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)). * `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA. -Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)): +Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)). The script creates and saves model checkpoints and the `pytorch_lora_weights.bin` file in your repository. + +It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads +to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA, +specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script. ```bash accelerate launch train_dreambooth_lora.py \ @@ -179,12 +202,7 @@ accelerate launch train_dreambooth_lora.py \ --validation_epochs=50 \ --seed="0" \ --push_to_hub -``` - -It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads -to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA, -specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script. - +``` ### Inference[[dreambooth-inference]] @@ -208,7 +226,7 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m ```py ->>> pipe.unet.load_attn_procs(model_path) +>>> pipe.unet.load_attn_procs(lora_model_path) >>> pipe.to("cuda") # use half the weights from the LoRA finetuned model and half the weights from the base model @@ -222,4 +240,115 @@ Load the LoRA weights from your finetuned DreamBooth model *on top of the base m >>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0] >>> image.save("bucket-dog.png") +``` + +If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA +weights. For example: + +```python +from huggingface_hub.repocard import RepoCard +from diffusers import StableDiffusionPipeline +import torch + +lora_model_id = "sayakpaul/dreambooth-text-encoder-test" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.load_lora_weights(lora_model_id) +image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] +``` + + + +If your LoRA parameters involve the UNet as well as the Text Encoder, then passing +`cross_attention_kwargs={"scale": 0.5}` will apply the `scale` value to both the UNet +and the Text Encoder. + + + +Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is preferred to [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] for loading LoRA parameters. This is because +[`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] can handle the following situations: + +* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do: + + ```py + pipe.load_lora_weights(lora_model_path) + ``` + +* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). + +**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs, +refer to the respective docstrings. + +## Supporting A1111 themed LoRA checkpoints from Diffusers + +To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted +LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity. +In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/) +in Diffusers and perform inference with it. + +First, download a checkpoint. We'll use +[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes. + +```bash +wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors +``` + +Next, we initialize a [`~DiffusionPipeline`]: + +```python +import torch + +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + +pipeline = StableDiffusionPipeline.from_pretrained( + "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None +).to("cuda") +pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, use_karras_sigmas=True +) +``` + +We then load the checkpoint downloaded from CivitAI: + +```python +pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors") +``` + + + +If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed. + + + +And then it's time for running inference: + +```python +prompt = "masterpiece, best quality, 1girl, at dusk" +negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " + "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts") + +images = pipeline(prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.manual_seed(0) +).images +``` + +Below is a comparison between the LoRA and the non-LoRA results: + +![lora_non_lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_non_lora_comparison.png) + +You have a similar checkpoint stored on the Hugging Face Hub, you can load it +directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: + +```python +lora_model_id = "sayakpaul/civitai-light-shadow-lora" +lora_filename = "light_and_shadow.safetensors" +pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) ``` \ No newline at end of file diff --git a/docs/source/en/training/text2image.mdx b/docs/source/en/training/text2image.mdx index dabb68397f78..eb8a120c0211 100644 --- a/docs/source/en/training/text2image.mdx +++ b/docs/source/en/training/text2image.mdx @@ -74,15 +74,27 @@ To load a checkpoint to resume training, pass the argument `--resume_from_checkp Launch the [PyTorch training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) for a fine-tuning run on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset like this. -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. - -{"path": "../../../../examples/text_to_image/README.md", -"language": "bash", -"start-after": "accelerate_snippet_start", -"end-before": "accelerate_snippet_end", -"dedent": 0} - +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export dataset_name="lambdalabs/pokemon-blip-captions" + +accelerate launch --mixed_precision="fp16" train_text_to_image.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$dataset_name \ + --use_ema \ + --resolution=512 --center_crop --random_flip \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --max_train_steps=15000 \ + --learning_rate=1e-05 \ + --max_grad_norm=1 \ + --lr_scheduler="constant" --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" \ + --push_to_hub +``` To finetune on your own dataset, prepare the dataset according to the format required by 🤗 [Datasets](https://huggingface.co/docs/datasets/index). You can [upload your dataset to the Hub](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub), or you can [prepare a local folder with your files](https://huggingface.co/docs/datasets/image_dataset#imagefolder). @@ -105,8 +117,10 @@ accelerate launch train_text_to_image.py \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ - --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir=${OUTPUT_DIR} + --lr_scheduler="constant" + --lr_warmup_steps=0 \ + --output_dir=${OUTPUT_DIR} \ + --push_to_hub ``` #### Training with multiple GPUs @@ -129,8 +143,10 @@ accelerate launch --mixed_precision="fp16" --multi_gpu train_text_to_image.py \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ - --lr_scheduler="constant" --lr_warmup_steps=0 \ - --output_dir="sd-pokemon-model" + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --output_dir="sd-pokemon-model" \ + --push_to_hub ``` @@ -143,7 +159,7 @@ Before running the script, make sure you have the requirements installed: pip install -U -r requirements_flax.txt ``` -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. Now you can launch the [Flax training script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py) like this: @@ -159,7 +175,8 @@ python train_text_to_image_flax.py \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ - --output_dir="sd-pokemon-model" + --output_dir="sd-pokemon-model" \ + --push_to_hub ``` To finetune on your own dataset, prepare the dataset according to the format required by 🤗 [Datasets](https://huggingface.co/docs/datasets/index). You can [upload your dataset to the Hub](https://huggingface.co/docs/datasets/image_dataset#upload-dataset-to-the-hub), or you can [prepare a local folder with your files](https://huggingface.co/docs/datasets/image_dataset#imagefolder). @@ -179,7 +196,8 @@ python train_text_to_image_flax.py \ --max_train_steps=15000 \ --learning_rate=1e-05 \ --max_grad_norm=1 \ - --output_dir="sd-pokemon-model" + --output_dir="sd-pokemon-model" \ + --push_to_hub ``` diff --git a/docs/source/en/training/text_inversion.mdx b/docs/source/en/training/text_inversion.mdx index 4cbab9886045..a4fe4c2c4e5b 100644 --- a/docs/source/en/training/text_inversion.mdx +++ b/docs/source/en/training/text_inversion.mdx @@ -81,7 +81,7 @@ To resume training from a saved checkpoint, pass the following argument to the t ## Finetuning -For your training dataset, download these [images of a cat toy](https://huggingface.co/datasets/diffusers/cat_toy_example) and store them in a directory: +For your training dataset, download these [images of a cat toy](https://huggingface.co/datasets/diffusers/cat_toy_example) and store them in a directory. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. ```py from huggingface_hub import snapshot_download @@ -92,9 +92,9 @@ snapshot_download( ) ``` -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument, and the `DATA_DIR` environment variable to the path of the directory containing the images. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument, and the `DATA_DIR` environment variable to the path of the directory containing the images. -Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py): +Now you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py). The script creates and saves the following files to your repository: `learned_embeds.bin`, `token_identifier.txt`, and `type_of_concept.txt`. @@ -120,7 +120,8 @@ accelerate launch textual_inversion.py \ --learning_rate=5.0e-04 --scale_lr \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --output_dir="textual_inversion_cat" + --output_dir="textual_inversion_cat" \ + --push_to_hub ``` @@ -144,7 +145,7 @@ Before you begin, make sure you install the Flax specific dependencies: pip install -U -r requirements_flax.txt ``` -Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`~diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path`] argument. +Specify the `MODEL_NAME` environment variable (either a Hub model repository id or a path to the directory containing the model weights) and pass it to the [`pretrained_model_name_or_path`](https://huggingface.co/docs/diffusers/en/api/diffusion_pipeline#diffusers.DiffusionPipeline.from_pretrained.pretrained_model_name_or_path) argument. Then you can launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion_flax.py): @@ -161,7 +162,8 @@ python textual_inversion_flax.py \ --train_batch_size=1 \ --max_train_steps=3000 \ --learning_rate=5.0e-04 --scale_lr \ - --output_dir="textual_inversion_cat" + --output_dir="textual_inversion_cat" \ + --push_to_hub ``` @@ -245,7 +247,7 @@ from flax.training.common_utils import shard from diffusers import FlaxStableDiffusionPipeline model_path = "path-to-your-trained-model" -pipe, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) +pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(model_path, dtype=jax.numpy.bfloat16) prompt = "A backpack" prng_seed = jax.random.PRNGKey(0) diff --git a/docs/source/en/training/unconditional_training.mdx b/docs/source/en/training/unconditional_training.mdx index 514932d4b22d..7a588cc4cc63 100644 --- a/docs/source/en/training/unconditional_training.mdx +++ b/docs/source/en/training/unconditional_training.mdx @@ -74,7 +74,9 @@ The full training state is saved in a subfolder in the `output_dir` every 500 st ## Finetuning -You're ready to launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) now! Specify the dataset name to finetune on with the `--dataset_name` argument and then save it to the path in `--output_dir`. +You're ready to launch the [training script](https://github.com/huggingface/diffusers/blob/main/examples/unconditional_image_generation/train_unconditional.py) now! Specify the dataset name to finetune on with the `--dataset_name` argument and then save it to the path in `--output_dir`. To use your own dataset, take a look at the [Create a dataset for training](create_dataset) guide. + +The training script creates and saves a `diffusion_pytorch_model.bin` file in your repository. @@ -139,83 +141,6 @@ accelerate launch --mixed_precision="fp16" --multi_gpu train_unconditional.py \ --learning_rate=1e-4 \ --lr_warmup_steps=500 \ --mixed_precision="fp16" \ - --logger="wandb" -``` - -## Finetuning with your own data - -There are two ways to finetune a model on your own dataset: - -- provide your own folder of images to the `--train_data_dir` argument -- upload your dataset to the Hub and pass the dataset repository id to the `--dataset_name` argument. - - - -💡 Learn more about how to create an image dataset for training in the [Create an image dataset](https://huggingface.co/docs/datasets/image_dataset) guide. - - - -Below, we explain both in more detail. - -### Provide the dataset as a folder - -If you provide your own dataset as a folder, the script expects the following directory structure: - -```bash -data_dir/xxx.png -data_dir/xxy.png -data_dir/[...]/xxz.png -``` - -Pass the path to the folder containing the images to the `--train_data_dir` argument and launch the training: - -```bash -accelerate launch train_unconditional.py \ - --train_data_dir \ - -``` - -Internally, the script uses the [`ImageFolder`](https://huggingface.co/docs/datasets/image_load#imagefolder) to automatically build a dataset from the folder. - -### Upload your data to the Hub - - - -💡 For more details and context about creating and uploading a dataset to the Hub, take a look at the [Image search with 🤗 Datasets](https://huggingface.co/blog/image-search-datasets) post. - - - -To upload your dataset to the Hub, you can start by creating one with the [`ImageFolder`](https://huggingface.co/docs/datasets/image_load#imagefolder) feature, which creates an `image` column containing the PIL-encoded images, from 🤗 Datasets: - -```python -from datasets import load_dataset - -# example 1: local folder -dataset = load_dataset("imagefolder", data_dir="path_to_your_folder") - -# example 2: local files (supported formats are tar, gzip, zip, xz, rar, zstd) -dataset = load_dataset("imagefolder", data_files="path_to_zip_file") - -# example 3: remote files (supported formats are tar, gzip, zip, xz, rar, zstd) -dataset = load_dataset( - "imagefolder", - data_files="https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip", -) - -# example 4: providing several splits -dataset = load_dataset( - "imagefolder", data_files={"train": ["path/to/file1", "path/to/file2"], "test": ["path/to/file3", "path/to/file4"]} -) -``` - -Then you can use the [`~datasets.Dataset.push_to_hub`] method to upload it to the Hub: - -```python -# assuming you have ran the huggingface-cli login command in a terminal -dataset.push_to_hub("name_of_your_dataset") - -# if you want to push to a private repo, simply pass private=True: -dataset.push_to_hub("name_of_your_dataset", private=True) -``` - -Now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the Hub. \ No newline at end of file + --logger="wandb" \ + --push_to_hub +``` \ No newline at end of file diff --git a/docs/source/en/tutorials/basic_training.mdx b/docs/source/en/tutorials/basic_training.mdx index 52ce7c71fa68..99221274f745 100644 --- a/docs/source/en/tutorials/basic_training.mdx +++ b/docs/source/en/tutorials/basic_training.mdx @@ -407,9 +407,9 @@ Once training is complete, take a look at the final 🦋 images 🦋 generated b ## Next steps -Unconditional image generation is one example of a task that can be trained. You can explore other tasks and training techniques by visiting the [🧨 Diffusers Training Examples](./training/overview) page. Here are some examples of what you can learn: +Unconditional image generation is one example of a task that can be trained. You can explore other tasks and training techniques by visiting the [🧨 Diffusers Training Examples](../training/overview) page. Here are some examples of what you can learn: -* [Textual Inversion](./training/text_inversion), an algorithm that teaches a model a specific visual concept and integrates it into the generated image. -* [DreamBooth](./training/dreambooth), a technique for generating personalized images of a subject given several input images of the subject. -* [Guide](./training/text2image) to finetuning a Stable Diffusion model on your own dataset. -* [Guide](./training/lora) to using LoRA, a memory-efficient technique for finetuning really large models faster. +* [Textual Inversion](../training/text_inversion), an algorithm that teaches a model a specific visual concept and integrates it into the generated image. +* [DreamBooth](../training/dreambooth), a technique for generating personalized images of a subject given several input images of the subject. +* [Guide](../training/text2image) to finetuning a Stable Diffusion model on your own dataset. +* [Guide](../training/lora) to using LoRA, a memory-efficient technique for finetuning really large models faster. diff --git a/docs/source/en/using-diffusers/conditional_image_generation.mdx b/docs/source/en/using-diffusers/conditional_image_generation.mdx index 0b5c02415d87..195aa2d6c360 100644 --- a/docs/source/en/using-diffusers/conditional_image_generation.mdx +++ b/docs/source/en/using-diffusers/conditional_image_generation.mdx @@ -20,12 +20,12 @@ The [`DiffusionPipeline`] is the easiest way to use a pre-trained diffusion syst Start by creating an instance of [`DiffusionPipeline`] and specify which pipeline [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) you would like to download. -In this guide, you'll use [`DiffusionPipeline`] for text-to-image generation with [Latent Diffusion](https://huggingface.co/CompVis/ldm-text2im-large-256): +In this guide, you'll use [`DiffusionPipeline`] for text-to-image generation with [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5): ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") +>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") ``` The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. diff --git a/docs/source/en/using-diffusers/controlling_generation.mdx b/docs/source/en/using-diffusers/controlling_generation.mdx index b1ba17cd2c67..57b5640ffcd5 100644 --- a/docs/source/en/using-diffusers/controlling_generation.mdx +++ b/docs/source/en/using-diffusers/controlling_generation.mdx @@ -37,6 +37,28 @@ Unless otherwise mentioned, these are techniques that work with existing models 9. [Textual Inversion](#textual-inversion) 10. [ControlNet](#controlnet) 11. [Prompt Weighting](#prompt-weighting) +12. [Custom Diffusion](#custom-diffusion) +13. [Model Editing](#model-editing) +14. [DiffEdit](#diffedit) + +For convenience, we provide a table to denote which methods are inference-only and which require fine-tuning/training. + +| **Method** | **Inference only** | **Requires training /
fine-tuning** | **Comments** | +|:---:|:---:|:---:|:---:| +| [Instruct Pix2Pix](#instruct-pix2pix) | ✅ | ❌ | Can additionally be
fine-tuned for better
performance on specific
edit instructions. | +| [Pix2Pix Zero](#pix2pixzero) | ✅ | ❌ | | +| [Attend and Excite](#attend-and-excite) | ✅ | ❌ | | +| [Semantic Guidance](#semantic-guidance) | ✅ | ❌ | | +| [Self-attention Guidance](#self-attention-guidance) | ✅ | ❌ | | +| [Depth2Image](#depth2image) | ✅ | ❌ | | +| [MultiDiffusion Panorama](#multidiffusion-panorama) | ✅ | ❌ | | +| [DreamBooth](#dreambooth) | ❌ | ✅ | | +| [Textual Inversion](#textual-inversion) | ❌ | ✅ | | +| [ControlNet](#controlnet) | ✅ | ❌ | A ControlNet can be
trained/fine-tuned on
a custom conditioning. | +| [Prompt Weighting](#prompt-weighting) | ✅ | ❌ | | +| [Custom Diffusion](#custom-diffusion) | ❌ | ✅ | | +| [Model Editing](#model-editing) | ✅ | ❌ | | +| [DiffEdit](#diffedit) | ✅ | ❌ | | ## Instruct Pix2Pix @@ -137,13 +159,13 @@ See [here](../api/pipelines/stable_diffusion/panorama) for more information on h In addition to pre-trained models, Diffusers has training scripts for fine-tuning models on user-provided data. -### DreamBooth +## DreamBooth [DreamBooth](../training/dreambooth) fine-tunes a model to teach it about a new subject. I.e. a few pictures of a person can be used to generate images of that person in different styles. See [here](../training/dreambooth) for more information on how to use it. -### Textual Inversion +## Textual Inversion [Textual Inversion](../training/text_inversion) fine-tunes a model to teach it about a new concept. I.e. a few pictures of a style of artwork can be used to generate images in that style. @@ -165,3 +187,32 @@ Prompt weighting is a simple technique that puts more attention weight on certai input. For a more in-detail explanation and examples, see [here](../using-diffusers/weighted_prompts). + +## Custom Diffusion + +[Custom Diffusion](../training/custom_diffusion) only fine-tunes the cross-attention maps of a pre-trained +text-to-image diffusion model. It also allows for additionally performing textual inversion. It supports +multi-concept training by design. Like DreamBooth and Textual Inversion, Custom Diffusion is also used to +teach a pre-trained text-to-image diffusion model about new concepts to generate outputs involving the +concept(s) of interest. + +For more details, check out our [official doc](../training/custom_diffusion). + +## Model Editing + +[Paper](https://arxiv.org/abs/2303.08084) + +The [text-to-image model editing pipeline](../api/pipelines/stable_diffusion/model_editing) helps you mitigate some of the incorrect implicit assumptions a pre-trained text-to-image +diffusion model might make about the subjects present in the input prompt. For example, if you prompt Stable Diffusion to generate images for "A pack of roses", the roses in the generated images +are more likely to be red. This pipeline helps you change that assumption. + +To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing). + +## DiffEdit + +[Paper](https://arxiv.org/abs/2210.11427) + +[DiffEdit](../api/pipelines/stable_diffusion/diffedit) allows for semantic editing of input images along with +input prompts while preserving the original input images as much as possible. + +To know more details, check out the [official doc](../api/pipelines/stable_diffusion/model_editing). \ No newline at end of file diff --git a/docs/source/en/using-diffusers/inpaint.mdx b/docs/source/en/using-diffusers/inpaint.mdx index 41a6d4b7e1b2..228e14e84833 100644 --- a/docs/source/en/using-diffusers/inpaint.mdx +++ b/docs/source/en/using-diffusers/inpaint.mdx @@ -52,7 +52,7 @@ Now you can create a prompt to replace the mask with something else: ```python prompt = "Face of a yellow cat, high resolution, sitting on a park bench" -image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image).images[0] ``` `image` | `mask_image` | `prompt` | output | diff --git a/docs/source/en/using-diffusers/kerascv.mdx b/docs/source/en/using-diffusers/kerascv.mdx deleted file mode 100644 index 06981cc8fdd1..000000000000 --- a/docs/source/en/using-diffusers/kerascv.mdx +++ /dev/null @@ -1,179 +0,0 @@ - - -# Using KerasCV Stable Diffusion Checkpoints in Diffusers - - - -This is an experimental feature. - - - -[KerasCV](https://github.com/keras-team/keras-cv/) provides APIs for implementing various computer vision workflows. It -also provides the Stable Diffusion [v1 and v2](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion) -models. Many practitioners find it easy to fine-tune the Stable Diffusion models shipped by KerasCV. However, as of this writing, KerasCV offers limited support to experiment with Stable Diffusion models for inference and deployment. On the other hand, -Diffusers provides tooling dedicated to this purpose (and more), such as different [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other -optimization techniques](https://huggingface.co/docs/diffusers/optimization/fp16). - -How about fine-tuning Stable Diffusion models in KerasCV and exporting them such that they become compatible with Diffusers to combine the -best of both worlds? We have created a [tool](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) that -lets you do just that! It takes KerasCV Stable Diffusion checkpoints and exports them to Diffusers-compatible checkpoints. -More specifically, it first converts the checkpoints to PyTorch and then wraps them into a -[`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview) which is ready -for inference. Finally, it pushes the converted checkpoints to a repository on the Hugging Face Hub. - -We welcome you to try out the tool [here](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) -and share feedback via [discussions](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers/discussions/new). - -## Getting Started - -First, you need to obtain the fine-tuned KerasCV Stable Diffusion checkpoints. We provide an -overview of the different ways Stable Diffusion models can be fine-tuned [using `diffusers`](https://huggingface.co/docs/diffusers/training/overview). For the Keras implementation of some of these methods, you can check out these resources: - -* [Teach StableDiffusion new concepts via Textual Inversion](https://keras.io/examples/generative/fine_tune_via_textual_inversion/) -* [Fine-tuning Stable Diffusion](https://keras.io/examples/generative/finetune_stable_diffusion/) -* [DreamBooth](https://keras.io/examples/generative/dreambooth/) -* [Prompt-to-Prompt editing](https://github.com/miguelCalado/prompt-to-prompt-tensorflow) - -Stable Diffusion is comprised of the following models: - -* Text encoder -* UNet -* VAE - -Depending on the fine-tuning task, we may fine-tune one or more of these components (the VAE is almost always left untouched). Here are some common combinations: - -* DreamBooth: UNet and text encoder -* Classical text to image fine-tuning: UNet -* Textual Inversion: Just the newly initialized embeddings in the text encoder - -### Performing the Conversion - -Let's use [this checkpoint](https://huggingface.co/sayakpaul/textual-inversion-kerasio/resolve/main/textual_inversion_kerasio.h5) which was generated -by conducting Textual Inversion with the following "placeholder token": ``. - -On the tool, we supply the following things: - -* Path(s) to download the fine-tuned checkpoint(s) (KerasCV) -* An HF token -* Placeholder token (only applicable for Textual Inversion) - -
- -
- -As soon as you hit "Submit", the conversion process will begin. Once it's complete, you should see the following: - -
- -
- -If you click the [link](https://huggingface.co/sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline/tree/main), you -should see something like so: - -
- -
- -If you head over to the [model card of the repository](https://huggingface.co/sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline), the -following should appear: - -
- -
- - - -Note that we're not specifying the UNet weights here since the UNet is not fine-tuned during Textual Inversion. - - - -And that's it! You now have your fine-tuned KerasCV Stable Diffusion model in Diffusers 🧨. - -## Using the Converted Model in Diffusers - -Just beside the model card of the [repository](https://huggingface.co/sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline), -you'd notice an inference widget to try out the model directly from the UI 🤗 - -
- -
- -On the top right hand side, we provide a "Use in Diffusers" button. If you click the button, you should see the following code-snippet: - -```py -from diffusers import DiffusionPipeline - -pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline") -``` - -The model is in standard `diffusers` format. Let's perform inference! - -```py -from diffusers import DiffusionPipeline - -pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline") -pipeline.to("cuda") - -placeholder_token = "" -prompt = f"two {placeholder_token} getting married, photorealistic, high quality" -image = pipeline(prompt, num_inference_steps=50).images[0] -``` - -And we get: - -
- -
- -_**Note that if you specified a `placeholder_token` while performing the conversion, the tool will log it accordingly. Refer -to the model card of [this repository](https://huggingface.co/sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline) -as an example.**_ - -We welcome you to use the tool for various Stable Diffusion fine-tuning scenarios and let us know your feedback! Here are some examples -of Diffusers checkpoints that were obtained using the tool: - -* [sayakpaul/text-unet-dogs-kerascv_sd_diffusers_pipeline](https://huggingface.co/sayakpaul/text-unet-dogs-kerascv_sd_diffusers_pipeline) (DreamBooth with both the text encoder and UNet fine-tuned) -* [sayakpaul/unet-dogs-kerascv_sd_diffusers_pipeline](https://huggingface.co/sayakpaul/unet-dogs-kerascv_sd_diffusers_pipeline) (DreamBooth with only the UNet fine-tuned) - -## Incorporating Diffusers Goodies 🎁 - -Diffusers provides various options that one can leverage to experiment with different inference setups. One particularly -useful option is the use of a different noise scheduler during inference other than what was used during fine-tuning. -Let's try out the [`DPMSolverMultistepScheduler`](https://huggingface.co/docs/diffusers/main/en/api/schedulers/multistep_dpm_solver) -which is different from the one ([`DDPMScheduler`](https://huggingface.co/docs/diffusers/main/en/api/schedulers/ddpm)) used during -fine-tuning. - -You can read more details about this process in [this section](https://huggingface.co/docs/diffusers/using-diffusers/schedulers). - -```py -from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler - -pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline") -pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) -pipeline.to("cuda") - -placeholder_token = "" -prompt = f"two {placeholder_token} getting married, photorealistic, high quality" -image = pipeline(prompt, num_inference_steps=50).images[0] -``` - -
- -
- -One can also continue fine-tuning from these Diffusers checkpoints by leveraging some relevant tools from Diffusers. Refer [here](https://huggingface.co/docs/diffusers/training/overview) for -more details. For inference-specific optimizations, refer [here](https://huggingface.co/docs/diffusers/main/en/optimization/fp16). - -## Known Limitations - -* Only Stable Diffusion v1 checkpoints are supported for conversion in this tool. diff --git a/docs/source/en/using-diffusers/other-formats.mdx b/docs/source/en/using-diffusers/other-formats.mdx new file mode 100644 index 000000000000..8e606f13469d --- /dev/null +++ b/docs/source/en/using-diffusers/other-formats.mdx @@ -0,0 +1,191 @@ + + +# Load different Stable Diffusion formats + +Stable Diffusion models are available in different formats depending on the framework they're trained and saved with, and where you download them from. Converting these formats for use in 🤗 Diffusers allows you to use all the features supported by the library, such as [using different schedulers](schedulers) for inference, [building your custom pipeline](write_own_pipeline), and a variety of techniques and methods for [optimizing inference speed](./optimization/opt_overview). + + + +We highly recommend using the `.safetensors` format because it is more secure than traditional pickled files which are vulnerable and can be exploited to execute any code on your machine (learn more in the [Load safetensors](using_safetensors) guide). + + + +This guide will show you how to convert other Stable Diffusion formats to be compatible with 🤗 Diffusers. + +## PyTorch .ckpt + +The checkpoint - or `.ckpt` - format is commonly used to store and save models. The `.ckpt` file contains the entire model and is typically several GBs in size. While you can load and use a `.ckpt` file directly with the [`~StableDiffusionPipeline.from_ckpt`] method, it is generally better to convert the `.ckpt` file to 🤗 Diffusers so both formats are available. + +There are two options for converting a `.ckpt` file; use a Space to convert the checkpoint or convert the `.ckpt` file with a script. + +### Convert with a Space + +The easiest and most convenient way to convert a `.ckpt` file is to use the [SD to Diffusers](https://huggingface.co/spaces/diffusers/sd-to-diffusers) Space. You can follow the instructions on the Space to convert the `.ckpt` file. + +This approach works well for basic models, but it may struggle with more customized models. You'll know the Space failed if it returns an empty pull request or error. In this case, you can try converting the `.ckpt` file with a script. + +### Convert with a script + +🤗 Diffusers provides a [conversion script](https://github.com/huggingface/diffusers/blob/main/scripts/convert_original_stable_diffusion_to_diffusers.py) for converting `.ckpt` files. This approach is more reliable than the Space above. + +Before you start, make sure you have a local clone of 🤗 Diffusers to run the script and log in to your Hugging Face account so you can open pull requests and push your converted model to the Hub. + +```bash +huggingface-cli login +``` + +To use the script: + +1. Git clone the repository containing the `.ckpt` file you want to convert. For this example, let's convert this [TemporalNet](https://huggingface.co/CiaraRowles/TemporalNet) `.ckpt` file: + +```bash +git lfs install +git clone https://huggingface.co/CiaraRowles/TemporalNet +``` + +2. Open a pull request on the repository where you're converting the checkpoint from: + +```bash +cd TemporalNet && git fetch origin refs/pr/13:pr/13 +git checkout pr/13 +``` + +3. There are several input arguments to configure in the conversion script, but the most important ones are: + + - `checkpoint_path`: the path to the `.ckpt` file to convert. + - `original_config_file`: a YAML file defining the configuration of the original architecture. If you can't find this file, try searching for the YAML file in the GitHub repository where you found the `.ckpt` file. + - `dump_path`: the path to the converted model. + + For example, you can take the `cldm_v15.yaml` file from the [ControlNet](https://github.com/lllyasviel/ControlNet/tree/main/models) repository because the TemporalNet model is a Stable Diffusion v1.5 and ControlNet model. + +4. Now you can run the script to convert the `.ckpt` file: + +```bash +python ../diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py --checkpoint_path temporalnetv3.ckpt --original_config_file cldm_v15.yaml --dump_path ./ --controlnet +``` + +5. Once the conversion is done, upload your converted model and test out the resulting [pull request](https://huggingface.co/CiaraRowles/TemporalNet/discussions/13)! + +```bash +git push origin pr/13:refs/pr/13 +``` + +## Keras .pb or .h5 + + + +🧪 This is an experimental feature. Only Stable Diffusion v1 checkpoints are supported by the Convert KerasCV Space at the moment. + + + +[KerasCV](https://keras.io/keras_cv/) supports training for [Stable Diffusion](https://github.com/keras-team/keras-cv/blob/master/keras_cv/models/stable_diffusion) v1 and v2. However, it offers limited support for experimenting with Stable Diffusion models for inference and deployment whereas 🤗 Diffusers has a more complete set of features for this purpose, such as different [noise schedulers](https://huggingface.co/docs/diffusers/using-diffusers/schedulers), [flash attention](https://huggingface.co/docs/diffusers/optimization/xformers), and [other +optimization techniques](https://huggingface.co/docs/diffusers/optimization/fp16). + +The [Convert KerasCV](https://huggingface.co/spaces/sayakpaul/convert-kerascv-sd-diffusers) Space converts `.pb` or `.h5` files to PyTorch, and then wraps them in a [`StableDiffusionPipeline`] so it is ready for inference. The converted checkpoint is stored in a repository on the Hugging Face Hub. + +For this example, let's convert the [`sayakpaul/textual-inversion-kerasio`](https://huggingface.co/sayakpaul/textual-inversion-kerasio/tree/main) checkpoint which was trained with Textual Inversion. It uses the special token `` to personalize images with cats. + +The Convert KerasCV Space allows you to input the following: + +* Your Hugging Face token. +* Paths to download the UNet and text encoder weights from. Depending on how the model was trained, you don't necessarily need to provide the paths to both the UNet and text encoder. For example, Textual Inversion only requires the embeddings from the text encoder and a text-to-image model only requires the UNet weights. +* Placeholder token is only applicable for textual inversion models. +* The `output_repo_prefix` is the name of the repository where the converted model is stored. + +Click the **Submit** button to automatically convert the KerasCV checkpoint! Once the checkpoint is successfully converted, you'll see a link to the new repository containing the converted checkpoint. Follow the link to the new repository, and you'll see the Convert KerasCV Space generated a model card with an inference widget to try out the converted model. + +If you prefer to run inference with code, click on the **Use in Diffusers** button in the upper right corner of the model card to copy and paste the code snippet: + +```py +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline") +``` + +Then you can generate an image like: + +```py +from diffusers import DiffusionPipeline + +pipeline = DiffusionPipeline.from_pretrained("sayakpaul/textual-inversion-cat-kerascv_sd_diffusers_pipeline") +pipeline.to("cuda") + +placeholder_token = "" +prompt = f"two {placeholder_token} getting married, photorealistic, high quality" +image = pipeline(prompt, num_inference_steps=50).images[0] +``` + +## A1111 LoRA files + +[Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) (A1111) is a popular web UI for Stable Diffusion that supports model sharing platforms like [Civitai](https://civitai.com/). Models trained with the Low-Rank Adaptation (LoRA) technique are especially popular because they're fast to train and have a much smaller file size than a fully finetuned model. 🤗 Diffusers supports loading A1111 LoRA checkpoints with [`~loaders.LoraLoaderMixin.load_lora_weights`]: + +```py +from diffusers import DiffusionPipeline, UniPCMultistepScheduler +import torch + +pipeline = DiffusionPipeline.from_pretrained( + "andite/anything-v4.0", torch_dtype=torch.float16, safety_checker=None +).to("cuda") +pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) +``` + +Download a LoRA checkpoint from Civitai; this example uses the [Howls Moving Castle,Interior/Scenery LoRA (Ghibli Stlye)](https://civitai.com/models/14605?modelVersionId=19998) checkpoint, but feel free to try out any LoRA checkpoint! + +```bash +!wget https://civitai.com/api/download/models/19998 -O howls_moving_castle.safetensors +``` + +Load the LoRA checkpoint into the pipeline with the [`~loaders.LoraLoaderMixin.load_lora_weights`] method: + +```py +pipeline.load_lora_weights(".", weight_name="howls_moving_castle.safetensors") +``` + +Now you can use the pipeline to generate images: + +```py +prompt = "masterpiece, illustration, ultra-detailed, cityscape, san francisco, golden gate bridge, california, bay area, in the snow, beautiful detailed starry sky" +negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture" + +images = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=512, + num_inference_steps=25, + num_images_per_prompt=4, + generator=torch.manual_seed(0), +).images +``` + +Finally, create a helper function to display the images: + +```py +from PIL import Image + + +def image_grid(imgs, rows=2, cols=2): + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +image_grid(images) +``` + +
+ +
diff --git a/docs/source/en/using-diffusers/reproducibility.mdx b/docs/source/en/using-diffusers/reproducibility.mdx index 5bef10bfe190..b666dac72cbf 100644 --- a/docs/source/en/using-diffusers/reproducibility.mdx +++ b/docs/source/en/using-diffusers/reproducibility.mdx @@ -111,7 +111,7 @@ print(np.abs(image).sum()) The result is not the same even though you're using an identical seed because the GPU uses a different random number generator than the CPU. -To circumvent this problem, 🧨 Diffusers has a [`randn_tensor`](#diffusers.utils.randn_tensor) function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The `randn_tensor` function is used everywhere inside the pipeline, allowing the user to **always** pass a CPU `Generator` even if the pipeline is run on a GPU. +To circumvent this problem, 🧨 Diffusers has a [`~diffusers.utils.randn_tensor`] function for creating random noise on the CPU, and then moving the tensor to a GPU if necessary. The `randn_tensor` function is used everywhere inside the pipeline, allowing the user to **always** pass a CPU `Generator` even if the pipeline is run on a GPU. You'll see the results are much closer now! @@ -147,9 +147,6 @@ susceptible to precision error propagation. Don't expect similar results across different GPU hardware or PyTorch versions. In this case, you'll need to run exactly the same hardware and PyTorch version for full reproducibility. -### randn_tensor -[[autodoc]] diffusers.utils.randn_tensor - ## Deterministic algorithms You can also configure PyTorch to use deterministic algorithms to create a reproducible pipeline. However, you should be aware that deterministic algorithms may be slower than nondeterministic ones and you may observe a decrease in performance. But if reproducibility is important to you, then this is the way to go! diff --git a/docs/source/en/using-diffusers/schedulers.mdx b/docs/source/en/using-diffusers/schedulers.mdx index e17d826c7dab..741d92bdd90d 100644 --- a/docs/source/en/using-diffusers/schedulers.mdx +++ b/docs/source/en/using-diffusers/schedulers.mdx @@ -28,18 +28,15 @@ The following paragraphs show how to do so with the 🧨 Diffusers library. ## Load pipeline -Let's start by loading the stable diffusion pipeline. -Remember that you have to be a registered user on the 🤗 Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion. +Let's start by loading the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) model in the [`DiffusionPipeline`]: ```python from huggingface_hub import login from diffusers import DiffusionPipeline import torch -# first we need to login with our access token login() -# Now we can download the pipeline pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) ``` diff --git a/docs/source/en/using-diffusers/textual_inversion_inference.mdx b/docs/source/en/using-diffusers/textual_inversion_inference.mdx new file mode 100644 index 000000000000..9eca3e7e465c --- /dev/null +++ b/docs/source/en/using-diffusers/textual_inversion_inference.mdx @@ -0,0 +1,80 @@ +# Textual inversion + +[[open-in-colab]] + +The [`StableDiffusionPipeline`] supports textual inversion, a technique that enables a model like Stable Diffusion to learn a new concept from just a few sample images. This gives you more control over the generated images and allows you to tailor the model towards specific concepts. You can get started quickly with a collection of community created concepts in the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer). + +This guide will show you how to run inference with textual inversion using a pre-learned concept from the Stable Diffusion Conceptualizer. If you're interested in teaching a model new concepts with textual inversion, take a look at the [Textual Inversion](./training/text_inversion) training guide. + +Login to your Hugging Face account: + +```py +from huggingface_hub import notebook_login + +notebook_login() +``` + +Import the necessary libraries, and create a helper function to visualize the generated images: + +```py +import os +import torch + +import PIL +from PIL import Image + +from diffusers import StableDiffusionPipeline +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid +``` + +Pick a Stable Diffusion checkpoint and a pre-learned concept from the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer): + +```py +pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" +repo_id_embeds = "sd-concepts-library/cat-toy" +``` + +Now you can load a pipeline, and pass the pre-learned concept to it: + +```py +pipeline = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16).to("cuda") + +pipeline.load_textual_inversion(repo_id_embeds) +``` + +Create a prompt with the pre-learned concept by using the special placeholder token ``, and choose the number of samples and rows of images you'd like to generate: + +```py +prompt = "a grafitti in a favela wall with a on it" + +num_samples = 2 +num_rows = 2 +``` + +Then run the pipeline (feel free to adjust the parameters like `num_inference_steps` and `guidance_scale` to see how they affect image quality), save the generated images and visualize them with the helper function you created at the beginning: + +```py +all_images = [] +for _ in range(num_rows): + images = pipe(prompt, num_images_per_prompt=num_samples, num_inference_steps=50, guidance_scale=7.5).images + all_images.extend(images) + +grid = image_grid(all_images, num_samples, num_rows) +grid +``` + +
+ +
diff --git a/docs/source/en/using-diffusers/using_safetensors.mdx b/docs/source/en/using-diffusers/using_safetensors.mdx index b522f3236fbb..2015f2faf85a 100644 --- a/docs/source/en/using-diffusers/using_safetensors.mdx +++ b/docs/source/en/using-diffusers/using_safetensors.mdx @@ -1,87 +1,67 @@ -# What is safetensors ? +# Load safetensors -[safetensors](https://github.com/huggingface/safetensors) is a different format -from the classic `.bin` which uses Pytorch which uses pickle. It contains the -exact same data, which is just the model weights (or tensors). +[safetensors](https://github.com/huggingface/safetensors) is a safe and fast file format for storing and loading tensors. Typically, PyTorch model weights are saved or *pickled* into a `.bin` file with Python's [`pickle`](https://docs.python.org/3/library/pickle.html) utility. However, `pickle` is not secure and pickled files may contain malicious code that can be executed. safetensors is a secure alternative to `pickle`, making it ideal for sharing model weights. -Pickle is notoriously unsafe which allow any malicious file to execute arbitrary code. -The hub itself tries to prevent issues from it, but it's not a silver bullet. +This guide will show you how you load `.safetensor` files, and how to convert Stable Diffusion model weights stored in other formats to `.safetensor`. Before you start, make sure you have safetensors installed: -`safetensors` first and foremost goal is to make loading machine learning models *safe* -in the sense that no takeover of your computer can be done. - -Hence the name. - -# Why use safetensors ? - -**Safety** can be one reason, if you're attempting to use a not well known model and -you're not sure about the source of the file. - -And a secondary reason, is **the speed of loading**. Safetensors can load models much faster -than regular pickle files. If you spend a lot of times switching models, this can be -a huge timesave. - -Numbers taken AMD EPYC 7742 64-Core Processor +```bash +!pip install safetensors ``` -from diffusers import StableDiffusionPipeline -pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") +If you look at the [`runwayml/stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main) repository, you'll see weights inside the `text_encoder`, `unet` and `vae` subfolders are stored in the `.safetensors` format. By default, 🤗 Diffusers automatically loads these `.safetensors` files from their subfolders if they're available in the model repository. -# Loaded in safetensors 0:00:02.033658 -# Loaded in Pytorch 0:00:02.663379 -``` +For more explicit control, you can optionally set `use_safetensors=True` (if `safetensors` is not installed, you'll get an error message asking you to install it): -This is for the entire loading time, the actual weights loading time to load 500MB: +```py +from diffusers import DiffusionPipeline +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_safetensors=True) ``` -Safetensors: 3.4873ms -PyTorch: 172.7537ms -``` - -Performance in general is a tricky business, and there are a few things to understand: -- If you're using the model for the first time from the hub, you will have to download the weights. - That's extremely likely to be much slower than any loading method, therefore you will not see any difference -- If you're loading the model for the first time (let's say after a reboot) then your machine will have to - actually read the disk. It's likely to be as slow in both cases. Again the speed difference may not be as visible (this depends on hardware and the actual model). -- The best performance benefit is when the model was already loaded previously on your computer and you're switching from one model to another. Your OS, is trying really hard not to read from disk, since this is slow, so it will keep the files around in RAM, making it loading again much faster. Since safetensors is doing zero-copy of the tensors, reloading will be faster than pytorch since it has at least once extra copy to do. +However, model weights are not necessarily stored in separate subfolders like in the example above. Sometimes, all the weights are stored in a single `.safetensors` file. In this case, if the weights are Stable Diffusion weights, you can load the file directly with the [`~diffusers.loaders.FromCkptMixin.from_ckpt`] method: -# How to use safetensors ? +```py +from diffusers import StableDiffusionPipeline -If you have `safetensors` installed, and all the weights are available in `safetensors` format, \ -then by default it will use that instead of the pytorch weights. +pipeline = StableDiffusionPipeline.from_ckpt( + "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors" +) +``` -If you are really paranoid about this, the ultimate weapon would be disabling `torch.load`: -```python -import torch +## Convert to safetensors +Not all weights on the Hub are available in the `.safetensors` format, and you may encounter weights stored as `.bin`. In this case, use the [Convert Space](https://huggingface.co/spaces/diffusers/convert) to convert the weights to `.safetensors`. The Convert Space downloads the pickled weights, converts them, and opens a Pull Request to upload the newly converted `.safetensors` file on the Hub. This way, if there is any malicious code contained in the pickled files, they're uploaded to the Hub - which has a [security scanner](https://huggingface.co/docs/hub/security-pickle#hubs-security-scanner) to detect unsafe files and suspicious pickle imports - instead of your computer. -def _raise(): - raise RuntimeError("I don't want to use pickle") +You can use the model with the new `.safetensors` weights by specifying the reference to the Pull Request in the `revision` parameter (you can also test it in this [Check PR](https://huggingface.co/spaces/diffusers/check_pr) Space on the Hub), for example `refs/pr/22`: +```py +from diffusers import DiffusionPipeline -torch.load = lambda *args, **kwargs: _raise() +pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", revision="refs/pr/22") ``` -# I want to use model X but it doesn't have safetensors weights. +## Why use safetensors? -Just go to this [space](https://huggingface.co/spaces/diffusers/convert). -This will create a new PR with the weights, let's say `refs/pr/22`. +There are several reasons for using safetensors: -This space will download the pickled version, convert it, and upload it on the hub as a PR. -If anything bad is contained in the file, it's Huggingface hub that will get issues, not your own computer. -And we're equipped with dealing with it. +- Safety is the number one reason for using safetensors. As open-source and model distribution grows, it is important to be able to trust the model weights you downloaded don't contain any malicious code. The current size of the header in safetensors prevents parsing extremely large JSON files. +- Loading speed between switching models is another reason to use safetensors, which performs zero-copy of the tensors. It is especially fast compared to `pickle` if you're loading the weights to CPU (the default case), and just as fast if not faster when directly loading the weights to GPU. You'll only notice the performance difference if the model is already loaded, and not if you're downloading the weights or loading the model for the first time. -Then in order to use the model, even before the branch gets accepted by the original author you can do: + The time it takes to load the entire pipeline: -```python -from diffusers import DiffusionPipeline + ```py + from diffusers import StableDiffusionPipeline -pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", revision="refs/pr/22") -``` + pipeline = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") + "Loaded in safetensors 0:00:02.033658" + "Loaded in PyTorch 0:00:02.663379" + ``` -or you can test it directly online with this [space](https://huggingface.co/spaces/diffusers/check_pr). + But the actual time it takes to load 500MB of the model weights is only: -And that's it ! + ```bash + safetensors: 3.4873ms + PyTorch: 172.7537ms + ``` -Anything unclear, concerns, or found a bugs ? [Open an issue](https://github.com/huggingface/diffusers/issues/new/choose) +- Lazy loading is also supported in safetensors, which is useful in distributed settings to only load some of the tensors. This format allowed the [BLOOM](https://huggingface.co/bigscience/bloom) model to be loaded in 45 seconds on 8 GPUs instead of 10 minutes with regular PyTorch weights. diff --git a/docs/source/en/using-diffusers/weighted_prompts.mdx b/docs/source/en/using-diffusers/weighted_prompts.mdx index c1316dc9f47d..58e670fbafe9 100644 --- a/docs/source/en/using-diffusers/weighted_prompts.mdx +++ b/docs/source/en/using-diffusers/weighted_prompts.mdx @@ -94,5 +94,15 @@ a try! If your favorite pipeline does not have a `prompt_embeds` input, please make sure to open an issue, the diffusers team tries to be as responsive as possible. +Compel 1.1.6 adds a utility class to simplify using textual inversions. Instantiate a `DiffusersTextualInversionManager` and pass it to Compel init: + +``` +textual_inversion_manager = DiffusersTextualInversionManager(pipe) +compel = Compel( + tokenizer=pipe.tokenizer, + text_encoder=pipe.text_encoder, + textual_inversion_manager=textual_inversion_manager) +``` + Also, please check out the documentation of the [compel](https://github.com/damian0815/compel) library for more information. diff --git a/docs/source/en/using-diffusers/write_own_pipeline.mdx b/docs/source/en/using-diffusers/write_own_pipeline.mdx index 3c993ed53a2a..be92980118b1 100644 --- a/docs/source/en/using-diffusers/write_own_pipeline.mdx +++ b/docs/source/en/using-diffusers/write_own_pipeline.mdx @@ -36,7 +36,7 @@ A pipeline is a quick and easy way to run a model for inference, requiring no mo That was super easy, but how did the pipeline do that? Let's breakdown the pipeline and take a look at what's happening under the hood. -In the example above, the pipeline contains a UNet model and a DDPM scheduler. The pipeline denoises an image by taking random noise the size of the desired output and passing it through the model several times. At each timestep, the model predicts the *noise residual* and the scheduler uses it to predict a less noisy image. The pipeline repeats this process until it reaches the end of the specified number of inference steps. +In the example above, the pipeline contains a [`UNet2DModel`] model and a [`DDPMScheduler`]. The pipeline denoises an image by taking random noise the size of the desired output and passing it through the model several times. At each timestep, the model predicts the *noise residual* and the scheduler uses it to predict a less noisy image. The pipeline repeats this process until it reaches the end of the specified number of inference steps. To recreate the pipeline with the model and scheduler separately, let's write our own denoising process. @@ -82,8 +82,8 @@ To recreate the pipeline with the model and scheduler separately, let's write ou >>> for t in scheduler.timesteps: ... with torch.no_grad(): ... noisy_residual = model(input, t).sample - >>> previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample - >>> input = previous_noisy_sample + ... previous_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample + ... input = previous_noisy_sample ``` This is the entire denoising process, and you can use this same pattern to write any diffusion system. @@ -96,7 +96,7 @@ To recreate the pipeline with the model and scheduler separately, let's write ou >>> image = (input / 2 + 0.5).clamp(0, 1) >>> image = image.cpu().permute(0, 2, 3, 1).numpy()[0] - >>> image = Image.fromarray((image * 255)).round().astype("uint8") + >>> image = Image.fromarray((image * 255).round().astype("uint8")) >>> image ``` @@ -287,4 +287,4 @@ This is really what 🧨 Diffusers is designed for: to make it intuitive and eas For your next steps, feel free to: * Learn how to [build and contribute a pipeline](using-diffusers/#contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with! -* Explore [existing pipelines](./api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately. \ No newline at end of file +* Explore [existing pipelines](./api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately. diff --git a/docs/source/zh/_toctree.yml b/docs/source/zh/_toctree.yml index 2d67d9c4a025..58f6ac09faef 100644 --- a/docs/source/zh/_toctree.yml +++ b/docs/source/zh/_toctree.yml @@ -4,51 +4,79 @@ - local: quicktour title: 快速入门 - local: stable_diffusion - title: Stable Diffusion + title: Effective and efficient diffusion - local: installation title: 安装 title: 开始 - sections: + - local: tutorials/tutorial_overview + title: Overview + - local: using-diffusers/write_own_pipeline + title: Understanding models and schedulers - local: tutorials/basic_training title: Train a diffusion model title: Tutorials - sections: - sections: + - local: using-diffusers/loading_overview + title: Overview - local: using-diffusers/loading - title: Loading Pipelines, Models, and Schedulers + title: Load pipelines, models, and schedulers - local: using-diffusers/schedulers - title: Using different Schedulers - - local: using-diffusers/configuration - title: Configuring Pipelines, Models, and Schedulers + title: Load and compare different schedulers - local: using-diffusers/custom_pipeline_overview - title: Loading and Adding Custom Pipelines + title: Load community pipelines - local: using-diffusers/kerascv - title: Using KerasCV Stable Diffusion Checkpoints in Diffusers + title: Load KerasCV Stable Diffusion checkpoints title: Loading & Hub - sections: + - local: using-diffusers/pipeline_overview + title: Overview - local: using-diffusers/unconditional_image_generation - title: Unconditional Image Generation + title: Unconditional image generation - local: using-diffusers/conditional_image_generation - title: Text-to-Image Generation + title: Text-to-image generation - local: using-diffusers/img2img - title: Text-Guided Image-to-Image + title: Text-guided image-to-image - local: using-diffusers/inpaint - title: Text-Guided Image-Inpainting + title: Text-guided image-inpainting - local: using-diffusers/depth2img - title: Text-Guided Depth-to-Image - - local: using-diffusers/controlling_generation - title: Controlling generation + title: Text-guided depth-to-image - local: using-diffusers/reusing_seeds - title: Reusing seeds for deterministic generation + title: Improve image quality with deterministic generation - local: using-diffusers/reproducibility - title: Reproducibility + title: Create reproducible pipelines - local: using-diffusers/custom_pipeline_examples - title: Community Pipelines + title: Community pipelines - local: using-diffusers/contribute_pipeline - title: How to contribute a Pipeline + title: How to contribute a community pipeline - local: using-diffusers/using_safetensors title: Using safetensors + - local: using-diffusers/stable_diffusion_jax_how_to + title: Stable Diffusion in JAX/Flax + - local: using-diffusers/weighted_prompts + title: Weighting Prompts title: Pipelines for Inference + - sections: + - local: training/overview + title: Overview + - local: training/unconditional_training + title: Unconditional image generation + - local: training/text_inversion + title: Textual Inversion + - local: training/dreambooth + title: DreamBooth + - local: training/text2image + title: Text-to-image + - local: training/lora + title: Low-Rank Adaptation of Large Language Models (LoRA) + - local: training/controlnet + title: ControlNet + - local: training/instructpix2pix + title: InstructPix2Pix Training + - local: training/custom_diffusion + title: Custom Diffusion + title: Training - sections: - local: using-diffusers/rl title: Reinforcement Learning @@ -59,6 +87,8 @@ title: Taking Diffusers Beyond Images title: Using Diffusers - sections: + - local: optimization/opt_overview + title: Overview - local: optimization/fp16 title: Memory and Speed - local: optimization/torch2.0 @@ -69,32 +99,26 @@ title: ONNX - local: optimization/open_vino title: OpenVINO + - local: optimization/coreml + title: Core ML - local: optimization/mps title: MPS - local: optimization/habana title: Habana Gaudi + - local: optimization/tome + title: Token Merging title: Optimization/Special Hardware -- sections: - - local: training/overview - title: Overview - - local: training/unconditional_training - title: Unconditional Image Generation - - local: training/text_inversion - title: Textual Inversion - - local: training/dreambooth - title: DreamBooth - - local: training/text2image - title: Text-to-image - - local: training/lora - title: Low-Rank Adaptation of Large Language Models (LoRA) - title: Training - sections: - local: conceptual/philosophy title: Philosophy + - local: using-diffusers/controlling_generation + title: Controlled generation - local: conceptual/contribution title: How to contribute? - local: conceptual/ethical_guidelines title: Diffusers' Ethical Guidelines + - local: conceptual/evaluation + title: Evaluating Diffusion Models title: Conceptual Guides - sections: - sections: @@ -118,6 +142,8 @@ title: AltDiffusion - local: api/pipelines/audio_diffusion title: Audio Diffusion + - local: api/pipelines/audioldm + title: AudioLDM - local: api/pipelines/cycle_diffusion title: Cycle Diffusion - local: api/pipelines/dance_diffusion @@ -128,6 +154,8 @@ title: DDPM - local: api/pipelines/dit title: DiT + - local: api/pipelines/if + title: IF - local: api/pipelines/latent_diffusion title: Latent Diffusion - local: api/pipelines/paint_by_example @@ -142,6 +170,8 @@ title: Score SDE VE - local: api/pipelines/semantic_stable_diffusion title: Semantic Guidance + - local: api/pipelines/spectrogram_diffusion + title: "Spectrogram Diffusion" - sections: - local: api/pipelines/stable_diffusion/overview title: Overview @@ -171,6 +201,8 @@ title: MultiDiffusion Panorama - local: api/pipelines/stable_diffusion/controlnet title: Text-to-Image Generation with ControlNet Conditioning + - local: api/pipelines/stable_diffusion/model_editing + title: Text-to-Image Model Editing title: Stable Diffusion - local: api/pipelines/stable_diffusion_2 title: Stable Diffusion 2 @@ -178,6 +210,10 @@ title: Stable unCLIP - local: api/pipelines/stochastic_karras_ve title: Stochastic Karras VE + - local: api/pipelines/text_to_video + title: Text-to-Video + - local: api/pipelines/text_to_video_zero + title: Text-to-Video Zero - local: api/pipelines/unclip title: UnCLIP - local: api/pipelines/latent_diffusion_uncond @@ -235,4 +271,4 @@ - local: api/experimental/rl title: RL Planning title: Experimental Features - title: API + title: API \ No newline at end of file diff --git a/docs/source/zh/index.mdx b/docs/source/zh/index.mdx index 4f952c5db79c..e1a2a3971d87 100644 --- a/docs/source/zh/index.mdx +++ b/docs/source/zh/index.mdx @@ -18,61 +18,84 @@ specific language governing permissions and limitations under the License. # 🧨 Diffusers -🤗Diffusers提供了预训练好的视觉和音频扩散模型,并可以作为推理和训练的模块化工具箱。 +🤗 Diffusers 是一个值得首选用于生成图像、音频甚至 3D 分子结构的,最先进的预训练扩散模型库。 +无论您是在寻找简单的推理解决方案,还是想训练自己的扩散模型,🤗 Diffusers 这一模块化工具箱都能对其提供支持。 +本库的设计更偏重于[可用而非高性能](conceptual/philosophy#usability-over-performance)、[简明而非简单](conceptual/philosophy#simple-over-easy)以及[易用而非抽象](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction)。 -更准确地说,🤗Diffusers提供了: -- 最先进的扩散管道,可以在推理中仅用几行代码运行(详情看[**Using Diffusers**](./using-diffusers/conditional_image_generation))或看[**管道**](#pipelines) 以获取所有支持的管道及其对应的论文的概述。 -- 可以在推理中交替使用的各种噪声调度程序,以便在推理过程中权衡如何选择速度和质量。有关更多信息,可以看[**Schedulers**](./api/schedulers/overview)。 -- 多种类型的模型,如U-Net,可用作端到端扩散系统中的构建模块。有关更多详细信息,可以看 [**Models**](./api/models) 。 -- 训练示例,展示如何训练最流行的扩散模型任务。更多相关信息,可以看[**Training**](./training/overview)。 +本库包含三个主要组件: +- 最先进的扩散管道 [diffusion pipelines](api/pipelines/overview),只需几行代码即可进行推理。 +- 可交替使用的各种噪声调度器 [noise schedulers](api/schedulers/overview),用于平衡生成速度和质量。 +- 预训练模型 [models](api/models),可作为构建模块,并与调度程序结合使用,来创建您自己的端到端扩散系统。 -## 🧨 Diffusers pipelines - -下表总结了所有官方支持的pipelines及其对应的论文,部分提供了colab,可以直接尝试一下。 + +## 🧨 Diffusers pipelines -| 管道 | 论文 | 任务 | Colab -|---|---|:---:|:---:| -| [alt_diffusion](./api/pipelines/alt_diffusion) | [**AltDiffusion**](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | -| [audio_diffusion](./api/pipelines/audio_diffusion) | [**Audio Diffusion**](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb) -| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [**ControlNet with Stable Diffusion**](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/controlnet.ipynb) -| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [**Cycle Diffusion**](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | -| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | -| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | -| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | -| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | -| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | -| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | -| [paint_by_example](./api/pipelines/paint_by_example) | [**Paint by Example: Exemplar-based Image Editing with Diffusion Models**](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting | -| [pndm](./api/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | -| [score_sde_ve](./api/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [score_sde_vp](./api/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | -| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [**Semantic Guidance**](https://arxiv.org/abs/2301.12247) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/semantic-image-editing/blob/main/examples/SemanticGuidance.ipynb) -| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) -| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) -| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) -| [stable_diffusion_panorama](./api/pipelines/stable_diffusion/panorama) | [**MultiDiffusion**](https://multidiffusion.github.io/) | Text-to-Panorama Generation | -| [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [**InstructPix2Pix**](https://github.com/timothybrooks/instruct-pix2pix) | Text-Guided Image Editing| -| [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [**Zero-shot Image-to-Image Translation**](https://pix2pixzero.github.io/) | Text-Guided Image Editing | -| [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [**Attend and Excite for Stable Diffusion**](https://attendandexcite.github.io/Attend-and-Excite/) | Text-to-Image Generation | -| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [**Self-Attention Guidance**](https://ku-cvlab.github.io/Self-Attention-Guidance) | Text-to-Image Generation | -| [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [**Stable Diffusion Image Variations**](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation | -| [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [**Stable Diffusion Latent Upscaler**](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Depth-Conditional Stable Diffusion**](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation | -| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [**Stable Diffusion 2**](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | -| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [**Safe Stable Diffusion**](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ml-research/safe-latent-diffusion/blob/main/examples/Safe%20Latent%20Diffusion.ipynb) -| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Text-to-Image Generation | -| [stable_unclip](./stable_unclip) | **Stable unCLIP** | Image-to-Image Text-Guided Generation | -| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | -| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125) | Text-to-Image Generation | -| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | -| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | -| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | -| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | - +下表汇总了当前所有官方支持的pipelines及其对应的论文. -**注意**: 管道是如何使用相应论文中提出的扩散模型的简单示例。 \ No newline at end of file +| 管道 | 论文/仓库 | 任务 | +|---|---|:---:| +| [alt_diffusion](./api/pipelines/alt_diffusion) | [AltCLIP: Altering the Language Encoder in CLIP for Extended Language Capabilities](https://arxiv.org/abs/2211.06679) | Image-to-Image Text-Guided Generation | +| [audio_diffusion](./api/pipelines/audio_diffusion) | [Audio Diffusion](https://github.com/teticio/audio-diffusion.git) | Unconditional Audio Generation | +| [controlnet](./api/pipelines/stable_diffusion/controlnet) | [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543) | Image-to-Image Text-Guided Generation | +| [cycle_diffusion](./api/pipelines/cycle_diffusion) | [Unifying Diffusion Models' Latent Space, with Applications to CycleDiffusion and Guidance](https://arxiv.org/abs/2210.05559) | Image-to-Image Text-Guided Generation | +| [dance_diffusion](./api/pipelines/dance_diffusion) | [Dance Diffusion](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation | +| [ddpm](./api/pipelines/ddpm) | [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation | +| [ddim](./api/pipelines/ddim) | [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | +| [if](./if) | [**IF**](./api/pipelines/if) | Image Generation | +| [if_img2img](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation | +| [if_inpainting](./if) | [**IF**](./api/pipelines/if) | Image-to-Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation | +| [latent_diffusion](./api/pipelines/latent_diffusion) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752)| Super Resolution Image-to-Image | +| [latent_diffusion_uncond](./api/pipelines/latent_diffusion_uncond) | [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation | +| [paint_by_example](./api/pipelines/paint_by_example) | [Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://arxiv.org/abs/2211.13227) | Image-Guided Image Inpainting | +| [pndm](./api/pipelines/pndm) | [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation | +| [score_sde_ve](./api/pipelines/score_sde_ve) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [score_sde_vp](./api/pipelines/score_sde_vp) | [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation | +| [semantic_stable_diffusion](./api/pipelines/semantic_stable_diffusion) | [Semantic Guidance](https://arxiv.org/abs/2301.12247) | Text-Guided Generation | +| [stable_diffusion_text2img](./api/pipelines/stable_diffusion/text2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | +| [stable_diffusion_img2img](./api/pipelines/stable_diffusion/img2img) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | +| [stable_diffusion_inpaint](./api/pipelines/stable_diffusion/inpaint) | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | +| [stable_diffusion_panorama](./api/pipelines/stable_diffusion/panorama) | [MultiDiffusion](https://multidiffusion.github.io/) | Text-to-Panorama Generation | +| [stable_diffusion_pix2pix](./api/pipelines/stable_diffusion/pix2pix) | [InstructPix2Pix: Learning to Follow Image Editing Instructions](https://arxiv.org/abs/2211.09800) | Text-Guided Image Editing| +| [stable_diffusion_pix2pix_zero](./api/pipelines/stable_diffusion/pix2pix_zero) | [Zero-shot Image-to-Image Translation](https://pix2pixzero.github.io/) | Text-Guided Image Editing | +| [stable_diffusion_attend_and_excite](./api/pipelines/stable_diffusion/attend_and_excite) | [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://arxiv.org/abs/2301.13826) | Text-to-Image Generation | +| [stable_diffusion_self_attention_guidance](./api/pipelines/stable_diffusion/self_attention_guidance) | [Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939) | Text-to-Image Generation Unconditional Image Generation | +| [stable_diffusion_image_variation](./stable_diffusion/image_variation) | [Stable Diffusion Image Variations](https://github.com/LambdaLabsML/lambda-diffusers#stable-diffusion-image-variations) | Image-to-Image Generation | +| [stable_diffusion_latent_upscale](./stable_diffusion/latent_upscale) | [Stable Diffusion Latent Upscaler](https://twitter.com/StabilityAI/status/1590531958815064065) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_model_editing](./api/pipelines/stable_diffusion/model_editing) | [Editing Implicit Assumptions in Text-to-Image Diffusion Models](https://time-diffusion.github.io/) | Text-to-Image Model Editing | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-to-Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Image Inpainting | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Depth-Conditional Stable Diffusion](https://github.com/Stability-AI/stablediffusion#depth-conditional-stable-diffusion) | Depth-to-Image Generation | +| [stable_diffusion_2](./api/pipelines/stable_diffusion_2) | [Stable Diffusion 2](https://stability.ai/blog/stable-diffusion-v2-release) | Text-Guided Super Resolution Image-to-Image | +| [stable_diffusion_safe](./api/pipelines/stable_diffusion_safe) | [Safe Stable Diffusion](https://arxiv.org/abs/2211.05105) | Text-Guided Generation | +| [stable_unclip](./stable_unclip) | Stable unCLIP | Text-to-Image Generation | +| [stable_unclip](./stable_unclip) | Stable unCLIP | Image-to-Image Text-Guided Generation | +| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation | +| [text_to_video_sd](./api/pipelines/text_to_video) | [Modelscope's Text-to-video-synthesis Model in Open Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary) | Text-to-Video Generation | +| [unclip](./api/pipelines/unclip) | [Hierarchical Text-Conditional Image Generation with CLIP Latents](https://arxiv.org/abs/2204.06125)(implementation by [kakaobrain](https://github.com/kakaobrain/karlo)) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Text-to-Image Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Image Variations Generation | +| [versatile_diffusion](./api/pipelines/versatile_diffusion) | [Versatile Diffusion: Text, Images and Variations All in One Diffusion Model](https://arxiv.org/abs/2211.08332) | Dual Image and Text Guided Generation | +| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation | diff --git a/docs/source/zh/installation.mdx b/docs/source/zh/installation.mdx index cda91df8a6cd..8cd3ad97cc21 100644 --- a/docs/source/zh/installation.mdx +++ b/docs/source/zh/installation.mdx @@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License. # 安装 -安装🤗 Diffusers 到你正在使用的任何深度学习框架中。 +在你正在使用的任意深度学习框架中安装 🤗 Diffusers 。 🤗 Diffusers已在Python 3.7+、PyTorch 1.7.0+和Flax上进行了测试。按照下面的安装说明,针对你正在使用的深度学习框架进行安装: @@ -21,11 +21,11 @@ specific language governing permissions and limitations under the License. ## 使用pip安装 -你需要在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装🤗 Diffusers 。 +你需要在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Diffusers 。 如果你对 Python 虚拟环境不熟悉,可以看看这个[教程](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). -使用虚拟环境你可以轻松管理不同的项目,避免了依赖项之间的兼容性问题。 +在虚拟环境中,你可以轻松管理不同的项目,避免依赖项之间的兼容性问题。 首先,在你的项目目录下创建一个虚拟环境: @@ -39,7 +39,7 @@ python -m venv .env source .env/bin/activate ``` -现在你就可以安装 🤗 Diffusers了!使用下边这个命令: +现在,你就可以安装 🤗 Diffusers了!使用下边这个命令: **PyTorch** @@ -55,7 +55,7 @@ pip install diffusers["flax"] ## 从源代码安装 -在从源代码安装 `diffusers` 之前,你先确定你已经安装了 `torch` 和 `accelerate`。 +在从源代码安装 `diffusers` 之前,确保你已经安装了 `torch` 和 `accelerate`。 `torch`的安装教程可以看 `torch` [文档](https://pytorch.org/get-started/locally/#start-locally). @@ -65,17 +65,17 @@ pip install diffusers["flax"] pip install accelerate ``` -从源码安装 🤗 Diffusers 使用以下命令: +从源码安装 🤗 Diffusers 需要使用以下命令: ```bash pip install git+https://github.com/huggingface/diffusers ``` 这个命令安装的是最新的 `main`版本,而不是最近的`stable`版。 -`main`是一直和最新进展保持一致的。比如,上次正式版发布了,有bug,新的正式版还没推出,但是`main`中可以看到这个bug被修复了。 -但是这也意味着 `main`版本并不总是稳定的。 +`main`是一直和最新进展保持一致的。比如,上次发布的正式版中有bug,在`main`中可以看到这个bug被修复了,但是新的正式版此时尚未推出。 +但是这也意味着 `main`版本不保证是稳定的。 -我们努力保持`main`版本正常运行,大多数问题都能在几个小时或一天之内解决 +我们努力保持`main`版本正常运行,大多数问题都能在几个小时或一天之内解决 如果你遇到了问题,可以提 [Issue](https://github.com/huggingface/transformers/issues),这样我们就能更快修复问题了。 @@ -105,8 +105,8 @@ pip install -e ".[torch]" pip install -e ".[flax]" ``` -这些命令将连接你克隆的版本库和你的 Python 库路径。 -现在,除了正常的库路径外,Python 还会在你克隆的文件夹内寻找。 +这些命令将连接到你克隆的版本库和你的 Python 库路径。 +现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。 例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.7/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。 @@ -116,32 +116,31 @@ pip install -e ".[flax]" -现在你可以用下面的命令轻松地将你克隆的🤗Diffusers仓库更新到最新版本。 +现在你可以用下面的命令轻松地将你克隆的 🤗 Diffusers 库更新到最新版本。 ```bash cd ~/diffusers/ git pull ``` -你的Python环境将在下次运行时找到`main`版本的🤗 Diffusers。 +你的Python环境将在下次运行时找到`main`版本的 🤗 Diffusers。 -## 注意遥测日志 +## 注意 Telemetry 日志 -我们的库会在使用`from_pretrained()`请求期间收集信息。这些数据包括Diffusers和PyTorch/Flax的版本,请求的模型或管道,以及预训练检查点的路径(如果它被托管在Hub上)。 +我们的库会在使用`from_pretrained()`请求期间收集 telemetry 信息。这些数据包括Diffusers和PyTorch/Flax的版本,请求的模型或管道类,以及预训练检查点的路径(如果它被托管在Hub上的话)。 +这些使用数据有助于我们调试问题并确定新功能的开发优先级。 +Telemetry 数据仅在从 HuggingFace Hub 中加载模型和管道时发送,而不会在本地使用期间收集。 -这些使用数据有助于我们调试问题并优先考虑新功能。 -当从HuggingFace Hub加载模型和管道时才会发送遥测数据,并且在本地使用时不会收集数据。 +我们知道,并不是每个人都想分享这些的信息,我们尊重您的隐私, +因此您可以通过在终端中设置 `DISABLE_TELEMETRY` 环境变量从而禁用 Telemetry 数据收集: -我们知道并不是每个人都想分享这些的信息,我们尊重您的隐私, -因此您可以通过在终端中设置“DISABLE_TELEMETRY”环境变量来禁用遥测数据的收集: - -在Linux/MacOS中: +Linux/MacOS : ```bash export DISABLE_TELEMETRY=YES ``` -在Windows中: +Windows : ```bash set DISABLE_TELEMETRY=YES ``` \ No newline at end of file diff --git a/examples/community/README.md b/examples/community/README.md old mode 100644 new mode 100755 index 8b5b1743203d..065b46f5410c --- a/examples/community/README.md +++ b/examples/community/README.md @@ -6,33 +6,38 @@ Please have a look at the following table to get an overview of all community examples. Click on the **Code Example** to get a copy-and-paste ready code example that you can try out. If a community doesn't work as expected, please open an issue and ping the author on it. -| Example | Description | Code Example | Colab | Author | -|:---------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------:| -| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | -| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | -| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) | -| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | -| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) | -| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) -| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) | -| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | -| Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) | -| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | -| Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) | -| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting| [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) | -| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting| [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) | -| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - |[Stuti R.](https://github.com/kingstut) | -| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | -| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | -Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | -MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | -| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - |[Ray Wang](https://wrong.wang) | -| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | -| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | -| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - |[Aengus (Duc-Anh)](https://github.com/aengusng8) | -| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | -| TensorRT Stable Diffusion Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - |[Asfiya Baig](https://github.com/asfiyab-nvidia) | - +| Example | Description | Code Example | Colab | Author | +|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| +| CLIP Guided Stable Diffusion | Doing CLIP guidance for text to image generation with Stable Diffusion | [CLIP Guided Stable Diffusion](#clip-guided-stable-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/CLIP_Guided_Stable_diffusion_with_diffusers.ipynb) | [Suraj Patil](https://github.com/patil-suraj/) | +| One Step U-Net (Dummy) | Example showcasing of how to use Community Pipelines (see https://github.com/huggingface/diffusers/issues/841) | [One Step U-Net](#one-step-unet) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Stable Diffusion Interpolation | Interpolate the latent space of Stable Diffusion between different prompts/seeds | [Stable Diffusion Interpolation](#stable-diffusion-interpolation) | - | [Nate Raw](https://github.com/nateraw/) | +| Stable Diffusion Mega | **One** Stable Diffusion Pipeline with all functionalities of [Text2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py), [Image2Image](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py) and [Inpainting](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py) | [Stable Diffusion Mega](#stable-diffusion-mega) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) | +| Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) +| Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) | +| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | +| Seed Resizing Stable Diffusion | Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) | +| Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image | [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | +| Multilingual Stable Diffusion | Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piñeros](https://github.com/juancopi81) | +| Image to Image Inpainting Stable Diffusion | Stable Diffusion Pipeline that enables the overlaying of two images and subsequent inpainting | [Image to Image Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Alex McKinney](https://github.com/vvvm23) | +| Text Based Inpainting Stable Diffusion | Stable Diffusion Inpainting Pipeline that enables passing a text prompt to generate the mask for inpainting | [Text Based Inpainting Stable Diffusion](#image-to-image-inpainting-stable-diffusion) | - | [Dhruv Karan](https://github.com/unography) | +| Bit Diffusion | Diffusion on discrete data | [Bit Diffusion](#bit-diffusion) | - | [Stuti R.](https://github.com/kingstut) | +| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) | +| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | + Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) | + MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) | +| Stable UnCLIP | Diffusion Pipeline for combining prior model (generate clip image embedding from text, UnCLIPPipeline `"kakaobrain/karlo-v1-alpha"`) and decoder pipeline (decode clip image embedding to image, StableDiffusionImageVariationPipeline `"lambdalabs/sd-image-variations-diffusers"` ). | [Stable UnCLIP](#stable-unclip) | - | [Ray Wang](https://wrong.wang) | +| UnCLIP Text Interpolation Pipeline | Diffusion Pipeline that allows passing two prompts and produces images while interpolating between the text-embeddings of the two prompts | [UnCLIP Text Interpolation Pipeline](#unclip-text-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | +| UnCLIP Image Interpolation Pipeline | Diffusion Pipeline that allows passing two images/image_embeddings and produces images while interpolating between their image-embeddings | [UnCLIP Image Interpolation Pipeline](#unclip-image-interpolation-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) | +| DDIM Noise Comparative Analysis Pipeline | Investigating how the diffusion models learn visual concepts from each noise level (which is a contribution of [P2 weighting (CVPR 2022)](https://arxiv.org/abs/2204.00227)) | [DDIM Noise Comparative Analysis Pipeline](#ddim-noise-comparative-analysis-pipeline) | - | [Aengus (Duc-Anh)](https://github.com/aengusng8) | +| CLIP Guided Img2Img Stable Diffusion Pipeline | Doing CLIP guidance for image to image generation with Stable Diffusion | [CLIP Guided Img2Img Stable Diffusion](#clip-guided-img2img-stable-diffusion) | - | [Nipun Jindal](https://github.com/nipunjindal/) | +| TensorRT Stable Diffusion Text to Image Pipeline | Accelerates the Stable Diffusion Text2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Text to Image Pipeline](#tensorrt-text2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | +| EDICT Image Editing Pipeline | Diffusion pipeline for text-guided image editing | [EDICT Image Editing Pipeline](#edict-image-editing-pipeline) | - | [Joqsan Azocar](https://github.com/Joqsan) | +| Stable Diffusion RePaint | Stable Diffusion pipeline using [RePaint](https://arxiv.org/abs/2201.0986) for inpainting. | [Stable Diffusion RePaint](#stable-diffusion-repaint ) | - | [Markus Pobitzer](https://github.com/Markus-Pobitzer) | +| TensorRT Stable Diffusion Image to Image Pipeline | Accelerates the Stable Diffusion Image2Image Pipeline using TensorRT | [TensorRT Stable Diffusion Image to Image Pipeline](#tensorrt-image2image-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | +| Stable Diffusion IPEX Pipeline | Accelerate Stable Diffusion inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion on IPEX](#stable-diffusion-on-ipex) | - | [Yingjie Han](https://github.com/yingjie-han/) | +| CLIP Guided Images Mixing Stable Diffusion Pipeline | Сombine images using usual diffusion models. | [CLIP Guided Images Mixing Using Stable Diffusion](#clip-guided-images-mixing-with-stable-diffusion) | - | [Karachev Denis](https://github.com/TheDenk) | +| TensorRT Stable Diffusion Inpainting Pipeline | Accelerates the Stable Diffusion Inpainting Pipeline using TensorRT | [TensorRT Stable Diffusion Inpainting Pipeline](#tensorrt-inpainting-stable-diffusion-pipeline) | - | [Asfiya Baig](https://github.com/asfiyab-nvidia) | To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly. ```py @@ -1161,3 +1166,510 @@ prompt = "a beautiful photograph of Mt. Fuji during cherry blossom" image = pipe(prompt).images[0] image.save('tensorrt_mt_fuji.png') ``` + +### EDICT Image Editing Pipeline + +This pipeline implements the text-guided image editing approach from the paper [EDICT: Exact Diffusion Inversion via Coupled Transformations](https://arxiv.org/abs/2211.12446). You have to pass: +- (`PIL`) `image` you want to edit. +- `base_prompt`: the text prompt describing the current image (before editing). +- `target_prompt`: the text prompt describing with the edits. + +```python +from diffusers import DiffusionPipeline, DDIMScheduler +from transformers import CLIPTextModel +import torch, PIL, requests +from io import BytesIO +from IPython.display import display + +def center_crop_and_resize(im): + + width, height = im.size + d = min(width, height) + left = (width - d) / 2 + upper = (height - d) / 2 + right = (width + d) / 2 + lower = (height + d) / 2 + + return im.crop((left, upper, right, lower)).resize((512, 512)) + +torch_dtype = torch.float16 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +# scheduler and text_encoder param values as in the paper +scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + set_alpha_to_one=False, + clip_sample=False, +) + +text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path="openai/clip-vit-large-patch14", + torch_dtype=torch_dtype, +) + +# initialize pipeline +pipeline = DiffusionPipeline.from_pretrained( + pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4", + custom_pipeline="edict_pipeline", + revision="fp16", + scheduler=scheduler, + text_encoder=text_encoder, + leapfrog_steps=True, + torch_dtype=torch_dtype, +).to(device) + +# download image +image_url = "https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg" +response = requests.get(image_url) +image = PIL.Image.open(BytesIO(response.content)) + +# preprocess it +cropped_image = center_crop_and_resize(image) + +# define the prompts +base_prompt = "A dog" +target_prompt = "A golden retriever" + +# run the pipeline +result_image = pipeline( + base_prompt=base_prompt, + target_prompt=target_prompt, + image=cropped_image, +) + +display(result_image) +``` + +Init Image + +![img2img_init_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1.jpeg) + +Output Image + +![img2img_edict_text_editing](https://huggingface.co/datasets/Joqsan/images/resolve/main/imagenet_dog_1_cropped_generated.png) + +### Stable Diffusion RePaint + +This pipeline uses the [RePaint](https://arxiv.org/abs/2201.09865) logic on the latent space of stable diffusion. It can +be used similarly to other image inpainting pipelines but does not rely on a specific inpainting model. This means you can use +models that are not specifically created for inpainting. + +Make sure to use the ```RePaintScheduler``` as shown in the example below. + +Disclaimer: The mask gets transferred into latent space, this may lead to unexpected changes on the edge of the masked part. +The inference time is a lot slower. + +```py +import PIL +import requests +import torch +from io import BytesIO +from diffusers import StableDiffusionPipeline, RePaintScheduler +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" +init_image = download_image(img_url).resize((512, 512)) +mask_image = download_image(mask_url).resize((512, 512)) +mask_image = PIL.ImageOps.invert(mask_image) +pipe = StableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, custom_pipeline="stable_diffusion_repaint", +) +pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") +prompt = "Face of a yellow cat, high resolution, sitting on a park bench" +image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] +``` + +### TensorRT Image2Image Stable Diffusion Pipeline + +The TensorRT Pipeline can be used to accelerate the Image2Image Stable Diffusion Inference run. + +NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes. + +```python +import requests +from io import BytesIO +from PIL import Image +import torch +from diffusers import DDIMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionImg2ImgPipeline + +# Use the DDIMScheduler scheduler here instead +scheduler = DDIMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1", + subfolder="scheduler") + + +pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", + custom_pipeline="stable_diffusion_tensorrt_img2img", + revision='fp16', + torch_dtype=torch.float16, + scheduler=scheduler,) + +# re-use cached folder to save ONNX models and TensorRT Engines +pipe.set_cached_folder("stabilityai/stable-diffusion-2-1", revision='fp16',) + +pipe = pipe.to("cuda") + +url = "https://pajoca.com/wp-content/uploads/2022/09/tekito-yamakawa-1.png" +response = requests.get(url) +input_image = Image.open(BytesIO(response.content)).convert("RGB") + +prompt = "photorealistic new zealand hills" +image = pipe(prompt, image=input_image, strength=0.75,).images[0] +image.save('tensorrt_img2img_new_zealand_hills.png') +``` + +### Stable Diffusion Reference + +This pipeline uses the Reference Control. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280). + +Based on [this issue](https://github.com/huggingface/diffusers/issues/3566), +- `EulerAncestralDiscreteScheduler` got poor results. + +```py +import torch +from diffusers import UniPCMultistepScheduler +from diffusers.utils import load_image + +input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + +pipe = StableDiffusionReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + +result_img = pipe(ref_image=input_image, + prompt="1girl", + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] +``` + +Reference Image + +![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png) + +Output Image of `reference_attn=True` and `reference_adain=False` + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/813b5c6a-6d89-46ba-b7a4-2624e240eea5) + +Output Image of `reference_attn=False` and `reference_adain=True` + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/ffc90339-9ef0-4c4d-a544-135c3e5644da) + +Output Image of `reference_attn=True` and `reference_adain=True` + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/3c5255d6-867d-4d35-b202-8dfd30cc6827) + +### Stable Diffusion ControlNet Reference + +This pipeline uses the Reference Control with ControlNet. Refer to the [sd-webui-controlnet discussion: Reference-only Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1236)[sd-webui-controlnet discussion: Reference-adain Control](https://github.com/Mikubill/sd-webui-controlnet/discussions/1280). + +Based on [this issue](https://github.com/huggingface/diffusers/issues/3566), +- `EulerAncestralDiscreteScheduler` got poor results. +- `guess_mode=True` works well for ControlNet v1.1 + +```py +import cv2 +import torch +import numpy as np +from PIL import Image +from diffusers import UniPCMultistepScheduler +from diffusers.utils import load_image + +input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + +# get canny image +image = cv2.Canny(np.array(input_image), 100, 200) +image = image[:, :, None] +image = np.concatenate([image, image, image], axis=2) +canny_image = Image.fromarray(image) + +controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) +pipe = StableDiffusionControlNetReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + controlnet=controlnet, + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + +result_img = pipe(ref_image=input_image, + prompt="1girl", + image=canny_image, + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] +``` + +Reference Image + +![reference_image](https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png) + +Output Image + +![output_image](https://github.com/huggingface/diffusers/assets/24734142/7b9a5830-f173-4b92-b0cf-73d0e9c01d60) + + +### Stable Diffusion on IPEX + +This diffusion pipeline aims to accelarate the inference of Stable-Diffusion on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch). + +To use this pipeline, you need to: +1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch) + +**Note:** For each PyTorch release, there is a corresponding release of the IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.0 to get the best performance. + +|PyTorch Version|IPEX Version| +|--|--| +|[v2.0.\*](https://github.com/pytorch/pytorch/tree/v2.0.1 "v2.0.1")|[v2.0.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.0.100+cpu)| +|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)| + +You can simply use pip to install IPEX with the latest version. +```python +python -m pip install intel_extension_for_pytorch +``` +**Note:** To install a specific version, run with the following command: +``` +python -m pip install intel_extension_for_pytorch== -f https://developer.intel.com/ipex-whl-stable-cpu +``` + +2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16. + +**Note:** The setting of generated image height/width for `prepare_for_ipex()` should be same as the setting of pipeline inference. +```python +pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") +# For Float32 +pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference +# For BFloat16 +pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference +``` + +Then you can use the ipex pipeline in a similar way to the default stable diffusion pipeline. +```python +# For Float32 +image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' +# For BFloat16 +with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): + image = pipe(prompt, num_inference_steps=20, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' +``` + +The following code compares the performance of the original stable diffusion pipeline with the ipex-optimized pipeline. + +```python +import torch +import intel_extension_for_pytorch as ipex +from diffusers import StableDiffusionPipeline +import time + +prompt = "sailing ship in storm by Rembrandt" +model_id = "runwayml/stable-diffusion-v1-5" +# Helper function for time evaluation +def elapsed_time(pipeline, nb_pass=3, num_inference_steps=20): + # warmup + for _ in range(2): + images = pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images + #time evaluation + start = time.time() + for _ in range(nb_pass): + pipeline(prompt, num_inference_steps=num_inference_steps, height=512, width=512) + end = time.time() + return (end - start) / nb_pass + +############## bf16 inference performance ############### + +# 1. IPEX Pipeline initialization +pipe = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex") +pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) + +# 2. Original Pipeline initialization +pipe2 = StableDiffusionPipeline.from_pretrained(model_id) + +# 3. Compare performance between Original Pipeline and IPEX Pipeline +with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): + latency = elapsed_time(pipe) + print("Latency of StableDiffusionIPEXPipeline--bf16", latency) + latency = elapsed_time(pipe2) + print("Latency of StableDiffusionPipeline--bf16",latency) + +############## fp32 inference performance ############### + +# 1. IPEX Pipeline initialization +pipe3 = DiffusionPipeline.from_pretrained(model_id, custom_pipeline="stable_diffusion_ipex") +pipe3.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) + +# 2. Original Pipeline initialization +pipe4 = StableDiffusionPipeline.from_pretrained(model_id) + +# 3. Compare performance between Original Pipeline and IPEX Pipeline +latency = elapsed_time(pipe3) +print("Latency of StableDiffusionIPEXPipeline--fp32", latency) +latency = elapsed_time(pipe4) +print("Latency of StableDiffusionPipeline--fp32",latency) + +``` + +### CLIP Guided Images Mixing With Stable Diffusion + +![clip_guided_images_mixing_examples](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/main.png) + +CLIP guided stable diffusion images mixing pipline allows to combine two images using standard diffusion models. +This approach is using (optional) CoCa model to avoid writing image description. +[More code examples](https://github.com/TheDenk/images_mixing) + +## Example Images Mixing (with CoCa) +```python +import requests +from io import BytesIO + +import PIL +import torch +import open_clip +from open_clip import SimpleTokenizer +from diffusers import DiffusionPipeline +from transformers import CLIPFeatureExtractor, CLIPModel + + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + +# Loading additional models +feature_extractor = CLIPFeatureExtractor.from_pretrained( + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" +) +clip_model = CLIPModel.from_pretrained( + "laion/CLIP-ViT-B-32-laion2B-s34B-b79K", torch_dtype=torch.float16 +) +coca_model = open_clip.create_model('coca_ViT-L-14', pretrained='laion2B-s13B-b90k').to('cuda') +coca_model.dtype = torch.float16 +coca_transform = open_clip.image_transform( + coca_model.visual.image_size, + is_train = False, + mean = getattr(coca_model.visual, 'image_mean', None), + std = getattr(coca_model.visual, 'image_std', None), +) +coca_tokenizer = SimpleTokenizer() + +# Pipline creating +mixing_pipeline = DiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", + custom_pipeline="clip_guided_images_mixing_stable_diffusion", + clip_model=clip_model, + feature_extractor=feature_extractor, + coca_model=coca_model, + coca_tokenizer=coca_tokenizer, + coca_transform=coca_transform, + torch_dtype=torch.float16, +) +mixing_pipeline.enable_attention_slicing() +mixing_pipeline = mixing_pipeline.to("cuda") + +# Pipline running +generator = torch.Generator(device="cuda").manual_seed(17) + +def download_image(url): + response = requests.get(url) + return PIL.Image.open(BytesIO(response.content)).convert("RGB") + +content_image = download_image("https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir.jpg") +style_image = download_image("https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/gigachad.jpg") + +pipe_images = mixing_pipeline( + num_inference_steps=50, + content_image=content_image, + style_image=style_image, + noise_strength=0.65, + slerp_latent_style_strength=0.9, + slerp_prompt_style_strength=0.1, + slerp_clip_image_style_strength=0.1, + guidance_scale=9.0, + batch_size=1, + clip_guidance_scale=100, + generator=generator, +).images +``` + +![image_mixing_result](https://huggingface.co/datasets/TheDenk/images_mixing/resolve/main/boromir_gigachad.png) + +### Stable Diffusion Mixture + +This pipeline uses the Mixture. Refer to the [Mixture](https://arxiv.org/abs/2302.02412) paper for more details. + +```python +from diffusers import LMSDiscreteScheduler, DiffusionPipeline + +# Creater scheduler and model (similar to StableDiffusionPipeline) +scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) +pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, custom_pipeline="mixture_tiling") +pipeline.to("cuda") + +# Mixture of Diffusers generation +image = pipeline( + prompt=[[ + "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", + "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", + "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" + ]], + tile_height=640, + tile_width=640, + tile_row_overlap=0, + tile_col_overlap=256, + guidance_scale=8, + seed=7178915308, + num_inference_steps=50, +)["images"][0] +``` +![mixture_tiling_results](https://huggingface.co/datasets/kadirnar/diffusers_readme_images/resolve/main/mixture_tiling.png) + +### TensorRT Inpainting Stable Diffusion Pipeline + +The TensorRT Pipeline can be used to accelerate the Inpainting Stable Diffusion Inference run. + +NOTE: The ONNX conversions and TensorRT engine build may take up to 30 minutes. + +```python +import requests +from io import BytesIO +from PIL import Image +import torch +from diffusers import PNDMScheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionImg2ImgPipeline + +# Use the PNDMScheduler scheduler here instead +scheduler = PNDMScheduler.from_pretrained("stabilityai/stable-diffusion-2-inpainting", subfolder="scheduler") + + +pipe = StableDiffusionImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", + custom_pipeline="stable_diffusion_tensorrt_inpaint", + revision='fp16', + torch_dtype=torch.float16, + scheduler=scheduler, + ) + +# re-use cached folder to save ONNX models and TensorRT Engines +pipe.set_cached_folder("stabilityai/stable-diffusion-2-inpainting", revision='fp16',) + +pipe = pipe.to("cuda") + +url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +response = requests.get(url) +input_image = Image.open(BytesIO(response.content)).convert("RGB") + +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" +response = requests.get(mask_url) +mask_image = Image.open(BytesIO(response.content)).convert("RGB") + +prompt = "a mecha robot sitting on a bench" +image = pipe(prompt, image=input_image, mask_image=mask_image, strength=0.75,).images[0] +image.save('tensorrt_inpaint_mecha_robot.png') +``` \ No newline at end of file diff --git a/examples/community/clip_guided_images_mixing_stable_diffusion.py b/examples/community/clip_guided_images_mixing_stable_diffusion.py new file mode 100644 index 000000000000..e4c52fe63f49 --- /dev/null +++ b/examples/community/clip_guided_images_mixing_stable_diffusion.py @@ -0,0 +1,456 @@ +# -*- coding: utf-8 -*- +import inspect +from typing import Optional, Union + +import numpy as np +import PIL +import torch +from torch.nn import functional as F +from torchvision import transforms +from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import ( + PIL_INTERPOLATION, + randn_tensor, +) + + +def preprocess(image, w, h): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def slerp(t, v0, v1, DOT_THRESHOLD=0.9995): + if not isinstance(v0, np.ndarray): + inputs_are_torch = True + input_device = v0.device + v0 = v0.cpu().numpy() + v1 = v1.cpu().numpy() + + dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1))) + if np.abs(dot) > DOT_THRESHOLD: + v2 = (1 - t) * v0 + t * v1 + else: + theta_0 = np.arccos(dot) + sin_theta_0 = np.sin(theta_0) + theta_t = theta_0 * t + sin_theta_t = np.sin(theta_t) + s0 = np.sin(theta_0 - theta_t) / sin_theta_0 + s1 = sin_theta_t / sin_theta_0 + v2 = s0 * v0 + s1 * v1 + + if inputs_are_torch: + v2 = torch.from_numpy(v2).to(input_device) + + return v2 + + +def spherical_dist_loss(x, y): + x = F.normalize(x, dim=-1) + y = F.normalize(y, dim=-1) + return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) + + +def set_requires_grad(model, value): + for param in model.parameters(): + param.requires_grad = value + + +class CLIPGuidedImagesMixingStableDiffusion(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + clip_model: CLIPModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler, DPMSolverMultistepScheduler], + feature_extractor: CLIPFeatureExtractor, + coca_model=None, + coca_tokenizer=None, + coca_transform=None, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + clip_model=clip_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + feature_extractor=feature_extractor, + coca_model=coca_model, + coca_tokenizer=coca_tokenizer, + coca_transform=coca_transform, + ) + self.feature_extractor_size = ( + feature_extractor.size + if isinstance(feature_extractor.size, int) + else feature_extractor.size["shortest_edge"] + ) + self.normalize = transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std) + set_requires_grad(self.text_encoder, False) + set_requires_grad(self.clip_model, False) + + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + self.unet.set_attention_slice(slice_size) + + def disable_attention_slicing(self): + self.enable_attention_slicing(None) + + def freeze_vae(self): + set_requires_grad(self.vae, False) + + def unfreeze_vae(self): + set_requires_grad(self.vae, True) + + def freeze_unet(self): + set_requires_grad(self.unet, False) + + def unfreeze_unet(self): + set_requires_grad(self.unet, True) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, dtype, device, generator=None): + if not isinstance(image, torch.Tensor): + raise ValueError(f"`image` has to be of type `torch.Tensor` but is {type(image)}") + + image = image.to(device=device, dtype=dtype) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor + init_latents = 0.18215 * init_latents + init_latents = init_latents.repeat_interleave(batch_size, dim=0) + + noise = randn_tensor(init_latents.shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + def get_image_description(self, image): + transformed_image = self.coca_transform(image).unsqueeze(0) + with torch.no_grad(), torch.cuda.amp.autocast(): + generated = self.coca_model.generate(transformed_image.to(device=self.device, dtype=self.coca_model.dtype)) + generated = self.coca_tokenizer.decode(generated[0].cpu().numpy()) + return generated.split("")[0].replace("", "").rstrip(" .,") + + def get_clip_image_embeddings(self, image, batch_size): + clip_image_input = self.feature_extractor.preprocess(image) + clip_image_features = torch.from_numpy(clip_image_input["pixel_values"][0]).unsqueeze(0).to(self.device).half() + image_embeddings_clip = self.clip_model.get_image_features(clip_image_features) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + image_embeddings_clip = image_embeddings_clip.repeat_interleave(batch_size, dim=0) + return image_embeddings_clip + + @torch.enable_grad() + def cond_fn( + self, + latents, + timestep, + index, + text_embeddings, + noise_pred_original, + original_image_embeddings_clip, + clip_guidance_scale, + ): + latents = latents.detach().requires_grad_() + + latent_model_input = self.scheduler.scale_model_input(latents, timestep) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings).sample + + if isinstance(self.scheduler, (PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler)): + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + beta_prod_t = 1 - alpha_prod_t + # compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_original_sample = (latents - beta_prod_t ** (0.5) * noise_pred) / alpha_prod_t ** (0.5) + + fac = torch.sqrt(beta_prod_t) + sample = pred_original_sample * (fac) + latents * (1 - fac) + elif isinstance(self.scheduler, LMSDiscreteScheduler): + sigma = self.scheduler.sigmas[index] + sample = latents - sigma * noise_pred + else: + raise ValueError(f"scheduler type {type(self.scheduler)} not supported") + + # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor + sample = 1 / 0.18215 * sample + image = self.vae.decode(sample).sample + image = (image / 2 + 0.5).clamp(0, 1) + + image = transforms.Resize(self.feature_extractor_size)(image) + image = self.normalize(image).to(latents.dtype) + + image_embeddings_clip = self.clip_model.get_image_features(image) + image_embeddings_clip = image_embeddings_clip / image_embeddings_clip.norm(p=2, dim=-1, keepdim=True) + + loss = spherical_dist_loss(image_embeddings_clip, original_image_embeddings_clip).mean() * clip_guidance_scale + + grads = -torch.autograd.grad(loss, latents)[0] + + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents.detach() + grads * (sigma**2) + noise_pred = noise_pred_original + else: + noise_pred = noise_pred_original - torch.sqrt(beta_prod_t) * grads + return noise_pred, latents + + @torch.no_grad() + def __call__( + self, + style_image: Union[torch.FloatTensor, PIL.Image.Image], + content_image: Union[torch.FloatTensor, PIL.Image.Image], + style_prompt: Optional[str] = None, + content_prompt: Optional[str] = None, + height: Optional[int] = 512, + width: Optional[int] = 512, + noise_strength: float = 0.6, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + batch_size: Optional[int] = 1, + eta: float = 0.0, + clip_guidance_scale: Optional[float] = 100, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + slerp_latent_style_strength: float = 0.8, + slerp_prompt_style_strength: float = 0.1, + slerp_clip_image_style_strength: float = 0.1, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError(f"You have passed {batch_size} batch_size, but only {len(generator)} generators.") + + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if isinstance(generator, torch.Generator) and batch_size > 1: + generator = [generator] + [None] * (batch_size - 1) + + coca_is_none = [ + ("model", self.coca_model is None), + ("tokenizer", self.coca_tokenizer is None), + ("transform", self.coca_transform is None), + ] + coca_is_none = [x[0] for x in coca_is_none if x[1]] + coca_is_none_str = ", ".join(coca_is_none) + # generate prompts with coca model if prompt is None + if content_prompt is None: + if len(coca_is_none): + raise ValueError( + f"Content prompt is None and CoCa [{coca_is_none_str}] is None." + f"Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline." + ) + content_prompt = self.get_image_description(content_image) + if style_prompt is None: + if len(coca_is_none): + raise ValueError( + f"Style prompt is None and CoCa [{coca_is_none_str}] is None." + f" Set prompt or pass Coca [{coca_is_none_str}] to DiffusionPipeline." + ) + style_prompt = self.get_image_description(style_image) + + # get prompt text embeddings for content and style + content_text_input = self.tokenizer( + content_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + content_text_embeddings = self.text_encoder(content_text_input.input_ids.to(self.device))[0] + + style_text_input = self.tokenizer( + style_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + style_text_embeddings = self.text_encoder(style_text_input.input_ids.to(self.device))[0] + + text_embeddings = slerp(slerp_prompt_style_strength, content_text_embeddings, style_text_embeddings) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + # Some schedulers like PNDM have timesteps as arrays + # It's more optimized to move all timesteps to correct device beforehand + self.scheduler.timesteps.to(self.device) + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, noise_strength, self.device) + latent_timestep = timesteps[:1].repeat(batch_size) + + # Preprocess image + preprocessed_content_image = preprocess(content_image, width, height) + content_latents = self.prepare_latents( + preprocessed_content_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator + ) + + preprocessed_style_image = preprocess(style_image, width, height) + style_latents = self.prepare_latents( + preprocessed_style_image, latent_timestep, batch_size, text_embeddings.dtype, self.device, generator + ) + + latents = slerp(slerp_latent_style_strength, content_latents, style_latents) + + if clip_guidance_scale > 0: + content_clip_image_embedding = self.get_clip_image_embeddings(content_image, batch_size) + style_clip_image_embedding = self.get_clip_image_embeddings(style_image, batch_size) + clip_image_embeddings = slerp( + slerp_clip_image_style_strength, content_clip_image_embedding, style_clip_image_embedding + ) + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + max_length = content_text_input.input_ids.shape[-1] + uncond_input = self.tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt") + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + # duplicate unconditional embeddings for each generation per prompt + uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size, dim=0) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + # get the initial random noise unless the user supplied it + + # Unlike in other pipelines, latents need to be generated in the target device + # for 1-to-1 results reproducibility with the CompVis implementation. + # However this currently doesn't work in `mps`. + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + latents_dtype = text_embeddings.dtype + if latents is None: + if self.device.type == "mps": + # randn does not work reproducibly on mps + latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( + self.device + ) + else: + latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + + with self.progress_bar(total=num_inference_steps): + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform classifier free guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # perform clip guidance + if clip_guidance_scale > 0: + text_embeddings_for_guidance = ( + text_embeddings.chunk(2)[1] if do_classifier_free_guidance else text_embeddings + ) + noise_pred, latents = self.cond_fn( + latents, + t, + i, + text_embeddings_for_guidance, + noise_pred, + clip_image_embeddings, + clip_guidance_scale, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # Hardcode 0.18215 because stable-diffusion-2-base has not self.vae.config.scaling_factor + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/examples/community/edict_pipeline.py b/examples/community/edict_pipeline.py new file mode 100644 index 000000000000..ac977f79abec --- /dev/null +++ b/examples/community/edict_pipeline.py @@ -0,0 +1,264 @@ +from typing import Optional + +import torch +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DDIMScheduler, DiffusionPipeline, UNet2DConditionModel +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils import ( + deprecate, +) + + +class EDICTPipeline(DiffusionPipeline): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + mixing_coeff: float = 0.93, + leapfrog_steps: bool = True, + ): + self.mixing_coeff = mixing_coeff + self.leapfrog_steps = leapfrog_steps + + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def _encode_prompt( + self, prompt: str, negative_prompt: Optional[str] = None, do_classifier_free_guidance: bool = False + ): + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + prompt_embeds = self.text_encoder(text_inputs.input_ids.to(self.device)).last_hidden_state + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=self.device) + + if do_classifier_free_guidance: + uncond_tokens = "" if negative_prompt is None else negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(self.device)).last_hidden_state + + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def denoise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): + x = self.mixing_coeff * x + (1 - self.mixing_coeff) * y + y = self.mixing_coeff * y + (1 - self.mixing_coeff) * x + + return [x, y] + + def noise_mixing_layer(self, x: torch.Tensor, y: torch.Tensor): + y = (y - (1 - self.mixing_coeff) * x) / self.mixing_coeff + x = (x - (1 - self.mixing_coeff) * y) / self.mixing_coeff + + return [x, y] + + def _get_alpha_and_beta(self, t: torch.Tensor): + # as self.alphas_cumprod is always in cpu + t = int(t) + + alpha_prod = self.scheduler.alphas_cumprod[t] if t >= 0 else self.scheduler.final_alpha_cumprod + + return alpha_prod, 1 - alpha_prod + + def noise_step( + self, + base: torch.Tensor, + model_input: torch.Tensor, + model_output: torch.Tensor, + timestep: torch.Tensor, + ): + prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps + + alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep) + alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) + + a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 + b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 + + next_model_input = (base - b_t * model_output) / a_t + + return model_input, next_model_input.to(base.dtype) + + def denoise_step( + self, + base: torch.Tensor, + model_input: torch.Tensor, + model_output: torch.Tensor, + timestep: torch.Tensor, + ): + prev_timestep = timestep - self.scheduler.config.num_train_timesteps / self.scheduler.num_inference_steps + + alpha_prod_t, beta_prod_t = self._get_alpha_and_beta(timestep) + alpha_prod_t_prev, beta_prod_t_prev = self._get_alpha_and_beta(prev_timestep) + + a_t = (alpha_prod_t_prev / alpha_prod_t) ** 0.5 + b_t = -a_t * (beta_prod_t**0.5) + beta_prod_t_prev**0.5 + next_model_input = a_t * base + b_t * model_output + + return model_input, next_model_input.to(base.dtype) + + @torch.no_grad() + def decode_latents(self, latents: torch.Tensor): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image + + @torch.no_grad() + def prepare_latents( + self, + image: Image.Image, + text_embeds: torch.Tensor, + timesteps: torch.Tensor, + guidance_scale: float, + generator: Optional[torch.Generator] = None, + ): + do_classifier_free_guidance = guidance_scale > 1.0 + + image = image.to(device=self.device, dtype=text_embeds.dtype) + latent = self.vae.encode(image).latent_dist.sample(generator) + + latent = self.vae.config.scaling_factor * latent + + coupled_latents = [latent.clone(), latent.clone()] + + for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): + coupled_latents = self.noise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) + + # j - model_input index, k - base index + for j in range(2): + k = j ^ 1 + + if self.leapfrog_steps: + if i % 2 == 0: + k, j = j, k + + model_input = coupled_latents[j] + base = coupled_latents[k] + + latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeds).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + base, model_input = self.noise_step( + base=base, + model_input=model_input, + model_output=noise_pred, + timestep=t, + ) + + coupled_latents[k] = model_input + + return coupled_latents + + @torch.no_grad() + def __call__( + self, + base_prompt: str, + target_prompt: str, + image: Image.Image, + guidance_scale: float = 3.0, + num_inference_steps: int = 50, + strength: float = 0.8, + negative_prompt: Optional[str] = None, + generator: Optional[torch.Generator] = None, + output_type: Optional[str] = "pil", + ): + do_classifier_free_guidance = guidance_scale > 1.0 + + image = self.image_processor.preprocess(image) + + base_embeds = self._encode_prompt(base_prompt, negative_prompt, do_classifier_free_guidance) + target_embeds = self._encode_prompt(target_prompt, negative_prompt, do_classifier_free_guidance) + + self.scheduler.set_timesteps(num_inference_steps, self.device) + + t_limit = num_inference_steps - int(num_inference_steps * strength) + fwd_timesteps = self.scheduler.timesteps[t_limit:] + bwd_timesteps = fwd_timesteps.flip(0) + + coupled_latents = self.prepare_latents(image, base_embeds, bwd_timesteps, guidance_scale, generator) + + for i, t in tqdm(enumerate(fwd_timesteps), total=len(fwd_timesteps)): + # j - model_input index, k - base index + for k in range(2): + j = k ^ 1 + + if self.leapfrog_steps: + if i % 2 == 1: + k, j = j, k + + model_input = coupled_latents[j] + base = coupled_latents[k] + + latent_model_input = torch.cat([model_input] * 2) if do_classifier_free_guidance else model_input + + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=target_embeds).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + base, model_input = self.denoise_step( + base=base, + model_input=model_input, + model_output=noise_pred, + timestep=t, + ) + + coupled_latents[k] = model_input + + coupled_latents = self.denoise_mixing_layer(x=coupled_latents[0], y=coupled_latents[1]) + + # either one is fine + final_latent = coupled_latents[0] + + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + image = final_latent + else: + image = self.decode_latents(final_latent) + image = self.image_processor.postprocess(image, output_type=output_type) + + return image diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py index e912ad5244be..56fb903c7106 100644 --- a/examples/community/lpw_stable_diffusion.py +++ b/examples/community/lpw_stable_diffusion.py @@ -1,6 +1,6 @@ import inspect import re -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -8,32 +8,23 @@ from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -import diffusers -from diffusers import SchedulerMixin, StableDiffusionPipeline +from diffusers import DiffusionPipeline +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker -from diffusers.utils import logging - - -try: - from diffusers.utils import PIL_INTERPOLATION -except ImportError: - if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): - PIL_INTERPOLATION = { - "linear": PIL.Image.Resampling.BILINEAR, - "bilinear": PIL.Image.Resampling.BILINEAR, - "bicubic": PIL.Image.Resampling.BICUBIC, - "lanczos": PIL.Image.Resampling.LANCZOS, - "nearest": PIL.Image.Resampling.NEAREST, - } - else: - PIL_INTERPOLATION = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - "nearest": PIL.Image.NEAREST, - } +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) + + # ------------------------------------------------------------------------------ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -144,7 +135,7 @@ def multiply_range(start_position, multiplier): return res -def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int): +def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int): r""" Tokenize a list of prompts and return its tokens with weights of each token. @@ -205,7 +196,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, pad, no_boseos def get_unweighted_text_embeddings( - pipe: StableDiffusionPipeline, + pipe: DiffusionPipeline, text_input: torch.Tensor, chunk_length: int, no_boseos_middle: Optional[bool] = True, @@ -245,7 +236,7 @@ def get_unweighted_text_embeddings( def get_weighted_text_embeddings( - pipe: StableDiffusionPipeline, + pipe: DiffusionPipeline, prompt: Union[str, List[str]], uncond_prompt: Optional[Union[str, List[str]]] = None, max_embeddings_multiples: Optional[int] = 3, @@ -261,7 +252,7 @@ def get_weighted_text_embeddings( Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. Args: - pipe (`StableDiffusionPipeline`): + pipe (`DiffusionPipeline`): Pipe to provide access to the tokenizer and the text encoder. prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. @@ -349,7 +340,7 @@ def get_weighted_text_embeddings( pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, ) - prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) + prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=text_embeddings.device) if uncond_prompt is not None: uncond_embeddings = get_unweighted_text_embeddings( pipe, @@ -357,7 +348,7 @@ def get_weighted_text_embeddings( pipe.tokenizer.model_max_length, no_boseos_middle=no_boseos_middle, ) - uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) + uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=uncond_embeddings.device) # assign weights to the prompts and normalize in the sense of mean # TODO: should we normalize by chunk or in a whole (current implementation)? @@ -377,30 +368,50 @@ def get_weighted_text_embeddings( return text_embeddings, None -def preprocess_image(image): +def preprocess_image(image, batch_size): w, h = image.size - w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) + image = np.vstack([image[None].transpose(0, 3, 1, 2)] * batch_size) image = torch.from_numpy(image) return 2.0 * image - 1.0 -def preprocess_mask(mask, scale_factor=8): - mask = mask.convert("L") - w, h = mask.size - w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 - mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) - mask = np.array(mask).astype(np.float32) / 255.0 - mask = np.tile(mask, (4, 1, 1)) - mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? - mask = 1 - mask # repaint white, keep black - mask = torch.from_numpy(mask) - return mask +def preprocess_mask(mask, batch_size, scale_factor=8): + if not isinstance(mask, torch.FloatTensor): + mask = mask.convert("L") + w, h = mask.size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"]) + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask, (4, 1, 1)) + mask = np.vstack([mask[None]] * batch_size) + mask = 1 - mask # repaint white, keep black + mask = torch.from_numpy(mask) + return mask + + else: + valid_mask_channel_sizes = [1, 3] + # if mask channel is fourth tensor dimension, permute dimensions to pytorch standard (B, C, H, W) + if mask.shape[3] in valid_mask_channel_sizes: + mask = mask.permute(0, 3, 1, 2) + elif mask.shape[1] not in valid_mask_channel_sizes: + raise ValueError( + f"Mask channel dimension of size in {valid_mask_channel_sizes} should be second or fourth dimension," + f" but received mask of shape {tuple(mask.shape)}" + ) + # (potentially) reduce mask channel dimension from 3 to 1 for broadcasting to latent shape + mask = mask.mean(dim=1, keepdim=True) + h, w = mask.shape[-2:] + h, w = (x - x % 8 for x in (h, w)) # resize to integer multiple of 8 + mask = torch.nn.functional.interpolate(mask, (h // scale_factor, w // scale_factor)) + return mask -class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): +class StableDiffusionLongPromptWeightingPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing weighting in prompt. @@ -429,66 +440,196 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - if version.parse(version.parse(diffusers.__version__).base_version) >= version.parse("0.9.0"): - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) - self.__init__additional__() - else: + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" ) - self.__init__additional__() + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config( + requires_safety_checker=requires_safety_checker, + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - def __init__additional__(self): - if not hasattr(self, "vae_scale_factor"): - setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1)) + # We'll offload the last model manually. + self.final_offload_hook = hook @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): + if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( @@ -505,8 +646,10 @@ def _encode_prompt( device, num_images_per_prompt, do_classifier_free_guidance, - negative_prompt, - max_embeddings_multiples, + negative_prompt=None, + max_embeddings_multiples=3, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -526,47 +669,71 @@ def _encode_prompt( max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. """ - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * batch_size + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * batch_size + if batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + if prompt_embeds is None or negative_prompt_embeds is None: + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = self.maybe_convert_prompt(negative_prompt, self.tokenizer) + + prompt_embeds1, negative_prompt_embeds1 = get_weighted_text_embeddings( + pipe=self, + prompt=prompt, + uncond_prompt=negative_prompt if do_classifier_free_guidance else None, + max_embeddings_multiples=max_embeddings_multiples, ) + if prompt_embeds is None: + prompt_embeds = prompt_embeds1 + if negative_prompt_embeds is None: + negative_prompt_embeds = negative_prompt_embeds1 - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt if do_classifier_free_guidance else None, - max_embeddings_multiples=max_embeddings_multiples, - ) - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: - bs_embed, seq_len, _ = uncond_embeddings.shape - uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + bs_embed, seq_len, _ = negative_prompt_embeds.shape + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - return text_embeddings + return prompt_embeds - def check_inputs(self, prompt, height, width, strength, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + def check_inputs( + self, + prompt, + height, + width, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -575,17 +742,42 @@ def check_inputs(self, prompt, height, width, strength, callback_steps): f" {type(callback_steps)}." ) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + def get_timesteps(self, num_inference_steps, strength, device, is_text2img): if is_text2img: return self.scheduler.timesteps.to(device), num_inference_steps else: # get the original timestep using init_timestep - offset = self.scheduler.config.get("steps_offset", 0) - init_timestep = int(num_inference_steps * strength) + offset - init_timestep = min(init_timestep, num_inference_steps) + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - t_start = max(num_inference_steps - init_timestep + offset, 0) - timesteps = self.scheduler.timesteps[t_start:].to(device) return timesteps, num_inference_steps - t_start def run_safety_checker(self, image, device, dtype): @@ -599,7 +791,7 @@ def run_safety_checker(self, image, device, dtype): return image, has_nsfw_concept def decode_latents(self, latents): - latents = 1 / 0.18215 * latents + latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -623,43 +815,51 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + image, + timestep, + num_images_per_prompt, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): if image is None: - shape = ( - batch_size, - self.unet.config.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) + batch_size = batch_size * num_images_per_prompt + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) if latents is None: - if device.type == "mps": - # randn does not work reproducibly on mps - latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - latents = torch.randn(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents, None, None else: + image = image.to(device=self.device, dtype=dtype) init_latent_dist = self.vae.encode(image).latent_dist init_latents = init_latent_dist.sample(generator=generator) - init_latents = 0.18215 * init_latents - init_latents = torch.cat([init_latents] * batch_size, dim=0) + init_latents = self.vae.config.scaling_factor * init_latents + + # Expand init_latents for batch_size and num_images_per_prompt + init_latents = torch.cat([init_latents] * num_images_per_prompt, dim=0) init_latents_orig = init_latents - shape = init_latents.shape # add noise to latents using the timesteps - if device.type == "mps": - noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) - else: - noise = torch.randn(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.add_noise(init_latents, noise, timestep) + noise = randn_tensor(init_latents.shape, generator=generator, device=self.device, dtype=dtype) + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents return latents, init_latents_orig, noise @torch.no_grad() @@ -675,15 +875,19 @@ def __call__( guidance_scale: float = 7.5, strength: float = 0.8, num_images_per_prompt: Optional[int] = 1, + add_predicted_noise: Optional[bool] = False, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -723,16 +927,26 @@ def __call__( `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -750,6 +964,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: `None` if cancelled by `is_cancelled_callback`, @@ -764,10 +982,18 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, strength, callback_steps) + self.check_inputs( + prompt, height, width, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) # 2. Define call parameters - batch_size = 1 if isinstance(prompt, str) else len(prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -775,26 +1001,28 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt - text_embeddings = self._encode_prompt( + prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt, max_embeddings_multiples, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, ) - dtype = text_embeddings.dtype + dtype = prompt_embeds.dtype # 4. Preprocess image and mask if isinstance(image, PIL.Image.Image): - image = preprocess_image(image) + image = preprocess_image(image, batch_size) if image is not None: image = image.to(device=self.device, dtype=dtype) if isinstance(mask_image, PIL.Image.Image): - mask_image = preprocess_mask(mask_image, self.vae_scale_factor) + mask_image = preprocess_mask(mask_image, batch_size, self.vae_scale_factor) if mask_image is not None: mask = mask_image.to(device=self.device, dtype=dtype) - mask = torch.cat([mask] * batch_size * num_images_per_prompt) + mask = torch.cat([mask] * num_images_per_prompt) else: mask = None @@ -807,7 +1035,9 @@ def __call__( latents, init_latents_orig, noise = self.prepare_latents( image, latent_timestep, - batch_size * num_images_per_prompt, + num_images_per_prompt, + batch_size, + self.unet.config.in_channels, height, width, dtype, @@ -820,43 +1050,70 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 8. Denoising loop - for i, t in enumerate(self.progress_bar(timesteps)): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - if mask is not None: - # masking - init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) - latents = (init_latents_proper * mask) + (latents * (1 - mask)) - - # call the callback, if provided - if i % callback_steps == 0: - if callback is not None: - callback(i, t, latents) - if is_cancelled_callback is not None and is_cancelled_callback(): - return None - - # 9. Post-processing - image = self.decode_latents(latents) - - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) - - # 11. Convert to PIL - if output_type == "pil": + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + if mask is not None: + # masking + if add_predicted_noise: + init_latents_proper = self.scheduler.add_noise( + init_latents_orig, noise_pred_uncond, torch.tensor([t]) + ) + else: + init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t])) + latents = (init_latents_proper * mask) + (latents * (1 - mask)) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if i % callback_steps == 0: + if callback is not None: + callback(i, t, latents) + if is_cancelled_callback is not None and is_cancelled_callback(): + return None + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 11. Convert to PIL image = self.numpy_to_pil(image) + else: + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() if not return_dict: return image, has_nsfw_concept @@ -873,14 +1130,17 @@ def text2img( guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function for text-to-image generation. @@ -908,13 +1168,20 @@ def text2img( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -932,7 +1199,13 @@ def text2img( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Returns: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a @@ -950,12 +1223,15 @@ def text2img( eta=eta, generator=generator, latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, max_embeddings_multiples=max_embeddings_multiples, output_type=output_type, return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, ) def img2img( @@ -968,13 +1244,16 @@ def img2img( guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function for image-to-image generation. @@ -1007,9 +1286,16 @@ def img2img( eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -1027,8 +1313,13 @@ def img2img( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" @@ -1044,12 +1335,15 @@ def img2img( num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, max_embeddings_multiples=max_embeddings_multiples, output_type=output_type, return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, ) def inpaint( @@ -1062,14 +1356,18 @@ def inpaint( num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, num_images_per_prompt: Optional[int] = 1, + add_predicted_noise: Optional[bool] = False, eta: Optional[float] = 0.0, - generator: Optional[torch.Generator] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, max_embeddings_multiples: Optional[int] = 3, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, is_cancelled_callback: Optional[Callable[[], bool]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function for inpaint. @@ -1103,12 +1401,22 @@ def inpaint( usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. + add_predicted_noise (`bool`, *optional*, defaults to True): + Use predicted noise instead of random noise when constructing noisy versions of the original image in + the reverse diffusion process eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. max_embeddings_multiples (`int`, *optional*, defaults to `3`): The max multiple length of prompt embeddings compared to the max output length of text encoder. output_type (`str`, *optional*, defaults to `"pil"`): @@ -1126,8 +1434,13 @@ def inpaint( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + `None` if cancelled by `is_cancelled_callback`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" @@ -1142,12 +1455,16 @@ def inpaint( guidance_scale=guidance_scale, strength=strength, num_images_per_prompt=num_images_per_prompt, + add_predicted_noise=add_predicted_noise, eta=eta, generator=generator, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, max_embeddings_multiples=max_embeddings_multiples, output_type=output_type, return_dict=return_dict, callback=callback, is_cancelled_callback=is_cancelled_callback, callback_steps=callback_steps, + cross_attention_kwargs=cross_attention_kwargs, ) diff --git a/examples/community/mixture.py b/examples/community/mixture.py new file mode 100644 index 000000000000..845ad76b6a2e --- /dev/null +++ b/examples/community/mixture.py @@ -0,0 +1,401 @@ +import inspect +from copy import deepcopy +from enum import Enum +from typing import List, Optional, Tuple, Union + +import torch +from ligo.segments import segment +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import LMSDiscreteScheduler + >>> from mixdiff import StableDiffusionTilingPipeline + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + >>> pipeline = StableDiffusionTilingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler) + >>> pipeline.to("cuda:0") + + >>> image = pipeline( + >>> prompt=[[ + >>> "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", + >>> "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", + >>> "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" + >>> ]], + >>> tile_height=640, + >>> tile_width=640, + >>> tile_row_overlap=0, + >>> tile_col_overlap=256, + >>> guidance_scale=8, + >>> seed=7178915308, + >>> num_inference_steps=50, + >>> )["images"][0] + ``` +""" + + +def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): + """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image + + Returns a tuple with: + - Starting coordinates of rows in pixel space + - Ending coordinates of rows in pixel space + - Starting coordinates of columns in pixel space + - Ending coordinates of columns in pixel space + """ + px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap) + px_row_end = px_row_init + tile_height + px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap) + px_col_end = px_col_init + tile_width + return px_row_init, px_row_end, px_col_init, px_col_end + + +def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end): + """Translates coordinates in pixel space to coordinates in latent space""" + return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8 + + +def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): + """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image + + Returns a tuple with: + - Starting coordinates of rows in latent space + - Ending coordinates of rows in latent space + - Starting coordinates of columns in latent space + - Ending coordinates of columns in latent space + """ + px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices( + tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end) + + +def _tile2latent_exclusive_indices( + tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns +): + """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image + + Returns a tuple with: + - Starting coordinates of rows in latent space + - Ending coordinates of rows in latent space + - Starting coordinates of columns in latent space + - Ending coordinates of columns in latent space + """ + row_init, row_end, col_init, col_end = _tile2latent_indices( + tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + row_segment = segment(row_init, row_end) + col_segment = segment(col_init, col_end) + # Iterate over the rest of tiles, clipping the region for the current tile + for row in range(rows): + for column in range(columns): + if row != tile_row and column != tile_col: + clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices( + row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + row_segment = row_segment - segment(clip_row_init, clip_row_end) + col_segment = col_segment - segment(clip_col_init, clip_col_end) + # return row_init, row_end, col_init, col_end + return row_segment[0], row_segment[1], col_segment[0], col_segment[1] + + +class StableDiffusionExtrasMixin: + """Mixin providing additional convenience method to Stable Diffusion pipelines""" + + def decode_latents(self, latents, cpu_vae=False): + """Decodes a given array of latents into pixel space""" + # scale and decode the image latents with vae + if cpu_vae: + lat = deepcopy(latents).cpu() + vae = deepcopy(self.vae).cpu() + else: + lat = latents + vae = self.vae + + lat = 1 / 0.18215 * lat + image = vae.decode(lat).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + return self.numpy_to_pil(image) + + +class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + class SeedTilesMode(Enum): + """Modes in which the latents of a particular tile can be re-seeded""" + + FULL = "full" + EXCLUSIVE = "exclusive" + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[List[str]]], + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + seed: Optional[int] = None, + tile_height: Optional[int] = 512, + tile_width: Optional[int] = 512, + tile_row_overlap: Optional[int] = 256, + tile_col_overlap: Optional[int] = 256, + guidance_scale_tiles: Optional[List[List[float]]] = None, + seed_tiles: Optional[List[List[int]]] = None, + seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full", + seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None, + cpu_vae: Optional[bool] = False, + ): + r""" + Function to run the diffusion pipeline with tiling support. + + Args: + prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure. + num_inference_steps: number of diffusions steps. + guidance_scale: classifier-free guidance. + seed: general random seed to initialize latents. + tile_height: height in pixels of each grid tile. + tile_width: width in pixels of each grid tile. + tile_row_overlap: number of overlap pixels between tiles in consecutive rows. + tile_col_overlap: number of overlap pixels between tiles in consecutive columns. + guidance_scale_tiles: specific weights for classifier-free guidance in each tile. + guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used. + seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter. + seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden. + seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles. + cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues. + + Examples: + + Returns: + A PIL image with the generated image. + + """ + if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt): + raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}") + grid_rows = len(prompt) + grid_cols = len(prompt[0]) + if not all(len(row) == grid_cols for row in prompt): + raise ValueError("All prompt rows must have the same number of prompt columns") + if not isinstance(seed_tiles_mode, str) and ( + not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode) + ): + raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}") + if isinstance(seed_tiles_mode, str): + seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt] + modes = [mode.value for mode in self.SeedTilesMode] + if any(mode not in modes for row in seed_tiles_mode for mode in row): + raise ValueError(f"Seed tiles mode must be one of {modes}") + if seed_reroll_regions is None: + seed_reroll_regions = [] + batch_size = 1 + + # create original noisy latents using the timesteps + height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap) + width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap) + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + generator = torch.Generator("cuda").manual_seed(seed) + latents = torch.randn(latents_shape, generator=generator, device=self.device) + + # overwrite latents for specific tiles if provided + if seed_tiles is not None: + for row in range(grid_rows): + for col in range(grid_cols): + if (seed_tile := seed_tiles[row][col]) is not None: + mode = seed_tiles_mode[row][col] + if mode == self.SeedTilesMode.FULL.value: + row_init, row_end, col_init, col_end = _tile2latent_indices( + row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + else: + row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices( + row, + col, + tile_width, + tile_height, + tile_row_overlap, + tile_col_overlap, + grid_rows, + grid_cols, + ) + tile_generator = torch.Generator("cuda").manual_seed(seed_tile) + tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) + latents[:, :, row_init:row_end, col_init:col_end] = torch.randn( + tile_shape, generator=tile_generator, device=self.device + ) + + # overwrite again for seed reroll regions + for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions: + row_init, row_end, col_init, col_end = _pixel2latent_indices( + row_init, row_end, col_init, col_end + ) # to latent space coordinates + reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll) + region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) + latents[:, :, row_init:row_end, col_init:col_end] = torch.randn( + region_shape, generator=reroll_generator, device=self.device + ) + + # Prepare scheduler + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # get prompts text embeddings + text_input = [ + [ + self.tokenizer( + col, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + for col in row + ] + for row in prompt + ] + text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + for i in range(grid_rows): + for j in range(grid_cols): + max_length = text_input[i][j].input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # Mask for tile weights strenght + tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size) + + # Diffusion timesteps + for i, t in tqdm(enumerate(self.scheduler.timesteps)): + # Diffuse each tile + noise_preds = [] + for row in range(grid_rows): + noise_preds_row = [] + for col in range(grid_cols): + px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices( + row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[ + "sample" + ] + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guidance = ( + guidance_scale + if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None + else guidance_scale_tiles[row][col] + ) + noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + noise_preds_row.append(noise_pred_tile) + noise_preds.append(noise_preds_row) + # Stitch noise predictions for all tiles + noise_pred = torch.zeros(latents.shape, device=self.device) + contributors = torch.zeros(latents.shape, device=self.device) + # Add each tile contribution to overall latents + for row in range(grid_rows): + for col in range(grid_cols): + px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices( + row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += ( + noise_preds[row][col] * tile_weights + ) + contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights + # Average overlapping areas with more than 1 contributor + noise_pred /= contributors + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # scale and decode the image latents with vae + image = self.decode_latents(latents, cpu_vae) + + return {"images": image} + + def _gaussian_weights(self, tile_width, tile_height, nbatches): + """Generates a gaussian mask of weights for tile contributions""" + import numpy as np + from numpy import exp, pi, sqrt + + latent_width = tile_width // 8 + latent_height = tile_height // 8 + + var = 0.01 + midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 + x_probs = [ + exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) + for x in range(latent_width) + ] + midpoint = latent_height / 2 + y_probs = [ + exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) + for y in range(latent_height) + ] + + weights = np.outer(y_probs, x_probs) + return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1)) diff --git a/examples/community/mixture_tiling.py b/examples/community/mixture_tiling.py new file mode 100644 index 000000000000..3e701cf607f5 --- /dev/null +++ b/examples/community/mixture_tiling.py @@ -0,0 +1,405 @@ +import inspect +from copy import deepcopy +from enum import Enum +from typing import List, Optional, Tuple, Union + +import torch +from tqdm.auto import tqdm + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.utils import logging + + +try: + from ligo.segments import segment + from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +except ImportError: + raise ImportError("Please install transformers and ligo-segments to use the mixture pipeline") + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import LMSDiscreteScheduler, DiffusionPipeline + + >>> scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) + >>> pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, custom_pipeline="mixture_tiling") + >>> pipeline.to("cuda") + + >>> image = pipeline( + >>> prompt=[[ + >>> "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", + >>> "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", + >>> "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" + >>> ]], + >>> tile_height=640, + >>> tile_width=640, + >>> tile_row_overlap=0, + >>> tile_col_overlap=256, + >>> guidance_scale=8, + >>> seed=7178915308, + >>> num_inference_steps=50, + >>> )["images"][0] + ``` +""" + + +def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): + """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image + + Returns a tuple with: + - Starting coordinates of rows in pixel space + - Ending coordinates of rows in pixel space + - Starting coordinates of columns in pixel space + - Ending coordinates of columns in pixel space + """ + px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap) + px_row_end = px_row_init + tile_height + px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap) + px_col_end = px_col_init + tile_width + return px_row_init, px_row_end, px_col_init, px_col_end + + +def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end): + """Translates coordinates in pixel space to coordinates in latent space""" + return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8 + + +def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): + """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image + + Returns a tuple with: + - Starting coordinates of rows in latent space + - Ending coordinates of rows in latent space + - Starting coordinates of columns in latent space + - Ending coordinates of columns in latent space + """ + px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices( + tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end) + + +def _tile2latent_exclusive_indices( + tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns +): + """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image + + Returns a tuple with: + - Starting coordinates of rows in latent space + - Ending coordinates of rows in latent space + - Starting coordinates of columns in latent space + - Ending coordinates of columns in latent space + """ + row_init, row_end, col_init, col_end = _tile2latent_indices( + tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + row_segment = segment(row_init, row_end) + col_segment = segment(col_init, col_end) + # Iterate over the rest of tiles, clipping the region for the current tile + for row in range(rows): + for column in range(columns): + if row != tile_row and column != tile_col: + clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices( + row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + row_segment = row_segment - segment(clip_row_init, clip_row_end) + col_segment = col_segment - segment(clip_col_init, clip_col_end) + # return row_init, row_end, col_init, col_end + return row_segment[0], row_segment[1], col_segment[0], col_segment[1] + + +class StableDiffusionExtrasMixin: + """Mixin providing additional convenience method to Stable Diffusion pipelines""" + + def decode_latents(self, latents, cpu_vae=False): + """Decodes a given array of latents into pixel space""" + # scale and decode the image latents with vae + if cpu_vae: + lat = deepcopy(latents).cpu() + vae = deepcopy(self.vae).cpu() + else: + lat = latents + vae = self.vae + + lat = 1 / 0.18215 * lat + image = vae.decode(lat).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + + return self.numpy_to_pil(image) + + +class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin): + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: Union[DDIMScheduler, PNDMScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + ): + super().__init__() + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + + class SeedTilesMode(Enum): + """Modes in which the latents of a particular tile can be re-seeded""" + + FULL = "full" + EXCLUSIVE = "exclusive" + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[List[str]]], + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + eta: Optional[float] = 0.0, + seed: Optional[int] = None, + tile_height: Optional[int] = 512, + tile_width: Optional[int] = 512, + tile_row_overlap: Optional[int] = 256, + tile_col_overlap: Optional[int] = 256, + guidance_scale_tiles: Optional[List[List[float]]] = None, + seed_tiles: Optional[List[List[int]]] = None, + seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full", + seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None, + cpu_vae: Optional[bool] = False, + ): + r""" + Function to run the diffusion pipeline with tiling support. + + Args: + prompt: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure. + num_inference_steps: number of diffusions steps. + guidance_scale: classifier-free guidance. + seed: general random seed to initialize latents. + tile_height: height in pixels of each grid tile. + tile_width: width in pixels of each grid tile. + tile_row_overlap: number of overlap pixels between tiles in consecutive rows. + tile_col_overlap: number of overlap pixels between tiles in consecutive columns. + guidance_scale_tiles: specific weights for classifier-free guidance in each tile. + guidance_scale_tiles: specific weights for classifier-free guidance in each tile. If None, the value provided in guidance_scale will be used. + seed_tiles: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard seed parameter. + seed_tiles_mode: either "full" "exclusive". If "full", all the latents affected by the tile be overriden. If "exclusive", only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden. + seed_reroll_regions: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over seed_tiles. + cpu_vae: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to True to run the decoder in CPU. Slower, but should run without memory issues. + + Examples: + + Returns: + A PIL image with the generated image. + + """ + if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt): + raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}") + grid_rows = len(prompt) + grid_cols = len(prompt[0]) + if not all(len(row) == grid_cols for row in prompt): + raise ValueError("All prompt rows must have the same number of prompt columns") + if not isinstance(seed_tiles_mode, str) and ( + not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode) + ): + raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}") + if isinstance(seed_tiles_mode, str): + seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt] + + modes = [mode.value for mode in self.SeedTilesMode] + if any(mode not in modes for row in seed_tiles_mode for mode in row): + raise ValueError(f"Seed tiles mode must be one of {modes}") + if seed_reroll_regions is None: + seed_reroll_regions = [] + batch_size = 1 + + # create original noisy latents using the timesteps + height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap) + width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap) + latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) + generator = torch.Generator("cuda").manual_seed(seed) + latents = torch.randn(latents_shape, generator=generator, device=self.device) + + # overwrite latents for specific tiles if provided + if seed_tiles is not None: + for row in range(grid_rows): + for col in range(grid_cols): + if (seed_tile := seed_tiles[row][col]) is not None: + mode = seed_tiles_mode[row][col] + if mode == self.SeedTilesMode.FULL.value: + row_init, row_end, col_init, col_end = _tile2latent_indices( + row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + else: + row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices( + row, + col, + tile_width, + tile_height, + tile_row_overlap, + tile_col_overlap, + grid_rows, + grid_cols, + ) + tile_generator = torch.Generator("cuda").manual_seed(seed_tile) + tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) + latents[:, :, row_init:row_end, col_init:col_end] = torch.randn( + tile_shape, generator=tile_generator, device=self.device + ) + + # overwrite again for seed reroll regions + for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions: + row_init, row_end, col_init, col_end = _pixel2latent_indices( + row_init, row_end, col_init, col_end + ) # to latent space coordinates + reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll) + region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) + latents[:, :, row_init:row_end, col_init:col_end] = torch.randn( + region_shape, generator=reroll_generator, device=self.device + ) + + # Prepare scheduler + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) + # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas + if isinstance(self.scheduler, LMSDiscreteScheduler): + latents = latents * self.scheduler.sigmas[0] + + # get prompts text embeddings + text_input = [ + [ + self.tokenizer( + col, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + for col in row + ] + for row in prompt + ] + text_embeddings = [[self.text_encoder(col.input_ids.to(self.device))[0] for col in row] for row in text_input] + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance: + for i in range(grid_rows): + for j in range(grid_cols): + max_length = text_input[i][j].input_ids.shape[-1] + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" + ) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]]) + + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # Mask for tile weights strenght + tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size) + + # Diffusion timesteps + for i, t in tqdm(enumerate(self.scheduler.timesteps)): + # Diffuse each tile + noise_preds = [] + for row in range(grid_rows): + noise_preds_row = [] + for col in range(grid_cols): + px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices( + row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end] + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])[ + "sample" + ] + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + guidance = ( + guidance_scale + if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None + else guidance_scale_tiles[row][col] + ) + noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + noise_preds_row.append(noise_pred_tile) + noise_preds.append(noise_preds_row) + # Stitch noise predictions for all tiles + noise_pred = torch.zeros(latents.shape, device=self.device) + contributors = torch.zeros(latents.shape, device=self.device) + # Add each tile contribution to overall latents + for row in range(grid_rows): + for col in range(grid_cols): + px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices( + row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap + ) + noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += ( + noise_preds[row][col] * tile_weights + ) + contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights + # Average overlapping areas with more than 1 contributor + noise_pred /= contributors + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + # scale and decode the image latents with vae + image = self.decode_latents(latents, cpu_vae) + + return {"images": image} + + def _gaussian_weights(self, tile_width, tile_height, nbatches): + """Generates a gaussian mask of weights for tile contributions""" + import numpy as np + from numpy import exp, pi, sqrt + + latent_width = tile_width // 8 + latent_height = tile_height // 8 + + var = 0.01 + midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 + x_probs = [ + exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) + for x in range(latent_width) + ] + midpoint = latent_height / 2 + y_probs = [ + exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) + for y in range(latent_height) + ] + + weights = np.outer(y_probs, x_probs) + return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1)) diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index c47f4c3194e8..aae199f91b9e 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -1,7 +1,7 @@ # Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/ import inspect -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import PIL.Image @@ -11,6 +11,7 @@ from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import ( PIL_INTERPOLATION, @@ -184,7 +185,14 @@ def prepare_mask_image(mask_image): def prepare_controlnet_conditioning_image( - controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype + controlnet_conditioning_image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance, ): if not isinstance(controlnet_conditioning_image, torch.Tensor): if isinstance(controlnet_conditioning_image, PIL.Image.Image): @@ -214,6 +222,9 @@ def prepare_controlnet_conditioning_image( controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype) + if do_classifier_free_guidance: + controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) + return controlnet_conditioning_image @@ -230,7 +241,7 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - controlnet: ControlNetModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, @@ -254,6 +265,9 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -264,6 +278,7 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.register_to_config(requires_safety_checker=requires_safety_checker) @@ -522,6 +537,42 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs + def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: + raise TypeError( + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: + image_batch_size = len(image) + else: + raise ValueError("controlnet condition image is not valid") + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + else: + raise ValueError("prompt or prompt_embeds are not valid") + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, prompt, @@ -534,6 +585,7 @@ def check_inputs( negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, + controlnet_conditioning_scale=None, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -572,45 +624,35 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image) - controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor) - controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], PIL.Image.Image - ) - controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance( - controlnet_conditioning_image[0], torch.Tensor - ) - - if ( - not controlnet_cond_image_is_pil - and not controlnet_cond_image_is_tensor - and not controlnet_cond_image_is_pil_list - and not controlnet_cond_image_is_tensor_list - ): - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) - - if controlnet_cond_image_is_pil: - controlnet_cond_image_batch_size = 1 - elif controlnet_cond_image_is_tensor: - controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0] - elif controlnet_cond_image_is_pil_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) - elif controlnet_cond_image_is_tensor_list: - controlnet_cond_image_batch_size = len(controlnet_conditioning_image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}" - ) + # check controlnet condition image + if isinstance(self.controlnet, ControlNetModel): + self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds) + elif isinstance(self.controlnet, MultiControlNetModel): + if not isinstance(controlnet_conditioning_image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + if len(controlnet_conditioning_image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + for image_ in controlnet_conditioning_image: + self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if isinstance(self.controlnet, ControlNetModel): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif isinstance(self.controlnet, MultiControlNetModel): + if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False if isinstance(image, torch.Tensor) and not isinstance(mask_image, torch.Tensor): raise TypeError("if `image` is a tensor, `mask_image` must also be a tensor") @@ -630,6 +672,8 @@ def check_inputs( image_channels, image_height, image_width = image.shape elif image.ndim == 4: image_batch_size, image_channels, image_height, image_width = image.shape + else: + assert False if mask_image.ndim == 2: mask_image_batch_size = 1 @@ -797,7 +841,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: float = 1.0, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -897,6 +941,7 @@ def __call__( negative_prompt, prompt_embeds, negative_prompt_embeds, + controlnet_conditioning_scale, ) # 2. Define call parameters @@ -913,6 +958,9 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) + # 3. Encode input prompt prompt_embeds = self._encode_prompt( prompt, @@ -929,15 +977,37 @@ def __call__( mask_image = prepare_mask_image(mask_image) - controlnet_conditioning_image = prepare_controlnet_conditioning_image( - controlnet_conditioning_image, - width, - height, - batch_size * num_images_per_prompt, - num_images_per_prompt, - device, - self.controlnet.dtype, - ) + # condition image(s) + if isinstance(self.controlnet, ControlNetModel): + controlnet_conditioning_image = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=controlnet_conditioning_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + elif isinstance(self.controlnet, MultiControlNetModel): + controlnet_conditioning_images = [] + + for image_ in controlnet_conditioning_image: + image_ = prepare_controlnet_conditioning_image( + controlnet_conditioning_image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + controlnet_conditioning_images.append(image_) + + controlnet_conditioning_image = controlnet_conditioning_images + else: + assert False masked_image = image * (mask_image < 0.5) @@ -979,9 +1049,6 @@ def __call__( do_classifier_free_guidance, ) - if do_classifier_free_guidance: - controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2) - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -1007,15 +1074,10 @@ def __call__( t, encoder_hidden_states=prompt_embeds, controlnet_cond=controlnet_conditioning_image, + conditioning_scale=controlnet_conditioning_scale, return_dict=False, ) - down_block_res_samples = [ - down_block_res_sample * controlnet_conditioning_scale - for down_block_res_sample in down_block_res_samples - ] - mid_block_res_sample *= controlnet_conditioning_scale - # predict the noise residual noise_pred = self.unet( inpainting_latent_model_input, diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py new file mode 100644 index 000000000000..ca06136d7829 --- /dev/null +++ b/examples/community/stable_diffusion_controlnet_reference.py @@ -0,0 +1,822 @@ +# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import PIL.Image +import torch + +from diffusers import StableDiffusionControlNetPipeline +from diffusers.models import ControlNetModel +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import is_compiled_module, logging, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import cv2 + >>> import torch + >>> import numpy as np + >>> from PIL import Image + >>> from diffusers import UniPCMultistepScheduler + >>> from diffusers.utils import load_image + + >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + + >>> # get canny image + >>> image = cv2.Canny(np.array(input_image), 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + controlnet=controlnet, + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) + + >>> result_img = pipe(ref_image=input_image, + prompt="1girl", + image=canny_image, + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] + + >>> result_img.show() + ``` +""" + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class StableDiffusionControlNetReferencePipeline(StableDiffusionControlNetPipeline): + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + refimage = refimage.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + ref_image_latents = [ + self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + ref_image_latents = torch.cat(ref_image_latents, dim=0) + else: + ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) + ref_image_latents = self.vae.config.scaling_factor * ref_image_latents + + # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method + if ref_image_latents.shape[0] < batch_size: + if not batch_size % ref_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) + + ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents + + # aligning device to prevent device errors when concating it with the latent model input + ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + return ref_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, + ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + ref_image (`torch.FloatTensor`, `PIL.Image.Image`): + The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can + also be accepted as an image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + attention_auto_machine_weight (`float`): + Weight of using reference query for self attention's context. + If attention_auto_machine_weight=1.0, use reference query for all self attention's context. + gn_auto_machine_weight (`float`): + Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. + style_fidelity (`float`): + style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, + elif style_fidelity=0.0, prompt more important, else balanced. + reference_attn (`bool`): + Whether to use reference query for self attention's context. + reference_adain (`bool`): + Whether to use reference adain. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + else: + assert False + + # 5. Preprocess reference image + ref_image = self.prepare_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + # 6. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 7. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 8. Prepare reference latent variables + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Modify self attention and group norm + MODE = "write" + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) + .type_as(ref_image_latents) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.detach().clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + if attention_auto_machine_weight > self.attn_weight: + attn_output_uc = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + # attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_c = attn_output_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + attn_output_c[uc_mask] = self.attn1( + norm_hidden_states[uc_mask], + encoder_hidden_states=norm_hidden_states[uc_mask], + **cross_attention_kwargs, + ) + attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc + self.bank.clear() + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if reference_attn: + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + # 11. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + controlnet_latent_model_input = latents + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + controlnet_latent_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + controlnet_latent_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # ref only part + noise = randn_tensor( + ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + self.unet( + ref_xt, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py new file mode 100644 index 000000000000..9abe16d56f10 --- /dev/null +++ b/examples/community/stable_diffusion_ipex.py @@ -0,0 +1,848 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import intel_extension_for_pytorch as ipex +import torch +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex") + + >>> # For Float32 + >>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference + >>> # For BFloat16 + >>> pipe.prepare_for_ipex(prompt, dtype=torch.bfloat16, height=512, width=512) #value of image height/width should be consistent with the pipeline inference + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> # For Float32 + >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' + >>> # For BFloat16 + >>> with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16): + >>> image = pipe(prompt, num_inference_steps=num_inference_steps, height=512, width=512).images[0] #value of image height/width should be consistent with 'prepare_for_ipex()' + ``` +""" + + +class StableDiffusionIPEXPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion on IPEX. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def get_input_example(self, prompt, height=None, width=None, guidance_scale=7.5, num_images_per_prompt=1): + prompt_embeds = None + negative_prompt_embeds = None + negative_prompt = None + callback_steps = 1 + generator = None + latents = None + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + + device = "cpu" + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 5. Prepare latent variables + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + self.unet.in_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + dummy = torch.ones(1, dtype=torch.int32) + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, dummy) + + unet_input_example = (latent_model_input, dummy, prompt_embeds) + vae_decoder_input_example = latents + + return unet_input_example, vae_decoder_input_example + + def prepare_for_ipex(self, promt, dtype=torch.float32, height=None, width=None, guidance_scale=7.5): + self.unet = self.unet.to(memory_format=torch.channels_last) + self.vae.decoder = self.vae.decoder.to(memory_format=torch.channels_last) + self.text_encoder = self.text_encoder.to(memory_format=torch.channels_last) + if self.safety_checker is not None: + self.safety_checker = self.safety_checker.to(memory_format=torch.channels_last) + + unet_input_example, vae_decoder_input_example = self.get_input_example(promt, height, width, guidance_scale) + + # optimize with ipex + if dtype == torch.bfloat16: + self.unet = ipex.optimize( + self.unet.eval(), dtype=torch.bfloat16, inplace=True, sample_input=unet_input_example + ) + self.vae.decoder = ipex.optimize(self.vae.decoder.eval(), dtype=torch.bfloat16, inplace=True) + self.text_encoder = ipex.optimize(self.text_encoder.eval(), dtype=torch.bfloat16, inplace=True) + if self.safety_checker is not None: + self.safety_checker = ipex.optimize(self.safety_checker.eval(), dtype=torch.bfloat16, inplace=True) + elif dtype == torch.float32: + self.unet = ipex.optimize( + self.unet.eval(), + dtype=torch.float32, + inplace=True, + sample_input=unet_input_example, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) + self.vae.decoder = ipex.optimize( + self.vae.decoder.eval(), + dtype=torch.float32, + inplace=True, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) + self.text_encoder = ipex.optimize( + self.text_encoder.eval(), + dtype=torch.float32, + inplace=True, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) + if self.safety_checker is not None: + self.safety_checker = ipex.optimize( + self.safety_checker.eval(), + dtype=torch.float32, + inplace=True, + level="O1", + weights_prepack=True, + auto_kernel_selection=False, + ) + else: + raise ValueError(" The value of 'dtype' should be 'torch.bfloat16' or 'torch.float32' !") + + # trace unet model to get better performance on IPEX + with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad(): + unet_trace_model = torch.jit.trace(self.unet, unet_input_example, check_trace=False, strict=False) + unet_trace_model = torch.jit.freeze(unet_trace_model) + self.unet.forward = unet_trace_model.forward + + # trace vae.decoder model to get better performance on IPEX + with torch.cpu.amp.autocast(enabled=dtype == torch.bfloat16), torch.no_grad(): + ave_decoder_trace_model = torch.jit.trace( + self.vae.decoder, vae_decoder_input_example, check_trace=False, strict=False + ) + ave_decoder_trace_model = torch.jit.freeze(ave_decoder_trace_model) + self.vae.decoder.forward = ave_decoder_trace_model.forward + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds)["sample"] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py new file mode 100644 index 000000000000..dbfb768f8b4f --- /dev/null +++ b/examples/community/stable_diffusion_reference.py @@ -0,0 +1,781 @@ +# Inspired by: https://github.com/Mikubill/sd-webui-controlnet/discussions/1236 and https://github.com/Mikubill/sd-webui-controlnet/discussions/1280 +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch + +from diffusers import StableDiffusionPipeline +from diffusers.models.attention import BasicTransformerBlock +from diffusers.models.unet_2d_blocks import CrossAttnDownBlock2D, CrossAttnUpBlock2D, DownBlock2D, UpBlock2D +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.utils import PIL_INTERPOLATION, logging, randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import UniPCMultistepScheduler + >>> from diffusers.utils import load_image + + >>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png") + + >>> pipe = StableDiffusionReferencePipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + safety_checker=None, + torch_dtype=torch.float16 + ).to('cuda:0') + + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) + + >>> result_img = pipe(ref_image=input_image, + prompt="1girl", + num_inference_steps=20, + reference_attn=True, + reference_adain=True).images[0] + + >>> result_img.show() + ``` +""" + + +def torch_dfs(model: torch.nn.Module): + result = [model] + for child in model.children(): + result += torch_dfs(child) + return result + + +class StableDiffusionReferencePipeline(StableDiffusionPipeline): + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = (image - 0.5) / 0.5 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_ref_latents(self, refimage, batch_size, dtype, device, generator, do_classifier_free_guidance): + refimage = refimage.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + ref_image_latents = [ + self.vae.encode(refimage[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + ref_image_latents = torch.cat(ref_image_latents, dim=0) + else: + ref_image_latents = self.vae.encode(refimage).latent_dist.sample(generator=generator) + ref_image_latents = self.vae.config.scaling_factor * ref_image_latents + + # duplicate mask and ref_image_latents for each generation per prompt, using mps friendly method + if ref_image_latents.shape[0] < batch_size: + if not batch_size % ref_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {ref_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + ref_image_latents = ref_image_latents.repeat(batch_size // ref_image_latents.shape[0], 1, 1, 1) + + ref_image_latents = torch.cat([ref_image_latents] * 2) if do_classifier_free_guidance else ref_image_latents + + # aligning device to prevent device errors when concating it with the latent model input + ref_image_latents = ref_image_latents.to(device=device, dtype=dtype) + return ref_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + ref_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_auto_machine_weight: float = 1.0, + gn_auto_machine_weight: float = 1.0, + style_fidelity: float = 0.5, + reference_attn: bool = True, + reference_adain: bool = True, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + ref_image (`torch.FloatTensor`, `PIL.Image.Image`): + The Reference Control input condition. Reference Control uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to Reference Control as is. `PIL.Image.Image` can + also be accepted as an image. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + attention_auto_machine_weight (`float`): + Weight of using reference query for self attention's context. + If attention_auto_machine_weight=1.0, use reference query for all self attention's context. + gn_auto_machine_weight (`float`): + Weight of using reference adain. If gn_auto_machine_weight=2.0, use all reference adain plugins. + style_fidelity (`float`): + style fidelity of ref_uncond_xt. If style_fidelity=1.0, control more important, + elif style_fidelity=0.0, prompt more important, else balanced. + reference_attn (`bool`): + Whether to use reference query for self attention's context. + reference_adain (`bool`): + Whether to use reference adain. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + assert reference_attn or reference_adain, "`reference_attn` or `reference_adain` must be True." + + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, ref_image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess reference image + ref_image = self.prepare_image( + image=ref_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=prompt_embeds.dtype, + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare reference latent variables + ref_image_latents = self.prepare_ref_latents( + ref_image, + batch_size * num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Modify self attention and group norm + MODE = "write" + uc_mask = ( + torch.Tensor([1] * batch_size * num_images_per_prompt + [0] * batch_size * num_images_per_prompt) + .type_as(ref_image_latents) + .bool() + ) + + def hacked_basic_transformer_inner_forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if self.only_cross_attention: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + else: + if MODE == "write": + self.bank.append(norm_hidden_states.detach().clone()) + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if MODE == "read": + if attention_auto_machine_weight > self.attn_weight: + attn_output_uc = self.attn1( + norm_hidden_states, + encoder_hidden_states=torch.cat([norm_hidden_states] + self.bank, dim=1), + # attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output_c = attn_output_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + attn_output_c[uc_mask] = self.attn1( + norm_hidden_states[uc_mask], + encoder_hidden_states=norm_hidden_states[uc_mask], + **cross_attention_kwargs, + ) + attn_output = style_fidelity * attn_output_c + (1.0 - style_fidelity) * attn_output_uc + self.bank.clear() + else: + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states + + return hidden_states + + def hacked_mid_forward(self, *args, **kwargs): + eps = 1e-6 + x = self.original_forward(*args, **kwargs) + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append(mean) + self.var_bank.append(var) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank) / float(len(self.mean_bank)) + var_acc = sum(self.var_bank) / float(len(self.var_bank)) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + x_uc = (((x - mean) / std) * std_acc) + mean_acc + x_c = x_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + x_c[uc_mask] = x[uc_mask] + x = style_fidelity * x_c + (1.0 - style_fidelity) * x_uc + self.mean_bank = [] + self.var_bank = [] + return x + + def hack_CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_DownBlock2D_forward(self, hidden_states, temb=None): + eps = 1e-6 + + output_states = () + + for i, resnet in enumerate(self.resnets): + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + output_states = output_states + (hidden_states,) + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + def hacked_CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + eps = 1e-6 + # TODO(Patrick, William) - attention mask is not used + for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + eps = 1e-6 + for i, resnet in enumerate(self.resnets): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + if MODE == "write": + if gn_auto_machine_weight >= self.gn_weight: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + self.mean_bank.append([mean]) + self.var_bank.append([var]) + if MODE == "read": + if len(self.mean_bank) > 0 and len(self.var_bank) > 0: + var, mean = torch.var_mean(hidden_states, dim=(2, 3), keepdim=True, correction=0) + std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5 + mean_acc = sum(self.mean_bank[i]) / float(len(self.mean_bank[i])) + var_acc = sum(self.var_bank[i]) / float(len(self.var_bank[i])) + std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5 + hidden_states_uc = (((hidden_states - mean) / std) * std_acc) + mean_acc + hidden_states_c = hidden_states_uc.clone() + if do_classifier_free_guidance and style_fidelity > 0: + hidden_states_c[uc_mask] = hidden_states[uc_mask] + hidden_states = style_fidelity * hidden_states_c + (1.0 - style_fidelity) * hidden_states_uc + + if MODE == "read": + self.mean_bank = [] + self.var_bank = [] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + if reference_attn: + attn_modules = [module for module in torch_dfs(self.unet) if isinstance(module, BasicTransformerBlock)] + attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0]) + + for i, module in enumerate(attn_modules): + module._original_inner_forward = module.forward + module.forward = hacked_basic_transformer_inner_forward.__get__(module, BasicTransformerBlock) + module.bank = [] + module.attn_weight = float(i) / float(len(attn_modules)) + + if reference_adain: + gn_modules = [self.unet.mid_block] + self.unet.mid_block.gn_weight = 0 + + down_blocks = self.unet.down_blocks + for w, module in enumerate(down_blocks): + module.gn_weight = 1.0 - float(w) / float(len(down_blocks)) + gn_modules.append(module) + + up_blocks = self.unet.up_blocks + for w, module in enumerate(up_blocks): + module.gn_weight = float(w) / float(len(up_blocks)) + gn_modules.append(module) + + for i, module in enumerate(gn_modules): + if getattr(module, "original_forward", None) is None: + module.original_forward = module.forward + if i == 0: + # mid_block + module.forward = hacked_mid_forward.__get__(module, torch.nn.Module) + elif isinstance(module, CrossAttnDownBlock2D): + module.forward = hack_CrossAttnDownBlock2D_forward.__get__(module, CrossAttnDownBlock2D) + elif isinstance(module, DownBlock2D): + module.forward = hacked_DownBlock2D_forward.__get__(module, DownBlock2D) + elif isinstance(module, CrossAttnUpBlock2D): + module.forward = hacked_CrossAttnUpBlock2D_forward.__get__(module, CrossAttnUpBlock2D) + elif isinstance(module, UpBlock2D): + module.forward = hacked_UpBlock2D_forward.__get__(module, UpBlock2D) + module.mean_bank = [] + module.var_bank = [] + module.gn_weight *= 2 + + # 10. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # ref only part + noise = randn_tensor( + ref_image_latents.shape, generator=generator, device=device, dtype=ref_image_latents.dtype + ) + ref_xt = self.scheduler.add_noise( + ref_image_latents, + noise, + t.reshape( + 1, + ), + ) + ref_xt = self.scheduler.scale_model_input(ref_xt, t) + + MODE = "write" + self.unet( + ref_xt, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + ) + + # predict the noise residual + MODE = "read" + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py new file mode 100644 index 000000000000..3fd63d4b213a --- /dev/null +++ b/examples/community/stable_diffusion_repaint.py @@ -0,0 +1,956 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers import AutoencoderKL, DiffusionPipeline, UNet2DConditionModel +from diffusers.configuration_utils import FrozenDict, deprecate +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + # masked_image = image * (mask >= 0.5) + masked_image = image + + return mask, masked_image + + +class StableDiffusionRepaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate( + "skip_prk_steps not set", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 + if unet.config.in_channels != 4: + logger.warning( + f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," + f" {self.__class__} assumes that `pipeline.unet` has 4 input channels: 4 for `num_channels_latents`," + ". If you did not intend to modify" + " this behavior, please check whether you have loaded the right checkpoint." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + jump_length: Optional[int] = 10, + jump_n_sample: Optional[int] = 10, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + jump_length (`int`, *optional*, defaults to 10): + The number of steps taken forward in time before going backward in time for a single jump ("j" in + RePaint paper). Take a look at Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + jump_n_sample (`int`, *optional*, defaults to 10): + The number of times we will make forward time jump for a given chosen time sample. Take a look at + Figure 9 and 10 in https://arxiv.org/pdf/2201.09865.pdf. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Examples: + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + >>> from diffusers import StableDiffusionPipeline, RePaintScheduler + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + >>> base_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/" + >>> img_url = base_url + "overture-creations-5sI6fQgYIuo.png" + >>> mask_url = base_url + "overture-creations-5sI6fQgYIuo_mask.png " + >>> init_image = download_image(img_url).resize((512, 512)) + >>> mask_image = download_image(mask_url).resize((512, 512)) + >>> pipe = DiffusionPipeline.from_pretrained( + ... "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, custom_pipeline="stable_diffusion_repaint", + ... ) + >>> pipe.scheduler = RePaintScheduler.from_config(pipe.scheduler.config) + >>> pipe = pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0] + ``` + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask_image is None: + raise ValueError("`mask_image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, jump_length, jump_n_sample, device) + self.scheduler.eta = eta + + timesteps = self.scheduler.timesteps + # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance=False, # We do not need duplicate mask and image + ) + + # 8. Check that sizes of mask, masked image and latents match + # num_channels_mask = mask.shape[1] + # num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} " + f" = Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + t_last = timesteps[0] + 1 + + # 10. Denoising loop + with self.progress_bar(total=len(timesteps)) as progress_bar: + for i, t in enumerate(timesteps): + if t >= t_last: + # compute the reverse: x_t-1 -> x_t + latents = self.scheduler.undo_step(latents, t_last, generator) + progress_bar.update() + t_last = t + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + masked_image_latents, + mask, + **extra_step_kwargs, + ).prev_sample + + # call the callback, if provided + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + t_last = t + + # 11. Post-processing + image = self.decode_latents(latents) + + # 12. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 13. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py new file mode 100755 index 000000000000..67c7c2d00fbf --- /dev/null +++ b/examples/community/stable_diffusion_tensorrt_img2img.py @@ -0,0 +1,1055 @@ +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from collections import OrderedDict +from copy import copy +from typing import List, Optional, Union + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import PIL +import tensorrt as trt +import torch +from huggingface_hub import snapshot_download +from onnx import shape_inference +from polygraphy import cuda +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx.loader import fold_constants +from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from polygraphy.backend.trt import util as trt_util +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionImg2ImgPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import DIFFUSERS_CACHE, logging + + +""" +Installation instructions +python3 -m pip install --upgrade transformers diffusers>=0.16.0 +python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +python3 -m pip install onnxruntime +""" + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + + +def device_view(t): + return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) + + +def preprocess_image(image): + """ + image: torch.Tensor + """ + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h)) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).contiguous() + return 2.0 * image - 1.0 + + +class Engine: + def __init__(self, engine_path): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] + del self.engine + del self.context + del self.buffers + del self.tensors + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + workspace_size=0, + ): + logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + + config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] + if enable_preview: + # Faster dynamic shapes made optional since it increases engine build time. + config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) + if workspace_size > 0: + config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + engine = engine_from_network( + network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + logger.warning(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(trt_util.get_bindings_per_profile(self.engine)): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + + def infer(self, feed_dict, stream): + start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) + # shallow copy of ordered dict + device_buffers = copy(self.buffers) + for name, buf in feed_dict.items(): + assert isinstance(buf, cuda.DeviceView) + device_buffers[name] = buf + bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] + noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class Optimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + +class BaseModel: + def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = "SD Model" + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + onnx_opt_graph = opt.cleanup(return_onnx=True) + return onnx_opt_graph + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +def getOnnxPath(model_name, onnx_dir, opt=True): + return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") + + +def getEnginePath(model_name, engine_dir): + return os.path.join(engine_dir, model_name + ".plan") + + +def build_engines( + models: dict, + engine_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_shape=True, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + max_workspace_size=0, +): + built_engines = {} + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + if force_engine_rebuild or not os.path.exists(engine_path): + logger.warning("Building Engines...") + logger.warning("Engine build can take a while to complete") + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + if force_engine_rebuild or not os.path.exists(onnx_path): + logger.warning(f"Exporting model: {onnx_path}") + model = model_obj.get_model() + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.warning(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + logger.warning(f"Generating optimizing model: {onnx_opt_path}") + onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + logger.warning(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + engine = Engine(engine_path) + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + + if force_engine_rebuild or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_preview=enable_preview, + timing_cache=timing_cache, + workspace_size=max_workspace_size, + ) + built_engines[model_name] = engine + + # Load and activate TensorRT engines + for model_name, model_obj in models.items(): + engine = built_engines[model_name] + engine.load() + engine.activate() + + return built_engines + + +def runEngine(engine, feed_dict, stream): + return engine.infer(feed_dict, stream) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(CLIP, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "CLIP" + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt_onnx_graph = opt.cleanup(return_onnx=True) + return opt_onnx_graph + + +def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): + return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class UNet(BaseModel): + def __init__( + self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 + ): + super(UNet, self).__init__( + model=model, + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.name = "UNet" + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False): + return UNet( + model, + fp16=True, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + unet_dim=(9 if inpaint else 4), + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAE, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE decoder" + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + + +def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.vae_encoder = model + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAEEncoder, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE encoder" + + def get_model(self): + vae_encoder = TorchVAEEncoder(self.model) + return vae_encoder + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) + + +def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class TensorRTStableDiffusionImg2ImgPipeline(StableDiffusionImg2ImgPipeline): + r""" + Pipeline for image-to-image generation using TensorRT accelerated Stable Diffusion. + + This model inherits from [`StableDiffusionImg2ImgPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + stages=["clip", "unet", "vae", "vae_encoder"], + image_height: int = 512, + image_width: int = 512, + max_batch_size: int = 16, + # ONNX export parameters + onnx_opset: int = 17, + onnx_dir: str = "onnx", + # TensorRT engine build parameters + engine_dir: str = "engine", + build_preview_features: bool = True, + force_engine_rebuild: bool = False, + timing_cache: str = "timing_cache", + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + + self.vae.forward = self.vae.decode + + self.stages = stages + self.image_height, self.image_width = image_height, image_width + self.inpaint = False + self.onnx_opset = onnx_opset + self.onnx_dir = onnx_dir + self.engine_dir = engine_dir + self.force_engine_rebuild = force_engine_rebuild + self.timing_cache = timing_cache + self.build_static_batch = False + self.build_dynamic_shape = False + self.build_preview_features = build_preview_features + + self.max_batch_size = max_batch_size + # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. + if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: + self.max_batch_size = 4 + + self.stream = None # loaded in loadResources() + self.models = {} # loaded in __loadModels() + self.engine = {} # loaded in build_engines() + + def __loadModels(self): + # Load pipeline models + self.embedding_dim = self.text_encoder.config.hidden_size + models_args = { + "device": self.torch_device, + "max_batch_size": self.max_batch_size, + "embedding_dim": self.embedding_dim, + "inpaint": self.inpaint, + } + if "clip" in self.stages: + self.models["clip"] = make_CLIP(self.text_encoder, **models_args) + if "unet" in self.stages: + self.models["unet"] = make_UNet(self.unet, **models_args) + if "vae" in self.stages: + self.models["vae"] = make_VAE(self.vae, **models_args) + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) + + @classmethod + def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): + super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings) + + self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) + self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) + self.timing_cache = os.path.join(self.cached_folder, self.timing_cache) + + # set device + self.torch_device = self._execution_device + logger.warning(f"Running inference on device: {self.torch_device}") + + # load models + self.__loadModels() + + # build engines + self.engine = build_engines( + self.models, + self.engine_dir, + self.onnx_dir, + self.onnx_opset, + opt_image_height=self.image_height, + opt_image_width=self.image_width, + force_engine_rebuild=self.force_engine_rebuild, + static_batch=self.build_static_batch, + static_shape=not self.build_dynamic_shape, + enable_preview=self.build_preview_features, + timing_cache=self.timing_cache, + ) + + return self + + def __initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device) + return timesteps, t_start + + def __preprocess_images(self, batch_size, images=()): + init_images = [] + for image in images: + image = image.to(self.torch_device).float() + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + return tuple(init_images) + + def __encode_image(self, init_image): + init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ + "latent" + ] + init_latents = 0.18215 * init_latents + return init_latents + + def __encode_prompt(self, prompt, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + """ + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + text_input_ids_inp = device_view(text_input_ids) + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + "text_embeddings" + ].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + uncond_input_ids_inp = device_view(uncond_input_ids) + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + "text_embeddings" + ] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def __denoise_latent( + self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = device_view(latent_model_input) + timestep_inp = device_view(timestep_float) + embeddings_inp = device_view(text_embeddings) + noise_pred = runEngine( + self.engine["unet"], + {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + self.stream, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def __decode_latent(self, latents): + images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def __loadResources(self, image_height, image_width, batch_size): + self.stream = cuda.Stream() + + # Allocate buffers for TensorRT engine bindings + for model_name, obj in self.models.items(): + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + + """ + self.generator = generator + self.denoising_steps = num_inference_steps + self.guidance_scale = guidance_scale + + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + assert len(prompt) == len(negative_prompt) + + if batch_size > self.max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + # load resources + self.__loadResources(self.image_height, self.image_width, batch_size) + + with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): + # Initialize timesteps + timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size) + + # Pre-process input image + if isinstance(image, PIL.Image.Image): + image = preprocess_image(image) + init_image = self.__preprocess_images(batch_size, (image,))[0] + + # VAE encode init image + init_latents = self.__encode_image(init_image) + + # Add noise to latents using timesteps + noise = torch.randn( + init_latents.shape, generator=self.generator, device=self.torch_device, dtype=torch.float32 + ) + latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) + + # CLIP text encoder + text_embeddings = self.__encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.__denoise_latent(latents, text_embeddings, timesteps=timesteps, step_offset=t_start) + + # VAE decode latent + images = self.__decode_latent(latents) + + images = self.numpy_to_pil(images) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py new file mode 100755 index 000000000000..44f3bf5049b8 --- /dev/null +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -0,0 +1,1088 @@ +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from collections import OrderedDict +from copy import copy +from typing import List, Optional, Union + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import PIL +import tensorrt as trt +import torch +from huggingface_hub import snapshot_download +from onnx import shape_inference +from polygraphy import cuda +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.onnx.loader import fold_constants +from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from polygraphy.backend.trt import util as trt_util +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion import ( + StableDiffusionInpaintPipeline, + StableDiffusionPipelineOutput, + StableDiffusionSafetyChecker, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image +from diffusers.schedulers import DDIMScheduler +from diffusers.utils import DIFFUSERS_CACHE, logging + + +""" +Installation instructions +python3 -m pip install --upgrade transformers diffusers>=0.16.0 +python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +python3 -m pip install onnxruntime +""" + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} +if np.version.full_version >= "1.24.0": + numpy_to_torch_dtype_dict[np.bool_] = torch.bool +else: + numpy_to_torch_dtype_dict[np.bool] = torch.bool + +# Map of torch dtype -> numpy dtype +torch_to_numpy_dtype_dict = {value: key for (key, value) in numpy_to_torch_dtype_dict.items()} + + +def device_view(t): + return cuda.DeviceView(ptr=t.data_ptr(), shape=t.shape, dtype=torch_to_numpy_dtype_dict[t.dtype]) + + +def preprocess_image(image): + """ + image: torch.Tensor + """ + w, h = image.size + w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h)) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).contiguous() + return 2.0 * image - 1.0 + + +class Engine: + def __init__(self, engine_path): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + + def __del__(self): + [buf.free() for buf in self.buffers.values() if isinstance(buf, cuda.DeviceArray)] + del self.engine + del self.context + del self.buffers + del self.tensors + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + workspace_size=0, + ): + logger.warning(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + + config_kwargs["preview_features"] = [trt.PreviewFeature.DISABLE_EXTERNAL_TACTIC_SOURCES_FOR_CORE_0805] + if enable_preview: + # Faster dynamic shapes made optional since it increases engine build time. + config_kwargs["preview_features"].append(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805) + if workspace_size > 0: + config_kwargs["memory_pool_limits"] = {trt.MemoryPoolType.WORKSPACE: workspace_size} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + engine = engine_from_network( + network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), + config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + logger.warning(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self): + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(trt_util.get_bindings_per_profile(self.engine)): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + self.buffers[binding] = cuda.DeviceView(ptr=tensor.data_ptr(), shape=shape, dtype=dtype) + + def infer(self, feed_dict, stream): + start_binding, end_binding = trt_util.get_active_profile_bindings(self.context) + # shallow copy of ordered dict + device_buffers = copy(self.buffers) + for name, buf in feed_dict.items(): + assert isinstance(buf, cuda.DeviceView) + device_buffers[name] = buf + bindings = [0] * start_binding + [buf.ptr for buf in device_buffers.values()] + noerror = self.context.execute_async_v2(bindings=bindings, stream_handle=stream.ptr) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class Optimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + if return_onnx: + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() > 2147483648: + raise TypeError("ERROR: model size exceeds supported 2GB limit") + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + +class BaseModel: + def __init__(self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77): + self.model = model + self.name = "SD Model" + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_model(self): + return self.model + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + onnx_opt_graph = opt.cleanup(return_onnx=True) + return onnx_opt_graph + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +def getOnnxPath(model_name, onnx_dir, opt=True): + return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") + + +def getEnginePath(model_name, engine_dir): + return os.path.join(engine_dir, model_name + ".plan") + + +def build_engines( + models: dict, + engine_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_shape=True, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + max_workspace_size=0, +): + built_engines = {} + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + # Export models to ONNX + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + if force_engine_rebuild or not os.path.exists(engine_path): + logger.warning("Building Engines...") + logger.warning("Engine build can take a while to complete") + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + if force_engine_rebuild or not os.path.exists(onnx_path): + logger.warning(f"Exporting model: {onnx_path}") + model = model_obj.get_model() + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.warning(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_engine_rebuild or not os.path.exists(onnx_opt_path): + logger.warning(f"Generating optimizing model: {onnx_opt_path}") + onnx_opt_graph = model_obj.optimize(onnx.load(onnx_path)) + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + logger.warning(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, model_obj in models.items(): + engine_path = getEnginePath(model_name, engine_dir) + engine = Engine(engine_path) + onnx_path = getOnnxPath(model_name, onnx_dir, opt=False) + onnx_opt_path = getOnnxPath(model_name, onnx_dir) + + if force_engine_rebuild or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_preview=enable_preview, + timing_cache=timing_cache, + workspace_size=max_workspace_size, + ) + built_engines[model_name] = engine + + # Load and activate TensorRT engines + for model_name, model_obj in models.items(): + engine = built_engines[model_name] + engine.load() + engine.activate() + + return built_engines + + +def runEngine(engine, feed_dict, stream): + return engine.infer(feed_dict, stream) + + +class CLIP(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(CLIP, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "CLIP" + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings", "pooler_output"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) + + def optimize(self, onnx_graph): + opt = Optimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt_onnx_graph = opt.cleanup(return_onnx=True) + return opt_onnx_graph + + +def make_CLIP(model, device, max_batch_size, embedding_dim, inpaint=False): + return CLIP(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class UNet(BaseModel): + def __init__( + self, model, fp16=False, device="cuda", max_batch_size=16, embedding_dim=768, text_maxlen=77, unet_dim=4 + ): + super(UNet, self).__init__( + model=model, + fp16=fp16, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.name = "UNet" + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +def make_UNet(model, device, max_batch_size, embedding_dim, inpaint=False, unet_dim=4): + return UNet( + model, + fp16=True, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + unet_dim=unet_dim, + ) + + +class VAE(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAE, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE decoder" + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) + + +def make_VAE(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAE(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.vae_encoder = model + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, model, device, max_batch_size, embedding_dim): + super(VAEEncoder, self).__init__( + model=model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim + ) + self.name = "VAE encoder" + + def get_model(self): + vae_encoder = TorchVAEEncoder(self.model) + return vae_encoder + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_shape): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) + + +def make_VAEEncoder(model, device, max_batch_size, embedding_dim, inpaint=False): + return VAEEncoder(model, device=device, max_batch_size=max_batch_size, embedding_dim=embedding_dim) + + +class TensorRTStableDiffusionInpaintPipeline(StableDiffusionInpaintPipeline): + r""" + Pipeline for inpainting using TensorRT accelerated Stable Diffusion. + + This model inherits from [`StableDiffusionInpaintPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + stages=["clip", "unet", "vae", "vae_encoder"], + image_height: int = 512, + image_width: int = 512, + max_batch_size: int = 16, + # ONNX export parameters + onnx_opset: int = 17, + onnx_dir: str = "onnx", + # TensorRT engine build parameters + engine_dir: str = "engine", + build_preview_features: bool = True, + force_engine_rebuild: bool = False, + timing_cache: str = "timing_cache", + ): + super().__init__( + vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker + ) + + self.vae.forward = self.vae.decode + + self.stages = stages + self.image_height, self.image_width = image_height, image_width + self.inpaint = True + self.onnx_opset = onnx_opset + self.onnx_dir = onnx_dir + self.engine_dir = engine_dir + self.force_engine_rebuild = force_engine_rebuild + self.timing_cache = timing_cache + self.build_static_batch = False + self.build_dynamic_shape = False + self.build_preview_features = build_preview_features + + self.max_batch_size = max_batch_size + # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. + if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: + self.max_batch_size = 4 + + self.stream = None # loaded in loadResources() + self.models = {} # loaded in __loadModels() + self.engine = {} # loaded in build_engines() + + def __loadModels(self): + # Load pipeline models + self.embedding_dim = self.text_encoder.config.hidden_size + models_args = { + "device": self.torch_device, + "max_batch_size": self.max_batch_size, + "embedding_dim": self.embedding_dim, + "inpaint": self.inpaint, + } + if "clip" in self.stages: + self.models["clip"] = make_CLIP(self.text_encoder, **models_args) + if "unet" in self.stages: + self.models["unet"] = make_UNet(self.unet, **models_args, unet_dim=self.unet.config.in_channels) + if "vae" in self.stages: + self.models["vae"] = make_VAE(self.vae, **models_args) + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = make_VAEEncoder(self.vae, **models_args) + + @classmethod + def set_cached_folder(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + cls.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) + + def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): + super().to(torch_device, silence_dtype_warnings=silence_dtype_warnings) + + self.onnx_dir = os.path.join(self.cached_folder, self.onnx_dir) + self.engine_dir = os.path.join(self.cached_folder, self.engine_dir) + self.timing_cache = os.path.join(self.cached_folder, self.timing_cache) + + # set device + self.torch_device = self._execution_device + logger.warning(f"Running inference on device: {self.torch_device}") + + # load models + self.__loadModels() + + # build engines + self.engine = build_engines( + self.models, + self.engine_dir, + self.onnx_dir, + self.onnx_opset, + opt_image_height=self.image_height, + opt_image_width=self.image_width, + force_engine_rebuild=self.force_engine_rebuild, + static_batch=self.build_static_batch, + static_shape=not self.build_dynamic_shape, + enable_preview=self.build_preview_features, + timing_cache=self.timing_cache, + ) + + return self + + def __initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.torch_device) + return timesteps, t_start + + def __preprocess_images(self, batch_size, images=()): + init_images = [] + for image in images: + image = image.to(self.torch_device).float() + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + return tuple(init_images) + + def __encode_image(self, init_image): + init_latents = runEngine(self.engine["vae_encoder"], {"images": device_view(init_image)}, self.stream)[ + "latent" + ] + init_latents = 0.18215 * init_latents + return init_latents + + def __encode_prompt(self, prompt, negative_prompt): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + """ + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + text_input_ids_inp = device_view(text_input_ids) + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = runEngine(self.engine["clip"], {"input_ids": text_input_ids_inp}, self.stream)[ + "text_embeddings" + ].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + uncond_input_ids_inp = device_view(uncond_input_ids) + uncond_embeddings = runEngine(self.engine["clip"], {"input_ids": uncond_input_ids_inp}, self.stream)[ + "text_embeddings" + ] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def __denoise_latent( + self, latents, text_embeddings, timesteps=None, step_offset=0, mask=None, masked_image_latents=None + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = device_view(latent_model_input) + timestep_inp = device_view(timestep_float) + embeddings_inp = device_view(text_embeddings) + noise_pred = runEngine( + self.engine["unet"], + {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp}, + self.stream, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def __decode_latent(self, latents): + images = runEngine(self.engine["vae"], {"latent": device_view(latents)}, self.stream)["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def __loadResources(self, image_height, image_width, batch_size): + self.stream = cuda.Stream() + + # Allocate buffers for TensorRT engine bindings + for model_name, obj in self.models.items(): + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + strength: float = 0.75, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + + """ + self.generator = generator + self.denoising_steps = num_inference_steps + self.guidance_scale = guidance_scale + + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) + + # Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"Expected prompt to be of type list or str but got {type(prompt)}") + + if negative_prompt is None: + negative_prompt = [""] * batch_size + + if negative_prompt is not None and isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + assert len(prompt) == len(negative_prompt) + + if batch_size > self.max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + # Validate image dimensions + mask_width, mask_height = mask_image.size + if mask_height != self.image_height or mask_width != self.image_width: + raise ValueError( + f"Input image height and width {self.image_height} and {self.image_width} are not equal to " + f"the respective dimensions of the mask image {mask_height} and {mask_width}" + ) + + # load resources + self.__loadResources(self.image_height, self.image_width, batch_size) + + with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): + # Spatial dimensions of latent tensor + latent_height = self.image_height // 8 + latent_width = self.image_width // 8 + + # Pre-initialize latents + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents( + batch_size, + num_channels_latents, + self.image_height, + self.image_width, + torch.float32, + self.torch_device, + generator, + ) + + # Pre-process input images + mask, masked_image = self.__preprocess_images(batch_size, prepare_mask_and_masked_image(image, mask_image)) + # print(mask) + mask = torch.nn.functional.interpolate(mask, size=(latent_height, latent_width)) + mask = torch.cat([mask] * 2) + + # Initialize timesteps + timesteps, t_start = self.__initialize_timesteps(self.denoising_steps, strength) + + # VAE encode masked image + masked_latents = self.__encode_image(masked_image) + masked_latents = torch.cat([masked_latents] * 2) + + # CLIP text encoder + text_embeddings = self.__encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.__denoise_latent( + latents, + text_embeddings, + timesteps=timesteps, + step_offset=t_start, + mask=mask, + masked_image_latents=masked_latents, + ) + + # VAE decode latent + images = self.__decode_latent(latents) + + images = self.numpy_to_pil(images) + return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=None) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py old mode 100644 new mode 100755 index aa7b5c12313b..b51f3176b958 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -54,8 +54,9 @@ """ Installation instructions -python3 -m pip install --upgrade tensorrt -python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +python3 -m pip install --upgrade transformers diffusers>=0.16.0 +python3 -m pip install --upgrade tensorrt>=8.6.1 +python3 -m pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com python3 -m pip install onnxruntime """ @@ -132,7 +133,7 @@ def build( config_kwargs["tactic_sources"] = [] engine = engine_from_network( - network_from_onnx_path(onnx_path), + network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]), config=CreateConfig(fp16=fp16, profiles=[p], load_timing_cache=timing_cache, **config_kwargs), save_timing_cache=timing_cache, ) @@ -633,6 +634,7 @@ def __init__( onnx_dir: str = "onnx", # TensorRT engine build parameters engine_dir: str = "engine", + build_preview_features: bool = True, force_engine_rebuild: bool = False, timing_cache: str = "timing_cache", ): @@ -652,7 +654,7 @@ def __init__( self.timing_cache = timing_cache self.build_static_batch = False self.build_dynamic_shape = False - self.build_preview_features = False + self.build_preview_features = build_preview_features self.max_batch_size = max_batch_size # TODO: Restrict batch size to 4 for larger image dimensions as a WAR for TensorRT limitation. diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 9b9ba5ab737f..d54b190bc7bd 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -55,11 +55,22 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__) +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") @@ -156,6 +167,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler else: logger.warn(f"image logging not implemented for {tracker.name}") + return image_logs + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( @@ -177,6 +190,43 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument( @@ -486,7 +536,6 @@ def parse_args(input_args=None): "--tracker_project_name", type=str, default="train_controlnet", - required=True, help=( "The `project_name` argument passed to Accelerator.init_trackers for" " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" @@ -930,7 +979,7 @@ def load_model_hook(models, input_dir): accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - initial_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 @@ -943,6 +992,7 @@ def load_model_hook(models, input_dir): disable=not accelerator.is_local_main_process, ) + image_logs = None for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): @@ -1014,7 +1064,7 @@ def load_model_hook(models, input_dir): logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation( + image_logs = log_validation( vae, text_encoder, tokenizer, @@ -1040,6 +1090,12 @@ def load_model_hook(models, input_dir): controlnet.save_pretrained(args.output_dir) if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index aff361cb6e01..fa7e0193fc73 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -59,7 +59,7 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = logging.getLogger(__name__) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 0954f3d6e789..0b5e23c015d5 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -56,7 +56,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 8447c7560720..5813c42cd5d3 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -43,6 +43,8 @@ from accelerate.utils import write_basic_config write_basic_config() ``` +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. + ### Dog toy example Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. @@ -80,7 +82,8 @@ accelerate launch train_dreambooth.py \ --learning_rate=5e-6 \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ - --max_train_steps=400 + --max_train_steps=400 \ + --push_to_hub ``` ### Training with prior-preservation loss @@ -109,7 +112,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -141,7 +145,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -176,7 +181,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` @@ -218,7 +224,8 @@ accelerate launch --mixed_precision="fp16" train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### Fine-tune text encoder with the UNet. @@ -251,7 +258,8 @@ accelerate launch train_dreambooth.py \ --lr_scheduler="constant" \ --lr_warmup_steps=0 \ --num_class_images=200 \ - --max_train_steps=800 + --max_train_steps=800 \ + --push_to_hub ``` ### Using DreamBooth for pipelines other than Stable Diffusion @@ -355,7 +363,7 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5). You can use the `Step` slider to see how the model learned the features of our subject while the model trained. -Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we +Optionally, we can also train additional LoRA layers for the text encoder. Specify the `--train_text_encoder` argument above for that. If you're interested to know more about how we enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918). With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth). @@ -387,6 +395,50 @@ Finally, we can run the model in inference. image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] ``` +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "patrickvonplaten/lora_dreambooth_dog_example" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + +If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA +weights. For example: + +```python +from huggingface_hub.repocard import RepoCard +from diffusers import StableDiffusionPipeline +import torch + +lora_model_id = "sayakpaul/dreambooth-text-encoder-test" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +pipe = pipe.to("cuda") +pipe.load_lora_weights(lora_model_id) +image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] +``` + +Note that the use of [`LoraLoaderMixin.load_lora_weights`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights) is preferred to [`UNet2DConditionLoadersMixin.load_attn_procs`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs) for loading LoRA parameters. This is because +`LoraLoaderMixin.load_lora_weights` can handle the following situations: + +* LoRA parameters that don't have separate identifiers for the UNet and the text encoder (such as [`"patrickvonplaten/lora_dreambooth_dog_example"`](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example)). So, you can just do: + + ```py + pipe.load_lora_weights(lora_model_path) + ``` + +* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). + ## Training with Flax/JAX For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. @@ -481,3 +533,207 @@ More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_ ### Experimental results You can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects. + +## IF + +You can use the lora and full dreambooth scripts to train the text to image [IF model](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) and the stage II upscaler +[IF model](https://huggingface.co/DeepFloyd/IF-II-L-v1.0). + +Note that IF has a predicted variance, and our finetuning scripts only train the models predicted error, so for finetuned IF models we switch to a fixed +variance schedule. The full finetuning scripts will update the scheduler config for the full saved model. However, when loading saved LoRA weights, you +must also update the pipeline's scheduler config. + +```py +from diffusers import DiffusionPipeline + +pipe = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0") + +pipe.load_lora_weights("") + +# Update scheduler config to fixed variance schedule +pipe.scheduler = pipe.scheduler.__class__.from_config(pipe.scheduler.config, variance_type="fixed_small") +``` + +Additionally, a few alternative cli flags are needed for IF. + +`--resolution=64`: IF is a pixel space diffusion model. In order to operate on un-compressed pixels, the input images are of a much smaller resolution. + +`--pre_compute_text_embeddings`: IF uses [T5](https://huggingface.co/docs/transformers/model_doc/t5) for its text encoder. In order to save GPU memory, we pre compute all text embeddings and then de-allocate +T5. + +`--tokenizer_max_length=77`: T5 has a longer default text length, but the default IF encoding procedure uses a smaller number. + +`--text_encoder_use_attention_mask`: T5 passes the attention mask to the text encoder. + +### Tips and Tricks +We find LoRA to be sufficient for finetuning the stage I model as the low resolution of the model makes representing finegrained detail hard regardless. + +For common and/or not-visually complex object concepts, you can get away with not-finetuning the upscaler. Just be sure to adjust the prompt passed to the +upscaler to remove the new token from the instance prompt. I.e. if your stage I prompt is "a sks dog", use "a dog" for your stage II prompt. + +For finegrained detail like faces that aren't present in the original training set, we find that full finetuning of the stage II upscaler is better than +LoRA finetuning stage II. + +For finegrained detail like faces, we find that lower learning rates along with larger batch sizes work best. + +For stage II, we find that lower learning rates are also needed. + +We found experimentally that the DDPM scheduler with the default larger number of denoising steps to sometimes work better than the DPM Solver scheduler +used in the training scripts. + +### Stage II additional validation images + +The stage II validation requires images to upscale, we can download a downsized version of the training set: + +```py +from huggingface_hub import snapshot_download + +local_dir = "./dog_downsized" +snapshot_download( + "diffusers/dog-example-downsized", + local_dir=local_dir, + repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +### IF stage I LoRA Dreambooth +This training configuration requires ~28 GB VRAM. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_lora" + +accelerate launch train_dreambooth_lora.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=64 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --scale_lr \ + --max_train_steps=1200 \ + --validation_prompt="a sks dog" \ + --validation_epochs=25 \ + --checkpointing_steps=100 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask +``` + +### IF stage II LoRA Dreambooth + +`--validation_images`: These images are upscaled during validation steps. + +`--class_labels_conditioning=timesteps`: Pass additional conditioning to the UNet needed for stage II. + +`--learning_rate=1e-6`: Lower learning rate than stage I. + +`--resolution=256`: The upscaler expects higher resolution inputs + +```sh +export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_upscale" +export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png" + +python train_dreambooth_lora.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=256 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-6 \ + --max_train_steps=2000 \ + --validation_prompt="a sks dog" \ + --validation_epochs=100 \ + --checkpointing_steps=500 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask \ + --validation_images $VALIDATION_IMAGES \ + --class_labels_conditioning=timesteps +``` + +### IF Stage I Full Dreambooth +`--skip_save_text_encoder`: When training the full model, this will skip saving the entire T5 with the finetuned model. You can still load the pipeline +with a T5 loaded from the original model. + +`use_8bit_adam`: Due to the size of the optimizer states, we recommend training the full XL IF model with 8bit adam. + +`--learning_rate=1e-7`: For full dreambooth, IF requires very low learning rates. With higher learning rates model quality will degrade. Note that it is +likely the learning rate can be increased with larger batch sizes. + +Using 8bit adam and a batch size of 4, the model can be trained in ~48 GB VRAM. + +```sh +export MODEL_NAME="DeepFloyd/IF-I-XL-v1.0" + +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_if" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a photo of sks dog" \ + --resolution=64 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-7 \ + --max_train_steps=150 \ + --validation_prompt "a photo of sks dog" \ + --validation_steps 25 \ + --text_encoder_use_attention_mask \ + --tokenizer_max_length 77 \ + --pre_compute_text_embeddings \ + --use_8bit_adam \ + --set_grads_to_none \ + --skip_save_text_encoder \ + --push_to_hub +``` + +### IF Stage II Full Dreambooth + +`--learning_rate=5e-6`: With a smaller effective batch size of 4, we found that we required learning rates as low as +1e-8. + +`--resolution=256`: The upscaler expects higher resolution inputs + +`--train_batch_size=2` and `--gradient_accumulation_steps=6`: We found that full training of stage II particularly with +faces required large effective batch sizes. + +```sh +export MODEL_NAME="DeepFloyd/IF-II-L-v1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="dreambooth_dog_upscale" +export VALIDATION_IMAGES="dog_downsized/image_1.png dog_downsized/image_2.png dog_downsized/image_3.png dog_downsized/image_4.png" + +accelerate launch train_dreambooth.py \ + --report_to wandb \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt="a sks dog" \ + --resolution=256 \ + --train_batch_size=2 \ + --gradient_accumulation_steps=6 \ + --learning_rate=5e-6 \ + --max_train_steps=2000 \ + --validation_prompt="a sks dog" \ + --validation_steps=150 \ + --checkpointing_steps=500 \ + --pre_compute_text_embeddings \ + --tokenizer_max_length=77 \ + --text_encoder_use_attention_mask \ + --validation_images $VALIDATION_IMAGES \ + --class_labels_conditioning timesteps \ + --push_to_hub +``` diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index a9449002ca80..01344662afa7 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import gc import hashlib import itertools import logging @@ -22,7 +23,6 @@ import warnings from pathlib import Path -import accelerate import numpy as np import torch import torch.nn.functional as F @@ -31,9 +31,10 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from huggingface_hub import create_repo, upload_folder +from huggingface_hub import create_repo, model_info, upload_folder from packaging import version from PIL import Image +from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm @@ -45,6 +46,7 @@ DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, + StableDiffusionPipeline, UNet2DConditionModel, ) from diffusers.optimization import get_scheduler @@ -56,37 +58,115 @@ import wandb # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__) -def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +instance_prompt: {prompt} +tags: +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} +- text-to-image +- diffusers +- dreambooth +inference: true +--- + """ + model_card = f""" +# DreamBooth - {repo_id} + +This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). +You can find some example images in the following. \n +{img_str} + +DreamBooth for the text encoder was enabled: {train_text_encoder}. +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds +): logger.info( f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f" {args.validation_prompt}." ) + + pipeline_args = {} + + if vae is not None: + pipeline_args["vae"] = vae + + if text_encoder is not None: + text_encoder = accelerator.unwrap_model(text_encoder) + # create pipeline (note: unet and vae are loaded again in float32) pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, + text_encoder=text_encoder, unet=accelerator.unwrap_model(unet), - vae=vae, revision=args.revision, torch_dtype=weight_dtype, + **pipeline_args, ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": prompt_embeds, + "negative_prompt_embeds": negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - images.append(image) + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -104,6 +184,8 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight del pipeline torch.cuda.empty_cache() + return images + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( @@ -121,6 +203,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") @@ -425,6 +511,40 @@ def parse_args(input_args=None): " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." ), ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -447,6 +567,9 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + return args @@ -466,10 +589,16 @@ def __init__( class_num=None, size=512, center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -508,43 +637,59 @@ def __len__(self): def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask return example def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] + pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -554,6 +699,11 @@ def collate_fn(examples, with_prior_preservation=False): "input_ids": input_ids, "pixel_values": pixel_values, } + + if has_attention_mask: + attention_mask = torch.cat(attention_mask, dim=0) + batch["attention_mask"] = attention_mask + return batch @@ -574,6 +724,50 @@ def __getitem__(self, index): return example +def model_has_vae(args): + config_file_name = os.path.join("vae", AutoencoderKL.config_name) + if os.path.isdir(args.pretrained_model_name_or_path): + config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) + return os.path.isfile(config_file_name) + else: + files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings + return any(file.rfilename == config_file_name for file in files_in_repo) + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -693,43 +887,50 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + if model_has_vae(args): + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + else: + vae = None + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - for model in models: - sub_dir = "unet" if type(model) == type(unet) else "text_encoder" - model.save_pretrained(os.path.join(output_dir, sub_dir)) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - while len(models) > 0: - # pop models so that they are not loaded again - model = models.pop() - - if type(model) == type(text_encoder): - # load transformers style into model - load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") - model.config = load_model.config - else: - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for model in models: + sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(text_encoder))): + # load transformers style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model - model.load_state_dict(load_model.state_dict()) - del load_model + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) + if vae is not None: + vae.requires_grad_(False) - vae.requires_grad_(False) if not args.train_text_encoder: text_encoder.requires_grad_(False) @@ -803,6 +1004,44 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -813,6 +1052,9 @@ def load_model_hook(models, input_dir): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, ) train_dataloader = torch.utils.data.DataLoader( @@ -858,8 +1100,10 @@ def load_model_hook(models, input_dir): weight_dtype = torch.bfloat16 # Move vae and text_encoder to device and cast to weight_dtype - vae.to(accelerator.device, dtype=weight_dtype) - if not args.train_text_encoder: + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) + + if not args.train_text_encoder and text_encoder is not None: text_encoder.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. @@ -929,37 +1173,65 @@ def load_model_hook(models, input_dir): continue with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values - # Sample noise that we'll add to the latents + # Sample noise that we'll add to the model input if args.offset_noise: - noise = torch.randn_like(latents) + 0.1 * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, device=latents.device + noise = torch.randn_like(model_input) + 0.1 * torch.randn( + model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device ) else: - noise = torch.randn_like(latents) - bsz = latents.shape[0] + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep + # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample + + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -997,13 +1269,25 @@ def load_model_hook(models, input_dir): global_step += 1 if accelerator.is_main_process: + images = [] if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) + images = log_validation( + text_encoder, + tokenizer, + unet, + vae, + args, + accelerator, + weight_dtype, + epoch, + validation_prompt_encoder_hidden_states, + validation_prompt_negative_prompt_embeds, + ) logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -1015,15 +1299,46 @@ def load_model_hook(models, input_dir): # Create the pipeline using using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: + pipeline_args = {} + + if text_encoder is not None: + pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) + + if args.skip_save_text_encoder: + pipeline_args["text_encoder"] = None + pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), revision=args.revision, + **pipeline_args, ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.save_pretrained(args.output_dir) if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + prompt=args.instance_prompt, + repo_folder=args.output_dir, + pipeline=pipeline, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 1a4ca9153c80..0a664bafa573 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -36,7 +36,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") # Cache compiled models across invocations of this script. cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 805a8d1eea4d..caebfdac80bc 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import gc import hashlib import itertools import logging @@ -33,6 +34,7 @@ from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image +from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm @@ -48,19 +50,34 @@ UNet2DConditionModel, ) from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__) -def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None): +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + prompt=str, + repo_folder=None, + pipeline: DiffusionPipeline = None, +): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) @@ -72,8 +89,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, train_text_encode base_model: {base_model} instance_prompt: {prompt} tags: -- stable-diffusion -- stable-diffusion-diffusers +- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} +- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} - text-to-image - diffusers - lora @@ -108,6 +125,10 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel else: raise ValueError(f"{model_class} is not supported.") @@ -387,6 +408,37 @@ def parse_args(input_args=None): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--pre_compute_text_embeddings", + action="store_true", + help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", + ) + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -409,6 +461,9 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.train_text_encoder and args.pre_compute_text_embeddings: + raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") + return args @@ -428,10 +483,16 @@ def __init__( class_num=None, size=512, center_crop=False, + encoder_hidden_states=None, + instance_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, ): self.size = size self.center_crop = center_crop self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): @@ -470,42 +531,57 @@ def __len__(self): def __getitem__(self, index): example = {} instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) - example["instance_prompt_ids"] = self.tokenizer( - self.instance_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) - example["class_prompt_ids"] = self.tokenizer( - self.class_prompt, - truncation=True, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids + + if self.instance_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask return example def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: input_ids += [example["class_prompt_ids"] for example in examples] pixel_values += [example["class_images"] for example in examples] + if has_attention_mask: + attention_mask += [example["class_attention_mask"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() @@ -516,6 +592,10 @@ def collate_fn(examples, with_prior_preservation=False): "input_ids": input_ids, "pixel_values": pixel_values, } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + return batch @@ -536,6 +616,40 @@ def __getitem__(self, index): return example +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -656,13 +770,22 @@ def main(args): text_encoder = text_encoder_cls.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + try: + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + except OSError: + # IF does not have a VAE so let's just set it to None + # We don't have to error out here + vae = None + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) # We only train the additional adapter LoRA layers - vae.requires_grad_(False) + if vae is not None: + vae.requires_grad_(False) text_encoder.requires_grad_(False) unet.requires_grad_(False) @@ -676,7 +799,8 @@ def main(args): # Move unet, vae and text_encoder to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) + if vae is not None: + vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: @@ -707,7 +831,7 @@ def main(args): # Set correct lora layers unet_lora_attn_procs = {} - for name in unet.attn_processors.keys(): + for name, attn_processor in unet.attn_processors.items(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] @@ -718,13 +842,18 @@ def main(args): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - unet_lora_attn_procs[name] = LoRAAttnProcessor( + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + unet_lora_attn_procs[name] = lora_attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim ) unet.set_attn_processor(unet_lora_attn_procs) unet_lora_layers = AttnProcsLayers(unet.attn_processors) - accelerator.register_for_checkpointing(unet_lora_layers) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, @@ -733,24 +862,78 @@ def main(args): if args.train_text_encoder: text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None + hidden_size=module.out_proj.out_features, cross_attention_dim=None ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) - temp_pipeline = StableDiffusionPipeline.from_pretrained( + temp_pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=text_encoder ) temp_pipeline._modify_text_encoder(text_lora_attn_procs) text_encoder = temp_pipeline.text_encoder - accelerator.register_for_checkpointing(text_encoder_lora_layers) del temp_pipeline - if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + if args.train_text_encoder: + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + + for model in models: + state_dict = model.state_dict() + + if ( + text_encoder_lora_layers is not None + and text_encoder_keys is not None + and state_dict.keys() == text_encoder_keys + ): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_keys: + # unet + unet_lora_layers_to_save = state_dict + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, ) + def load_model_hook(models, input_dir): + # Note we DON'T pass the unet and text encoder here an purpose + # so that the we don't accidentally override the LoRA layers of + # unet_lora_layers and text_encoder_lora_layers which are stored in `models` + # with new torch.nn.Modules / weights. We simply use the pipeline class as + # an easy way to load the lora checkpoints + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) + + # load lora weights into models + models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) + if len(models) > 1: + models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + + # delete temporary pipeline and pop models + del temp_pipeline + for _ in range(len(models)): + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -788,6 +971,44 @@ def main(args): eps=args.adam_epsilon, ) + if args.pre_compute_text_embeddings: + + def compute_text_embeddings(prompt): + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) + prompt_embeds = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + return prompt_embeds + + pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + validation_prompt_negative_prompt_embeds = compute_text_embeddings("") + + if args.validation_prompt is not None: + validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) + else: + validation_prompt_encoder_hidden_states = None + + if args.instance_prompt is not None: + pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) + else: + pre_computed_instance_prompt_encoder_hidden_states = None + + text_encoder = None + tokenizer = None + + gc.collect() + torch.cuda.empty_cache() + else: + pre_computed_encoder_hidden_states = None + validation_prompt_encoder_hidden_states = None + validation_prompt_negative_prompt_embeds = None + pre_computed_instance_prompt_encoder_hidden_states = None + # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -798,6 +1019,9 @@ def main(args): tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + instance_prompt_encoder_hidden_states=pre_computed_instance_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, ) train_dataloader = torch.utils.data.DataLoader( @@ -901,32 +1125,63 @@ def main(args): continue with accelerator.accumulate(unet): - # Convert images to latent space - latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() - latents = latents * vae.config.scaling_factor + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + if vae is not None: + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + else: + model_input = pixel_values # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep + # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + if args.pre_compute_text_embeddings: + encoder_hidden_states = batch["input_ids"] + else: + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if accelerator.unwrap_model(unet).config.in_channels == channels * 2: + noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + model_pred = unet( + noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels + ).sample + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = torch.chunk(model_pred, 2, dim=1) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) + target = noise_scheduler.get_velocity(model_input, noise, timesteps) else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") @@ -963,17 +1218,10 @@ def main(args): progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - # We combine the text encoder and UNet LoRA parameters with a simple - # custom logic. `accelerator.save_state()` won't know that. So, - # use `LoraLoaderMixin.save_lora_weights()`. - LoraLoaderMixin.save_lora_weights( - save_directory=save_path, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, - ) + accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -993,20 +1241,50 @@ def main(args): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), + text_encoder=None if args.pre_compute_text_embeddings else accelerator.unwrap_model(text_encoder), revision=args.revision, torch_dtype=weight_dtype, ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + if args.pre_compute_text_embeddings: + pipeline_args = { + "prompt_embeds": validation_prompt_encoder_hidden_states, + "negative_prompt_embeds": validation_prompt_negative_prompt_embeds, + } + else: + pipeline_args = {"prompt": args.validation_prompt} + + if args.validation_images is None: + images = [ + pipeline(**pipeline_args, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": @@ -1029,7 +1307,12 @@ def main(args): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) - text_encoder = text_encoder.to(torch.float32) + unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + + if text_encoder is not None: + text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, @@ -1041,13 +1324,27 @@ def main(args): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline = pipeline.to(accelerator.device) # load attention processors - pipeline.load_attn_procs(args.output_dir) + pipeline.load_lora_weights(args.output_dir) # run inference + images = [] if args.validation_prompt and args.num_validation_images > 0: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ @@ -1077,6 +1374,7 @@ def main(args): train_text_encoder=args.train_text_encoder, prompt=args.instance_prompt, repo_folder=args.output_dir, + pipeline=pipeline, ) upload_folder( repo_id=repo_id, diff --git a/examples/instruct_pix2pix/README.md b/examples/instruct_pix2pix/README.md index 94a7bd2a98f6..355d48193634 100644 --- a/examples/instruct_pix2pix/README.md +++ b/examples/instruct_pix2pix/README.md @@ -185,3 +185,5 @@ speed and quality during performance: Particularly, `image_guidance_scale` and `guidance_scale` can have a profound impact on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example). + +If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd). \ No newline at end of file diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py index dc5a1c3081c0..0fbe45c80bfa 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py @@ -51,7 +51,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__, log_level="INFO") diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 61312fb3a4b3..a5bfbbb7b12a 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -20,6 +20,7 @@ import random from pathlib import Path +import accelerate import datasets import numpy as np import torch @@ -28,30 +29,96 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder +from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer from onnxruntime.training.ortmodule import ORTModule +from packaging import version from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version +from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.13.0.dev0") +check_min_version("0.17.0.dev0") logger = get_logger(__name__, log_level="INFO") +DATASET_NAME_MAPPING = { + "lambdalabs/pokemon-blip-captions": ("image", "text"), +} + + +def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): + logger.info("Running validation... ") + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=accelerator.unwrap_model(vae), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=accelerator.unwrap_model(unet), + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + images = [] + for i in range(len(args.validation_prompts)): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] + + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + elif tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") + for i, image in enumerate(images) + ] + } + ) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + torch.cuda.empty_cache() + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1." + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -110,6 +177,13 @@ def parse_args(): "value if set." ), ) + parser.add_argument( + "--validation_prompts", + type=str, + default=None, + nargs="+", + help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), + ) parser.add_argument( "--output_dir", type=str, @@ -191,6 +265,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -295,6 +376,22 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") + parser.add_argument( + "--validation_epochs", + type=int, + default=5, + help="Run validation every X epochs.", + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -312,13 +409,18 @@ def parse_args(): return args -dataset_name_mapping = { - "lambdalabs/pokemon-blip-captions": ("image", "text"), -} - - def main(): args = parse_args() + + if args.non_ema_revision is not None: + deprecate( + "non_ema_revision!=None", + "0.15.0", + message=( + "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" + " use `--variant=non_ema` instead." + ), + ) logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) @@ -366,10 +468,34 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) @@ -383,17 +509,81 @@ def main(): ema_unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - ema_unet = EMAModel(ema_unet.parameters()) + ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - vae.enable_gradient_checkpointing() # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices @@ -426,6 +616,8 @@ def main(): eps=args.adam_epsilon, ) + optimizer = ORT_FP16_Optimizer(optimizer) + # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -455,7 +647,7 @@ def main(): column_names = dataset["train"].column_names # 6. Get the column names for input/target. - dataset_columns = dataset_name_mapping.get(args.dataset_name, None) + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) if args.image_column is None: image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] else: @@ -549,10 +741,10 @@ def collate_fn(examples): unet, optimizer, train_dataloader, lr_scheduler ) - unet = ORTModule(unet) - if args.use_ema: - accelerator.register_for_checkpointing(ema_unet) + ema_unet.to(accelerator.device) + + unet = ORTModule(unet) # For mixed precision training we cast the text_encoder and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -565,8 +757,6 @@ def collate_fn(examples): # Move text_encode and vae to gpu and cast to weight_dtype text_encoder.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) - if args.use_ema: - ema_unet.to(accelerator.device) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -578,7 +768,9 @@ def collate_fn(examples): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - accelerator.init_trackers("text2image-fine-tune", config=vars(args)) + tracker_config = dict(vars(args)) + tracker_config.pop("validation_prompts") + accelerator.init_trackers(args.tracker_project_name, tracker_config) # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -639,6 +831,13 @@ def collate_fn(examples): # Sample noise that we'll add to the latents noise = torch.randn_like(latents) + if args.noise_offset: + # https://www.crosslabs.org//blog/diffusion-with-offset-noise + noise += args.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1), device=latents.device + ) + if args.input_pertubation: + new_noise = noise + args.input_pertubation * torch.randn_like(noise) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -646,7 +845,10 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if args.input_pertubation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] @@ -660,8 +862,24 @@ def collate_fn(examples): raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() @@ -696,6 +914,26 @@ def collate_fn(examples): if global_step >= args.max_train_steps: break + if accelerator.is_main_process: + if args.validation_prompts is not None and epoch % args.validation_epochs == 0: + if args.use_ema: + # Store the UNet parameters temporarily and load the EMA parameters to perform inference. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + log_validation( + vae, + text_encoder, + tokenizer, + unet, + args, + accelerator, + weight_dtype, + global_step, + ) + if args.use_ema: + # Switch back to the original UNet parameters. + ema_unet.restore(unet.parameters()) + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: diff --git a/examples/research_projects/onnxruntime/textual_inversion/README.md b/examples/research_projects/onnxruntime/textual_inversion/README.md index 0ed34966e9f1..9f08983eaaad 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/README.md +++ b/examples/research_projects/onnxruntime/textual_inversion/README.md @@ -53,7 +53,19 @@ If you have already cloned the repo, then you won't need to go through these ste
-Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data. +Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example . + +Let's first download it locally: + +```py +from huggingface_hub import snapshot_download + +local_dir = "./cat" +snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes") +``` + +This will be our training data. +Now we can launch the training using ## Use ONNXRuntime to accelerate training In order to leverage onnxruntime to accelerate training, please use textual_inversion.py diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index a3d24066ad7a..7ff77118c38e 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -18,9 +18,9 @@ import math import os import random +import warnings from pathlib import Path -import datasets import numpy as np import PIL import torch @@ -31,6 +31,7 @@ from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder +from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer from onnxruntime.training.ortmodule import ORTModule # TODO: remove and import from diffusers.utils when the new version of diffusers is released @@ -55,6 +56,9 @@ from diffusers.utils.import_utils import is_xformers_available +if is_wandb_available(): + import wandb + if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): PIL_INTERPOLATION = { "linear": PIL.Image.Resampling.BILINEAR, @@ -75,14 +79,92 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.13.0.dev0") +check_min_version("0.17.0.dev0") logger = get_logger(__name__) -def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): +def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): + img_str = "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- textual_inversion +inference: true +--- + """ + model_card = f""" +# Textual inversion text2image fine-tuning - {repo_id} +These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline (note: unet and vae are loaded again in float32) + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + unet=unet, + vae=vae, + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [] + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + return images + + +def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path): logger.info("Saving embeddings") - learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] + learned_embeds = ( + accelerator.unwrap_model(text_encoder) + .get_input_embeddings() + .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] + ) learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} torch.save(learned_embeds_dict, save_path) @@ -96,10 +178,15 @@ def parse_args(): help="Save learned_embeds.bin every X updates steps.", ) parser.add_argument( - "--only_save_embeds", + "--save_as_full_pipeline", action="store_true", - default=False, - help="Save only the embeddings for the new concept.", + help="Save the complete stable diffusion pipeline.", + ) + parser.add_argument( + "--num_vectors", + type=int, + default=1, + help="How many textual inversion vectors shall be used to learn the concept.", ) parser.add_argument( "--pretrained_model_name_or_path", @@ -269,12 +356,22 @@ def parse_args(): default=4, help="Number of images that should be generated during validation with `validation_prompt`.", ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) parser.add_argument( "--validation_epochs", type=int, - default=50, + default=None, help=( - "Run validation every X epochs. Validation consists of running the prompt" + "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt" " `args.validation_prompt` multiple times: `args.num_validation_images`" " and logging the images." ), @@ -479,7 +576,6 @@ def main(): if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb # Make one log on every process with the configuration for debugging. logging.basicConfig( @@ -489,11 +585,9 @@ def main(): ) logger.info(accelerator.state, main_process_only=False) if accelerator.is_local_main_process: - datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_warning() diffusers.utils.logging.set_verbosity_info() else: - datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() diffusers.utils.logging.set_verbosity_error() @@ -528,8 +622,19 @@ def main(): ) # Add the placeholder token in tokenizer - num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - if num_added_tokens == 0: + placeholder_tokens = [args.placeholder_token] + + if args.num_vectors < 1: + raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}") + + # add dummy tokens for multi-vector + additional_tokens = [] + for i in range(1, args.num_vectors): + additional_tokens.append(f"{args.placeholder_token}_{i}") + placeholder_tokens += additional_tokens + + num_added_tokens = tokenizer.add_tokens(placeholder_tokens) + if num_added_tokens != args.num_vectors: raise ValueError( f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." @@ -542,14 +647,16 @@ def main(): raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens) # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data - token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + with torch.no_grad(): + for token_id in placeholder_token_ids: + token_embeds[token_id] = token_embeds[initializer_token_id].clone() # Freeze vae and unet vae.requires_grad_(False) @@ -568,6 +675,13 @@ def main(): if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") @@ -591,6 +705,8 @@ def main(): eps=args.adam_epsilon, ) + optimizer = ORT_FP16_Optimizer(optimizer) + # Dataset and DataLoaders creation: train_dataset = TextualInversionDataset( data_root=args.train_data_dir, @@ -605,6 +721,15 @@ def main(): train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) + if args.validation_epochs is not None: + warnings.warn( + f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}." + " Deprecated validation_epochs in favor of `validation_steps`" + f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}", + FutureWarning, + stacklevel=2, + ) + args.validation_steps = args.validation_epochs * len(train_dataset) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -626,6 +751,8 @@ def main(): ) text_encoder = ORTModule(text_encoder) + unet = ORTModule(unet) + vae = ORTModule(vae) # For mixed precision training we cast the unet and vae weights to half-precision # as these models are only used for inference, keeping weights in full precision is not required. @@ -663,7 +790,6 @@ def main(): logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 - # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": @@ -744,7 +870,9 @@ def main(): optimizer.zero_grad() # Let's make sure we don't update any embedding weights besides the newly added token - index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False + with torch.no_grad(): accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ index_no_updates @@ -752,72 +880,38 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + images = [] progress_bar.update(1) global_step += 1 if global_step % args.save_steps == 0: save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") - save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) + save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path) - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + images = log_validation( + text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch + ) + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: break - - if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline (note: unet and vae are loaded again in float32) - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - text_encoder=accelerator.unwrap_model(text_encoder), - revision=args.revision, - ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = ( - None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) - ) - prompt = args.num_validation_images * [args.validation_prompt] - images = pipeline(prompt, num_inference_steps=25, generator=generator).images - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() - - # Create the pipeline using using the trained modules and save it. + # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - if args.push_to_hub and args.only_save_embeds: + if args.push_to_hub and not args.save_as_full_pipeline: logger.warn("Enabling full model saving because --push_to_hub=True was specified.") save_full_model = True else: - save_full_model = not args.only_save_embeds + save_full_model = args.save_as_full_pipeline if save_full_model: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -829,9 +923,15 @@ def main(): pipeline.save_pretrained(args.output_dir) # Save the newly trained embeddings save_path = os.path.join(args.output_dir, "learned_embeds.bin") - save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) + save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path) if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir, diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/README.md b/examples/research_projects/onnxruntime/unconditional_image_generation/README.md index 621e9a2fd69a..c28ecefc9a30 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/README.md +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/README.md @@ -34,7 +34,7 @@ In order to leverage onnxruntime to accelerate training, please use train_uncond The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime: ```bash -accelerate launch train_unconditional_ort.py \ +accelerate launch train_unconditional.py \ --dataset_name="huggan/flowers-102-categories" \ --resolution=64 --center_crop --random_flip \ --output_dir="ddpm-ema-flowers-64" \ diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt b/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt index f366720afd11..ca21143c42d9 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/requirements.txt @@ -1,3 +1,4 @@ accelerate>=0.16.0 torchvision datasets +tensorboard \ No newline at end of file diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 1b38036d82c0..9dc46e864ae8 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Optional +import accelerate import datasets import torch import torch.nn.functional as F @@ -14,7 +15,9 @@ from accelerate.utils import ProjectConfiguration from datasets import load_dataset from huggingface_hub import HfFolder, Repository, create_repo, whoami +from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer from onnxruntime.training.ortmodule import ORTModule +from packaging import version from torchvision import transforms from tqdm.auto import tqdm @@ -22,11 +25,12 @@ from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel -from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available +from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.13.0.dev0") +check_min_version("0.17.0.dev0") logger = get_logger(__name__, log_level="INFO") @@ -34,6 +38,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param broadcast_shape: a larger shape of K dimensions with the batch @@ -66,6 +71,12 @@ def parse_args(): default=None, help="The config of the Dataset, leave as None if there's only one config.", ) + parser.add_argument( + "--model_config_name_or_path", + type=str, + default=None, + help="The config of the UNet model to train, leave as None to use standard DDPM configuration.", + ) parser.add_argument( "--train_data_dir", type=str, @@ -251,6 +262,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -295,6 +309,40 @@ def main(args): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") import wandb + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if args.use_ema: + ema_model.save_pretrained(os.path.join(output_dir, "unet_ema")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel) + ema_model.load_state_dict(load_model.state_dict()) + ema_model.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", @@ -328,29 +376,33 @@ def main(args): os.makedirs(args.output_dir, exist_ok=True) # Initialize the model - model = UNet2DModel( - sample_size=args.resolution, - in_channels=3, - out_channels=3, - layers_per_block=2, - block_out_channels=(128, 128, 256, 256, 512, 512), - down_block_types=( - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "DownBlock2D", - "AttnDownBlock2D", - "DownBlock2D", - ), - up_block_types=( - "UpBlock2D", - "AttnUpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - "UpBlock2D", - ), - ) + if args.model_config_name_or_path is None: + model = UNet2DModel( + sample_size=args.resolution, + in_channels=3, + out_channels=3, + layers_per_block=2, + block_out_channels=(128, 128, 256, 256, 512, 512), + down_block_types=( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "AttnDownBlock2D", + "DownBlock2D", + ), + up_block_types=( + "UpBlock2D", + "AttnUpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + ) + else: + config = UNet2DModel.load_config(args.model_config_name_or_path) + model = UNet2DModel.from_config(config) # Create EMA for the model. if args.use_ema: @@ -360,8 +412,23 @@ def main(args): use_ema_warmup=True, inv_gamma=args.ema_inv_gamma, power=args.ema_power, + model_cls=UNet2DModel, + model_config=model.config, ) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + model.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + # Initialize the scheduler accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) if accepts_prediction_type: @@ -382,6 +449,8 @@ def main(args): eps=args.adam_epsilon, ) + optimizer = ORT_FP16_Optimizer(optimizer) + # Get the datasets: you can either provide your own training and evaluation files (see below) # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). @@ -434,10 +503,7 @@ def transform_images(examples): model, optimizer, train_dataloader, lr_scheduler ) - model = ORTModule(model) - if args.use_ema: - accelerator.register_for_checkpointing(ema_model) ema_model.to(accelerator.device) # We need to initialize the trackers we use, and also store our configuration. @@ -446,6 +512,8 @@ def transform_images(examples): run = os.path.split(__file__)[-1].split(".")[0] accelerator.init_trackers(run) + model = ORTModule(model) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) max_train_steps = args.num_epochs * num_update_steps_per_epoch @@ -552,7 +620,7 @@ def transform_images(examples): logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} if args.use_ema: - logs["ema_decay"] = ema_model.decay + logs["ema_decay"] = ema_model.cur_decay_value progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) progress_bar.close() @@ -563,8 +631,11 @@ def transform_images(examples): if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: unet = accelerator.unwrap_model(model) + if args.use_ema: + ema_model.store(unet.parameters()) ema_model.copy_to(unet.parameters()) + pipeline = DDPMPipeline( unet=unet, scheduler=noise_scheduler, @@ -575,18 +646,24 @@ def transform_images(examples): images = pipeline( generator=generator, batch_size=args.eval_batch_size, - output_type="numpy", num_inference_steps=args.ddpm_num_inference_steps, + output_type="numpy", ).images + if args.use_ema: + ema_model.restore(unet.parameters()) + # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") if args.logger == "tensorboard": - accelerator.get_tracker("tensorboard").add_images( - "test_samples", images_processed.transpose(0, 3, 1, 2), epoch - ) + if is_accelerate_version(">=", "0.17.0.dev0"): + tracker = accelerator.get_tracker("tensorboard", unwrap=True) + else: + tracker = accelerator.get_tracker("tensorboard") + tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) elif args.logger == "wandb": + # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files accelerator.get_tracker("wandb").log( {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, step=global_step, @@ -594,7 +671,22 @@ def transform_images(examples): if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: # save the model + unet = accelerator.unwrap_model(model) + + if args.use_ema: + ema_model.store(unet.parameters()) + ema_model.copy_to(unet.parameters()) + + pipeline = DDPMPipeline( + unet=unet, + scheduler=noise_scheduler, + ) + pipeline.save_pretrained(args.output_dir) + + if args.use_ema: + ema_model.restore(unet.parameters()) + if args.push_to_hub: repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) diff --git a/examples/test_examples.py b/examples/test_examples.py index d4a5ef5046f0..59c96f44fe93 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -147,6 +147,32 @@ def test_dreambooth(self): self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_if(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + def test_dreambooth_checkpointing(self): instance_prompt = "photo" pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" @@ -281,13 +307,52 @@ def test_dreambooth_lora_with_text_encoder(self): # save_pretrained smoke test self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) - # the names of the keys of the state dict should either start with `unet` - # or `text_encoder`. + # check `text_encoder` is present at all. lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) keys = lora_state_dict.keys() + is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) + self.assertTrue(is_text_encoder_present) + + # the names of the keys of the state dict should either start with `unet` + # or `text_encoder`. is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys) self.assertTrue(is_correct_naming) + def test_dreambooth_lora_if_model(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora.py + --pretrained_model_name_or_path hf-internal-testing/tiny-if-pipe + --instance_data_dir docs/source/en/imgs + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --pre_compute_text_embeddings + --tokenizer_max_length=77 + --text_encoder_use_attention_mask + """.split() + + run_command(self._launch_args + test_args) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + def test_custom_diffusion(self): with tempfile.TemporaryDirectory() as tmpdir: test_args = f""" diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index 406a64b3759f..160e73fa02bb 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -229,6 +229,21 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] image.save("pokemon.png") ``` +If you are loading the LoRA parameters from the Hub and if the Hub repository has +a `base_model` tag (such as [this](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/README.md?code=true#L4)), then +you can do: + +```py +from huggingface_hub.repocard import RepoCard + +lora_model_id = "sayakpaul/sd-model-finetuned-lora-t4" +card = RepoCard.load(lora_model_id) +base_model_id = card.data.to_dict()["base_model"] + +pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16) +... +``` + ## Training with Flax/JAX For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 1d6db2a6f1da..5195ef7848df 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -29,6 +29,7 @@ import transformers from accelerate import Accelerator from accelerate.logging import get_logger +from accelerate.state import AcceleratorState from accelerate.utils import ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder @@ -36,6 +37,7 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer +from transformers.utils import ContextManagers import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel @@ -50,7 +52,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__, log_level="INFO") @@ -112,6 +114,9 @@ def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -302,6 +307,12 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -461,10 +472,34 @@ def main(): tokenizer = CLIPTokenizer.from_pretrained( args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision ) - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + + def deepspeed_zero_init_disabled_context_manager(): + """ + returns either a context list that includes one that will disable zero.Init or an empty context list + """ + deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None + if deepspeed_plugin is None: + return [] + + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] + + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate + # will try to assign the same optimizer with the same weights to all models during + # `deepspeed.initialize`, which of course doesn't work. + # + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 + # frozen models from being partitioned during `zero.Init` which gets called during + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) @@ -801,7 +836,8 @@ def collate_fn(examples): noise += args.noise_offset * torch.randn( (latents.shape[0], latents.shape[1], 1, 1), device=latents.device ) - + if args.input_perturbation: + new_noise = noise + args.input_perturbation * torch.randn_like(noise) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -809,12 +845,19 @@ def collate_fn(examples): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if args.input_perturbation: + noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) + else: + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index c5dc71f0536e..2953896e5240 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -33,7 +33,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = logging.getLogger(__name__) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 39bdb4e59a52..523c6ae9fc93 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -47,7 +47,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__, log_level="INFO") @@ -239,6 +239,13 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -265,6 +272,12 @@ def parse_args(): parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", + ) parser.add_argument( "--hub_model_id", type=str, @@ -472,6 +485,30 @@ def main(): else: raise ValueError("xformers is not available. Make sure it is installed correctly") + def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + lora_layers = AttnProcsLayers(unet.attn_processors) # Enable TF32 for faster training on Ampere GPUs, @@ -718,6 +755,10 @@ def collate_fn(examples): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Get the target for loss depending on the prediction type + if args.prediction_type is not None: + # set prediction_type of scheduler if defined + noise_scheduler.register_to_config(prediction_type=args.prediction_type) + if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": @@ -727,7 +768,23 @@ def collate_fn(examples): # Predict the noise residual and compute loss model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + # We first calculate the original loss. Then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() @@ -778,7 +835,9 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): images.append( @@ -834,7 +893,9 @@ def collate_fn(examples): pipeline.unet.load_attn_procs(args.output_dir) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 824759cc4ca9..4a193abc138f 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -77,7 +77,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__) @@ -176,10 +176,9 @@ def parse_args(): help="Save learned_embeds.bin every X updates steps.", ) parser.add_argument( - "--only_save_embeds", + "--save_as_full_pipeline", action="store_true", - default=True, - help="Save only the embeddings for the new concept.", + help="Save the complete stable diffusion pipeline.", ) parser.add_argument( "--num_vectors", @@ -286,6 +285,12 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -740,6 +745,7 @@ def main(): optimizer=optimizer, num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles * args.gradient_accumulation_steps, ) # Prepare everything with our `accelerator`. @@ -900,11 +906,11 @@ def main(): # Create the pipeline using the trained modules and save it. accelerator.wait_for_everyone() if accelerator.is_main_process: - if args.push_to_hub and args.only_save_embeds: + if args.push_to_hub and not args.save_as_full_pipeline: logger.warn("Enabling full model saving because --push_to_hub=True was specified.") save_full_model = True else: - save_full_model = not args.only_save_embeds + save_full_model = args.save_as_full_pipeline if save_full_model: pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index 19553ceb92ec..41db4ed2d004 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -56,7 +56,7 @@ # ------------------------------------------------------------------------------ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = logging.getLogger(__name__) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 836a38f96286..24a30796c785 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -28,7 +28,7 @@ # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.16.0") +check_min_version("0.17.0") logger = get_logger(__name__, log_level="INFO") diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py new file mode 100644 index 000000000000..de9879f7f03b --- /dev/null +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -0,0 +1,1400 @@ +import argparse +import os +import tempfile + +import torch +from accelerate import load_checkpoint_and_dispatch + +from diffusers import UNet2DConditionModel +from diffusers.models.prior_transformer import PriorTransformer +from diffusers.models.vq_model import VQModel +from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel + + +""" +Example - From the diffusers root directory: + +Download weights: +```sh +$ wget https://huggingface.co/ai-forever/Kandinsky_2.1/blob/main/prior_fp16.ckpt +``` + +Convert the model: +```sh +python scripts/convert_kandinsky_to_diffusers.py \ + --prior_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/prior_fp16.ckpt \ + --clip_stat_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/ViT-L-14_stats.th \ + --text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/decoder_fp16.ckpt \ + --inpaint_text2img_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/inpainting_fp16.ckpt \ + --movq_checkpoint_path /home/yiyi_huggingface_co/Kandinsky-2/checkpoints_Kandinsky_2.1/movq_final.ckpt \ + --dump_path /home/yiyi_huggingface_co/dump \ + --debug decoder +``` +""" + + +# prior + +PRIOR_ORIGINAL_PREFIX = "model" + +# Uses default arguments +PRIOR_CONFIG = {} + + +def prior_model_from_original_config(): + model = PriorTransformer(**PRIOR_CONFIG) + + return model + + +def prior_original_checkpoint_to_diffusers_checkpoint(model, checkpoint, clip_stats_checkpoint): + diffusers_checkpoint = {} + + # .time_embed.0 -> .time_embedding.linear_1 + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.0.bias"], + } + ) + + # .clip_img_proj -> .proj_in + diffusers_checkpoint.update( + { + "proj_in.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.weight"], + "proj_in.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.clip_img_proj.bias"], + } + ) + + # .text_emb_proj -> .embedding_proj + diffusers_checkpoint.update( + { + "embedding_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.weight"], + "embedding_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_emb_proj.bias"], + } + ) + + # .text_enc_proj -> .encoder_hidden_states_proj + diffusers_checkpoint.update( + { + "encoder_hidden_states_proj.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.weight"], + "encoder_hidden_states_proj.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.text_enc_proj.bias"], + } + ) + + # .positional_embedding -> .positional_embedding + diffusers_checkpoint.update({"positional_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.positional_embedding"]}) + + # .prd_emb -> .prd_embedding + diffusers_checkpoint.update({"prd_embedding": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.prd_emb"]}) + + # .time_embed.2 -> .time_embedding.linear_2 + diffusers_checkpoint.update( + { + "time_embedding.linear_2.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.time_embed.2.bias"], + } + ) + + # .resblocks. -> .transformer_blocks. + for idx in range(len(model.transformer_blocks)): + diffusers_transformer_prefix = f"transformer_blocks.{idx}" + original_transformer_prefix = f"{PRIOR_ORIGINAL_PREFIX}.transformer.resblocks.{idx}" + + # .attn -> .attn1 + diffusers_attention_prefix = f"{diffusers_transformer_prefix}.attn1" + original_attention_prefix = f"{original_transformer_prefix}.attn" + diffusers_checkpoint.update( + prior_attention_to_diffusers( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + original_attention_prefix=original_attention_prefix, + attention_head_dim=model.attention_head_dim, + ) + ) + + # .mlp -> .ff + diffusers_ff_prefix = f"{diffusers_transformer_prefix}.ff" + original_ff_prefix = f"{original_transformer_prefix}.mlp" + diffusers_checkpoint.update( + prior_ff_to_diffusers( + checkpoint, diffusers_ff_prefix=diffusers_ff_prefix, original_ff_prefix=original_ff_prefix + ) + ) + + # .ln_1 -> .norm1 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm1.weight": checkpoint[ + f"{original_transformer_prefix}.ln_1.weight" + ], + f"{diffusers_transformer_prefix}.norm1.bias": checkpoint[f"{original_transformer_prefix}.ln_1.bias"], + } + ) + + # .ln_2 -> .norm3 + diffusers_checkpoint.update( + { + f"{diffusers_transformer_prefix}.norm3.weight": checkpoint[ + f"{original_transformer_prefix}.ln_2.weight" + ], + f"{diffusers_transformer_prefix}.norm3.bias": checkpoint[f"{original_transformer_prefix}.ln_2.bias"], + } + ) + + # .final_ln -> .norm_out + diffusers_checkpoint.update( + { + "norm_out.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.weight"], + "norm_out.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.final_ln.bias"], + } + ) + + # .out_proj -> .proj_to_clip_embeddings + diffusers_checkpoint.update( + { + "proj_to_clip_embeddings.weight": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.weight"], + "proj_to_clip_embeddings.bias": checkpoint[f"{PRIOR_ORIGINAL_PREFIX}.out_proj.bias"], + } + ) + + # clip stats + clip_mean, clip_std = clip_stats_checkpoint + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + diffusers_checkpoint.update({"clip_mean": clip_mean, "clip_std": clip_std}) + + return diffusers_checkpoint + + +def prior_attention_to_diffusers( + checkpoint, *, diffusers_attention_prefix, original_attention_prefix, attention_head_dim +): + diffusers_checkpoint = {} + + # .c_qkv -> .{to_q, to_k, to_v} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{original_attention_prefix}.c_qkv.weight"], + bias=checkpoint[f"{original_attention_prefix}.c_qkv.bias"], + split=3, + chunk_size=attention_head_dim, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .c_proj -> .to_out.0 + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{original_attention_prefix}.c_proj.weight"], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{original_attention_prefix}.c_proj.bias"], + } + ) + + return diffusers_checkpoint + + +def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix): + diffusers_checkpoint = { + # .c_fc -> .net.0.proj + f"{diffusers_ff_prefix}.net.{0}.proj.weight": checkpoint[f"{original_ff_prefix}.c_fc.weight"], + f"{diffusers_ff_prefix}.net.{0}.proj.bias": checkpoint[f"{original_ff_prefix}.c_fc.bias"], + # .c_proj -> .net.2 + f"{diffusers_ff_prefix}.net.{2}.weight": checkpoint[f"{original_ff_prefix}.c_proj.weight"], + f"{diffusers_ff_prefix}.net.{2}.bias": checkpoint[f"{original_ff_prefix}.c_proj.bias"], + } + + return diffusers_checkpoint + + +# done prior + +# unet + +# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can +# update then. + +UNET_CONFIG = { + "act_fn": "silu", + "attention_head_dim": 64, + "block_out_channels": (384, 768, 1152, 1536), + "center_input_sample": False, + "class_embed_type": "identity", + "cross_attention_dim": 768, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + ), + "downsample_padding": 1, + "dual_cross_attention": False, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 3, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "only_cross_attention": False, + "out_channels": 8, + "resnet_time_scale_shift": "scale_shift", + "sample_size": 64, + "up_block_types": ( + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "ResnetUpsampleBlock2D", + ), + "upcast_attention": False, + "use_linear_projection": False, +} + + +def unet_model_from_original_config(): + model = UNet2DConditionModel(**UNET_CONFIG) + + return model + + +def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + num_head_channels = UNET_CONFIG["attention_head_dim"] + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) + diffusers_checkpoint.update(unet_conv_in(checkpoint)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + num_head_channels=num_head_channels, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .input_blocks -> .down_blocks + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + num_head_channels=num_head_channels, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + num_head_channels=num_head_channels, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint)) + diffusers_checkpoint.update(unet_conv_out(checkpoint)) + + return diffusers_checkpoint + + +# done unet + +# inpaint unet + +# We are hardcoding the model configuration for now. If we need to generalize to more model configurations, we can +# update then. + +INPAINT_UNET_CONFIG = { + "act_fn": "silu", + "attention_head_dim": 64, + "block_out_channels": (384, 768, 1152, 1536), + "center_input_sample": False, + "class_embed_type": "identity", + "cross_attention_dim": 768, + "down_block_types": ( + "ResnetDownsampleBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + "SimpleCrossAttnDownBlock2D", + ), + "downsample_padding": 1, + "dual_cross_attention": False, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 9, + "layers_per_block": 3, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "only_cross_attention": False, + "out_channels": 8, + "resnet_time_scale_shift": "scale_shift", + "sample_size": 64, + "up_block_types": ( + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "SimpleCrossAttnUpBlock2D", + "ResnetUpsampleBlock2D", + ), + "upcast_attention": False, + "use_linear_projection": False, +} + + +def inpaint_unet_model_from_original_config(): + model = UNet2DConditionModel(**INPAINT_UNET_CONFIG) + + return model + + +def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + num_head_channels = UNET_CONFIG["attention_head_dim"] + + diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) + diffusers_checkpoint.update(unet_conv_in(checkpoint)) + + # .input_blocks -> .down_blocks + + original_down_block_idx = 1 + + for diffusers_down_block_idx in range(len(model.down_blocks)): + checkpoint_update, num_original_down_blocks = unet_downblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_down_block_idx=diffusers_down_block_idx, + original_down_block_idx=original_down_block_idx, + num_head_channels=num_head_channels, + ) + + original_down_block_idx += num_original_down_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .input_blocks -> .down_blocks + + diffusers_checkpoint.update( + unet_midblock_to_diffusers_checkpoint( + model, + checkpoint, + num_head_channels=num_head_channels, + ) + ) + + # .output_blocks -> .up_blocks + + original_up_block_idx = 0 + + for diffusers_up_block_idx in range(len(model.up_blocks)): + checkpoint_update, num_original_up_blocks = unet_upblock_to_diffusers_checkpoint( + model, + checkpoint, + diffusers_up_block_idx=diffusers_up_block_idx, + original_up_block_idx=original_up_block_idx, + num_head_channels=num_head_channels, + ) + + original_up_block_idx += num_original_up_blocks + + diffusers_checkpoint.update(checkpoint_update) + + # done .output_blocks -> .up_blocks + + diffusers_checkpoint.update(unet_conv_norm_out(checkpoint)) + diffusers_checkpoint.update(unet_conv_out(checkpoint)) + + return diffusers_checkpoint + + +# done inpaint unet + +# text proj + +TEXT_PROJ_CONFIG = {} + + +def text_proj_from_original_config(): + model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG) + return model + + +# Note that the input checkpoint is the original text2img model checkpoint +def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint): + diffusers_checkpoint = { + # .text_seq_proj.0 -> .encoder_hidden_states_proj + "encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"], + "encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"], + # .clip_tok_proj -> .clip_extra_context_tokens_proj + "clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"], + "clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"], + # .proj_n -> .embedding_proj + "embedding_proj.weight": checkpoint["proj_n.weight"], + "embedding_proj.bias": checkpoint["proj_n.bias"], + # .ln_model_n -> .embedding_norm + "embedding_norm.weight": checkpoint["ln_model_n.weight"], + "embedding_norm.bias": checkpoint["ln_model_n.bias"], + # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings + "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"], + "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"], + } + + return diffusers_checkpoint + + +# unet utils + + +# .time_embed -> .time_embedding +def unet_time_embeddings(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "time_embedding.linear_1.weight": checkpoint["time_embed.0.weight"], + "time_embedding.linear_1.bias": checkpoint["time_embed.0.bias"], + "time_embedding.linear_2.weight": checkpoint["time_embed.2.weight"], + "time_embedding.linear_2.bias": checkpoint["time_embed.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks.0 -> .conv_in +def unet_conv_in(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_in.weight": checkpoint["input_blocks.0.0.weight"], + "conv_in.bias": checkpoint["input_blocks.0.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.0 -> .conv_norm_out +def unet_conv_norm_out(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_norm_out.weight": checkpoint["out.0.weight"], + "conv_norm_out.bias": checkpoint["out.0.bias"], + } + ) + + return diffusers_checkpoint + + +# .out.2 -> .conv_out +def unet_conv_out(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "conv_out.weight": checkpoint["out.2.weight"], + "conv_out.bias": checkpoint["out.2.bias"], + } + ) + + return diffusers_checkpoint + + +# .input_blocks -> .down_blocks +def unet_downblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_down_block_idx, original_down_block_idx, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.resnets" + original_down_block_prefix = "input_blocks" + + down_block = model.down_blocks[diffusers_down_block_idx] + + num_resnets = len(down_block.resnets) + + if down_block.downsamplers is None: + downsampler = False + else: + assert len(down_block.downsamplers) == 1 + downsampler = True + # The downsample block is also a resnet + num_resnets += 1 + + for resnet_idx_inc in range(num_resnets): + full_resnet_prefix = f"{original_down_block_prefix}.{original_down_block_idx + resnet_idx_inc}.0" + + if downsampler and resnet_idx_inc == num_resnets - 1: + # this is a downsample block + full_diffusers_resnet_prefix = f"down_blocks.{diffusers_down_block_idx}.downsamplers.0" + else: + # this is a regular resnet block + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if hasattr(down_block, "attentions"): + num_attentions = len(down_block.attentions) + diffusers_attention_prefix = f"down_blocks.{diffusers_down_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_down_block_prefix}.{original_down_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +# .middle_block -> .mid_block +def unet_midblock_to_diffusers_checkpoint(model, checkpoint, *, num_head_channels): + diffusers_checkpoint = {} + + # block 0 + + original_block_idx = 0 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.0", + resnet_prefix=f"middle_block.{original_block_idx}", + ) + ) + + original_block_idx += 1 + + # optional block 1 + + if hasattr(model.mid_block, "attentions") and model.mid_block.attentions[0] is not None: + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix="mid_block.attentions.0", + attention_prefix=f"middle_block.{original_block_idx}", + num_head_channels=num_head_channels, + ) + ) + original_block_idx += 1 + + # block 1 or block 2 + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, + diffusers_resnet_prefix="mid_block.resnets.1", + resnet_prefix=f"middle_block.{original_block_idx}", + ) + ) + + return diffusers_checkpoint + + +# .output_blocks -> .up_blocks +def unet_upblock_to_diffusers_checkpoint( + model, checkpoint, *, diffusers_up_block_idx, original_up_block_idx, num_head_channels +): + diffusers_checkpoint = {} + + diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.resnets" + original_up_block_prefix = "output_blocks" + + up_block = model.up_blocks[diffusers_up_block_idx] + + num_resnets = len(up_block.resnets) + + if up_block.upsamplers is None: + upsampler = False + else: + assert len(up_block.upsamplers) == 1 + upsampler = True + # The upsample block is also a resnet + num_resnets += 1 + + has_attentions = hasattr(up_block, "attentions") + + for resnet_idx_inc in range(num_resnets): + if upsampler and resnet_idx_inc == num_resnets - 1: + # this is an upsample block + if has_attentions: + # There is a middle attention block that we skip + original_resnet_block_idx = 2 + else: + original_resnet_block_idx = 1 + + # we add the `minus 1` because the last two resnets are stuck together in the same output block + full_resnet_prefix = ( + f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc - 1}.{original_resnet_block_idx}" + ) + + full_diffusers_resnet_prefix = f"up_blocks.{diffusers_up_block_idx}.upsamplers.0" + else: + # this is a regular resnet block + full_resnet_prefix = f"{original_up_block_prefix}.{original_up_block_idx + resnet_idx_inc}.0" + full_diffusers_resnet_prefix = f"{diffusers_resnet_prefix}.{resnet_idx_inc}" + + diffusers_checkpoint.update( + resnet_to_diffusers_checkpoint( + checkpoint, resnet_prefix=full_resnet_prefix, diffusers_resnet_prefix=full_diffusers_resnet_prefix + ) + ) + + if has_attentions: + num_attentions = len(up_block.attentions) + diffusers_attention_prefix = f"up_blocks.{diffusers_up_block_idx}.attentions" + + for attention_idx_inc in range(num_attentions): + full_attention_prefix = f"{original_up_block_prefix}.{original_up_block_idx + attention_idx_inc}.1" + full_diffusers_attention_prefix = f"{diffusers_attention_prefix}.{attention_idx_inc}" + + diffusers_checkpoint.update( + attention_to_diffusers_checkpoint( + checkpoint, + attention_prefix=full_attention_prefix, + diffusers_attention_prefix=full_diffusers_attention_prefix, + num_head_channels=num_head_channels, + ) + ) + + num_original_down_blocks = num_resnets - 1 if upsampler else num_resnets + + return diffusers_checkpoint, num_original_down_blocks + + +def resnet_to_diffusers_checkpoint(checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + diffusers_checkpoint = { + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.in_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.in_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.in_layers.2.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.in_layers.2.bias"], + f"{diffusers_resnet_prefix}.time_emb_proj.weight": checkpoint[f"{resnet_prefix}.emb_layers.1.weight"], + f"{diffusers_resnet_prefix}.time_emb_proj.bias": checkpoint[f"{resnet_prefix}.emb_layers.1.bias"], + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.out_layers.0.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.out_layers.0.bias"], + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.out_layers.3.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.out_layers.3.bias"], + } + + skip_connection_prefix = f"{resnet_prefix}.skip_connection" + + if f"{skip_connection_prefix}.weight" in checkpoint: + diffusers_checkpoint.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{skip_connection_prefix}.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{skip_connection_prefix}.bias"], + } + ) + + return diffusers_checkpoint + + +def attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix, num_head_channels): + diffusers_checkpoint = {} + + # .norm -> .group_norm + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + } + ) + + # .qkv -> .{query, key, value} + [q_weight, k_weight, v_weight], [q_bias, k_bias, v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.qkv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.qkv.bias"], + split=3, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_q.weight": q_weight, + f"{diffusers_attention_prefix}.to_q.bias": q_bias, + f"{diffusers_attention_prefix}.to_k.weight": k_weight, + f"{diffusers_attention_prefix}.to_k.bias": k_bias, + f"{diffusers_attention_prefix}.to_v.weight": v_weight, + f"{diffusers_attention_prefix}.to_v.bias": v_bias, + } + ) + + # .encoder_kv -> .{context_key, context_value} + [encoder_k_weight, encoder_v_weight], [encoder_k_bias, encoder_v_bias] = split_attentions( + weight=checkpoint[f"{attention_prefix}.encoder_kv.weight"][:, :, 0], + bias=checkpoint[f"{attention_prefix}.encoder_kv.bias"], + split=2, + chunk_size=num_head_channels, + ) + + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.add_k_proj.weight": encoder_k_weight, + f"{diffusers_attention_prefix}.add_k_proj.bias": encoder_k_bias, + f"{diffusers_attention_prefix}.add_v_proj.weight": encoder_v_weight, + f"{diffusers_attention_prefix}.add_v_proj.bias": encoder_v_bias, + } + ) + + # .proj_out (1d conv) -> .proj_attn (linear) + diffusers_checkpoint.update( + { + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ + :, :, 0 + ], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + ) + + return diffusers_checkpoint + + +# TODO maybe document and/or can do more efficiently (build indices in for loop and extract once for each split?) +def split_attentions(*, weight, bias, split, chunk_size): + weights = [None] * split + biases = [None] * split + + weights_biases_idx = 0 + + for starting_row_index in range(0, weight.shape[0], chunk_size): + row_indices = torch.arange(starting_row_index, starting_row_index + chunk_size) + + weight_rows = weight[row_indices, :] + bias_rows = bias[row_indices] + + if weights[weights_biases_idx] is None: + assert weights[weights_biases_idx] is None + weights[weights_biases_idx] = weight_rows + biases[weights_biases_idx] = bias_rows + else: + assert weights[weights_biases_idx] is not None + weights[weights_biases_idx] = torch.concat([weights[weights_biases_idx], weight_rows]) + biases[weights_biases_idx] = torch.concat([biases[weights_biases_idx], bias_rows]) + + weights_biases_idx = (weights_biases_idx + 1) % split + + return weights, biases + + +# done unet utils + + +def prior(*, args, checkpoint_map_location): + print("loading prior") + + prior_checkpoint = torch.load(args.prior_checkpoint_path, map_location=checkpoint_map_location) + + clip_stats_checkpoint = torch.load(args.clip_stat_path, map_location=checkpoint_map_location) + + prior_model = prior_model_from_original_config() + + prior_diffusers_checkpoint = prior_original_checkpoint_to_diffusers_checkpoint( + prior_model, prior_checkpoint, clip_stats_checkpoint + ) + + del prior_checkpoint + del clip_stats_checkpoint + + load_checkpoint_to_model(prior_diffusers_checkpoint, prior_model, strict=True) + + print("done loading prior") + + return prior_model + + +def text2img(*, args, checkpoint_map_location): + print("loading text2img") + + text2img_checkpoint = torch.load(args.text2img_checkpoint_path, map_location=checkpoint_map_location) + + unet_model = unet_model_from_original_config() + + unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint) + + # text proj interlude + + # The original decoder implementation includes a set of parameters that are used + # for creating the `encoder_hidden_states` which are what the U-net is conditioned + # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull + # the parameters into the KandinskyTextProjModel class + text_proj_model = text_proj_from_original_config() + + text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint) + + load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) + + del text2img_checkpoint + + load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True) + + print("done loading text2img") + + return unet_model, text_proj_model + + +def inpaint_text2img(*, args, checkpoint_map_location): + print("loading inpaint text2img") + + inpaint_text2img_checkpoint = torch.load( + args.inpaint_text2img_checkpoint_path, map_location=checkpoint_map_location + ) + + inpaint_unet_model = inpaint_unet_model_from_original_config() + + inpaint_unet_diffusers_checkpoint = inpaint_unet_original_checkpoint_to_diffusers_checkpoint( + inpaint_unet_model, inpaint_text2img_checkpoint + ) + + # text proj interlude + + # The original decoder implementation includes a set of parameters that are used + # for creating the `encoder_hidden_states` which are what the U-net is conditioned + # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull + # the parameters into the KandinskyTextProjModel class + text_proj_model = text_proj_from_original_config() + + text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint) + + load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) + + del inpaint_text2img_checkpoint + + load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True) + + print("done loading inpaint text2img") + + return inpaint_unet_model, text_proj_model + + +# movq + +MOVQ_CONFIG = { + "in_channels": 3, + "out_channels": 3, + "latent_channels": 4, + "down_block_types": ("DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "AttnDownEncoderBlock2D"), + "up_block_types": ("AttnUpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"), + "num_vq_embeddings": 16384, + "block_out_channels": (128, 256, 256, 512), + "vq_embed_dim": 4, + "layers_per_block": 2, + "norm_type": "spatial", +} + + +def movq_model_from_original_config(): + movq = VQModel(**MOVQ_CONFIG) + return movq + + +def movq_encoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv_in + diffusers_checkpoint.update( + { + "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], + "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], + } + ) + + # down_blocks + for down_block_idx, down_block in enumerate(model.encoder.down_blocks): + diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" + down_block_prefix = f"encoder.down.{down_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(down_block.resnets): + diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # downsample + + # do not include the downsample when on the last down block + # There is no downsample on the last down block + if down_block_idx != len(model.encoder.down_blocks) - 1: + # There's a single downsample in the original checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" + downsample_prefix = f"{down_block_prefix}.downsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(down_block, "attentions"): + for attention_idx, _ in enumerate(down_block.attentions): + diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion encoder + diffusers_attention_prefix = "encoder.mid_block.attentions.0" + attention_prefix = "encoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion encoder + resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], + "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], + # conv_out + "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], + "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_decoder_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + + # conv in + diffusers_checkpoint.update( + { + "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], + "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], + } + ) + + # up_blocks + + for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): + # up_blocks are stored in reverse order in the VQ-diffusion checkpoint + orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx + + diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" + up_block_prefix = f"decoder.up.{orig_up_block_idx}" + + # resnets + for resnet_idx, resnet in enumerate(up_block.resnets): + diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" + resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + # upsample + + # there is no up sample on the last up block + if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: + # There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples + # in the diffusers model. + diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" + downsample_prefix = f"{up_block_prefix}.upsample.conv" + diffusers_checkpoint.update( + { + f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], + f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], + } + ) + + # attentions + + if hasattr(up_block, "attentions"): + for attention_idx, _ in enumerate(up_block.attentions): + diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" + attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, + diffusers_attention_prefix=diffusers_attention_prefix, + attention_prefix=attention_prefix, + ) + ) + + # mid block + + # mid block attentions + + # There is a single hardcoded attention block in the middle of the VQ-diffusion decoder + diffusers_attention_prefix = "decoder.mid_block.attentions.0" + attention_prefix = "decoder.mid.attn_1" + diffusers_checkpoint.update( + movq_attention_to_diffusers_checkpoint_spatial_norm( + checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix + ) + ) + + # mid block resnets + + for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): + diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" + + # the hardcoded prefixes to `block_` are 1 and 2 + orig_resnet_idx = diffusers_resnet_idx + 1 + # There are two hardcoded resnets in the middle of the VQ-diffusion decoder + resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" + + diffusers_checkpoint.update( + movq_resnet_to_diffusers_checkpoint_spatial_norm( + resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix + ) + ) + + diffusers_checkpoint.update( + { + # conv_norm_out + "decoder.conv_norm_out.norm_layer.weight": checkpoint["decoder.norm_out.norm_layer.weight"], + "decoder.conv_norm_out.norm_layer.bias": checkpoint["decoder.norm_out.norm_layer.bias"], + "decoder.conv_norm_out.conv_y.weight": checkpoint["decoder.norm_out.conv_y.weight"], + "decoder.conv_norm_out.conv_y.bias": checkpoint["decoder.norm_out.conv_y.bias"], + "decoder.conv_norm_out.conv_b.weight": checkpoint["decoder.norm_out.conv_b.weight"], + "decoder.conv_norm_out.conv_b.bias": checkpoint["decoder.norm_out.conv_b.bias"], + # conv_out + "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], + "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], + } + ) + + return diffusers_checkpoint + + +def movq_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], + f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], + f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def movq_resnet_to_diffusers_checkpoint_spatial_norm(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): + rv = { + # norm1 + f"{diffusers_resnet_prefix}.norm1.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm1.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm1.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm1.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_y.weight": checkpoint[f"{resnet_prefix}.norm1.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_y.bias": checkpoint[f"{resnet_prefix}.norm1.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm1.conv_b.weight": checkpoint[f"{resnet_prefix}.norm1.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm1.conv_b.bias": checkpoint[f"{resnet_prefix}.norm1.conv_b.bias"], + # conv1 + f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], + f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], + # norm2 + f"{diffusers_resnet_prefix}.norm2.norm_layer.weight": checkpoint[f"{resnet_prefix}.norm2.norm_layer.weight"], + f"{diffusers_resnet_prefix}.norm2.norm_layer.bias": checkpoint[f"{resnet_prefix}.norm2.norm_layer.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_y.weight": checkpoint[f"{resnet_prefix}.norm2.conv_y.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_y.bias": checkpoint[f"{resnet_prefix}.norm2.conv_y.bias"], + f"{diffusers_resnet_prefix}.norm2.conv_b.weight": checkpoint[f"{resnet_prefix}.norm2.conv_b.weight"], + f"{diffusers_resnet_prefix}.norm2.conv_b.bias": checkpoint[f"{resnet_prefix}.norm2.conv_b.bias"], + # conv2 + f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], + f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], + } + + if resnet.conv_shortcut is not None: + rv.update( + { + f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], + f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], + } + ) + + return rv + + +def movq_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], + f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +def movq_attention_to_diffusers_checkpoint_spatial_norm(checkpoint, *, diffusers_attention_prefix, attention_prefix): + return { + # norm + f"{diffusers_attention_prefix}.spatial_norm.norm_layer.weight": checkpoint[ + f"{attention_prefix}.norm.norm_layer.weight" + ], + f"{diffusers_attention_prefix}.spatial_norm.norm_layer.bias": checkpoint[ + f"{attention_prefix}.norm.norm_layer.bias" + ], + f"{diffusers_attention_prefix}.spatial_norm.conv_y.weight": checkpoint[ + f"{attention_prefix}.norm.conv_y.weight" + ], + f"{diffusers_attention_prefix}.spatial_norm.conv_y.bias": checkpoint[f"{attention_prefix}.norm.conv_y.bias"], + f"{diffusers_attention_prefix}.spatial_norm.conv_b.weight": checkpoint[ + f"{attention_prefix}.norm.conv_b.weight" + ], + f"{diffusers_attention_prefix}.spatial_norm.conv_b.bias": checkpoint[f"{attention_prefix}.norm.conv_b.bias"], + # query + f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.q.bias"], + # key + f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.k.bias"], + # value + f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.v.bias"], + # proj_attn + f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][:, :, 0, 0], + f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], + } + + +def movq_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): + diffusers_checkpoint = {} + diffusers_checkpoint.update(movq_encoder_to_diffusers_checkpoint(model, checkpoint)) + + # quant_conv + + diffusers_checkpoint.update( + { + "quant_conv.weight": checkpoint["quant_conv.weight"], + "quant_conv.bias": checkpoint["quant_conv.bias"], + } + ) + + # quantize + diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding.weight"]}) + + # post_quant_conv + diffusers_checkpoint.update( + { + "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], + "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], + } + ) + + # decoder + diffusers_checkpoint.update(movq_decoder_to_diffusers_checkpoint(model, checkpoint)) + + return diffusers_checkpoint + + +def movq(*, args, checkpoint_map_location): + print("loading movq") + + movq_checkpoint = torch.load(args.movq_checkpoint_path, map_location=checkpoint_map_location) + + movq_model = movq_model_from_original_config() + + movq_diffusers_checkpoint = movq_original_checkpoint_to_diffusers_checkpoint(movq_model, movq_checkpoint) + + del movq_checkpoint + + load_checkpoint_to_model(movq_diffusers_checkpoint, movq_model, strict=True) + + print("done loading movq") + + return movq_model + + +def load_checkpoint_to_model(checkpoint, model, strict=False): + with tempfile.NamedTemporaryFile(delete=False) as file: + torch.save(checkpoint, file.name) + del checkpoint + if strict: + model.load_state_dict(torch.load(file.name), strict=True) + else: + load_checkpoint_and_dispatch(model, file.name, device_map="auto") + os.remove(file.name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + + parser.add_argument( + "--prior_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the prior checkpoint to convert.", + ) + parser.add_argument( + "--clip_stat_path", + default=None, + type=str, + required=False, + help="Path to the clip stats checkpoint to convert.", + ) + parser.add_argument( + "--text2img_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the text2img checkpoint to convert.", + ) + parser.add_argument( + "--movq_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the text2img checkpoint to convert.", + ) + parser.add_argument( + "--inpaint_text2img_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to the inpaint text2img checkpoint to convert.", + ) + parser.add_argument( + "--checkpoint_load_device", + default="cpu", + type=str, + required=False, + help="The device passed to `map_location` when loading checkpoints.", + ) + + parser.add_argument( + "--debug", + default=None, + type=str, + required=False, + help="Only run a specific stage of the convert script. Used for debugging", + ) + + args = parser.parse_args() + + print(f"loading checkpoints to {args.checkpoint_load_device}") + + checkpoint_map_location = torch.device(args.checkpoint_load_device) + + if args.debug is not None: + print(f"debug: only executing {args.debug}") + + if args.debug is None: + print("to-do") + elif args.debug == "prior": + prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) + prior_model.save_pretrained(args.dump_path) + elif args.debug == "text2img": + unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) + unet_model.save_pretrained(f"{args.dump_path}/unet") + text_proj_model.save_pretrained(f"{args.dump_path}/text_proj") + elif args.debug == "inpaint_text2img": + inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img( + args=args, checkpoint_map_location=checkpoint_map_location + ) + inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") + inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj") + elif args.debug == "decoder": + decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) + decoder.save_pretrained(f"{args.dump_path}/decoder") + else: + raise ValueError(f"unknown debug value : {args.debug}") diff --git a/scripts/convert_original_audioldm_to_diffusers.py b/scripts/convert_original_audioldm_to_diffusers.py index 189b165c0a01..a0d154d7e6ba 100644 --- a/scripts/convert_original_audioldm_to_diffusers.py +++ b/scripts/convert_original_audioldm_to_diffusers.py @@ -774,6 +774,8 @@ def load_pipeline_from_original_audioldm_ckpt( extract_ema: bool = False, scheduler_type: str = "ddim", num_in_channels: int = None, + model_channels: int = None, + num_head_channels: int = None, device: str = None, from_safetensors: bool = False, ) -> AudioLDMPipeline: @@ -784,23 +786,36 @@ def load_pipeline_from_original_audioldm_ckpt( global step count, which will likely fail for models that have undergone further fine-tuning. Therefore, it is recommended that you override the default values and/or supply an `original_config_file` wherever possible. - :param checkpoint_path: Path to `.ckpt` file. :param original_config_file: Path to `.yaml` config file - corresponding to the original architecture. - If `None`, will be automatically instantiated based on default values. - :param image_size: The image size that the model was trained on. Use 512 for original AudioLDM checkpoints. :param - prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for original - AudioLDM checkpoints. - :param num_in_channels: The number of input channels. If `None` number of input channels will be automatically - inferred. - :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", - "euler-ancestral", "dpm", "ddim"]`. - :param extract_ema: Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract - the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually - yield higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning. - :param device: The device to use. Pass `None` to determine automatically. :param from_safetensors: If - `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors - instead of PyTorch. - :return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. + Args: + checkpoint_path (`str`): Path to `.ckpt` file. + original_config_file (`str`): + Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically + set to the audioldm-s-full-v2 config. + image_size (`int`, *optional*, defaults to 512): + The image size that the model was trained on. + prediction_type (`str`, *optional*): + The prediction type that the model was trained on. If `None`, will be automatically + inferred by looking for a key in the config. For the default config, the prediction type is `'epsilon'`. + num_in_channels (`int`, *optional*, defaults to None): + The number of UNet input channels. If `None`, it will be automatically inferred from the config. + model_channels (`int`, *optional*, defaults to None): + The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override + to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large. + num_head_channels (`int`, *optional*, defaults to None): + The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override + to 32 for the small and medium checkpoints, and 64 for the large. + scheduler_type (`str`, *optional*, defaults to 'pndm'): + Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", "euler-ancestral", "dpm", + "ddim"]`. + extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for + checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to + `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for + inference. Non-EMA weights are usually better to continue fine-tuning. + device (`str`, *optional*, defaults to `None`): + The device to use. Pass `None` to determine automatically. + from_safetensors (`str`, *optional*, defaults to `False`): + If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch. + return: An AudioLDMPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ if not is_omegaconf_available(): @@ -837,6 +852,12 @@ def load_pipeline_from_original_audioldm_ckpt( if num_in_channels is not None: original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + if model_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["model_channels"] = model_channels + + if num_head_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = num_head_channels + if ( "parameterization" in original_config["model"]["params"] and original_config["model"]["params"]["parameterization"] == "v" @@ -960,6 +981,20 @@ def load_pipeline_from_original_audioldm_ckpt( type=int, help="The number of input channels. If `None` number of input channels will be automatically inferred.", ) + parser.add_argument( + "--model_channels", + default=None, + type=int, + help="The number of UNet model channels. If `None`, it will be automatically inferred from the config. Override" + " to 128 for the small checkpoints, 192 for the medium checkpoints and 256 for the large.", + ) + parser.add_argument( + "--num_head_channels", + default=None, + type=int, + help="The number of UNet head channels. If `None`, it will be automatically inferred from the config. Override" + " to 32 for the small and medium checkpoints, and 64 for the large.", + ) parser.add_argument( "--scheduler_type", default="ddim", @@ -1009,6 +1044,8 @@ def load_pipeline_from_original_audioldm_ckpt( extract_ema=args.extract_ema, scheduler_type=args.scheduler_type, num_in_channels=args.num_in_channels, + model_channels=args.model_channels, + num_head_channels=args.num_head_channels, from_safetensors=args.from_safetensors, device=args.device, ) diff --git a/scripts/convert_original_controlnet_to_diffusers.py b/scripts/convert_original_controlnet_to_diffusers.py index a9e05abd4cf1..9466bd27234c 100644 --- a/scripts/convert_original_controlnet_to_diffusers.py +++ b/scripts/convert_original_controlnet_to_diffusers.py @@ -75,6 +75,22 @@ ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + + # small workaround to get argparser to parse a boolean input as either true _or_ false + def parse_bool(string): + if string == "True": + return True + elif string == "False": + return False + else: + raise ValueError(f"could not parse string as bool {string}") + + parser.add_argument( + "--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool + ) + + parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int) + args = parser.parse_args() controlnet = download_controlnet_from_original_ckpt( @@ -86,6 +102,8 @@ upcast_attention=args.upcast_attention, from_safetensors=args.from_safetensors, device=args.device, + use_linear_projection=args.use_linear_projection, + cross_attention_dim=args.cross_attention_dim, ) controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/scripts/convert_unidiffuser_to_diffusers.py b/scripts/convert_unidiffuser_to_diffusers.py new file mode 100644 index 000000000000..891d289d8c76 --- /dev/null +++ b/scripts/convert_unidiffuser_to_diffusers.py @@ -0,0 +1,776 @@ +# Convert the original UniDiffuser checkpoints into diffusers equivalents. + +import argparse +from argparse import Namespace + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from diffusers import ( + AutoencoderKL, + DPMSolverMultistepScheduler, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, +) + + +SCHEDULER_CONFIG = Namespace( + **{ + "beta_start": 0.00085, + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "solver_order": 3, + } +) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_resnet_paths +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.renew_vae_attention_paths +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +# Modified from diffusers.pipelines.stable_diffusion.convert_from_ckpt.assign_to_checkpoint +# config.num_head_channels => num_head_channels +def assign_to_checkpoint( + paths, + checkpoint, + old_checkpoint, + attention_paths_to_split=None, + additional_replacements=None, + num_head_channels=1, +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits + attention layers, and takes into account additional replacements that may arise. Assigns the weights to the new + checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // num_head_channels // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +# Copied from diffusers.pipelines.stable_diffusion.convert_from_ckpt.conv_attn_to_linear +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def create_vae_diffusers_config(config_type): + # Hardcoded for now + if args.config_type == "test": + vae_config = create_vae_diffusers_config_test() + elif args.config_type == "big": + vae_config = create_vae_diffusers_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + return vae_config + + +def create_unidiffuser_unet_config(config_type, version): + # Hardcoded for now + if args.config_type == "test": + unet_config = create_unidiffuser_unet_config_test() + elif args.config_type == "big": + unet_config = create_unidiffuser_unet_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + # Unidiffuser-v1 uses data type embeddings + if version == 1: + unet_config["use_data_type_embedding"] = True + return unet_config + + +def create_text_decoder_config(config_type): + # Hardcoded for now + if args.config_type == "test": + text_decoder_config = create_text_decoder_config_test() + elif args.config_type == "big": + text_decoder_config = create_text_decoder_config_big() + else: + raise NotImplementedError( + f"Config type {config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + return text_decoder_config + + +# Hardcoded configs for test versions of the UniDiffuser models, corresponding to those in the fast default tests. +def create_vae_diffusers_config_test(): + vae_config = { + "sample_size": 32, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D"], + "block_out_channels": [32, 64], + "latent_channels": 4, + "layers_per_block": 1, + } + return vae_config + + +def create_unidiffuser_unet_config_test(): + unet_config = { + "text_dim": 32, + "clip_img_dim": 32, + "num_text_tokens": 77, + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "num_layers": 2, + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": False, + "sample_size": 16, + "patch_size": 2, + "activation_fn": "gelu", + "num_embeds_ada_norm": 1000, + "norm_type": "layer_norm", + "block_type": "unidiffuser", + "pre_layer_norm": False, + "use_timestep_embedding": False, + "norm_elementwise_affine": True, + "use_patch_pos_embed": False, + "ff_final_dropout": True, + "use_data_type_embedding": False, + } + return unet_config + + +def create_text_decoder_config_test(): + text_decoder_config = { + "prefix_length": 77, + "prefix_inner_dim": 32, + "prefix_hidden_dim": 32, + "vocab_size": 1025, # 1024 + 1 for new EOS token + "n_positions": 1024, + "n_embd": 32, + "n_layer": 5, + "n_head": 4, + "n_inner": 37, + "activation_function": "gelu", + "resid_pdrop": 0.1, + "embd_pdrop": 0.1, + "attn_pdrop": 0.1, + "layer_norm_epsilon": 1e-5, + "initializer_range": 0.02, + } + return text_decoder_config + + +# Hardcoded configs for the UniDiffuser V1 model at https://huggingface.co/thu-ml/unidiffuser-v1 +# See also https://github.com/thu-ml/unidiffuser/blob/main/configs/sample_unidiffuser_v1.py +def create_vae_diffusers_config_big(): + vae_config = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 3, + "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"], + "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"], + "block_out_channels": [128, 256, 512, 512], + "latent_channels": 4, + "layers_per_block": 2, + } + return vae_config + + +def create_unidiffuser_unet_config_big(): + unet_config = { + "text_dim": 64, + "clip_img_dim": 512, + "num_text_tokens": 77, + "num_attention_heads": 24, + "attention_head_dim": 64, + "in_channels": 4, + "out_channels": 4, + "num_layers": 30, + "dropout": 0.0, + "norm_num_groups": 32, + "attention_bias": False, + "sample_size": 64, + "patch_size": 2, + "activation_fn": "gelu", + "num_embeds_ada_norm": 1000, + "norm_type": "layer_norm", + "block_type": "unidiffuser", + "pre_layer_norm": False, + "use_timestep_embedding": False, + "norm_elementwise_affine": True, + "use_patch_pos_embed": False, + "ff_final_dropout": True, + "use_data_type_embedding": False, + } + return unet_config + + +# From https://huggingface.co/gpt2/blob/main/config.json, the GPT2 checkpoint used by UniDiffuser +def create_text_decoder_config_big(): + text_decoder_config = { + "prefix_length": 77, + "prefix_inner_dim": 768, + "prefix_hidden_dim": 64, + "vocab_size": 50258, # 50257 + 1 for new EOS token + "n_positions": 1024, + "n_embd": 768, + "n_layer": 12, + "n_head": 12, + "n_inner": 3072, + "activation_function": "gelu", + "resid_pdrop": 0.1, + "embd_pdrop": 0.1, + "attn_pdrop": 0.1, + "layer_norm_epsilon": 1e-5, + "initializer_range": 0.02, + } + return text_decoder_config + + +# Based on diffusers.pipelines.stable_diffusion.convert_from_ckpt.shave_segments.convert_ldm_vae_checkpoint +def convert_vae_to_diffusers(ckpt, diffusers_model, num_head_channels=1): + """ + Converts a UniDiffuser autoencoder_kl.pth checkpoint to a diffusers AutoencoderKL. + """ + # autoencoder_kl.pth ckpt is a torch state dict + vae_state_dict = torch.load(ckpt, map_location="cpu") + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint( + paths, + new_checkpoint, + vae_state_dict, + additional_replacements=[meta_path], + num_head_channels=num_head_channels, # not used in vae + ) + conv_attn_to_linear(new_checkpoint) + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_checkpoint) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +def convert_uvit_block_to_diffusers_block( + uvit_state_dict, + new_state_dict, + block_prefix, + new_prefix="transformer.transformer_", + skip_connection=False, +): + """ + Maps the keys in a UniDiffuser transformer block (`Block`) to the keys in a diffusers transformer block + (`UTransformerBlock`/`UniDiffuserBlock`). + """ + prefix = new_prefix + block_prefix + if skip_connection: + new_state_dict[prefix + ".skip.skip_linear.weight"] = uvit_state_dict[block_prefix + ".skip_linear.weight"] + new_state_dict[prefix + ".skip.skip_linear.bias"] = uvit_state_dict[block_prefix + ".skip_linear.bias"] + new_state_dict[prefix + ".skip.norm.weight"] = uvit_state_dict[block_prefix + ".norm1.weight"] + new_state_dict[prefix + ".skip.norm.bias"] = uvit_state_dict[block_prefix + ".norm1.bias"] + + # Create the prefix string for out_blocks. + prefix += ".block" + + # Split up attention qkv.weight into to_q.weight, to_k.weight, to_v.weight + qkv = uvit_state_dict[block_prefix + ".attn.qkv.weight"] + new_attn_keys = [".attn1.to_q.weight", ".attn1.to_k.weight", ".attn1.to_v.weight"] + new_attn_keys = [prefix + key for key in new_attn_keys] + shape = qkv.shape[0] // len(new_attn_keys) + for i, attn_key in enumerate(new_attn_keys): + new_state_dict[attn_key] = qkv[i * shape : (i + 1) * shape] + + new_state_dict[prefix + ".attn1.to_out.0.weight"] = uvit_state_dict[block_prefix + ".attn.proj.weight"] + new_state_dict[prefix + ".attn1.to_out.0.bias"] = uvit_state_dict[block_prefix + ".attn.proj.bias"] + new_state_dict[prefix + ".norm1.weight"] = uvit_state_dict[block_prefix + ".norm2.weight"] + new_state_dict[prefix + ".norm1.bias"] = uvit_state_dict[block_prefix + ".norm2.bias"] + new_state_dict[prefix + ".ff.net.0.proj.weight"] = uvit_state_dict[block_prefix + ".mlp.fc1.weight"] + new_state_dict[prefix + ".ff.net.0.proj.bias"] = uvit_state_dict[block_prefix + ".mlp.fc1.bias"] + new_state_dict[prefix + ".ff.net.2.weight"] = uvit_state_dict[block_prefix + ".mlp.fc2.weight"] + new_state_dict[prefix + ".ff.net.2.bias"] = uvit_state_dict[block_prefix + ".mlp.fc2.bias"] + new_state_dict[prefix + ".norm3.weight"] = uvit_state_dict[block_prefix + ".norm3.weight"] + new_state_dict[prefix + ".norm3.bias"] = uvit_state_dict[block_prefix + ".norm3.bias"] + + return uvit_state_dict, new_state_dict + + +def convert_uvit_to_diffusers(ckpt, diffusers_model): + """ + Converts a UniDiffuser uvit_v*.pth checkpoint to a diffusers UniDiffusersModel. + """ + # uvit_v*.pth ckpt is a torch state dict + uvit_state_dict = torch.load(ckpt, map_location="cpu") + + new_state_dict = {} + + # Input layers + new_state_dict["vae_img_in.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"] + new_state_dict["vae_img_in.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"] + new_state_dict["clip_img_in.weight"] = uvit_state_dict["clip_img_embed.weight"] + new_state_dict["clip_img_in.bias"] = uvit_state_dict["clip_img_embed.bias"] + new_state_dict["text_in.weight"] = uvit_state_dict["text_embed.weight"] + new_state_dict["text_in.bias"] = uvit_state_dict["text_embed.bias"] + + new_state_dict["pos_embed"] = uvit_state_dict["pos_embed"] + + # Handle data type token embeddings for UniDiffuser-v1 + if "token_embedding.weight" in uvit_state_dict and diffusers_model.use_data_type_embedding: + new_state_dict["data_type_pos_embed_token"] = uvit_state_dict["pos_embed_token"] + new_state_dict["data_type_token_embedding.weight"] = uvit_state_dict["token_embedding.weight"] + + # Also initialize the PatchEmbedding in UTransformer2DModel with the PatchEmbedding from the checkpoint. + # This isn't used in the current implementation, so might want to remove. + new_state_dict["transformer.pos_embed.proj.weight"] = uvit_state_dict["patch_embed.proj.weight"] + new_state_dict["transformer.pos_embed.proj.bias"] = uvit_state_dict["patch_embed.proj.bias"] + + # Output layers + new_state_dict["transformer.norm_out.weight"] = uvit_state_dict["norm.weight"] + new_state_dict["transformer.norm_out.bias"] = uvit_state_dict["norm.bias"] + + new_state_dict["vae_img_out.weight"] = uvit_state_dict["decoder_pred.weight"] + new_state_dict["vae_img_out.bias"] = uvit_state_dict["decoder_pred.bias"] + new_state_dict["clip_img_out.weight"] = uvit_state_dict["clip_img_out.weight"] + new_state_dict["clip_img_out.bias"] = uvit_state_dict["clip_img_out.bias"] + new_state_dict["text_out.weight"] = uvit_state_dict["text_out.weight"] + new_state_dict["text_out.bias"] = uvit_state_dict["text_out.bias"] + + # in_blocks + in_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "in_blocks" in layer} + for in_block_prefix in list(in_blocks_prefixes): + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, in_block_prefix) + + # mid_block + # Assume there's only one mid block + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, "mid_block") + + # out_blocks + out_blocks_prefixes = {".".join(layer.split(".")[:2]) for layer in uvit_state_dict if "out_blocks" in layer} + for out_block_prefix in list(out_blocks_prefixes): + convert_uvit_block_to_diffusers_block(uvit_state_dict, new_state_dict, out_block_prefix, skip_connection=True) + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +def convert_caption_decoder_to_diffusers(ckpt, diffusers_model): + """ + Converts a UniDiffuser caption_decoder.pth checkpoint to a diffusers UniDiffuserTextDecoder. + """ + # caption_decoder.pth ckpt is a torch state dict + checkpoint_state_dict = torch.load(ckpt, map_location="cpu") + decoder_state_dict = {} + # Remove the "module." prefix, if necessary + caption_decoder_key = "module." + for key in checkpoint_state_dict: + if key.startswith(caption_decoder_key): + decoder_state_dict[key.replace(caption_decoder_key, "")] = checkpoint_state_dict.get(key) + else: + decoder_state_dict[key] = checkpoint_state_dict.get(key) + + new_state_dict = {} + + # Encoder and Decoder + new_state_dict["encode_prefix.weight"] = decoder_state_dict["encode_prefix.weight"] + new_state_dict["encode_prefix.bias"] = decoder_state_dict["encode_prefix.bias"] + new_state_dict["decode_prefix.weight"] = decoder_state_dict["decode_prefix.weight"] + new_state_dict["decode_prefix.bias"] = decoder_state_dict["decode_prefix.bias"] + + # Internal GPT2LMHeadModel transformer model + for key, val in decoder_state_dict.items(): + if key.startswith("gpt"): + suffix = key[len("gpt") :] + new_state_dict["transformer" + suffix] = val + + missing_keys, unexpected_keys = diffusers_model.load_state_dict(new_state_dict) + for missing_key in missing_keys: + print(f"Missing key: {missing_key}") + for unexpected_key in unexpected_keys: + print(f"Unexpected key: {unexpected_key}") + + return diffusers_model + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--caption_decoder_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to caption decoder checkpoint to convert.", + ) + parser.add_argument( + "--uvit_checkpoint_path", default=None, type=str, required=False, help="Path to U-ViT checkpoint to convert." + ) + parser.add_argument( + "--vae_checkpoint_path", + default=None, + type=str, + required=False, + help="Path to VAE checkpoint to convert.", + ) + parser.add_argument( + "--pipeline_output_path", + default=None, + type=str, + required=True, + help="Path to save the output pipeline to.", + ) + parser.add_argument( + "--config_type", + default="test", + type=str, + help=( + "Config type to use. Should be 'test' to create small models for testing or 'big' to convert a full" + " checkpoint." + ), + ) + parser.add_argument( + "--version", + default=0, + type=int, + help="The UniDiffuser model type to convert to. Should be 0 for UniDiffuser-v0 and 1 for UniDiffuser-v1.", + ) + + args = parser.parse_args() + + # Convert the VAE model. + if args.vae_checkpoint_path is not None: + vae_config = create_vae_diffusers_config(args.config_type) + vae = AutoencoderKL(**vae_config) + vae = convert_vae_to_diffusers(args.vae_checkpoint_path, vae) + + # Convert the U-ViT ("unet") model. + if args.uvit_checkpoint_path is not None: + unet_config = create_unidiffuser_unet_config(args.config_type, args.version) + unet = UniDiffuserModel(**unet_config) + unet = convert_uvit_to_diffusers(args.uvit_checkpoint_path, unet) + + # Convert the caption decoder ("text_decoder") model. + if args.caption_decoder_checkpoint_path is not None: + text_decoder_config = create_text_decoder_config(args.config_type) + text_decoder = UniDiffuserTextDecoder(**text_decoder_config) + text_decoder = convert_caption_decoder_to_diffusers(args.caption_decoder_checkpoint_path, text_decoder) + + # Scheduler is the same for both the test and big models. + scheduler_config = SCHEDULER_CONFIG + scheduler = DPMSolverMultistepScheduler( + beta_start=scheduler_config.beta_start, + beta_end=scheduler_config.beta_end, + beta_schedule=scheduler_config.beta_schedule, + solver_order=scheduler_config.solver_order, + ) + + if args.config_type == "test": + # Make a small random CLIPTextModel + torch.manual_seed(0) + clip_text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(clip_text_encoder_config) + clip_tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + # Make a small random CLIPVisionModel and accompanying CLIPImageProcessor + torch.manual_seed(0) + clip_image_encoder_config = CLIPVisionConfig( + image_size=32, + patch_size=2, + num_channels=3, + hidden_size=32, + projection_dim=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + dropout=0.1, + attention_dropout=0.1, + initializer_range=0.02, + ) + image_encoder = CLIPVisionModelWithProjection(clip_image_encoder_config) + image_processor = CLIPImageProcessor(crop_size=32, size=32) + + # Note that the text_decoder should already have its token embeddings resized. + text_tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-GPT2Model") + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + text_tokenizer.add_special_tokens(special_tokens_dict) + elif args.config_type == "big": + text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + + image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-base-patch32") + image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32") + + # Note that the text_decoder should already have its token embeddings resized. + text_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + eos = "<|EOS|>" + special_tokens_dict = {"eos_token": eos} + text_tokenizer.add_special_tokens(special_tokens_dict) + else: + raise NotImplementedError( + f"Config type {args.config_type} is not implemented, currently only config types" + " 'test' and 'big' are available." + ) + + pipeline = UniDiffuserPipeline( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + pipeline.save_pretrained(args.pipeline_output_path) diff --git a/setup.py b/setup.py index 5c26b246aa01..c718f221f763 100644 --- a/setup.py +++ b/setup.py @@ -95,8 +95,8 @@ "Jinja2", "k-diffusion>=0.0.12", "librosa", - "note-seq", "numpy", + "omegaconf", "parameterized", "protobuf>=3.20.3,<4", "pytest", @@ -112,6 +112,7 @@ "torch>=1.4", "torchvision", "transformers>=4.25.1", + "urllib3<=2.0.0", ] # this is a lookup table with items like: @@ -182,7 +183,7 @@ def run(self): extras = {} -extras["quality"] = deps_list("black", "isort", "ruff", "hf-doc-builder") +extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder") extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2") extras["test"] = deps_list( @@ -191,7 +192,7 @@ def run(self): "Jinja2", "k-diffusion", "librosa", - "note-seq", + "omegaconf", "parameterized", "pytest", "pytest-timeout", @@ -226,7 +227,7 @@ def run(self): setup( name="diffusers", - version="0.16.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + version="0.17.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) description="Diffusers", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index bb7381d65a54..1dd1cd31b7c2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.16.1" +__version__ = "0.17.0" from .configuration_utils import ConfigMixin from .utils import ( @@ -12,6 +12,7 @@ is_onnx_available, is_scipy_available, is_torch_available, + is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, @@ -75,6 +76,7 @@ DDIMScheduler, DDPMScheduler, DEISMultistepScheduler, + DPMSolverMultistepInverseScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, @@ -102,6 +104,13 @@ else: from .schedulers import LMSDiscreteScheduler +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 +else: + from .schedulers import DPMSolverSDEScheduler try: if not (is_torch_available() and is_transformers_available()): @@ -120,12 +129,20 @@ IFInpaintingSuperResolutionPipeline, IFPipeline, IFSuperResolutionPipeline, + ImageTextPipelineOutput, + KandinskyImg2ImgPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, LDMTextToImagePipeline, PaintByExamplePipeline, SemanticStableDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, @@ -145,6 +162,9 @@ TextToVideoZeroPipeline, UnCLIPImageVariationPipeline, UnCLIPPipeline, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, VersatileDiffusionPipeline, diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index af639de306ee..bb5adf3e9444 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -160,7 +160,7 @@ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool @classmethod def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" - Instantiate a Python class from a config dictionary + Instantiate a Python class from a config dictionary. Parameters: config (`Dict[str, Any]`): @@ -170,9 +170,13 @@ def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_un Whether kwargs that are not consumed by the Python class should be returned or not. kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the Python class. - `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually - overwrite same named arguments of `config`. + Can be used to update the configuration object (after it is loaded) and initiate the Python class. + `**kwargs` are directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments in `config`. + + Returns: + [`ModelMixin`] or [`SchedulerMixin`]: + A model or scheduler object instantiated from a config dictionary. Examples: @@ -258,59 +262,57 @@ def load_config( **kwargs, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: r""" - Instantiate a Python class from a config dictionary + Load a model or scheduler configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an - organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ConfigMixin.save_config`], e.g., - `./my_model_directory/`. + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with + [`~ConfigMixin.save_config`]. cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. + Whether or not to resume downloading the model weights and configuration files. If set to False, any + incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). + Whether to only load local model weights and configuration files or not. If set to True, the model + won’t be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `transformers-cli login` (stored in `~/.huggingface`). + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. + The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False): - Whether unused keyword arguments of the config shall be returned. + Whether unused keyword arguments of the config are returned. return_commit_hash (`bool`, *optional*, defaults to `False): - Whether the commit_hash of the loaded configuration shall be returned. - - + Whether the `commit_hash` of the loaded configuration are returned. - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). - - + Returns: + `dict`: + A dictionary of all the parameters stored in a JSON configuration file. - Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to - use this method in a firewalled environment. + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to use this method in a + firewalled environment. """ diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index 1269cf1578a6..19a843470ee1 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -19,8 +19,8 @@ "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "librosa": "librosa", - "note-seq": "note-seq", "numpy": "numpy", + "omegaconf": "omegaconf", "parameterized": "parameterized", "protobuf": "protobuf>=3.20.3,<4", "pytest": "pytest", @@ -36,4 +36,5 @@ "torch": "torch>=1.4", "torchvision": "torchvision", "transformers": "transformers>=4.25.1", + "urllib3": "urllib3<=2.0.0", } diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 4598e1b4288c..17c083914753 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Union +from typing import List, Optional, Union import numpy as np import PIL @@ -21,7 +21,7 @@ from PIL import Image from .configuration_utils import ConfigMixin, register_to_config -from .utils import CONFIG_NAME, PIL_INTERPOLATION +from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate class VaeImageProcessor(ConfigMixin): @@ -30,7 +30,8 @@ class VaeImageProcessor(ConfigMixin): Args: do_resize (`bool`, *optional*, defaults to `True`): - Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. + Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept + `height` and `width` arguments from `preprocess` method vae_scale_factor (`int`, *optional*, defaults to `8`): VAE scale factor. If `do_resize` is True, the image will be automatically resized to multiples of this factor. @@ -38,6 +39,8 @@ class VaeImageProcessor(ConfigMixin): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1] + do_convert_rgb (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to RGB format. """ config_name = CONFIG_NAME @@ -49,11 +52,12 @@ def __init__( vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, + do_convert_rgb: bool = False, ): super().__init__() @staticmethod - def numpy_to_pil(images): + def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: """ Convert a numpy image or a batch of images to a PIL image. """ @@ -69,7 +73,19 @@ def numpy_to_pil(images): return pil_images @staticmethod - def numpy_to_pt(images): + def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray: + """ + Convert a PIL image or a list of PIL images to numpy arrays. + """ + if not isinstance(images, list): + images = [images] + images = [np.array(image).astype(np.float32) / 255.0 for image in images] + images = np.stack(images, axis=0) + + return images + + @staticmethod + def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: """ Convert a numpy image to a pytorch tensor """ @@ -80,9 +96,9 @@ def numpy_to_pt(images): return images @staticmethod - def pt_to_numpy(images): + def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: """ - Convert a numpy image to a pytorch tensor + Convert a pytorch tensor to a numpy image """ images = images.cpu().permute(0, 2, 3, 1).float().numpy() return images @@ -94,18 +110,46 @@ def normalize(images): """ return 2.0 * images - 1.0 - def resize(self, images: PIL.Image.Image) -> PIL.Image.Image: + @staticmethod + def denormalize(images): + """ + Denormalize an image array to [0,1] + """ + return (images / 2 + 0.5).clamp(0, 1) + + @staticmethod + def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: + """ + Converts an image to RGB format. + """ + image = image.convert("RGB") + return image + + def resize( + self, + image: PIL.Image.Image, + height: Optional[int] = None, + width: Optional[int] = None, + ) -> PIL.Image.Image: """ Resize a PIL image. Both height and width will be downscaled to the next integer multiple of `vae_scale_factor` """ - w, h = images.size - w, h = (x - x % self.config.vae_scale_factor for x in (w, h)) # resize to integer multiple of vae_scale_factor - images = images.resize((w, h), resample=PIL_INTERPOLATION[self.config.resample]) - return images + if height is None: + height = image.height + if width is None: + width = image.width + + width, height = ( + x - x % self.config.vae_scale_factor for x in (width, height) + ) # resize to integer multiple of vae_scale_factor + image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) + return image def preprocess( self, image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + height: Optional[int] = None, + width: Optional[int] = None, ) -> torch.Tensor: """ Preprocess the image input, accepted formats are PIL images, numpy arrays or pytorch tensors" @@ -119,10 +163,11 @@ def preprocess( ) if isinstance(image[0], PIL.Image.Image): + if self.config.do_convert_rgb: + image = [self.convert_to_rgb(i) for i in image] if self.config.do_resize: - image = [self.resize(i) for i in image] - image = [np.array(i).astype(np.float32) / 255.0 for i in image] - image = np.stack(image, axis=0) # to np + image = [self.resize(i, height, width) for i in image] + image = self.pil_to_numpy(image) # to np image = self.numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): @@ -139,7 +184,12 @@ def preprocess( elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) - _, _, height, width = image.shape + _, channel, height, width = image.shape + + # don't need any preprocess if the image is latents + if channel == 4: + return image + if self.config.do_resize and ( height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 ): @@ -165,17 +215,39 @@ def preprocess( def postprocess( self, - image, + image: torch.FloatTensor, output_type: str = "pil", + do_denormalize: Optional[List[bool]] = None, ): - if isinstance(image, torch.Tensor) and output_type == "pt": + if not isinstance(image, torch.Tensor): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor" + ) + if output_type not in ["latent", "pt", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `pt`, `latent`" + ) + deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) + output_type = "np" + + if output_type == "latent": + return image + + if do_denormalize is None: + do_denormalize = [self.config.do_normalize] * image.shape[0] + + image = torch.stack( + [self.denormalize(image[i]) if do_denormalize[i] else image[i] for i in range(image.shape[0])] + ) + + if output_type == "pt": return image image = self.pt_to_numpy(image) if output_type == "np": return image - elif output_type == "pil": + + if output_type == "pil": return self.numpy_to_pil(image) - else: - raise ValueError(f"Unsupported output_type {output_type}.") diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 0db716c012d8..6d273de5ca9d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -12,22 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Union import torch +import torch.nn.functional as F from huggingface_hub import hf_hub_download from .models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, + LoRAAttnAddedKVProcessor, LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SlicedAttnAddedKVProcessor, + XFormersAttnProcessor, ) from .utils import ( DIFFUSERS_CACHE, HF_HUB_OFFLINE, - TEXT_ENCODER_TARGET_MODULES, + TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, @@ -45,6 +54,8 @@ logger = logging.get_logger(__name__) +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" @@ -63,6 +74,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] + # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): @@ -74,10 +88,19 @@ def map_to(module, state_dict, *args, **kwargs): return new_state_dict + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + + raise ValueError( + f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." + ) + def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: - replace_key = key.split(".processor")[0] + ".processor" + replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] @@ -87,6 +110,9 @@ def map_from(module, state_dict, *args, **kwargs): class UNet2DConditionLoadersMixin: + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME + def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be @@ -158,6 +184,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -225,6 +254,18 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: + is_new_lora_format = all( + key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() + ) + if is_new_lora_format: + # Strip the `"unet"` prefix. + is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys()) + if is_text_encoder_present: + warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)." + warnings.warn(warn_message) + unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)] + state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} + lora_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) @@ -232,11 +273,31 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict for key, value_dict in lora_grouped_dict.items(): rank = value_dict["to_k_lora.down.weight"].shape[0] - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - attn_processors[key] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + attn_processor = self + for sub_key in key.split("."): + attn_processor = getattr(attn_processor, sub_key) + + if isinstance( + attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) + ): + cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] + attn_processor_class = LoRAAttnAddedKVProcessor + else: + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): + attn_processor_class = LoRAXFormersAttnProcessor + else: + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + + attn_processors[key] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) elif is_custom_diffusion: @@ -405,7 +466,8 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): `str`: The converted prompt """ tokens = tokenizer.tokenize(prompt) - for token in tokens: + unique_tokens = set(tokens) + for token in unique_tokens: if token in tokenizer.added_tokens_encoder: replacement = token i = 1 @@ -418,7 +480,10 @@ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): return prompt def load_textual_inversion( - self, pretrained_model_name_or_path: Union[str, Dict[str, torch.Tensor]], token: Optional[str] = None, **kwargs + self, + pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], + token: Optional[Union[str, List[str]]] = None, + **kwargs, ): r""" Load textual inversion embeddings into the text encoder of stable diffusion pipelines. Both `diffusers` and @@ -431,7 +496,7 @@ def load_textual_inversion(
Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): + pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`): Can be either: - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. @@ -439,6 +504,14 @@ def load_textual_inversion( `"sd-concepts-library/low-poly-hd-logos-icons"`. - A path to a *directory* containing textual inversion weights, e.g. `./my_text_inversion_directory/`. + - A path to a *file* containing textual inversion weights, e.g. `./my_text_inversions.pt`. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + Or a list of those elements. + token (`str` or `List[str]`, *optional*): + Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a + list, then `token` must also be a list of equal length. weight_name (`str`, *optional*): Name of a custom weight file. This should be used in two cases: @@ -558,108 +631,138 @@ def load_textual_inversion( "framework": "pytorch", } - # 1. Load textual inversion file - model_file = None - # Let's first try to load .safetensors weights - if (use_safetensors and weight_name is None) or ( - weight_name is not None and weight_name.endswith(".safetensors") - ): - try: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - state_dict = safetensors.torch.load_file(model_file, device="cpu") - except Exception as e: - if not allow_pickle: - raise e + if not isinstance(pretrained_model_name_or_path, list): + pretrained_model_name_or_paths = [pretrained_model_name_or_path] + else: + pretrained_model_name_or_paths = pretrained_model_name_or_path - model_file = None + if isinstance(token, str): + tokens = [token] + elif token is None: + tokens = [None] * len(pretrained_model_name_or_paths) + else: + tokens = token - if model_file is None: - model_file = _get_model_file( - pretrained_model_name_or_path, - weights_name=weight_name or TEXT_INVERSION_NAME, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, + if len(pretrained_model_name_or_paths) != len(tokens): + raise ValueError( + f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}" + f"Make sure both lists have the same length." ) - state_dict = torch.load(model_file, map_location="cpu") - # 2. Load token and embedding correcly from file - if isinstance(state_dict, torch.Tensor): - if token is None: + valid_tokens = [t for t in tokens if t is not None] + if len(set(valid_tokens)) < len(valid_tokens): + raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}") + + token_ids_and_embeddings = [] + + for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens): + if not isinstance(pretrained_model_name_or_path, dict): + # 1. Load textual inversion file + model_file = None + # Let's first try to load .safetensors weights + if (use_safetensors and weight_name is None) or ( + weight_name is not None and weight_name.endswith(".safetensors") + ): + try: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME_SAFE, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = safetensors.torch.load_file(model_file, device="cpu") + except Exception as e: + if not allow_pickle: + raise e + + model_file = None + + if model_file is None: + model_file = _get_model_file( + pretrained_model_name_or_path, + weights_name=weight_name or TEXT_INVERSION_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + state_dict = torch.load(model_file, map_location="cpu") + else: + state_dict = pretrained_model_name_or_path + + # 2. Load token and embedding correcly from file + loaded_token = None + if isinstance(state_dict, torch.Tensor): + if token is None: + raise ValueError( + "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + ) + embedding = state_dict + elif len(state_dict) == 1: + # diffusers + loaded_token, embedding = next(iter(state_dict.items())) + elif "string_to_param" in state_dict: + # A1111 + loaded_token = state_dict["name"] + embedding = state_dict["string_to_param"]["*"] + + if token is not None and loaded_token != token: + logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") + else: + token = loaded_token + + embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + + # 3. Make sure we don't mess up the tokenizer or text encoder + vocab = self.tokenizer.get_vocab() + if token in vocab: raise ValueError( - "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`." + f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." ) - embedding = state_dict - elif len(state_dict) == 1: - # diffusers - loaded_token, embedding = next(iter(state_dict.items())) - elif "string_to_param" in state_dict: - # A1111 - loaded_token = state_dict["name"] - embedding = state_dict["string_to_param"]["*"] - - if token is not None and loaded_token != token: - logger.warn(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.") - else: - token = loaded_token - - embedding = embedding.to(dtype=self.text_encoder.dtype, device=self.text_encoder.device) + elif f"{token}_1" in vocab: + multi_vector_tokens = [token] + i = 1 + while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: + multi_vector_tokens.append(f"{token}_{i}") + i += 1 - # 3. Make sure we don't mess up the tokenizer or text encoder - vocab = self.tokenizer.get_vocab() - if token in vocab: - raise ValueError( - f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder." - ) - elif f"{token}_1" in vocab: - multi_vector_tokens = [token] - i = 1 - while f"{token}_{i}" in self.tokenizer.added_tokens_encoder: - multi_vector_tokens.append(f"{token}_{i}") - i += 1 + raise ValueError( + f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." + ) - raise ValueError( - f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder." - ) + is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 - is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1 + if is_multi_vector: + tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] + embeddings = [e for e in embedding] # noqa: C416 + else: + tokens = [token] + embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] - if is_multi_vector: - tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])] - embeddings = [e for e in embedding] # noqa: C416 - else: - tokens = [token] - embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding] + # add tokens and get ids + self.tokenizer.add_tokens(tokens) + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + token_ids_and_embeddings += zip(token_ids, embeddings) - # add tokens and get ids - self.tokenizer.add_tokens(tokens) - token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + logger.info(f"Loaded textual inversion embedding for {token}.") - # resize token embeddings and set new embeddings + # resize token embeddings and set all new embeddings self.text_encoder.resize_token_embeddings(len(self.tokenizer)) - for token_id, embedding in zip(token_ids, embeddings): + for token_id, embedding in token_ids_and_embeddings: self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding - logger.info(f"Loaded textual inversion embedding for {token}.") - class LoraLoaderMixin: r""" @@ -672,8 +775,8 @@ class LoraLoaderMixin:
""" - text_encoder_name = "text_encoder" - unet_name = "unet" + text_encoder_name = TEXT_ENCODER_NAME + unet_name = UNET_NAME def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): r""" @@ -682,6 +785,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + This function is experimental and might change in the future. @@ -747,6 +852,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + # set lora scale to a reasonable default + self._lora_scale = 1.0 + if use_safetensors and not is_safetensors_available(): raise ValueError( "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors" @@ -806,25 +914,38 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di else: state_dict = pretrained_model_name_or_path_or_dict + # Convert kohya-ss Style LoRA attn procs to diffusers attn procs + network_alpha = None + if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): + state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) - - # Load the layers corresponding to UNet. - if all(key.startswith(self.unet_name) for key in keys): + if all(key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in keys): + # Load the layers corresponding to UNet. + unet_keys = [k for k in keys if k.startswith(self.unet_name)] logger.info(f"Loading {self.unet_name}.") - unet_lora_state_dict = {k: v for k, v in state_dict.items() if k.startswith(self.unet_name)} - self.unet.load_attn_procs(unet_lora_state_dict) + unet_lora_state_dict = { + k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys + } + self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) - # Load the layers corresponding to text encoder and make necessary adjustments. - elif all(key.startswith(self.text_encoder_name) for key in keys): - logger.info(f"Loading {self.text_encoder_name}.") + # Load the layers corresponding to text encoder and make necessary adjustments. + text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] text_encoder_lora_state_dict = { - k: v for k, v in state_dict.items() if k.startswith(self.text_encoder_name) + k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } - attn_procs_text_encoder = self.load_attn_procs(text_encoder_lora_state_dict) - self._modify_text_encoder(attn_procs_text_encoder) + if len(text_encoder_lora_state_dict) > 0: + logger.info(f"Loading {self.text_encoder_name}.") + attn_procs_text_encoder = self._load_text_encoder_attn_procs( + text_encoder_lora_state_dict, network_alpha=network_alpha + ) + self._modify_text_encoder(attn_procs_text_encoder) + + # save lora attn procs of text encoder so that it can be easily retrieved + self._text_encoder_lora_attn_procs = attn_procs_text_encoder # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. @@ -832,11 +953,33 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() ): self.unet.load_attn_procs(state_dict) - deprecation_message = "You have saved the LoRA weights using the old format. This will be" - " deprecated soon. To convert the old LoRA weights to the new format, you can first load them" - " in a dictionary and then create a new dictionary like the following:" - " `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." - deprecate("legacy LoRA weights", "1.0.0", deprecation_message, standard_warn=False) + warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." + warnings.warn(warn_message) + + @property + def lora_scale(self) -> float: + # property function that returns the lora scale which can be set at run time by the pipeline. + # if _lora_scale has not been set, return 1 + return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 + + @property + def text_encoder_lora_attn_procs(self): + if hasattr(self, "_text_encoder_lora_attn_procs"): + return self._text_encoder_lora_attn_procs + return + + def _remove_text_encoder_monkey_patch(self): + # Loop over the CLIPAttention module of text_encoder + for name, attn_module in self.text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + # Loop over the LoRA layers + for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): + # Retrieve the q/k/v/out projection of CLIPAttention + module = attn_module.get_submodule(text_encoder_attr) + if hasattr(module, "old_forward"): + # restore original `forward` to remove monkey-patch + module.forward = module.old_forward + delattr(module, "old_forward") def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" @@ -846,33 +989,46 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): attn_processors: Dict[str, `LoRAAttnProcessor`]: A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`]. """ - # Loop over the original attention modules. - for name, _ in self.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - # Retrieve the module and its corresponding LoRA processor. - module = self.text_encoder.get_submodule(name) - # Construct a new function that performs the LoRA merging. We will monkey patch - # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - old_forward = module.forward - - def new_forward(x): - return old_forward(x) + lora_layer(x) - - # Monkey-patch. - module.forward = new_forward - - def _get_lora_layer_attribute(self, name: str) -> str: - if "q_proj" in name: - return "to_q_lora" - elif "v_proj" in name: - return "to_v_lora" - elif "k_proj" in name: - return "to_k_lora" - else: - return "to_out_lora" - def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # First, remove any monkey-patch that might have been applied before + self._remove_text_encoder_monkey_patch() + + # Loop over the CLIPAttention module of text_encoder + for name, attn_module in self.text_encoder.named_modules(): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + # Loop over the LoRA layers + for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items(): + # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer. + module = attn_module.get_submodule(text_encoder_attr) + lora_layer = attn_processors[name].get_submodule(attn_proc_attr) + + # save old_forward to module that can be used to remove monkey-patch + old_forward = module.old_forward = module.forward + + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function + # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + result = old_forward(x) + self.lora_scale * lora_layer(x) + return result + + return new_forward + + # Monkey-patch. + module.forward = make_new_forward(old_forward, lora_layer) + + @property + def _lora_attn_processor_attr_to_text_encoder_attr(self): + return { + "to_q_lora": "q_proj", + "to_k_lora": "k_proj", + "to_v_lora": "v_proj", + "to_out_lora": "out_proj", + } + + def _load_text_encoder_attn_procs( + self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs + ): r""" Load pretrained attention processor layers for [`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel). @@ -945,6 +1101,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -1021,8 +1178,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - attn_processors[key] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + attn_processors[key] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) @@ -1039,7 +1202,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict def save_lora_weights( self, save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, @@ -1052,13 +1215,14 @@ def save_lora_weights( Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module`]): + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the - serialization process easier and cleaner. - text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): + serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch + weights. + text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state - dict. + dict. Values can be both LoRA torch.nn.Modules layers or torch weights. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on @@ -1086,15 +1250,22 @@ def save_function(weights, filename): # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: - unet_lora_state_dict = { - f"{self.unet_name}.{module_name}": param - for module_name, param in unet_lora_layers.state_dict().items() - } + weights = ( + unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers + ) + + unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) + if text_encoder_lora_layers is not None: + weights = ( + text_encoder_lora_layers.state_dict() + if isinstance(text_encoder_lora_layers, torch.nn.Module) + else text_encoder_lora_layers + ) + text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param - for module_name, param in text_encoder_lora_layers.state_dict().items() + f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) @@ -1108,6 +1279,56 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + def _convert_kohya_lora_to_diffusers(self, state_dict): + unet_state_dict = {} + te_state_dict = {} + network_alpha = None + + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + if lora_name_alpha in state_dict: + alpha = state_dict[lora_name_alpha].item() + if network_alpha is None: + network_alpha = alpha + elif network_alpha != alpha: + raise ValueError("Network alpha is not consistent") + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = value + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} + new_state_dict = {**unet_state_dict, **te_state_dict} + return new_state_dict, network_alpha + class FromCkptMixin: """This helper class allows to directly load .ckpt stable diffusion file_extension @@ -1150,10 +1371,10 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. - use_safetensors (`bool`, *optional* ): - If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the - default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if - `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the pipeline will load the `safetensors` weights if they're available **and** if the + `safetensors` library is installed. If set to `True`, the pipeline will forcibly load the models from + `safetensors` weights. If set to `False` the pipeline will *not* use `safetensors`. extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for @@ -1226,28 +1447,30 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] from_safetensors = file_extension == "safetensors" - if from_safetensors and use_safetensors is True: + if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") # TODO: For now we only support stable diffusion stable_unclip = None + model_type = None controlnet = False if pipeline_name == "StableDiffusionControlNetPipeline": - model_type = "FrozenCLIPEmbedder" + # Model type will be inferred from the checkpoint. controlnet = True elif "StableDiffusion" in pipeline_name: - model_type = "FrozenCLIPEmbedder" + # Model type will be inferred from the checkpoint. + pass elif pipeline_name == "StableUnCLIPPipeline": - model_type == "FrozenOpenCLIPEmbedder" + model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "txt2img" elif pipeline_name == "StableUnCLIPImg2ImgPipeline": - model_type == "FrozenOpenCLIPEmbedder" + model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "img2img" elif pipeline_name == "PaintByExamplePipeline": - model_type == "PaintByExample" + model_type = "PaintByExample" elif pipeline_name == "LDMTextToImagePipeline": - model_type == "LDMTextToImage" + model_type = "LDMTextToImage" else: raise ValueError(f"Unhandled pipeline class: {pipeline_name}") @@ -1260,8 +1483,8 @@ def from_ckpt(cls, pretrained_model_link_or_path, **kwargs): ckpt_path = Path(pretrained_model_link_or_path) if not ckpt_path.is_file(): # get repo_id and (potentially nested) file path of ckpt in repo - repo_id = str(Path().joinpath(*ckpt_path.parts[:2])) - file_path = str(Path().joinpath(*ckpt_path.parts[2:])) + repo_id = "/".join(ckpt_path.parts[:2]) + file_path = "/".join(ckpt_path.parts[2:]) if file_path.startswith("blob/"): file_path = file_path[len("blob/") :] diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py new file mode 100644 index 000000000000..64759b706e2f --- /dev/null +++ b/src/diffusers/models/activations.py @@ -0,0 +1,12 @@ +from torch import nn + + +def get_activation(act_fn): + if act_fn in ["swish", "silu"]: + return nn.SiLU() + elif act_fn == "mish": + return nn.Mish() + elif act_fn == "gelu": + return nn.GELU() + else: + raise ValueError(f"Unsupported activation function: {act_fn}") diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index fb5f6f48b324..8805257ebe9a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -11,188 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Callable, Optional +from typing import Any, Dict, Optional import torch import torch.nn.functional as F from torch import nn -from ..utils.import_utils import is_xformers_available +from ..utils import maybe_allow_in_graph +from .activations import get_activation from .attention_processor import Attention from .embeddings import CombinedTimestepLabelEmbeddings -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (`int`): The number of channels in the input and output. - num_head_channels (`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - norm_num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, bias=True) - - self._use_memory_efficient_attention_xformers = False - self._use_2_0_attn = True - self._attention_op = None - - def reshape_heads_to_batch_dim(self, tensor, merge_head_and_batch=True): - batch_size, seq_len, dim = tensor.shape - head_size = self.num_heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3) - if merge_head_and_batch: - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor, unmerge_head_and_batch=True): - head_size = self.num_heads - - if unmerge_head_and_batch: - batch_head_size, seq_len, dim = tensor.shape - batch_size = batch_head_size // head_size - - tensor = tensor.reshape(batch_size, head_size, seq_len, dim) - else: - batch_size, _, seq_len, dim = tensor.shape - - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size, seq_len, dim * head_size) - return tensor - - def set_use_memory_efficient_attention_xformers( - self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None - ): - if use_memory_efficient_attention_xformers: - if not is_xformers_available(): - raise ModuleNotFoundError( - ( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers" - ), - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" - " only available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - self._attention_op = attention_op - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - scale = 1 / math.sqrt(self.channels / self.num_heads) - - _use_2_0_attn = self._use_2_0_attn and not self._use_memory_efficient_attention_xformers - use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") and _use_2_0_attn - - query_proj = self.reshape_heads_to_batch_dim(query_proj, merge_head_and_batch=not use_torch_2_0_attn) - key_proj = self.reshape_heads_to_batch_dim(key_proj, merge_head_and_batch=not use_torch_2_0_attn) - value_proj = self.reshape_heads_to_batch_dim(value_proj, merge_head_and_batch=not use_torch_2_0_attn) - - if self._use_memory_efficient_attention_xformers: - # Memory efficient attention - hidden_states = xformers.ops.memory_efficient_attention( - query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op, scale=scale - ) - hidden_states = hidden_states.to(query_proj.dtype) - elif use_torch_2_0_attn: - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query_proj, key_proj, value_proj, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.to(query_proj.dtype) - else: - attention_scores = torch.baddbmm( - torch.empty( - query_proj.shape[0], - query_proj.shape[1], - key_proj.shape[1], - dtype=query_proj.dtype, - device=query_proj.device, - ), - query_proj, - key_proj.transpose(-1, -2), - beta=0, - alpha=scale, - ) - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - hidden_states = torch.bmm(attention_probs, value_proj) - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states, unmerge_head_and_batch=not use_torch_2_0_attn) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states - - +@maybe_allow_in_graph class BasicTransformerBlock(nn.Module): r""" A basic Transformer block. @@ -290,13 +121,13 @@ def __init__( def forward( self, - hidden_states, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - timestep=None, - cross_attention_kwargs=None, - class_labels=None, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, ): # Notice that normalization is always applied before the real computation in the following blocks. # 1. Self-Attention @@ -325,8 +156,6 @@ def forward( norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) - # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly - # prepare attention mask here attn_output = self.attn2( norm_hidden_states, @@ -517,15 +346,11 @@ def __init__( super().__init__() self.num_groups = num_groups self.eps = eps - self.act = None - if act_fn == "swish": - self.act = lambda x: F.silu(x) - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "gelu": - self.act = nn.GELU() + + if act_fn is None: + self.act = None + else: + self.act = get_activation(act_fn) self.linear = nn.Linear(embedding_dim, out_dim * 2) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index b8787aed91f2..e0404a83cc9a 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch import nn -from ..utils import deprecate, logging +from ..utils import deprecate, logging, maybe_allow_in_graph from ..utils.import_utils import is_xformers_available @@ -31,6 +31,7 @@ xformers = None +@maybe_allow_in_graph class Attention(nn.Module): r""" A cross attention layer. @@ -60,9 +61,14 @@ def __init__( cross_attention_norm_num_groups: int = 32, added_kv_proj_dim: Optional[int] = None, norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, out_bias: bool = True, scale_qk: bool = True, only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block=False, processor: Optional["AttnProcessor"] = None, ): super().__init__() @@ -70,8 +76,15 @@ def __init__( cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.upcast_attention = upcast_attention self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection - self.scale = dim_head**-0.5 if scale_qk else 1.0 + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.heads = heads # for slice_size > 0 the attention score computation @@ -88,10 +101,15 @@ def __init__( ) if norm_num_groups is not None: - self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) else: self.group_norm = None + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + if cross_attention_norm is None: self.norm_cross = None elif cross_attention_norm == "layer_norm": @@ -139,7 +157,7 @@ def __init__( # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 if processor is None: processor = ( - AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor() + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() ) self.set_processor(processor) @@ -147,22 +165,29 @@ def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): is_lora = hasattr(self, "processor") and isinstance( - self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor) + self.processor, + (LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor, LoRAAttnAddedKVProcessor), ) is_custom_diffusion = hasattr(self, "processor") and isinstance( self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor) ) + is_added_kv_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + SlicedAttnAddedKVProcessor, + XFormersAttnAddedKVProcessor, + LoRAAttnAddedKVProcessor, + ), + ) if use_memory_efficient_attention_xformers: - if self.added_kv_proj_dim is not None: - # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP - # which uses this type of cross attention ONLY because the attention mask of format - # [0, ..., -10.000, ..., 0, ...,] is not supported + if is_added_kv_processor and (is_lora or is_custom_diffusion): raise NotImplementedError( - "Memory efficient attention with `xformers` is currently not supported when" - " `self.added_kv_proj_dim` is defined." + f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}" ) - elif not is_xformers_available(): + if not is_xformers_available(): raise ModuleNotFoundError( ( "Refer to https://github.com/facebookresearch/xformers for more information on how to install" @@ -187,6 +212,8 @@ def set_use_memory_efficient_attention_xformers( raise e if is_lora: + # TODO (sayakpaul): should we throw a warning if someone wants to use the xformers + # variant when using PT 2.0 now that we have LoRAAttnProcessor2_0? processor = LoRAXFormersAttnProcessor( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, @@ -206,11 +233,23 @@ def set_use_memory_efficient_attention_xformers( processor.load_state_dict(self.processor.state_dict()) if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) + elif is_added_kv_processor: + # TODO(Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP + # which uses this type of cross attention ONLY because the attention mask of format + # [0, ..., -10.000, ..., 0, ...,] is not supported + # throw warning + logger.info( + "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation." + ) + processor = XFormersAttnAddedKVProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: if is_lora: - processor = LoRAAttnProcessor( + attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + processor = attn_processor_class( hidden_size=self.processor.hidden_size, cross_attention_dim=self.processor.cross_attention_dim, rank=self.processor.rank, @@ -228,7 +267,15 @@ def set_use_memory_efficient_attention_xformers( if hasattr(self.processor, "to_k_custom_diffusion"): processor.to(self.processor.to_k_custom_diffusion.weight.device) else: - processor = AttnProcessor() + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() + if hasattr(F, "scaled_dot_product_attention") and self.scale_qk + else AttnProcessor() + ) self.set_processor(processor) @@ -243,7 +290,13 @@ def set_attention_slice(self, slice_size): elif self.added_kv_proj_dim is not None: processor = AttnAddedKVProcessor() else: - processor = AttnProcessor() + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) self.set_processor(processor) @@ -312,11 +365,14 @@ def get_attention_scores(self, query, key, attention_mask=None): beta=beta, alpha=self.scale, ) + del baddbmm_input if self.upcast_softmax: attention_scores = attention_scores.float() attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + attention_probs = attention_probs.to(dtype) return attention_probs @@ -338,7 +394,8 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, if attention_mask is None: return attention_mask - if attention_mask.shape[-1] != target_length: + current_length: int = attention_mask.shape[-1] + if current_length != target_length: if attention_mask.device.type == "mps": # HACK: MPS: Does not support padding by greater than dimension of input tensor. # Instead, we can manually construct the padding tensor. @@ -346,6 +403,10 @@ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat([attention_mask, padding], dim=2) else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) if out_dim == 3: @@ -378,17 +439,37 @@ def norm_encoder_hidden_states(self, encoder_hidden_states): class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, + temb=None, ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -412,11 +493,19 @@ def __call__( # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): super().__init__() if rank > min(in_features, out_features): @@ -424,6 +513,10 @@ def __init__(self, in_features, out_features, rank=4): self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -435,28 +528,61 @@ def forward(self, hidden_states): down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + return up_hidden_states.to(orig_dtype) class LoRAAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + r""" + Processor for implementing the LoRA attention mechanism. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query) @@ -480,10 +606,36 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states class CustomDiffusionAttnProcessor(nn.Module): + r""" + Processor for implementing attention for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + """ + def __init__( self, train_kv=True, @@ -562,6 +714,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class AttnAddedKVProcessor: + r""" + Processor for performing attention-related computations with extra learnable key and value matrices for the text + encoder. + """ + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): residual = hidden_states hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) @@ -611,6 +768,11 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class AttnAddedKVProcessor2_0: + r""" + Processor for performing scaled dot-product attention (enabled by default if you're using PyTorch 2.0), with extra + learnable key and value matrices for the text encoder. + """ + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -668,16 +830,203 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a return hidden_states -class XFormersAttnProcessor: +class LoRAAttnAddedKVProcessor(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text + encoder. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + scale * self.add_k_proj_lora( + encoder_hidden_states + ) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + scale * self.add_v_proj_lora( + encoder_hidden_states + ) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + scale * self.to_k_lora(hidden_states) + value = attn.to_v(hidden_states) + scale * self.to_v_lora(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersAttnAddedKVProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + def __init__(self, attention_op: Optional[Callable] = None): self.attention_op = attention_op def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): - batch_size, sequence_length, _ = ( + residual = hidden_states + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) + batch_size, sequence_length, _ = hidden_states.shape + + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + query = attn.head_to_batch_dim(query) + + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) + + if not attn.only_cross_attention: + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) + else: + key = encoder_hidden_states_key_proj + value = encoder_hidden_states_value_proj + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) + hidden_states = hidden_states + residual + + return hidden_states + + +class XFormersAttnProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, key_tokens, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask(attention_mask, key_tokens, batch_size) + if attention_mask is not None: + # expand our mask's singleton query_tokens dimension: + # [batch*heads, 1, key_tokens] -> + # [batch*heads, query_tokens, key_tokens] + # so that it can be added as a bias onto the attention scores that xformers computes: + # [batch*heads, query_tokens, key_tokens] + # we do this explicitly because xformers doesn't broadcast the singleton dimension for us. + _, query_tokens, _ = hidden_states.shape + attention_mask = attention_mask.expand(-1, query_tokens, -1) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) @@ -703,15 +1052,46 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, + attn: Attention, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) @@ -723,6 +1103,9 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -751,11 +1134,42 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states class LoRAXFormersAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): + r""" + Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers. + + Args: + hidden_size (`int`, *optional*): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + + """ + + def __init__( + self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None + ): super().__init__() self.hidden_size = hidden_size @@ -763,17 +1177,33 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.rank = rank self.attention_op = attention_op - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__( + self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) query = attn.head_to_batch_dim(query).contiguous() @@ -798,10 +1228,131 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class LoRAAttnProcessor2_0(nn.Module): + r""" + Processor for implementing the LoRA attention mechanism using PyTorch 2.0's memory-efficient scaled dot-product + attention. + + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*): + The number of channels in the `encoder_hidden_states`. + rank (`int`, defaults to 4): + The dimension of the LoRA update matrices. + network_alpha (`int`, *optional*): + Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.rank = rank + + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + + def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + inner_dim = hidden_states.shape[-1] + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) + + head_dim = inner_dim // attn.heads + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states class CustomDiffusionXFormersAttnProcessor(nn.Module): + r""" + Processor for implementing memory efficient attention using xFormers for the Custom Diffusion method. + + Args: + train_kv (`bool`, defaults to `True`): + Whether to newly train the key and value matrices corresponding to the text features. + train_q_out (`bool`, defaults to `True`): + Whether to newly train query matrices corresponding to the latent image features. + hidden_size (`int`, *optional*, defaults to `None`): + The hidden size of the attention layer. + cross_attention_dim (`int`, *optional*, defaults to `None`): + The number of channels in the `encoder_hidden_states`. + out_bias (`bool`, defaults to `True`): + Whether to include the bias parameter in `train_q_out`. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use + as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator. + """ + def __init__( self, train_kv=True, @@ -887,15 +1438,35 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class SlicedAttnProcessor: + r""" + Processor for implementing sliced attention. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + def __init__(self, slice_size): self.slice_size = slice_size def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) @@ -936,15 +1507,36 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a # dropout hidden_states = attn.to_out[1](hidden_states) + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + return hidden_states class SlicedAttnAddedKVProcessor: + r""" + Processor for implementing sliced attention with extra learnable key and value matrices for the text encoder. + + Args: + slice_size (`int`, *optional*): + The number of steps to compute attention. Uses as many slices as `attention_head_dim // slice_size`, and + `attention_head_dim` must be a multiple of the `slice_size`. + """ + def __init__(self, slice_size): self.slice_size = slice_size - def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) batch_size, sequence_length, _ = hidden_states.shape @@ -1019,8 +1611,34 @@ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0, + XFormersAttnAddedKVProcessor, LoRAAttnProcessor, LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + LoRAAttnAddedKVProcessor, CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, ] + + +class SpatialNorm(nn.Module): + """ + Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002 + """ + + def __init__( + self, + f_channels, + zq_channels, + ): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f, zq): + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 1a8a204d80ce..a4894e78c43f 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -196,12 +196,14 @@ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[Decode return DecoderOutput(sample=decoded) def blend_v(self, a, b, blend_extent): - for y in range(min(a.shape[2], b.shape[2], blend_extent)): + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) return b def blend_h(self, a, b, blend_extent): - for x in range(min(a.shape[3], b.shape[3], blend_extent)): + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) return b diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 3ffbb04eb222..0b0ce0be547f 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -498,7 +498,7 @@ def forward( # timesteps does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) @@ -517,7 +517,7 @@ def forward( controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) - sample += controlnet_cond + sample = sample + controlnet_cond # 3. down down_block_res_samples = (sample,) @@ -551,21 +551,22 @@ def forward( for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): down_block_res_sample = controlnet_block(down_block_res_sample) - controlnet_down_block_res_samples += (down_block_res_sample,) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = controlnet_down_block_res_samples mid_block_res_sample = self.controlnet_mid_block(sample) # 6. scaling - if guess_mode: - scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1) # 0.1 to 1.0 - scales *= conditioning_scale + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0 + + scales = scales * conditioning_scale down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)] - mid_block_res_sample *= scales[-1] # last one + mid_block_res_sample = mid_block_res_sample * scales[-1] # last one else: down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples] - mid_block_res_sample *= conditioning_scale + mid_block_res_sample = mid_block_res_sample * conditioning_scale if self.config.global_pool_conditions: down_block_res_samples = [ diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index fa88bce305e6..4dd16f0dd5ff 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -18,6 +18,8 @@ import torch from torch import nn +from .activations import get_activation + def get_timestep_embedding( timesteps: torch.Tensor, @@ -171,14 +173,7 @@ def __init__( else: self.cond_proj = None - if act_fn == "silu": - self.act = nn.SiLU() - elif act_fn == "mish": - self.act = nn.Mish() - elif act_fn == "gelu": - self.act = nn.GELU() - else: - raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim @@ -188,14 +183,8 @@ def __init__( if post_act_fn is None: self.post_act = None - elif post_act_fn == "silu": - self.post_act = nn.SiLU() - elif post_act_fn == "mish": - self.post_act = nn.Mish() - elif post_act_fn == "gelu": - self.post_act = nn.GELU() else: - raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") + self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: @@ -352,7 +341,7 @@ def token_drop(self, labels, force_drop_ids=None): labels = torch.where(drop_ids, self.num_classes, labels) return labels - def forward(self, labels, force_drop_ids=None): + def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) @@ -360,6 +349,33 @@ def forward(self, labels, force_drop_ids=None): return embeddings +class TextImageProjection(nn.Module): + def __init__( + self, + text_embed_dim: int = 1024, + image_embed_dim: int = 768, + cross_attention_dim: int = 768, + num_image_text_embeds: int = 10, + ): + super().__init__() + + self.num_image_text_embeds = num_image_text_embeds + self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) + self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) + + def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): + batch_size = text_embeds.shape[0] + + # image + image_text_embeds = self.image_embeds(image_embeds) + image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) + + # text + text_embeds = self.text_proj(text_embeds) + + return torch.cat([image_text_embeds, text_embeds], dim=1) + + class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() @@ -395,6 +411,24 @@ def forward(self, hidden_states): return hidden_states +class TextImageTimeEmbedding(nn.Module): + def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): + super().__init__() + self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) + self.text_norm = nn.LayerNorm(time_embed_dim) + self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) + + def forward(self, text_embeds: torch.FloatTensor, image_embeds: torch.FloatTensor): + # text + time_text_embeds = self.text_proj(text_embeds) + time_text_embeds = self.text_norm(time_text_embeds) + + # image + time_image_embeds = self.image_proj(image_embeds) + + return time_image_embeds + time_text_embeds + + class AttentionPooling(nn.Module): # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 521e99fdd69c..f6d6bc5711cd 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -17,6 +17,7 @@ import inspect import itertools import os +import re from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union @@ -77,8 +78,14 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: def get_parameter_dtype(parameter: torch.nn.Module): try: - parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) - return next(parameters_and_buffers).dtype + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + except StopIteration: # For torch.nn.DataParallel compatibility in PyTorch 1.5 @@ -156,6 +163,7 @@ class ModelMixin(torch.nn.Module): config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _supports_gradient_checkpointing = False + _keys_to_ignore_on_load_unexpected = None def __init__(self): super().__init__() @@ -392,6 +400,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the @@ -400,10 +417,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant (`str`, *optional*): If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is ignored when using `from_flax`. - use_safetensors (`bool`, *optional* ): - If set to `True`, the pipeline will forcibly load the models from `safetensors` weights. If set to - `None` (the default). The pipeline will load using `safetensors` if safetensors weights are available - *and* if `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. @@ -433,6 +450,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P torch_dtype = kwargs.pop("torch_dtype", None) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) @@ -504,6 +524,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder, device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, user_agent=user_agent, **kwargs, ) @@ -577,6 +600,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P if device_map is None: param_device = "cpu" state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) if len(missing_keys) > 0: @@ -586,6 +610,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" " those weights or else make sure your checkpoint file is correct." ) + unexpected_keys = [] empty_state_dict = model.state_dict() for param_name, param in state_dict.items(): @@ -593,6 +618,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P inspect.signature(set_module_tensor_to_device).parameters.keys() ) + if param_name not in empty_state_dict: + unexpected_keys.append(param_name) + continue + if empty_state_dict[param_name].shape != param.shape: raise ValueError( f"Cannot load {pretrained_model_name_or_path} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example." @@ -604,10 +633,28 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) else: set_module_tensor_to_device(model, param_name, param_device, value=param) + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + if len(unexpected_keys) > 0: + logger.warn( + f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" + ) + else: # else let accelerate handle loading and dispatching. # Load weights and dispatch according to the device_map # by default the device_map is None and the weights are loaded on the CPU - accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype) + accelerate.load_checkpoint_and_dispatch( + model, + model_file, + device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + ) loading_info = { "missing_keys": [], @@ -619,6 +666,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = cls.from_config(config, **unused_kwargs) state_dict = load_state_dict(model_file, variant=variant) + model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, @@ -797,3 +845,47 @@ def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) else: return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def _convert_deprecated_attention_blocks(self, state_dict): + deprecated_attention_block_paths = [] + + def recursive_find_attn_block(name, module): + if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block: + deprecated_attention_block_paths.append(name) + + for sub_name, sub_module in module.named_children(): + sub_name = sub_name if name == "" else f"{name}.{sub_name}" + recursive_find_attn_block(sub_name, sub_module) + + recursive_find_attn_block("", self) + + # NOTE: we have to check if the deprecated parameters are in the state dict + # because it is possible we are loading from a state dict that was already + # converted + + for path in deprecated_attention_block_paths: + # group_norm path stays the same + + # query -> to_q + if f"{path}.query.weight" in state_dict: + state_dict[f"{path}.to_q.weight"] = state_dict.pop(f"{path}.query.weight") + if f"{path}.query.bias" in state_dict: + state_dict[f"{path}.to_q.bias"] = state_dict.pop(f"{path}.query.bias") + + # key -> to_k + if f"{path}.key.weight" in state_dict: + state_dict[f"{path}.to_k.weight"] = state_dict.pop(f"{path}.key.weight") + if f"{path}.key.bias" in state_dict: + state_dict[f"{path}.to_k.bias"] = state_dict.pop(f"{path}.key.bias") + + # value -> to_v + if f"{path}.value.weight" in state_dict: + state_dict[f"{path}.to_v.weight"] = state_dict.pop(f"{path}.value.weight") + if f"{path}.value.bias" in state_dict: + state_dict[f"{path}.to_v.bias"] = state_dict.pop(f"{path}.value.bias") + + # proj_attn -> to_out.0 + if f"{path}.proj_attn.weight" in state_dict: + state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight") + if f"{path}.proj_attn.bias" in state_dict: + state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias") diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d9d539959c09..52f01552c528 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -20,18 +20,23 @@ import torch.nn as nn import torch.nn.functional as F +from .activations import get_activation from .attention import AdaGroupNorm +from .attention_processor import SpatialNorm class Upsample1D(nn.Module): - """ - An upsampling layer with an optional convolution. + """A 1D upsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -48,28 +53,31 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann elif use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) - def forward(self, x): - assert x.shape[1] == self.channels + def forward(self, inputs): + assert inputs.shape[1] == self.channels if self.use_conv_transpose: - return self.conv(x) + return self.conv(inputs) - x = F.interpolate(x, scale_factor=2.0, mode="nearest") + outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest") if self.use_conv: - x = self.conv(x) + outputs = self.conv(outputs) - return x + return outputs class Downsample1D(nn.Module): - """ - A downsampling layer with an optional convolution. + """A 1D downsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -93,14 +101,17 @@ def forward(self, x): class Upsample2D(nn.Module): - """ - An upsampling layer with an optional convolution. + """A 2D upsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - use_conv_transpose: - out_channels: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + use_conv_transpose (`bool`, default `False`): + option to use a convolution transpose. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -162,14 +173,17 @@ def forward(self, hidden_states, output_size=None): class Downsample2D(nn.Module): - """ - A downsampling layer with an optional convolution. + """A 2D downsampling layer with an optional convolution. Parameters: - channels: channels in the inputs and outputs. - use_conv: a bool determining if a convolution is applied. - out_channels: - padding: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + padding (`int`, default `1`): + padding for the convolution. """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -209,6 +223,19 @@ def forward(self, hidden_states): class FirUpsample2D(nn.Module): + """A 2D FIR upsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels @@ -309,6 +336,19 @@ def forward(self, hidden_states): class FirDownsample2D(nn.Module): + """A 2D FIR downsampling layer with an optional convolution. + + Parameters: + channels (`int`): + number of channels in the inputs and outputs. + use_conv (`bool`, default `False`): + option to use a convolution. + out_channels (`int`, optional): + number of output channels. Defaults to `channels`. + fir_kernel (`tuple`, default `(1, 3, 3, 1)`): + kernel for the FIR filter. + """ + def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): super().__init__() out_channels = out_channels if out_channels else channels @@ -395,7 +435,8 @@ def forward(self, x): x = F.pad(x, (self.pad,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) + weight[indices, indices] = kernel return F.conv2d(x, weight, stride=2) @@ -411,7 +452,8 @@ def forward(self, x): x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) indices = torch.arange(x.shape[1], device=x.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(x.shape[1], -1, -1) + weight[indices, indices] = kernel return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) @@ -460,7 +502,7 @@ def __init__( eps=1e-6, non_linearity="swish", skip_time_act=False, - time_embedding_norm="default", # default, scale_shift, ada_group + time_embedding_norm="default", # default, scale_shift, ada_group, spatial kernel=None, output_scale_factor=1.0, use_in_shortcut=None, @@ -487,6 +529,8 @@ def __init__( if self.time_embedding_norm == "ada_group": self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm1 = SpatialNorm(in_channels, temb_channels) else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) @@ -497,7 +541,7 @@ def __init__( self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) - elif self.time_embedding_norm == "ada_group": + elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") @@ -506,6 +550,8 @@ def __init__( if self.time_embedding_norm == "ada_group": self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) + elif self.time_embedding_norm == "spatial": + self.norm2 = SpatialNorm(out_channels, temb_channels) else: self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) @@ -513,14 +559,7 @@ def __init__( conv_2d_out_channels = conv_2d_out_channels or out_channels self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - elif non_linearity == "gelu": - self.nonlinearity = nn.GELU() + self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: @@ -551,7 +590,7 @@ def __init__( def forward(self, input_tensor, temb): hidden_states = input_tensor - if self.time_embedding_norm == "ada_group": + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm1(hidden_states, temb) else: hidden_states = self.norm1(hidden_states) @@ -579,7 +618,7 @@ def forward(self, input_tensor, temb): if temb is not None and self.time_embedding_norm == "default": hidden_states = hidden_states + temb - if self.time_embedding_norm == "ada_group": + if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": hidden_states = self.norm2(hidden_states, temb) else: hidden_states = self.norm2(hidden_states) @@ -601,11 +640,6 @@ def forward(self, input_tensor, temb): return output_tensor -class Mish(torch.nn.Module): - def forward(self, hidden_states): - return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) - - # unet_rl.py def rearrange_dims(tensor): if len(tensor.shape) == 2: diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index fde1014bd2e7..ec4cb371845f 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Optional +from typing import Any, Dict, Optional import torch import torch.nn.functional as F @@ -213,11 +213,13 @@ def __init__( def forward( self, - hidden_states, - encoder_hidden_states=None, - timestep=None, - class_labels=None, - cross_attention_kwargs=None, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ): """ @@ -228,11 +230,17 @@ def forward( encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. - timestep ( `torch.long`, *optional*): + timestep ( `torch.LongTensor`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels conditioning. + encoder_attention_mask ( `torch.Tensor`, *optional* ). + Cross-attention mask, applied to encoder_hidden_states. Two formats supported: + Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0 + = keep, -10000 = discard. + If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -241,6 +249,29 @@ def forward( [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 1. Input if self.is_input_continuous: batch, _, height, width = hidden_states.shape @@ -264,7 +295,9 @@ def forward( for block in self.transformer_blocks: hidden_states = block( hidden_states, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, class_labels=class_labels, diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index a0f0e58f9103..3c04bffeeacc 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,6 +17,7 @@ import torch.nn.functional as F from torch import nn +from .activations import get_activation from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims @@ -55,14 +56,10 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: + if non_linearity is None: self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) self.downsample = None if add_downsample: @@ -119,14 +116,10 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: + if non_linearity is None: self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: @@ -194,14 +187,10 @@ def __init__( self.resnets = nn.ModuleList(resnets) - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = nn.Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - else: + if non_linearity is None: self.nonlinearity = None + else: + self.nonlinearity = get_activation(non_linearity) self.upsample = None if add_upsample: @@ -232,10 +221,7 @@ def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): super().__init__() self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) - if act_fn == "silu": - self.final_conv1d_act = nn.SiLU() - if act_fn == "mish": - self.final_conv1d_act = nn.Mish() + self.final_conv1d_act = get_activation(act_fn) self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) def forward(self, hidden_states, temb=None): @@ -300,7 +286,8 @@ def forward(self, hidden_states): hidden_states = F.pad(hidden_states, (self.pad,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel return F.conv1d(hidden_states, weight, stride=2) @@ -316,7 +303,8 @@ def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) - weight[indices, indices] = self.kernel.to(weight) + kernel = self.kernel.to(weight)[None, :].expand(hidden_states.shape[1], -1) + weight[indices, indices] = kernel return F.conv_transpose1d(hidden_states, weight, stride=2, padding=self.pad * 2 + 1) diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 439c5c34b601..674e58d7180e 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -11,14 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn -from .attention import AdaGroupNorm, AttentionBlock +from ..utils import is_torch_version +from .attention import AdaGroupNorm from .attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 from .dual_transformer_2d import DualTransformer2DModel from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D @@ -348,6 +349,7 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, ) elif up_block_type == "AttnUpDecoderBlock2D": return AttnUpDecoderBlock2D( @@ -360,6 +362,7 @@ def get_up_block( resnet_groups=resnet_groups, attn_num_head_channels=attn_num_head_channels, resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, ) elif up_block_type == "KUpBlock2D": return KUpBlock2D( @@ -395,7 +398,7 @@ def __init__( dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", + resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, @@ -427,12 +430,18 @@ def __init__( for _ in range(num_layers): if self.add_attention: attentions.append( - AttentionBlock( + Attention( in_channels, - num_head_channels=attn_num_head_channels, + heads=in_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else in_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) else: @@ -460,7 +469,7 @@ def forward(self, hidden_states, temb=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: - hidden_states = attn(hidden_states) + hidden_states = attn(hidden_states, temb=temb) hidden_states = resnet(hidden_states, temb) return hidden_states @@ -552,15 +561,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = resnet(hidden_states, temb) return hidden_states @@ -652,16 +670,34 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) @@ -710,12 +746,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -733,7 +774,7 @@ def __init__( else: self.downsamplers = None - def forward(self, hidden_states, temb=None): + def forward(self, hidden_states, temb=None, upsample_size=None): output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -838,9 +879,14 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -855,12 +901,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + None, # timestep + None, # class_labels cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -868,15 +925,18 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -945,17 +1005,24 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1058,12 +1125,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1132,11 +1204,17 @@ def __init__( ) ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1338,17 +1416,24 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1449,30 +1534,65 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): - # resnet - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) - output_states += (hidden_states,) + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1535,7 +1655,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1614,7 +1741,13 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): output_states = () @@ -1630,13 +1763,22 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + temb, attention_mask, cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, ) else: hidden_states = resnet(hidden_states, temb) @@ -1646,6 +1788,7 @@ def custom_forward(*inputs): emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if self.downsamplers is None: @@ -1701,12 +1844,17 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -1718,7 +1866,7 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1820,15 +1968,15 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1846,12 +1994,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + None, # timestep + None, # class_labels cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1859,7 +2018,10 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1931,7 +2093,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1950,12 +2119,13 @@ def __init__( dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", + resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, + temb_channels=None, ): super().__init__() resnets = [] @@ -1967,7 +2137,7 @@ def __init__( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -1985,9 +2155,9 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb=None) + hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2011,6 +2181,7 @@ def __init__( attn_num_head_channels=1, output_scale_factor=1.0, add_upsample=True, + temb_channels=None, ): super().__init__() resnets = [] @@ -2023,7 +2194,7 @@ def __init__( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, - temb_channels=None, + temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, @@ -2034,12 +2205,18 @@ def __init__( ) ) attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, - norm_num_groups=resnet_groups, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -2051,10 +2228,10 @@ def __init__( else: self.upsamplers = None - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb=None) - hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2106,11 +2283,17 @@ def __init__( ) self.attentions.append( - AttentionBlock( + Attention( out_channels, - num_head_channels=attn_num_head_channels, + heads=out_channels // attn_num_head_channels if attn_num_head_channels is not None else 1, + dim_head=attn_num_head_channels if attn_num_head_channels is not None else out_channels, rescale_output_factor=output_scale_factor, eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, ) ) @@ -2348,7 +2531,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2458,15 +2648,28 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + for resnet, attn in zip(self.resnets, self.attentions): # resnet # pop res hidden states @@ -2474,15 +2677,34 @@ def forward( res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: - # attn - hidden_states = attn( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - **cross_attention_kwargs, - ) + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + mask, + cross_attention_kwargs, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -2553,7 +2775,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -2650,13 +2879,14 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: @@ -2674,13 +2904,22 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + temb, attention_mask, cross_attention_kwargs, + encoder_attention_mask, + **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -2690,6 +2929,7 @@ def custom_forward(*inputs): emb=temb, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if self.upsamplers is not None: @@ -2768,11 +3008,14 @@ def _to_4d(self, hidden_states, height, weight): def forward( self, - hidden_states, - encoder_hidden_states=None, - emb=None, - attention_mask=None, - cross_attention_kwargs=None, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} @@ -2786,6 +3029,7 @@ def forward( attn_output = self.attn1( norm_hidden_states, encoder_hidden_states=None, + attention_mask=attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) @@ -2800,6 +3044,7 @@ def forward( attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, **cross_attention_kwargs, ) attn_output = self._to_4d(attn_output, height, weight) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 38e0fa3b5b2e..dda21fd80479 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -16,14 +16,21 @@ import torch import torch.nn as nn -import torch.nn.functional as F import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..loaders import UNet2DConditionLoadersMixin from ..utils import BaseOutput, logging +from .activations import get_activation from .attention_processor import AttentionProcessor, AttnProcessor -from .embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps +from .embeddings import ( + GaussianFourierProjection, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -90,7 +97,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. encoder_hid_dim (`int`, *optional*, defaults to None): - If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to None): + If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. @@ -156,6 +167,7 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, @@ -247,8 +259,32 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) - if encoder_hid_dim is not None: + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) else: self.encoder_hid_proj = None @@ -290,21 +326,20 @@ def __init__( self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.") + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) - elif time_embedding_act_fn == "mish": - self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() - elif time_embedding_act_fn == "gelu": - self.time_embed_act = nn.GELU() else: - raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -458,16 +493,7 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) - elif act_fn == "mish": - self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() - elif act_fn == "gelu": - self.conv_act = nn.GELU() - else: - raise ValueError(f"Unsupported activation function: {act_fn}") + self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None @@ -616,8 +642,10 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -625,12 +653,20 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + added_cond_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time + embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and + `addition_embed_type` for more information. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -651,11 +687,27 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -682,7 +734,7 @@ def forward( # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) @@ -697,7 +749,7 @@ def forward( # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) @@ -707,12 +759,33 @@ def forward( if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) emb = emb + aug_emb + elif self.config.addition_embed_type == "text_image": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + + aug_emb = self.add_embedding(text_embs, image_embs) + emb = emb + aug_emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) - if self.encoder_hid_proj is not None: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) # 2. pre-process sample = self.conv_in(sample) @@ -727,6 +800,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -740,7 +814,7 @@ def forward( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples @@ -752,6 +826,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: @@ -778,6 +853,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 400c3030af90..dd4af0efcfd9 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -18,7 +18,8 @@ import torch import torch.nn as nn -from ..utils import BaseOutput, randn_tensor +from ..utils import BaseOutput, is_torch_version, randn_tensor +from .attention_processor import SpatialNorm from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -117,11 +118,20 @@ def custom_forward(*inputs): return custom_forward # down - for down_block in self.down_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) - - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) else: # down @@ -149,6 +159,7 @@ def __init__( layers_per_block=2, norm_num_groups=32, act_fn="silu", + norm_type="group", # group, spatial ): super().__init__() self.layers_per_block = layers_per_block @@ -164,16 +175,18 @@ def __init__( self.mid_block = None self.up_blocks = nn.ModuleList([]) + temb_channels = in_channels if norm_type == "spatial" else None + # mid self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, resnet_act_fn=act_fn, output_scale_factor=1, - resnet_time_scale_shift="default", + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, attn_num_head_channels=None, resnet_groups=norm_num_groups, - temb_channels=None, + temb_channels=temb_channels, ) # up @@ -196,19 +209,23 @@ def __init__( resnet_act_fn=act_fn, resnet_groups=norm_num_groups, attn_num_head_channels=None, - temb_channels=None, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, ) self.up_blocks.append(up_block) prev_output_channel = output_channel # out - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) self.gradient_checkpointing = False - def forward(self, z): + def forward(self, z, latent_embeds=None): sample = z sample = self.conv_in(sample) @@ -221,24 +238,42 @@ def custom_forward(*inputs): return custom_forward - # middle - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) - sample = sample.to(upscale_dtype) + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds, use_reentrant=False + ) + sample = sample.to(upscale_dtype) - # up - for up_block in self.up_blocks: - sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, latent_embeds, use_reentrant=False + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) else: # middle - sample = self.mid_block(sample) + sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) # up for up_block in self.up_blocks: - sample = up_block(sample) + sample = up_block(sample, latent_embeds) # post-process - sample = self.conv_norm_out(sample) + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 65f734dccb2d..73158294ee6e 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -82,6 +82,7 @@ def __init__( norm_num_groups: int = 32, vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, + norm_type: str = "group", # group, spatial ): super().__init__() @@ -112,6 +113,7 @@ def __init__( layers_per_block=layers_per_block, act_fn=act_fn, norm_num_groups=norm_num_groups, + norm_type=norm_type, ) def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: @@ -131,8 +133,8 @@ def decode( quant, emb_loss, info = self.quantize(h) else: quant = h - quant = self.post_quant_conv(quant) - dec = self.decoder(quant) + quant2 = self.post_quant_conv(quant) + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) if not return_dict: return (dec,) diff --git a/src/diffusers/optimization.py b/src/diffusers/optimization.py index 657e085062e0..46e6125a0f55 100644 --- a/src/diffusers/optimization.py +++ b/src/diffusers/optimization.py @@ -34,6 +34,7 @@ class SchedulerType(Enum): POLYNOMIAL = "polynomial" CONSTANT = "constant" CONSTANT_WITH_WARMUP = "constant_with_warmup" + PIECEWISE_CONSTANT = "piecewise_constant" def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1): @@ -77,6 +78,48 @@ def lr_lambda(current_step: int): return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) +def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1): + """ + Create a schedule with a constant learning rate, using the learning rate set in optimizer. + + Args: + optimizer ([`~torch.optim.Optimizer`]): + The optimizer for which to schedule the learning rate. + step_rules (`string`): + The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate + if multiple 1 for the first 10 steps, mutiple 0.1 for the next 20 steps, multiple 0.01 for the next 30 + steps and multiple 0.005 for the other steps. + last_epoch (`int`, *optional*, defaults to -1): + The index of the last epoch when resuming training. + + Return: + `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. + """ + + rules_dict = {} + rule_list = step_rules.split(",") + for rule_str in rule_list[:-1]: + value_str, steps_str = rule_str.split(":") + steps = int(steps_str) + value = float(value_str) + rules_dict[steps] = value + last_lr_multiple = float(rule_list[-1]) + + def create_rules_function(rules_dict, last_lr_multiple): + def rule_func(steps: int) -> float: + sorted_steps = sorted(rules_dict.keys()) + for i, sorted_step in enumerate(sorted_steps): + if steps < sorted_step: + return rules_dict[sorted_steps[i]] + return last_lr_multiple + + return rule_func + + rules_func = create_rules_function(rules_dict, last_lr_multiple) + + return LambdaLR(optimizer, rules_func, last_epoch=last_epoch) + + def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after @@ -232,12 +275,14 @@ def lr_lambda(current_step: int): SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup, SchedulerType.CONSTANT: get_constant_schedule, SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup, + SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule, } def get_scheduler( name: Union[str, SchedulerType], optimizer: Optimizer, + step_rules: Optional[str] = None, num_warmup_steps: Optional[int] = None, num_training_steps: Optional[int] = None, num_cycles: int = 1, @@ -252,6 +297,8 @@ def get_scheduler( The name of the scheduler to use. optimizer (`torch.optim.Optimizer`): The optimizer that will be used during training. + step_rules (`str`, *optional*): + A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler. num_warmup_steps (`int`, *optional*): The number of warmup steps to do. This is not required by all schedulers (hence the argument being optional), the function will raise an error if it's unset and the scheduler type requires it. @@ -270,6 +317,9 @@ def get_scheduler( if name == SchedulerType.CONSTANT: return schedule_func(optimizer, last_epoch=last_epoch) + if name == SchedulerType.PIECEWISE_CONSTANT: + return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch) + # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 5c0c2337dc04..87709d5f616c 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -17,3 +17,13 @@ # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401 +from .utils import deprecate + + +deprecate( + "pipelines_utils", + "0.22.0", + "Importing `DiffusionPipeline` or `ImagePipelineOutput` from diffusers.pipeline_utils is deprecated. Please import from diffusers.pipelines.pipeline_utils instead.", + standard_warn=False, + stacklevel=3, +) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 10da653a1377..9e68538f233c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -44,6 +44,11 @@ else: from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .audioldm import AudioLDMPipeline + from .controlnet import ( + StableDiffusionControlNetImg2ImgPipeline, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionControlNetPipeline, + ) from .deepfloyd_if import ( IFImg2ImgPipeline, IFImg2ImgSuperResolutionPipeline, @@ -52,14 +57,20 @@ IFPipeline, IFSuperResolutionPipeline, ) + from .kandinsky import ( + KandinskyImg2ImgPipeline, + KandinskyInpaintPipeline, + KandinskyPipeline, + KandinskyPriorPipeline, + ) from .latent_diffusion import LDMTextToImagePipeline from .paint_by_example import PaintByExamplePipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .stable_diffusion import ( CycleDiffusionPipeline, StableDiffusionAttendAndExcitePipeline, - StableDiffusionControlNetPipeline, StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionImageVariationPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipeline, @@ -78,6 +89,7 @@ from .stable_diffusion_safe import StableDiffusionPipelineSafe from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline + from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder from .versatile_diffusion import ( VersatileDiffusionDualGuidedPipeline, VersatileDiffusionImageVariationPipeline, @@ -132,8 +144,8 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_flax_and_transformers_objects import * # noqa F403 else: + from .controlnet import FlaxStableDiffusionControlNetPipeline from .stable_diffusion import ( - FlaxStableDiffusionControlNetPipeline, FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index ff9474ffd43a..b79e4f72144b 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -22,7 +23,8 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring @@ -49,8 +51,23 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using Alt Diffusion. @@ -174,6 +191,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): @@ -288,6 +306,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -312,7 +331,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -369,7 +395,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -426,18 +452,29 @@ def _encode_prompt( return prompt_embeds def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -545,6 +582,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -605,6 +643,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: @@ -639,6 +682,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -647,6 +693,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps @@ -683,15 +730,20 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -699,24 +751,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: image = latents has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 10. Convert to PIL - image = self.numpy_to_pil(image) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - # 8. Post-processing - image = self.decode_latents(latents) + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index dee4a91924f7..5903f97aca36 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -25,7 +26,7 @@ from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring @@ -68,6 +69,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -89,7 +95,7 @@ def preprocess(image): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker -class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class AltDiffusionImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Alt Diffusion. @@ -202,6 +208,7 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -212,11 +219,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -298,6 +302,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -322,7 +327,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -379,7 +391,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -436,17 +448,32 @@ def _encode_prompt( return prompt_embeds def run_safety_checker(self, image, device, dtype): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + ( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead" + ), + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image def prepare_extra_step_kwargs(self, generator, eta): @@ -524,21 +551,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective" + f" batch size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - init_latents = self.vae.config.scaling_factor * init_latents + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -572,7 +604,14 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -595,9 +634,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of @@ -674,6 +714,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -682,6 +725,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image @@ -714,7 +758,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -722,7 +767,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -730,27 +775,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type not in ["latent", "pt", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) - output_type = "np" - - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: image = latents has_nsfw_concept = None + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - has_nsfw_concept = False + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py index 1df76ed6c52c..629a2e7d32ca 100644 --- a/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py +++ b/src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py @@ -51,21 +51,6 @@ def __init__( super().__init__() self.register_modules(unet=unet, scheduler=scheduler, mel=mel, vqvae=vqvae) - def get_input_dims(self) -> Tuple: - """Returns dimension of input image - - Returns: - `Tuple`: (height, width) - """ - input_module = self.vqvae if self.vqvae is not None else self.unet - # For backwards compatibility - sample_size = ( - (input_module.config.sample_size, input_module.config.sample_size) - if type(input_module.config.sample_size) == int - else input_module.config.sample_size - ) - return sample_size - def get_default_steps(self) -> int: """Returns default number of steps recommended for inference @@ -123,8 +108,6 @@ def __call__( # For backwards compatibility if type(self.unet.config.sample_size) == int: self.unet.config.sample_size = (self.unet.config.sample_size, self.unet.config.sample_size) - input_dims = self.get_input_dims() - self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0]) if noise is None: noise = randn_tensor( ( @@ -234,7 +217,7 @@ def encode(self, images: List[Image.Image], steps: int = 50) -> np.ndarray: sample = torch.Tensor(sample).to(self.device) for t in self.progress_bar(torch.flip(self.scheduler.timesteps, (0,))): - prev_timestep = t - self.scheduler.num_train_timesteps // self.scheduler.num_inference_steps + prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[prev_timestep] diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py new file mode 100644 index 000000000000..76ab63bdb116 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -0,0 +1,22 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_torch_available, + is_transformers_available, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 +else: + from .multicontrolnet import MultiControlNetModel + from .pipeline_controlnet import StableDiffusionControlNetPipeline + from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline + from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + + +if is_transformers_available() and is_flax_available(): + from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py new file mode 100644 index 000000000000..91d40b20124c --- /dev/null +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -0,0 +1,66 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from ...models.controlnet import ControlNetModel, ControlNetOutput +from ...models.modeling_utils import ModelMixin + + +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states, + image, + scale, + class_labels, + timestep_cond, + attention_mask, + cross_attention_kwargs, + guess_mode, + return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py new file mode 100644 index 000000000000..89398b6f01f9 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -0,0 +1,1008 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> image = np.array(image) + + >>> # get canny image + >>> image = cv2.Canny(image, 100, 200) + >>> image = image[:, :, None] + >>> image = np.concatenate([image, image, image], axis=2) + >>> canny_image = Image.fromarray(image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image + ... ).images[0] + ``` +""" + + +class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + elif isinstance(controlnet, MultiControlNetModel): + images = [] + + for image_ in image: + image_ = self.prepare_image( + image=image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + images.append(image_) + + image = images + height, width = image[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py new file mode 100644 index 000000000000..0e984d8ae5e3 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -0,0 +1,1098 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + + >>> # download an image + >>> image = load_image( + ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" + ... ) + >>> np_image = np.array(image) + + >>> # get canny image + >>> np_image = cv2.Canny(np_image, 100, 200) + >>> np_image = np_image[:, :, None] + >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2) + >>> canny_image = Image.fromarray(np_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking woman", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... ).images[0] + ``` +""" + + +def prepare_image(image): + if isinstance(image, torch.Tensor): + # Batch single image + if image.ndim == 3: + image = image.unsqueeze(0) + + image = image.to(dtype=torch.float32) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + return image + + +class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.8, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The initial image will be used as the starting point for the image generation process. Can also accpet + image latents as `image`, if passing latents directly, it will not be encoded again. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + # 4. Prepare image + image = self.image_processor.preprocess(image).to(dtype=torch.float32) + + # 5. Prepare controlnet_conditioning_image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, + latent_timestep, + batch_size, + num_images_per_prompt, + prompt_embeds.dtype, + device, + generator, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py new file mode 100644 index 000000000000..5ce2fd5543b8 --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -0,0 +1,1347 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion import StableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ... ) + >>> init_image = init_image.resize((512, 512)) + + >>> generator = torch.Generator(device="cpu").manual_seed(1) + + >>> mask_image = load_image( + ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ... ) + >>> mask_image = mask_image.resize((512, 512)) + + + >>> def make_inpaint_condition(image, image_mask): + ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + ... image[image_mask > 0.5] = -1.0 # set as masked pixel + ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + ... image = torch.from_numpy(image) + ... return image + + + >>> control_image = make_inpaint_condition(init_image, mask_image) + + >>> controlnet = ControlNetModel.from_pretrained( + ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> image = pipe( + ... "a handsome man with ray-ban sunglasses", + ... num_inference_steps=20, + ... generator=generator, + ... eta=1.0, + ... image=init_image, + ... mask_image=mask_image, + ... control_image=control_image, + ... ).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + + + + This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as + [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting) + as well as default text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5). + Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on + those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). + + + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + if isinstance(controlnet, (list, tuple)): + controlnet = MultiControlNetModel(controlnet) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + # the safety checker can offload the vae again + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # `prompt` needs more sophisticated handling when there are multiple + # conditionings. + if isinstance(self.controlnet, MultiControlNetModel): + if isinstance(prompt, list): + logger.warning( + f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" + " prompts. The conditionings will be fixed across the prompts." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + "For multiple controlnets: `image` must have the same length as the number of controlnets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + elif ( + isinstance(self.controlnet, MultiControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, MultiControlNetModel) + ): + if isinstance(controlnet_conditioning_scale, list): + if any(isinstance(i, list) for i in controlnet_conditioning_scale): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( + self.controlnet.nets + ): + raise ValueError( + "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" + " the same length as the number of controlnets" + ) + else: + assert False + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + "image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_control_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def prepare_mask_latents( + self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting + than for [`~StableDiffusionControlNetPipeline.__call__`]. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t])) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py new file mode 100644 index 000000000000..6003fc96b0ad --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -0,0 +1,537 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from functools import partial +from typing import Dict, List, Optional, Union + +import jax +import jax.numpy as jnp +import numpy as np +from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image +from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel + +from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel +from ...schedulers import ( + FlaxDDIMScheduler, + FlaxDPMSolverMultistepScheduler, + FlaxLMSDiscreteScheduler, + FlaxPNDMScheduler, +) +from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring +from ..pipeline_flax_utils import FlaxDiffusionPipeline +from ..stable_diffusion import FlaxStableDiffusionPipelineOutput +from ..stable_diffusion.safety_checker_flax import FlaxStableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import jax + >>> import numpy as np + >>> import jax.numpy as jnp + >>> from flax.jax_utils import replicate + >>> from flax.training.common_utils import shard + >>> from diffusers.utils import load_image + >>> from PIL import Image + >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel + + + >>> def image_grid(imgs, rows, cols): + ... w, h = imgs[0].size + ... grid = Image.new("RGB", size=(cols * w, rows * h)) + ... for i, img in enumerate(imgs): + ... grid.paste(img, box=(i % cols * w, i // cols * h)) + ... return grid + + + >>> def create_key(seed=0): + ... return jax.random.PRNGKey(seed) + + + >>> rng = create_key(0) + + >>> # get canny image + >>> canny_image = load_image( + ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" + ... ) + + >>> prompts = "best quality, extremely detailed" + >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" + + >>> # load control net and stable diffusion v1-5 + >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( + ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 + ... ) + >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( + ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 + ... ) + >>> params["controlnet"] = controlnet_params + + >>> num_samples = jax.device_count() + >>> rng = jax.random.split(rng, jax.device_count()) + + >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) + >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) + >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) + + >>> p_params = replicate(params) + >>> prompt_ids = shard(prompt_ids) + >>> negative_prompt_ids = shard(negative_prompt_ids) + >>> processed_image = shard(processed_image) + + >>> output = pipe( + ... prompt_ids=prompt_ids, + ... image=processed_image, + ... params=p_params, + ... prng_seed=rng, + ... num_inference_steps=50, + ... neg_prompt_ids=negative_prompt_ids, + ... jit=True, + ... ).images + + >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) + >>> output_images = image_grid(output_images, num_samples // 4, 4) + >>> output_images.save("generated_image.png") + ``` +""" + + +class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. + + This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`FlaxAutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`FlaxCLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), + specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`FlaxControlNetModel`]: + Provides additional conditioning to the unet during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or + [`FlaxDPMSolverMultistepScheduler`]. + safety_checker ([`FlaxStableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def __init__( + self, + vae: FlaxAutoencoderKL, + text_encoder: FlaxCLIPTextModel, + tokenizer: CLIPTokenizer, + unet: FlaxUNet2DConditionModel, + controlnet: FlaxControlNetModel, + scheduler: Union[ + FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler + ], + safety_checker: FlaxStableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + dtype: jnp.dtype = jnp.float32, + ): + super().__init__() + self.dtype = dtype + + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + def prepare_text_inputs(self, prompt: Union[str, List[str]]): + if not isinstance(prompt, (str, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + + return text_input.input_ids + + def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): + if not isinstance(image, (Image.Image, list)): + raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") + + if isinstance(image, Image.Image): + image = [image] + + processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) + + return processed_images + + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) + safety_model_params = unreplicate(safety_model_params) + else: + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts + + def _generate( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int, + guidance_scale: float, + latents: Optional[jnp.array] = None, + neg_prompt_ids: Optional[jnp.array] = None, + controlnet_conditioning_scale: float = 1.0, + ): + height, width = image.shape[-2:] + if height % 64 != 0 or width % 64 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") + + # get prompt text embeddings + prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] + + # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` + # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` + batch_size = prompt_ids.shape[0] + + max_length = prompt_ids.shape[-1] + + if neg_prompt_ids is None: + uncond_input = self.tokenizer( + [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" + ).input_ids + else: + uncond_input = neg_prompt_ids + negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] + context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) + + image = jnp.concatenate([image] * 2) + + latents_shape = ( + batch_size, + self.unet.config.in_channels, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if latents is None: + latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + def loop_body(step, args): + latents, scheduler_state = args + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + latents_input = jnp.concatenate([latents] * 2) + + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + timestep = jnp.broadcast_to(t, latents_input.shape[0]) + + latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) + + down_block_res_samples, mid_block_res_sample = self.controlnet.apply( + {"params": params["controlnet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + controlnet_cond=image, + conditioning_scale=controlnet_conditioning_scale, + return_dict=False, + ) + + # predict the noise residual + noise_pred = self.unet.apply( + {"params": params["unet"]}, + jnp.array(latents_input), + jnp.array(timestep, dtype=jnp.int32), + encoder_hidden_states=context, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents, scheduler_state + + scheduler_state = self.scheduler.set_timesteps( + params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape + ) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * params["scheduler"].init_noise_sigma + + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + + # scale and decode the image latents with vae + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample + + image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image + + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt_ids: jnp.array, + image: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.KeyArray, + num_inference_steps: int = 50, + guidance_scale: Union[float, jnp.array] = 7.5, + latents: jnp.array = None, + neg_prompt_ids: jnp.array = None, + controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, + return_dict: bool = True, + jit: bool = False, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt_ids (`jnp.array`): + The prompt or prompts to guide the image generation. + image (`jnp.array`): + Array representing the ControlNet input condition. ControlNet use this input condition to generate + guidance to Unet. + params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights + prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + + height, width = image.shape[-2:] + + if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale[:, None] + + if isinstance(controlnet_conditioning_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. + controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] + + if jit: + images = _p_generate( + self, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + else: + images = self._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + images = np.asarray(images) + has_nsfw_concept = False + + if not return_dict: + return (images, has_nsfw_concept) + + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), + static_broadcasted_argnums=(0, 5), +) +def _p_generate( + pipe, + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, +): + return pipe._generate( + prompt_ids, + image, + params, + prng_seed, + num_inference_steps, + guidance_scale, + latents, + neg_prompt_ids, + controlnet_conditioning_scale, + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest) + + +def preprocess(image, dtype): + image = image.convert("RGB") + w, h = image.size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) + image = jnp.array(image).astype(dtype) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + return image diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py index 479ffa9e6635..cd1015dc03bb 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if.py @@ -7,6 +7,7 @@ import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -85,7 +86,7 @@ """ -class IFPipeline(DiffusionPipeline): +class IFPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -793,7 +794,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -803,10 +805,13 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -829,7 +834,7 @@ def __call__( # 11. Apply watermark if self.watermarker is not None: - self.watermarker.apply_watermark(image, self.unet.config.sample_size) + image = self.watermarker.apply_watermark(image, self.unet.config.sample_size) elif output_type == "pt": nsfw_detected = None watermark_detected = None diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py index fac4adeea463..6bae2071173b 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py @@ -9,6 +9,7 @@ import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -109,7 +110,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: """ -class IFImg2ImgPipeline(DiffusionPipeline): +class IFImg2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -918,7 +919,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -928,10 +930,13 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index eed1bb43e5d8..0ee9c6ba331d 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -112,7 +113,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: """ -class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline): +class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -759,10 +760,10 @@ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device image = [image] if isinstance(image[0], PIL.Image.Image): - image = [np.array(i).astype(np.float32) / 255.0 for i in image] + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np - torch.from_numpy(image.transpose(0, 3, 1, 2)) + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: @@ -1036,7 +1037,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1046,10 +1048,13 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py index d3651f5169c1..9c1f71126ac5 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py @@ -9,6 +9,7 @@ import torch from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -112,7 +113,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: """ -class IFInpaintingPipeline(DiffusionPipeline): +class IFInpaintingPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -1033,7 +1034,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1043,12 +1045,15 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(model_input.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py index 5ea6a47082ae..6a90f2b765d4 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -114,7 +115,7 @@ def resize(images: PIL.Image.Image, img_size: int) -> PIL.Image.Image: """ -class IFInpaintingSuperResolutionPipeline(DiffusionPipeline): +class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -795,10 +796,10 @@ def preprocess_image(self, image: PIL.Image.Image, num_images_per_prompt, device image = [image] if isinstance(image[0], PIL.Image.Image): - image = [np.array(i).astype(np.float32) / 255.0 for i in image] + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np - torch.from_numpy(image.transpose(0, 3, 1, 2)) + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: @@ -1143,7 +1144,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -1153,12 +1155,15 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 prev_intermediate_images = intermediate_images intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] intermediate_images = (1 - mask_image) * prev_intermediate_images + mask_image * intermediate_images diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py index a62a51b0972f..86d9574b97e1 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py @@ -10,6 +10,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer +from ...loaders import LoraLoaderMixin from ...models import UNet2DConditionModel from ...schedulers import DDPMScheduler from ...utils import ( @@ -70,7 +71,7 @@ """ -class IFSuperResolutionPipeline(DiffusionPipeline): +class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin): tokenizer: T5Tokenizer text_encoder: T5EncoderModel @@ -664,10 +665,10 @@ def preprocess_image(self, image, num_images_per_prompt, device): image = [image] if isinstance(image[0], PIL.Image.Image): - image = [np.array(i).astype(np.float32) / 255.0 for i in image] + image = [np.array(i).astype(np.float32) / 127.5 - 1.0 for i in image] image = np.stack(image, axis=0) # to np - torch.from_numpy(image.transpose(0, 3, 1, 2)) + image = torch.from_numpy(image.transpose(0, 3, 1, 2)) elif isinstance(image[0], np.ndarray): image = np.stack(image, axis=0) # to np if image.ndim == 5: @@ -695,6 +696,8 @@ def preprocess_image(self, image, num_images_per_prompt, device): def __call__( self, prompt: Union[str, List[str]] = None, + height: int = None, + width: int = None, image: Union[PIL.Image.Image, np.ndarray, torch.FloatTensor] = None, num_inference_steps: int = 50, timesteps: List[int] = None, @@ -720,6 +723,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. image (`PIL.Image.Image`, `np.ndarray`, `torch.FloatTensor`): The image to be upscaled. num_inference_steps (`int`, *optional*, defaults to 50): @@ -806,8 +813,8 @@ def __call__( # 2. Define call parameters - height = self.unet.config.sample_size - width = self.unet.config.sample_size + height = height or self.unet.config.sample_size + width = width or self.unet.config.sample_size device = self._execution_device @@ -886,7 +893,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=noise_level, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -896,10 +904,13 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + if self.scheduler.config.variance_type not in ["learned", "learned_range"]: + noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1) + # compute the previous noisy sample x_t -> x_t-1 intermediate_images = self.scheduler.step( - noise_pred, t, intermediate_images, **extra_step_kwargs - ).prev_sample + noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/kandinsky/__init__.py b/src/diffusers/pipelines/kandinsky/__init__.py new file mode 100644 index 000000000000..c8eecba0c7f2 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/__init__.py @@ -0,0 +1,19 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import KandinskyPipeline, KandinskyPriorPipeline +else: + from .pipeline_kandinsky import KandinskyPipeline + from .pipeline_kandinsky_img2img import KandinskyImg2ImgPipeline + from .pipeline_kandinsky_inpaint import KandinskyInpaintPipeline + from .pipeline_kandinsky_prior import KandinskyPriorPipeline + from .text_encoder import MultilingualCLIP diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py new file mode 100644 index 000000000000..6de9cf4451de --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky.py @@ -0,0 +1,464 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import torch +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDIMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/Kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +class KandinskyPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ Decoder to generate the image from the latents. + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + movq: VQModel, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=77, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.unet, + self.text_encoder, + self.movq, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_image_embeds: Union[torch.FloatTensor, List[torch.FloatTensor]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + # YiYi notes: only reason this pipeline can't work with unclip scheduler is that can't pass down this argument + # need to use DDPM scheduler instead + # prev_timestep=prev_timestep, + generator=generator, + ).prev_sample + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py new file mode 100644 index 000000000000..f32528617e5a --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py @@ -0,0 +1,548 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDIMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyImg2ImgPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "A red cartoon frog, 4k" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyImg2ImgPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/frog.png" + ... ) + + >>> image = pipe( + ... prompt, + ... image=init_image, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... strength=0.2, + ... ).images + + >>> image[0].save("red_frog.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_image(pil_image, w=512, h=512): + pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) + arr = np.array(pil_image.convert("RGB")) + arr = arr.astype(np.float32) / 127.5 - 1 + arr = np.transpose(arr, [2, 0, 1]) + image = torch.from_numpy(arr).unsqueeze(0) + return image + + +class KandinskyImg2ImgPipeline(DiffusionPipeline): + """ + Pipeline for image-to-image generation using Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + movq=movq, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, latents, latent_timestep, shape, dtype, device, generator, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + + shape = latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + latents = self.add_noise(latents, noise, latent_timestep) + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.unet, + self.text_encoder, + self.movq, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # add_noise method to overwrite the one in schedule because it use a different beta schedule for adding noise vs sampling + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + betas = torch.linspace(0.0001, 0.02, 1000, dtype=torch.float32) + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod = alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + + return noisy_samples + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]], + image_embeds: torch.FloatTensor, + negative_image_embeds: torch.FloatTensor, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + strength: float = 0.3, + guidance_scale: float = 7.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor`, `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + strength (`float`, *optional*, defaults to 0.3): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + # 1. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + + # 2. get text and image embeddings + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # 3. pre-processing initial image + if not isinstance(image, list): + image = [image] + if not all(isinstance(i, (PIL.Image.Image, torch.Tensor)) for i in image): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support PIL image and pytorch tensor" + ) + + image = torch.cat([prepare_image(i, width, height) for i in image], dim=0) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + latents = self.movq.encode(image)["latents"] + latents = latents.repeat_interleave(num_images_per_prompt, dim=0) + + # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps_tensor, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + # the formular to calculate timestep for add_noise is taken from the original kandinsky repo + latent_timestep = int(self.scheduler.config.num_train_timesteps * strength) - 2 + + latent_timestep = torch.tensor([latent_timestep] * batch_size, dtype=timesteps_tensor.dtype, device=device) + + num_channels_latents = self.unet.config.in_channels + + height, width = get_new_h_w(height, width, self.movq_scale_factor) + + # 5. Create initial latent + latents = self.prepare_latents( + latents, + latent_timestep, + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + self.scheduler, + ) + + # 6. Denoising loop + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + # 7. post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py new file mode 100644 index 000000000000..04810ddb6e0a --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py @@ -0,0 +1,673 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + XLMRobertaTokenizer, +) + +from ...models import UNet2DConditionModel, VQModel +from ...pipelines import DiffusionPipeline +from ...pipelines.pipeline_utils import ImagePipelineOutput +from ...schedulers import DDIMScheduler +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from .text_encoder import MultilingualCLIP + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyInpaintPipeline, KandinskyPriorPipeline + >>> from diffusers.utils import load_image + >>> import torch + >>> import numpy as np + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> prompt = "a hat" + >>> image_emb, zero_image_emb = pipe_prior(prompt, return_dict=False) + + >>> pipe = KandinskyInpaintPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ... ) + >>> pipe.to("cuda") + + >>> init_image = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> mask = np.ones((768, 768), dtype=np.float32) + >>> mask[:250, 250:-250] = 0 + + >>> out = pipe( + ... prompt, + ... image=init_image, + ... mask_image=mask, + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=50, + ... ) + + >>> image = out.images[0] + >>> image.save("cat_with_hat.png") + ``` +""" + + +def get_new_h_w(h, w, scale_factor=8): + new_h = h // scale_factor**2 + if h % scale_factor**2 != 0: + new_h += 1 + new_w = w // scale_factor**2 + if w % scale_factor**2 != 0: + new_w += 1 + return new_h * scale_factor, new_w * scale_factor + + +def prepare_mask(masks): + prepared_masks = [] + for mask in masks: + old_mask = deepcopy(mask) + for i in range(mask.shape[1]): + for j in range(mask.shape[2]): + if old_mask[0][i][j] == 1: + continue + if i != 0: + mask[:, i - 1, j] = 0 + if j != 0: + mask[:, i, j - 1] = 0 + if i != 0 and j != 0: + mask[:, i - 1, j - 1] = 0 + if i != mask.shape[1] - 1: + mask[:, i + 1, j] = 0 + if j != mask.shape[2] - 1: + mask[:, i, j + 1] = 0 + if i != mask.shape[1] - 1 and j != mask.shape[2] - 1: + mask[:, i + 1, j + 1] = 0 + prepared_masks.append(mask) + return torch.stack(prepared_masks, dim=0) + + +def prepare_mask_and_masked_image(image, mask, height, width): + r""" + Prepares a pair (mask, image) to be consumed by the Kandinsky inpaint pipeline. This means that those inputs will + be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for + the ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=Image.BICUBIC, reducing_gap=1) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + return mask, image + + +class KandinskyInpaintPipeline(DiffusionPipeline): + """ + Pipeline for text-guided image inpainting using Kandinsky2.1 + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + text_encoder ([`MultilingualCLIP`]): + Frozen text-encoder. + tokenizer ([`XLMRobertaTokenizer`]): + Tokenizer of class + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `unet` to generate image latents. + unet ([`UNet2DConditionModel`]): + Conditional U-Net architecture to denoise the image embedding. + movq ([`VQModel`]): + MoVQ image encoder and decoder + """ + + def __init__( + self, + text_encoder: MultilingualCLIP, + movq: VQModel, + tokenizer: XLMRobertaTokenizer, + unet: UNet2DConditionModel, + scheduler: DDIMScheduler, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + movq=movq, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.movq_scale_factor = 2 ** (len(self.movq.config.block_out_channels) - 1) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + text_mask = text_inputs.attention_mask.to(device) + + prompt_embeds, text_encoder_hidden_states = self.text_encoder( + input_ids=text_input_ids, attention_mask=text_mask + ) + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=77, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + uncond_text_input_ids = uncond_input.input_ids.to(device) + uncond_text_mask = uncond_input.attention_mask.to(device) + + negative_prompt_embeds, uncond_text_encoder_hidden_states = self.text_encoder( + input_ids=uncond_text_input_ids, attention_mask=uncond_text_mask + ) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.unet, + self.text_encoder, + self.movq, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.movq]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + image_embeds: torch.FloatTensor, + negative_image_embeds: torch.FloatTensor, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 512, + num_inference_steps: int = 100, + guidance_scale: float = 4.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor`, `PIL.Image.Image` or `np.ndarray`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + mask_image (`PIL.Image.Image`,`torch.FloatTensor` or `np.ndarray`): + `Image`, or a tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. You can pass a pytorch tensor as mask only if the + image you passed is a pytorch tensor, and it should contain one color channel (L) instead of 3, so the + expected shape would be either `(B, 1, H, W,)`, `(B, H, W)`, `(1, H, W)` or `(H, W)` If image is an PIL + image or numpy array, mask should also be a either PIL image or numpy array. If it is a PIL image, it + will be converted to a single channel (luminance) before use. If it is a nummpy array, the expected + shape is `(H, W)`. + image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for text prompt, that will be used to condition the image generation. + negative_image_embeds (`torch.FloatTensor` or `List[torch.FloatTensor]`): + The clip image embeddings for negative text prompt, will be used to condition the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"` + (`np.array`) or `"pt"` (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple` + """ + + # Define call parameters + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + prompt_embeds, text_encoder_hidden_states, _ = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + if isinstance(image_embeds, list): + image_embeds = torch.cat(image_embeds, dim=0) + if isinstance(negative_image_embeds, list): + negative_image_embeds = torch.cat(negative_image_embeds, dim=0) + + if do_classifier_free_guidance: + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + negative_image_embeds = negative_image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0).to( + dtype=prompt_embeds.dtype, device=device + ) + + # preprocess image and mask + mask_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + + image = image.to(dtype=prompt_embeds.dtype, device=device) + image = self.movq.encode(image)["latents"] + + mask_image = mask_image.to(dtype=prompt_embeds.dtype, device=device) + + image_shape = tuple(image.shape[-2:]) + mask_image = F.interpolate( + mask_image, + image_shape, + mode="nearest", + ) + mask_image = prepare_mask(mask_image) + masked_image = image * mask_image + + mask_image = mask_image.repeat_interleave(num_images_per_prompt, dim=0) + masked_image = masked_image.repeat_interleave(num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + mask_image = mask_image.repeat(2, 1, 1, 1) + masked_image = masked_image.repeat(2, 1, 1, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps_tensor = self.scheduler.timesteps + + num_channels_latents = self.movq.config.latent_channels + + # get h, w for latents + sample_height, sample_width = get_new_h_w(height, width, self.movq_scale_factor) + + # create initial latent + latents = self.prepare_latents( + (batch_size, num_channels_latents, sample_height, sample_width), + text_encoder_hidden_states.dtype, + device, + generator, + latents, + self.scheduler, + ) + + # Check that sizes of mask, masked image and latents match with expected + num_channels_mask = mask_image.shape[1] + num_channels_masked_image = masked_image.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + + for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latent_model_input, masked_image, mask_image], dim=1) + + added_cond_kwargs = {"text_embeds": prompt_embeds, "image_embeds": image_embeds} + noise_pred = self.unet( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + if do_classifier_free_guidance: + noise_pred, variance_pred = noise_pred.split(latents.shape[1], dim=1) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + _, variance_pred_text = variance_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, variance_pred_text], dim=1) + + if not ( + hasattr(self.scheduler.config, "variance_type") + and self.scheduler.config.variance_type in ["learned", "learned_range"] + ): + noise_pred, _ = noise_pred.split(latents.shape[1], dim=1) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t, + latents, + generator=generator, + ).prev_sample + + # post-processing + image = self.movq.decode(latents, force_not_quantize=True)["sample"] + + if output_type not in ["pt", "np", "pil"]: + raise ValueError(f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}") + + if output_type in ["np", "pil"]: + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py new file mode 100644 index 000000000000..a0208d5858b1 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py @@ -0,0 +1,578 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection + +from ...models import PriorTransformer +from ...pipelines import DiffusionPipeline +from ...schedulers import UnCLIPScheduler +from ...utils import ( + BaseOutput, + is_accelerate_available, + logging, + randn_tensor, + replace_example_docstring, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPipeline, KandinskyPriorPipeline + >>> import torch + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained("kandinsky-community/kandinsky-2-1-prior") + >>> pipe_prior.to("cuda") + + >>> prompt = "red cat, 4k photo" + >>> out = pipe_prior(prompt) + >>> image_emb = out.image_embeds + >>> negative_image_emb = out.negative_image_embeds + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1") + >>> pipe.to("cuda") + + >>> image = pipe( + ... prompt, + ... image_embeds=image_emb, + ... negative_image_embeds=negative_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=100, + ... ).images + + >>> image[0].save("cat.png") + ``` +""" + +EXAMPLE_INTERPOLATE_DOC_STRING = """ + Examples: + ```py + >>> from diffusers import KandinskyPriorPipeline, KandinskyPipeline + >>> from diffusers.utils import load_image + >>> import PIL + + >>> import torch + >>> from torchvision import transforms + + >>> pipe_prior = KandinskyPriorPipeline.from_pretrained( + ... "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ... ) + >>> pipe_prior.to("cuda") + + >>> img1 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/cat.png" + ... ) + + >>> img2 = load_image( + ... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + ... "/kandinsky/starry_night.jpeg" + ... ) + + >>> images_texts = ["a cat", img1, img2] + >>> weights = [0.3, 0.3, 0.4] + >>> image_emb, zero_image_emb = pipe_prior.interpolate(images_texts, weights) + + >>> pipe = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) + >>> pipe.to("cuda") + + >>> image = pipe( + ... "", + ... image_embeds=image_emb, + ... negative_image_embeds=zero_image_emb, + ... height=768, + ... width=768, + ... num_inference_steps=150, + ... ).images[0] + + >>> image.save("starry_cat.png") + ``` +""" + + +@dataclass +class KandinskyPriorPipelineOutput(BaseOutput): + """ + Output class for KandinskyPriorPipeline. + + Args: + image_embeds (`torch.FloatTensor`) + clip image embeddings for text prompt + negative_image_embeds (`List[PIL.Image.Image]` or `np.ndarray`) + clip image embeddings for unconditional tokens + """ + + image_embeds: Union[torch.FloatTensor, np.ndarray] + negative_image_embeds: Union[torch.FloatTensor, np.ndarray] + + +class KandinskyPriorPipeline(DiffusionPipeline): + """ + Pipeline for generating image prior for Kandinsky + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + image_encoder ([`CLIPVisionModelWithProjection`]): + Frozen image-encoder. + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + scheduler ([`UnCLIPScheduler`]): + A scheduler to be used in combination with `prior` to generate image embedding. + """ + + def __init__( + self, + prior: PriorTransformer, + image_encoder: CLIPVisionModelWithProjection, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + scheduler: UnCLIPScheduler, + image_processor: CLIPImageProcessor, + ): + super().__init__() + + self.register_modules( + prior=prior, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + image_encoder=image_encoder, + image_processor=image_processor, + ) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INTERPOLATE_DOC_STRING) + def interpolate( + self, + images_and_prompts: List[Union[str, PIL.Image.Image, torch.FloatTensor]], + weights: List[float], + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + negative_prior_prompt: Optional[str] = None, + negative_prompt: Union[str] = "", + guidance_scale: float = 4.0, + device=None, + ): + """ + Function invoked when using the prior pipeline for interpolation. + + Args: + images_and_prompts (`List[Union[str, PIL.Image.Image, torch.FloatTensor]]`): + list of prompts and images to guide the image generation. + weights: (`List[float]`): + list of weights for each condition in `images_and_prompts` + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + negative_prior_prompt (`str`, *optional*): + The prompt not to guide the prior diffusion process. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. Ignored when not using guidance (i.e., ignored if + `guidance_scale` is less than `1`). + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + device = device or self.device + + if len(images_and_prompts) != len(weights): + raise ValueError( + f"`images_and_prompts` contains {len(images_and_prompts)} items and `weights` contains {len(weights)} items - they should be lists of same length" + ) + + image_embeddings = [] + for cond, weight in zip(images_and_prompts, weights): + if isinstance(cond, str): + image_emb = self( + cond, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ).image_embeds + + elif isinstance(cond, (PIL.Image.Image, torch.Tensor)): + if isinstance(cond, PIL.Image.Image): + cond = ( + self.image_processor(cond, return_tensors="pt") + .pixel_values[0] + .unsqueeze(0) + .to(dtype=self.image_encoder.dtype, device=device) + ) + + image_emb = self.image_encoder(cond)["image_embeds"] + + else: + raise ValueError( + f"`images_and_prompts` can only contains elements to be of type `str`, `PIL.Image.Image` or `torch.Tensor` but is {type(cond)}" + ) + + image_embeddings.append(image_emb * weight) + + image_emb = torch.cat(image_embeddings).sum(dim=0, keepdim=True) + + out_zero = self( + negative_prompt, + num_inference_steps=num_inference_steps, + num_images_per_prompt=num_images_per_prompt, + generator=generator, + latents=latents, + negative_prompt=negative_prior_prompt, + guidance_scale=guidance_scale, + ) + zero_image_emb = out_zero.negative_image_embeds if negative_prompt == "" else out_zero.image_embeds + + return KandinskyPriorPipelineOutput(image_embeds=image_emb, negative_image_embeds=zero_image_emb) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def get_zero_embed(self, batch_size=1, device=None): + device = device or self.device + zero_img = torch.zeros(1, 3, self.image_encoder.config.image_size, self.image_encoder.config.image_size).to( + device=device, dtype=self.image_encoder.dtype + ) + zero_image_emb = self.image_encoder(zero_img)["image_embeds"] + zero_image_emb = zero_image_emb.repeat(batch_size, 1) + return zero_image_emb + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + models = [ + self.image_encoder, + self.text_encoder, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.text_encoder, "_hf_hook"): + return self.device + for module in self.text_encoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + ): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + prompt_embeds = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + negative_prompt_embeds_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + negative_prompt_embeds = negative_prompt_embeds_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return prompt_embeds, text_encoder_hidden_states, text_mask + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + num_inference_steps: int = 25, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + guidance_scale: float = 4.0, + output_type: Optional[str] = "pt", # pt only + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generate image. Choose between: `"np"` (`np.array`) or `"pt"` + (`torch.Tensor`). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + + Examples: + + Returns: + [`KandinskyPriorPipelineOutput`] or `tuple` + """ + + if isinstance(prompt, str): + prompt = [prompt] + elif not isinstance(prompt, list): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif not isinstance(negative_prompt, list) and negative_prompt is not None: + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + # if the negative prompt is defined we double the batch size to + # directly retrieve the negative prompt embedding + if negative_prompt is not None: + prompt = prompt + negative_prompt + negative_prompt = 2 * negative_prompt + + device = self._execution_device + + batch_size = len(prompt) + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = guidance_scale > 1.0 + prompt_embeds, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # prior + self.scheduler.set_timesteps(num_inference_steps, device=device) + prior_timesteps_tensor = self.scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + latents = self.prepare_latents( + (batch_size, embedding_dim), + prompt_embeds.dtype, + device, + generator, + latents, + self.scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=prompt_embeds, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + latents = self.scheduler.step( + predicted_image_embedding, + timestep=t, + sample=latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + latents = self.prior.post_process_latents(latents) + + image_embeddings = latents + + # if negative prompt has been defined, we retrieve split the image embedding into two + if negative_prompt is None: + zero_embeds = self.get_zero_embed(latents.shape[0], device=latents.device) + else: + image_embeddings, zero_embeds = image_embeddings.chunk(2) + + if output_type not in ["pt", "np"]: + raise ValueError(f"Only the output types `pt` and `np` are supported not output_type={output_type}") + + if output_type == "np": + image_embeddings = image_embeddings.cpu().numpy() + zero_embeds = zero_embeds.cpu().numpy() + + if not return_dict: + return (image_embeddings, zero_embeds) + + return KandinskyPriorPipelineOutput(image_embeds=image_embeddings, negative_image_embeds=zero_embeds) diff --git a/src/diffusers/pipelines/kandinsky/text_encoder.py b/src/diffusers/pipelines/kandinsky/text_encoder.py new file mode 100644 index 000000000000..caa0029f00ca --- /dev/null +++ b/src/diffusers/pipelines/kandinsky/text_encoder.py @@ -0,0 +1,27 @@ +import torch +from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel + + +class MCLIPConfig(XLMRobertaConfig): + model_type = "M-CLIP" + + def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): + self.transformerDimensions = transformerDimSize + self.numDims = imageDimSize + super().__init__(**kwargs) + + +class MultilingualCLIP(PreTrainedModel): + config_class = MCLIPConfig + + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + self.transformer = XLMRobertaModel(config) + self.LinearTransformation = torch.nn.Linear( + in_features=config.transformerDimensions, out_features=config.numDims + ) + + def forward(self, input_ids, attention_mask): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] + return self.LinearTransformation(embs2), embs diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index ca0a90a5b5ca..c8f3e8a9ee11 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -22,6 +23,7 @@ from diffusers.utils import is_accelerate_available +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...utils import logging, randn_tensor @@ -184,6 +186,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -226,13 +229,17 @@ def _execution_device(self): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -255,8 +262,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -316,17 +328,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -355,6 +357,21 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + return image_latents + def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype @@ -560,15 +577,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 11. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None - # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 13. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 8c028b64a8c8..ed95163087a8 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -30,6 +30,7 @@ import torch from huggingface_hub import hf_hub_download, model_info, snapshot_download from packaging import version +from requests.exceptions import HTTPError from tqdm.auto import tqdm import diffusers @@ -296,8 +297,7 @@ def maybe_raise_or_warn( if not issubclass(model_cls, expected_class_obj): raise ValueError( - f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" - f" {expected_class_obj}" + f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" ) else: logger.warning( @@ -355,6 +355,9 @@ def load_sub_model( provider: Any, sess_options: Any, device_map: Optional[Union[Dict[str, torch.device], str]], + max_memory: Optional[Dict[Union[int, str], Union[int, str]]], + offload_folder: Optional[Union[str, os.PathLike]], + offload_state_dict: bool, model_variants: Dict[str, str], name: str, from_flax: bool, @@ -417,6 +420,9 @@ def load_sub_model( # This makes sure that the weights won't be initialized which significantly speeds up loading. if is_diffusers_model or is_transformers_model: loading_kwargs["device_map"] = device_map + loading_kwargs["max_memory"] = max_memory + loading_kwargs["offload_folder"] = offload_folder + loading_kwargs["offload_state_dict"] = offload_state_dict loading_kwargs["variant"] = model_variants.pop(name, None) if from_flax: loading_kwargs["from_flax"] = True @@ -479,25 +485,31 @@ def register_modules(self, **kwargs): if module is None: register_dict = {name: (None, None)} else: - # register the original module, not the dynamo compiled one + # register the config from the original module, not the dynamo compiled one if is_compiled_module(module): - module = module._orig_mod + not_compiled_module = module._orig_mod + else: + not_compiled_module = module - library = module.__module__.split(".")[0] + library = not_compiled_module.__module__.split(".")[0] # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None - path = module.__module__.split(".") + module_path_items = not_compiled_module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + + path = not_compiled_module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: + if is_pipeline_module: library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = not_compiled_module.__module__ # retrieve class_name - class_name = module.__class__.__name__ + class_name = not_compiled_module.__class__.__name__ register_dict = {name: (library, class_name)} @@ -531,7 +543,7 @@ def save_pretrained( """ Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading - method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method. + method. The pipeline can easily be re-loaded using the [`~DiffusionPipeline.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): @@ -809,15 +821,24 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): Speed up model loading by not initializing the weights and only loading the pre-trained weights. This also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, setting this argument to `True` will raise an error. - use_safetensors (`bool`, *optional* ): - If set to `True`, the pipeline will be loaded from `safetensors` weights. If set to `None` (the - default). The pipeline will load using `safetensors` if the safetensors weights are available *and* if - `safetensors` is installed. If the to `False` the pipeline will *not* use `safetensors`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the pipeline will load the `safetensors` weights if they're available **and** if the + `safetensors` library is installed. If set to `True`, the pipeline will forcibly load the models from + `safetensors` weights. If set to `False` the pipeline will *not* use `safetensors`. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the specific pipeline class. The overwritten components are then directly passed to the pipelines @@ -874,6 +895,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P provider = kwargs.pop("provider", None) sess_options = kwargs.pop("sess_options", None) device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) @@ -1022,7 +1046,7 @@ def load_module(name, value): # 6.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) - importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name] + importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None # 6.3 Use passed sub model or load class_name from library_name @@ -1047,6 +1071,9 @@ def load_module(name, value): provider=provider, sess_options=sess_options, device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, model_variants=model_variants, name=name, from_flax=from_flax, @@ -1075,8 +1102,8 @@ def load_module(name, value): return_cached_folder = kwargs.pop("return_cached_folder", False) if return_cached_folder: - message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.17.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`." - deprecate("return_cached_folder", "0.17.0", message) + message = f"Passing `return_cached_folder=True` is deprecated and will be removed in `diffusers=0.18.0`. Please do the following instead: \n 1. Load the cached_folder via `cached_folder={cls}.download({pretrained_model_name_or_path})`. \n 2. Load the pipeline by loading from the cached folder: `pipeline={cls}.from_pretrained(cached_folder)`." + deprecate("return_cached_folder", "0.18.0", message) return model, cached_folder return model @@ -1084,95 +1111,78 @@ def load_module(name, value): @classmethod def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: r""" - Download and cache a PyTorch diffusion pipeline from pre-trained pipeline weights. + Download and cache a PyTorch diffusion pipeline from pretrained pipeline weights. Parameters: - pretrained_model_name (`str` or `os.PathLike`, *optional*): - Should be a string, the *repo id* of a pretrained pipeline hosted inside a model repo on - https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like - `CompVis/ldm-text2im-large-256`. + pretrained_model_name (`str` or `os.PathLike`, *optional*): + A string, the repository id (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline + hosted on the Hub. custom_pipeline (`str`, *optional*): - - - - This is an experimental feature and is likely to change in the future. - - - Can be either: - - A string, the *repo id* of a custom pipeline hosted inside a model repo on - https://huggingface.co/. Valid repo ids have to be located under a user or organization name, - like `hf-internal-testing/diffusers-dummy-pipeline`. - - - - It is required that the model repo has a file, called `pipeline.py` that defines the custom - pipeline. - - + - A string, the repository id (for example `CompVis/ldm-text2im-large-256`) of a pretrained + pipeline hosted on the Hub. The repository must contain a file called `pipeline.py` that defines + the custom pipeline. - A string, the *file name* of a community pipeline hosted on GitHub under - https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to - match exactly the file name without `.py` located under the above link, *e.g.* - `clip_guided_stable_diffusion`. + [Community](https://github.com/huggingface/diffusers/tree/main/examples/community). Valid file + names must match the file name and not the pipeline script (`clip_guided_stable_diffusion` + instead of `clip_guided_stable_diffusion.py`). Community pipelines are always loaded from the + current `main` branch of GitHub. - + - A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory + must contain a file called `pipeline.py` that defines the custom pipeline. - Community pipelines are always loaded from the current `main` branch of GitHub. - - - - - A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`. - - + - It is required that the directory has a file, called `pipeline.py` that defines the custom - pipeline. + 🧪 This is an experimental feature and may change in the future. - + - For more information on how to load and create custom pipelines, please have a look at [Loading and - Adding Custom - Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview) + For more information on how to load and create custom pipelines, take a look at [How to contribute a + community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. + Whether or not to resume downloading the model weights and configuration files. If set to False, any + incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). + Whether to only load local model weights and configuration files or not. If set to True, the model + won’t be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `huggingface-cli login` (stored in `~/.huggingface`). + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub): The specific model version to use. It can be a branch name, a tag name, or a commit id similar to `revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a custom pipeline from GitHub. mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. specify the folder name here. + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. + Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. + + Returns: + `os.PathLike`: + A path to the downloaded pipeline. - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models) + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. @@ -1204,6 +1214,17 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: allow_patterns = None ignore_patterns = None + if not local_files_only: + try: + info = model_info( + pretrained_model_name, + use_auth_token=use_auth_token, + revision=revision, + ) + except HTTPError as e: + logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") + local_files_only = True + if not local_files_only: config_file = hf_hub_download( pretrained_model_name, @@ -1215,11 +1236,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: resume_download=resume_download, use_auth_token=use_auth_token, ) - info = model_info( - pretrained_model_name, - use_auth_token=use_auth_token, - revision=revision, - ) config_dict = cls._dict_from_json_file(config_file) @@ -1238,7 +1254,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # if the whole pipeline is cached we don't have to ping the Hub if revision in DEPRECATED_REVISION_ARGS and version.parse( version.parse(__version__).base_version - ) >= version.parse("0.17.0"): + ) >= version.parse("0.18.0"): warn_deprecated_model_variant( pretrained_model_name, use_auth_token, variant, revision, model_filenames ) @@ -1250,7 +1266,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: # allow all patterns from non-model folders # this enables downloading schedulers, tokenizers, ... - allow_patterns += [os.path.join(k, "*") for k in folder_names if k not in model_folder_names] + allow_patterns += [f"{k}/*" for k in folder_names if k not in model_folder_names] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] diff --git a/src/diffusers/pipelines/repaint/pipeline_repaint.py b/src/diffusers/pipelines/repaint/pipeline_repaint.py index f4914c46db51..d2aa1d4f1f77 100644 --- a/src/diffusers/pipelines/repaint/pipeline_repaint.py +++ b/src/diffusers/pipelines/repaint/pipeline_repaint.py @@ -13,6 +13,7 @@ # limitations under the License. +import warnings from typing import List, Optional, Tuple, Union import numpy as np @@ -30,6 +31,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def _preprocess_image(image: Union[List, PIL.Image.Image, torch.Tensor]): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 60a6f1e70f4a..3ff7b8ee460b 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -33,7 +33,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): unet: UNet2DModel scheduler: ScoreSdeVeScheduler - def __init__(self, unet: UNet2DModel, scheduler: DiffusionPipeline): + def __init__(self, unet: UNet2DModel, scheduler: ScoreSdeVeScheduler): super().__init__() self.register_modules(unet=unet, scheduler=scheduler) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 3d5374875d12..911a5018de18 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -1,15 +1,17 @@ import inspect +import warnings from itertools import repeat from typing import Callable, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel -from ...pipeline_utils import DiffusionPipeline from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging, randn_tensor +from ..pipeline_utils import DiffusionPipeline from . import SemanticStableDiffusionPipelineOutput @@ -129,12 +131,33 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -681,20 +704,19 @@ def __call__( callback(i, t, latents) # 8. Post-processing - image = self.decode_latents(latents) - - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( - self.device - ) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) else: + image = latents has_nsfw_concept = None - if output_type == "pil": - image = self.numpy_to_pil(image) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index be4c5d942b2e..66df9a811afb 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -61,7 +61,7 @@ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -image = pipe(prompt).sample[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -80,7 +80,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -image = pipe(prompt).sample[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -99,7 +99,7 @@ pipe = StableDiffusionPipeline.from_pretrained( ).to("cuda") prompt = "a photo of an astronaut riding a horse on mars" -image = pipe(prompt).sample[0] +image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 6bc2b58b5fef..f39ae67a9aff 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -45,7 +45,6 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_cycle_diffusion import CycleDiffusionPipeline from .pipeline_stable_diffusion import StableDiffusionPipeline from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline - from .pipeline_stable_diffusion_controlnet import StableDiffusionControlNetPipeline from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy @@ -75,10 +74,12 @@ class StableDiffusionPipelineOutput(BaseOutput): except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import ( StableDiffusionDepth2ImgPipeline, + StableDiffusionDiffEditPipeline, StableDiffusionPix2PixZeroPipeline, ) else: from .pipeline_stable_diffusion_depth2img import StableDiffusionDepth2ImgPipeline + from .pipeline_stable_diffusion_diffedit import StableDiffusionDiffEditPipeline from .pipeline_stable_diffusion_pix2pix_zero import StableDiffusionPix2PixZeroPipeline @@ -128,7 +129,6 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput): from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline - from .pipeline_flax_stable_diffusion_controlnet import FlaxStableDiffusionControlNetPipeline from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 5961636dd197..e59b91e486f5 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -140,17 +140,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) @@ -204,8 +204,12 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: + is_attn_weight = "proj_attn.weight" in new_path or ("attentions" in new_path and "to_" in new_path) + shape = old_checkpoint[path["old"]].shape + if is_attn_weight and len(shape) == 3: checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + elif is_attn_weight and len(shape) == 4: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] @@ -335,41 +339,46 @@ def create_ldm_bert_config(original_config): return config -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): +def convert_ldm_unet_checkpoint( + checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False +): """ Takes a state dict and a config, and returns a converted checkpoint. """ - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." + if skip_extract_state_dict: + unet_state_dict = checkpoint else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - print(f"Checkpoint {path} has both EMA and non-EMA weights.") - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: + # extract state_dict for UNet + unet_state_dict = {} + keys = list(checkpoint.keys()) + + if controlnet: + unet_key = "control_model." + else: + unet_key = "model.diffusion_model." + + # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA + if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: + print(f"Checkpoint {path} has both EMA and non-EMA weights.") print( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" + " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." ) + for key in keys: + if key.startswith("model.diffusion_model"): + flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) + else: + if sum(k.startswith("model_ema") for k in keys) > 100: + print( + "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" + " weights (usually better for inference), please make sure to add the `--extract_ema` flag." + ) - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) new_checkpoint = {} @@ -723,8 +732,8 @@ def _copy_layers(hf_layers, pt_layers): return hf_model -def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") +def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False): + text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) keys = list(checkpoint.keys()) @@ -952,17 +961,42 @@ def stable_unclip_image_noising_components( def convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=None, + cross_attention_dim=None, ): ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) ctrlnet_config["upcast_attention"] = upcast_attention ctrlnet_config.pop("sample_size") + if use_linear_projection is not None: + ctrlnet_config["use_linear_projection"] = use_linear_projection + + if cross_attention_dim is not None: + ctrlnet_config["cross_attention_dim"] = cross_attention_dim + controlnet_model = ControlNetModel(**ctrlnet_config) + # Some controlnet ckpt files are distributed independently from the rest of the + # model components i.e. https://huggingface.co/thibaud/controlnet-sd21/ + if "time_embed.0.weight" in checkpoint: + skip_extract_state_dict = True + else: + skip_extract_state_dict = False + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + checkpoint, + ctrlnet_config, + path=checkpoint_path, + extract_ema=extract_ema, + controlnet=True, + skip_extract_state_dict=skip_extract_state_dict, ) controlnet_model.load_state_dict(converted_ctrl_checkpoint) @@ -988,6 +1022,7 @@ def download_from_original_stable_diffusion_ckpt( controlnet: Optional[bool] = None, load_safety_checker: bool = True, pipeline_class: DiffusionPipeline = None, + local_files_only=False, ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1033,6 +1068,8 @@ def download_from_original_stable_diffusion_ckpt( Whether to load the safety checker or not. Defaults to `True`. pipeline_class (`str`, *optional*, defaults to `None`): The pipeline class to use. Pass `None` to determine automatically. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). return: A StableDiffusionPipeline object representing the passed-in `.ckpt`/`.safetensors` file. """ @@ -1288,7 +1325,7 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor=feature_extractor, ) elif model_type == "FrozenCLIPEmbedder": - text_model = convert_ldm_clip_checkpoint(checkpoint) + text_model = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") if load_safety_checker: @@ -1337,6 +1374,8 @@ def download_controlnet_from_original_ckpt( upcast_attention: Optional[bool] = None, device: str = None, from_safetensors: bool = False, + use_linear_projection: Optional[bool] = None, + cross_attention_dim: Optional[bool] = None, ) -> DiffusionPipeline: if not is_omegaconf_available(): raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) @@ -1374,7 +1413,14 @@ def download_controlnet_from_original_ckpt( raise ValueError("`control_stage_config` not present in original config") controlnet_model = convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + checkpoint, + original_config, + checkpoint_path, + image_size, + upcast_attention, + extract_ema, + use_linear_projection=use_linear_projection, + cross_attention_dim=cross_attention_dim, ) return controlnet_model diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index e2accb6d2d2a..b8360f512405 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -13,7 +13,8 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +import warnings +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -24,7 +25,8 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor @@ -38,6 +40,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -119,7 +126,7 @@ def compute_noise(scheduler, prev_latents, latents, timestep, noise_pred, eta): return noise -class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image to image generation using Stable Diffusion. @@ -220,6 +227,8 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -306,6 +315,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -330,7 +340,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -387,7 +404,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -504,19 +521,28 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -536,21 +562,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = image.shape[0] - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) - init_latents = self.vae.config.scaling_factor * init_latents + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -586,7 +617,14 @@ def __call__( self, prompt: Union[str, List[str]], source_prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -599,6 +637,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -606,9 +645,10 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image generation. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of @@ -654,6 +694,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -674,19 +718,23 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, num_images_per_prompt, do_classifier_free_guidance, prompt_embeds=prompt_embeds, + lora_scale=text_encoder_lora_scale, ) source_prompt_embeds = self._encode_prompt( source_prompt, device, num_images_per_prompt, do_classifier_free_guidance, None ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -733,7 +781,10 @@ def __call__( dim=0, ) concat_noise_pred = self.unet( - concat_latent_model_input, t, encoder_hidden_states=concat_prompt_embeds + concat_latent_model_input, + t, + cross_attention_kwargs=cross_attention_kwargs, + encoder_hidden_states=concat_prompt_embeds, ).sample # perform guidance @@ -770,14 +821,19 @@ def __call__( callback(i, t, latents) # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py index 7035242a0cda..bec2424ece4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_controlnet.py @@ -12,526 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings -from functools import partial -from typing import Dict, List, Optional, Union +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works -import jax -import jax.numpy as jnp -import numpy as np -from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate -from flax.training.common_utils import shard -from PIL import Image -from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel +from ...utils import deprecate +from ..controlnet.pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline # noqa: F401 -from ...models import FlaxAutoencoderKL, FlaxControlNetModel, FlaxUNet2DConditionModel -from ...schedulers import ( - FlaxDDIMScheduler, - FlaxDPMSolverMultistepScheduler, - FlaxLMSDiscreteScheduler, - FlaxPNDMScheduler, -) -from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring -from ..pipeline_flax_utils import FlaxDiffusionPipeline -from . import FlaxStableDiffusionPipelineOutput -from .safety_checker_flax import FlaxStableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -# Set to True to use python for loop instead of jax.fori_loop for easier debugging -DEBUG = False - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> import jax - >>> import numpy as np - >>> import jax.numpy as jnp - >>> from flax.jax_utils import replicate - >>> from flax.training.common_utils import shard - >>> from diffusers.utils import load_image - >>> from PIL import Image - >>> from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel - - - >>> def image_grid(imgs, rows, cols): - ... w, h = imgs[0].size - ... grid = Image.new("RGB", size=(cols * w, rows * h)) - ... for i, img in enumerate(imgs): - ... grid.paste(img, box=(i % cols * w, i // cols * h)) - ... return grid - - - >>> def create_key(seed=0): - ... return jax.random.PRNGKey(seed) - - - >>> rng = create_key(0) - - >>> # get canny image - >>> canny_image = load_image( - ... "https://huggingface.co/datasets/YiYiXu/test-doc-assets/resolve/main/blog_post_cell_10_output_0.jpeg" - ... ) - - >>> prompts = "best quality, extremely detailed" - >>> negative_prompts = "monochrome, lowres, bad anatomy, worst quality, low quality" - - >>> # load control net and stable diffusion v1-5 - >>> controlnet, controlnet_params = FlaxControlNetModel.from_pretrained( - ... "lllyasviel/sd-controlnet-canny", from_pt=True, dtype=jnp.float32 - ... ) - >>> pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, revision="flax", dtype=jnp.float32 - ... ) - >>> params["controlnet"] = controlnet_params - - >>> num_samples = jax.device_count() - >>> rng = jax.random.split(rng, jax.device_count()) - - >>> prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples) - >>> negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples) - >>> processed_image = pipe.prepare_image_inputs([canny_image] * num_samples) - - >>> p_params = replicate(params) - >>> prompt_ids = shard(prompt_ids) - >>> negative_prompt_ids = shard(negative_prompt_ids) - >>> processed_image = shard(processed_image) - - >>> output = pipe( - ... prompt_ids=prompt_ids, - ... image=processed_image, - ... params=p_params, - ... prng_seed=rng, - ... num_inference_steps=50, - ... neg_prompt_ids=negative_prompt_ids, - ... jit=True, - ... ).images - - >>> output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:]))) - >>> output_images = image_grid(output_images, num_samples // 4, 4) - >>> output_images.save("generated_image.png") - ``` -""" - - -class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet Guidance. - - This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Args: - vae ([`FlaxAutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`FlaxCLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.FlaxCLIPTextModel), - specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`FlaxControlNetModel`]: - Provides additional conditioning to the unet during the denoising process. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], [`FlaxPNDMScheduler`], or - [`FlaxDPMSolverMultistepScheduler`]. - safety_checker ([`FlaxStableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - - def __init__( - self, - vae: FlaxAutoencoderKL, - text_encoder: FlaxCLIPTextModel, - tokenizer: CLIPTokenizer, - unet: FlaxUNet2DConditionModel, - controlnet: FlaxControlNetModel, - scheduler: Union[ - FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler - ], - safety_checker: FlaxStableDiffusionSafetyChecker, - feature_extractor: CLIPFeatureExtractor, - dtype: jnp.dtype = jnp.float32, - ): - super().__init__() - self.dtype = dtype - - if safety_checker is None: - logger.warn( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - - def prepare_text_inputs(self, prompt: Union[str, List[str]]): - if not isinstance(prompt, (str, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - text_input = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - - return text_input.input_ids - - def prepare_image_inputs(self, image: Union[Image.Image, List[Image.Image]]): - if not isinstance(image, (Image.Image, list)): - raise ValueError(f"image has to be of type `PIL.Image.Image` or list but is {type(image)}") - - if isinstance(image, Image.Image): - image = [image] - - processed_images = jnp.concatenate([preprocess(img, jnp.float32) for img in image]) - - return processed_images - - def _get_has_nsfw_concepts(self, features, params): - has_nsfw_concepts = self.safety_checker(features, params) - return has_nsfw_concepts - - def _run_safety_checker(self, images, safety_model_params, jit=False): - # safety_model_params should already be replicated when jit is True - pil_images = [Image.fromarray(image) for image in images] - features = self.feature_extractor(pil_images, return_tensors="np").pixel_values - - if jit: - features = shard(features) - has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) - has_nsfw_concepts = unshard(has_nsfw_concepts) - safety_model_params = unreplicate(safety_model_params) - else: - has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() - - images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image - - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts - def _generate( - self, - prompt_ids: jnp.array, - image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int, - guidance_scale: float, - latents: Optional[jnp.array] = None, - neg_prompt_ids: Optional[jnp.array] = None, - controlnet_conditioning_scale: float = 1.0, - ): - height, width = image.shape[-2:] - if height % 64 != 0 or width % 64 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 64 but are {height} and {width}.") - - # get prompt text embeddings - prompt_embeds = self.text_encoder(prompt_ids, params=params["text_encoder"])[0] - - # TODO: currently it is assumed `do_classifier_free_guidance = guidance_scale > 1.0` - # implement this conditional `do_classifier_free_guidance = guidance_scale > 1.0` - batch_size = prompt_ids.shape[0] - - max_length = prompt_ids.shape[-1] - - if neg_prompt_ids is None: - uncond_input = self.tokenizer( - [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="np" - ).input_ids - else: - uncond_input = neg_prompt_ids - negative_prompt_embeds = self.text_encoder(uncond_input, params=params["text_encoder"])[0] - context = jnp.concatenate([negative_prompt_embeds, prompt_embeds]) - - image = jnp.concatenate([image] * 2) - - latents_shape = ( - batch_size, - self.unet.config.in_channels, - height // self.vae_scale_factor, - width // self.vae_scale_factor, - ) - if latents is None: - latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) - else: - if latents.shape != latents_shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - def loop_body(step, args): - latents, scheduler_state = args - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - latents_input = jnp.concatenate([latents] * 2) - - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - timestep = jnp.broadcast_to(t, latents_input.shape[0]) - - latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t) - - down_block_res_samples, mid_block_res_sample = self.controlnet.apply( - {"params": params["controlnet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - return_dict=False, - ) - - # predict the noise residual - noise_pred = self.unet.apply( - {"params": params["unet"]}, - jnp.array(latents_input), - jnp.array(timestep, dtype=jnp.int32), - encoder_hidden_states=context, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample - - # perform guidance - noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0) - noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents, scheduler_state - - scheduler_state = self.scheduler.set_timesteps( - params["scheduler"], num_inference_steps=num_inference_steps, shape=latents_shape - ) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * params["scheduler"].init_noise_sigma - - if DEBUG: - # run with python for loop - for i in range(num_inference_steps): - latents, scheduler_state = loop_body(i, (latents, scheduler_state)) - else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) - - # scale and decode the image latents with vae - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample - - image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) - return image - - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt_ids: jnp.array, - image: jnp.array, - params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, - num_inference_steps: int = 50, - guidance_scale: Union[float, jnp.array] = 7.5, - latents: jnp.array = None, - neg_prompt_ids: jnp.array = None, - controlnet_conditioning_scale: Union[float, jnp.array] = 1.0, - return_dict: bool = True, - jit: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt_ids (`jnp.array`): - The prompt or prompts to guide the image generation. - image (`jnp.array`): - Array representing the ControlNet input condition. ControlNet use this input condition to generate - guidance to Unet. - params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights - prng_seed (`jax.random.KeyArray` or `jax.Array`): Array containing random number generator key - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - latents (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - controlnet_conditioning_scale (`float` or `jnp.array`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of - a plain tuple. - jit (`bool`, defaults to `False`): - Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument - exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ - - height, width = image.shape[-2:] - - if isinstance(guidance_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - guidance_scale = guidance_scale[:, None] - - if isinstance(controlnet_conditioning_scale, float): - # Convert to a tensor so each device gets a copy. Follow the prompt_ids for - # shape information, as they may be sharded (when `jit` is `True`), or not. - controlnet_conditioning_scale = jnp.array([controlnet_conditioning_scale] * prompt_ids.shape[0]) - if len(prompt_ids.shape) > 2: - # Assume sharded - controlnet_conditioning_scale = controlnet_conditioning_scale[:, None] - - if jit: - images = _p_generate( - self, - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, - ) - else: - images = self._generate( - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, - ) - - if self.safety_checker is not None: - safety_params = params["safety_checker"] - images_uint8_casted = (images * 255).round().astype("uint8") - num_devices, batch_size = images.shape[:2] - - images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) - images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) - - # block images - if any(has_nsfw_concept): - for i, is_nsfw in enumerate(has_nsfw_concept): - if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) - - images = images.reshape(num_devices, batch_size, height, width, 3) - else: - images = np.asarray(images) - has_nsfw_concept = False - - if not return_dict: - return (images, has_nsfw_concept) - - return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) - - -# Static argnums are pipe, num_inference_steps. A change would trigger recompilation. -# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). -@partial( - jax.pmap, - in_axes=(None, 0, 0, 0, 0, None, 0, 0, 0, 0), - static_broadcasted_argnums=(0, 5), +deprecate( + "stable diffusion controlnet", + "0.22.0", + "Importing `FlaxStableDiffusionControlNetPipeline` from diffusers.pipelines.stable_diffusion.flax_pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import FlaxStableDiffusionControlNetPipeline` instead.", + standard_warn=False, + stacklevel=3, ) -def _p_generate( - pipe, - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, -): - return pipe._generate( - prompt_ids, - image, - params, - prng_seed, - num_inference_steps, - guidance_scale, - latents, - neg_prompt_ids, - controlnet_conditioning_scale, - ) - - -@partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_has_nsfw_concepts(pipe, features, params): - return pipe._get_has_nsfw_concepts(features, params) - - -def unshard(x: jnp.ndarray): - # einops.rearrange(x, 'd b ... -> (d b) ...') - num_devices, batch_size = x.shape[:2] - rest = x.shape[2:] - return x.reshape(num_devices * batch_size, *rest) - - -def preprocess(image, dtype): - image = image.convert("RGB") - w, h = image.size - w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 - image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]) - image = jnp.array(image).astype(dtype) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - return image diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 67d3f44e6d4b..293ed7d981b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -33,6 +34,13 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess with 8->64 def preprocess(image): + warnings.warn( + ( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead" + ), + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 7347d70c4023..8368668ebea7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch @@ -20,6 +21,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -53,6 +55,20 @@ """ +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromCkptMixin): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -177,6 +193,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_vae_slicing(self): @@ -291,6 +308,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -315,7 +333,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -372,7 +397,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -429,18 +454,27 @@ def _encode_prompt( return prompt_embeds def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -548,6 +582,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -608,6 +643,11 @@ def __call__( A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. Examples: @@ -642,6 +682,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -650,6 +693,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps @@ -686,15 +730,20 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -702,24 +751,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: image = latents has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - # 10. Convert to PIL - image = self.numpy_to_pil(image) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - # 8. Post-processing - image = self.decode_latents(latents) + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py index fba2a4e32f88..f76268463707 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py @@ -14,6 +14,7 @@ import inspect import math +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -21,7 +22,8 @@ from torch.nn import functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import KarrasDiffusionSchedulers @@ -228,6 +230,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -303,6 +306,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -327,7 +331,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -384,7 +395,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -442,19 +453,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -972,14 +992,19 @@ def __call__( callback(i, t, latents) # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 3bd7f82d7eb6..c7555e2ebad4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -12,1047 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import inspect -import os -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import PIL.Image -import torch -from torch import nn -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - -from ...loaders import TextualInversionLoaderMixin -from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel -from ...models.controlnet import ControlNetOutput -from ...models.modeling_utils import ModelMixin -from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - PIL_INTERPOLATION, - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, +# NOTE: This file is deprecated and will be removed in a future version. +# It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works +from ...utils import deprecate +from ..controlnet.multicontrolnet import MultiControlNetModel # noqa: F401 +from ..controlnet.pipeline_controlnet import StableDiffusionControlNetPipeline # noqa: F401 + + +deprecate( + "stable diffusion controlnet", + "0.22.0", + "Importing `StableDiffusionControlNetPipeline` or `MultiControlNetModel` from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet is deprecated. Please import `from diffusers import StableDiffusionControlNetPipeline` instead.", + standard_warn=False, + stacklevel=3, ) -from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -EXAMPLE_DOC_STRING = """ - Examples: - ```py - >>> # !pip install opencv-python transformers accelerate - >>> from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler - >>> from diffusers.utils import load_image - >>> import numpy as np - >>> import torch - - >>> import cv2 - >>> from PIL import Image - - >>> # download an image - >>> image = load_image( - ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png" - ... ) - >>> image = np.array(image) - - >>> # get canny image - >>> image = cv2.Canny(image, 100, 200) - >>> image = image[:, :, None] - >>> image = np.concatenate([image, image, image], axis=2) - >>> canny_image = Image.fromarray(image) - - >>> # load control net and stable diffusion v1-5 - >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) - >>> pipe = StableDiffusionControlNetPipeline.from_pretrained( - ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16 - ... ) - - >>> # speed up diffusion process with faster scheduler and memory optimization - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) - >>> # remove following line if xformers is not installed - >>> pipe.enable_xformers_memory_efficient_attention() - - >>> pipe.enable_model_cpu_offload() - - >>> # generate image - >>> generator = torch.manual_seed(0) - >>> image = pipe( - ... "futuristic-looking woman", num_inference_steps=20, generator=generator, image=canny_image - ... ).images[0] - ``` -""" - - -class MultiControlNetModel(ModelMixin): - r""" - Multiple `ControlNetModel` wrapper class for Multi-ControlNet - - This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be - compatible with `ControlNetModel`. - - Args: - controlnets (`List[ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. You must set multiple - `ControlNetModel` as a list. - """ - - def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): - super().__init__() - self.nets = nn.ModuleList(controlnets) - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - controlnet_cond: List[torch.tensor], - conditioning_scale: List[float], - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guess_mode: bool = False, - return_dict: bool = True, - ) -> Union[ControlNetOutput, Tuple]: - for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): - down_samples, mid_sample = controlnet( - sample, - timestep, - encoder_hidden_states, - image, - scale, - class_labels, - timestep_cond, - attention_mask, - cross_attention_kwargs, - guess_mode, - return_dict, - ) - - # merge samples - if i == 0: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) - ] - mid_block_res_sample += mid_sample - - return down_block_res_samples, mid_block_res_sample - - -class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): - r""" - Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - In addition the pipeline inherits the following loading methods: - - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): - Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets - as a list, the outputs from each ControlNet are added together to create one combined additional - conditioning. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. - feature_extractor ([`CLIPImageProcessor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel], - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - requires_safety_checker: bool = True, - ): - super().__init__() - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - if isinstance(controlnet, (list, tuple)): - controlnet = MultiControlNetModel(controlnet) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - controlnet=controlnet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing - def enable_vae_slicing(self): - r""" - Enable sliced VAE decoding. - - When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several - steps. This is useful to save some memory and allow larger batch sizes. - """ - self.vae.enable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing - def disable_vae_slicing(self): - r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_slicing() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - - def enable_sequential_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, - text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a - `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. - Note that offloading happens on a submodule basis. Memory savings are higher than with - `enable_model_cpu_offload`, but performance is lower. - """ - if is_accelerate_available(): - from accelerate import cpu_offload - else: - raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") - - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: - cpu_offload(cpu_offloaded_model, device) - - if self.safety_checker is not None: - cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) - - def enable_model_cpu_offload(self, gpu_id=0): - r""" - Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared - to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` - method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with - `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. - """ - if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): - from accelerate import cpu_offload_with_hook - else: - raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") - - device = torch.device(f"cuda:{gpu_id}") - - hook = None - for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: - _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - - if self.safety_checker is not None: - # the safety checker can offload the vae again - _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) - - # control net hook has be manually offloaded as it alternates with unet - cpu_offload_with_hook(self.controlnet, device) - - # We'll offload the last model manually. - self.final_offload_hook = hook - - @property - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device - def _execution_device(self): - r""" - Returns the device on which the pipeline's models will be executed. After calling - `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module - hooks. - """ - if not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - """ - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) - prompt_embeds = prompt_embeds[0] - - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: procecss multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - return prompt_embeds - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - else: - has_nsfw_concept = None - return image, has_nsfw_concept - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents - def decode_latents(self, latents): - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - image, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - controlnet_conditioning_scale=1.0, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) - ): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - # `prompt` needs more sophisticated handling when there are multiple - # conditionings. - if isinstance(self.controlnet, MultiControlNetModel): - if isinstance(prompt, list): - logger.warning( - f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}" - " prompts. The conditionings will be fixed across the prompts." - ) - - # Check `image` - if isinstance(self.controlnet, ControlNetModel): - self.check_image(image, prompt, prompt_embeds) - elif isinstance(self.controlnet, MultiControlNetModel): - if not isinstance(image, list): - raise TypeError("For multiple controlnets: `image` must be type `list`") - - # When `image` is a nested list: - # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) - elif any(isinstance(i, list) for i in image): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif len(image) != len(self.controlnet.nets): - raise ValueError( - "For multiple controlnets: `image` must have the same length as the number of controlnets." - ) - - for image_ in image: - self.check_image(image_, prompt, prompt_embeds) - else: - assert False - - # Check `controlnet_conditioning_scale` - if isinstance(self.controlnet, ControlNetModel): - if not isinstance(controlnet_conditioning_scale, float): - raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif isinstance(self.controlnet, MultiControlNetModel): - if isinstance(controlnet_conditioning_scale, list): - if any(isinstance(i, list) for i in controlnet_conditioning_scale): - raise ValueError("A single batch of multiple conditionings are supported at the moment.") - elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len( - self.controlnet.nets - ): - raise ValueError( - "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have" - " the same length as the number of controlnets" - ) - else: - assert False - - def check_image(self, image, prompt, prompt_embeds): - image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - - if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: - raise TypeError( - "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" - ) - - if image_is_pil: - image_batch_size = 1 - elif image_is_tensor: - image_batch_size = image.shape[0] - elif image_is_pil_list: - image_batch_size = len(image) - elif image_is_tensor_list: - image_batch_size = len(image) - - if prompt is not None and isinstance(prompt, str): - prompt_batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - prompt_batch_size = len(prompt) - elif prompt_embeds is not None: - prompt_batch_size = prompt_embeds.shape[0] - - if image_batch_size != 1 and image_batch_size != prompt_batch_size: - raise ValueError( - f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" - ) - - def prepare_image( - self, - image, - width, - height, - batch_size, - num_images_per_prompt, - device, - dtype, - do_classifier_free_guidance=False, - guess_mode=False, - ): - if not isinstance(image, torch.Tensor): - if isinstance(image, PIL.Image.Image): - image = [image] - - if isinstance(image[0], PIL.Image.Image): - images = [] - - for image_ in image: - image_ = image_.convert("RGB") - image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) - image_ = np.array(image_) - image_ = image_[None, :] - images.append(image_) - - image = images - - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - - image_batch_size = image.shape[0] - - if image_batch_size == 1: - repeat_by = batch_size - else: - # image batch size is the same as prompt batch size - repeat_by = num_images_per_prompt - - image = image.repeat_interleave(repeat_by, dim=0) - - image = image.to(device=device, dtype=dtype) - - if do_classifier_free_guidance and not guess_mode: - image = torch.cat([image] * 2) - - return image - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - def _default_height_width(self, height, width, image): - # NOTE: It is possible that a list of images have different - # dimensions for each image, so just checking the first image - # is not _exactly_ correct, but it is simple. - while isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[2] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[3] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - - # override DiffusionPipeline - def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - safe_serialization: bool = False, - variant: Optional[str] = None, - ): - if isinstance(self.controlnet, ControlNetModel): - super().save_pretrained(save_directory, safe_serialization, variant) - else: - raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") - - @torch.no_grad() - @replace_example_docstring(EXAMPLE_DOC_STRING) - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 1.0, - guess_mode: bool = False, - ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. - instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, - `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can - also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If - height and/or width are passed, `image` is resized according to them. If multiple ControlNets are - specified in init, images must be passed as a list such that each element of the list can be correctly - batched for input to a single controlnet. - height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) - to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - callback (`Callable`, *optional*): - A function that will be called every `callback_steps` steps during inference. The function will be - called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. - callback_steps (`int`, *optional*, defaults to 1): - The frequency at which the `callback` function will be called. If not specified, the callback will be - called at every step. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): - The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added - to the residual in the original unet. If multiple ControlNets are specified in init, you can set the - corresponding scale as a list. - guess_mode (`bool`, *optional*, defaults to `False`): - In this mode, the ControlNet encoder will try best to recognize the content of the input image even if - you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. - When returning a tuple, the first element is a list with the generated images, and the second element is a - list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content, according to the `safety_checker`. - """ - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - image, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - controlnet_conditioning_scale, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) - - # 3. Encode input prompt - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - # 4. Prepare image - if isinstance(self.controlnet, ControlNetModel): - image = self.prepare_image( - image=image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - elif isinstance(self.controlnet, MultiControlNetModel): - images = [] - - for image_ in image: - image_ = self.prepare_image( - image=image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - - images.append(image_) - - image = images - else: - assert False - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - # 6. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - controlnet_latent_model_input = latents - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - else: - controlnet_latent_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - down_block_res_samples, mid_block_res_sample = self.controlnet( - controlnet_latent_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, - conditioning_scale=controlnet_conditioning_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if output_type == "latent": - image = latents - has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) - else: - # 8. Post-processing - image = self.decode_latents(latents) - - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index c4f9ae59a4e9..002014681040 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -14,7 +14,8 @@ import contextlib import inspect -from typing import Callable, List, Optional, Union +import warnings +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -23,6 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -35,6 +37,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -128,6 +135,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -175,6 +183,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -199,7 +208,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -256,7 +272,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -314,19 +330,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -411,21 +436,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - init_latents = self.vae.config.scaling_factor * init_latents + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -462,6 +492,8 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui if isinstance(image[0], PIL.Image.Image): width, height = image[0].size + elif isinstance(image[0], np.ndarray): + width, height = image[0].shape[:-1] else: height, width = image[0].shape[-2:] @@ -500,7 +532,14 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, @@ -515,6 +554,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -523,9 +563,12 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can accept image latents as `image` only if `depth_map` is not `None`. + depth_map (`torch.FloatTensor`, *optional*): + depth prediction that will be used as additional conditioning for the image generation process. If not + defined, it will automatically predicts the depth via `self.depth_estimator`. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of @@ -572,6 +615,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: @@ -631,6 +678,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -639,6 +689,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare depth mask @@ -652,7 +703,7 @@ def __call__( ) # 5. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 6. Set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -677,7 +728,13 @@ def __call__( latent_model_input = torch.cat([latent_model_input, depth_mask], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -685,7 +742,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -693,12 +750,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py new file mode 100644 index 000000000000..837811baae64 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py @@ -0,0 +1,1570 @@ +# Copyright 2023 DiffEdit Authors and Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL +import torch +from packaging import version +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, UNet2DConditionModel +from ...schedulers import DDIMInverseScheduler, KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + BaseOutput, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class DiffEditInversionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + latents (`torch.FloatTensor`) + inverted latents tensor + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `num_timesteps * batch_size` or numpy array of shape `(num_timesteps, + batch_size, height, width, num_channels)`. PIL images or numpy array present the denoised images of the + diffusion pipeline. + """ + + latents: torch.FloatTensor + images: Union[List[PIL.Image.Image], np.ndarray] + + +EXAMPLE_DOC_STRING = """ + + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> mask_prompt = "A bowl of fruits" + >>> prompt = "A bowl of pears" + + >>> mask_image = pipe.generate_mask(image=init_image, source_prompt=prompt, target_prompt=mask_prompt) + >>> image_latents = pipe.invert(image=init_image, prompt=mask_prompt).latents + >>> image = pipe(prompt=prompt, mask_image=mask_image, image_latents=image_latents).images[0] + ``` +""" + +EXAMPLE_INVERT_DOC_STRING = """ + ```py + >>> import PIL + >>> import requests + >>> import torch + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionDiffEditPipeline + + + >>> def download_image(url): + ... response = requests.get(url) + ... return PIL.Image.open(BytesIO(response.content)).convert("RGB") + + + >>> img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png" + + >>> init_image = download_image(img_url).resize((768, 768)) + + >>> pipe = StableDiffusionDiffEditPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.enable_model_cpu_offload() + + >>> prompt = "A bowl of fruits" + + >>> inverted_latents = pipe.invert(image=init_image, prompt=prompt).latents + ``` +""" + + +def auto_corr_loss(hidden_states, generator=None): + reg_loss = 0.0 + for i in range(hidden_states.shape[0]): + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = torch.nn.functional.avg_pool2d(noise, kernel_size=2) + return reg_loss + + +def kl_divergence(hidden_states): + return hidden_states.var() + hidden_states.mean() ** 2 - 1 - torch.log(hidden_states.var() + 1e-7) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +def preprocess_mask(mask, batch_size: int = 1): + if not isinstance(mask, torch.Tensor): + # preprocess mask + if isinstance(mask, PIL.Image.Image) or isinstance(mask, np.ndarray): + mask = [mask] + + if isinstance(mask, list): + if isinstance(mask[0], PIL.Image.Image): + mask = [np.array(m.convert("L")).astype(np.float32) / 255.0 for m in mask] + if isinstance(mask[0], np.ndarray): + mask = np.stack(mask, axis=0) if mask[0].ndim < 3 else np.concatenate(mask, axis=0) + mask = torch.from_numpy(mask) + elif isinstance(mask[0], torch.Tensor): + mask = torch.stack(mask, dim=0) if mask[0].ndim < 3 else torch.cat(mask, dim=0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + # Check mask shape + if batch_size > 1: + if mask.shape[0] == 1: + mask = torch.cat([mask] * batch_size) + elif mask.shape[0] > 1 and mask.shape[0] != batch_size: + raise ValueError( + f"`mask_image` with batch size {mask.shape[0]} cannot be broadcasted to batch size {batch_size} " + f"inferred by prompt inputs" + ) + + if mask.shape[1] != 1: + raise ValueError(f"`mask_image` must have 1 channel, but has {mask.shape[1]} channels") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("`mask_image` should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + return mask + + +class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion using DiffEdit. *This is an experimental feature*. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + as well as the following saving methods: + - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. + inverse_scheduler (`[DDIMInverseScheduler]`): + A scheduler to be used in combination with `unet` to fill in the unmasked part of the input latents + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPImageProcessor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor", "inverse_scheduler"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + inverse_scheduler: DDIMInverseScheduler, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "skip_prk_steps") and scheduler.config.skip_prk_steps is False: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration" + " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make" + " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face" + " Hub, it would be very nice if you could open a Pull request for the" + " `scheduler/scheduler_config.json` file" + ) + deprecate("skip_prk_steps not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["skip_prk_steps"] = True + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + inverse_scheduler=inverse_scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def check_inputs( + self, + prompt, + strength, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (strength is None) or (strength is not None and (strength < 0 or strength > 1)): + raise ValueError( + f"The value of `strength` should in [0.0, 1.0] but is, but is {strength} of type {type(strength)}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def check_source_inputs( + self, + source_prompt=None, + source_negative_prompt=None, + source_prompt_embeds=None, + source_negative_prompt_embeds=None, + ): + if source_prompt is not None and source_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_prompt`: {source_prompt} and `source_prompt_embeds`: {source_prompt_embeds}." + " Please make sure to only forward one of the two." + ) + elif source_prompt is None and source_prompt_embeds is None: + raise ValueError( + "Provide either `source_image` or `source_prompt_embeds`. Cannot leave all both of the arguments undefined." + ) + elif source_prompt is not None and ( + not isinstance(source_prompt, str) and not isinstance(source_prompt, list) + ): + raise ValueError(f"`source_prompt` has to be of type `str` or `list` but is {type(source_prompt)}") + + if source_negative_prompt is not None and source_negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `source_negative_prompt`: {source_negative_prompt} and `source_negative_prompt_embeds`:" + f" {source_negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if source_prompt_embeds is not None and source_negative_prompt_embeds is not None: + if source_prompt_embeds.shape != source_negative_prompt_embeds.shape: + raise ValueError( + "`source_prompt_embeds` and `source_negative_prompt_embeds` must have the same shape when passed" + f" directly, but got: `source_prompt_embeds` {source_prompt_embeds.shape} !=" + f" `source_negative_prompt_embeds` {source_negative_prompt_embeds.shape}." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def get_inverse_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + + # safety for t_start overflow to prevent empty timsteps slice + if t_start == 0: + return self.inverse_scheduler.timesteps, num_inference_steps + timesteps = self.inverse_scheduler.timesteps[:-t_start] + + return timesteps, num_inference_steps - t_start + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero.StableDiffusionPix2PixZeroPipeline.prepare_image_latents + def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + if image.shape[1] == 4: + latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) + else: + latents = torch.cat([latents], dim=0) + + return latents + + def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): + pred_type = self.inverse_scheduler.config.prediction_type + alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] + + beta_prod_t = 1 - alpha_prod_t + + if pred_type == "epsilon": + return model_output + elif pred_type == "sample": + return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) + elif pred_type == "v_prediction": + return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" + ) + + @torch.no_grad() + def generate_mask( + self, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + target_prompt: Optional[Union[str, List[str]]] = None, + target_negative_prompt: Optional[Union[str, List[str]]] = None, + target_prompt_embeds: Optional[torch.FloatTensor] = None, + target_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + source_prompt: Optional[Union[str, List[str]]] = None, + source_negative_prompt: Optional[Union[str, List[str]]] = None, + source_prompt_embeds: Optional[torch.FloatTensor] = None, + source_negative_prompt_embeds: Optional[torch.FloatTensor] = None, + num_maps_per_mask: Optional[int] = 10, + mask_encode_strength: Optional[float] = 0.5, + mask_thresholding_ratio: Optional[float] = 3.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "np", + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function used to generate a latent mask given a mask prompt, a target prompt, and an image. + + Args: + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be used for computing the mask. + target_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation. If not defined, one has to pass + `prompt_embeds`. instead. + target_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + target_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + target_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + source_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If + not defined, one has to pass `source_prompt_embeds` or `source_image` instead. + source_negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the semantic mask generation away from using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). If + not defined, one has to pass `source_negative_prompt_embeds` or `source_image` instead. + source_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings to guide the semantic mask generation. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from + `source_prompt` input argument. + source_negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings to negatively guide the semantic mask generation. Can be used to easily + tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from + `source_negative_prompt` input argument. + num_maps_per_mask (`int`, *optional*, defaults to 10): + The number of noise maps sampled to generate the semantic mask using the method in [DiffEdit: + Diffusion-Based Semantic Image Editing with Mask Guidance](https://arxiv.org/pdf/2210.11427.pdf). + mask_encode_strength (`float`, *optional*, defaults to 0.5): + Conceptually, the strength of the noise maps sampled to generate the semantic mask using the method in + [DiffEdit: Diffusion-Based Semantic Image Editing with Mask Guidance]( + https://arxiv.org/pdf/2210.11427.pdf). Must be between 0 and 1. + mask_thresholding_ratio (`float`, *optional*, defaults to 3.0): + The maximum multiple of the mean absolute difference used to clamp the semantic guidance map before + mask binarization. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + `List[PIL.Image.Image]` or `np.array`: `List[PIL.Image.Image]` if `output_type` is `"pil"`, otherwise a + `np.array`. When returning a `List[PIL.Image.Image]`, the list will consist of a batch of single-channel + binary image with dimensions `(height // self.vae_scale_factor, width // self.vae_scale_factor)`, otherwise + the `np.array` will have shape `(batch_size, height // self.vae_scale_factor, width // + self.vae_scale_factor)`. + """ + + # 1. Check inputs (Provide dummy argument for callback_steps) + self.check_inputs( + target_prompt, + mask_encode_strength, + 1, + target_negative_prompt, + target_prompt_embeds, + target_negative_prompt_embeds, + ) + + self.check_source_inputs( + source_prompt, + source_negative_prompt, + source_prompt_embeds, + source_negative_prompt_embeds, + ) + + if (num_maps_per_mask is None) or ( + num_maps_per_mask is not None and (not isinstance(num_maps_per_mask, int) or num_maps_per_mask <= 0) + ): + raise ValueError( + f"`num_maps_per_mask` has to be a positive integer but is {num_maps_per_mask} of type" + f" {type(num_maps_per_mask)}." + ) + + if mask_thresholding_ratio is None or mask_thresholding_ratio <= 0: + raise ValueError( + f"`mask_thresholding_ratio` has to be positive but is {mask_thresholding_ratio} of type" + f" {type(mask_thresholding_ratio)}." + ) + + # 2. Define call parameters + if target_prompt is not None and isinstance(target_prompt, str): + batch_size = 1 + elif target_prompt is not None and isinstance(target_prompt, list): + batch_size = len(target_prompt) + else: + batch_size = target_prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompts + (cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None) + target_prompt_embeds = self._encode_prompt( + target_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + target_negative_prompt, + prompt_embeds=target_prompt_embeds, + negative_prompt_embeds=target_negative_prompt_embeds, + ) + + source_prompt_embeds = self._encode_prompt( + source_prompt, + device, + num_maps_per_mask, + do_classifier_free_guidance, + source_negative_prompt, + prompt_embeds=source_prompt_embeds, + negative_prompt_embeds=source_negative_prompt_embeds, + ) + + # 4. Preprocess image + image = preprocess(image).repeat_interleave(num_maps_per_mask, dim=0) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, _ = self.get_timesteps(num_inference_steps, mask_encode_strength, device) + encode_timestep = timesteps[0] + + # 6. Prepare image latents and add noise with specified strength + image_latents = self.prepare_image_latents( + image, batch_size * num_maps_per_mask, self.vae.dtype, device, generator + ) + noise = randn_tensor(image_latents.shape, generator=generator, device=device, dtype=self.vae.dtype) + image_latents = self.scheduler.add_noise(image_latents, noise, encode_timestep) + + latent_model_input = torch.cat([image_latents] * (4 if do_classifier_free_guidance else 2)) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, encode_timestep) + + # 7. Predict the noise residual + prompt_embeds = torch.cat([source_prompt_embeds, target_prompt_embeds]) + noise_pred = self.unet( + latent_model_input, + encode_timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + if do_classifier_free_guidance: + noise_pred_neg_src, noise_pred_source, noise_pred_uncond, noise_pred_target = noise_pred.chunk(4) + noise_pred_source = noise_pred_neg_src + guidance_scale * (noise_pred_source - noise_pred_neg_src) + noise_pred_target = noise_pred_uncond + guidance_scale * (noise_pred_target - noise_pred_uncond) + else: + noise_pred_source, noise_pred_target = noise_pred.chunk(2) + + # 8. Compute the mask from the absolute difference of predicted noise residuals + # TODO: Consider smoothing mask guidance map + mask_guidance_map = ( + torch.abs(noise_pred_target - noise_pred_source) + .reshape(batch_size, num_maps_per_mask, *noise_pred_target.shape[-3:]) + .mean([1, 2]) + ) + clamp_magnitude = mask_guidance_map.mean() * mask_thresholding_ratio + semantic_mask_image = mask_guidance_map.clamp(0, clamp_magnitude) / clamp_magnitude + semantic_mask_image = torch.where(semantic_mask_image <= 0.5, 0, 1) + mask_image = semantic_mask_image.cpu().numpy() + + # 9. Convert to Numpy array or PIL. + if output_type == "pil": + mask_image = self.image_processor.numpy_to_pil(mask_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return mask_image + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) + def invert( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + num_inference_steps: int = 50, + inpaint_strength: float = 0.8, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + decode_latents: bool = False, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + lambda_auto_corr: float = 20.0, + lambda_kl: float = 20.0, + num_reg_steps: int = 0, + num_auto_corr_rolls: int = 5, + ): + r""" + Function used to generate inverted latents given a prompt and image. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch to produce the inverted latents, guided by `prompt`. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how far into the noising process to run latent inversion. Must be between 0 and + 1. When `strength` is 1, the inversion process will be run for the full number of iterations specified + in `num_inference_steps`. `image` will be used as a reference for the inversion process, adding more + noise the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + decode_latents (`bool`, *optional*, defaults to `False`): + Whether or not to decode the inverted latents into a generated image. Setting this argument to `True` + will decode all inverted latents for each timestep into a list of generated images. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.DiffEditInversionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + lambda_auto_corr (`float`, *optional*, defaults to 20.0): + Lambda parameter to control auto correction + lambda_kl (`float`, *optional*, defaults to 20.0): + Lambda parameter to control Kullback–Leibler divergence output + num_reg_steps (`int`, *optional*, defaults to 0): + Number of regularization loss steps + num_auto_corr_rolls (`int`, *optional*, defaults to 5): + Number of auto correction roll steps + + Examples: + + Returns: + [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] or + `tuple`: [`~pipelines.stable_diffusion.pipeline_stable_diffusion_diffedit.DiffEditInversionPipelineOutput`] + if `return_dict` is `True`, otherwise a `tuple`. When returning a tuple, the first element is the inverted + latents tensors ordered by increasing noise, and then second is the corresponding decoded images if + `decode_latents` is `True`, otherwise `None`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Preprocess image + image = preprocess(image) + + # 4. Prepare latent variables + num_images_per_prompt = 1 + latents = self.prepare_image_latents( + image, batch_size * num_images_per_prompt, self.vae.dtype, device, generator + ) + + # 5. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 6. Prepare timesteps + self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_inverse_timesteps(num_inference_steps, inpaint_strength, device) + + # 7. Noising loop where we obtain the intermediate noised latent image for each timestep. + num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order + inverted_latents = [latents.detach().clone()] + with self.progress_bar(total=num_inference_steps - 1) as progress_bar: + for i, t in enumerate(timesteps[:-1]): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # regularization of the noise prediction (not in original code or paper but borrowed from Pix2PixZero) + if num_reg_steps > 0: + with torch.enable_grad(): + for _ in range(num_reg_steps): + if lambda_auto_corr > 0: + for _ in range(num_auto_corr_rolls): + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_ac = auto_corr_loss(var_epsilon, generator=generator) + l_ac.backward() + + grad = var.grad.detach() / num_auto_corr_rolls + noise_pred = noise_pred - lambda_auto_corr * grad + + if lambda_kl > 0: + var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) + + # Derive epsilon from model output before regularizing to IID standard normal + var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) + + l_kld = kl_divergence(var_epsilon) + l_kld.backward() + + grad = var.grad.detach() + noise_pred = noise_pred - lambda_kl * grad + + noise_pred = noise_pred.detach() + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample + inverted_latents.append(latents.detach().clone()) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + assert len(inverted_latents) == len(timesteps) + latents = torch.stack(list(reversed(inverted_latents)), 1) + + # 8. Post-processing + image = None + if decode_latents: + image = self.decode_latents(latents.flatten(0, 1).detach()) + + # 9. Convert to PIL. + if decode_latents and output_type == "pil": + image = self.image_processor.numpy_to_pil(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (latents, image) + + return DiffEditInversionPipelineOutput(latents=latents, images=image) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image_latents: torch.FloatTensor = None, + inpaint_strength: Optional[float] = 0.8, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask the generated image. White pixels in the mask + will be repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be + converted to a single channel (luminance) before use. If it's a tensor, it should contain one color + channel (L) instead of 3, so the expected shape would be `(B, 1, H, W)`. + image_latents (`PIL.Image.Image` or `torch.FloatTensor`): + Partially noised image latents from the inversion process to be used as inputs for image generation. + inpaint_strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` + is 1, the denoising process will be run on the masked area for the full number of iterations specified + in `num_inference_steps`. `image_latents` will be used as a reference for the masked area, adding more + noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + + # 1. Check inputs + self.check_inputs( + prompt, + inpaint_strength, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if mask_image is None: + raise ValueError( + "`mask_image` input cannot be undefined. Use `generate_mask()` to compute `mask_image` from text prompts." + ) + if image_latents is None: + raise ValueError( + "`image_latents` input cannot be undefined. Use `invert()` to compute `image_latents` from input images." + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if cross_attention_kwargs is None: + cross_attention_kwargs = {} + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess mask + mask_image = preprocess_mask(mask_image, batch_size) + latent_height, latent_width = mask_image.shape[-2:] + mask_image = torch.cat([mask_image] * num_images_per_prompt) + mask_image = mask_image.to(device=device, dtype=prompt_embeds.dtype) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, inpaint_strength, device) + + # 6. Preprocess image latents + image_latents = preprocess(image_latents) + latent_shape = (self.vae.config.latent_channels, latent_height, latent_width) + if image_latents.shape[-3:] != latent_shape: + raise ValueError( + f"Each latent image in `image_latents` must have shape {latent_shape}, " + f"but has shape {image_latents.shape[-3:]}" + ) + if image_latents.ndim == 4: + image_latents = image_latents.reshape(batch_size, len(timesteps), *latent_shape) + if image_latents.shape[:2] != (batch_size, len(timesteps)): + raise ValueError( + f"`image_latents` must have batch size {batch_size} with latent images from {len(timesteps)} timesteps, " + f"but has batch size {image_latents.shape[0]} with latent images from {image_latents.shape[1]} timesteps." + ) + image_latents = image_latents.transpose(0, 1).repeat_interleave(num_images_per_prompt, dim=1) + image_latents = image_latents.to(device=device, dtype=prompt_embeds.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + latents = image_latents[0].detach().clone() + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # mask with inverted latents from appropriate timestep - use original image latent for last step + latents = latents * mask_image + image_latents[i] * (1 - mask_image) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index d543593fdbf5..640fd7f2d94b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import PIL @@ -21,6 +22,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor @@ -118,6 +120,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -183,19 +186,28 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -398,15 +410,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + else: + image = latents + has_nsfw_concept = None - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index c26ddf06cadc..e9e91b646ed5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import numpy as np @@ -72,6 +73,11 @@ def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -205,6 +211,7 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -215,11 +222,8 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config( - requires_safety_checker=requires_safety_checker, - ) + self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload def enable_sequential_cpu_offload(self, gpu_id=0): @@ -305,6 +309,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -329,7 +334,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -386,7 +398,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -442,18 +454,33 @@ def _encode_prompt( return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) return image, has_nsfw_concept + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -532,21 +559,26 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + if image.shape[1] == 4: + init_latents = image + else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - init_latents = self.vae.config.scaling_factor * init_latents + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: # expand init_latents for batch_size @@ -580,7 +612,14 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -603,9 +642,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor` or `PIL.Image.Image`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. strength (`float`, *optional*, defaults to 0.8): Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the `strength`. The number of @@ -682,6 +722,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -690,6 +733,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image @@ -722,7 +766,8 @@ def __call__( t, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -730,7 +775,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -738,27 +783,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type not in ["latent", "pt", "np", "pil"]: - deprecation_message = ( - f"the output_type {output_type} is outdated. Please make sure to set it to one of these instead: " - "`pil`, `np`, `pt`, `latent`" - ) - deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False) - output_type = "np" - - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: image = latents has_nsfw_concept = None + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - image = self.decode_latents(latents) - - if self.safety_checker is not None: - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - has_nsfw_concept = False + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - image = self.image_processor.postprocess(image, output_type=output_type) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index fb2e5dc424e3..b07a5555f1c7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -13,7 +13,8 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +import warnings +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -34,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask): +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -62,6 +64,13 @@ def prepare_mask_and_masked_image(image, mask): tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + if isinstance(image, torch.Tensor): if not isinstance(mask, torch.Tensor): raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") @@ -109,8 +118,9 @@ def prepare_mask_and_masked_image(image, mask): # preprocess image if isinstance(image, (PIL.Image.Image, np.ndarray)): image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] image = [np.array(i.convert("RGB"))[None, :] for i in image] image = np.concatenate(image, axis=0) elif isinstance(image, list) and isinstance(image[0], np.ndarray): @@ -124,6 +134,7 @@ def prepare_mask_and_masked_image(image, mask): mask = [mask] if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) mask = mask.astype(np.float32) / 255.0 elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): @@ -135,12 +146,16 @@ def prepare_mask_and_masked_image(image, mask): masked_image = image * (mask < 0.5) + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + return mask, masked_image class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" - Pipeline for text-guided image inpainting using Stable Diffusion. *This is an experimental feature*. + Pipeline for text-guided image inpainting using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) @@ -152,6 +167,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMi as well as the following saving methods: - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + + + It is recommended to use this pipeline with checkpoints that have been specifically fine-tuned for inpainting, such + as [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting). Default + text-to-image stable diffusion checkpoints, such as + [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) are also compatible with + this pipeline, but might be less performant. + + + Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. @@ -251,14 +276,10 @@ def __init__( new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) + # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4 if unet.config.in_channels != 9: - logger.warning( - f"You have loaded a UNet with {unet.config.in_channels} input channels, whereas by default," - f" {self.__class__} assumes that `pipeline.unet` has 9 input channels: 4 for `num_channels_latents`," - " 1 for `num_channels_mask`, and 4 for `num_channels_masked_image`. If you did not intend to modify" - " this behavior, please check whether you have loaded the right checkpoint." - ) + logger.info(f"You have loaded a UNet with {unet.config.in_channels} input channels which.") self.register_modules( vae=vae, @@ -270,6 +291,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -356,6 +378,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -380,7 +403,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -437,7 +467,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -495,13 +525,17 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -524,24 +558,32 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, height, width, + strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -579,8 +621,22 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + return_noise=False, + return_image_latents=False, + ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -588,14 +644,49 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + + if return_image_latents or (latents is None and not is_strength_max): + image = image.to(device=device, dtype=dtype) + image_latents = self._encode_vae_image(image=image, generator=generator) + if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # if strength is 1. then initialise the latents to noise, else initial to image + noise + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep) + # if pure noise then scale the initial latents by the Scheduler's init sigma + latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents + else: + noise = latents.to(device) + latents = noise * self.scheduler.init_noise_sigma + + outputs = (latents,) + + if return_noise: + outputs += (noise,) + + if return_image_latents: + outputs += (image_latents,) + + return outputs + + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) else: - latents = latents.to(device) + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents + return image_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance @@ -609,17 +700,7 @@ def prepare_mask_latents( mask = mask.to(device=device, dtype=dtype) masked_image = masked_image.to(device=device, dtype=dtype) - - # encode the mask image into latents space so we can concatenate it to the latents - if isinstance(generator, list): - masked_image_latents = [ - self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - masked_image_latents = torch.cat(masked_image_latents, dim=0) - else: - masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator) - masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + masked_image_latents = self._encode_vae_image(masked_image, generator=generator) # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: @@ -648,6 +729,16 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + @torch.no_grad() def __call__( self, @@ -656,6 +747,7 @@ def __call__( mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -669,6 +761,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -689,6 +782,13 @@ def __call__( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -733,7 +833,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. - + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py @@ -781,18 +884,13 @@ def __call__( prompt, height, width, + strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, ) - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask_image is None: - raise ValueError("`mask_image` input cannot be undefined.") - # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -808,6 +906,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -816,18 +917,36 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) - # 4. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) - - # 5. set timesteps + # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # check that number of inference steps is not < 1 - as this doesn't make sense + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 5. Preprocess mask and image + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels - latents = self.prepare_latents( + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + + latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, @@ -836,8 +955,18 @@ def __call__( device, generator, latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, ) + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + # 7. Prepare mask latent variables mask, masked_image_latents = self.prepare_mask_latents( mask, @@ -850,17 +979,25 @@ def __call__( generator, do_classifier_free_guidance, ) + init_image = init_image.to(device=device, dtype=masked_image_latents.dtype) + init_image = self._encode_vae_image(init_image, generator=generator) # 8. Check that sizes of mask, masked image and latents match - num_channels_mask = mask.shape[1] - num_channels_masked_image = masked_image_latents.shape[1] - if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + if num_channels_unet == 9: + # default case for runwayml/stable-diffusion-inpainting + num_channels_mask = mask.shape[1] + num_channels_masked_image = masked_image_latents.shape[1] + if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" + f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" + " `pipeline.unet` or your `mask_image` or `image` input." + ) + elif num_channels_unet != 4: raise ValueError( - f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" - f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" - f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}" - f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of" - " `pipeline.unet` or your `mask_image` or `image` input." + f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}." ) # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -875,10 +1012,18 @@ def __call__( # concat latents, mask, masked_image_latents in the channel dimension latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -886,7 +1031,16 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + + if i < len(timesteps) - 1: + init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t])) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -894,15 +1048,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 11. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 12. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 13. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 3ad1d5e92273..147d914fe6c1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -13,7 +13,8 @@ # limitations under the License. import inspect -from typing import Callable, List, Optional, Union +import warnings +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL @@ -22,6 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -121,7 +123,6 @@ class StableDiffusionInpaintPipelineLegacy( """ _optional_components = ["feature_extractor"] - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ def __init__( self, vae: AutoencoderKL, @@ -135,6 +136,13 @@ def __init__( ): super().__init__() + deprecation_message = ( + f"The class {self.__class__} is deprecated and will be removed in v1.0.0. You can achieve exactly the same functionality" + "by loading your model into `StableDiffusionInpaintPipeline` instead. See https://github.com/huggingface/diffusers/pull/3533" + "for more information." + ) + deprecate("legacy is outdated", "1.0.0", deprecation_message, standard_warn=False) + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: deprecation_message = ( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" @@ -209,6 +217,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -295,6 +304,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -319,7 +329,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -376,7 +393,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -434,19 +451,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -557,6 +583,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -621,6 +648,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: @@ -647,6 +678,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -655,6 +689,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image and mask @@ -690,7 +725,13 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -698,7 +739,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # masking if add_predicted_noise: init_latents_proper = self.scheduler.add_noise( @@ -718,15 +759,19 @@ def __call__( # use original latents corresponding to unmasked portions of the image latents = (init_latents_orig * mask) + (latents * (1 - mask)) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 12. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 49944cdcd636..25102ae7cf4a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -20,6 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -41,6 +43,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -136,13 +143,21 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, @@ -165,8 +180,9 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`): - `Image`, or tensor representing an image batch which will be repainted according to `prompt`. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be repainted according to `prompt`. Can also + accpet image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 100): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -287,8 +303,7 @@ def __call__( ) # 3. Preprocess image - image = preprocess(image) - height, width = image.shape[-2:] + image = self.image_processor.preprocess(image) # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -305,6 +320,10 @@ def __call__( generator, ) + height, width = image_latents.shape[-2:] + height = height * self.vae_scale_factor + width = width * self.vae_scale_factor + # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( @@ -346,7 +365,9 @@ def __call__( scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1) # predict the noise residual - noise_pred = self.unet(scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + noise_pred = self.unet( + scaled_latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False + )[0] # Hack: # For karras style schedulers the model does classifer free guidance using the @@ -376,7 +397,7 @@ def __call__( noise_pred = (noise_pred - latents) / (-sigma) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -384,15 +405,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 11. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 12. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -626,13 +651,17 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs @@ -655,8 +684,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -728,17 +762,21 @@ def prepare_image_latents( image = image.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - if isinstance(generator, list): - image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] - image_latents = torch.cat(image_latents, dim=0) + if image.shape[1] == 4: + image_latents = image else: - image_latents = self.vae.encode(image).latent_dist.mode() + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [self.vae.encode(image[i : i + 1]).latent_dist.mode() for i in range(batch_size)] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.mode() if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand image_latents for batch_size diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py index 99aca66db809..ab613dd4dfe4 100755 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py @@ -13,13 +13,15 @@ # limitations under the License. import importlib +import warnings from typing import Callable, List, Optional, Union import torch from k_diffusion.external import CompVisDenoiser, CompVisVDenoiser from k_diffusion.sampling import get_sigmas_karras -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...pipelines import DiffusionPipeline from ...schedulers import LMSDiscreteScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor @@ -111,6 +113,7 @@ def __init__( ) self.register_to_config(requires_safety_checker=requires_safety_checker) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) model = ModelWrapper(unet, scheduler.alphas_cumprod) if scheduler.config.prediction_type == "v_prediction": @@ -207,6 +210,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -231,7 +235,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -288,7 +299,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -346,19 +357,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -590,15 +610,19 @@ def model_fn(x, t): # 8. Run k-diffusion solver latents = self.sampler(model_fn, latents, sigmas) - # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 10. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index 822bd49ce31c..d67a7f894886 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -20,6 +21,7 @@ import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import is_accelerate_available, logging, randn_tensor @@ -31,6 +33,11 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -91,6 +98,8 @@ def __init__( unet=unet, scheduler=scheduler, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -220,8 +229,13 @@ def _encode_prompt(self, prompt, device, do_classifier_free_guidance, negative_p # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -282,7 +296,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]], - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]], + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -299,7 +320,7 @@ def __call__( Args: prompt (`str` or `List[str]`): The prompt or prompts to guide the image upscaling. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. If it's a tensor, it can be either a latent output from a stable diffusion model, or an image tensor in the range `[-1, 1]`. It will be considered a `latent` if `image.shape[1]` is `4`; otherwise, it will be considered to be an @@ -404,7 +425,7 @@ def __call__( ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) image = image.to(dtype=text_embeddings.dtype, device=device) if image.shape[1] == 3: # encode image if not in latent-space yet @@ -505,12 +526,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 10. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents - # 11. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py index b7ded03d529b..1d30b9ee0347 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py @@ -13,12 +13,14 @@ import copy import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import PNDMScheduler from ...schedulers.scheduling_utils import SchedulerMixin @@ -53,7 +55,7 @@ """ -class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image model editing using "Editing Implicit Assumptions in Text-to-Image Diffusion Models". @@ -129,6 +131,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) self.with_to_k = with_to_k @@ -234,6 +237,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -258,7 +262,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -315,7 +326,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -373,19 +384,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -707,6 +727,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -715,6 +738,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps @@ -767,24 +791,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - if output_type == "latent": + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: image = latents has_nsfw_concept = None - elif output_type == "pil": - # 8. Post-processing - image = self.decode_latents(latents) - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - - # 10. Convert to PIL - image = self.numpy_to_pil(image) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] else: - # 8. Post-processing - image = self.decode_latents(latents) + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py index 392b2a72a76f..3826447576d4 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py @@ -11,15 +11,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, PNDMScheduler +from ...schedulers import DDIMScheduler from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -48,7 +51,7 @@ """ -class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-image generation using "MultiDiffusion: Fusing Diffusion Paths for Controlled Image Generation". @@ -94,9 +97,6 @@ def __init__( ): super().__init__() - if isinstance(scheduler, PNDMScheduler): - logger.error("PNDMScheduler for this pipeline is currently not supported.") - if safety_checker is None and requires_safety_checker: logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" @@ -123,6 +123,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -198,6 +199,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -222,7 +224,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -279,7 +288,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -337,19 +346,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -441,10 +459,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def get_views(self, panorama_height, panorama_width, window_size=64, stride=8): # Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113) + # if panorama's height/width < window_size, num_blocks of height/width should return 1 panorama_height /= 8 panorama_width /= 8 - num_blocks_height = (panorama_height - window_size) // stride + 1 - num_blocks_width = (panorama_width - window_size) // stride + 1 + num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1 + num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_height > window_size else 1 total_num_blocks = int(num_blocks_height * num_blocks_width) views = [] for i in range(total_num_blocks): @@ -464,6 +483,7 @@ def __call__( width: Optional[int] = 2048, num_inference_steps: int = 50, guidance_scale: float = 7.5, + view_batch_size: int = 1, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, @@ -498,6 +518,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. + view_batch_size (`int`, *optional*, defaults to 1): + The batch size to denoise splited views. For some GPUs with high performance, higher view batch size + can speedup the generation and increase the VRAM usage. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is @@ -571,6 +594,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -579,6 +605,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps @@ -599,7 +626,11 @@ def __call__( ) # 6. Define panorama grid and initialize views for synthesis. + # prepare batch grid views = self.get_views(height, width) + views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)] + views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch) + count = torch.zeros_like(latents) value = torch.zeros_like(latents) @@ -620,35 +651,55 @@ def __call__( # denoised (latent) crops are then averaged to produce the final latent # for the current timestep via MultiDiffusion. Please see Sec. 4.1 in the # MultiDiffusion paper for more details: https://arxiv.org/abs/2302.08113 - for h_start, h_end, w_start, w_end in views: + # Batch views denoise + for j, batch_view in enumerate(views_batch): + vb_size = len(batch_view) # get the latents corresponding to the current view coordinates - latents_for_view = latents[:, :, h_start:h_end, w_start:w_end] + latents_for_view = torch.cat( + [latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view] + ) + + # rematch block's scheduler status + self.scheduler.__dict__.update(views_scheduler_status[j]) # expand the latents if we are doing classifier free guidance latent_model_input = ( - torch.cat([latents_for_view] * 2) if do_classifier_free_guidance else latents_for_view + latents_for_view.repeat_interleave(2, dim=0) + if do_classifier_free_guidance + else latents_for_view ) latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # repeat prompt_embeds for batch + prompt_embeds_input = torch.cat([prompt_embeds] * vb_size) + # predict the noise residual noise_pred = self.unet( latent_model_input, t, - encoder_hidden_states=prompt_embeds, + encoder_hidden_states=prompt_embeds_input, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, noise_pred_text = noise_pred[::2], noise_pred[1::2] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents_view_denoised = self.scheduler.step( + latents_denoised_batch = self.scheduler.step( noise_pred, t, latents_for_view, **extra_step_kwargs ).prev_sample - value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised - count[:, :, h_start:h_end, w_start:w_end] += 1 + + # save views scheduler status after sample + views_scheduler_status[j] = copy.deepcopy(self.scheduler.__dict__) + + # extract value from batch + for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip( + latents_denoised_batch.chunk(vb_size), batch_view + ): + value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised + count[:, :, h_start:h_end, w_start:w_end] += 1 # take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113 latents = torch.where(count > 0, value / count, value) @@ -659,15 +710,19 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 6444ec7c8506..75ac4f777756 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union @@ -28,7 +29,8 @@ CLIPTokenizer, ) -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention from ...schedulers import DDIMScheduler, DDPMScheduler, EulerAncestralDiscreteScheduler, LMSDiscreteScheduler @@ -175,6 +177,11 @@ class Pix2PixInversionPipelineOutput(BaseOutput, TextualInversionLoaderMixin): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -358,6 +365,7 @@ def __init__( inverse_scheduler=inverse_scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload @@ -439,6 +447,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -463,7 +472,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -520,7 +536,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -578,19 +594,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -617,7 +642,6 @@ def prepare_extra_step_kwargs(self, generator, eta): def check_inputs( self, prompt, - image, source_embeds, target_embeds, callback_steps, @@ -715,19 +739,25 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None image = image.to(device=device, dtype=dtype) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) + if image.shape[1] == 4: + latents = image - if isinstance(generator, list): - latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] - latents = torch.cat(latents, dim=0) else: - latents = self.vae.encode(image).latent_dist.sample(generator) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0) + else: + latents = self.vae.encode(image).latent_dist.sample(generator) - latents = self.vae.config.scaling_factor * latents + latents = self.vae.config.scaling_factor * latents if batch_size != latents.shape[0]: if batch_size % latents.shape[0] == 0: @@ -792,7 +822,6 @@ def kl_divergence(self, hidden_states): def __call__( self, prompt: Optional[Union[str, List[str]]] = None, - image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, source_embeds: torch.Tensor = None, target_embeds: torch.Tensor = None, height: Optional[int] = None, @@ -893,7 +922,6 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, - image, source_embeds, target_embeds, callback_steps, @@ -1045,31 +1073,42 @@ def __call__( if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - # 11. Post-process the latents. - edited_image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 12. Run the safety checker. - edited_image, has_nsfw_concept = self.run_safety_checker(edited_image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 13. Convert to PIL. - if output_type == "pil": - edited_image = self.numpy_to_pil(edited_image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() if not return_dict: - return (edited_image, has_nsfw_concept) + return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=edited_image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) @torch.no_grad() @replace_example_docstring(EXAMPLE_INVERT_DOC_STRING) def invert( self, prompt: Optional[str] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 50, guidance_scale: float = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -1093,8 +1132,9 @@ def invert( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image`, *optional*): - `Image`, or tensor representing an image batch which will be used for conditioning. + image (`torch.FloatTensor` `np.ndarray`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch which will be used for conditioning. Can also accpet + image latents as `image`, if passing latents directly, it will not be encoded again. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -1163,7 +1203,7 @@ def invert( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) # 4. Prepare latent variables latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) @@ -1251,16 +1291,13 @@ def invert( inverted_latents = latents.detach().clone() # 8. Post-processing - image = self.decode_latents(latents.detach()) + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 9. Convert to PIL. - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (inverted_latents, image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py index ebac58e18f62..ba1c0d2b9d49 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py @@ -13,13 +13,15 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring @@ -140,6 +142,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -215,6 +218,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -239,7 +243,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -296,7 +307,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -354,19 +365,28 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is not None: - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) - else: - has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -619,7 +639,7 @@ def __call__( def get_map_size(module, input, output): nonlocal map_size - map_size = output.sample.shape[-2:] + map_size = output[0].shape[-2:] with self.unet.mid_block.attentions[0].register_forward_hook(get_map_size): with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -682,15 +702,19 @@ def get_map_size(module, input, output): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None - # 9. Run safety checker - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) if not return_dict: return (image, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 87014f52dfc2..0fda05ea5ec2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -13,16 +13,18 @@ # limitations under the License. import inspect -from typing import Any, Callable, List, Optional, Union +import warnings +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import PIL import torch -import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel +from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline @@ -33,6 +35,11 @@ def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) if isinstance(image, torch.Tensor): return image elif isinstance(image, PIL.Image.Image): @@ -53,7 +60,7 @@ def preprocess(image): return image -class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided image super-resolution using Stable Diffusion 2. @@ -124,6 +131,8 @@ def __init__( watermarker=watermarker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, resample="bicubic") self.register_to_config(max_noise_level=max_noise_level) def enable_sequential_cpu_offload(self, gpu_id=0): @@ -215,6 +224,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -239,7 +249,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -296,7 +313,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -372,8 +389,13 @@ def prepare_extra_step_kwargs(self, generator, eta): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -426,14 +448,15 @@ def check_inputs( if ( not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image) + and not isinstance(image, np.ndarray) and not isinstance(image, list) ): raise ValueError( - f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + f"`image` has to be of type `torch.Tensor`, `np.ndarray`, `PIL.Image.Image` or `list` but is {type(image)}" ) - # verify batch size of prompt and image are same if image is a list or tensor - if isinstance(image, list) or isinstance(image, torch.Tensor): + # verify batch size of prompt and image are same if image is a list or tensor or numpy array + if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): if isinstance(prompt, str): batch_size = 1 else: @@ -477,7 +500,14 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, @@ -492,6 +522,7 @@ def __call__( return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. @@ -500,7 +531,7 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch which will be upscaled. * num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the @@ -546,6 +577,10 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). Examples: ```py @@ -610,6 +645,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -618,10 +656,11 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Preprocess image - image = preprocess(image) + image = self.image_processor.preprocess(image) image = image.to(dtype=prompt_embeds.dtype, device=device) # 5. set timesteps @@ -678,8 +717,13 @@ def __call__( # predict the noise residual noise_pred = self.unet( - latent_model_input, t, encoder_hidden_states=prompt_embeds, class_labels=noise_level - ).sample + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=noise_level, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -687,7 +731,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -699,37 +743,39 @@ def __call__( # make sure the VAE is in float32 mode, as it overflows in float16 self.vae.to(dtype=torch.float32) - # TODO(Patrick, William) - clean up when attention is refactored - use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") - use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + use_torch_2_0_or_xformers = self.vae.decoder.mid_block.attentions[0].processor in [ + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + ] # if xformers or torch_2_0 is used attention block does not need # to be in float32 which can save lots of memory - if not use_torch_2_0_attn and not use_xformers: + if not use_torch_2_0_or_xformers: self.vae.post_quant_conv.to(latents.dtype) self.vae.decoder.conv_in.to(latents.dtype) self.vae.decoder.mid_block.to(latents.dtype) else: latents = latents.float() - # 11. Convert to PIL - if output_type == "pil": - image = self.decode_latents(latents) - + # post-processing + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept, _ = self.run_safety_checker(image, device, prompt_embeds.dtype) - - image = self.numpy_to_pil(image) - - # 11. Apply watermark - if self.watermarker is not None: - image = self.watermarker.apply_watermark(image) - elif output_type == "pt": - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample - has_nsfw_concept = None else: - image = self.decode_latents(latents) + image = latents has_nsfw_concept = None + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # 11. Apply watermark + if output_type == "pil" and self.watermarker is not None: + image = self.watermarker.apply_watermark(image) + # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py index fafb8d1d2800..e36ebfbb70f1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py @@ -13,13 +13,15 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers.models.clip.modeling_clip import CLIPTextModelOutput -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, PriorTransformer, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers @@ -48,7 +50,7 @@ """ -class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): """ Pipeline for text-to-image generation using stable unCLIP. @@ -136,6 +138,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -335,6 +338,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -359,7 +363,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -416,7 +427,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -474,8 +485,13 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -830,7 +846,8 @@ def __call__( timestep=t, sample=prior_latents, **prior_extra_step_kwargs, - ).prev_sample + return_dict=False, + )[0] if callback is not None and i % callback_steps == 0: callback(i, t, prior_latents) @@ -847,6 +864,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 8. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt=prompt, device=device, @@ -855,6 +875,7 @@ def __call__( negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 9. Prepare image embeddings @@ -903,7 +924,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -911,22 +933,22 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 14. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 15. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py index 22b7280f3679..0187c86b4239 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL @@ -21,7 +22,8 @@ from diffusers.utils.import_utils import is_accelerate_available -from ...loaders import TextualInversionLoaderMixin +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.embeddings import get_timestep_embedding from ...schedulers import KarrasDiffusionSchedulers @@ -61,7 +63,7 @@ """ -class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): """ Pipeline for text-guided image to image generation using stable unCLIP. @@ -138,6 +140,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -235,6 +238,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -259,7 +263,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -316,7 +327,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -429,8 +440,13 @@ def _encode_image( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -744,6 +760,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt=prompt, device=device, @@ -752,6 +771,7 @@ def __call__( negative_prompt=negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Encoder input image @@ -799,7 +819,8 @@ def __call__( encoder_hidden_states=prompt_embeds, class_labels=image_embeds, cross_attention_kwargs=cross_attention_kwargs, - ).sample + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: @@ -807,22 +828,23 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if callback is not None and i % callback_steps == 0: callback(i, t, latents) # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) - if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 87e7b3e6c9eb..d770ee290517 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -363,8 +363,13 @@ def run_safety_checker(self, image, device, dtype, enable_safety_guidance): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 6fc89e945604..8bf4bafa4fe5 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -19,7 +19,7 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...loaders import TextualInversionLoaderMixin +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -73,7 +73,7 @@ def tensor2vid(video: torch.Tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) - return images -class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin): +class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -224,6 +224,7 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -248,7 +249,14 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -305,7 +313,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -591,6 +599,9 @@ def __call__( do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) prompt_embeds = self._encode_prompt( prompt, device, @@ -599,6 +610,7 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) # 4. Prepare timesteps diff --git a/src/diffusers/pipelines/unidiffuser/__init__.py b/src/diffusers/pipelines/unidiffuser/__init__.py new file mode 100644 index 000000000000..a774e3274030 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/__init__.py @@ -0,0 +1,20 @@ +from ...utils import ( + OptionalDependencyNotAvailable, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + ImageTextPipelineOutput, + UniDiffuserPipeline, + ) +else: + from .modeling_text_decoder import UniDiffuserTextDecoder + from .modeling_uvit import UniDiffuserModel, UTransformer2DModel + from .pipeline_unidiffuser import ImageTextPipelineOutput, UniDiffuserPipeline diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py new file mode 100644 index 000000000000..9b962f6e0656 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -0,0 +1,296 @@ +from typing import Optional + +import numpy as np +import torch +from torch import nn +from transformers import GPT2Config, GPT2LMHeadModel +from transformers.modeling_utils import ModuleUtilsMixin + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin + + +# Modified from ClipCaptionModel in https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py +class UniDiffuserTextDecoder(ModelMixin, ConfigMixin, ModuleUtilsMixin): + """ + Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to + generate text from the UniDiffuser image-text embedding. + + Parameters: + prefix_length (`int`): + Max number of prefix tokens that will be supplied to the model. + prefix_inner_dim (`int`): + The hidden size of the the incoming prefix embeddings. For UniDiffuser, this would be the hidden dim of the + CLIP text encoder. + prefix_hidden_dim (`int`, *optional*): + Hidden dim of the MLP if we encode the prefix. + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + """ + + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.bias", r"h\.\d+\.attn\.masked_bias"] + + @register_to_config + def __init__( + self, + prefix_length: int, + prefix_inner_dim: int, + prefix_hidden_dim: Optional[int] = None, + vocab_size: int = 50257, # Start of GPT2 config args + n_positions: int = 1024, + n_embd: int = 768, + n_layer: int = 12, + n_head: int = 12, + n_inner: Optional[int] = None, + activation_function: str = "gelu_new", + resid_pdrop: float = 0.1, + embd_pdrop: float = 0.1, + attn_pdrop: float = 0.1, + layer_norm_epsilon: float = 1e-5, + initializer_range: float = 0.02, + scale_attn_weights: bool = True, + use_cache: bool = True, + scale_attn_by_inverse_layer_idx: bool = False, + reorder_and_upcast_attn: bool = False, + ): + super().__init__() + + self.prefix_length = prefix_length + + if prefix_inner_dim != n_embd and prefix_hidden_dim is None: + raise ValueError( + f"`prefix_hidden_dim` cannot be `None` when `prefix_inner_dim`: {prefix_hidden_dim} and" + f" `n_embd`: {n_embd} are not equal." + ) + + self.prefix_inner_dim = prefix_inner_dim + self.prefix_hidden_dim = prefix_hidden_dim + + self.encode_prefix = ( + nn.Linear(self.prefix_inner_dim, self.prefix_hidden_dim) + if self.prefix_hidden_dim is not None + else nn.Identity() + ) + self.decode_prefix = ( + nn.Linear(self.prefix_hidden_dim, n_embd) if self.prefix_hidden_dim is not None else nn.Identity() + ) + + gpt_config = GPT2Config( + vocab_size=vocab_size, + n_positions=n_positions, + n_embd=n_embd, + n_layer=n_layer, + n_head=n_head, + n_inner=n_inner, + activation_function=activation_function, + resid_pdrop=resid_pdrop, + embd_pdrop=embd_pdrop, + attn_pdrop=attn_pdrop, + layer_norm_epsilon=layer_norm_epsilon, + initializer_range=initializer_range, + scale_attn_weights=scale_attn_weights, + use_cache=use_cache, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + self.transformer = GPT2LMHeadModel(gpt_config) + + def forward( + self, + input_ids: torch.Tensor, + prefix_embeds: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + """ + Args: + input_ids (`torch.Tensor` of shape `(N, max_seq_len)`): + Text tokens to use for inference. + prefix_embeds (`torch.Tensor` of shape `(N, prefix_length, 768)`): + Prefix embedding to preprend to the embedded tokens. + attention_mask (`torch.Tensor` of shape `(N, prefix_length + max_seq_len, 768)`, *optional*): + Attention mask for the prefix embedding. + labels (`torch.Tensor`, *optional*): + Labels to use for language modeling. + """ + embedding_text = self.transformer.transformer.wte(input_ids) + hidden = self.encode_prefix(prefix_embeds) + prefix_embeds = self.decode_prefix(hidden) + embedding_cat = torch.cat((prefix_embeds, embedding_text), dim=1) + + if labels is not None: + dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) + labels = torch.cat((dummy_token, input_ids), dim=1) + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) + if self.prefix_hidden_dim is not None: + return out, hidden + else: + return out + + def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: + return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) + + def encode(self, prefix): + return self.encode_prefix(prefix) + + @torch.no_grad() + def generate_captions(self, features, eos_token_id, device): + """ + Generate captions given text embedding features. Returns list[L]. + + Args: + features (`torch.Tensor` of shape `(B, L, D)`): + Text embedding features to generate captions from. + eos_token_id (`int`): + The token ID of the EOS token for the text decoder model. + device: + Device to perform text generation on. + + Returns: + `List[str]`: A list of strings generated from the decoder model. + """ + + features = torch.split(features, 1, dim=0) + generated_tokens = [] + generated_seq_lengths = [] + for feature in features: + feature = self.decode_prefix(feature.to(device)) # back to the clip feature + # Only support beam search for now + output_tokens, seq_lengths = self.generate_beam( + input_embeds=feature, device=device, eos_token_id=eos_token_id + ) + generated_tokens.append(output_tokens[0]) + generated_seq_lengths.append(seq_lengths[0]) + generated_tokens = torch.stack(generated_tokens) + generated_seq_lengths = torch.stack(generated_seq_lengths) + return generated_tokens, generated_seq_lengths + + @torch.no_grad() + def generate_beam( + self, + input_ids=None, + input_embeds=None, + device=None, + beam_size: int = 5, + entry_length: int = 67, + temperature: float = 1.0, + eos_token_id: Optional[int] = None, + ): + """ + Generates text using the given tokenizer and text prompt or token embedding via beam search. This + implementation is based on the beam search implementation from the [original UniDiffuser + code](https://github.com/thu-ml/unidiffuser/blob/main/libs/caption_decoder.py#L89). + + Args: + eos_token_id (`int`, *optional*): + The token ID of the EOS token for the text decoder model. + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Tokenizer indices of input sequence tokens in the vocabulary. One of `input_ids` and `input_embeds` + must be supplied. + input_embeds (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + An embedded representation to directly pass to the transformer as a prefix for beam search. One of + `input_ids` and `input_embeds` must be supplied. + device: + The device to perform beam search on. + beam_size (`int`, *optional*, defaults to `5`): + The number of best states to store during beam search. + entry_length (`int`, *optional*, defaults to `67`): + The number of iterations to run beam search. + temperature (`float`, *optional*, defaults to 1.0): + The temperature to use when performing the softmax over logits from the decoding model. + + Returns: + `Tuple(torch.Tensor, torch.Tensor)`: A tuple of tensors where the first element is a tensor of generated + token sequences sorted by score in descending order, and the second element is the sequence lengths + corresponding to those sequences. + """ + # Generates text until stop_token is reached using beam search with the desired beam size. + stop_token_index = eos_token_id + tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=device, dtype=torch.int) + is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) + + if input_embeds is not None: + generated = input_embeds + else: + generated = self.transformer.transformer.wte(input_ids) + + for i in range(entry_length): + outputs = self.transformer(inputs_embeds=generated) + logits = outputs.logits + logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) + logits = logits.softmax(-1).log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + generated = generated.expand(beam_size, *generated.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + if tokens is None: + tokens = next_tokens + else: + tokens = tokens.expand(beam_size, *tokens.shape[1:]) + tokens = torch.cat((tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + tokens = tokens[next_tokens_source] + tokens = torch.cat((tokens, next_tokens), dim=1) + generated = generated[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + next_token_embed = self.transformer.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) + generated = torch.cat((generated, next_token_embed), dim=1) + is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() + if is_stopped.all(): + break + + scores = scores / seq_lengths + order = scores.argsort(descending=True) + # tokens tensors are already padded to max_seq_length + output_texts = [tokens[i] for i in order] + output_texts = torch.stack(output_texts, dim=0) + seq_lengths = torch.tensor([seq_lengths[i] for i in order], dtype=seq_lengths.dtype) + return output_texts, seq_lengths diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py new file mode 100644 index 000000000000..b7829f76ec12 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -0,0 +1,1196 @@ +import math +from typing import Optional, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models import ModelMixin +from ...models.attention import AdaLayerNorm, FeedForward +from ...models.attention_processor import Attention +from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed +from ...models.transformer_2d import Transformer2DModelOutput +from ...utils import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + logger.warning( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect." + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (torch.Tensor, float, float, float, float) -> torch.Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, + \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for + generating the random values works best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + use_pos_embed=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.use_pos_embed = use_pos_embed + if self.use_pos_embed: + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + if self.use_pos_embed: + return latent + self.pos_embed + else: + return latent + + +class SkipBlock(nn.Module): + def __init__(self, dim: int): + super().__init__() + + self.skip_linear = nn.Linear(2 * dim, dim) + + # Use torch.nn.LayerNorm for now, following the original code + self.norm = nn.LayerNorm(dim) + + def forward(self, x, skip): + x = self.skip_linear(torch.cat([x, skip], dim=-1)) + x = self.norm(x) + + return x + + +# Modified to support both pre-LayerNorm and post-LayerNorm configurations +# Don't support AdaLayerNormZero for now +# Modified from diffusers.models.attention.BasicTransformerBlock +class UTransformerBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). Note that `BasicTransformerBlock` uses pre-LayerNorm, e.g. + `pre_layer_norm = True`. + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = True, + final_dropout: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + else: + norm_hidden_states = self.norm1(hidden_states) + else: + norm_hidden_states = hidden_states + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + attn_output = self.norm1(attn_output, timestep) + else: + attn_output = self.norm1(attn_output) + + hidden_states = attn_output + hidden_states + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + else: + norm_hidden_states = hidden_states + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + # Post-LayerNorm + if not self.pre_layer_norm: + attn_output = self.norm2(attn_output, timestep) if self.use_ada_layer_norm else self.norm2(attn_output) + + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + norm_hidden_states = self.norm3(hidden_states) + else: + norm_hidden_states = hidden_states + + ff_output = self.ff(norm_hidden_states) + + # Post-LayerNorm + if not self.pre_layer_norm: + ff_output = self.norm3(ff_output) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +# Like UTransformerBlock except with LayerNorms on the residual backbone of the block +# Modified from diffusers.models.attention.BasicTransformerBlock +class UniDiffuserBlock(nn.Module): + r""" + A modification of BasicTransformerBlock which supports pre-LayerNorm and post-LayerNorm configurations and puts the + LayerNorms on the residual backbone of the block. This matches the transformer block in the [original UniDiffuser + implementation](https://github.com/thu-ml/unidiffuser/blob/main/libs/uvit_multi_post_ln_v1.py#L104). + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to be used in feed-forward. + num_embeds_ada_norm (:obj: `int`, *optional*): + The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (:obj: `bool`, *optional*, defaults to `False`): + Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + norm_type (`str`, defaults to `"layer_norm"`): + The layer norm implementation to use. + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + pre_layer_norm: bool = False, + final_dropout: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + self.pre_layer_norm = pre_layer_norm + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + # 1. Self-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) + + def forward( + self, + hidden_states, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + timestep=None, + cross_attention_kwargs=None, + class_labels=None, + ): + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Pre-LayerNorm + if self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + # 1. Self-Attention + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + attn_output = self.attn1( + hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Following the diffusers transformer block implementation, put the LayerNorm on the + # residual backbone + # Post-LayerNorm + if not self.pre_layer_norm: + if self.use_ada_layer_norm: + hidden_states = self.norm1(hidden_states, timestep) + else: + hidden_states = self.norm1(hidden_states) + + if self.attn2 is not None: + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + # TODO (Birch-San): Here we should prepare the encoder_attention mask correctly + # prepare attention mask here + + # 2. Cross-Attention + attn_output = self.attn2( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + + # 3. Feed-forward + # Pre-LayerNorm + if self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + ff_output = self.ff(hidden_states) + + hidden_states = ff_output + hidden_states + + # Post-LayerNorm + if not self.pre_layer_norm: + hidden_states = self.norm3(hidden_states) + + return hidden_states + + +# Modified from diffusers.models.transformer_2d.Transformer2DModel +# Modify the transformer block structure to be U-Net like following U-ViT +# Only supports patch-style input and torch.nn.LayerNorm currently +# https://github.com/baofff/U-ViT +class UTransformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model based on the [U-ViT](https://github.com/baofff/U-ViT) architecture for image-like data. Compared + to [`Transformer2DModel`], this model has skip connections between transformer blocks in a "U"-shaped fashion, + similar to a U-Net. Supports only continuous (actual embeddings) inputs, which are embedded via a [`PatchEmbed`] + layer and then reshaped to (b, t, d). + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float() when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = 2, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Input + # Only support patch input of shape (batch_size, num_channels, height, width) for now + assert in_channels is not None and patch_size is not None, "Patch input requires in_channels and patch_size." + + assert sample_size is not None, "UTransformer2DModel over patched input must provide sample_size" + + # 2. Define input layers + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + + # 3. Define transformers blocks + # Modify this to have in_blocks ("downsample" blocks, even though we don't actually downsample), a mid_block, + # and out_blocks ("upsample" blocks). Like a U-Net, there are skip connections from in_blocks to out_blocks in + # a "U"-shaped fashion (e.g. first in_block to last out_block, etc.). + # Quick hack to make the transformer block type configurable + if block_type == "unidiffuser": + block_cls = UniDiffuserBlock + else: + block_cls = UTransformerBlock + self.transformer_in_blocks = nn.ModuleList( + [ + block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + for d in range(num_layers // 2) + ] + ) + + self.transformer_mid_block = block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ) + + # For each skip connection, we use a SkipBlock (concatenation + Linear + LayerNorm) to process the inputs + # before each transformer out_block. + self.transformer_out_blocks = nn.ModuleList( + [ + nn.ModuleDict( + { + "skip": SkipBlock( + inner_dim, + ), + "block": block_cls( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + final_dropout=ff_final_dropout, + ), + } + ) + for d in range(num_layers // 2) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + + # Following the UniDiffuser U-ViT implementation, we process the transformer output with + # a LayerNorm layer with per-element affine params + self.norm_out = nn.LayerNorm(inner_dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + timestep=None, + class_labels=None, + cross_attention_kwargs=None, + return_dict: bool = True, + hidden_states_is_embedding: bool = False, + unpatchify: bool = True, + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + hidden_states_is_embedding (`bool`, *optional*, defaults to `False`): + Whether or not hidden_states is an embedding directly usable by the transformer. In this case we will + ignore input handling (e.g. continuous, vectorized, etc.) and directly feed hidden_states into the + transformer blocks. + unpatchify (`bool`, *optional*, defaults to `True`): + Whether to unpatchify the transformer output. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # 0. Check inputs + + if not unpatchify and return_dict: + raise ValueError( + f"Cannot both define `unpatchify`: {unpatchify} and `return_dict`: {return_dict} since when" + f" `unpatchify` is {unpatchify} the returned output is of shape (batch_size, seq_len, hidden_dim)" + " rather than (batch_size, num_channels, height, width)." + ) + + # 1. Input + if not hidden_states_is_embedding: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + + # In ("downsample") blocks + skips = [] + for in_block in self.transformer_in_blocks: + hidden_states = in_block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + skips.append(hidden_states) + + # Mid block + hidden_states = self.transformer_mid_block(hidden_states) + + # Out ("upsample") blocks + for out_block in self.transformer_out_blocks: + hidden_states = out_block["skip"](hidden_states, skips.pop()) + hidden_states = out_block["block"]( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + # Don't support AdaLayerNorm for now, so no conditioning/scale/shift logic + hidden_states = self.norm_out(hidden_states) + # hidden_states = self.proj_out(hidden_states) + + if unpatchify: + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + else: + output = hidden_states + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class UniDiffuserModel(ModelMixin, ConfigMixin): + """ + Transformer model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is a + modification of [`UTransformer2DModel`] with input and output heads for the VAE-embedded latent image, the + CLIP-embedded image, and the CLIP-embedded prompt (see paper for more details). + + Parameters: + text_dim (`int`): The hidden dimension of the CLIP text model used to embed images. + clip_img_dim (`int`): The hidden dimension of the CLIP vision model used to embed prompts. + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input. + out_channels (`int`, *optional*): + The number of output channels; if `None`, defaults to `in_channels`. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups to use when performing Group Normalization. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + patch_size (`int`, *optional*, defaults to 2): + The patch size to use in the patch embedding. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + use_linear_projection (int, *optional*): TODO: Not used + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used in each + transformer block. + upcast_attention (`bool`, *optional*): + Whether to upcast the query and key to float32 when performing the attention calculation. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The Layer Normalization implementation to use. Defaults to `torch.nn.LayerNorm`. + block_type (`str`, *optional*, defaults to `"unidiffuser"`): + The transformer block implementation to use. If `"unidiffuser"`, has the LayerNorms on the residual + backbone of each transformer block; otherwise has them in the attention/feedforward branches (the standard + behavior in `diffusers`.) + pre_layer_norm (`bool`, *optional*): + Whether to perform layer normalization before the attention and feedforward operations ("pre-LayerNorm"), + as opposed to after ("post-LayerNorm"). The original UniDiffuser implementation is post-LayerNorm + (`pre_layer_norm = False`). + norm_elementwise_affine (`bool`, *optional*): + Whether to use learnable per-element affine parameters during layer normalization. + use_patch_pos_embed (`bool`, *optional*): + Whether to use position embeddings inside the patch embedding layer (`PatchEmbed`). + ff_final_dropout (`bool`, *optional*): + Whether to use a final Dropout layer after the feedforward network. + use_data_type_embedding (`bool`, *optional*): + Whether to use a data type embedding. This is only relevant for UniDiffuser-v1 style models; UniDiffuser-v1 + is continue-trained from UniDiffuser-v0 on non-publically-available data and accepts a `data_type` + argument, which can either be `1` to use the weights trained on non-publically-available data or `0` + otherwise. This argument is subsequently embedded by the data type embedding, if used. + """ + + @register_to_config + def __init__( + self, + text_dim: int = 768, + clip_img_dim: int = 512, + num_text_tokens: int = 77, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + block_type: str = "unidiffuser", + pre_layer_norm: bool = False, + use_timestep_embedding=False, + norm_elementwise_affine: bool = True, + use_patch_pos_embed=False, + ff_final_dropout: bool = True, + use_data_type_embedding: bool = False, + ): + super().__init__() + + # 0. Handle dimensions + self.inner_dim = num_attention_heads * attention_head_dim + + assert sample_size is not None, "UniDiffuserModel over patched input must provide sample_size" + self.sample_size = sample_size + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.patch_size = patch_size + # Assume image is square... + self.num_patches = (self.sample_size // patch_size) * (self.sample_size // patch_size) + + # 1. Define input layers + # 1.1 Input layers for text and image input + # For now, only support patch input for VAE latent image input + self.vae_img_in = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + use_pos_embed=use_patch_pos_embed, + ) + self.clip_img_in = nn.Linear(clip_img_dim, self.inner_dim) + self.text_in = nn.Linear(text_dim, self.inner_dim) + + # 1.2. Timestep embeddings for t_img, t_text + self.timestep_img_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_img_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + self.timestep_text_proj = Timesteps( + self.inner_dim, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_text_embed = ( + TimestepEmbedding( + self.inner_dim, + 4 * self.inner_dim, + out_dim=self.inner_dim, + ) + if use_timestep_embedding + else nn.Identity() + ) + + # 1.3. Positional embedding + self.num_text_tokens = num_text_tokens + self.num_tokens = 1 + 1 + num_text_tokens + 1 + self.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, self.inner_dim)) + self.pos_embed_drop = nn.Dropout(p=dropout) + trunc_normal_(self.pos_embed, std=0.02) + + # 1.4. Handle data type token embeddings for UniDiffuser-V1, if necessary + self.use_data_type_embedding = use_data_type_embedding + if self.use_data_type_embedding: + self.data_type_token_embedding = nn.Embedding(2, self.inner_dim) + self.data_type_pos_embed_token = nn.Parameter(torch.zeros(1, 1, self.inner_dim)) + + # 2. Define transformer blocks + self.transformer = UTransformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + out_channels=out_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + patch_size=patch_size, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + block_type=block_type, + pre_layer_norm=pre_layer_norm, + norm_elementwise_affine=norm_elementwise_affine, + use_patch_pos_embed=use_patch_pos_embed, + ff_final_dropout=ff_final_dropout, + ) + + # 3. Define output layers + patch_dim = (patch_size**2) * out_channels + self.vae_img_out = nn.Linear(self.inner_dim, patch_dim) + self.clip_img_out = nn.Linear(self.inner_dim, clip_img_dim) + self.text_out = nn.Linear(self.inner_dim, text_dim) + + @torch.jit.ignore + def no_weight_decay(self): + return {"pos_embed"} + + def forward( + self, + latent_image_embeds: torch.FloatTensor, + image_embeds: torch.FloatTensor, + prompt_embeds: torch.FloatTensor, + timestep_img: Union[torch.Tensor, float, int], + timestep_text: Union[torch.Tensor, float, int], + data_type: Optional[Union[torch.Tensor, float, int]] = 1, + encoder_hidden_states=None, + cross_attention_kwargs=None, + ): + """ + Args: + latent_image_embeds (`torch.FloatTensor` of shape `(batch size, latent channels, height, width)`): + Latent image representation from the VAE encoder. + image_embeds (`torch.FloatTensor` of shape `(batch size, 1, clip_img_dim)`): + CLIP-embedded image representation (unsqueezed in the first dimension). + prompt_embeds (`torch.FloatTensor` of shape `(batch size, seq_len, text_dim)`): + CLIP-embedded text representation. + timestep_img (`torch.long` or `float` or `int`): + Current denoising step for the image. + timestep_text (`torch.long` or `float` or `int`): + Current denoising step for the text. + data_type: (`torch.int` or `float` or `int`, *optional*, defaults to `1`): + Only used in UniDiffuser-v1-style models. Can be either `1`, to use weights trained on nonpublic data, + or `0` otherwise. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + cross_attention_kwargs (*optional*): + Keyword arguments to supply to the cross attention layers, if used. + + + Returns: + `tuple`: Returns relevant parts of the model's noise prediction: the first element of the tuple is tbe VAE + image embedding, the second element is the CLIP image embedding, and the third element is the CLIP text + embedding. + """ + batch_size = latent_image_embeds.shape[0] + + # 1. Input + # 1.1. Map inputs to shape (B, N, inner_dim) + vae_hidden_states = self.vae_img_in(latent_image_embeds) + clip_hidden_states = self.clip_img_in(image_embeds) + text_hidden_states = self.text_in(prompt_embeds) + + num_text_tokens, num_img_tokens = text_hidden_states.size(1), vae_hidden_states.size(1) + + # 1.2. Encode image timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_img): + timestep_img = torch.tensor([timestep_img], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_img = timestep_img * torch.ones(batch_size, dtype=timestep_img.dtype, device=timestep_img.device) + + timestep_img_token = self.timestep_img_proj(timestep_img) + # t_img_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_img_token = timestep_img_token.to(dtype=self.dtype) + timestep_img_token = self.timestep_img_embed(timestep_img_token) + timestep_img_token = timestep_img_token.unsqueeze(dim=1) + + # 1.3. Encode text timesteps to single token (B, 1, inner_dim) + if not torch.is_tensor(timestep_text): + timestep_text = torch.tensor([timestep_text], dtype=torch.long, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_text = timestep_text * torch.ones(batch_size, dtype=timestep_text.dtype, device=timestep_text.device) + + timestep_text_token = self.timestep_text_proj(timestep_text) + # t_text_token does not contain any weights and will always return f32 tensors + # but time_embedding might be fp16, so we need to cast here. + timestep_text_token = timestep_text_token.to(dtype=self.dtype) + timestep_text_token = self.timestep_text_embed(timestep_text_token) + timestep_text_token = timestep_text_token.unsqueeze(dim=1) + + # 1.4. Concatenate all of the embeddings together. + if self.use_data_type_embedding: + assert data_type is not None, "data_type must be supplied if the model uses a data type embedding" + if not torch.is_tensor(data_type): + data_type = torch.tensor([data_type], dtype=torch.int, device=vae_hidden_states.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + data_type = data_type * torch.ones(batch_size, dtype=data_type.dtype, device=data_type.device) + + data_type_token = self.data_type_token_embedding(data_type).unsqueeze(dim=1) + hidden_states = torch.cat( + [ + timestep_img_token, + timestep_text_token, + data_type_token, + text_hidden_states, + clip_hidden_states, + vae_hidden_states, + ], + dim=1, + ) + else: + hidden_states = torch.cat( + [timestep_img_token, timestep_text_token, text_hidden_states, clip_hidden_states, vae_hidden_states], + dim=1, + ) + + # 1.5. Prepare the positional embeddings and add to hidden states + # Note: I think img_vae should always have the proper shape, so there's no need to interpolate + # the position embeddings. + if self.use_data_type_embedding: + pos_embed = torch.cat( + [self.pos_embed[:, : 1 + 1, :], self.data_type_pos_embed_token, self.pos_embed[:, 1 + 1 :, :]], dim=1 + ) + else: + pos_embed = self.pos_embed + hidden_states = hidden_states + pos_embed + hidden_states = self.pos_embed_drop(hidden_states) + + # 2. Blocks + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=None, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + hidden_states_is_embedding=True, + unpatchify=False, + )[0] + + # 3. Output + # Split out the predicted noise representation. + if self.use_data_type_embedding: + ( + t_img_token_out, + t_text_token_out, + data_type_token_out, + text_out, + img_clip_out, + img_vae_out, + ) = hidden_states.split((1, 1, 1, num_text_tokens, 1, num_img_tokens), dim=1) + else: + t_img_token_out, t_text_token_out, text_out, img_clip_out, img_vae_out = hidden_states.split( + (1, 1, num_text_tokens, 1, num_img_tokens), dim=1 + ) + + img_vae_out = self.vae_img_out(img_vae_out) + + # unpatchify + height = width = int(img_vae_out.shape[1] ** 0.5) + img_vae_out = img_vae_out.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + img_vae_out = torch.einsum("nhwpqc->nchpwq", img_vae_out) + img_vae_out = img_vae_out.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + img_clip_out = self.clip_img_out(img_clip_out) + + text_out = self.text_out(text_out) + + return img_vae_out, img_clip_out, text_out diff --git a/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py new file mode 100644 index 000000000000..ecc457b4cb94 --- /dev/null +++ b/src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py @@ -0,0 +1,1428 @@ +import inspect +import warnings +from dataclasses import dataclass +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from ...models import AutoencoderKL +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, +) +from ...utils.outputs import BaseOutput +from ..pipeline_utils import DiffusionPipeline +from .modeling_text_decoder import UniDiffuserTextDecoder +from .modeling_uvit import UniDiffuserModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess +def preprocess(image): + warnings.warn( + "The preprocess method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor.preprocess instead", + FutureWarning, + ) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +# New BaseOutput child class for joint image-text output +@dataclass +class ImageTextPipelineOutput(BaseOutput): + """ + Output class for joint image-text pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + text (`List[str]` or `List[List[str]]`) + List of generated text strings of length `batch_size` or a list of list of strings whose outer list has + length `batch_size`. Text generated by the diffusion pipeline. + """ + + images: Optional[Union[List[PIL.Image.Image], np.ndarray]] + text: Optional[Union[List[str], List[List[str]]]] + + +class UniDiffuserPipeline(DiffusionPipeline): + r""" + Pipeline for a bimodal image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model, which supports + unconditional text and image generation, text-conditioned image generation, image-conditioned text generation, and + joint image-text generation. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. This + is part of the UniDiffuser image representation, along with the CLIP vision encoding. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Similar to Stable Diffusion, UniDiffuser uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel) to encode text + prompts. + image_encoder ([`CLIPVisionModel`]): + UniDiffuser uses the vision portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel) to encode + images as part of its image representation, along with the VAE latent representation. + image_processor ([`CLIPImageProcessor`]): + CLIP image processor of class + [CLIPImageProcessor](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPImageProcessor), + used to preprocess the image before CLIP encoding it with `image_encoder`. + clip_tokenizer ([`CLIPTokenizer`]): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTokenizer) which + is used to tokenizer a prompt before encoding it with `text_encoder`. + text_decoder ([`UniDiffuserTextDecoder`]): + Frozen text decoder. This is a GPT-style model which is used to generate text from the UniDiffuser + embedding. + text_tokenizer ([`GPT2Tokenizer`]): + Tokenizer of class + [GPT2Tokenizer](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.GPT2Tokenizer) which + is used along with the `text_decoder` to decode text for text generation. + unet ([`UniDiffuserModel`]): + UniDiffuser uses a [U-ViT](https://github.com/baofff/U-ViT) model architecture, which is similar to a + [`Transformer2DModel`] with U-Net-style skip connections between transformer layers. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image and/or text latents. The + original UniDiffuser paper uses the [`DPMSolverMultistepScheduler`] scheduler. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + image_encoder: CLIPVisionModelWithProjection, + image_processor: CLIPImageProcessor, + clip_tokenizer: CLIPTokenizer, + text_decoder: UniDiffuserTextDecoder, + text_tokenizer: GPT2Tokenizer, + unet: UniDiffuserModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if text_encoder.config.hidden_size != text_decoder.prefix_inner_dim: + raise ValueError( + f"The text encoder hidden size and text decoder prefix inner dim must be the same, but" + f" `text_encoder.config.hidden_size`: {text_encoder.config.hidden_size} and `text_decoder.prefix_inner_dim`: {text_decoder.prefix_inner_dim}" + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + image_encoder=image_encoder, + image_processor=image_processor, + clip_tokenizer=clip_tokenizer, + text_decoder=text_decoder, + text_tokenizer=text_tokenizer, + unet=unet, + scheduler=scheduler, + ) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + + self.num_channels_latents = vae.config.latent_channels + self.text_encoder_seq_len = text_encoder.config.max_position_embeddings + self.text_encoder_hidden_size = text_encoder.config.hidden_size + self.image_encoder_projection_dim = image_encoder.config.projection_dim + self.unet_resolution = unet.config.sample_size + + self.text_intermediate_dim = self.text_encoder_hidden_size + if self.text_decoder.prefix_hidden_dim is not None: + self.text_intermediate_dim = self.text_decoder.prefix_hidden_dim + + self.mode = None + + # TODO: handle safety checking? + self.safety_checker = None + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta')` and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.image_encoder, self.text_decoder]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + # Add self.image_encoder, self.text_decoder to cpu_offloaded_models list + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae, self.image_encoder, self.text_decoder]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def _infer_mode(self, prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents): + r""" + Infer the generation task ('mode') from the inputs to `__call__`. If the mode has been manually set, the set + mode will be used. + """ + prompt_available = (prompt is not None) or (prompt_embeds is not None) + image_available = image is not None + input_available = prompt_available or image_available + + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + full_latents_available = latents is not None + image_latents_available = vae_latents_available and clip_latents_available + all_indv_latents_available = prompt_latents_available and image_latents_available + + if self.mode is not None: + # Preferentially use the mode set by the user + mode = self.mode + elif prompt_available: + mode = "text2img" + elif image_available: + mode = "img2text" + else: + # Neither prompt nor image supplied, infer based on availability of latents + if full_latents_available or all_indv_latents_available: + mode = "joint" + elif prompt_latents_available: + mode = "text" + elif image_latents_available: + mode = "img" + else: + # No inputs or latents available + mode = "joint" + + # Give warnings for ambiguous cases + if self.mode is None and prompt_available and image_available: + logger.warning( + f"You have supplied both a text prompt and image to the pipeline and mode has not been set manually," + f" defaulting to mode '{mode}'." + ) + + if self.mode is None and not input_available: + if vae_latents_available != clip_latents_available: + # Exactly one of vae_latents and clip_latents is supplied + logger.warning( + f"You have supplied exactly one of `vae_latents` and `clip_latents`, whereas either both or none" + f" are expected to be supplied. Defaulting to mode '{mode}'." + ) + elif not prompt_latents_available and not vae_latents_available and not clip_latents_available: + # No inputs or latents supplied + logger.warning( + f"No inputs or latents have been supplied, and mode has not been manually set," + f" defaulting to mode '{mode}'." + ) + + return mode + + # Functions to manually set the mode + def set_text_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") text generation.""" + self.mode = "text" + + def set_image_mode(self): + r"""Manually set the generation mode to unconditional ("marginal") image generation.""" + self.mode = "img" + + def set_text_to_image_mode(self): + r"""Manually set the generation mode to text-conditioned image generation.""" + self.mode = "text2img" + + def set_image_to_text_mode(self): + r"""Manually set the generation mode to image-conditioned text generation.""" + self.mode = "img2text" + + def set_joint_mode(self): + r"""Manually set the generation mode to unconditional joint image-text generation.""" + self.mode = "joint" + + def reset_mode(self): + r"""Removes a manually set mode; after calling this, the pipeline will infer the mode from inputs.""" + self.mode = None + + def _infer_batch_size( + self, + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ): + r"""Infers the batch size and multiplier depending on mode and supplied arguments to `__call__`.""" + if num_images_per_prompt is None: + num_images_per_prompt = 1 + if num_prompts_per_image is None: + num_prompts_per_image = 1 + + assert num_images_per_prompt > 0, "num_images_per_prompt must be a positive integer" + assert num_prompts_per_image > 0, "num_prompts_per_image must be a positive integer" + + if mode in ["text2img"]: + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + # Either prompt or prompt_embeds must be present for text2img. + batch_size = prompt_embeds.shape[0] + multiplier = num_images_per_prompt + elif mode in ["img2text"]: + if isinstance(image, PIL.Image.Image): + batch_size = 1 + else: + # Image must be available and type either PIL.Image.Image or torch.FloatTensor. + # Not currently supporting something like image_embeds. + batch_size = image.shape[0] + multiplier = num_prompts_per_image + elif mode in ["img"]: + if vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + multiplier = num_images_per_prompt + elif mode in ["text"]: + if prompt_latents is not None: + batch_size = prompt_latents.shape[0] + else: + batch_size = 1 + multiplier = num_prompts_per_image + elif mode in ["joint"]: + if latents is not None: + batch_size = latents.shape[0] + elif prompt_latents is not None: + batch_size = prompt_latents.shape[0] + elif vae_latents is not None: + batch_size = vae_latents.shape[0] + elif clip_latents is not None: + batch_size = clip_latents.shape[0] + else: + batch_size = 1 + + if num_images_per_prompt == num_prompts_per_image: + multiplier = num_images_per_prompt + else: + multiplier = min(num_images_per_prompt, num_prompts_per_image) + logger.warning( + f"You are using mode `{mode}` and `num_images_per_prompt`: {num_images_per_prompt} and" + f" num_prompts_per_image: {num_prompts_per_image} are not equal. Using batch size equal to" + f" `min(num_images_per_prompt, num_prompts_per_image) = {batch_size}." + ) + return batch_size, multiplier + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + # self.tokenizer => self.clip_tokenizer + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.clip_tokenizer( + prompt, + padding="max_length", + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.clip_tokenizer.batch_decode( + untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.clip_tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.prepare_image_latents + # Add num_prompts_per_image argument, sample from autoencoder moment distribution + def encode_image_vae_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + do_classifier_free_guidance, + generator=None, + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + * self.vae.config.scaling_factor + for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + # Scale image_latents by the VAE's scaling factor + image_latents = image_latents * self.vae.config.scaling_factor + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if do_classifier_free_guidance: + uncond_image_latents = torch.zeros_like(image_latents) + image_latents = torch.cat([image_latents, image_latents, uncond_image_latents], dim=0) + + return image_latents + + def encode_image_clip_latents( + self, + image, + batch_size, + num_prompts_per_image, + dtype, + device, + generator=None, + ): + # Map image to CLIP embedding. + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + preprocessed_image = self.image_processor.preprocess( + image, + return_tensors="pt", + ) + preprocessed_image = preprocessed_image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_prompts_per_image + if isinstance(generator, list): + image_latents = [ + self.image_encoder(**preprocessed_image[i : i + 1]).image_embeds for i in range(batch_size) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = self.image_encoder(**preprocessed_image).image_embeds + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {image_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + return image_latents + + # Note that the CLIP latents are not decoded for image generation. + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + # Rename: decode_latents -> decode_image_latents + def decode_image_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_text_latents( + self, batch_size, num_images_per_prompt, seq_len, hidden_size, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded prompt. + shape = (batch_size * num_images_per_prompt, seq_len, hidden_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shace (B, L, D) + latents = latents.repeat(num_images_per_prompt, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + # Rename prepare_latents -> prepare_image_vae_latents and add num_prompts_per_image argument. + def prepare_image_vae_latents( + self, + batch_size, + num_prompts_per_image, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size * num_prompts_per_image, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, C, H, W) + latents = latents.repeat(num_prompts_per_image, 1, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def prepare_image_clip_latents( + self, batch_size, num_prompts_per_image, clip_img_dim, dtype, device, generator, latents=None + ): + # Prepare latents for the CLIP embedded image. + shape = (batch_size * num_prompts_per_image, 1, clip_img_dim) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + # latents is assumed to have shape (B, L, D) + latents = latents.repeat(num_prompts_per_image, 1, 1) + latents = latents.to(device=device, dtype=dtype) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _split(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim) into two tensors of shape (B, C, H, W) + and (B, 1, clip_img_dim) + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + + img_vae, img_clip = x.split([img_vae_dim, self.image_encoder_projection_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + return img_vae, img_clip + + def _combine(self, img_vae, img_clip): + r""" + Combines a latent iamge img_vae of shape (B, C, H, W) and a CLIP-embedded image img_clip of shape (B, 1, + clip_img_dim) into a single tensor of shape (B, C * H * W + clip_img_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + return torch.concat([img_vae, img_clip], dim=-1) + + def _split_joint(self, x, height, width): + r""" + Splits a flattened embedding x of shape (B, C * H * W + clip_img_dim + text_seq_len * text_dim] into (img_vae, + img_clip, text) where img_vae is of shape (B, C, H, W), img_clip is of shape (B, 1, clip_img_dim), and text is + of shape (B, text_seq_len, text_dim). + """ + batch_size = x.shape[0] + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_intermediate_dim + + img_vae, img_clip, text = x.split([img_vae_dim, self.image_encoder_projection_dim, text_dim], dim=1) + + img_vae = torch.reshape(img_vae, (batch_size, self.num_channels_latents, latent_height, latent_width)) + img_clip = torch.reshape(img_clip, (batch_size, 1, self.image_encoder_projection_dim)) + text = torch.reshape(text, (batch_size, self.text_encoder_seq_len, self.text_intermediate_dim)) + return img_vae, img_clip, text + + def _combine_joint(self, img_vae, img_clip, text): + r""" + Combines a latent image img_vae of shape (B, C, H, W), a CLIP-embedded image img_clip of shape (B, L_img, + clip_img_dim), and a text embedding text of shape (B, L_text, text_dim) into a single embedding x of shape (B, + C * H * W + L_img * clip_img_dim + L_text * text_dim). + """ + img_vae = torch.reshape(img_vae, (img_vae.shape[0], -1)) + img_clip = torch.reshape(img_clip, (img_clip.shape[0], -1)) + text = torch.reshape(text, (text.shape[0], -1)) + return torch.concat([img_vae, img_clip, text], dim=-1) + + def _get_noise_pred( + self, + mode, + latents, + t, + prompt_embeds, + img_vae, + img_clip, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ): + r""" + Gets the noise prediction using the `unet` and performs classifier-free guidance, if necessary. + """ + if mode == "joint": + # Joint text-image generation + img_vae_latents, img_clip_latents, text_latents = self._split_joint(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, text_latents, timestep_img=t, timestep_text=t, data_type=data_type + ) + + x_out = self._combine_joint(img_vae_out, img_clip_out, text_out) + + if guidance_scale <= 1.0: + return x_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + _, _, text_out_uncond = self.unet( + img_vae_T, img_clip_T, text_latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + img_vae_out_uncond, img_clip_out_uncond, _ = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond) + + return guidance_scale * x_out + (1.0 - guidance_scale) * x_out_uncond + elif mode == "text2img": + # Text-conditioned image generation + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, img_clip_latents, prompt_embeds, timestep_img=t, timestep_text=0, data_type=data_type + ) + + img_out = self._combine(img_vae_out, img_clip_out) + + if guidance_scale <= 1.0: + return img_out + + # Classifier-free guidance + text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_latents, + img_clip_latents, + text_T, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond) + + return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond + elif mode == "img2text": + # Image-conditioned text generation + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=0, timestep_text=t, data_type=data_type + ) + + if guidance_scale <= 1.0: + return text_out + + # Classifier-free guidance + img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype) + img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype) + + img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet( + img_vae_T, img_clip_T, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond + elif mode == "text": + # Unconditional ("marginal") text generation (no CFG) + img_vae_out, img_clip_out, text_out = self.unet( + img_vae, img_clip, latents, timestep_img=max_timestep, timestep_text=t, data_type=data_type + ) + + return text_out + elif mode == "img": + # Unconditional ("marginal") image generation (no CFG) + img_vae_latents, img_clip_latents = self._split(latents, height, width) + + img_vae_out, img_clip_out, text_out = self.unet( + img_vae_latents, + img_clip_latents, + prompt_embeds, + timestep_img=t, + timestep_text=max_timestep, + data_type=data_type, + ) + + img_out = self._combine(img_vae_out, img_clip_out) + return img_out + + def check_latents_shape(self, latents_name, latents, expected_shape): + latents_shape = latents.shape + expected_num_dims = len(expected_shape) + 1 # expected dimensions plus the batch dimension + expected_shape_str = ", ".join(str(dim) for dim in expected_shape) + if len(latents_shape) != expected_num_dims: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {len(latents_shape)} dimensions." + ) + for i in range(1, expected_num_dims): + if latents_shape[i] != expected_shape[i - 1]: + raise ValueError( + f"`{latents_name}` should have shape (batch_size, {expected_shape_str}), but the current shape" + f" {latents_shape} has {latents_shape[i]} != {expected_shape[i - 1]} at dimension {i}." + ) + + def check_inputs( + self, + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + latents=None, + prompt_latents=None, + vae_latents=None, + clip_latents=None, + ): + # Check inputs before running the generative process. + if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if mode == "text2img": + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if mode == "img2text": + if image is None: + raise ValueError("`img2text` mode requires an image to be provided.") + + # Check provided latents + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + full_latents_available = latents is not None + prompt_latents_available = prompt_latents is not None + vae_latents_available = vae_latents is not None + clip_latents_available = clip_latents is not None + + if full_latents_available: + individual_latents_available = ( + prompt_latents is not None or vae_latents is not None or clip_latents is not None + ) + if individual_latents_available: + logger.warning( + "You have supplied both `latents` and at least one of `prompt_latents`, `vae_latents`, and" + " `clip_latents`. The value of `latents` will override the value of any individually supplied latents." + ) + # Check shape of full latents + img_vae_dim = self.num_channels_latents * latent_height * latent_width + text_dim = self.text_encoder_seq_len * self.text_encoder_hidden_size + latents_dim = img_vae_dim + self.image_encoder_projection_dim + text_dim + latents_expected_shape = (latents_dim,) + self.check_latents_shape("latents", latents, latents_expected_shape) + + # Check individual latent shapes, if present + if prompt_latents_available: + prompt_latents_expected_shape = (self.text_encoder_seq_len, self.text_encoder_hidden_size) + self.check_latents_shape("prompt_latents", prompt_latents, prompt_latents_expected_shape) + + if vae_latents_available: + vae_latents_expected_shape = (self.num_channels_latents, latent_height, latent_width) + self.check_latents_shape("vae_latents", vae_latents, vae_latents_expected_shape) + + if clip_latents_available: + clip_latents_expected_shape = (1, self.image_encoder_projection_dim) + self.check_latents_shape("clip_latents", clip_latents, clip_latents_expected_shape) + + if mode in ["text2img", "img"] and vae_latents_available and clip_latents_available: + if vae_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"Both `vae_latents` and `clip_latents` are supplied, but their batch dimensions are not equal:" + f" {vae_latents.shape[0]} != {clip_latents.shape[0]}." + ) + + if mode == "joint" and prompt_latents_available and vae_latents_available and clip_latents_available: + if prompt_latents.shape[0] != vae_latents.shape[0] or prompt_latents.shape[0] != clip_latents.shape[0]: + raise ValueError( + f"All of `prompt_latents`, `vae_latents`, and `clip_latents` are supplied, but their batch" + f" dimensions are not equal: {prompt_latents.shape[0]} != {vae_latents.shape[0]}" + f" != {clip_latents.shape[0]}." + ) + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + data_type: Optional[int] = 1, + num_inference_steps: int = 50, + guidance_scale: float = 8.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + num_prompts_per_image: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_latents: Optional[torch.FloatTensor] = None, + vae_latents: Optional[torch.FloatTensor] = None, + clip_latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds` + instead. Required for text-conditioned image generation (`text2img`) mode. + image (`torch.FloatTensor` or `PIL.Image.Image`, *optional*): + `Image`, or tensor representing an image batch. Required for image-conditioned text generation + (`img2text`) mode. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + data_type (`int`, *optional*, defaults to 1): + The data type (either 0 or 1). Only used if you are loading a checkpoint which supports a data type + embedding; this is added for compatibility with the UniDiffuser-v1 checkpoint. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. Note that the original [UniDiffuser + paper](https://arxiv.org/pdf/2303.06555.pdf) uses a different definition of the guidance scale `w'`, + which satisfies `w = w' + 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). Used in text-conditioned image generation (`text2img`) mode. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. Used in `text2img` (text-conditioned image generation) and + `img` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. + num_prompts_per_image (`int`, *optional*, defaults to 1): + The number of prompts to generate per image. Used in `img2text` (image-conditioned text generation) and + `text` mode. If the mode is joint and both `num_images_per_prompt` and `num_prompts_per_image` are + supplied, `min(num_images_per_prompt, num_prompts_per_image)` samples will be generated. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for joint + image-text generation. Can be used to tweak the same generation with different prompts. If not + provided, a latents tensor will be generated by sampling using the supplied random `generator`. Note + that this is assumed to be a full set of VAE, CLIP, and text latents, if supplied, this will override + the value of `prompt_latents`, `vae_latents`, and `clip_latents`. + prompt_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for text + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + vae_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + clip_latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. Used in text-conditioned + image generation (`text2img`) mode. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. Used in text-conditioned image generation (`text2img`) mode. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.unidiffuser.ImageTextPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Returns: + [`~pipelines.unidiffuser.ImageTextPipelineOutput`] or `tuple`: + [`pipelines.unidiffuser.ImageTextPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is a list with the generated images, and the second element is a list + of generated texts. + """ + + # 0. Default height and width to unet + height = height or self.unet_resolution * self.vae_scale_factor + width = width or self.unet_resolution * self.vae_scale_factor + + # 1. Check inputs + # Recalculate mode for each call to the pipeline. + mode = self._infer_mode(prompt, prompt_embeds, image, latents, prompt_latents, vae_latents, clip_latents) + self.check_inputs( + mode, + prompt, + image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + + # 2. Define call parameters + batch_size, multiplier = self._infer_batch_size( + mode, + prompt, + prompt_embeds, + image, + num_images_per_prompt, + num_prompts_per_image, + latents, + prompt_latents, + vae_latents, + clip_latents, + ) + device = self._execution_device + reduce_text_emb_dim = self.text_intermediate_dim < self.text_encoder_hidden_size or self.mode != "text2img" + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + # Note that this differs from the formulation in the unidiffusers paper! + # do_classifier_free_guidance = guidance_scale > 1.0 + + # check if scheduler is in sigmas space + # scheduler_is_in_sigma_space = hasattr(self.scheduler, "sigmas") + + # 3. Encode input prompt, if available; otherwise prepare text latents + if latents is not None: + # Overwrite individual latents + vae_latents, clip_latents, prompt_latents = self._split_joint(latents, height, width) + + if mode in ["text2img"]: + # 3.1. Encode input prompt, if available + assert prompt is not None or prompt_embeds is not None + prompt_embeds = self._encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=multiplier, + do_classifier_free_guidance=False, # don't support standard classifier-free guidance for now + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + else: + # 3.2. Prepare text latent variables, if input not available + prompt_embeds = self.prepare_text_latents( + batch_size=batch_size, + num_images_per_prompt=multiplier, + seq_len=self.text_encoder_seq_len, + hidden_size=self.text_encoder_hidden_size, + dtype=self.text_encoder.dtype, # Should work with both full precision and mixed precision + device=device, + generator=generator, + latents=prompt_latents, + ) + + if reduce_text_emb_dim: + prompt_embeds = self.text_decoder.encode(prompt_embeds) + + # 4. Encode image, if available; otherwise prepare image latents + if mode in ["img2text"]: + # 4.1. Encode images, if available + assert image is not None, "`img2text` requires a conditioning image" + # Encode image using VAE + image_vae = preprocess(image) + height, width = image_vae.shape[-2:] + image_vae_latents = self.encode_image_vae_latents( + image=image_vae, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + do_classifier_free_guidance=False, # Copied from InstructPix2Pix, don't use their version of CFG + generator=generator, + ) + + # Encode image using CLIP + image_clip_latents = self.encode_image_clip_latents( + image=image, + batch_size=batch_size, + num_prompts_per_image=multiplier, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + ) + # (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size) + image_clip_latents = image_clip_latents.unsqueeze(1) + else: + # 4.2. Prepare image latent variables, if input not available + # Prepare image VAE latents in latent space + image_vae_latents = self.prepare_image_vae_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + num_channels_latents=self.num_channels_latents, + height=height, + width=width, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=vae_latents, + ) + + # Prepare image CLIP latents + image_clip_latents = self.prepare_image_clip_latents( + batch_size=batch_size, + num_prompts_per_image=multiplier, + clip_img_dim=self.image_encoder_projection_dim, + dtype=prompt_embeds.dtype, + device=device, + generator=generator, + latents=clip_latents, + ) + + # 5. Set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + # max_timestep = timesteps[0] + max_timestep = self.scheduler.config.num_train_timesteps + + # 6. Prepare latent variables + if mode == "joint": + latents = self._combine_joint(image_vae_latents, image_clip_latents, prompt_embeds) + elif mode in ["text2img", "img"]: + latents = self._combine(image_vae_latents, image_clip_latents) + elif mode in ["img2text", "text"]: + latents = prompt_embeds + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + logger.debug(f"Scheduler extra step kwargs: {extra_step_kwargs}") + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # predict the noise residual + # Also applies classifier-free guidance as described in the UniDiffuser paper + noise_pred = self._get_noise_pred( + mode, + latents, + t, + prompt_embeds, + image_vae_latents, + image_clip_latents, + max_timestep, + data_type, + guidance_scale, + generator, + device, + height, + width, + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + gen_image = None + gen_text = None + if mode == "joint": + image_vae_latents, image_clip_latents, text_latents = self._split_joint(latents, height, width) + + # Map latent VAE image back to pixel space + gen_image = self.decode_image_latents(image_vae_latents) + + # Generate text using the text decoder + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + gen_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + elif mode in ["text2img", "img"]: + image_vae_latents, image_clip_latents = self._split(latents, height, width) + gen_image = self.decode_image_latents(image_vae_latents) + elif mode in ["img2text", "text"]: + text_latents = latents + output_token_list, seq_lengths = self.text_decoder.generate_captions( + text_latents, self.text_tokenizer.eos_token_id, device=device + ) + output_list = output_token_list.cpu().numpy() + gen_text = [ + self.text_tokenizer.decode(output[: int(length)], skip_special_tokens=True) + for output, length in zip(output_list, seq_lengths) + ] + + # 10. Convert to PIL + if output_type == "pil" and gen_image is not None: + gen_image = self.numpy_to_pil(gen_image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (gen_image, gen_text) + + return ImageTextPipelineOutput(images=gen_image, text=gen_text) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 57e1abc7315b..f11729451299 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -7,6 +7,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin +from ...models.activations import get_activation from ...models.attention import Attention from ...models.attention_processor import ( AttentionProcessor, @@ -15,10 +16,17 @@ AttnProcessor, ) from ...models.dual_transformer_2d import DualTransformer2DModel -from ...models.embeddings import GaussianFourierProjection, TextTimeEmbedding, TimestepEmbedding, Timesteps +from ...models.embeddings import ( + GaussianFourierProjection, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) from ...models.transformer_2d import Transformer2DModel from ...models.unet_2d_condition import UNet2DConditionOutput -from ...utils import logging +from ...utils import is_torch_version, logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -42,6 +50,9 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlockFlat": @@ -98,6 +109,9 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, ): up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlockFlat": @@ -176,7 +190,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. encoder_hid_dim (`int`, *optional*, defaults to None): - If given, `encoder_hidden_states` will be projected from this dimension to `cross_attention_dim`. + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to None): + If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. @@ -247,6 +265,7 @@ def __init__( norm_eps: float = 1e-5, cross_attention_dim: Union[int, Tuple[int]] = 1280, encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, dual_cross_attention: bool = False, use_linear_projection: bool = False, @@ -344,8 +363,32 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) - if encoder_hid_dim is not None: + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) else: self.encoder_hid_proj = None @@ -387,21 +430,20 @@ def __init__( self.add_embedding = TextTimeEmbedding( text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None or 'text'.") + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") if time_embedding_act_fn is None: self.time_embed_act = None - elif time_embedding_act_fn == "swish": - self.time_embed_act = lambda x: F.silu(x) - elif time_embedding_act_fn == "mish": - self.time_embed_act = nn.Mish() - elif time_embedding_act_fn == "silu": - self.time_embed_act = nn.SiLU() - elif time_embedding_act_fn == "gelu": - self.time_embed_act = nn.GELU() else: - raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") + self.time_embed_act = get_activation(time_embedding_act_fn) self.down_blocks = nn.ModuleList([]) self.up_blocks = nn.ModuleList([]) @@ -555,16 +597,7 @@ def __init__( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) - if act_fn == "swish": - self.conv_act = lambda x: F.silu(x) - elif act_fn == "mish": - self.conv_act = nn.Mish() - elif act_fn == "silu": - self.conv_act = nn.SiLU() - elif act_fn == "gelu": - self.conv_act = nn.GELU() - else: - raise ValueError(f"Unsupported activation function: {act_fn}") + self.conv_act = get_activation(act_fn) else: self.conv_norm_out = None @@ -713,8 +746,10 @@ def forward( timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" @@ -722,12 +757,20 @@ def forward( sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + encoder_attention_mask (`torch.Tensor`): + (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = + discard. Mask will be converted into a bias, which adds large negative values to attention scores + corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + added_cond_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time + embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and + `addition_embed_type` for more information. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -748,11 +791,27 @@ def forward( logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -779,7 +838,7 @@ def forward( # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) + t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) @@ -794,7 +853,7 @@ def forward( # there might be better ways to encapsulate this. class_labels = class_labels.to(dtype=sample.dtype) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) @@ -804,12 +863,35 @@ def forward( if self.config.addition_embed_type == "text": aug_emb = self.add_embedding(encoder_hidden_states) emb = emb + aug_emb + elif self.config.addition_embed_type == "text_image": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires" + " the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + + aug_emb = self.add_embedding(text_embs, image_embs) + emb = emb + aug_emb if self.time_embed_act is not None: emb = self.time_embed_act(emb) - if self.encoder_hid_proj is not None: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which" + " requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) # 2. pre-process sample = self.conv_in(sample) @@ -824,6 +906,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) @@ -837,7 +920,7 @@ def forward( down_block_res_samples, down_block_additional_residuals ): down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples += (down_block_res_sample,) + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples @@ -849,6 +932,7 @@ def forward( encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, ) if mid_block_additional_residual is not None: @@ -875,6 +959,7 @@ def forward( cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( @@ -1071,17 +1156,24 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1175,9 +1267,14 @@ def __init__( self.gradient_checkpointing = False def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used output_states = () for resnet, attn in zip(self.resnets, self.attentions): @@ -1192,12 +1289,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + None, # timestep + None, # class_labels cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1205,15 +1313,18 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states @@ -1282,7 +1393,14 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) else: hidden_states = resnet(hidden_states, temb) @@ -1379,15 +1497,15 @@ def __init__( def forward( self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - cross_attention_kwargs=None, - upsample_size=None, - attention_mask=None, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): - # TODO(Patrick, William) - attention mask is not used for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] @@ -1405,12 +1523,23 @@ def custom_forward(*inputs): return custom_forward - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states, + None, # timestep + None, # class_labels cross_attention_kwargs, + attention_mask, + encoder_attention_mask, + **ckpt_kwargs, )[0] else: hidden_states = resnet(hidden_states, temb) @@ -1418,7 +1547,10 @@ def custom_forward(*inputs): hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1514,15 +1646,24 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None - ): + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs, - ).sample + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] hidden_states = resnet(hidden_states, temb) return hidden_states @@ -1615,16 +1756,34 @@ def __init__( self.resnets = nn.ModuleList(resnets) def forward( - self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs=None + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, ): cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): # attn hidden_states = attn( hidden_states, encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, + attention_mask=mask, **cross_attention_kwargs, ) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 661a1bd3cf73..1d2e61d86b90 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Tuple, Union import numpy as np @@ -26,6 +27,7 @@ CLIPVisionModelWithProjection, ) +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor @@ -88,6 +90,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None and ( "dual_cross_attention" not in self.image_unet.config or not self.image_unet.config.dual_cross_attention @@ -329,8 +332,13 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -572,12 +580,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index e3a2ee370362..4450846300fc 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import numpy as np @@ -21,6 +22,7 @@ import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor @@ -71,6 +73,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def enable_sequential_cpu_offload(self, gpu_id=0): r""" @@ -189,8 +192,13 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -414,12 +422,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 8. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents - # 9. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index 26b9be2bfa76..1fdb21f2b745 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -13,12 +13,14 @@ # limitations under the License. import inspect +import warnings from typing import Callable, List, Optional, Union import torch import torch.utils.checkpoint from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor @@ -76,6 +78,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) if self.text_unet is not None: self._swap_unet_attention_blocks() @@ -246,8 +249,13 @@ def normalize_embeddings(encoder_output): # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): + warnings.warn( + "The decode_latents method is deprecated and will be removed in a future version. Please" + " use VaeImageProcessor instead", + FutureWarning, + ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents).sample + image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -488,12 +496,12 @@ def __call__( if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # 9. Post-processing - image = self.decode_latents(latents) + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents - # 10. Convert to PIL - if output_type == "pil": - image = self.numpy_to_pil(image) + image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index e5d5bb40633f..05414e32fc9e 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. -from ..utils import OptionalDependencyNotAvailable, is_flax_available, is_scipy_available, is_torch_available +from ..utils import ( + OptionalDependencyNotAvailable, + is_flax_available, + is_scipy_available, + is_torch_available, + is_torchsde_available, +) try: @@ -27,6 +33,7 @@ from .scheduling_ddpm import DDPMScheduler from .scheduling_deis_multistep import DEISMultistepScheduler from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler + from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler from .scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from .scheduling_euler_discrete import EulerDiscreteScheduler @@ -72,3 +79,11 @@ from ..utils.dummy_torch_and_scipy_objects import * # noqa F403 else: from .scheduling_lms_discrete import LMSDiscreteScheduler + +try: + if not (is_torch_available() and is_torchsde_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 +else: + from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 6b62d8893482..bab6f8acea03 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -76,6 +76,42 @@ def alpha_bar(time_step): return torch.tensor(betas, dtype=torch.float32) +def rescale_zero_terminal_snr(betas): + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + class DDIMScheduler(SchedulerMixin, ConfigMixin): """ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising @@ -122,6 +158,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. sample_max_value (`float`, default `1.0`): the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, default `"leading"`): + The way the timesteps should be scaled. Refer to Table 2. of [Common Diffusion Noise Schedules and Sample + Steps are Flawed](https://arxiv.org/abs/2305.08891) for more information. + rescale_betas_zero_snr (`bool`, default `False`): + whether to rescale the betas to have zero terminal SNR (proposed by https://arxiv.org/pdf/2305.08891.pdf). + This can enable the model to generate very bright and dark samples instead of limiting it to samples with + medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -143,6 +187,8 @@ def __init__( dynamic_thresholding_ratio: float = 0.995, clip_sample_range: float = 1.0, sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + rescale_betas_zero_snr: bool = False, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -159,6 +205,10 @@ def __init__( else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -251,12 +301,26 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic ) self.num_inference_steps = num_inference_steps - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + + # "leading" and "trailing" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'." + ) + self.timesteps = torch.from_numpy(timesteps).to(device) - self.timesteps += self.config.steps_offset def step( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index a8a71fe420aa..5d24766d68c7 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -91,7 +91,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. variance_type (`str`): diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 3399ee2c54cb..e72b1bdc23b5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,6 +21,7 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput @@ -70,6 +71,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as stable-diffusion). + We also support the SDE variant of DPM-Solver and DPM-Solver++, which is a fast SDE solver for the reverse + diffusion SDE. Currently we only support the first-order and second-order solvers. We recommend using the + second-order `sde-dpmsolver++`. + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and @@ -103,10 +108,10 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): the threshold value for dynamic thresholding. Valid only when `thresholding=True` and `algorithm_type="dpmsolver++`. algorithm_type (`str`, default `dpmsolver++`): - the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++`. The `dpmsolver` type implements the - algorithms in https://arxiv.org/abs/2206.00927, and the `dpmsolver++` type implements the algorithms in - https://arxiv.org/abs/2211.01095. We recommend to use `dpmsolver++` with `solver_order=2` for guided - sampling (e.g. stable-diffusion). + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). solver_type (`str`, default `midpoint`): the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are @@ -118,6 +123,17 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. """ _compatibles = [e.name for e in KarrasDiffusionSchedulers] @@ -140,6 +156,8 @@ def __init__( solver_type: str = "midpoint", lower_order_final: bool = True, use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -167,7 +185,7 @@ def __init__( self.init_noise_sigma = 1.0 # settings for DPM-Solver - if algorithm_type not in ["dpmsolver", "dpmsolver++"]: + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: if algorithm_type == "deis": self.register_to_config(algorithm_type="dpmsolver++") else: @@ -187,7 +205,7 @@ def __init__( self.lower_order_nums = 0 self.use_karras_sigmas = use_karras_sigmas - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -197,8 +215,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic device (`str` or `torch.device`, optional): the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) @@ -320,9 +341,13 @@ def convert_model_output( Returns: `torch.FloatTensor`: the converted model output. """ + # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type == "dpmsolver++": + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": @@ -340,30 +365,42 @@ def convert_model_output( x0_pred = self._threshold_sample(x0_pred) return x0_pred + # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type == "dpmsolver": + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: if self.config.prediction_type == "epsilon": - return model_output + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = (sample - alpha_t * model_output) / sigma_t - return epsilon elif self.config.prediction_type == "v_prediction": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] epsilon = alpha_t * model_output + sigma_t * sample - return epsilon else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction` for the DPMSolverMultistepScheduler." ) + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + def dpm_solver_first_order_update( self, model_output: torch.FloatTensor, timestep: int, prev_timestep: int, sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the first-order DPM-Solver (equivalent to DDIM). @@ -388,6 +425,20 @@ def dpm_solver_first_order_update( x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output elif self.config.algorithm_type == "dpmsolver": x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def multistep_dpm_solver_second_order_update( @@ -396,6 +447,7 @@ def multistep_dpm_solver_second_order_update( timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, ) -> torch.FloatTensor: """ One step for the second-order multistep DPM-Solver. @@ -447,6 +499,38 @@ def multistep_dpm_solver_second_order_update( - (sigma_t * (torch.exp(h) - 1.0)) * D0 - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) return x_t def multistep_dpm_solver_third_order_update( @@ -509,6 +593,7 @@ def step( model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, + generator=None, return_dict: bool = True, ) -> Union[SchedulerOutput, Tuple]: """ @@ -551,12 +636,21 @@ def step( self.model_outputs[i] = self.model_outputs[i + 1] self.model_outputs[-1] = model_output + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, timestep, prev_timestep, sample) + prev_sample = self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample, noise=noise + ) elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: timestep_list = [self.timesteps[step_index - 1], timestep] prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample + self.model_outputs, timestep_list, prev_timestep, sample, noise=noise ) else: timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py new file mode 100644 index 000000000000..b424ebbff262 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py @@ -0,0 +1,701 @@ +# Copyright 2023 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin): + """ + DPMSolverMultistepInverseScheduler is the reverse scheduler of [`DPMSolverMultistepScheduler`]. + + We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space + diffusion models, you can set both `algorithm_type="dpmsolver++"` and `thresholding=True` to use the dynamic + thresholding. Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + solver_order (`int`, default `2`): + the order of DPM-Solver; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + For pixel-space diffusion models, you can set both `algorithm_type=dpmsolver++` and `thresholding=True` to + use the dynamic thresholding. Note that the thresholding method is unsuitable for latent-space diffusion + models (such as stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++`. + algorithm_type (`str`, default `dpmsolver++`): + the algorithm type for the solver. Either `dpmsolver` or `dpmsolver++` or `sde-dpmsolver` or + `sde-dpmsolver++`. The `dpmsolver` type implements the algorithms in https://arxiv.org/abs/2206.00927, and + the `dpmsolver++` type implements the algorithms in https://arxiv.org/abs/2211.01095. We recommend to use + `dpmsolver++` or `sde-dpmsolver++` with `solver_order=2` for guided sampling (e.g. stable-diffusion). + solver_type (`str`, default `midpoint`): + the solver type for the second-order solver. Either `midpoint` or `heun`. The solver type slightly affects + the sample quality, especially for small number of steps. We empirically find that `midpoint` solvers are + slightly better, so we recommend to use the `midpoint` type. + lower_order_final (`bool`, default `True`): + whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically + find this trick can stabilize the sampling of DPM-Solver for steps < 15, especially for steps <= 10. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # settings for DPM-Solver + if algorithm_type not in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} does is not implemented for {self.__class__}") + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32).copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self.use_karras_sigmas = use_karras_sigmas + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.lambda_min_clipped) + self.noisiest_timestep = self.config.num_train_timesteps - 1 - clipped_idx + timesteps = ( + np.linspace(0, self.noisiest_timestep, num_inference_steps + 1).round()[:-1].copy().astype(np.int64) + ) + + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = timesteps.copy().astype(np.int64) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + """ + Convert the model output to the corresponding type that the algorithm (DPM-Solver / DPM-Solver++) needs. + + DPM-Solver is designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to + discretize an integral of the data prediction model. So we need to first convert the model output to the + corresponding type to match the algorithm. + + Note that the algorithm type and the model type is decoupled. That is to say, we can use either DPM-Solver or + DPM-Solver++ for both noise prediction model and data prediction model. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the converted model output. + """ + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + model_output = model_output[:, :3] + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.config.prediction_type == "sample": + x0_pred = model_output + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned", "learned_range"]: + epsilon = model_output[:, :3] + else: + epsilon = model_output + elif self.config.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + elif self.config.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the DPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * epsilon) / alpha_t + x0_pred = self._threshold_sample(x0_pred) + epsilon = (sample - alpha_t * x0_pred) / sigma_t + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.FloatTensor, + timestep: int, + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the first-order DPM-Solver (equivalent to DDIM). + + See https://arxiv.org/abs/2206.00927 for the detailed derivation. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] + alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] + sigma_t, sigma_s = self.sigma_t[prev_timestep], self.sigma_t[timestep] + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ( + (alpha_t / alpha_s) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + noise: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + One step for the second-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] + m0, m1 = model_output_list[-1], model_output_list[-2] + lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 0.5 * (sigma_t * (torch.exp(h) - 1.0)) * D1 + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + ) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (alpha_t / alpha_s0) * sample + - 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * D0 + - 2.0 * (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise + ) + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.FloatTensor], + timestep_list: List[int], + prev_timestep: int, + sample: torch.FloatTensor, + ) -> torch.FloatTensor: + """ + One step for the third-order multistep DPM-Solver. + + Args: + model_output_list (`List[torch.FloatTensor]`): + direct outputs from learned diffusion model at current and latter timesteps. + timestep (`int`): current and latter discrete timestep in the diffusion chain. + prev_timestep (`int`): previous discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + + Returns: + `torch.FloatTensor`: the sample tensor at the previous timestep. + """ + t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] + m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] + lambda_t, lambda_s0, lambda_s1, lambda_s2 = ( + self.lambda_t[t], + self.lambda_t[s0], + self.lambda_t[s1], + self.lambda_t[s2], + ) + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (sigma_t / sigma_s0) * sample + - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 + - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2 + ) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ( + (alpha_t / alpha_s0) * sample + - (sigma_t * (torch.exp(h) - 1.0)) * D0 + - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2 + ) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Step function propagating the sample with the multistep DPM-Solver. + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + Returns: + [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is + True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + prev_timestep = ( + self.noisiest_timestep if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + ) + lower_order_final = ( + (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + lower_order_second = ( + (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, timestep, sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, timestep, prev_timestep, sample, noise=noise + ) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + timestep_list = [self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, timestep_list, prev_timestep, sample, noise=noise + ) + else: + timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, timestep_list, prev_timestep, sample + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py new file mode 100644 index 000000000000..ae9229981152 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -0,0 +1,447 @@ +# Copyright 2023 Katherine Crowson, The HuggingFace Team and hlky. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torchsde + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get("w0", torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will use one BrownianTree per batch item, each + with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DPMSolverSDEScheduler(SchedulerMixin, ConfigMixin): + """ + Implements Stochastic Sampler (Algorithm 2) from Karras et al. (2022). Based on the original k-diffusion + implementation by Katherine Crowson: + https://github.com/crowsonkb/k-diffusion/blob/41b4cb6df0506694a7776af31349acf082bf6091/k_diffusion/sampling.py#L543 + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. beta_start (`float`): the + starting `beta` value of inference. beta_end (`float`): the final `beta` value. beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear` or `scaled_linear`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + noise_sampler_seed (`int`, *optional*, defaults to `None`): + The random seed to use for the noise sampler. If `None`, a random seed will be generated. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 2 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, # sensible defaults + beta_end: float = 0.012, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + prediction_type: str = "epsilon", + use_karras_sigmas: Optional[bool] = False, + noise_sampler_seed: Optional[int] = None, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # set all values + self.set_timesteps(num_train_timesteps, None, num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas + self.noise_sampler = None + self.noise_sampler_seed = noise_sampler_seed + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + if self.state_in_first_order: + pos = -1 + else: + pos = 0 + return indices[pos].item() + + def scale_model_input( + self, + sample: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + ) -> torch.FloatTensor: + """ + Args: + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + sample (`torch.FloatTensor`): input sample timestep (`int`, optional): current timestep + Returns: + `torch.FloatTensor`: scaled input sample + """ + step_index = self.index_for_timestep(timestep) + + sigma = self.sigmas[step_index] + sigma_input = sigma if self.state_in_first_order else self.mid_point_sigma + sample = sample / ((sigma_input**2 + 1) ** 0.5) + return sample + + def set_timesteps( + self, + num_inference_steps: int, + device: Union[str, torch.device] = None, + num_train_timesteps: Optional[int] = None, + ): + """ + Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, optional): + the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + self.num_inference_steps = num_inference_steps + + num_train_timesteps = num_train_timesteps or self.config.num_train_timesteps + + timesteps = np.linspace(0, num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + + second_order_timesteps = self._second_order_timesteps(sigmas, log_sigmas) + + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + sigmas = torch.from_numpy(sigmas).to(device=device) + self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]]) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + timesteps = torch.from_numpy(timesteps) + second_order_timesteps = torch.from_numpy(second_order_timesteps) + timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)]) + timesteps[1::2] = second_order_timesteps + + if str(device).startswith("mps"): + # mps does not support float64 + self.timesteps = timesteps.to(device, dtype=torch.float32) + else: + self.timesteps = timesteps.to(device=device) + + # empty first order variables + self.sample = None + self.mid_point_sigma = None + + def _second_order_timesteps(self, sigmas, log_sigmas): + def sigma_fn(_t): + return np.exp(-_t) + + def t_fn(_sigma): + return -np.log(_sigma) + + midpoint_ratio = 0.5 + t = t_fn(sigmas) + delta_time = np.diff(t) + t_proposed = t[:-1] + delta_time * midpoint_ratio + sig_proposed = sigma_fn(t_proposed) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sig_proposed]) + return timesteps + + # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + @property + def state_in_first_order(self): + return self.sample is None + + def step( + self, + model_output: Union[torch.FloatTensor, np.ndarray], + timestep: Union[float, torch.FloatTensor], + sample: Union[torch.FloatTensor, np.ndarray], + return_dict: bool = True, + s_noise: float = 1.0, + ) -> Union[SchedulerOutput, Tuple]: + """ + Args: + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + model_output (Union[torch.FloatTensor, np.ndarray]): Direct output from learned diffusion model. + timestep (Union[float, torch.FloatTensor]): Current discrete timestep in the diffusion chain. + sample (Union[torch.FloatTensor, np.ndarray]): Current instance of sample being created by diffusion process. + return_dict (bool, optional): Option for returning tuple rather than SchedulerOutput class. Defaults to True. + s_noise (float, optional): Scaling factor for the noise added to the sample. Defaults to 1.0. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.SchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + step_index = self.index_for_timestep(timestep) + + # Create a noise sampler if it hasn't been created yet + if self.noise_sampler is None: + min_sigma, max_sigma = self.sigmas[self.sigmas > 0].min(), self.sigmas.max() + self.noise_sampler = BrownianTreeNoiseSampler(sample, min_sigma, max_sigma, self.noise_sampler_seed) + + # Define functions to compute sigma and t from each other + def sigma_fn(_t: torch.FloatTensor) -> torch.FloatTensor: + return _t.neg().exp() + + def t_fn(_sigma: torch.FloatTensor) -> torch.FloatTensor: + return _sigma.log().neg() + + if self.state_in_first_order: + sigma = self.sigmas[step_index] + sigma_next = self.sigmas[step_index + 1] + else: + # 2nd order + sigma = self.sigmas[step_index - 1] + sigma_next = self.sigmas[step_index] + + # Set the midpoint and step size for the current step + midpoint_ratio = 0.5 + t, t_next = t_fn(sigma), t_fn(sigma_next) + delta_time = t_next - t + t_proposed = t + delta_time * midpoint_ratio + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.config.prediction_type == "epsilon": + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) + pred_original_sample = sample - sigma_input * model_output + elif self.config.prediction_type == "v_prediction": + sigma_input = sigma if self.state_in_first_order else sigma_fn(t_proposed) + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) + elif self.config.prediction_type == "sample": + raise NotImplementedError("prediction_type not implemented yet: sample") + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + if sigma_next == 0: + derivative = (sample - pred_original_sample) / sigma + dt = sigma_next - sigma + prev_sample = sample + derivative * dt + else: + if self.state_in_first_order: + t_next = t_proposed + else: + sample = self.sample + + sigma_from = sigma_fn(t) + sigma_to = sigma_fn(t_next) + sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + ancestral_t = t_fn(sigma_down) + prev_sample = (sigma_fn(ancestral_t) / sigma_fn(t)) * sample - ( + t - ancestral_t + ).expm1() * pred_original_sample + prev_sample = prev_sample + self.noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * sigma_up + + if self.state_in_first_order: + # store for 2nd order step + self.sample = sample + self.mid_point_sigma = sigma_fn(t_next) + else: + # free for "first order mode" + self.sample = None + self.mid_point_sigma = None + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index 049e2b1dbd4d..7fa8eabb5a15 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,9 +21,13 @@ import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import logging from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ @@ -113,6 +117,21 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): lower_order_final (`bool`, default `True`): whether to use lower-order solvers in the final steps. For singlestep schedulers, we recommend to enable this to use up all the function evaluations. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. + lambda_min_clipped (`float`, default `-inf`): + the clipping threshold for the minimum value of lambda(t) for numerical stability. This is critical for + cosine (squaredcos_cap_v2) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. whether the model's output contains the predicted Gaussian variance. For example, OpenAI's + guided-diffusion (https://github.com/openai/guided-diffusion) predicts both mean and variance of the + Gaussian distribution in the model's output. DPM-Solver only needs the "mean" output because it is based on + diffusion ODEs. """ @@ -135,6 +154,9 @@ def __init__( algorithm_type: str = "dpmsolver++", solver_type: str = "midpoint", lower_order_final: bool = True, + use_karras_sigmas: Optional[bool] = False, + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, ): if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) @@ -180,6 +202,7 @@ def __init__( self.model_outputs = [None] * solver_order self.sample = None self.order_list = self.get_order_list(num_train_timesteps) + self.use_karras_sigmas = use_karras_sigmas def get_order_list(self, num_inference_steps: int) -> List[int]: """ @@ -226,16 +249,34 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + clipped_idx = torch.searchsorted(torch.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) timesteps = ( - np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1) + np.linspace(0, self.config.num_train_timesteps - 1 - clipped_idx, num_inference_steps + 1) .round()[::-1][:-1] .copy() .astype(np.int64) ) + + if self.use_karras_sigmas: + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) + sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() + timesteps = np.flip(timesteps).copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) self.model_outputs = [None] * self.config.solver_order self.sample = None - self.orders = self.get_order_list(num_inference_steps) + + if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0: + logger.warn( + "Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`." + ) + self.register_to_config(lower_order_final=True) + + self.order_list = self.get_order_list(num_inference_steps) # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: @@ -272,6 +313,44 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: return sample + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def convert_model_output( self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor ) -> torch.FloatTensor: @@ -297,6 +376,9 @@ def convert_model_output( # DPM-Solver++ needs to solve an integral of the data prediction model. if self.config.algorithm_type == "dpmsolver++": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] x0_pred = (sample - sigma_t * model_output) / alpha_t elif self.config.prediction_type == "sample": @@ -317,6 +399,9 @@ def convert_model_output( # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": if self.config.prediction_type == "epsilon": + # DPM-Solver and DPM-Solver++ only need the "mean" output. + if self.config.variance_type in ["learned_range"]: + model_output = model_output[:, :3] return model_output elif self.config.prediction_type == "sample": alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] @@ -575,6 +660,12 @@ def step( self.model_outputs[-1] = model_output order = self.order_list[step_index] + + # For img2img denoising might start with order>1 which is not possible + # In this case make sure that the first two steps are both order=1 + while self.model_outputs[-order] is None: + order -= 1 + # For single-step solvers, we use the initial value at each time with order = 1. if order == 1: self.sample = sample diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 2b32cad39925..100e2012ea20 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -70,8 +70,6 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, - `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 68a8e1bddc01..0656475c3093 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -94,6 +94,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): `linear` or `scaled_linear`. trained_betas (`np.ndarray`, optional): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + This parameter controls whether to use Karras sigmas (Karras et al. (2022) scheme) for step sizes in the + noise schedule during the sampling process. If True, the sigmas will be determined according to a sequence + of noise levels {σi} as defined in Equation (5) of the paper https://arxiv.org/pdf/2206.00364.pdf. prediction_type (`str`, default `epsilon`, optional): prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 @@ -111,6 +115,7 @@ def __init__( beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + use_karras_sigmas: Optional[bool] = False, prediction_type: str = "epsilon", ): if trained_betas is not None: @@ -140,8 +145,8 @@ def __init__( # setable values self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) + self.use_karras_sigmas = use_karras_sigmas + self.set_timesteps(num_train_timesteps, None) self.derivatives = [] self.is_scale_input_called = False @@ -201,8 +206,15 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.num_inference_steps = num_inference_steps timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + log_sigmas = np.log(sigmas) sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + + if self.use_karras_sigmas: + sigmas = self._convert_to_karras(in_sigmas=sigmas) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.sigmas = torch.from_numpy(sigmas).to(device=device) @@ -214,6 +226,44 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic self.derivatives = [] + # copied from diffusers.schedulers.scheduling_euler_discrete._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(sigma) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + # copied from diffusers.schedulers.scheduling_euler_discrete._convert_to_karras + def _convert_to_karras(self, in_sigmas: torch.FloatTensor) -> torch.FloatTensor: + """Constructs the noise schedule of Karras et al. (2022).""" + + sigma_min: float = in_sigmas[-1].item() + sigma_max: float = in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, self.num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + def step( self, model_output: torch.FloatTensor, diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index 96af210f06b1..f2f97b38f3d3 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -89,7 +89,7 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): beta_end (`float`): the final `beta` value. beta_schedule (`str`): the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + `linear`, `scaled_linear`, `squaredcos_cap_v2` or `sigmoid`. eta (`float`): The weight of noise for added noise in a diffusion step. Its value is between 0.0 and 1.0 -0.0 is DDIM and 1.0 is DDPM scheduler respectively. diff --git a/src/diffusers/schedulers/scheduling_unclip.py b/src/diffusers/schedulers/scheduling_unclip.py index 6403ee3f1518..d44edcb1812a 100644 --- a/src/diffusers/schedulers/scheduling_unclip.py +++ b/src/diffusers/schedulers/scheduling_unclip.py @@ -75,6 +75,9 @@ def alpha_bar(time_step): class UnCLIPScheduler(SchedulerMixin, ConfigMixin): """ + NOTE: do not use this scheduler. The DDPM scheduler has been updated to support the changes made here. This + scheduler will be removed and replaced with DDPM. + This is a modified DDPM Scheduler specifically for the karlo unCLIP model. This scheduler has some minor variations in how it calculates the learned range variance and dynamically diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a4121f75d850..0f95beb022ac 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -43,6 +43,7 @@ class KarrasDiffusionSchedulers(Enum): KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 + DPMSolverSDEScheduler = 14 @dataclass diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 340b96e29ac5..df9c7e882682 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,32 +1,16 @@ +import contextlib import copy -import os -import random +from random import random from typing import Any, Dict, Iterable, Optional, Union import numpy as np import torch -from .utils import deprecate +from .utils import deprecate, is_transformers_available -def enable_full_determinism(seed: int): - """ - Helper function for reproducible behavior during distributed training. See - - https://pytorch.org/docs/stable/notes/randomness.html for pytorch - """ - # set seed first - set_seed(seed) - - # Enable PyTorch deterministic mode. This potentially requires either the environment - # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, - # depending on the CUDA version, so we set them both here - os.environ["CUDA_LAUNCH_BLOCKING"] = "1" - os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - torch.use_deterministic_algorithms(True) - - # Enable CUDNN deterministic mode - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False +if is_transformers_available(): + import transformers def set_seed(seed: int): @@ -197,11 +181,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): self.cur_decay_value = decay one_minus_decay = 1 - decay + context_manager = contextlib.nullcontext + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) + if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): + context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) + + with context_manager(): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 1b8eca050c9e..36cbe82f79e7 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,7 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, - TEXT_ENCODER_TARGET_MODULES, + TEXT_ENCODER_ATTN_MODULE, WEIGHTS_NAME, ) from .deprecation_utils import deprecate @@ -70,6 +70,7 @@ is_tf_available, is_torch_available, is_torch_version, + is_torchsde_available, is_transformers_available, is_transformers_version, is_unidecode_available, @@ -100,6 +101,7 @@ torch_all_close, torch_device, ) + from .torch_utils import maybe_allow_in_graph from .testing_utils import export_to_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 1134ba6fb656..3c641a259a81 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,4 +30,4 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] -TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] +TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 014e193aa32a..e07b7cb27da7 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -450,6 +450,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class DPMSolverMultistepInverseScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class DPMSolverMultistepScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_torchsde_objects.py b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py new file mode 100644 index 000000000000..a81bbb316f32 --- /dev/null +++ b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py @@ -0,0 +1,17 @@ +# This file is autogenerated by the command `make fix-copies`, do not edit. +from ..utils import DummyObject, requires_backends + + +class DPMSolverSDEScheduler(metaclass=DummyObject): + _backends = ["torch", "torchsde"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "torchsde"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index bf4fe8d87ff9..95d07c081ccd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -152,6 +152,81 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ImageTextPipelineOutput(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class KandinskyPriorPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LDMTextToImagePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -212,6 +287,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionControlNetImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionControlNetPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -242,6 +347,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionDiffEditPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableDiffusionImageVariationPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -527,6 +647,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class UniDiffuserModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class UniDiffuserTextDecoder(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index aa6c9c657a87..5b0952f0b514 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -21,12 +21,12 @@ import re import shutil import sys -from distutils.version import StrictVersion from pathlib import Path from typing import Dict, Optional, Union from urllib import request from huggingface_hub import HfFolder, cached_download, hf_hub_download, model_info +from packaging import version from .. import __version__ from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging @@ -43,7 +43,7 @@ def get_diffusers_versions(): url = "https://pypi.org/pypi/diffusers/json" releases = json.loads(request.urlopen(url).read())["releases"].keys() - return sorted(releases, key=StrictVersion) + return sorted(releases, key=lambda x: version.Version(x)) def init_hf_modules(): diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index 9cfc649c8b86..6e44370a378a 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -280,7 +280,7 @@ def _get_model_file( if ( revision in DEPRECATED_REVISION_ARGS and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) - and version.parse(version.parse(__version__).base_version) >= version.parse("0.17.0") + and version.parse(version.parse(__version__).base_version) >= version.parse("0.18.0") ): try: model_file = hf_hub_download( diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2d90cb9747a7..4ded0f272462 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -287,6 +287,13 @@ except importlib_metadata.PackageNotFoundError: _bs4_available = False +_torchsde_available = importlib.util.find_spec("torchsde") is not None +try: + _torchsde_version = importlib_metadata.version("torchsde") + logger.debug(f"Successfully imported torchsde version {_torchsde_version}") +except importlib_metadata.PackageNotFoundError: + _torchsde_available = False + def is_torch_available(): return _torch_available @@ -372,6 +379,10 @@ def is_bs4_available(): return _bs4_available +def is_torchsde_available(): + return _torchsde_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -475,6 +486,11 @@ def is_bs4_available(): that match your environment. Please note that you may need to restart your runtime after installation. """ +# docstyle-ignore +TORCHSDE_IMPORT_ERROR = """ +{0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -495,6 +511,7 @@ def is_bs4_available(): ("tensorboard", (_tensorboard_available, TENSORBOARD_IMPORT_ERROR)), ("compel", (_compel_available, COMPEL_IMPORT_ERROR)), ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("torchsde", (_torchsde_available, TORCHSDE_IMPORT_ERROR)), ] ) diff --git a/src/diffusers/utils/pil_utils.py b/src/diffusers/utils/pil_utils.py index ad76a32230fb..15b97c73dcb7 100644 --- a/src/diffusers/utils/pil_utils.py +++ b/src/diffusers/utils/pil_utils.py @@ -23,6 +23,9 @@ def pt_to_pil(images): + """ + Convert a torch image to a PIL image. + """ images = (images / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() images = numpy_to_pil(images) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index d8fed5dec1c8..dcb80169de74 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,5 +1,6 @@ import inspect import logging +import multiprocessing import os import random import re @@ -26,6 +27,7 @@ is_opencv_available, is_torch_available, is_torch_version, + is_torchsde_available, ) from .logging import get_logger @@ -216,6 +218,13 @@ def require_note_seq(test_case): return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) +def require_torchsde(test_case): + """ + Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. + """ + return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) + + def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray: if isinstance(arry, str): # local_path = "/home/patrick_huggingface_co/" @@ -252,12 +261,14 @@ def load_pt(url: str): def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image: """ - Args: Loads `image` to a PIL Image. + + Args: image (`str` or `PIL.Image.Image`): The image to convert to the PIL Image format. Returns: - `PIL.Image.Image`: A PIL Image. + `PIL.Image.Image`: + A PIL Image. """ if isinstance(image, str): if image.startswith("http://") or image.startswith("https://"): @@ -469,6 +480,50 @@ def summary_failures_short(tr): config.option.tbstyle = orig_tbstyle +# Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers/testing_utils.py#L1787 +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') + + class CaptureLogger: """ Args: @@ -506,3 +561,27 @@ def __exit__(self, *exc): def __repr__(self): return f"captured: {self.out}\n" + + +def enable_full_determinism(): + """ + Helper function for reproducible behavior during distributed training. See + - https://pytorch.org/docs/stable/notes/randomness.html for pytorch + """ + # Enable PyTorch deterministic mode. This potentially requires either the environment + # variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set, + # depending on the CUDA version, so we set them both here + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" + torch.use_deterministic_algorithms(True) + + # Enable CUDNN deterministic mode + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cuda.matmul.allow_tf32 = False + + +def disable_full_determinism(): + os.environ["CUDA_LAUNCH_BLOCKING"] = "0" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = "" + torch.use_deterministic_algorithms(False) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index b9815cbceede..5f64bce25e78 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -25,6 +25,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +try: + from torch._dynamo import allow_in_graph as maybe_allow_in_graph +except (ImportError, ModuleNotFoundError): + + def maybe_allow_in_graph(cls): + return cls + def randn_tensor( shape: Union[Tuple, List], @@ -33,9 +40,9 @@ def randn_tensor( dtype: Optional["torch.dtype"] = None, layout: Optional["torch.layout"] = None, ): - """This is a helper function that allows to create random tensors on the desired `device` with the desired `dtype`. When - passing a list of generators one can seed each batched size individually. If CPU generators are passed the tensor - will always be created on CPU. + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. """ # device on which tensor is created defaults to device rand_device = device diff --git a/tests/models/test_activations.py b/tests/models/test_activations.py new file mode 100644 index 000000000000..4e8e51453e98 --- /dev/null +++ b/tests/models/test_activations.py @@ -0,0 +1,48 @@ +import unittest + +import torch +from torch import nn + +from diffusers.models.activations import get_activation + + +class ActivationsTests(unittest.TestCase): + def test_swish(self): + act = get_activation("swish") + + self.assertIsInstance(act, nn.SiLU) + + self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) + + def test_silu(self): + act = get_activation("silu") + + self.assertIsInstance(act, nn.SiLU) + + self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) + + def test_mish(self): + act = get_activation("mish") + + self.assertIsInstance(act, nn.Mish) + + self.assertEqual(act(torch.tensor(-200, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) + + def test_gelu(self): + act = get_activation("gelu") + + self.assertIsInstance(act, nn.GELU) + + self.assertEqual(act(torch.tensor(-100, dtype=torch.float32)).item(), 0) + self.assertNotEqual(act(torch.tensor(-1, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(0, dtype=torch.float32)).item(), 0) + self.assertEqual(act(torch.tensor(20, dtype=torch.float32)).item(), 20) diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index db0d6c78d902..b438b2ddb4af 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -20,16 +20,13 @@ import torch from torch import nn -from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.transformer_2d import Transformer2DModel from diffusers.utils import torch_device -torch.backends.cuda.matmul.allow_tf32 = False - - class EmbeddingsTests(unittest.TestCase): def test_timestep_embeddings(self): embedding_dim = 256 @@ -314,59 +311,6 @@ def test_restnet_with_kernel_sde_vp(self): assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) -class AttentionBlockTests(unittest.TestCase): - @unittest.skipIf( - torch_device == "mps", "Matmul crashes on MPS, see https://github.com/pytorch/pytorch/issues/84039" - ) - def test_attention_block_default(self): - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - sample = torch.randn(1, 32, 64, 64).to(torch_device) - attentionBlock = AttentionBlock( - channels=32, - num_head_channels=1, - rescale_output_factor=1.0, - eps=1e-6, - norm_num_groups=32, - ).to(torch_device) - with torch.no_grad(): - attention_scores = attentionBlock(sample) - - assert attention_scores.shape == (1, 32, 64, 64) - output_slice = attention_scores[0, -1, -3:, -3:] - - expected_slice = torch.tensor( - [-1.4975, -0.0038, -0.7847, -1.4567, 1.1220, -0.8962, -1.7394, 1.1319, -0.5427], device=torch_device - ) - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - - def test_attention_block_sd(self): - # This version uses SD params and is compatible with mps - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - - sample = torch.randn(1, 512, 64, 64).to(torch_device) - attentionBlock = AttentionBlock( - channels=512, - rescale_output_factor=1.0, - eps=1e-6, - norm_num_groups=32, - ).to(torch_device) - with torch.no_grad(): - attention_scores = attentionBlock(sample) - - assert attention_scores.shape == (1, 512, 64, 64) - output_slice = attention_scores[0, -1, -3:, -3:] - - expected_slice = torch.tensor( - [-0.6621, -0.0156, -3.2766, 0.8025, -0.8609, 0.2820, 0.0905, -1.1179, -3.2126], device=torch_device - ) - assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) - - class Transformer2DModelTests(unittest.TestCase): def test_spatial_transformer_default(self): torch.manual_seed(0) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6f1e85e15558..aaacf1e68f9f 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,18 +12,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile import unittest import torch import torch.nn as nn +import torch.nn.functional as F from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + AttnProcessor2_0, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device def create_unet_lora_layers(unet: nn.Module): @@ -38,20 +48,48 @@ def create_unet_lora_layers(unet: nn.Module): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] - lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim + ) unet_lora_layers = AttnProcsLayers(lora_attn_procs) return lora_attn_procs, unet_lora_layers -def create_text_encoder_lora_layers(text_encoder: nn.Module): +def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=module.out_proj.out_features, cross_attention_dim=None + ) + return text_lora_attn_procs + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers +def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): + for _, attn_proc in text_lora_attn_procs.items(): + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -135,13 +173,33 @@ def get_dummy_inputs(self): return noise, input_ids, pipeline_inputs + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def create_lora_weight_file(self, tmpdirname): + _, lora_components = self.get_dummy_components() + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + def test_lora_save_load(self): pipeline_components, lora_components = self.get_dummy_components() sd_pipe = StableDiffusionPipeline(**pipeline_components) sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) - noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + _, _, pipeline_inputs = self.get_dummy_inputs() original_images = sd_pipe(**pipeline_inputs).images orig_image_slice = original_images[0, -3:, -3:, -1] @@ -167,7 +225,7 @@ def test_lora_save_load_safetensors(self): sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) - noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + _, _, pipeline_inputs = self.get_dummy_inputs() original_images = sd_pipe(**pipeline_inputs).images orig_image_slice = original_images[0, -3:, -3:, -1] @@ -195,7 +253,7 @@ def test_lora_save_load_legacy(self): sd_pipe = sd_pipe.to(torch_device) sd_pipe.set_progress_bar_config(disable=None) - noise, input_ids, pipeline_inputs = self.get_dummy_inputs() + _, _, pipeline_inputs = self.get_dummy_inputs() original_images = sd_pipe(**pipeline_inputs).images orig_image_slice = original_images[0, -3:, -3:, -1] @@ -212,3 +270,200 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + def test_text_encoder_lora_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # create lora_attn_procs with zeroed out up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=False) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # create lora_attn_procs with randn up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" + + def test_text_encoder_lora_remove_monkey_patch(self): + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 32) + + # create lora_attn_procs with randn up.weights + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_attn_procs) + + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs + gc.collect() + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 32) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora outputs should be different to without lora outputs" + + # remove monkey patch + pipe._remove_text_encoder_monkey_patch() + + # inference with removed lora + outputs_without_lora_removed = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora_removed.shape == (1, 77, 32) + + assert torch.allclose( + outputs_without_lora, outputs_without_lora_removed + ), "remove lora monkey patch should restore the original outputs" + + def test_text_encoder_lora_scale(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + lora_images_with_scale = sd_pipe(**pipeline_inputs, cross_attention_kwargs={"scale": 0.5}).images + lora_image_with_scale_slice = lora_images_with_scale[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse( + torch.allclose(torch.from_numpy(lora_image_slice), torch.from_numpy(lora_image_with_scale_slice)) + ) + + def test_lora_unet_attn_processors(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.create_lora_weight_file(tmpdirname) + + pipeline_components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # check if vanilla attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, (AttnProcessor, AttnProcessor2_0)) + + # load LoRA weight file + sd_pipe.load_lora_weights(tmpdirname) + + # check if lora attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + attn_proc_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + self.assertIsInstance(module.processor, attn_proc_class) + + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") + def test_lora_unet_attn_processors_with_xformers(self): + with tempfile.TemporaryDirectory() as tmpdirname: + self.create_lora_weight_file(tmpdirname) + + pipeline_components, _ = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # enable XFormers + sd_pipe.enable_xformers_memory_efficient_attention() + + # check if xFormers attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, XFormersAttnProcessor) + + # load LoRA weight file + sd_pipe.load_lora_weights(tmpdirname) + + # check if lora attention processors are used + for _, module in sd_pipe.unet.named_modules(): + if isinstance(module, Attention): + self.assertIsInstance(module.processor, LoRAXFormersAttnProcessor) + + @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") + def test_lora_save_load_with_xformers(self): + pipeline_components, lora_components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**pipeline_components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + _, _, pipeline_inputs = self.get_dummy_inputs() + + # enable XFormers + sd_pipe.enable_xformers_memory_efficient_attention() + + original_images = sd_pipe(**pipeline_inputs).images + orig_image_slice = original_images[0, -3:, -3:, -1] + + with tempfile.TemporaryDirectory() as tmpdirname: + LoraLoaderMixin.save_lora_weights( + save_directory=tmpdirname, + unet_lora_layers=lora_components["unet_lora_layers"], + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], + ) + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) + sd_pipe.load_lora_weights(tmpdirname) + + lora_images = sd_pipe(**pipeline_inputs).images + lora_image_slice = lora_images[0, -3:, -3:, -1] + + # Outputs shouldn't match. + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 4a94a77fcabb..adc18e003a56 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -15,6 +15,7 @@ import inspect import tempfile +import traceback import unittest import unittest.mock as mock from typing import Dict, List, Tuple @@ -27,7 +28,31 @@ from diffusers.models import UNet2DConditionModel from diffusers.training_utils import EMAModel from diffusers.utils import logging, torch_device -from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from diffusers.utils.testing_utils import CaptureLogger, require_torch_2, run_test_in_subprocess + + +# Will be run via run_test_in_subprocess +def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): + error = None + try: + init_dict, model_class = in_queue.get(timeout=timeout) + + model = model_class(**init_dict) + model.to(torch_device) + model = torch.compile(model) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = model_class.from_pretrained(tmpdirname) + new_model.to(torch_device) + + assert new_model.__class__ == model_class + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() class ModelUtilsTest(unittest.TestCase): @@ -235,20 +260,11 @@ def test_from_save_pretrained_variant(self): max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") - @require_torch_gpu + @require_torch_2 def test_from_save_pretrained_dynamo(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - - model = self.model_class(**init_dict) - model.to(torch_device) - model = torch.compile(model) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - new_model = self.model_class.from_pretrained(tmpdirname) - new_model.to(torch_device) - - assert new_model.__class__ == self.model_class + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + inputs = [init_dict, self.model_class] + run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=inputs) def test_from_save_pretrained_dtype(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -268,7 +284,7 @@ def test_from_save_pretrained_dtype(self): new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) assert new_model.dtype == dtype - def test_determinism(self): + def test_determinism(self, expected_max_diff=1e-5): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) model.to(torch_device) @@ -288,7 +304,7 @@ def test_determinism(self): out_1 = out_1[~np.isnan(out_1)] out_2 = out_2[~np.isnan(out_2)] max_diff = np.amax(np.abs(out_1 - out_2)) - self.assertLessEqual(max_diff, 1e-5) + self.assertLessEqual(max_diff, expected_max_diff) def test_output(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index f954d876fa76..9fb1a61011e3 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -23,9 +23,6 @@ from .test_modeling_common import ModelTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False - - class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNet1DModel @@ -152,7 +149,7 @@ def test_unet_1d_maestro(self): output_sum = output.abs().sum() output_max = output.abs().max() - assert (output_sum - 224.0896).abs() < 4e-2 + assert (output_sum - 224.0896).abs() < 0.5 assert (output_max - 0.0607).abs() < 4e-4 diff --git a/tests/models/test_models_unet_2d.py b/tests/models/test_models_unet_2d.py index c20b0ef7d0a4..92a5664daa2b 100644 --- a/tests/models/test_models_unet_2d.py +++ b/tests/models/test_models_unet_2d.py @@ -21,12 +21,14 @@ from diffusers import UNet2DModel from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device +from diffusers.utils.testing_utils import enable_full_determinism from .test_modeling_common import ModelTesterMixin logger = logging.get_logger(__name__) -torch.backends.cuda.matmul.allow_tf32 = False + +enable_full_determinism() class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): @@ -246,10 +248,6 @@ def test_output_pretrained_ve_mid(self): model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256") model.to(torch_device) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - batch_size = 4 num_channels = 3 sizes = (256, 256) @@ -262,7 +260,7 @@ def test_output_pretrained_ve_mid(self): output_slice = output[0, -3:, -3:, -1].flatten().cpu() # fmt: off - expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) + expected_output_slice = torch.tensor([-4842.8691, -6499.6631, -3800.1953, -7978.2686, -10980.7129, -20028.8535, 8148.2822, 2342.2905, 567.7608]) # fmt: on self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) @@ -271,10 +269,6 @@ def test_output_pretrained_ve_large(self): model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") model.to(torch_device) - torch.manual_seed(0) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(0) - batch_size = 4 num_channels = 3 sizes = (32, 32) diff --git a/tests/models/test_models_unet_2d_condition.py b/tests/models/test_models_unet_2d_condition.py index 2576297762a8..8a3d9dd16fd5 100644 --- a/tests/models/test_models_unet_2d_condition.py +++ b/tests/models/test_models_unet_2d_condition.py @@ -20,6 +20,7 @@ import torch from parameterized import parameterized +from pytest import mark from diffusers import UNet2DConditionModel from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor @@ -33,12 +34,14 @@ torch_device, ) from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism from .test_modeling_common import ModelTesterMixin logger = logging.get_logger(__name__) -torch.backends.cuda.matmul.allow_tf32 = False + +enable_full_determinism() def create_lora_layers(model, mock_weights: bool = True): @@ -416,6 +419,76 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma assert processor.is_run assert processor.number == 123 + @parameterized.expand( + [ + # fmt: off + [torch.bool], + [torch.long], + [torch.float], + # fmt: on + ] + ) + def test_model_xattn_mask(self, mask_dtype): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + keepall_mask = torch.ones(*cond.shape[:-1], device=cond.device, dtype=mask_dtype) + full_cond_keepallmask_out = model(**{**inputs_dict, "encoder_attention_mask": keepall_mask}).sample + assert full_cond_keepallmask_out.allclose( + full_cond_out + ), "a 'keep all' mask should give the same result as no mask" + + trunc_cond = cond[:, :-1, :] + trunc_cond_out = model(**{**inputs_dict, "encoder_hidden_states": trunc_cond}).sample + assert not trunc_cond_out.allclose( + full_cond_out + ), "discarding the last token from our cond should change the result" + + batch, tokens, _ = cond.shape + mask_last = (torch.arange(tokens) < tokens - 1).expand(batch, -1).to(cond.device, mask_dtype) + masked_cond_out = model(**{**inputs_dict, "encoder_attention_mask": mask_last}).sample + assert masked_cond_out.allclose( + trunc_cond_out + ), "masking the last token from our cond should be equivalent to truncating that token out of the condition" + + # see diffusers.models.attention_processor::Attention#prepare_attention_mask + # note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks. + # since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric. + # maybe it's fine that this only works for the unclip use-case. + @mark.skip( + reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length." + ) + def test_model_xattn_padding(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)}) + model.to(torch_device) + model.eval() + + cond = inputs_dict["encoder_hidden_states"] + with torch.no_grad(): + full_cond_out = model(**inputs_dict).sample + assert full_cond_out is not None + + batch, tokens, _ = cond.shape + keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool) + keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample + assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result" + + trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool) + trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample + assert trunc_mask_out.allclose( + keeplast_out + ), "a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask." + def test_lora_processors(self): # enable deterministic behavior for gradient checkpointing init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -442,8 +515,8 @@ def test_lora_processors(self): sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample1 - sample2).abs().max() < 1e-4 - assert (sample3 - sample4).abs().max() < 1e-4 + assert (sample1 - sample2).abs().max() < 3e-3 + assert (sample3 - sample4).abs().max() < 3e-3 # sample 2 and sample 3 should be different assert (sample2 - sample3).abs().max() > 1e-4 @@ -587,7 +660,7 @@ def test_lora_on_off(self): new_sample = model(**inputs_dict).sample assert (sample - new_sample).abs().max() < 1e-4 - assert (sample - old_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 3e-3 @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -642,7 +715,7 @@ def test_custom_diffusion_processors(self): with torch.no_grad(): sample2 = model(**inputs_dict).sample - assert (sample1 - sample2).abs().max() < 1e-4 + assert (sample1 - sample2).abs().max() < 3e-3 def test_custom_diffusion_save_load(self): # enable deterministic behavior for gradient checkpointing @@ -677,7 +750,7 @@ def test_custom_diffusion_save_load(self): assert (sample - new_sample).abs().max() < 1e-4 # custom diffusion and no custom diffusion should be the same - assert (sample - old_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 3e-3 @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -957,7 +1030,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() expected_output_slice = torch.tensor(expected_slice) - assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) @parameterized.expand( [ diff --git a/tests/models/test_models_unet_3d_condition.py b/tests/models/test_models_unet_3d_condition.py index f245045bb3bb..762c4975da51 100644 --- a/tests/models/test_models_unet_3d_condition.py +++ b/tests/models/test_models_unet_3d_condition.py @@ -29,12 +29,14 @@ torch_device, ) from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism from .test_modeling_common import ModelTesterMixin +enable_full_determinism() + logger = logging.get_logger(__name__) -torch.backends.cuda.matmul.allow_tf32 = False def create_lora_layers(model, mock_weights: bool = True): @@ -224,11 +226,11 @@ def test_lora_processors(self): sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample1 - sample2).abs().max() < 1e-4 - assert (sample3 - sample4).abs().max() < 1e-4 + assert (sample1 - sample2).abs().max() < 3e-3 + assert (sample3 - sample4).abs().max() < 3e-3 # sample 2 and sample 3 should be different - assert (sample2 - sample3).abs().max() > 1e-4 + assert (sample2 - sample3).abs().max() > 3e-3 def test_lora_save_load(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() @@ -259,7 +261,7 @@ def test_lora_save_load(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 5e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 @@ -293,7 +295,7 @@ def test_lora_save_load_safetensors(self): with torch.no_grad(): new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample - assert (sample - new_sample).abs().max() < 1e-4 + assert (sample - new_sample).abs().max() < 3e-4 # LoRA and no LoRA should NOT be the same assert (sample - old_sample).abs().max() > 1e-4 @@ -365,7 +367,7 @@ def test_lora_on_off(self): new_sample = model(**inputs_dict).sample assert (sample - new_sample).abs().max() < 1e-4 - assert (sample - old_sample).abs().max() < 1e-4 + assert (sample - old_sample).abs().max() < 3e-3 @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), diff --git a/tests/models/test_models_vae.py b/tests/models/test_models_vae.py index 6cb71bebb9c0..fe27e138f5fa 100644 --- a/tests/models/test_models_vae.py +++ b/tests/models/test_models_vae.py @@ -21,11 +21,13 @@ from diffusers import AutoencoderKL from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism from .test_modeling_common import ModelTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): @@ -225,7 +227,7 @@ def test_stable_diffusion(self, seed, expected_slice, expected_slice_mps): output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice) - assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) @parameterized.expand( [ @@ -271,7 +273,7 @@ def test_stable_diffusion_mode(self, seed, expected_slice, expected_slice_mps): output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() expected_output_slice = torch.tensor(expected_slice_mps if torch_device == "mps" else expected_slice) - assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) + assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) @parameterized.expand( [ @@ -319,8 +321,9 @@ def test_stable_diffusion_decode_fp16(self, seed, expected_slice): assert torch_all_close(output_slice, expected_output_slice, atol=5e-3) - @parameterized.expand([13, 16, 27]) + @parameterized.expand([(13,), (16,), (27,)]) @require_torch_gpu + @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): model = self.get_sd_vae_model(fp16=True) encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True) @@ -336,8 +339,9 @@ def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed): assert torch_all_close(sample, sample_2, atol=1e-1) - @parameterized.expand([13, 16, 37]) + @parameterized.expand([(13,), (16,), (37,)]) @require_torch_gpu + @unittest.skipIf(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.") def test_stable_diffusion_decode_xformers_vs_2_0(self, seed): model = self.get_sd_vae_model() encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64)) @@ -375,5 +379,5 @@ def test_stable_diffusion_encode_sample(self, seed, expected_slice): output_slice = sample[0, -1, -3:, -3:].flatten().cpu() expected_output_slice = torch.tensor(expected_slice) - tolerance = 1e-3 if torch_device != "mps" else 1e-2 + tolerance = 3e-3 if torch_device != "mps" else 1e-2 assert torch_all_close(output_slice, expected_output_slice, atol=tolerance) diff --git a/tests/models/test_models_vq.py b/tests/models/test_models_vq.py index 015d2abfc6fa..8ea6ef77ce63 100644 --- a/tests/models/test_models_vq.py +++ b/tests/models/test_models_vq.py @@ -19,11 +19,12 @@ from diffusers import VQModel from diffusers.utils import floats_tensor, torch_device +from diffusers.utils.testing_utils import enable_full_determinism from .test_modeling_common import ModelTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class VQModelTests(ModelTesterMixin, unittest.TestCase): diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 812d83e2f241..32f7ae8a9a8e 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -20,7 +20,10 @@ from diffusers import UNet2DConditionModel from diffusers.training_utils import EMAModel -from diffusers.utils.testing_utils import skip_mps, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device + + +enable_full_determinism() class EMAModelTests(unittest.TestCase): diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py index 4f0e2c5aecfd..c2cd6f4a04f4 100644 --- a/tests/others/test_image_processor.py +++ b/tests/others/test_image_processor.py @@ -42,7 +42,7 @@ def to_np(self, image): return image def test_vae_image_processor_pt(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_pt = self.dummy_sample input_np = self.to_np(input_pt) @@ -59,7 +59,7 @@ def test_vae_image_processor_pt(self): ), f"decoded output does not match input for output_type {output_type}" def test_vae_image_processor_np(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) for output_type in ["pt", "np", "pil"]: @@ -72,7 +72,7 @@ def test_vae_image_processor_np(self): ), f"decoded output does not match input for output_type {output_type}" def test_vae_image_processor_pil(self): - image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + image_processor = VaeImageProcessor(do_resize=False, do_normalize=True) input_np = self.dummy_sample.cpu().numpy().transpose(0, 2, 3, 1) input_pil = image_processor.numpy_to_pil(input_np) diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion.py b/tests/pipelines/altdiffusion/test_alt_diffusion.py index 4d19621f0c2c..1344d33a2552 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion.py @@ -26,19 +26,21 @@ RobertaSeriesModelWithTransformation, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class AltDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = AltDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -125,6 +127,12 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + def test_attention_slicing_forward_pass(self): + super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + def test_alt_diffusion_ddim(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator diff --git a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py index 144107ec1c97..61457e6ca01f 100644 --- a/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py +++ b/tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py @@ -33,10 +33,10 @@ RobertaSeriesModelWithTransformation, ) from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase): @@ -123,6 +123,7 @@ def test_stable_diffusion_img2img_default_case(self): tokenizer.model_max_length = 77 init_image = self.dummy_image.to(device) + init_image = init_image / 2 + 0.5 # make sure here that pndm scheduler skips prk alt_pipe = AltDiffusionImg2ImgPipeline( @@ -134,7 +135,7 @@ def test_stable_diffusion_img2img_default_case(self): safety_checker=None, feature_extractor=self.dummy_extractor, ) - alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=False) + alt_pipe.image_processor = VaeImageProcessor(vae_scale_factor=alt_pipe.vae_scale_factor, do_normalize=True) alt_pipe = alt_pipe.to(device) alt_pipe.set_progress_bar_config(disable=None) @@ -250,7 +251,7 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self): assert image.shape == (504, 760, 3) expected_slice = np.array([0.9358, 0.9397, 0.9599, 0.9901, 1.0000, 1.0000, 0.9882, 1.0000, 1.0000]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow @@ -296,4 +297,4 @@ def test_stable_diffusion_img2img_pipeline_default(self): assert image.shape == (512, 768, 3) # img2img is flaky across GPUs even in fp32, so using MAE here - assert np.abs(expected_image - image).max() < 1e-3 + assert np.abs(expected_image - image).max() < 1e-2 diff --git a/tests/pipelines/audio_diffusion/test_audio_diffusion.py b/tests/pipelines/audio_diffusion/test_audio_diffusion.py index 0eb6252410f5..8c20f011cb86 100644 --- a/tests/pipelines/audio_diffusion/test_audio_diffusion.py +++ b/tests/pipelines/audio_diffusion/test_audio_diffusion.py @@ -30,10 +30,10 @@ UNet2DModel, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class PipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py index ec72108fafc9..0825fc36a266 100644 --- a/tests/pipelines/audioldm/test_audioldm.py +++ b/tests/pipelines/audioldm/test_audioldm.py @@ -37,11 +37,15 @@ UNet2DConditionModel, ) from diffusers.utils import slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..test_pipelines_common import PipelineTesterMixin +enable_full_determinism() + + class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = AudioLDMPipeline params = TEXT_TO_AUDIO_PARAMS @@ -413,4 +417,4 @@ def test_audioldm_lms(self): audio_slice = audio[27780:27790] expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212]) max_diff = np.abs(expected_slice - audio_slice).max() - assert max_diff < 1e-2 + assert max_diff < 3e-2 diff --git a/tests/pipelines/controlnet/__init__.py b/tests/pipelines/controlnet/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py similarity index 80% rename from tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py rename to tests/pipelines/controlnet/test_controlnet.py index 70b3652fce77..9915998be24e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -15,6 +15,7 @@ import gc import tempfile +import traceback import unittest import numpy as np @@ -25,22 +26,83 @@ AutoencoderKL, ControlNetModel, DDIMScheduler, + EulerDiscreteScheduler, StableDiffusionControlNetPipeline, UNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device from diffusers.utils.import_utils import is_xformers_available -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_stable_diffusion_compile(in_queue, out_queue, timeout): + error = None + try: + _ = in_queue.get(timeout=timeout) + + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.to("cuda") + pipe.set_progress_bar_config(disable=None) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + pipe.controlnet.to(memory_format=torch.channels_last) + pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "bird" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin + output = pipe(prompt, image, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out_full.npy" + ) + assert np.abs(expected_image - image).max() < 1.0 -class StableDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -149,6 +211,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unitt pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) @@ -301,7 +364,7 @@ def test_save_load_optional_components(self): @slow @require_torch_gpu -class StableDiffusionControlNetPipelineSlowTests(unittest.TestCase): +class ControlNetPipelineSlowTests(unittest.TestCase): def tearDown(self): super().tearDown() gc.collect() @@ -332,7 +395,7 @@ def test_canny(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 9e-2 def test_depth(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-depth") @@ -359,7 +422,7 @@ def test_depth(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 8e-1 def test_hed(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-hed") @@ -386,7 +449,7 @@ def test_hed(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/man_hed_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 8e-2 def test_mlsd(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-mlsd") @@ -413,7 +476,7 @@ def test_mlsd(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/room_mlsd_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 5e-2 def test_normal(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-normal") @@ -440,7 +503,7 @@ def test_normal(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/cute_toy_normal_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 5e-2 def test_openpose(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-openpose") @@ -467,7 +530,7 @@ def test_openpose(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/chef_pose_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 8e-2 def test_scribble(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-scribble") @@ -494,7 +557,7 @@ def test_scribble(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bag_scribble_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 8e-2 def test_seg(self): controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") @@ -521,7 +584,7 @@ def test_seg(self): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/house_seg_out.npy" ) - assert np.abs(expected_image - image).max() < 5e-3 + assert np.abs(expected_image - image).max() < 8e-2 def test_sequential_cpu_offloading(self): torch.cuda.empty_cache() @@ -585,6 +648,74 @@ def test_canny_guess_mode(self): expected_slice = np.array([0.2724, 0.2846, 0.2724, 0.3843, 0.3682, 0.2736, 0.4675, 0.3862, 0.2887]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_canny_guess_mode_euler(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "" + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=3.0, + guess_mode=True, + ) + + image = output.images[0] + assert image.shape == (768, 512, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.1655, 0.1721, 0.1623, 0.1685, 0.1711, 0.1646, 0.1651, 0.1631, 0.1494]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @require_torch_2 + def test_stable_diffusion_compile(self): + run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=None) + + def test_v11_shuffle_global_pool_conditions(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11e_sd15_shuffle") + + pipe = StableDiffusionControlNetPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "New York" + image = load_image( + "https://huggingface.co/lllyasviel/control_v11e_sd15_shuffle/resolve/main/images/control.png" + ) + + output = pipe( + prompt, + image, + generator=generator, + output_type="np", + num_inference_steps=3, + guidance_scale=7.0, + ) + + image = output.images[0] + assert image.shape == (512, 640, 3) + + image_slice = image[-3:, -3:, -1] + expected_slice = np.array([0.1338, 0.1597, 0.1202, 0.1687, 0.1377, 0.1017, 0.2070, 0.1574, 0.1348]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py new file mode 100644 index 000000000000..de8f578a3cce --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -0,0 +1,367 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/ + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + StableDiffusionControlNetImg2ImgPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel +from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS.union({"control_image"}) + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + control_image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + image = floats_tensor(control_image.shape, rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +class StableDiffusionMultiControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + + control_image = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + + image = floats_tensor(control_image[0].shape, rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_save_pretrained_raise_not_implemented_exception(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdir: + try: + # save_pretrained is not implemented for Multi-ControlNet + pipe.save_pretrained(tmpdir) + except NotImplementedError: + pass + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_float16(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_local(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_optional_components(self): + ... + + +@slow +@require_torch_gpu +class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + prompt = "evil space-punk bird" + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + + output = pipe( + prompt, + image, + control_image=control_image, + generator=generator, + output_type="np", + num_inference_steps=50, + strength=0.6, + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/img2img.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py new file mode 100644 index 000000000000..0f8808bcb728 --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -0,0 +1,509 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This model implementation is heavily based on: + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDIMScheduler, + StableDiffusionControlNetInpaintPipeline, + UNet2DConditionModel, +) +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel +from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset({"control_image"}) # skip `image` and `mask` for now, only test for control_image + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + control_image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + init_image = init_image.cpu().permute(0, 2, 3, 1)[0] + + image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64)) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "mask_image": mask_image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + +class ControlNetSimpleInpaintPipelineFastTests(ControlNetInpaintPipelineFastTests): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + +class MultiControlNetInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionControlNetInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=9, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + torch.manual_seed(0) + controlnet1 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + controlnet2 = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + cross_attention_dim=32, + conditioning_embedding_out_channels=(16, 32), + ) + torch.manual_seed(0) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + controlnet = MultiControlNetModel([controlnet1, controlnet2]) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + + control_image = [ + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ), + ] + init_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + init_image = init_image.cpu().permute(0, 2, 3, 1)[0] + + image = Image.fromarray(np.uint8(init_image)).convert("RGB").resize((64, 64)) + mask_image = Image.fromarray(np.uint8(init_image + 4)).convert("RGB").resize((64, 64)) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + "mask_image": mask_image, + "control_image": control_image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3) + + def test_save_pretrained_raise_not_implemented_exception(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + with tempfile.TemporaryDirectory() as tmpdir: + try: + # save_pretrained is not implemented for Multi-ControlNet + pipe.save_pretrained(tmpdir) + except NotImplementedError: + pass + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_float16(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_local(self): + ... + + # override PipelineTesterMixin + @unittest.skip("save pretrained not implemented") + def test_save_load_optional_components(self): + ... + + +@slow +@require_torch_gpu +class ControlNetInpaintPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_canny(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") + + pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None, controlnet=controlnet + ) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + image = load_image( + "https://huggingface.co/lllyasviel/sd-controlnet-canny/resolve/main/images/bird.png" + ).resize((512, 512)) + + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main" + "/stable_diffusion_inpaint/input_bench_mask.png" + ).resize((512, 512)) + + prompt = "pitch black hole" + + control_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ).resize((512, 512)) + + output = pipe( + prompt, + image=image, + mask_image=mask_image, + control_image=control_image, + generator=generator, + output_type="np", + num_inference_steps=3, + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/inpaint.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 + + def test_inpaint(self): + controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint") + + pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(33) + + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ) + init_image = init_image.resize((512, 512)) + + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ) + mask_image = mask_image.resize((512, 512)) + + prompt = "a handsome man with ray-ban sunglasses" + + def make_inpaint_condition(image, image_mask): + image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + image[image_mask > 0.5] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + control_image = make_inpaint_condition(init_image, mask_image) + + output = pipe( + prompt, + image=init_image, + mask_image=mask_image, + control_image=control_image, + guidance_scale=9.0, + eta=1.0, + generator=generator, + num_inference_steps=20, + output_type="np", + ) + image = output.images[0] + + assert image.shape == (512, 512, 3) + + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/boy_ray_ban.npy" + ) + + assert np.abs(expected_image - image).max() < 9e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py b/tests/pipelines/controlnet/test_flax_controlnet.py similarity index 98% rename from tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py rename to tests/pipelines/controlnet/test_flax_controlnet.py index 268c01320177..4ad75b407acc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_flax_controlnet.py +++ b/tests/pipelines/controlnet/test_flax_controlnet.py @@ -30,7 +30,7 @@ @slow @require_flax -class FlaxStableDiffusionControlNetPipelineIntegrationTests(unittest.TestCase): +class FlaxControlNetPipelineIntegrationTests(unittest.TestCase): def tearDown(self): # clean up the VRAM after each test super().tearDown() diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 5db90a3aa740..0ba86daa61fc 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -21,13 +21,13 @@ from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @@ -103,7 +103,7 @@ def test_save_load_local(self): @skip_mps def test_dict_tuple_outputs_equivalent(self): - return super().test_dict_tuple_outputs_equivalent() + return super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3) @skip_mps def test_save_load_optional_components(self): @@ -113,6 +113,9 @@ def test_save_load_optional_components(self): def test_attention_slicing_forward_pass(self): return super().test_attention_slicing_forward_pass() + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 319bd778e3b2..0861d7daab29 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -19,13 +19,13 @@ import torch from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel -from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): @@ -87,6 +87,18 @@ def test_inference(self): max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_dict_tuple_outputs_equivalent(self): + super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3) + + def test_save_load_local(self): + super().test_save_load_local(expected_max_difference=3e-3) + + def test_save_load_optional_components(self): + super().test_save_load_optional_components(expected_max_difference=3e-3) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 5e3e47cb74fb..a3c290215114 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -19,10 +19,10 @@ import torch from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel -from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class DDPMPipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index bf01c2350d22..2e7383067eec 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -28,6 +28,7 @@ IFSuperResolutionPipeline, ) from diffusers.models.attention_processor import AttnAddedKVProcessor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import floats_tensor, load_numpy, require_torch_gpu, skip_mps, slow, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -42,8 +43,6 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T batch_params = TEXT_TO_IMAGE_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_dummy_components() @@ -68,7 +67,7 @@ def test_save_load_optional_components(self): @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder - self._test_save_load_float16(expected_max_diff=1e-1) + super().test_save_load_float16(expected_max_diff=1e-1) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) @@ -81,6 +80,13 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index b4c99a8ab93a..ec4598906a6f 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -20,6 +20,7 @@ from diffusers import IFImg2ImgPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import ( @@ -37,8 +38,6 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_dummy_components() @@ -63,14 +62,21 @@ def get_dummy_inputs(self, device, seed=0): def test_save_load_optional_components(self): self._test_save_load_optional_components() + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder - self._test_save_load_float16(expected_max_diff=1e-1) + super().test_save_load_float16(expected_max_diff=1e-1) @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_float16_inference(self): - self._test_float16_inference(expected_max_diff=1e-1) + super().test_float16_inference(expected_max_diff=1e-1) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py index 626ab321f895..500557108aed 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py @@ -20,6 +20,7 @@ from diffusers import IFImg2ImgSuperResolutionPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS @@ -34,8 +35,6 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"original_image"}) required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_superresolution_dummy_components() @@ -59,13 +58,20 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder - self._test_save_load_float16(expected_max_diff=1e-1) + super().test_save_load_float16(expected_max_diff=1e-1) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py index 37d818c7a910..1317fcb64e81 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py @@ -20,6 +20,7 @@ from diffusers import IFInpaintingPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import ( @@ -37,8 +38,6 @@ class IFInpaintingPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_dummy_components() @@ -62,13 +61,20 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder - self._test_save_load_float16(expected_max_diff=1e-1) + super().test_save_load_float16(expected_max_diff=1e-1) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py index 30062cb2f8d0..961a22675f33 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py @@ -20,6 +20,7 @@ from diffusers import IFInpaintingSuperResolutionPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import ( @@ -37,8 +38,6 @@ class IFInpaintingSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipeli batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS.union({"original_image"}) required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_superresolution_dummy_components() @@ -64,13 +63,20 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder - self._test_save_load_float16(expected_max_diff=1e-1) + super().test_save_load_float16(expected_max_diff=1e-1) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py index 14acfa5415c2..52fb38308892 100644 --- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py @@ -20,6 +20,7 @@ from diffusers import IFSuperResolutionPipeline from diffusers.utils import floats_tensor +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import skip_mps, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS @@ -34,8 +35,6 @@ class IFSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMi batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} - test_xformers_attention = False - def get_dummy_components(self): return self._get_superresolution_dummy_components() @@ -57,13 +56,20 @@ def get_dummy_inputs(self, device, seed=0): return inputs + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + def test_save_load_optional_components(self): self._test_save_load_optional_components() @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") def test_save_load_float16(self): # Due to non-determinism in save load of the hf-internal-testing/tiny-random-t5 text encoder - self._test_save_load_float16(expected_max_diff=1e-1) + super().test_save_load_float16(expected_max_diff=1e-1) def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(expected_max_diff=1e-2) diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py index d8098178f339..4937915696b4 100644 --- a/tests/pipelines/dit/test_dit.py +++ b/tests/pipelines/dit/test_dit.py @@ -21,7 +21,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel from diffusers.utils import is_xformers_available, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import ( CLASS_CONDITIONED_IMAGE_GENERATION_BATCH_PARAMS, @@ -30,7 +30,7 @@ from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): diff --git a/tests/pipelines/kandinsky/__init__.py b/tests/pipelines/kandinsky/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py new file mode 100644 index 000000000000..239433910b45 --- /dev/null +++ b/tests/pipelines/kandinsky/test_kandinsky.py @@ -0,0 +1,282 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from transformers import XLMRobertaTokenizerFast + +from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel +from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP +from diffusers.utils import floats_tensor, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +enable_full_determinism() + + +class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = KandinskyPipeline + params = [ + "prompt", + "image_embeds", + "negative_image_embeds", + ] + batch_params = ["prompt", "negative_prompt", "image_embeds", "negative_image_embeds"] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "negative_prompt", + "num_inference_steps", + "return_dict", + "guidance_scale", + "num_images_per_prompt", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = XLMRobertaTokenizerFast.from_pretrained("YiYiXu/tiny-random-mclip-base") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = MCLIPConfig( + numDims=self.cross_attention_dim, + transformerDimensions=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + num_attention_heads=4, + num_hidden_layers=5, + vocab_size=1005, + ) + + text_encoder = MultilingualCLIP(config) + text_encoder = text_encoder.eval() + + return text_encoder + + @property + def dummy_unet(self): + torch.manual_seed(0) + + model_kwargs = { + "in_channels": 4, + # Out channels is double in channels because predicts mean and variance + "out_channels": 8, + "addition_embed_type": "text_image", + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "encoder_hid_dim": self.text_embedder_hidden_size, + "encoder_hid_dim_type": "text_image_proj", + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": None, + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self): + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + unet = self.dummy_unet + movq = self.dummy_movq + + scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="linear", + beta_start=0.00085, + beta_end=0.012, + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + prediction_type="epsilon", + thresholding=False, + ) + + components = { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "movq": movq, + } + return components + + def get_dummy_inputs(self, device, seed=0): + image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed)).to(device) + negative_image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed + 1)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "image_embeds": image_embeds, + "negative_image_embeds": negative_image_embeds, + "generator": generator, + "height": 64, + "width": 64, + "guidance_scale": 4.0, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_kandinsky(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe( + **self.get_dummy_inputs(device), + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [0.328663, 1.0, 0.23216873, 1.0, 0.92717564, 0.4639046, 0.96894777, 0.31713378, 0.6293953] + ) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert ( + np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + + +@slow +@require_torch_gpu +class KandinskyPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinsky_text2img(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/kandinsky/kandinsky_text2img_cat_fp16.npy" + ) + + pipe_prior = KandinskyPriorPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ) + pipe_prior.to(torch_device) + + pipeline = KandinskyPipeline.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + prompt = "red cat, 4k photo" + + generator = torch.Generator(device="cuda").manual_seed(0) + image_emb, zero_image_emb = pipe_prior( + prompt, + generator=generator, + num_inference_steps=5, + negative_prompt="", + ).to_tuple() + + generator = torch.Generator(device="cuda").manual_seed(0) + output = pipeline( + prompt, + image_embeds=image_emb, + negative_image_embeds=zero_image_emb, + generator=generator, + num_inference_steps=100, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (512, 512, 3) + + assert_mean_pixel_difference(image, expected_image) diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py new file mode 100644 index 000000000000..94817b3eed4b --- /dev/null +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import XLMRobertaTokenizerFast + +from diffusers import DDIMScheduler, KandinskyImg2ImgPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel +from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +enable_full_determinism() + + +class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = KandinskyImg2ImgPipeline + params = ["prompt", "image_embeds", "negative_image_embeds", "image"] + batch_params = [ + "prompt", + "negative_prompt", + "image_embeds", + "negative_image_embeds", + "image", + ] + required_optional_params = [ + "generator", + "height", + "width", + "strength", + "guidance_scale", + "negative_prompt", + "num_inference_steps", + "return_dict", + "guidance_scale", + "num_images_per_prompt", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = XLMRobertaTokenizerFast.from_pretrained("YiYiXu/tiny-random-mclip-base") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = MCLIPConfig( + numDims=self.cross_attention_dim, + transformerDimensions=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + num_attention_heads=4, + num_hidden_layers=5, + vocab_size=1005, + ) + + text_encoder = MultilingualCLIP(config) + text_encoder = text_encoder.eval() + + return text_encoder + + @property + def dummy_unet(self): + torch.manual_seed(0) + + model_kwargs = { + "in_channels": 4, + # Out channels is double in channels because predicts mean and variance + "out_channels": 8, + "addition_embed_type": "text_image", + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "encoder_hid_dim": self.text_embedder_hidden_size, + "encoder_hid_dim_type": "text_image_proj", + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": None, + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self): + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + unet = self.dummy_unet + movq = self.dummy_movq + + ddim_config = { + "num_train_timesteps": 1000, + "beta_schedule": "linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "clip_sample": False, + "set_alpha_to_one": False, + "steps_offset": 0, + "prediction_type": "epsilon", + "thresholding": False, + } + + scheduler = DDIMScheduler(**ddim_config) + + components = { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "movq": movq, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed)).to(device) + negative_image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed + 1)).to(device) + # create init_image + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((256, 256)) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "image": init_image, + "image_embeds": image_embeds, + "negative_image_embeds": negative_image_embeds, + "generator": generator, + "height": 64, + "width": 64, + "num_inference_steps": 10, + "guidance_scale": 7.0, + "strength": 0.2, + "output_type": "np", + } + return inputs + + def test_kandinsky_img2img(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe( + **self.get_dummy_inputs(device), + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [0.61474943, 0.6073539, 0.43308544, 0.5928269, 0.47493595, 0.46755973, 0.4613838, 0.45368797, 0.50119233] + ) + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert ( + np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + + +@slow +@require_torch_gpu +class KandinskyImg2ImgPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinsky_img2img(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/kandinsky/kandinsky_img2img_frog.npy" + ) + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + ) + prompt = "A red cartoon frog, 4k" + + pipe_prior = KandinskyPriorPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ) + pipe_prior.to(torch_device) + + pipeline = KandinskyImg2ImgPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 + ) + pipeline = pipeline.to(torch_device) + + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + image_emb, zero_image_emb = pipe_prior( + prompt, + generator=generator, + num_inference_steps=5, + negative_prompt="", + ).to_tuple() + + output = pipeline( + prompt, + image=init_image, + image_embeds=image_emb, + negative_image_embeds=zero_image_emb, + generator=generator, + num_inference_steps=100, + height=768, + width=768, + strength=0.2, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (768, 768, 3) + + assert_mean_pixel_difference(image, expected_image) diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py new file mode 100644 index 000000000000..46926479ae06 --- /dev/null +++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py @@ -0,0 +1,313 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import XLMRobertaTokenizerFast + +from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel +from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP +from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu + +from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference + + +enable_full_determinism() + + +class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = KandinskyInpaintPipeline + params = ["prompt", "image_embeds", "negative_image_embeds", "image", "mask_image"] + batch_params = [ + "prompt", + "negative_prompt", + "image_embeds", + "negative_image_embeds", + "image", + "mask_image", + ] + required_optional_params = [ + "generator", + "height", + "width", + "latents", + "guidance_scale", + "negative_prompt", + "num_inference_steps", + "return_dict", + "guidance_scale", + "num_images_per_prompt", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = XLMRobertaTokenizerFast.from_pretrained("YiYiXu/tiny-random-mclip-base") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = MCLIPConfig( + numDims=self.cross_attention_dim, + transformerDimensions=self.text_embedder_hidden_size, + hidden_size=self.text_embedder_hidden_size, + intermediate_size=37, + num_attention_heads=4, + num_hidden_layers=5, + vocab_size=1005, + ) + + text_encoder = MultilingualCLIP(config) + text_encoder = text_encoder.eval() + + return text_encoder + + @property + def dummy_unet(self): + torch.manual_seed(0) + + model_kwargs = { + "in_channels": 9, + # Out channels is double in channels because predicts mean and variance + "out_channels": 8, + "addition_embed_type": "text_image", + "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"), + "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"), + "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", + "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2), + "layers_per_block": 1, + "encoder_hid_dim": self.text_embedder_hidden_size, + "encoder_hid_dim_type": "text_image_proj", + "cross_attention_dim": self.cross_attention_dim, + "attention_head_dim": 4, + "resnet_time_scale_shift": "scale_shift", + "class_embed_type": None, + } + + model = UNet2DConditionModel(**model_kwargs) + return model + + @property + def dummy_movq_kwargs(self): + return { + "block_out_channels": [32, 64], + "down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 1, + "norm_num_groups": 8, + "norm_type": "spatial", + "num_vq_embeddings": 12, + "out_channels": 3, + "up_block_types": [ + "AttnUpDecoderBlock2D", + "UpDecoderBlock2D", + ], + "vq_embed_dim": 4, + } + + @property + def dummy_movq(self): + torch.manual_seed(0) + model = VQModel(**self.dummy_movq_kwargs) + return model + + def get_dummy_components(self): + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + unet = self.dummy_unet + movq = self.dummy_movq + + scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="linear", + beta_start=0.00085, + beta_end=0.012, + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, + prediction_type="epsilon", + thresholding=False, + ) + + components = { + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "unet": unet, + "scheduler": scheduler, + "movq": movq, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed)).to(device) + negative_image_embeds = floats_tensor((1, self.cross_attention_dim), rng=random.Random(seed + 1)).to(device) + # create init_image + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((256, 256)) + # create mask + mask = np.ones((64, 64), dtype=np.float32) + mask[:32, :32] = 0 + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "image": init_image, + "mask_image": mask, + "image_embeds": image_embeds, + "negative_image_embeds": negative_image_embeds, + "generator": generator, + "height": 64, + "width": 64, + "num_inference_steps": 2, + "guidance_scale": 4.0, + "output_type": "np", + } + return inputs + + def test_kandinsky_inpaint(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.images + + image_from_tuple = pipe( + **self.get_dummy_inputs(device), + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + + print(f"image.shape {image.shape}") + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array( + [0.8326919, 0.73790467, 0.20918581, 0.9309612, 0.5511791, 0.43713328, 0.5513321, 0.49922934, 0.59497786] + ) + + assert ( + np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}" + assert ( + np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}" + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + +@slow +@require_torch_gpu +class KandinskyInpaintPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_kandinsky_inpaint(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/kandinsky/kandinsky_inpaint_cat_with_hat_fp16.npy" + ) + + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/kandinsky/cat.png" + ) + mask = np.ones((768, 768), dtype=np.float32) + mask[:250, 250:-250] = 0 + + prompt = "a hat" + + pipe_prior = KandinskyPriorPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-prior", torch_dtype=torch.float16 + ) + pipe_prior.to(torch_device) + + pipeline = KandinskyInpaintPipeline.from_pretrained( + "kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16 + ) + pipeline = pipeline.to(torch_device) + pipeline.set_progress_bar_config(disable=None) + + generator = torch.Generator(device="cpu").manual_seed(0) + image_emb, zero_image_emb = pipe_prior( + prompt, + generator=generator, + num_inference_steps=5, + negative_prompt="", + ).to_tuple() + + output = pipeline( + prompt, + image=init_image, + mask_image=mask, + image_embeds=image_emb, + negative_image_embeds=zero_image_emb, + generator=generator, + num_inference_steps=100, + height=768, + width=768, + output_type="np", + ) + + image = output.images[0] + + assert image.shape == (768, 768, 3) + + assert_mean_pixel_difference(image, expected_image) diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py new file mode 100644 index 000000000000..d9c260eabc06 --- /dev/null +++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from torch import nn +from transformers import ( + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionConfig, + CLIPVisionModelWithProjection, +) + +from diffusers import KandinskyPriorPipeline, PriorTransformer, UnCLIPScheduler +from diffusers.utils import torch_device +from diffusers.utils.testing_utils import enable_full_determinism, skip_mps + +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = KandinskyPriorPipeline + params = ["prompt"] + batch_params = ["prompt", "negative_prompt"] + required_optional_params = [ + "num_images_per_prompt", + "generator", + "num_inference_steps", + "latents", + "negative_prompt", + "guidance_scale", + "output_type", + "return_dict", + ] + test_xformers_attention = False + + @property + def text_embedder_hidden_size(self): + return 32 + + @property + def time_input_dim(self): + return 32 + + @property + def block_out_channels_0(self): + return self.time_input_dim + + @property + def time_embed_dim(self): + return self.time_input_dim * 4 + + @property + def cross_attention_dim(self): + return 100 + + @property + def dummy_tokenizer(self): + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + return tokenizer + + @property + def dummy_text_encoder(self): + torch.manual_seed(0) + config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=self.text_embedder_hidden_size, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + return CLIPTextModelWithProjection(config) + + @property + def dummy_prior(self): + torch.manual_seed(0) + + model_kwargs = { + "num_attention_heads": 2, + "attention_head_dim": 12, + "embedding_dim": self.text_embedder_hidden_size, + "num_layers": 1, + } + + model = PriorTransformer(**model_kwargs) + # clip_std and clip_mean is initialized to be 0 so PriorTransformer.post_process_latents will always return 0 - set clip_std to be 1 so it won't return 0 + model.clip_std = nn.Parameter(torch.ones(model.clip_std.shape)) + return model + + @property + def dummy_image_encoder(self): + torch.manual_seed(0) + config = CLIPVisionConfig( + hidden_size=self.text_embedder_hidden_size, + image_size=224, + projection_dim=self.text_embedder_hidden_size, + intermediate_size=37, + num_attention_heads=4, + num_channels=3, + num_hidden_layers=5, + patch_size=14, + ) + + model = CLIPVisionModelWithProjection(config) + return model + + @property + def dummy_image_processor(self): + image_processor = CLIPImageProcessor( + crop_size=224, + do_center_crop=True, + do_normalize=True, + do_resize=True, + image_mean=[0.48145466, 0.4578275, 0.40821073], + image_std=[0.26862954, 0.26130258, 0.27577711], + resample=3, + size=224, + ) + + return image_processor + + def get_dummy_components(self): + prior = self.dummy_prior + image_encoder = self.dummy_image_encoder + text_encoder = self.dummy_text_encoder + tokenizer = self.dummy_tokenizer + image_processor = self.dummy_image_processor + + scheduler = UnCLIPScheduler( + variance_type="fixed_small_log", + prediction_type="sample", + num_train_timesteps=1000, + clip_sample=True, + clip_sample_range=10.0, + ) + + components = { + "prior": prior, + "image_encoder": image_encoder, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + "image_processor": image_processor, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "horse", + "generator": generator, + "guidance_scale": 4.0, + "num_inference_steps": 2, + "output_type": "np", + } + return inputs + + def test_kandinsky_prior(self): + device = "cpu" + + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + + pipe.set_progress_bar_config(disable=None) + + output = pipe(**self.get_dummy_inputs(device)) + image = output.image_embeds + + image_from_tuple = pipe( + **self.get_dummy_inputs(device), + return_dict=False, + )[0] + + image_slice = image[0, -10:] + image_from_tuple_slice = image_from_tuple[0, -10:] + + assert image.shape == (1, 32) + + expected_slice = np.array( + [-0.0532, 1.7120, 0.3656, -1.0852, -0.8946, -1.1756, 0.4348, 0.2482, 0.5146, -0.1156] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + + @skip_mps + def test_inference_batch_single_identical(self): + test_max_difference = torch_device == "cpu" + relax_max_difference = True + test_mean_pixel_difference = False + + self._test_inference_batch_single_identical( + test_max_difference=test_max_difference, + relax_max_difference=relax_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) + + @skip_mps + def test_attention_slicing_forward_pass(self): + test_max_difference = torch_device == "cpu" + test_mean_pixel_difference = False + + self._test_attention_slicing_forward_pass( + test_max_difference=test_max_difference, + test_mean_pixel_difference=test_mean_pixel_difference, + ) diff --git a/tests/pipelines/karras_ve/test_karras_ve.py b/tests/pipelines/karras_ve/test_karras_ve.py index 391e61a2b9c9..142058bcd710 100644 --- a/tests/pipelines/karras_ve/test_karras_ve.py +++ b/tests/pipelines/karras_ve/test_karras_ve.py @@ -19,10 +19,10 @@ import torch from diffusers import KarrasVePipeline, KarrasVeScheduler, UNet2DModel -from diffusers.utils.testing_utils import require_torch, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch, slow, torch_device -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class KarrasVePipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion.py b/tests/pipelines/latent_diffusion/test_latent_diffusion.py index 05ff4162e5c6..88dc8ef9b17b 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion.py @@ -21,13 +21,20 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel -from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_numpy, + nightly, + require_torch_gpu, + slow, + torch_device, +) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase): diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py index f1aa2f08efba..d21ead543af8 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_superresolution.py @@ -21,10 +21,10 @@ from diffusers import DDIMScheduler, LDMSuperResolutionPipeline, UNet2DModel, VQModel from diffusers.utils import PIL_INTERPOLATION, floats_tensor, load_image, slow, torch_device -from diffusers.utils.testing_utils import require_torch +from diffusers.utils.testing_utils import enable_full_determinism, require_torch -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class LDMSuperResolutionPipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py b/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py index aa7b33730d18..ff8670ea2950 100644 --- a/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py +++ b/tests/pipelines/latent_diffusion/test_latent_diffusion_uncond.py @@ -20,10 +20,10 @@ from transformers import CLIPTextConfig, CLIPTextModel from diffusers import DDIMScheduler, LDMPipeline, UNet2DModel, VQModel -from diffusers.utils.testing_utils import require_torch, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch, slow, torch_device -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class LDMPipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index 17feba59e8e4..14c16644889e 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -25,19 +25,20 @@ from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder from diffusers.utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = PaintByExamplePipeline params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess def get_dummy_components(self): torch.manual_seed(0) @@ -160,6 +161,9 @@ def test_paint_by_example_image_tensor(self): assert out_1.shape == (1, 64, 64, 3) assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2 + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/pipeline_params.py b/tests/pipelines/pipeline_params.py index a0ac6c641c0b..7c5ffa2ca24b 100644 --- a/tests/pipelines/pipeline_params.py +++ b/tests/pipelines/pipeline_params.py @@ -22,6 +22,10 @@ TEXT_TO_IMAGE_BATCH_PARAMS = frozenset(["prompt", "negative_prompt"]) +TEXT_TO_IMAGE_IMAGE_PARAMS = frozenset([]) + +IMAGE_TO_IMAGE_IMAGE_PARAMS = frozenset(["image"]) + IMAGE_VARIATION_PARAMS = frozenset( [ "image", diff --git a/tests/pipelines/pndm/test_pndm.py b/tests/pipelines/pndm/test_pndm.py index bed5fea561dc..c2595713933c 100644 --- a/tests/pipelines/pndm/test_pndm.py +++ b/tests/pipelines/pndm/test_pndm.py @@ -19,10 +19,10 @@ import torch from diffusers import PNDMPipeline, PNDMScheduler, UNet2DModel -from diffusers.utils.testing_utils import require_torch, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch, slow, torch_device -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class PNDMPipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index 4f98675bc5af..e372cf979ebb 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -20,13 +20,21 @@ import torch from diffusers import RePaintPipeline, RePaintScheduler, UNet2DModel -from diffusers.utils.testing_utils import load_image, load_numpy, nightly, require_torch_gpu, skip_mps, torch_device +from diffusers.utils.testing_utils import ( + enable_full_determinism, + load_image, + load_numpy, + nightly, + require_torch_gpu, + skip_mps, + torch_device, +) from ..pipeline_params import IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_INPAINTING_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class RepaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): diff --git a/tests/pipelines/score_sde_ve/test_score_sde_ve.py b/tests/pipelines/score_sde_ve/test_score_sde_ve.py index 036ecc3f6bf3..32505253f6c7 100644 --- a/tests/pipelines/score_sde_ve/test_score_sde_ve.py +++ b/tests/pipelines/score_sde_ve/test_score_sde_ve.py @@ -19,10 +19,10 @@ import torch from diffusers import ScoreSdeVePipeline, ScoreSdeVeScheduler, UNet2DModel -from diffusers.utils.testing_utils import require_torch, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, require_torch, slow, torch_device -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class ScoreSdeVeipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py index ba42b1fe9c5f..9e810616dc56 100644 --- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py +++ b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py @@ -25,10 +25,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline from diffusers.utils import floats_tensor, nightly, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class SafeDiffusionPipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py b/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py index 3b64ea2d2fc1..cc8690eb87ca 100644 --- a/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py +++ b/tests/pipelines/spectrogram_diffusion/test_spectrogram_diffusion.py @@ -22,18 +22,22 @@ from diffusers import DDPMScheduler, MidiProcessor, SpectrogramDiffusionPipeline from diffusers.pipelines.spectrogram_diffusion import SpectrogramContEncoder, SpectrogramNotesEncoder, T5FilmDecoder from diffusers.utils import require_torch_gpu, skip_mps, slow, torch_device -from diffusers.utils.testing_utils import require_note_seq, require_onnxruntime +from diffusers.utils.testing_utils import enable_full_determinism, require_note_seq, require_onnxruntime from ..pipeline_params import TOKENS_TO_AUDIO_GENERATION_BATCH_PARAMS, TOKENS_TO_AUDIO_GENERATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() MIDI_FILE = "./tests/fixtures/elise_format0.mid" +# The note-seq package throws an error on import because the default installed version of Ipython +# is not compatible with python 3.8 which we run in the CI. +# https://github.com/huggingface/diffusers/actions/runs/4830121056/jobs/8605954838#step:7:98 +@unittest.skip("The note-seq package currently throws an error on import") class SpectrogramDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = SpectrogramDiffusionPipeline required_optional_params = PipelineTesterMixin.required_optional_params - { diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py index 05b72ab6a0fd..9a54c21c0a21 100644 --- a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py @@ -23,16 +23,20 @@ from diffusers import AutoencoderKL, CycleDiffusionPipeline, DDIMScheduler, UNet2DConditionModel from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class CycleDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = CycleDiffusionPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { "negative_prompt", @@ -42,6 +46,8 @@ class CycleDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): } required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"source_prompt"}) + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -100,6 +106,7 @@ def get_dummy_components(self): def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: @@ -265,4 +272,4 @@ def test_cycle_diffusion_pipeline(self): ) image = output.images - assert np.abs(image - expected_image).max() < 1e-2 + assert np.abs(image - expected_image).max() < 2e-2 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index fcfcd84c5d48..93abe7ae58bc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -17,6 +17,7 @@ import gc import tempfile import time +import traceback import unittest import numpy as np @@ -36,22 +37,63 @@ UNet2DConditionModel, logging, ) -from diffusers.models.attention_processor import AttnProcessor +from diffusers.models.attention_processor import AttnProcessor, LoRAXFormersAttnProcessor from diffusers.utils import load_numpy, nightly, slow, torch_device -from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from diffusers.utils.testing_utils import ( + CaptureLogger, + enable_full_determinism, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, +) +from ...models.test_lora_layers import create_unet_lora_layers from ...models.test_models_unet_2d_condition import create_lora_layers -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_stable_diffusion_compile(in_queue, out_queue, timeout): + error = None + try: + inputs = in_queue.get(timeout=timeout) + torch_device = inputs.pop("torch_device") + seed = inputs.pop("seed") + inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed) + + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) + sd_pipe.scheduler = DDIMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + + sd_pipe.unet.to(memory_format=torch.channels_last) + sd_pipe.unet = torch.compile(sd_pipe.unet, mode="reduce-overhead", fullgraph=True) + + sd_pipe.set_progress_bar_config(disable=None) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.38019, 0.28647, 0.27321, 0.40377, 0.38290, 0.35446, 0.39218, 0.38165, 0.42239]) + assert np.abs(image_slice - expected_slice).max() < 5e-3 + except Exception: + error = f"{traceback.format_exc()}" -torch.backends.cuda.matmul.allow_tf32 = False + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() -class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -250,6 +292,45 @@ def test_stable_diffusion_negative_prompt_embeds(self): assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_prompt_embeds_with_plain_negative_prompt_list(self): + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + negative_prompt = 3 * ["this is a negative prompt"] + inputs["negative_prompt"] = negative_prompt + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = sd_pipe(**inputs) + image_slice_1 = output.images[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs(torch_device) + inputs["negative_prompt"] = negative_prompt + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = sd_pipe.tokenizer( + prompt, + padding="max_length", + max_length=sd_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_inputs = text_inputs["input_ids"].to(torch_device) + + prompt_embeds = sd_pipe.text_encoder(text_inputs)[0] + + inputs["prompt_embeds"] = prompt_embeds + + # forward + output = sd_pipe(**inputs) + image_slice_2 = output.images[0, -3:, -3:, -1] + + assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_ddim_factor_8(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -287,6 +368,56 @@ def test_stable_diffusion_pndm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda") + def test_stable_diffusion_attn_processors(self): + # disable_full_determinism() + device = "cuda" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + + # run normal sd pipe + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run xformers attention + sd_pipe.enable_xformers_memory_efficient_attention() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run attention slicing + sd_pipe.enable_attention_slicing() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run vae attention slicing + sd_pipe.enable_vae_slicing() + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run lora attention + attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) + attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} + sd_pipe.unet.set_attn_processor(attn_processors) + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # run lora xformers attention + attn_processors, _ = create_unet_lora_layers(sd_pipe.unet) + attn_processors = { + k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim) + for k, v in attn_processors.items() + } + attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()} + sd_pipe.unet.set_attn_processor(attn_processors) + image = sd_pipe(**inputs).images + assert image.shape == (1, 64, 64, 3) + + # enable_full_determinism() + def test_stable_diffusion_no_safety_checker(self): pipe = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None @@ -499,12 +630,17 @@ def test_stable_diffusion_height_width_opt(self): image_shape = output.images[0].shape[:2] assert image_shape == (192, 192) + def test_attention_slicing_forward_pass(self): + super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu class StableDiffusionPipelineSlowTests(unittest.TestCase): - def tearDown(self): - super().tearDown() + def setUp(self): gc.collect() torch.cuda.empty_cache() @@ -533,7 +669,7 @@ def test_stable_diffusion_1_1_pndm(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.43625, 0.43554, 0.36670, 0.40660, 0.39703, 0.38658, 0.43936, 0.43557, 0.40592]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 3e-3 def test_stable_diffusion_1_4_pndm(self): sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") @@ -546,7 +682,7 @@ def test_stable_diffusion_1_4_pndm(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.57400, 0.47841, 0.31625, 0.63583, 0.58306, 0.55056, 0.50825, 0.56306, 0.55748]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 3e-3 def test_stable_diffusion_ddim(self): sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) @@ -574,7 +710,7 @@ def test_stable_diffusion_lms(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.10542, 0.09620, 0.07332, 0.09015, 0.09382, 0.07597, 0.08496, 0.07806, 0.06455]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 3e-3 def test_stable_diffusion_dpm(self): sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) @@ -588,7 +724,7 @@ def test_stable_diffusion_dpm(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.03503, 0.03494, 0.01087, 0.03128, 0.02552, 0.00803, 0.00742, 0.00372, 0.00000]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 3e-3 def test_stable_diffusion_attention_slicing(self): torch.cuda.reset_peak_memory_stats() @@ -863,7 +999,17 @@ def test_stable_diffusion_textual_inversion(self): ) max_diff = np.abs(expected_image - image).max() - assert max_diff < 5e-2 + assert max_diff < 8e-1 + + @require_torch_2 + def test_stable_diffusion_compile(self): + seed = 0 + inputs = self.get_inputs(torch_device, seed=seed) + # Can't pickle a Generator object + del inputs["generator"] + inputs["torch_device"] = torch_device + inputs["seed"] = seed + run_test_in_subprocess(test_case=self, target_func=_test_stable_diffusion_compile, inputs=inputs) @slow @@ -985,7 +1131,7 @@ def test_stable_diffusion_ddim(self): "/stable_diffusion_text2img/stable_diffusion_1_4_ddim.npy" ) max_diff = np.abs(expected_image - image).max() - assert max_diff < 1e-3 + assert max_diff < 3e-3 def test_stable_diffusion_lms(self): sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(torch_device) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py index 3bfa5810428a..e16478f06112 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_image_variation.py @@ -30,19 +30,24 @@ UNet2DConditionModel, ) from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class StableDiffusionImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImageVariationPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionImageVariationPipeline params = IMAGE_VARIATION_PARAMS batch_params = IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) @@ -143,6 +148,9 @@ def test_stable_diffusion_img_variation_multiple_images(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu @@ -183,7 +191,7 @@ def test_stable_diffusion_img_variation_pipeline_default(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.84491, 0.90789, 0.75708, 0.78734, 0.83485, 0.70099, 0.66938, 0.68727, 0.61379]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 6e-3 def test_stable_diffusion_img_variation_intermediate_state(self): number_of_steps = 0 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 4262114c78eb..eefbc83ce9d7 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -15,6 +15,7 @@ import gc import random +import traceback import unittest import numpy as np @@ -31,22 +32,65 @@ StableDiffusionImg2ImgPipeline, UNet2DConditionModel, ) -from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, + skip_mps, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_img2img_compile(in_queue, out_queue, timeout): + error = None + try: + inputs = in_queue.get(timeout=timeout) + torch_device = inputs.pop("torch_device") + seed = inputs.pop("seed") + inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed) + + pipe = StableDiffusionImg2ImgPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", safety_checker=None) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) + + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin + assert image.shape == (1, 512, 768, 3) + expected_slice = np.array([0.0593, 0.0607, 0.0851, 0.0582, 0.0636, 0.0721, 0.0751, 0.0981, 0.0781]) + assert np.abs(expected_slice - image_slice).max() < 1e-3 + except Exception: + error = f"{traceback.format_exc()}" -torch.backends.cuda.matmul.allow_tf32 = False + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() -class StableDiffusionImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -96,33 +140,20 @@ def get_dummy_components(self): } return components - def get_dummy_inputs(self, device, seed=0, input_image_type="pt", output_type="np"): + def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image / 2 + 0.5 if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - - if input_image_type == "pt": - input_image = image - elif input_image_type == "np": - input_image = image.cpu().numpy().transpose(0, 2, 3, 1) - elif input_image_type == "pil": - input_image = image.cpu().numpy().transpose(0, 2, 3, 1) - input_image = VaeImageProcessor.numpy_to_pil(input_image) - else: - raise ValueError(f"unsupported input_image_type {input_image_type}.") - - if output_type not in ["pt", "np", "pil"]: - raise ValueError(f"unsupported output_type {output_type}") - inputs = { "prompt": "A painting of a squirrel eating a burger", - "image": input_image, + "image": image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": output_type, + "output_type": "numpy", } return inputs @@ -130,7 +161,6 @@ def test_stable_diffusion_img2img_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -147,7 +177,6 @@ def test_stable_diffusion_img2img_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -166,7 +195,6 @@ def test_stable_diffusion_img2img_multiple_init_images(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -188,7 +216,6 @@ def test_stable_diffusion_img2img_k_lms(self): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe.image_processor = VaeImageProcessor(vae_scale_factor=sd_pipe.vae_scale_factor, do_normalize=False) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) @@ -215,37 +242,10 @@ def test_save_load_optional_components(self): @skip_mps def test_attention_slicing_forward_pass(self): - return super().test_attention_slicing_forward_pass() - - @skip_mps - def test_pt_np_pil_outputs_equivalent(self): - device = "cpu" - components = self.get_dummy_components() - sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) + return super().test_attention_slicing_forward_pass(expected_max_diff=5e-3) - output_pt = sd_pipe(**self.get_dummy_inputs(device, output_type="pt"))[0] - output_np = sd_pipe(**self.get_dummy_inputs(device, output_type="np"))[0] - output_pil = sd_pipe(**self.get_dummy_inputs(device, output_type="pil"))[0] - - assert np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() <= 1e-4 - assert np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() <= 1e-4 - - @skip_mps - def test_image_types_consistent(self): - device = "cpu" - components = self.get_dummy_components() - sd_pipe = StableDiffusionImg2ImgPipeline(**components) - sd_pipe = sd_pipe.to(device) - sd_pipe.set_progress_bar_config(disable=None) - - output_pt = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pt"))[0] - output_np = sd_pipe(**self.get_dummy_inputs(device, input_image_type="np"))[0] - output_pil = sd_pipe(**self.get_dummy_inputs(device, input_image_type="pil"))[0] - - assert np.abs(output_pt - output_np).max() <= 1e-4 - assert np.abs(output_pil - output_np).max() <= 1e-2 + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) @slow @@ -495,6 +495,16 @@ def test_img2img_safety_checker_works(self): assert out.nsfw_content_detected[0], f"Safety checker should work for prompt: {inputs['prompt']}" assert np.abs(out.images[0]).sum() < 1e-5 # should be all zeros + @require_torch_2 + def test_img2img_compile(self): + seed = 0 + inputs = self.get_inputs(torch_device, seed=seed) + # Can't pickle a Generator object + del inputs["generator"] + inputs["torch_device"] = torch_device + inputs["seed"] = seed + run_test_in_subprocess(test_case=self, target_func=_test_img2img_compile, inputs=inputs) + @nightly @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 290d9b0a9134..f761f245883f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -15,6 +15,7 @@ import gc import random +import traceback import unittest import numpy as np @@ -32,19 +33,62 @@ ) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint import prepare_mask_and_masked_image from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import ( + enable_full_determinism, + require_torch_2, + require_torch_gpu, + run_test_in_subprocess, +) +from ...models.test_models_unet_2d_condition import create_lora_layers from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_inpaint_compile(in_queue, out_queue, timeout): + error = None + try: + inputs = in_queue.get(timeout=timeout) + torch_device = inputs.pop("torch_device") + seed = inputs.pop("seed") + inputs["generator"] = torch.Generator(device=torch_device).manual_seed(seed) + + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = PNDMScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.unet.to(memory_format=torch.channels_last) + pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) -torch.backends.cuda.matmul.allow_tf32 = False + image = pipe(**inputs).images + image_slice = image[0, 253:256, 253:256, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272]) + + assert np.abs(expected_slice - image_slice).max() < 3e-3 + except Exception: + error = f"{traceback.format_exc()}" + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() -class StableDiffusionInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + +class StableDiffusionInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) @@ -151,6 +195,133 @@ def test_stable_diffusion_inpaint_image_tensor(self): assert out_pil.shape == (1, 64, 64, 3) assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2 + def test_stable_diffusion_inpaint_lora(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + # forward 1 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + # set lora layers + lora_attn_procs = create_lora_layers(sd_pipe.unet) + sd_pipe.unet.set_attn_processor(lora_attn_procs) + sd_pipe = sd_pipe.to(torch_device) + + # forward 2 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) + image = output.images + image_slice_1 = image[0, -3:, -3:, -1] + + # forward 3 + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) + image = output.images + image_slice_2 = image[0, -3:, -3:, -1] + + assert np.abs(image_slice - image_slice_1).max() < 1e-2 + assert np.abs(image_slice - image_slice_2).max() > 1e-2 + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + def test_stable_diffusion_inpaint_strength_zero_test(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + + # check that the pipeline raises value error when num_inference_steps is < 1 + inputs["strength"] = 0.01 + with self.assertRaises(ValueError): + sd_pipe(**inputs).images + + +class StableDiffusionSimpleInpaintPipelineFastTests(StableDiffusionInpaintPipelineFastTests): + pipeline_class = StableDiffusionInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset([]) + # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + ) + scheduler = PNDMScheduler(skip_prk_steps=True) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + return components + + def test_stable_diffusion_inpaint(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionInpaintPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4925, 0.4967, 0.4100, 0.5234, 0.5322, 0.4532, 0.5805, 0.5877, 0.4151]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + + @unittest.skip("skipped here because area stays unchanged due to mask") + def test_stable_diffusion_inpaint_lora(self): + ... + @slow @require_torch_gpu @@ -199,7 +370,7 @@ def test_stable_diffusion_inpaint_ddim(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.0427, 0.0460, 0.0483, 0.0460, 0.0584, 0.0521, 0.1549, 0.1695, 0.1794]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 6e-4 def test_stable_diffusion_inpaint_fp16(self): pipe = StableDiffusionInpaintPipeline.from_pretrained( @@ -234,7 +405,7 @@ def test_stable_diffusion_inpaint_pndm(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 5e-3 def test_stable_diffusion_inpaint_k_lms(self): pipe = StableDiffusionInpaintPipeline.from_pretrained( @@ -252,7 +423,7 @@ def test_stable_diffusion_inpaint_k_lms(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.9314, 0.7575, 0.9432, 0.8885, 0.9028, 0.7298, 0.9811, 0.9667, 0.7633]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 6e-3 def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self): torch.cuda.empty_cache() @@ -274,6 +445,71 @@ def test_stable_diffusion_inpaint_with_sequential_cpu_offloading(self): # make sure that less than 2.2 GB is allocated assert mem_bytes < 2.2 * 10**9 + @require_torch_2 + def test_inpaint_compile(self): + seed = 0 + inputs = self.get_inputs(torch_device, seed=seed) + # Can't pickle a Generator object + del inputs["generator"] + inputs["torch_device"] = torch_device + inputs["seed"] = seed + run_test_in_subprocess(test_case=self, target_func=_test_inpaint_compile, inputs=inputs) + + def test_stable_diffusion_inpaint_pil_input_resolution_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input image to a random size (one that would cause a tensor mismatch error) + inputs["image"] = inputs["image"].resize((127, 127)) + inputs["mask_image"] = inputs["mask_image"].resize((127, 127)) + inputs["height"] = 128 + inputs["width"] = 128 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, inputs["height"], inputs["width"], 3) + + def test_stable_diffusion_inpaint_strength_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input strength + inputs["strength"] = 0.75 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, 253:256, 253:256, -1].flatten() + expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) + assert np.abs(expected_slice - image_slice).max() < 3e-3 + + def test_stable_diffusion_simple_inpaint_ddim(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images + + image_slice = image[0, 253:256, 253:256, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.5157, 0.6858, 0.6873, 0.4619, 0.6416, 0.6898, 0.3702, 0.5960, 0.6935]) + + assert np.abs(expected_slice - image_slice).max() < 6e-4 + @nightly @require_torch_gpu @@ -371,168 +607,526 @@ def test_inpaint_dpm(self): class StableDiffusionInpaintingPrepareMaskAndMaskedImageTests(unittest.TestCase): def test_pil_inputs(self): - im = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + height, width = 32, 32 + im = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im = Image.fromarray(im) - mask = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) - t_mask, t_masked = prepare_mask_and_masked_image(im, mask) + t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True) self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor)) + self.assertTrue(isinstance(t_image, torch.Tensor)) self.assertEqual(t_mask.ndim, 4) self.assertEqual(t_masked.ndim, 4) + self.assertEqual(t_image.ndim, 4) - self.assertEqual(t_mask.shape, (1, 1, 32, 32)) - self.assertEqual(t_masked.shape, (1, 3, 32, 32)) + self.assertEqual(t_mask.shape, (1, 1, height, width)) + self.assertEqual(t_masked.shape, (1, 3, height, width)) + self.assertEqual(t_image.shape, (1, 3, height, width)) self.assertTrue(t_mask.dtype == torch.float32) self.assertTrue(t_masked.dtype == torch.float32) + self.assertTrue(t_image.dtype == torch.float32) self.assertTrue(t_mask.min() >= 0.0) self.assertTrue(t_mask.max() <= 1.0) self.assertTrue(t_masked.min() >= -1.0) self.assertTrue(t_masked.min() <= 1.0) + self.assertTrue(t_image.min() >= -1.0) + self.assertTrue(t_image.min() >= -1.0) self.assertTrue(t_mask.sum() > 0.0) def test_np_inputs(self): - im_np = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8) + height, width = 32, 32 + + im_np = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) im_pil = Image.fromarray(im_np) - mask_np = np.random.randint(0, 255, (32, 32), dtype=np.uint8) > 127.5 + mask_np = ( + np.random.randint( + 0, + 255, + ( + height, + width, + ), + dtype=np.uint8, + ) + > 127.5 + ) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) - t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) + t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image( + im_pil, mask_pil, height, width, return_image=True + ) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) + self.assertTrue((t_image_np == t_image_pil).all()) def test_torch_3D_2D_inputs(self): - im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_3D_3D_inputs(self): - im_tensor = torch.randint(0, 255, (3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_2D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 1, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_3D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 1, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_4D_inputs(self): - im_tensor = torch.randint(0, 255, (1, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (1, 1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 1, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 1, + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0][0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np) + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_batch_4D_3D(self): - im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 2, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 2, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy() for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) + t_image_np = torch.cat([n[2] for n in nps]) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_batch_4D_4D(self): - im_tensor = torch.randint(0, 255, (2, 3, 32, 32), dtype=torch.uint8) - mask_tensor = torch.randint(0, 255, (2, 1, 32, 32), dtype=torch.uint8) > 127.5 + height, width = 32, 32 + + im_tensor = torch.randint( + 0, + 255, + ( + 2, + 3, + height, + width, + ), + dtype=torch.uint8, + ) + mask_tensor = ( + torch.randint( + 0, + 255, + ( + 2, + 1, + height, + width, + ), + dtype=torch.uint8, + ) + > 127.5 + ) im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image(im_tensor / 127.5 - 1, mask_tensor) - nps = [prepare_mask_and_masked_image(i, m) for i, m in zip(im_nps, mask_nps)] + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True + ) + nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) + t_image_np = torch.cat([n[2] for n in nps]) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_shape_mismatch(self): + height, width = 32, 32 + # test height and width with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(3, 32, 32), torch.randn(64, 64)) + prepare_mask_and_masked_image( + torch.randn( + 3, + height, + width, + ), + torch.randn(64, 64), + height, + width, + return_image=True, + ) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 64, 64)) + prepare_mask_and_masked_image( + torch.randn( + 2, + 3, + height, + width, + ), + torch.randn(4, 64, 64), + height, + width, + return_image=True, + ) # test batch dim with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.randn(2, 3, 32, 32), torch.randn(4, 1, 64, 64)) + prepare_mask_and_masked_image( + torch.randn( + 2, + 3, + height, + width, + ), + torch.randn(4, 1, 64, 64), + height, + width, + return_image=True, + ) def test_type_mismatch(self): + height, width = 32, 32 + # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.rand(3, 32, 32).numpy()) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ), + torch.rand( + 3, + height, + width, + ).numpy(), + height, + width, + return_image=True, + ) # test tensors-only with self.assertRaises(TypeError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32).numpy(), torch.rand(3, 32, 32)) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ).numpy(), + torch.rand( + 3, + height, + width, + ), + height, + width, + return_image=True, + ) def test_channels_first(self): + height, width = 32, 32 + # test channels first for 3D tensors with self.assertRaises(AssertionError): - prepare_mask_and_masked_image(torch.rand(32, 32, 3), torch.rand(3, 32, 32)) + prepare_mask_and_masked_image( + torch.rand(height, width, 3), + torch.rand( + 3, + height, + width, + ), + height, + width, + return_image=True, + ) def test_tensor_range(self): + height, width = 32, 32 + # test im <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, 32, 32) * 2, torch.rand(32, 32)) + prepare_mask_and_masked_image( + torch.ones( + 3, + height, + width, + ) + * 2, + torch.rand( + height, + width, + ), + height, + width, + return_image=True, + ) # test im >= -1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.ones(3, 32, 32) * (-2), torch.rand(32, 32)) + prepare_mask_and_masked_image( + torch.ones( + 3, + height, + width, + ) + * (-2), + torch.rand( + height, + width, + ), + height, + width, + return_image=True, + ) # test mask <= 1 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * 2) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ), + torch.ones( + height, + width, + ) + * 2, + height, + width, + return_image=True, + ) # test mask >= 0 with self.assertRaises(ValueError): - prepare_mask_and_masked_image(torch.rand(3, 32, 32), torch.ones(32, 32) * -1) + prepare_mask_and_masked_image( + torch.rand( + 3, + height, + width, + ), + torch.ones( + height, + width, + ) + * -1, + height, + width, + return_image=True, + ) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index f56fa31a9601..fa00a0d201af 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -34,10 +34,10 @@ VQModel, ) from diffusers.utils import floats_tensor, load_image, nightly, slow, torch_device -from diffusers.utils.testing_utils import load_numpy, preprocess_image, require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, preprocess_image, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class StableDiffusionInpaintLegacyPipelineFastTests(unittest.TestCase): @@ -435,7 +435,7 @@ def test_stable_diffusion_inpaint_legacy_pndm(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.5665, 0.6117, 0.6430, 0.4057, 0.4594, 0.5658, 0.1596, 0.3106, 0.4305]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 3e-3 def test_stable_diffusion_inpaint_legacy_batched(self): pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained( @@ -468,8 +468,8 @@ def test_stable_diffusion_inpaint_legacy_batched(self): [0.3592432, 0.4233033, 0.3914635, 0.31014425, 0.3702293, 0.39412856, 0.17526966, 0.2642669, 0.37480092] ) - assert np.abs(expected_slice_0 - image_slice_0).max() < 1e-4 - assert np.abs(expected_slice_1 - image_slice_1).max() < 1e-4 + assert np.abs(expected_slice_0 - image_slice_0).max() < 3e-3 + assert np.abs(expected_slice_1 - image_slice_1).max() < 3e-3 def test_stable_diffusion_inpaint_legacy_k_lms(self): pipe = StableDiffusionInpaintPipelineLegacy.from_pretrained( @@ -487,7 +487,7 @@ def test_stable_diffusion_inpaint_legacy_k_lms(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.4534, 0.4467, 0.4329, 0.4329, 0.4339, 0.4220, 0.4244, 0.4332, 0.4426]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 3e-3 def test_stable_diffusion_inpaint_legacy_intermediate_state(self): number_of_steps = 0 diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 8915f524d972..691427b1c6eb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -31,20 +31,29 @@ StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel, ) +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_image, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class StableDiffusionInstructPix2PixPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionInstructPix2PixPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -158,6 +167,7 @@ def test_stable_diffusion_pix2pix_multiple_init_images(self): image = np.array(inputs["image"]).astype(np.float32) / 255.0 image = torch.from_numpy(image).unsqueeze(0).to(device) + image = image / 2 + 0.5 image = image.permute(0, 3, 1, 2) inputs["image"] = image.repeat(2, 1, 1, 1) @@ -191,6 +201,31 @@ def test_stable_diffusion_pix2pix_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + # Overwrite the default test_latents_inputs because pix2pix encode the image differently + def test_latents_input(self): + components = self.get_dummy_components() + pipe = StableDiffusionInstructPix2PixPipeline(**components) + pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + + vae = components["vae"] + inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt") + + for image_param in self.image_latents_params: + if image_param in inputs.keys(): + inputs[image_param] = vae.encode(inputs[image_param]).latent_dist.mode() + + out_latents_inputs = pipe(**inputs)[0] + + max_diff = np.abs(out - out_latents_inputs).max() + self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py index 546b1d21252c..4eccb871a0cb 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_k_diffusion.py @@ -21,10 +21,10 @@ from diffusers import StableDiffusionKDiffusionPipeline from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() @slow diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py index bafad63ec2db..f47a70c4ece8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_model_editing.py @@ -29,20 +29,22 @@ UNet2DConditionModel, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() @skip_mps -class StableDiffusionModelEditingPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionModelEditingPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionModelEditingPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -174,6 +176,12 @@ def test_stable_diffusion_model_editing_pndm(self): with self.assertRaises(ValueError): _ = sd_pipe(**inputs).images + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=5e-3) + + def test_attention_slicing_forward_pass(self): + super().test_attention_slicing_forward_pass(expected_max_diff=5e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py index 3ead4fe55bab..32541c980a15 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_panorama.py @@ -30,20 +30,22 @@ UNet2DConditionModel, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() @skip_mps -class StableDiffusionPanoramaPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPanoramaPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -129,7 +131,7 @@ def test_inference_batch_consistent(self): # override to speed the overall test timing up. def test_inference_batch_single_identical(self): - super().test_inference_batch_single_identical(batch_size=2) + super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=3.25e-3) def test_stable_diffusion_panorama_negative_prompt(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -150,6 +152,24 @@ def test_stable_diffusion_panorama_negative_prompt(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_views_batch(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPanoramaPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs, view_batch_size=2) + image = output.images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_panorama_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -173,15 +193,22 @@ def test_stable_diffusion_panorama_euler(self): def test_stable_diffusion_panorama_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() - components["scheduler"] = PNDMScheduler() + components["scheduler"] = PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) sd_pipe = StableDiffusionPanoramaPipeline(**components) sd_pipe = sd_pipe.to(device) sd_pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - # the pipeline does not expect pndm so test if it raises error. - with self.assertRaises(ValueError): - _ = sd_pipe(**inputs).images + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + + expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 661926daaa3e..6f41d2c43c8e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -32,21 +32,28 @@ StableDiffusionPix2PixZeroPipeline, UNet2DConditionModel, ) +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import floats_tensor, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, load_image, load_pt, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() @skip_mps -class StableDiffusionPix2PixZeroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionPix2PixZeroPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPix2PixZeroPipeline - params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"image"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS @classmethod def setUpClass(cls): @@ -127,6 +134,7 @@ def get_dummy_inputs(self, device, seed=0): def get_dummy_inversion_inputs(self, device, seed=0): dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device) + dummy_image = dummy_image / 2 + 0.5 generator = torch.manual_seed(seed) inputs = { @@ -142,6 +150,24 @@ def get_dummy_inversion_inputs(self, device, seed=0): } return inputs + def get_dummy_inversion_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): + inputs = self.get_dummy_inversion_inputs(device, seed) + + if input_image_type == "pt": + image = inputs["image"] + elif input_image_type == "np": + image = VaeImageProcessor.pt_to_numpy(inputs["image"]) + elif input_image_type == "pil": + image = VaeImageProcessor.pt_to_numpy(inputs["image"]) + image = VaeImageProcessor.numpy_to_pil(image) + else: + raise ValueError(f"unsupported input_image_type {input_image_type}") + + inputs["image"] = image + inputs["output_type"] = output_type + + return inputs + def test_save_load_optional_components(self): if not hasattr(self.pipeline_class, "_optional_components"): return @@ -278,6 +304,41 @@ def test_stable_diffusion_pix2pix_zero_ddpm(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_outputs_equivalent(self): + device = torch_device + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + output_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pt")).images + output_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="np")).images + output_pil = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, output_type="pil")).images + + max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() + self.assertLess(max_diff, 1e-4, "`output_type=='pt'` generate different results from `output_type=='np'`") + + max_diff = np.abs(np.array(output_pil[0]) - (output_np[0] * 255).round()).max() + self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`") + + def test_stable_diffusion_pix2pix_zero_inversion_pt_np_pil_inputs_equivalent(self): + device = torch_device + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + out_input_pt = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="pt")).images + out_input_np = sd_pipe.invert(**self.get_dummy_inversion_inputs_by_type(device, input_image_type="np")).images + out_input_pil = sd_pipe.invert( + **self.get_dummy_inversion_inputs_by_type(device, input_image_type="pil") + ).images + + max_diff = np.abs(out_input_pt - out_input_np).max() + self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`") + + assert_mean_pixel_difference(out_input_pil, out_input_np, expected_max_diff=1) + # Non-determinism caused by the scheduler optimizing the latent inputs during inference @unittest.skip("non-deterministic pipeline") def test_inference_batch_single_identical(self): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py index 73859bdbf7d8..91719ce7676f 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_sag.py @@ -27,19 +27,21 @@ UNet2DConditionModel, ) from diffusers.utils import slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class StableDiffusionSAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionSAGPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionSAGPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS test_cpu_offload = False def get_dummy_components(self): @@ -111,6 +113,9 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py index 623dbde99469..33cc7f638ec2 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion.py @@ -33,19 +33,21 @@ logging, ) from diffusers.utils import load_numpy, nightly, slow, torch_device -from diffusers.utils.testing_utils import CaptureLogger, require_torch_gpu +from diffusers.utils.testing_utils import CaptureLogger, enable_full_determinism, require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class StableDiffusion2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2PipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -206,6 +208,27 @@ def test_stable_diffusion_k_euler(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_unflawed(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + components["scheduler"] = DDIMScheduler.from_config( + components["scheduler"].config, timestep_spacing="trailing" + ) + sd_pipe = StableDiffusionPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["guidance_rescale"] = 0.7 + inputs["num_inference_steps"] = 10 + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4736, 0.5405, 0.4705, 0.4955, 0.5675, 0.4812, 0.5310, 0.4967, 0.5064]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_long_prompt(self): components = self.get_dummy_components() components["scheduler"] = LMSDiscreteScheduler.from_config(components["scheduler"].config) @@ -244,6 +267,12 @@ def test_stable_diffusion_long_prompt(self): assert cap_logger.out.count("@") == 25 assert cap_logger_3.out == "" + def test_attention_slicing_forward_pass(self): + super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu @@ -278,7 +307,7 @@ def test_stable_diffusion_default_ddim(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 7e-3 def test_stable_diffusion_pndm(self): pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") @@ -292,7 +321,7 @@ def test_stable_diffusion_pndm(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.49493, 0.47896, 0.40798, 0.54214, 0.53212, 0.48202, 0.47656, 0.46329, 0.48506]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 7e-3 def test_stable_diffusion_k_lms(self): pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-base") @@ -306,7 +335,7 @@ def test_stable_diffusion_k_lms(self): assert image.shape == (1, 512, 512, 3) expected_slice = np.array([0.10440, 0.13115, 0.11100, 0.10141, 0.11440, 0.07215, 0.11332, 0.09693, 0.10006]) - assert np.abs(image_slice - expected_slice).max() < 1e-4 + assert np.abs(image_slice - expected_slice).max() < 3e-3 def test_stable_diffusion_attention_slicing(self): torch.cuda.reset_peak_memory_stats() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py index 846e251f3ce2..304ddacd2c36 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py @@ -29,16 +29,36 @@ from diffusers.utils import load_numpy, skip_mps, slow from diffusers.utils.testing_utils import require_torch_gpu -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False @skip_mps -class StableDiffusionAttendAndExcitePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionAttendAndExcitePipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionAttendAndExcitePipeline test_attention_slicing = False params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"}) + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + # Attend and excite requires being able to run a backward pass at + # inference time. There's no deterministic backward operator for pad + + @classmethod + def setUpClass(cls): + super().setUpClass() + torch.use_deterministic_algorithms(False) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + torch.use_deterministic_algorithms(True) def get_dummy_components(self): torch.manual_seed(0) @@ -138,17 +158,45 @@ def test_inference(self): max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_cpu_offload_forward_pass(self): + super().test_cpu_offload_forward_pass(expected_max_diff=5e-4) + def test_inference_batch_consistent(self): # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches self._test_inference_batch_consistent(batch_sizes=[1, 2]) def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(batch_size=2) + self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=7e-4) + + def test_dict_tuple_outputs_equivalent(self): + super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3) + + def test_pt_np_pil_outputs_equivalent(self): + super().test_pt_np_pil_outputs_equivalent(expected_max_diff=5e-4) + + def test_save_load_local(self): + super().test_save_load_local(expected_max_difference=5e-4) + + def test_save_load_optional_components(self): + super().test_save_load_optional_components(expected_max_difference=4e-4) @require_torch_gpu @slow class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): + # Attend and excite requires being able to run a backward pass at + # inference time. There's no deterministic backward operator for pad + + @classmethod + def setUpClass(cls): + super().setUpClass() + torch.use_deterministic_algorithms(False) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + torch.use_deterministic_algorithms(True) + def tearDown(self): super().tearDown() gc.collect() diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 7a5e02a42af4..f393967c7de4 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -49,22 +49,29 @@ slow, torch_device, ) -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps -from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, + TEXT_GUIDED_IMAGE_VARIATION_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, +) +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() @skip_mps -class StableDiffusionDepth2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionDepth2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionDepth2ImgPipeline test_save_load_optional_components = False params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS def get_dummy_components(self): torch.manual_seed(0) @@ -132,7 +139,7 @@ def get_dummy_components(self): backbone_config=backbone_config, backbone_featmap_shape=[1, 384, 24, 24], ) - depth_estimator = DPTForDepthEstimation(depth_estimator_config) + depth_estimator = DPTForDepthEstimation(depth_estimator_config).eval() feature_extractor = DPTFeatureExtractor.from_pretrained( "hf-internal-testing/tiny-random-DPTForDepthEstimation" ) @@ -359,6 +366,9 @@ def test_stable_diffusion_depth2img_pil(self): def test_attention_slicing_forward_pass(self): return super().test_attention_slicing_forward_pass() + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=7e-3) + @slow @require_torch_gpu @@ -399,7 +409,7 @@ def test_stable_diffusion_depth2img_pipeline_default(self): assert image.shape == (1, 480, 640, 3) expected_slice = np.array([0.5435, 0.4992, 0.3783, 0.4411, 0.5842, 0.4654, 0.3786, 0.5077, 0.4655]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 6e-1 def test_stable_diffusion_depth2img_pipeline_k_lms(self): pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( @@ -417,7 +427,7 @@ def test_stable_diffusion_depth2img_pipeline_k_lms(self): assert image.shape == (1, 480, 640, 3) expected_slice = np.array([0.6363, 0.6274, 0.6309, 0.6370, 0.6226, 0.6286, 0.6213, 0.6453, 0.6306]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 8e-4 def test_stable_diffusion_depth2img_pipeline_ddim(self): pipe = StableDiffusionDepth2ImgPipeline.from_pretrained( @@ -435,7 +445,7 @@ def test_stable_diffusion_depth2img_pipeline_ddim(self): assert image.shape == (1, 480, 640, 3) expected_slice = np.array([0.6424, 0.6524, 0.6249, 0.6041, 0.6634, 0.6420, 0.6522, 0.6555, 0.6436]) - assert np.abs(expected_slice - image_slice).max() < 1e-4 + assert np.abs(expected_slice - image_slice).max() < 5e-4 def test_stable_diffusion_depth2img_intermediate_state(self): number_of_steps = 0 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py new file mode 100644 index 000000000000..1de80d60d8e8 --- /dev/null +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py @@ -0,0 +1,399 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import random +import tempfile +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDIMInverseScheduler, + DDIMScheduler, + DPMSolverMultistepInverseScheduler, + DPMSolverMultistepScheduler, + StableDiffusionDiffEditPipeline, + UNet2DConditionModel, +) +from diffusers.utils import load_image, slow +from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin + + +enable_full_determinism() + + +class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableDiffusionDiffEditPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"} + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + cross_attention_dim=32, + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + ) + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + ) + inverse_scheduler = DDIMInverseScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_zero=False, + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + sample_size=128, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=512, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "unet": unet, + "scheduler": scheduler, + "inverse_scheduler": inverse_scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "safety_checker": None, + "feature_extractor": None, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + mask = floats_tensor((1, 16, 16), rng=random.Random(seed)).to(device) + latents = floats_tensor((1, 2, 4, 16, 16), rng=random.Random(seed)).to(device) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "a dog and a newt", + "mask_image": mask, + "image_latents": latents, + "generator": generator, + "num_inference_steps": 2, + "inpaint_strength": 1.0, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + return inputs + + def get_dummy_mask_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image": image, + "source_prompt": "a cat and a frog", + "target_prompt": "a dog and a newt", + "generator": generator, + "num_inference_steps": 2, + "num_maps_per_mask": 2, + "mask_encode_strength": 1.0, + "guidance_scale": 6.0, + "output_type": "numpy", + } + + return inputs + + def get_dummy_inversion_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "image": image, + "prompt": "a cat and a frog", + "generator": generator, + "num_inference_steps": 2, + "inpaint_strength": 1.0, + "guidance_scale": 6.0, + "decode_latents": True, + "output_type": "numpy", + } + return inputs + + def test_save_load_optional_components(self): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None and update pipeline config accordingly + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components}) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output - output_loaded).max() + self.assertLess(max_diff, 1e-4) + + def test_mask(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_mask_inputs(device) + mask = pipe.generate_mask(**inputs) + mask_slice = mask[0, -3:, -3:] + + self.assertEqual(mask.shape, (1, 16, 16)) + expected_slice = np.array([0] * 9) + max_diff = np.abs(mask_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + self.assertEqual(mask[0, -3, -4], 0) + + def test_inversion(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inversion_inputs(device) + image = pipe.invert(**inputs).images + image_slice = image[0, -1, -3:, -3:] + + self.assertEqual(image.shape, (2, 32, 32, 3)) + expected_slice = np.array( + [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799], + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=5e-3) + + def test_inversion_dpm(self): + device = "cpu" + + components = self.get_dummy_components() + + scheduler_args = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear"} + components["scheduler"] = DPMSolverMultistepScheduler(**scheduler_args) + components["inverse_scheduler"] = DPMSolverMultistepInverseScheduler(**scheduler_args) + + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inversion_inputs(device) + image = pipe.invert(**inputs).images + image_slice = image[0, -1, -3:, -3:] + + self.assertEqual(image.shape, (2, 32, 32, 3)) + expected_slice = np.array( + [0.5150, 0.5134, 0.5043, 0.5376, 0.4694, 0.51050, 0.5015, 0.4407, 0.4799], + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + +@require_torch_gpu +@slow +class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + @classmethod + def setUpClass(cls): + raw_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png" + ) + + raw_image = raw_image.convert("RGB").resize((768, 768)) + + cls.raw_image = raw_image + + def test_stable_diffusion_diffedit_full(self): + generator = torch.manual_seed(0) + + pipe = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + source_prompt = "a bowl of fruit" + target_prompt = "a bowl of pears" + + mask_image = pipe.generate_mask( + image=self.raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, + generator=generator, + ) + + inv_latents = pipe.invert( + prompt=source_prompt, image=self.raw_image, inpaint_strength=0.7, generator=generator + ).latents + + image = pipe( + prompt=target_prompt, + mask_image=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, + inpaint_strength=0.7, + output_type="numpy", + ).images[0] + + expected_image = ( + np.array( + load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/diffedit/pears.png" + ).resize((768, 768)) + ) + / 255 + ) + assert np.abs((expected_image - image).max()) < 5e-1 + + def test_stable_diffusion_diffedit_dpm(self): + generator = torch.manual_seed(0) + + pipe = StableDiffusionDiffEditPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16 + ) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config) + pipe.enable_model_cpu_offload() + pipe.set_progress_bar_config(disable=None) + + source_prompt = "a bowl of fruit" + target_prompt = "a bowl of pears" + + mask_image = pipe.generate_mask( + image=self.raw_image, + source_prompt=source_prompt, + target_prompt=target_prompt, + generator=generator, + ) + + inv_latents = pipe.invert( + prompt=source_prompt, + image=self.raw_image, + inpaint_strength=0.7, + generator=generator, + num_inference_steps=25, + ).latents + + image = pipe( + prompt=target_prompt, + mask_image=mask_image, + image_latents=inv_latents, + generator=generator, + negative_prompt=source_prompt, + inpaint_strength=0.7, + num_inference_steps=25, + output_type="numpy", + ).images[0] + + expected_image = ( + np.array( + load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + "/diffedit/pears.png" + ).resize((768, 768)) + ) + / 255 + ) + assert np.abs((expected_image - image).max()) < 5e-1 diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py index 2fa8b9045f43..37c254f367f3 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_inpaint.py @@ -24,19 +24,23 @@ from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers.utils import floats_tensor, load_image, load_numpy, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, slow +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class StableDiffusion2InpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusion2InpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): torch.manual_seed(0) @@ -130,6 +134,9 @@ def test_stable_diffusion_inpaint(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=3e-3) + @slow @require_torch_gpu @@ -172,7 +179,7 @@ def test_stable_diffusion_inpaint_pipeline(self): image = output.images[0] assert image.shape == (512, 512, 3) - assert np.abs(expected_image - image).max() < 1e-3 + assert np.abs(expected_image - image).max() < 9e-3 def test_stable_diffusion_inpaint_pipeline_fp16(self): init_image = load_image( diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py index aff1c1cdbde9..b94aaca4258a 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_latent_upscale.py @@ -29,16 +29,16 @@ UNet2DConditionModel, ) from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() -class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionLatentUpscalePipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionLatentUpscalePipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - { "height", @@ -49,6 +49,11 @@ class StableDiffusionLatentUpscalePipelineFastTests(PipelineTesterMixin, unittes } required_optional_params = PipelineTesterMixin.required_optional_params - {"num_images_per_prompt"} batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) + test_cpu_offload = True @property @@ -159,8 +164,26 @@ def test_inference(self): max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_attention_slicing_forward_pass(self): + super().test_attention_slicing_forward_pass(expected_max_diff=7e-3) + + def test_cpu_offload_forward_pass(self): + super().test_cpu_offload_forward_pass(expected_max_diff=3e-3) + + def test_dict_tuple_outputs_equivalent(self): + super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3) + def test_inference_batch_single_identical(self): - self._test_inference_batch_single_identical(relax_max_difference=False) + super().test_inference_batch_single_identical(expected_max_diff=7e-3) + + def test_pt_np_pil_outputs_equivalent(self): + super().test_pt_np_pil_outputs_equivalent(expected_max_diff=3e-3) + + def test_save_load_local(self): + super().test_save_load_local(expected_max_difference=3e-3) + + def test_save_load_optional_components(self): + super().test_save_load_optional_components(expected_max_difference=3e-3) @require_torch_gpu diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py index 747809a4fb2e..7100e5023a5d 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -24,10 +24,10 @@ from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, StableDiffusionUpscalePipeline, UNet2DConditionModel from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class StableDiffusionUpscalePipelineFastTests(unittest.TestCase): diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py index 083640a87ba9..21862ba6a216 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_v_pred.py @@ -30,10 +30,10 @@ UNet2DConditionModel, ) from diffusers.utils import load_numpy, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() class StableDiffusion2VPredictionPipelineFastTests(unittest.TestCase): @@ -382,7 +382,30 @@ def test_stable_diffusion_text2img_pipeline_v_pred_default(self): image = output.images[0] assert image.shape == (768, 768, 3) - assert np.abs(expected_image - image).max() < 7.5e-2 + assert np.abs(expected_image - image).max() < 9e-1 + + def test_stable_diffusion_text2img_pipeline_unflawed(self): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/" + "sd2-text2img/lion_galaxy.npy" + ) + + pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1") + pipe.scheduler = DDIMScheduler.from_config( + pipe.scheduler.config, timestep_scaling="trailing", rescale_betas_zero_snr=True + ) + pipe.to(torch_device) + pipe.enable_attention_slicing() + pipe.set_progress_bar_config(disable=None) + + prompt = "A lion in galaxies, spirals, nebulae, stars, smoke, iridescent, intricate detail, octane render, 8k" + + generator = torch.manual_seed(0) + output = pipe(prompt=prompt, guidance_scale=7.5, guidance_rescale=0.7, generator=generator, output_type="np") + image = output.images[0] + + assert image.shape == (768, 768, 3) + assert np.abs(expected_image - image).max() < 5e-1 def test_stable_diffusion_text2img_pipeline_v_pred_fp16(self): expected_image = load_numpy( diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py index c614fa48055e..09e31aacfbc9 100644 --- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py +++ b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py @@ -28,9 +28,6 @@ from diffusers.utils.testing_utils import require_torch_gpu -torch.backends.cuda.matmul.allow_tf32 = False - - class SafeDiffusionPipelineFastTests(unittest.TestCase): def tearDown(self): # clean up the VRAM after each test diff --git a/tests/pipelines/stable_unclip/test_stable_unclip.py b/tests/pipelines/stable_unclip/test_stable_unclip.py index 891323d22fe0..4bbbad757edf 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip.py @@ -13,16 +13,21 @@ UNet2DConditionModel, ) from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer -from diffusers.utils.testing_utils import load_numpy, require_torch_gpu, slow, torch_device +from diffusers.utils.testing_utils import enable_full_determinism, load_numpy, require_torch_gpu, slow, torch_device -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference -class StableUnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +enable_full_determinism() + + +class StableUnCLIPPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableUnCLIPPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS # TODO(will) Expected attn_bias.stride(1) == 0 to be true, but got false test_xformers_attention = False diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index 69e3225ced52..741343066133 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -18,6 +18,7 @@ from diffusers.pipelines.stable_diffusion.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( + enable_full_determinism, floats_tensor, load_image, load_numpy, @@ -29,15 +30,23 @@ from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..test_pipelines_common import ( + PipelineLatentTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference, ) -class StableUnCLIPImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +enable_full_determinism() + + +class StableUnCLIPImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableUnCLIPImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + image_params = frozenset( + [] + ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess + image_latents_params = frozenset([]) def get_dummy_components(self): embedder_hidden_size = 32 diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 168ff8106c52..cd3700d0ccdf 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -20,6 +20,7 @@ import shutil import sys import tempfile +import traceback import unittest import unittest.mock as mock @@ -35,6 +36,7 @@ from diffusers import ( AutoencoderKL, + ConfigMixin, DDIMPipeline, DDIMScheduler, DDPMPipeline, @@ -44,6 +46,7 @@ EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, + ModelMixin, PNDMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, @@ -58,16 +61,82 @@ CONFIG_NAME, WEIGHTS_NAME, floats_tensor, - is_flax_available, + is_compiled_module, nightly, require_torch_2, slow, torch_device, ) -from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, load_numpy, require_compel, require_torch_gpu +from diffusers.utils.testing_utils import ( + CaptureLogger, + enable_full_determinism, + get_tests_dir, + load_numpy, + require_compel, + require_flax, + require_torch_gpu, + run_test_in_subprocess, +) -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() + + +# Will be run via run_test_in_subprocess +def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): + error = None + try: + # 1. Load models + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + model = torch.compile(model) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler) + + # previous diffusers versions stripped compilation off + # compiled modules + assert is_compiled_module(ddpm.unet) + + ddpm.to(torch_device) + ddpm.set_progress_bar_config(disable=None) + + with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) + new_ddpm.to(torch_device) + + generator = torch.Generator(device=torch_device).manual_seed(0) + image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + generator = torch.Generator(device=torch_device).manual_seed(0) + new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images + + assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + except Exception: + error = f"{traceback.format_exc()}" + + results = {"error": error} + out_queue.put(results, timeout=timeout) + out_queue.join() + + +class CustomEncoder(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + +class CustomPipeline(DiffusionPipeline): + def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler): + super().__init__() + self.register_modules(encoder=encoder, scheduler=scheduler) class DownloadTests(unittest.TestCase): @@ -333,7 +402,7 @@ def test_cached_files_are_used_when_no_internet(self): with mock.patch("requests.request", return_value=response_mock): # Download this model to make sure it's in the cache. pipe = StableDiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None, local_files_only=True + "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None ) comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")} @@ -575,6 +644,102 @@ def test_text_inversion_download(self): out = pipe(prompt, num_inference_steps=1, output_type="numpy").images assert out.shape == (1, 128, 128, 3) + # multi embedding load + with tempfile.TemporaryDirectory() as tmpdirname1: + with tempfile.TemporaryDirectory() as tmpdirname2: + ten = {"<*****>": torch.ones((32,))} + torch.save(ten, os.path.join(tmpdirname1, "learned_embeds.bin")) + + ten = {"<******>": 2 * torch.ones((1, 32))} + torch.save(ten, os.path.join(tmpdirname2, "learned_embeds.bin")) + + pipe.load_textual_inversion([tmpdirname1, tmpdirname2]) + + token = pipe.tokenizer.convert_tokens_to_ids("<*****>") + assert token == num_tokens + 8, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32 + assert pipe._maybe_convert_prompt("<*****>", pipe.tokenizer) == "<*****>" + + token = pipe.tokenizer.convert_tokens_to_ids("<******>") + assert token == num_tokens + 9, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 + assert pipe._maybe_convert_prompt("<******>", pipe.tokenizer) == "<******>" + + prompt = "hey <*****> <******>" + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # single token state dict load + ten = {"": torch.ones((32,))} + pipe.load_textual_inversion(ten) + + token = pipe.tokenizer.convert_tokens_to_ids("") + assert token == num_tokens + 10, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 32 + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" + + prompt = "hey " + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multi embedding state dict load + ten1 = {"": torch.ones((32,))} + ten2 = {"": 2 * torch.ones((1, 32))} + + pipe.load_textual_inversion([ten1, ten2]) + + token = pipe.tokenizer.convert_tokens_to_ids("") + assert token == num_tokens + 11, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 32 + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" + + token = pipe.tokenizer.convert_tokens_to_ids("") + assert token == num_tokens + 12, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 64 + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == "" + + prompt = "hey " + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # auto1111 multi-token state dict load + ten = { + "string_to_param": { + "*": torch.cat([3 * torch.ones((1, 32)), 4 * torch.ones((1, 32)), 5 * torch.ones((1, 32))]) + }, + "name": "", + } + + pipe.load_textual_inversion(ten) + + token = pipe.tokenizer.convert_tokens_to_ids("") + token_1 = pipe.tokenizer.convert_tokens_to_ids("_1") + token_2 = pipe.tokenizer.convert_tokens_to_ids("_2") + + assert token == num_tokens + 13, "Added token must be at spot `num_tokens`" + assert token_1 == num_tokens + 14, "Added token must be at spot `num_tokens`" + assert token_2 == num_tokens + 15, "Added token must be at spot `num_tokens`" + assert pipe.text_encoder.get_input_embeddings().weight[-3].sum().item() == 96 + assert pipe.text_encoder.get_input_embeddings().weight[-2].sum().item() == 128 + assert pipe.text_encoder.get_input_embeddings().weight[-1].sum().item() == 160 + assert pipe._maybe_convert_prompt("", pipe.tokenizer) == " _1 _2" + + prompt = "hey " + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + + # multiple references to multi embedding + ten = {"": torch.ones(3, 32)} + pipe.load_textual_inversion(ten) + + assert ( + pipe._maybe_convert_prompt(" ", pipe.tokenizer) == " _1 _2 _1 _2" + ) + + prompt = "hey " + out = pipe(prompt, num_inference_steps=1, output_type="numpy").images + assert out.shape == (1, 128, 128, 3) + def test_download_ignore_files(self): # Check https://huggingface.co/hf-internal-testing/tiny-stable-diffusion-pipe-ignore-files/blob/72f58636e5508a218c6b3f60550dc96445547817/model_index.json#L4 with tempfile.TemporaryDirectory() as tmpdirname: @@ -663,9 +828,25 @@ def test_local_custom_pipeline_file(self): # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 assert output_str == "This is a local test" + def test_custom_model_and_pipeline(self): + pipe = CustomPipeline( + encoder=CustomEncoder(), + scheduler=DDIMScheduler(), + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + + pipe_new = CustomPipeline.from_pretrained(tmpdirname) + pipe_new.save_pretrained(tmpdirname) + + assert dict(pipe_new.config) == dict(pipe.config) + @slow @require_torch_gpu def test_download_from_git(self): + # Because adaptive_avg_pool2d_backward_cuda + # does not have a deterministic implementation. clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" feature_extractor = CLIPImageProcessor.from_pretrained(clip_model_id) @@ -1281,35 +1462,7 @@ def test_from_save_pretrained(self): @require_torch_2 def test_from_save_pretrained_dynamo(self): - # 1. Load models - model = UNet2DModel( - block_out_channels=(32, 64), - layers_per_block=2, - sample_size=32, - in_channels=3, - out_channels=3, - down_block_types=("DownBlock2D", "AttnDownBlock2D"), - up_block_types=("AttnUpBlock2D", "UpBlock2D"), - ) - model = torch.compile(model) - scheduler = DDPMScheduler(num_train_timesteps=10) - - ddpm = DDPMPipeline(model, scheduler) - ddpm.to(torch_device) - ddpm.set_progress_bar_config(disable=None) - - with tempfile.TemporaryDirectory() as tmpdirname: - ddpm.save_pretrained(tmpdirname) - new_ddpm = DDPMPipeline.from_pretrained(tmpdirname) - new_ddpm.to(torch_device) - - generator = torch.Generator(device=torch_device).manual_seed(0) - image = ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images - - generator = torch.Generator(device=torch_device).manual_seed(0) - new_image = new_ddpm(generator=generator, num_inference_steps=5, output_type="numpy").images - - assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" + run_test_in_subprocess(test_case=self, target_func=_test_from_save_pretrained_dynamo, inputs=None) def test_from_pretrained_hub(self): model_path = "google/ddpm-cifar10-32" @@ -1377,15 +1530,13 @@ def test_output_format(self): assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) + @require_flax def test_from_flax_from_pt(self): pipe_pt = StableDiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None ) pipe_pt.to(torch_device) - if not is_flax_available(): - raise ImportError("Make sure flax is installed.") - from diffusers import FlaxStableDiffusionPipeline with tempfile.TemporaryDirectory() as tmpdirname: @@ -1449,7 +1600,7 @@ def test_weighted_prompts_compel(self): f"/compel/forest_{i}.npy" ) - assert np.abs(image - expected_image).max() < 1e-2 + assert np.abs(image - expected_image).max() < 3e-1 @nightly diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0278092282ba..fac04bdbe30f 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -8,18 +8,17 @@ from typing import Callable, Union import numpy as np +import PIL import torch import diffusers from diffusers import DiffusionPipeline +from diffusers.image_processor import VaeImageProcessor from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.testing_utils import require_torch, torch_device -torch.backends.cuda.matmul.allow_tf32 = False - - def to_np(tensor): if isinstance(tensor, torch.Tensor): tensor = tensor.detach().cpu().numpy() @@ -27,6 +26,135 @@ def to_np(tensor): return tensor +class PipelineLatentTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for PyTorch pipeline that has vae, e.g. + equivalence of different input and output types, etc. + """ + + @property + def image_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `image_params` in the child test class. " + "`image_params` are tested for if all accepted input image types (i.e. `pt`,`pil`,`np`) are producing same results" + ) + + @property + def image_latents_params(self) -> frozenset: + raise NotImplementedError( + "You need to set the attribute `image_latents_params` in the child test class. " + "`image_latents_params` are tested for if passing latents directly are producing same results" + ) + + def get_dummy_inputs_by_type(self, device, seed=0, input_image_type="pt", output_type="np"): + inputs = self.get_dummy_inputs(device, seed) + + def convert_to_pt(image): + if isinstance(image, torch.Tensor): + input_image = image + elif isinstance(image, np.ndarray): + input_image = VaeImageProcessor.numpy_to_pt(image) + elif isinstance(image, PIL.Image.Image): + input_image = VaeImageProcessor.pil_to_numpy(image) + input_image = VaeImageProcessor.numpy_to_pt(input_image) + else: + raise ValueError(f"unsupported input_image_type {type(image)}") + return input_image + + def convert_pt_to_type(image, input_image_type): + if input_image_type == "pt": + input_image = image + elif input_image_type == "np": + input_image = VaeImageProcessor.pt_to_numpy(image) + elif input_image_type == "pil": + input_image = VaeImageProcessor.pt_to_numpy(image) + input_image = VaeImageProcessor.numpy_to_pil(input_image) + else: + raise ValueError(f"unsupported input_image_type {input_image_type}.") + return input_image + + for image_param in self.image_params: + if image_param in inputs.keys(): + inputs[image_param] = convert_pt_to_type( + convert_to_pt(inputs[image_param]).to(device), input_image_type + ) + + inputs["output_type"] = output_type + + return inputs + + def test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4): + self._test_pt_np_pil_outputs_equivalent(expected_max_diff=expected_max_diff) + + def _test_pt_np_pil_outputs_equivalent(self, expected_max_diff=1e-4, input_image_type="pt"): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + output_pt = pipe( + **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pt") + )[0] + output_np = pipe( + **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="np") + )[0] + output_pil = pipe( + **self.get_dummy_inputs_by_type(torch_device, input_image_type=input_image_type, output_type="pil") + )[0] + + max_diff = np.abs(output_pt.cpu().numpy().transpose(0, 2, 3, 1) - output_np).max() + self.assertLess( + max_diff, expected_max_diff, "`output_type=='pt'` generate different results from `output_type=='np'`" + ) + + max_diff = np.abs(np.array(output_pil[0]) - (output_np * 255).round()).max() + self.assertLess(max_diff, 2.0, "`output_type=='pil'` generate different results from `output_type=='np'`") + + def test_pt_np_pil_inputs_equivalent(self): + if len(self.image_params) == 0: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out_input_pt = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + out_input_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0] + out_input_pil = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pil"))[0] + + max_diff = np.abs(out_input_pt - out_input_np).max() + self.assertLess(max_diff, 1e-4, "`input_type=='pt'` generate different result from `input_type=='np'`") + max_diff = np.abs(out_input_pil - out_input_np).max() + self.assertLess(max_diff, 1e-2, "`input_type=='pt'` generate different result from `input_type=='np'`") + + def test_latents_input(self): + if len(self.image_latents_params) == 0: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.image_processor = VaeImageProcessor(do_resize=False, do_normalize=False) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + out = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="pt"))[0] + + vae = components["vae"] + inputs = self.get_dummy_inputs_by_type(torch_device, input_image_type="pt") + generator = inputs["generator"] + for image_param in self.image_latents_params: + if image_param in inputs.keys(): + inputs[image_param] = ( + vae.encode(inputs[image_param]).latent_dist.sample(generator) * vae.config.scaling_factor + ) + out_latents_inputs = pipe(**inputs)[0] + + max_diff = np.abs(out - out_latents_inputs).max() + self.assertLess(max_diff, 1e-4, "passing latents as image input generate different result from passing image") + + @require_torch class PipelineTesterMixin: """ @@ -115,7 +243,7 @@ def tearDown(self): gc.collect() torch.cuda.empty_cache() - def test_save_load_local(self): + def test_save_load_local(self, expected_max_difference=1e-4): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device) @@ -134,7 +262,7 @@ def test_save_load_local(self): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, 1e-4) + self.assertLess(max_diff, expected_max_difference) def test_pipeline_call_signature(self): self.assertTrue( @@ -215,7 +343,7 @@ def _test_inference_batch_consistent( for arg in additional_params_copy_to_batched_inputs: batched_inputs[arg] = inputs[arg] - batched_inputs["output_type"] = None + batched_inputs["output_type"] = "np" if self.pipeline_class.__name__ == "DanceDiffusionPipeline": batched_inputs.pop("output_type") @@ -235,8 +363,8 @@ def _test_inference_batch_consistent( logger.setLevel(level=diffusers.logging.WARNING) - def test_inference_batch_single_identical(self, batch_size=3): - self._test_inference_batch_single_identical(batch_size=batch_size) + def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=1e-4): + self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff) def _test_inference_batch_single_identical( self, @@ -318,7 +446,7 @@ def _test_inference_batch_single_identical( if test_mean_pixel_difference: assert_mean_pixel_difference(output_batch[0][0], output[0][0]) - def test_dict_tuple_outputs_equivalent(self): + def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device) @@ -328,7 +456,7 @@ def test_dict_tuple_outputs_equivalent(self): output_tuple = pipe(**self.get_dummy_inputs(torch_device), return_dict=False)[0] max_diff = np.abs(to_np(output) - to_np(output_tuple)).max() - self.assertLess(max_diff, 1e-4) + self.assertLess(max_diff, expected_max_difference) def test_components_function(self): init_components = self.get_dummy_components() @@ -338,10 +466,7 @@ def test_components_function(self): self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") - def test_float16_inference(self): - self._test_float16_inference() - - def _test_float16_inference(self, expected_max_diff=1e-2): + def test_float16_inference(self, expected_max_diff=1e-2): components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe.to(torch_device) @@ -358,10 +483,7 @@ def _test_float16_inference(self, expected_max_diff=1e-2): self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.") @unittest.skipIf(torch_device != "cuda", reason="float16 requires CUDA") - def test_save_load_float16(self): - self._test_save_load_float16() - - def _test_save_load_float16(self, expected_max_diff=1e-2): + def test_save_load_float16(self, expected_max_diff=1e-2): components = self.get_dummy_components() for name, module in components.items(): if hasattr(module, "half"): @@ -394,7 +516,7 @@ def _test_save_load_float16(self, expected_max_diff=1e-2): max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading." ) - def test_save_load_optional_components(self): + def test_save_load_optional_components(self, expected_max_difference=1e-4): if not hasattr(self.pipeline_class, "_optional_components"): return @@ -426,7 +548,7 @@ def test_save_load_optional_components(self): output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() - self.assertLess(max_diff, 1e-4) + self.assertLess(max_diff, expected_max_difference) @unittest.skipIf(torch_device != "cuda", reason="CUDA and CPU are required to switch devices") def test_to_device(self): @@ -460,8 +582,8 @@ def test_to_dtype(self): model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) - def test_attention_slicing_forward_pass(self): - self._test_attention_slicing_forward_pass() + def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3): + self._test_attention_slicing_forward_pass(expected_max_diff=expected_max_diff) def _test_attention_slicing_forward_pass( self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 @@ -492,7 +614,7 @@ def _test_attention_slicing_forward_pass( torch_device != "cuda" or not is_accelerate_available() or is_accelerate_version("<", "0.14.0"), reason="CPU offload is only available with CUDA and `accelerate v0.14.0` or higher", ) - def test_cpu_offload_forward_pass(self): + def test_cpu_offload_forward_pass(self, expected_max_diff=1e-4): if not self.test_cpu_offload: return @@ -509,7 +631,7 @@ def test_cpu_offload_forward_pass(self): output_with_offload = pipe(**inputs)[0] max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max() - self.assertLess(max_diff, 1e-4, "CPU offloading should not affect the inference results") + self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results") @unittest.skipIf( torch_device != "cuda" or not is_xformers_available(), @@ -584,7 +706,7 @@ def test_num_images_per_prompt(self): if key in self.batch_params: inputs[key] = batch_size * [inputs[key]] - images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt).images + images = pipe(**inputs, num_images_per_prompt=num_images_per_prompt)[0] assert images.shape[0] == batch_size * num_images_per_prompt @@ -592,8 +714,8 @@ def test_num_images_per_prompt(self): # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. -def assert_mean_pixel_difference(image, expected_image): +def assert_mean_pixel_difference(image, expected_image, expected_max_diff=10): image = np.asarray(DiffusionPipeline.numpy_to_pil(image)[0], dtype=np.float32) expected_image = np.asarray(DiffusionPipeline.numpy_to_pil(expected_image)[0], dtype=np.float32) avg_diff = np.abs(image - expected_image).mean() - assert avg_diff < 10, f"Error image deviates {avg_diff} pixels on average" + assert avg_diff < expected_max_diff, f"Error image deviates {avg_diff} pixels on average" diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index b59653694616..8b4bae2275e5 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -27,12 +27,13 @@ UNet3DConditionModel, ) from diffusers.utils import load_numpy, skip_mps, slow +from diffusers.utils.testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin -torch.backends.cuda.matmul.allow_tf32 = False +enable_full_determinism() @skip_mps @@ -140,7 +141,7 @@ def test_text_to_video_default_case(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_attention_slicing_forward_pass(self): - self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) + self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3) # (todo): sayakpaul @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py index 5c9181c08e3f..393c3ba1635d 100644 --- a/tests/pipelines/unclip/test_unclip.py +++ b/tests/pipelines/unclip/test_unclip.py @@ -23,12 +23,15 @@ from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.utils import load_numpy, nightly, slow, torch_device -from diffusers.utils.testing_utils import require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +enable_full_determinism() + + class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = UnCLIPPipeline params = TEXT_TO_IMAGE_PARAMS - { diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index c1b8be9cd49e..75a26250807b 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -37,12 +37,15 @@ ) from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel from diffusers.utils import floats_tensor, load_numpy, slow, torch_device -from diffusers.utils.testing_utils import load_image, require_torch_gpu, skip_mps +from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, skip_mps from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +enable_full_determinism() + + class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = UnCLIPImageVariationPipeline params = IMAGE_VARIATION_PARAMS - {"height", "width", "guidance_scale"} @@ -516,4 +519,4 @@ def test_unclip_image_variation_karlo(self): assert image.shape == (256, 256, 3) - assert_mean_pixel_difference(image, expected_image) + assert_mean_pixel_difference(image, expected_image, 15) diff --git a/tests/pipelines/unidiffuser/__init__.py b/tests/pipelines/unidiffuser/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py new file mode 100644 index 000000000000..06cb451281c9 --- /dev/null +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -0,0 +1,673 @@ +import gc +import random +import unittest + +import numpy as np +import torch +from PIL import Image +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTokenizer, + CLIPVisionModelWithProjection, + GPT2Tokenizer, +) + +from diffusers import ( + AutoencoderKL, + DPMSolverMultistepScheduler, + UniDiffuserModel, + UniDiffuserPipeline, + UniDiffuserTextDecoder, +) +from diffusers.utils import floats_tensor, load_image, randn_tensor, slow, torch_device +from diffusers.utils.testing_utils import require_torch_gpu + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +class UniDiffuserPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = UniDiffuserPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + + def get_dummy_components(self): + unet = UniDiffuserModel.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="unet", + ) + + scheduler = DPMSolverMultistepScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + solver_order=3, + ) + + vae = AutoencoderKL.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="vae", + ) + + text_encoder = CLIPTextModel.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="text_encoder", + ) + clip_tokenizer = CLIPTokenizer.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="clip_tokenizer", + ) + + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="image_encoder", + ) + # From the Stable Diffusion Image Variation pipeline tests + image_processor = CLIPImageProcessor(crop_size=32, size=32) + # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip") + + text_tokenizer = GPT2Tokenizer.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="text_tokenizer", + ) + text_decoder = UniDiffuserTextDecoder.from_pretrained( + "hf-internal-testing/unidiffuser-diffusers-test", + subfolder="text_decoder", + ) + + components = { + "vae": vae, + "text_encoder": text_encoder, + "image_encoder": image_encoder, + "image_processor": image_processor, + "clip_tokenizer": clip_tokenizer, + "text_decoder": text_decoder, + "text_tokenizer": text_tokenizer, + "unet": unet, + "scheduler": scheduler, + } + + return components + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = image.cpu().permute(0, 2, 3, 1)[0] + image = Image.fromarray(np.uint8(image)).convert("RGB") + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + } + return inputs + + def get_fixed_latents(self, device, seed=0): + if type(device) == str: + device = torch.device(device) + generator = torch.Generator(device=device).manual_seed(seed) + # Hardcode the shapes for now. + prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32) + vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32) + clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32) + + latents = { + "prompt_latents": prompt_latents, + "vae_latents": vae_latents, + "clip_latents": clip_latents, + } + return latents + + def get_dummy_inputs_with_latents(self, device, seed=0): + # image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + # image = image.cpu().permute(0, 2, 3, 1)[0] + # image = Image.fromarray(np.uint8(image)).convert("RGB") + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg", + ) + image = image.resize((32, 32)) + latents = self.get_fixed_latents(device, seed=seed) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "prompt_latents": latents.get("prompt_latents"), + "vae_latents": latents.get("vae_latents"), + "clip_latents": latents.get("clip_latents"), + } + return inputs + + def test_unidiffuser_default_joint_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_joint_no_cfg_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + # Set guidance scale to 1.0 to turn off CFG + inputs["guidance_scale"] = 1.0 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_text2img_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete image for text-conditioned image generation + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_unidiffuser_default_image_0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img' + unidiffuser_pipe.set_image_mode() + assert unidiffuser_pipe.mode == "img" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for unconditional ("marginal") text generation. + del inputs["prompt"] + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5760, 0.6270, 0.6571, 0.4966, 0.4638, 0.5663, 0.5254, 0.5068, 0.5715]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_unidiffuser_default_text_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img' + unidiffuser_pipe.set_text_mode() + assert unidiffuser_pipe.mode == "text" + + inputs = self.get_dummy_inputs(device) + # Delete prompt and image for unconditional ("marginal") text generation. + del inputs["prompt"] + del inputs["image"] + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_img2text_v0(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_joint_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_default_text2img_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete image for text-conditioned image generation + del inputs["image"] + image = unidiffuser_pipe(**inputs).images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_unidiffuser_default_img2text_v1(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1") + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = " no no no " + assert text[0][:10] == expected_text_prefix + + def test_unidiffuser_text2img_multiple_images(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs(device) + # Delete image for text-conditioned image generation + del inputs["image"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + image = unidiffuser_pipe(**inputs).images + assert image.shape == (2, 32, 32, 3) + + def test_unidiffuser_img2text_multiple_prompts(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + text = unidiffuser_pipe(**inputs).text + + assert len(text) == 3 + + def test_unidiffuser_text2img_multiple_images_with_latents(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete image for text-conditioned image generation + del inputs["image"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + image = unidiffuser_pipe(**inputs).images + assert image.shape == (2, 32, 32, 3) + + def test_unidiffuser_img2text_multiple_prompts_with_latents(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + unidiffuser_pipe = UniDiffuserPipeline(**components) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete text for image-conditioned text generation + del inputs["prompt"] + inputs["num_images_per_prompt"] = 2 + inputs["num_prompts_per_image"] = 3 + text = unidiffuser_pipe(**inputs).text + + assert len(text) == 3 + + def test_inference_batch_single_identical(self): + super().test_inference_batch_single_identical(expected_max_diff=2e-4) + + @require_torch_gpu + def test_unidiffuser_default_joint_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained( + "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16 + ) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'joint' + unidiffuser_pipe.set_joint_mode() + assert unidiffuser_pipe.mode == "joint" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5049, 0.5498, 0.5854, 0.3052, 0.4460, 0.6489, 0.5122, 0.4810, 0.6138]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + expected_text_prefix = '" This This' + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + @require_torch_gpu + def test_unidiffuser_default_text2img_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained( + "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16 + ) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'text2img' + unidiffuser_pipe.set_text_to_image_mode() + assert unidiffuser_pipe.mode == "text2img" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["image"] + inputs["data_type"] = 1 + sample = unidiffuser_pipe(**inputs) + image = sample.images + assert image.shape == (1, 32, 32, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3 + + @require_torch_gpu + def test_unidiffuser_default_img2text_v1_cuda_fp16(self): + device = "cuda" + unidiffuser_pipe = UniDiffuserPipeline.from_pretrained( + "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16 + ) + unidiffuser_pipe = unidiffuser_pipe.to(device) + unidiffuser_pipe.set_progress_bar_config(disable=None) + + # Set mode to 'img2text' + unidiffuser_pipe.set_image_to_text_mode() + assert unidiffuser_pipe.mode == "img2text" + + inputs = self.get_dummy_inputs_with_latents(device) + # Delete prompt and image for joint inference. + del inputs["prompt"] + inputs["data_type"] = 1 + text = unidiffuser_pipe(**inputs).text + + expected_text_prefix = '" This This' + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + +@slow +@require_torch_gpu +class UniDiffuserPipelineSlowTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0, generate_latents=False): + generator = torch.manual_seed(seed) + image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg" + ) + inputs = { + "prompt": "an elephant under the sea", + "image": image, + "generator": generator, + "num_inference_steps": 3, + "guidance_scale": 8.0, + "output_type": "numpy", + } + if generate_latents: + latents = self.get_fixed_latents(device, seed=seed) + for latent_name, latent_tensor in latents.items(): + inputs[latent_name] = latent_tensor + return inputs + + def get_fixed_latents(self, device, seed=0): + if type(device) == str: + device = torch.device(device) + latent_device = torch.device("cpu") + generator = torch.Generator(device=latent_device).manual_seed(seed) + # Hardcode the shapes for now. + prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32) + vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32) + clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32) + + # Move latents onto desired device. + prompt_latents = prompt_latents.to(device) + vae_latents = vae_latents.to(device) + clip_latents = clip_latents.to(device) + + latents = { + "prompt_latents": prompt_latents, + "vae_latents": vae_latents, + "clip_latents": clip_latents, + } + return latents + + def test_unidiffuser_default_joint_v1(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_inputs(device=torch_device, generate_latents=True) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1 + + expected_text_prefix = "a living room" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + def test_unidiffuser_default_text2img_v1(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1 + + def test_unidiffuser_default_img2text_v1(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1") + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["prompt"] + sample = pipe(**inputs) + text = sample.text + + expected_text_prefix = "An astronaut" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + def test_unidiffuser_default_joint_v1_fp16(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + # inputs = self.get_dummy_inputs(device) + inputs = self.get_inputs(device=torch_device, generate_latents=True) + # Delete prompt and image for joint inference. + del inputs["prompt"] + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + text = sample.text + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520]) + assert np.abs(image_slice.flatten() - expected_img_slice).max() < 2e-1 + + expected_text_prefix = "a living room" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix + + def test_unidiffuser_default_text2img_v1_fp16(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["image"] + sample = pipe(**inputs) + image = sample.images + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, -3:, -3:, -1] + expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1 + + def test_unidiffuser_default_img2text_v1_fp16(self): + pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(device=torch_device, generate_latents=True) + del inputs["prompt"] + sample = pipe(**inputs) + text = sample.text + + expected_text_prefix = "An astronaut" + assert text[0][: len(expected_text_prefix)] == expected_text_prefix diff --git a/tests/pipelines/vq_diffusion/test_vq_diffusion.py b/tests/pipelines/vq_diffusion/test_vq_diffusion.py index d97a7b2f6564..3f5ef16cff72 100644 --- a/tests/pipelines/vq_diffusion/test_vq_diffusion.py +++ b/tests/pipelines/vq_diffusion/test_vq_diffusion.py @@ -189,7 +189,7 @@ def test_vq_diffusion_classifier_free_sampling(self): expected_slice = np.array([0.6693, 0.6075, 0.4959, 0.5701, 0.5583, 0.4333, 0.6171, 0.5684, 0.4988]) - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_slice.flatten() - expected_slice).max() < 2.0 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 @@ -225,4 +225,4 @@ def test_vq_diffusion_classifier_free_sampling(self): image = output.images[0] assert image.shape == (256, 256, 3) - assert np.abs(expected_image - image).max() < 1e-2 + assert np.abs(expected_image - image).max() < 2.0 diff --git a/tests/schedulers/test_scheduler_ddim.py b/tests/schedulers/test_scheduler_ddim.py index e9c85314d558..156b02b2208e 100644 --- a/tests/schedulers/test_scheduler_ddim.py +++ b/tests/schedulers/test_scheduler_ddim.py @@ -69,6 +69,14 @@ def test_clip_sample(self): for clip_sample in [True, False]: self.check_over_configs(clip_sample=clip_sample) + def test_timestep_spacing(self): + for timestep_spacing in ["trailing", "leading"]: + self.check_over_configs(timestep_spacing=timestep_spacing) + + def test_rescale_betas_zero_snr(self): + for rescale_betas_zero_snr in [True, False]: + self.check_over_configs(rescale_betas_zero_snr=rescale_betas_zero_snr) + def test_thresholding(self): self.check_over_configs(thresholding=False) for threshold in [0.5, 1.0, 2.0]: diff --git a/tests/schedulers/test_scheduler_dpm_multi.py b/tests/schedulers/test_scheduler_dpm_multi.py index c1593bae3908..c9935780b983 100644 --- a/tests/schedulers/test_scheduler_dpm_multi.py +++ b/tests/schedulers/test_scheduler_dpm_multi.py @@ -29,6 +29,8 @@ def get_scheduler_config(self, **kwargs): "algorithm_type": "dpmsolver++", "solver_type": "midpoint", "lower_order_final": False, + "lambda_min_clipped": -float("inf"), + "variance_type": None, } config.update(**kwargs) @@ -165,16 +167,20 @@ def test_prediction_type(self): self.check_over_configs(prediction_type=prediction_type) def test_solver_order_and_type(self): - for algorithm_type in ["dpmsolver", "dpmsolver++"]: + for algorithm_type in ["dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++"]: for solver_type in ["midpoint", "heun"]: for order in [1, 2, 3]: for prediction_type in ["epsilon", "sample"]: - self.check_over_configs( - solver_order=order, - solver_type=solver_type, - prediction_type=prediction_type, - algorithm_type=algorithm_type, - ) + if algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + if order == 3: + continue + else: + self.check_over_configs( + solver_order=order, + solver_type=solver_type, + prediction_type=prediction_type, + algorithm_type=algorithm_type, + ) sample = self.full_loop( solver_order=order, solver_type=solver_type, @@ -187,6 +193,14 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) diff --git a/tests/schedulers/test_scheduler_dpm_sde.py b/tests/schedulers/test_scheduler_dpm_sde.py new file mode 100644 index 000000000000..7906c8d5d4e9 --- /dev/null +++ b/tests/schedulers/test_scheduler_dpm_sde.py @@ -0,0 +1,168 @@ +import torch + +from diffusers import DPMSolverSDEScheduler +from diffusers.utils import torch_device +from diffusers.utils.testing_utils import require_torchsde + +from .test_schedulers import SchedulerCommonTest + + +@require_torchsde +class DPMSolverSDESchedulerTest(SchedulerCommonTest): + scheduler_classes = (DPMSolverSDEScheduler,) + num_inference_steps = 10 + + def get_scheduler_config(self, **kwargs): + config = { + "num_train_timesteps": 1100, + "beta_start": 0.0001, + "beta_end": 0.02, + "beta_schedule": "linear", + "noise_sampler_seed": 0, + } + + config.update(**kwargs) + return config + + def test_timesteps(self): + for timesteps in [10, 50, 100, 1000]: + self.check_over_configs(num_train_timesteps=timesteps) + + def test_betas(self): + for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]): + self.check_over_configs(beta_start=beta_start, beta_end=beta_end) + + def test_schedules(self): + for schedule in ["linear", "scaled_linear"]: + self.check_over_configs(beta_schedule=schedule) + + def test_prediction_type(self): + for prediction_type in ["epsilon", "v_prediction"]: + self.check_over_configs(prediction_type=prediction_type) + + def test_full_loop_no_noise(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["mps"]: + assert abs(result_sum.item() - 167.47821044921875) < 1e-2 + assert abs(result_mean.item() - 0.2178705964565277) < 1e-3 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 171.59352111816406) < 1e-2 + assert abs(result_mean.item() - 0.22342906892299652) < 1e-3 + else: + assert abs(result_sum.item() - 162.52383422851562) < 1e-2 + assert abs(result_mean.item() - 0.211619570851326) < 1e-3 + + def test_full_loop_with_v_prediction(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(prediction_type="v_prediction") + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps) + + model = self.dummy_model() + sample = self.dummy_sample_deter * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for i, t in enumerate(scheduler.timesteps): + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["mps"]: + assert abs(result_sum.item() - 124.77149200439453) < 1e-2 + assert abs(result_mean.item() - 0.16226289014816284) < 1e-3 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 128.1663360595703) < 1e-2 + assert abs(result_mean.item() - 0.16688326001167297) < 1e-3 + else: + assert abs(result_sum.item() - 119.8487548828125) < 1e-2 + assert abs(result_mean.item() - 0.1560530662536621) < 1e-3 + + def test_full_loop_device(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["mps"]: + assert abs(result_sum.item() - 167.46957397460938) < 1e-2 + assert abs(result_mean.item() - 0.21805934607982635) < 1e-3 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 171.59353637695312) < 1e-2 + assert abs(result_mean.item() - 0.22342908382415771) < 1e-3 + else: + assert abs(result_sum.item() - 162.52383422851562) < 1e-2 + assert abs(result_mean.item() - 0.211619570851326) < 1e-3 + + def test_full_loop_device_karras_sigmas(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + if torch_device in ["mps"]: + assert abs(result_sum.item() - 176.66974135742188) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 + elif torch_device in ["cuda"]: + assert abs(result_sum.item() - 177.63653564453125) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 + else: + assert abs(result_sum.item() - 170.3135223388672) < 1e-2 + assert abs(result_mean.item() - 0.23003872730981811) < 1e-2 diff --git a/tests/schedulers/test_scheduler_dpm_single.py b/tests/schedulers/test_scheduler_dpm_single.py index 9dff04e7c998..66be3d5d00ad 100644 --- a/tests/schedulers/test_scheduler_dpm_single.py +++ b/tests/schedulers/test_scheduler_dpm_single.py @@ -28,6 +28,8 @@ def get_scheduler_config(self, **kwargs): "sample_max_value": 1.0, "algorithm_type": "dpmsolver++", "solver_type": "midpoint", + "lambda_min_clipped": -float("inf"), + "variance_type": None, } config.update(**kwargs) @@ -114,6 +116,22 @@ def full_loop(self, scheduler=None, **config): return sample + def test_full_uneven_loop(self): + scheduler = DPMSolverSinglestepScheduler(**self.get_scheduler_config()) + num_inference_steps = 50 + model = self.dummy_model() + sample = self.dummy_sample_deter + scheduler.set_timesteps(num_inference_steps) + + # make sure that the first t is uneven + for i, t in enumerate(scheduler.timesteps[3:]): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2574) < 1e-3 + def test_timesteps(self): for timesteps in [25, 50, 100, 999, 1000]: self.check_over_configs(num_train_timesteps=timesteps) @@ -179,6 +197,14 @@ def test_lower_order_final(self): self.check_over_configs(lower_order_final=True) self.check_over_configs(lower_order_final=False) + def test_lambda_min_clipped(self): + self.check_over_configs(lambda_min_clipped=-float("inf")) + self.check_over_configs(lambda_min_clipped=-5.1) + + def test_variance_type(self): + self.check_over_configs(variance_type=None) + self.check_over_configs(variance_type="learned_range") + def test_inference_steps(self): for num_inference_steps in [1, 2, 3, 5, 10, 50, 100, 999, 1000]: self.check_over_forward(num_inference_steps=num_inference_steps, time_step=0) @@ -189,12 +215,24 @@ def test_full_loop_no_noise(self): assert abs(result_mean.item() - 0.2791) < 1e-3 + def test_full_loop_with_karras(self): + sample = self.full_loop(use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.2248) < 1e-3 + def test_full_loop_with_v_prediction(self): sample = self.full_loop(prediction_type="v_prediction") result_mean = torch.mean(torch.abs(sample)) assert abs(result_mean.item() - 0.1453) < 1e-3 + def test_full_loop_with_karras_and_v_prediction(self): + sample = self.full_loop(prediction_type="v_prediction", use_karras_sigmas=True) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_mean.item() - 0.0649) < 1e-3 + def test_fp16_support(self): scheduler_class = self.scheduler_classes[0] scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) diff --git a/tests/schedulers/test_scheduler_lms.py b/tests/schedulers/test_scheduler_lms.py index ca3574e9ee63..2682886a788d 100644 --- a/tests/schedulers/test_scheduler_lms.py +++ b/tests/schedulers/test_scheduler_lms.py @@ -113,3 +113,28 @@ def test_full_loop_device(self): assert abs(result_sum.item() - 1006.388) < 1e-2 assert abs(result_mean.item() - 1.31) < 1e-3 + + def test_full_loop_device_karras_sigmas(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, use_karras_sigmas=True) + + scheduler.set_timesteps(self.num_inference_steps, device=torch_device) + + model = self.dummy_model() + sample = self.dummy_sample_deter.to(torch_device) * scheduler.init_noise_sigma + sample = sample.to(torch_device) + + for t in scheduler.timesteps: + sample = scheduler.scale_model_input(sample, t) + + model_output = model(sample, t) + + output = scheduler.step(model_output, t, sample) + sample = output.prev_sample + + result_sum = torch.sum(torch.abs(sample)) + result_mean = torch.mean(torch.abs(sample)) + + assert abs(result_sum.item() - 3812.9927) < 2e-2 + assert abs(result_mean.item() - 4.9648) < 1e-3