Skip to content

Commit 2a8cf8e

Browse files
authored
Animatediff Proposal (huggingface#5413)
* draft design * clean up * clean up * clean up * clean up * clean up * clean up * clean up * clean up * clean up * update pipeline * clean up * clean up * clean up * add tests * change motion block * clean up * clean up * clean up * update * update * update * update * update * update * update * update * clean up * update * update * update model test * update * update * update * update * make style * update * fix embeddings * update * merge upstream * max fix copies * fix bug * fix mistake * add docs * update * clean up * update * clean up * clean up * fix docstrings * fix docstrings * update * update * clean up * update
1 parent 9ced784 commit 2a8cf8e

18 files changed

+3322
-1
lines changed

docs/source/en/_toctree.yml

+5-1
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@
188188
title: UNet2DConditionModel
189189
- local: api/models/unet3d-cond
190190
title: UNet3DConditionModel
191+
- local: api/models/unet-motion
192+
title: UNetMotionModel
191193
- local: api/models/vq
192194
title: VQModel
193195
- local: api/models/autoencoderkl
@@ -210,6 +212,8 @@
210212
title: Overview
211213
- local: api/pipelines/alt_diffusion
212214
title: AltDiffusion
215+
- local: api/pipelines/animatediff
216+
title: AnimateDiff
213217
- local: api/pipelines/attend_and_excite
214218
title: Attend-and-Excite
215219
- local: api/pipelines/audio_diffusion
@@ -396,5 +400,5 @@
396400
title: Utilities
397401
- local: api/image_processor
398402
title: VAE Image Processor
399-
title: Internal classes
403+
title: Internal classes
400404
title: API
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# UNetMotionModel
2+
3+
The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model.
4+
5+
The abstract from the paper is:
6+
7+
*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.*
8+
9+
## UNetMotionModel
10+
[[autodoc]] UNetMotionModel
11+
12+
## UNet3DConditionOutput
13+
[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput
+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Text-to-Video Generation with AnimateDiff
14+
15+
## Overview
16+
17+
[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) by Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai
18+
19+
The abstract of the paper is the following:
20+
21+
With the advance of text-to-image models (e.g., Stable Diffusion) and corresponding personalization techniques such as DreamBooth and LoRA, everyone can manifest their imagination into high-quality images at an affordable cost. Subsequently, there is a great demand for image animation techniques to further combine generated static images with motion dynamics. In this report, we propose a practical framework to animate most of the existing personalized text-to-image models once and for all, saving efforts in model-specific tuning. At the core of the proposed framework is to insert a newly initialized motion modeling module into the frozen text-to-image model and train it on video clips to distill reasonable motion priors. Once trained, by simply injecting this motion modeling module, all personalized versions derived from the same base T2I readily become text-driven models that produce diverse and personalized animated images. We conduct our evaluation on several public representative personalized text-to-image models across anime pictures and realistic photographs, and demonstrate that our proposed framework helps these models generate temporally smooth animation clips while preserving the domain and diversity of their outputs. Code and pre-trained weights will be publicly available at this https URL .
22+
23+
## Available Pipelines:
24+
25+
| Pipeline | Tasks | Demo
26+
|---|---|:---:|
27+
| [AnimateDiffPipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/animatediff/pipeline_animatediff.py) | *Text-to-Video Generation with AnimateDiff* |
28+
29+
## Usage example
30+
31+
AnimateDiff works with a MotionAdapter checkpoint and a Stable Diffusion model checkpoint. The MotionAdapter is a collection of Motion Modules that are responsible for adding coherent motion across image frames. These modules are applied after the Resnet and Attention blocks in Stable Diffusion UNet.
32+
33+
The following example demonstrates how to use a *MotionAdapter* checkpoint with Diffusers for inference based on StableDiffusion-1.4/1.5.
34+
35+
```python
36+
import torch
37+
from diffusers import MotionAdapter, AnimateDiffPipeline, DDIMScheduler
38+
from diffusers.utils import export_to_gif
39+
40+
# Load the motion adapter
41+
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
42+
# load SD 1.5 based finetuned model
43+
model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
44+
pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter)
45+
scheduler = DDIMScheduler.from_pretrained(
46+
model_id, subfolder="scheduler", clip_sample=False, timestep_spacing="linspace", steps_offset=1
47+
)
48+
pipe.scheduler = scheduler
49+
50+
# enable memory savings
51+
pipe.enable_vae_slicing()
52+
pipe.enable_model_cpu_offload()
53+
54+
output = pipe(
55+
prompt=(
56+
"masterpiece, bestquality, highlydetailed, ultradetailed, sunset, "
57+
"orange sky, warm lighting, fishing boats, ocean waves seagulls, "
58+
"rippling water, wharf, silhouette, serene atmosphere, dusk, evening glow, "
59+
"golden hour, coastal landscape, seaside scenery"
60+
),
61+
negative_prompt="bad quality, worse quality",
62+
num_frames=16,
63+
guidance_scale=7.5,
64+
num_inference_steps=25,
65+
generator=torch.Generator("cpu").manual_seed(42),
66+
)
67+
frames = output.frames[0]
68+
export_to_gif(frames, "animation.gif")
69+
```
70+
71+
Here are some sample outputs:
72+
73+
<table>
74+
<tr>
75+
<td><center>
76+
masterpiece, bestquality, sunset.
77+
<br>
78+
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/animatediff-realistic-doc.gif"
79+
alt="masterpiece, bestquality, sunset"
80+
style="width: 300px;" />
81+
</center></td>
82+
</tr>
83+
</table>
84+
85+
<Tip>
86+
87+
AnimateDiff tends to work better with finetuned Stable Diffusion models. If you plan on using a scheduler that can clip samples, make sure to disable it by setting `clip_sample=False` in the scheduler as this can also have an adverse effect on generated samples.
88+
89+
</Tip>
90+
91+
## AnimateDiffPipeline
92+
[[autodoc]] AnimateDiffPipeline
93+
- all
94+
- __call__
95+
- enable_freeu
96+
- disable_freeu
97+
- enable_vae_slicing
98+
- disable_vae_slicing
99+
- enable_vae_tiling
100+
- disable_vae_tiling
101+
102+
## AnimateDiffPipelineOutput
103+
104+
[[autodoc]] pipelines.animatediff.AnimateDiffPipelineOutput
105+
106+
## Available checkpoints
107+
108+
Motion Adapter checkpoints can be found under [guoyww](https://huggingface.co/guoyww/). These checkpoints are meant to work with any model based on Stable Diffusion 1.4/1.5

src/diffusers/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
"AutoencoderTiny",
8080
"ControlNetModel",
8181
"ModelMixin",
82+
"MotionAdapter",
8283
"MultiAdapter",
8384
"PriorTransformer",
8485
"T2IAdapter",
@@ -88,6 +89,7 @@
8889
"UNet2DConditionModel",
8990
"UNet2DModel",
9091
"UNet3DConditionModel",
92+
"UNetMotionModel",
9193
"VQModel",
9294
]
9395
)
@@ -195,6 +197,7 @@
195197
[
196198
"AltDiffusionImg2ImgPipeline",
197199
"AltDiffusionPipeline",
200+
"AnimateDiffPipeline",
198201
"AudioLDM2Pipeline",
199202
"AudioLDM2ProjectionModel",
200203
"AudioLDM2UNet2DConditionModel",
@@ -440,6 +443,7 @@
440443
AutoencoderTiny,
441444
ControlNetModel,
442445
ModelMixin,
446+
MotionAdapter,
443447
MultiAdapter,
444448
PriorTransformer,
445449
T2IAdapter,
@@ -449,6 +453,7 @@
449453
UNet2DConditionModel,
450454
UNet2DModel,
451455
UNet3DConditionModel,
456+
UNetMotionModel,
452457
VQModel,
453458
)
454459
from .optimization import (
@@ -537,6 +542,7 @@
537542
from .pipelines import (
538543
AltDiffusionImg2ImgPipeline,
539544
AltDiffusionPipeline,
545+
AnimateDiffPipeline,
540546
AudioLDM2Pipeline,
541547
AudioLDM2ProjectionModel,
542548
AudioLDM2UNet2DConditionModel,

src/diffusers/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_import_structure["unet_2d"] = ["UNet2DModel"]
3636
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
3737
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
38+
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
3839
_import_structure["vq_model"] = ["VQModel"]
3940

4041
if is_flax_available():
@@ -60,6 +61,7 @@
6061
from .unet_2d import UNet2DModel
6162
from .unet_2d_condition import UNet2DConditionModel
6263
from .unet_3d_condition import UNet3DConditionModel
64+
from .unet_motion_model import MotionAdapter, UNetMotionModel
6365
from .vq_model import VQModel
6466

6567
if is_flax_available():

src/diffusers/models/attention.py

+22
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..utils.torch_utils import maybe_allow_in_graph
2121
from .activations import GEGLU, GELU, ApproximateGELU
2222
from .attention_processor import Attention
23+
from .embeddings import SinusoidalPositionalEmbedding
2324
from .lora import LoRACompatibleLinear
2425
from .normalization import AdaLayerNorm, AdaLayerNormZero
2526

@@ -96,6 +97,10 @@ class BasicTransformerBlock(nn.Module):
9697
Whether to apply a final dropout after the last feed-forward layer.
9798
attention_type (`str`, *optional*, defaults to `"default"`):
9899
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
100+
positional_embeddings (`str`, *optional*, defaults to `None`):
101+
The type of positional embeddings to apply to.
102+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
103+
The maximum number of positional embeddings to apply.
99104
"""
100105

101106
def __init__(
@@ -115,6 +120,8 @@ def __init__(
115120
norm_type: str = "layer_norm",
116121
final_dropout: bool = False,
117122
attention_type: str = "default",
123+
positional_embeddings: Optional[str] = None,
124+
num_positional_embeddings: Optional[int] = None,
118125
):
119126
super().__init__()
120127
self.only_cross_attention = only_cross_attention
@@ -128,6 +135,16 @@ def __init__(
128135
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
129136
)
130137

138+
if positional_embeddings and (num_positional_embeddings is None):
139+
raise ValueError(
140+
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
141+
)
142+
143+
if positional_embeddings == "sinusoidal":
144+
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
145+
else:
146+
self.pos_embed = None
147+
131148
# Define 3 blocks. Each block has its own normalization layer.
132149
# 1. Self-Attn
133150
if self.use_ada_layer_norm:
@@ -207,6 +224,9 @@ def forward(
207224
else:
208225
norm_hidden_states = self.norm1(hidden_states)
209226

227+
if self.pos_embed is not None:
228+
norm_hidden_states = self.pos_embed(norm_hidden_states)
229+
210230
# 1. Retrieve lora scale.
211231
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
212232

@@ -234,6 +254,8 @@ def forward(
234254
norm_hidden_states = (
235255
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
236256
)
257+
if self.pos_embed is not None:
258+
norm_hidden_states = self.pos_embed(norm_hidden_states)
237259

238260
attn_output = self.attn2(
239261
norm_hidden_states,

src/diffusers/models/embeddings.py

+27
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,33 @@ def forward(self, x):
251251
return out
252252

253253

254+
class SinusoidalPositionalEmbedding(nn.Module):
255+
"""Apply positional information to a sequence of embeddings.
256+
257+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
258+
them
259+
260+
Args:
261+
embed_dim: (int): Dimension of the positional embedding.
262+
max_seq_length: Maximum sequence length to apply positional embeddings
263+
264+
"""
265+
266+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
267+
super().__init__()
268+
position = torch.arange(max_seq_length).unsqueeze(1)
269+
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
270+
pe = torch.zeros(1, max_seq_length, embed_dim)
271+
pe[0, :, 0::2] = torch.sin(position * div_term)
272+
pe[0, :, 1::2] = torch.cos(position * div_term)
273+
self.register_buffer("pe", pe)
274+
275+
def forward(self, x):
276+
_, seq_length, _ = x.shape
277+
x = x + self.pe[:, :seq_length]
278+
return x
279+
280+
254281
class ImagePositionalEmbeddings(nn.Module):
255282
"""
256283
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the

src/diffusers/models/transformer_temporal.py

+8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin):
5959
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
6060
double_self_attention (`bool`, *optional*):
6161
Configure if each `TransformerBlock` should contain two self-attention layers.
62+
positional_embeddings: (`str`, *optional*):
63+
The type of positional embeddings to apply to the sequence input before passing use.
64+
num_positional_embeddings: (`int`, *optional*):
65+
The maximum length of the sequence over which to apply positional embeddings.
6266
"""
6367

6468
@register_to_config
@@ -77,6 +81,8 @@ def __init__(
7781
activation_fn: str = "geglu",
7882
norm_elementwise_affine: bool = True,
7983
double_self_attention: bool = True,
84+
positional_embeddings: Optional[str] = None,
85+
num_positional_embeddings: Optional[int] = None,
8086
):
8187
super().__init__()
8288
self.num_attention_heads = num_attention_heads
@@ -101,6 +107,8 @@ def __init__(
101107
attention_bias=attention_bias,
102108
double_self_attention=double_self_attention,
103109
norm_elementwise_affine=norm_elementwise_affine,
110+
positional_embeddings=positional_embeddings,
111+
num_positional_embeddings=num_positional_embeddings,
104112
)
105113
for d in range(num_layers)
106114
]

0 commit comments

Comments
 (0)