Skip to content

Commit

Permalink
Merge pull request #328 from alibaba/diffsynth
Browse files Browse the repository at this point in the history
Diffsynth
  • Loading branch information
chywang committed Aug 8, 2023
2 parents c058aae + 4208107 commit 1a89b04
Show file tree
Hide file tree
Showing 33 changed files with 16,907 additions and 0 deletions.
156 changes: 156 additions & 0 deletions diffusion/DiffSynth/DiffSynth/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Optional
from diffusers import UNet2DConditionModel, ControlNetModel, AutoencoderKL
from diffusers.models.attention_processor import Attention, AttnProcessor
from einops import rearrange


class CrossFrameAttention(Attention):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: Optional[str] = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
_from_deprecated_attn_block=False,
processor: Optional["AttnProcessor"] = None,
):
super().__init__(
query_dim,
cross_attention_dim,
heads,
dim_head,
dropout,
bias,
upcast_attention,
upcast_softmax,
cross_attention_norm,
cross_attention_norm_num_groups,
added_kv_proj_dim,
norm_num_groups,
spatial_norm_dim,
out_bias,
scale_qk,
only_cross_attention,
eps,
rescale_output_factor,
residual_connection,
_from_deprecated_attn_block,
processor,
)


@classmethod
def from_unet_attention(cls, attn: Attention):
state_dict = attn.state_dict()
inner_dim, query_dim = state_dict["to_q.weight"].shape
cross_attention_dim = state_dict["to_k.weight"].shape[1]
heads = attn.heads
dim_head = inner_dim // heads
cross_frame_attn = cls(
query_dim=query_dim,
cross_attention_dim=cross_attention_dim,
heads=heads,
dim_head=dim_head
)
cross_frame_attn.load_state_dict(state_dict)
cross_frame_attn.to(
device=state_dict["to_q.weight"].device,
dtype=state_dict["to_q.weight"].dtype,
)
return cross_frame_attn


@classmethod
def from_vae_attention(cls, attn: Attention):
state_dict = attn.state_dict()
inner_dim, query_dim = state_dict["to_q.weight"].shape
cross_attention_dim = state_dict["to_k.weight"].shape[1]
heads = attn.heads
dim_head = inner_dim // heads
cross_frame_attn = cls(
query_dim=query_dim,
cross_attention_dim=cross_attention_dim,
heads=heads,
dim_head=dim_head,
bias=True,
upcast_softmax=True,
norm_num_groups=32,
eps=1e-06,
residual_connection=True,
_from_deprecated_attn_block=True
)
cross_frame_attn.load_state_dict(state_dict)
cross_frame_attn.to(
device=state_dict["to_q.weight"].device,
dtype=state_dict["to_q.weight"].dtype,
)
return cross_frame_attn


def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
if encoder_hidden_states is not None:
raise Warning("encoder_hidden_states is not None in CrossFrameAttention")
B = hidden_states.shape[0]
if len(hidden_states.shape)==3:
hidden_states = rearrange(hidden_states, "B N D -> 1 (B N) D")
hidden_states = self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = rearrange(hidden_states, "1 (B N) D -> B N D", B=B)
elif len(hidden_states.shape)==4:
hidden_states = rearrange(hidden_states, "B D H W -> 1 D H (B W)")
hidden_states = self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = rearrange(hidden_states, "1 D H (B W) -> B D H W", B=B)
else:
raise ValueError(f"The shape of hidden_states is {hidden_states.shape}.")
return hidden_states


def set_cross_frame_attention_unet(model):
for module_name, module in model.named_children():
if isinstance(module, Attention) and module_name == "attn1":
setattr(model, module_name, CrossFrameAttention.from_unet_attention(module))
else:
set_cross_frame_attention_unet(getattr(model, module_name))


def set_cross_frame_attention_vae(model):
for module_name, module in model.named_children():
if isinstance(module, Attention):
setattr(model, module_name, CrossFrameAttention.from_vae_attention(module))
else:
set_cross_frame_attention_vae(getattr(model, module_name))


def set_cross_frame_attention(model):
if isinstance(model, UNet2DConditionModel) or isinstance(model, ControlNetModel):
set_cross_frame_attention_unet(model)
elif isinstance(model, AutoencoderKL):
set_cross_frame_attention_vae(model)
else:
raise Warning("Unsupported model architecture.")
76 changes: 76 additions & 0 deletions diffusion/DiffSynth/DiffSynth/controlnet_processors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from .video_level_processors import VideoControlnetImageProcesserOpenpose, VideoControlnetImageProcesserDepth
from controlnet_aux import ContentShuffleDetector, PidiNetDetector, HEDdetector, OpenposeDetector, MidasDetector
import torch, transformers
import numpy as np
from PIL import Image


class ControlnetImageProcesserDepth:
def __init__(self, model_path="Intel/dpt-large", device=torch.device("cuda:0"), threshold=None):
self.depth_estimator = transformers.pipeline(task="depth-estimation", model=model_path, device=device)
self.threshold = threshold

def __call__(self, image):
image = self.depth_estimator(image)['depth']
image = np.array(image)
image = image[:, :, None]
image = np.concatenate([image, image, image], axis=2)
if self.threshold is not None:
image[image<self.threshold] = 0
image = Image.fromarray(image)
return image


class ControlnetImageProcesserHED:
def __init__(self, model_path="lllyasviel/Annotators", detect_resolution=None, device="cuda"):
self.apply_softedge = HEDdetector.from_pretrained(model_path)
self.apply_softedge.netNetwork = self.apply_softedge.netNetwork.to(device)
self.detect_resolution = detect_resolution

def __call__(self, image):
detect_resolution = self.detect_resolution
if detect_resolution is None:
detect_resolution = min(image.size)
image = self.apply_softedge(
image,
detect_resolution=detect_resolution,
image_resolution=min(image.size)
)
return image


class ControlnetImageProcesserShuffle:
def __init__(self, seed=0):
self.seed = seed
self.processor = ContentShuffleDetector()

def __call__(self, image):
np.random.seed(self.seed)
image = self.processor(image)
return image


class ControlnetImageProcesserPose:
def __init__(self, detect_resolution=None, device="cuda"):
self.processor = OpenposeDetector.from_pretrained('lllyasviel/ControlNet')
self.detect_resolution = detect_resolution

def __call__(self, image):
detect_resolution = self.detect_resolution
if detect_resolution is None:
detect_resolution = min(image.size)
image = self.processor(
image,
detect_resolution=detect_resolution,
image_resolution=min(image.size),
hand_and_face=True
)
return image


class ControlnetImageProcesserTile:
def __init__(self, detect_resolution=None, device="cuda"):
pass

def __call__(self, image):
return image
Loading

0 comments on commit 1a89b04

Please sign in to comment.