-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #328 from alibaba/diffsynth
Diffsynth
- Loading branch information
Showing
33 changed files
with
16,907 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
from typing import Optional | ||
from diffusers import UNet2DConditionModel, ControlNetModel, AutoencoderKL | ||
from diffusers.models.attention_processor import Attention, AttnProcessor | ||
from einops import rearrange | ||
|
||
|
||
class CrossFrameAttention(Attention): | ||
def __init__( | ||
self, | ||
query_dim: int, | ||
cross_attention_dim: Optional[int] = None, | ||
heads: int = 8, | ||
dim_head: int = 64, | ||
dropout: float = 0.0, | ||
bias=False, | ||
upcast_attention: bool = False, | ||
upcast_softmax: bool = False, | ||
cross_attention_norm: Optional[str] = None, | ||
cross_attention_norm_num_groups: int = 32, | ||
added_kv_proj_dim: Optional[int] = None, | ||
norm_num_groups: Optional[int] = None, | ||
spatial_norm_dim: Optional[int] = None, | ||
out_bias: bool = True, | ||
scale_qk: bool = True, | ||
only_cross_attention: bool = False, | ||
eps: float = 1e-5, | ||
rescale_output_factor: float = 1.0, | ||
residual_connection: bool = False, | ||
_from_deprecated_attn_block=False, | ||
processor: Optional["AttnProcessor"] = None, | ||
): | ||
super().__init__( | ||
query_dim, | ||
cross_attention_dim, | ||
heads, | ||
dim_head, | ||
dropout, | ||
bias, | ||
upcast_attention, | ||
upcast_softmax, | ||
cross_attention_norm, | ||
cross_attention_norm_num_groups, | ||
added_kv_proj_dim, | ||
norm_num_groups, | ||
spatial_norm_dim, | ||
out_bias, | ||
scale_qk, | ||
only_cross_attention, | ||
eps, | ||
rescale_output_factor, | ||
residual_connection, | ||
_from_deprecated_attn_block, | ||
processor, | ||
) | ||
|
||
|
||
@classmethod | ||
def from_unet_attention(cls, attn: Attention): | ||
state_dict = attn.state_dict() | ||
inner_dim, query_dim = state_dict["to_q.weight"].shape | ||
cross_attention_dim = state_dict["to_k.weight"].shape[1] | ||
heads = attn.heads | ||
dim_head = inner_dim // heads | ||
cross_frame_attn = cls( | ||
query_dim=query_dim, | ||
cross_attention_dim=cross_attention_dim, | ||
heads=heads, | ||
dim_head=dim_head | ||
) | ||
cross_frame_attn.load_state_dict(state_dict) | ||
cross_frame_attn.to( | ||
device=state_dict["to_q.weight"].device, | ||
dtype=state_dict["to_q.weight"].dtype, | ||
) | ||
return cross_frame_attn | ||
|
||
|
||
@classmethod | ||
def from_vae_attention(cls, attn: Attention): | ||
state_dict = attn.state_dict() | ||
inner_dim, query_dim = state_dict["to_q.weight"].shape | ||
cross_attention_dim = state_dict["to_k.weight"].shape[1] | ||
heads = attn.heads | ||
dim_head = inner_dim // heads | ||
cross_frame_attn = cls( | ||
query_dim=query_dim, | ||
cross_attention_dim=cross_attention_dim, | ||
heads=heads, | ||
dim_head=dim_head, | ||
bias=True, | ||
upcast_softmax=True, | ||
norm_num_groups=32, | ||
eps=1e-06, | ||
residual_connection=True, | ||
_from_deprecated_attn_block=True | ||
) | ||
cross_frame_attn.load_state_dict(state_dict) | ||
cross_frame_attn.to( | ||
device=state_dict["to_q.weight"].device, | ||
dtype=state_dict["to_q.weight"].dtype, | ||
) | ||
return cross_frame_attn | ||
|
||
|
||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): | ||
if encoder_hidden_states is not None: | ||
raise Warning("encoder_hidden_states is not None in CrossFrameAttention") | ||
B = hidden_states.shape[0] | ||
if len(hidden_states.shape)==3: | ||
hidden_states = rearrange(hidden_states, "B N D -> 1 (B N) D") | ||
hidden_states = self.processor( | ||
self, | ||
hidden_states, | ||
encoder_hidden_states=encoder_hidden_states, | ||
attention_mask=attention_mask, | ||
**cross_attention_kwargs, | ||
) | ||
hidden_states = rearrange(hidden_states, "1 (B N) D -> B N D", B=B) | ||
elif len(hidden_states.shape)==4: | ||
hidden_states = rearrange(hidden_states, "B D H W -> 1 D H (B W)") | ||
hidden_states = self.processor( | ||
self, | ||
hidden_states, | ||
encoder_hidden_states=encoder_hidden_states, | ||
attention_mask=attention_mask, | ||
**cross_attention_kwargs, | ||
) | ||
hidden_states = rearrange(hidden_states, "1 D H (B W) -> B D H W", B=B) | ||
else: | ||
raise ValueError(f"The shape of hidden_states is {hidden_states.shape}.") | ||
return hidden_states | ||
|
||
|
||
def set_cross_frame_attention_unet(model): | ||
for module_name, module in model.named_children(): | ||
if isinstance(module, Attention) and module_name == "attn1": | ||
setattr(model, module_name, CrossFrameAttention.from_unet_attention(module)) | ||
else: | ||
set_cross_frame_attention_unet(getattr(model, module_name)) | ||
|
||
|
||
def set_cross_frame_attention_vae(model): | ||
for module_name, module in model.named_children(): | ||
if isinstance(module, Attention): | ||
setattr(model, module_name, CrossFrameAttention.from_vae_attention(module)) | ||
else: | ||
set_cross_frame_attention_vae(getattr(model, module_name)) | ||
|
||
|
||
def set_cross_frame_attention(model): | ||
if isinstance(model, UNet2DConditionModel) or isinstance(model, ControlNetModel): | ||
set_cross_frame_attention_unet(model) | ||
elif isinstance(model, AutoencoderKL): | ||
set_cross_frame_attention_vae(model) | ||
else: | ||
raise Warning("Unsupported model architecture.") |
76 changes: 76 additions & 0 deletions
76
diffusion/DiffSynth/DiffSynth/controlnet_processors/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from .video_level_processors import VideoControlnetImageProcesserOpenpose, VideoControlnetImageProcesserDepth | ||
from controlnet_aux import ContentShuffleDetector, PidiNetDetector, HEDdetector, OpenposeDetector, MidasDetector | ||
import torch, transformers | ||
import numpy as np | ||
from PIL import Image | ||
|
||
|
||
class ControlnetImageProcesserDepth: | ||
def __init__(self, model_path="Intel/dpt-large", device=torch.device("cuda:0"), threshold=None): | ||
self.depth_estimator = transformers.pipeline(task="depth-estimation", model=model_path, device=device) | ||
self.threshold = threshold | ||
|
||
def __call__(self, image): | ||
image = self.depth_estimator(image)['depth'] | ||
image = np.array(image) | ||
image = image[:, :, None] | ||
image = np.concatenate([image, image, image], axis=2) | ||
if self.threshold is not None: | ||
image[image<self.threshold] = 0 | ||
image = Image.fromarray(image) | ||
return image | ||
|
||
|
||
class ControlnetImageProcesserHED: | ||
def __init__(self, model_path="lllyasviel/Annotators", detect_resolution=None, device="cuda"): | ||
self.apply_softedge = HEDdetector.from_pretrained(model_path) | ||
self.apply_softedge.netNetwork = self.apply_softedge.netNetwork.to(device) | ||
self.detect_resolution = detect_resolution | ||
|
||
def __call__(self, image): | ||
detect_resolution = self.detect_resolution | ||
if detect_resolution is None: | ||
detect_resolution = min(image.size) | ||
image = self.apply_softedge( | ||
image, | ||
detect_resolution=detect_resolution, | ||
image_resolution=min(image.size) | ||
) | ||
return image | ||
|
||
|
||
class ControlnetImageProcesserShuffle: | ||
def __init__(self, seed=0): | ||
self.seed = seed | ||
self.processor = ContentShuffleDetector() | ||
|
||
def __call__(self, image): | ||
np.random.seed(self.seed) | ||
image = self.processor(image) | ||
return image | ||
|
||
|
||
class ControlnetImageProcesserPose: | ||
def __init__(self, detect_resolution=None, device="cuda"): | ||
self.processor = OpenposeDetector.from_pretrained('lllyasviel/ControlNet') | ||
self.detect_resolution = detect_resolution | ||
|
||
def __call__(self, image): | ||
detect_resolution = self.detect_resolution | ||
if detect_resolution is None: | ||
detect_resolution = min(image.size) | ||
image = self.processor( | ||
image, | ||
detect_resolution=detect_resolution, | ||
image_resolution=min(image.size), | ||
hand_and_face=True | ||
) | ||
return image | ||
|
||
|
||
class ControlnetImageProcesserTile: | ||
def __init__(self, detect_resolution=None, device="cuda"): | ||
pass | ||
|
||
def __call__(self, image): | ||
return image |
Oops, something went wrong.