Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AnimateDiff 2.0.0-a ControlNet part #2661

Merged
merged 9 commits into from
Mar 2, 2024
18 changes: 18 additions & 0 deletions scripts/batch_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def __init__(self):
self.postprocess_batch_callbacks = [self.on_postprocess_batch]

def img2img_process_batch_hijack(self, p, *args, **kwargs):
try:
from scripts.animatediff_utils import get_animatediff_arg
ad_params = get_animatediff_arg(p)
if ad_params and ad_params.enable:
ad_params.is_i2i_batch = True
from scripts.animatediff_i2ibatch import animatediff_i2i_batch
return animatediff_i2i_batch(p, *args, **kwargs)
except ImportError:
pass

cn_is_batch, batches, output_dir, _ = get_cn_batches(p)
if not cn_is_batch:
return getattr(img2img, '__controlnet_original_process_batch')(p, *args, **kwargs)
Expand All @@ -31,6 +41,14 @@ def img2img_process_batch_hijack(self, p, *args, **kwargs):
self.dispatch_callbacks(self.postprocess_batch_callbacks, p)

def processing_process_images_hijack(self, p, *args, **kwargs):
try:
from scripts.animatediff_utils import get_animatediff_arg
ad_params = get_animatediff_arg(p)
if ad_params and ad_params.enable:
return getattr(processing, '__controlnet_original_process_images_inner')(p, *args, **kwargs)
except ImportError:
pass

if self.is_batch:
# we are in img2img batch tab, do a single batch iteration
return self.process_images_cn_batch(p, *args, **kwargs)
Expand Down
17 changes: 11 additions & 6 deletions scripts/controlmodel_ipadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ class ImageEmbed(NamedTuple):
"""Image embed for a single image."""
cond_emb: torch.Tensor
uncond_emb: torch.Tensor
bypass_average: bool = False

def eval(self, cond_mark: torch.Tensor) -> torch.Tensor:
assert cond_mark.ndim == 4
assert self.cond_emb.ndim == self.uncond_emb.ndim == 3
assert self.cond_emb.shape[0] == self.uncond_emb.shape[0] == 1
assert self.uncond_emb.shape[0] == 1 or self.cond_emb.shape[0] == self.uncond_emb.shape[0]
assert self.cond_emb.shape[0] == 1 or self.cond_emb.shape[0] == cond_mark.shape[0]
cond_mark = cond_mark[:, :, :, 0].to(self.cond_emb)
device = cond_mark.device
dtype = cond_mark.dtype
Expand All @@ -26,7 +28,7 @@ def eval(self, cond_mark: torch.Tensor) -> torch.Tensor:
)

def average_of(*args: List[Tuple[torch.Tensor, torch.Tensor]]) -> "ImageEmbed":
conds, unconds = zip(*args)
conds, unconds, _ = zip(*args)
def average_tensors(tensors: List[torch.Tensor]) -> torch.Tensor:
return torch.sum(torch.stack(tensors), dim=0) / len(tensors)
return ImageEmbed(average_tensors(conds), average_tensors(unconds))
Expand Down Expand Up @@ -603,11 +605,14 @@ def hook(self, model, preprocessor_outputs, weight, start, end, dtype=torch.floa
self.dtype = dtype

self.ipadapter.to(device, dtype=self.dtype)
if isinstance(preprocessor_outputs, (list, tuple)):
preprocessor_outputs = preprocessor_outputs
if getattr(preprocessor_outputs, "bypass_average", False):
self.image_emb = preprocessor_outputs
else:
preprocessor_outputs = [preprocessor_outputs]
self.image_emb = ImageEmbed.average_of(*[self.get_image_emb(o) for o in preprocessor_outputs])
if isinstance(preprocessor_outputs, (list, tuple)):
preprocessor_outputs = preprocessor_outputs
else:
preprocessor_outputs = [preprocessor_outputs]
self.image_emb = ImageEmbed.average_of(*[self.get_image_emb(o) for o in preprocessor_outputs])
# From https://github.com/laksjdjf/IPAdapter-ComfyUI
if not self.sdxl:
number = 0 # index of to_kvs
Expand Down
186 changes: 159 additions & 27 deletions scripts/controlnet.py

Large diffs are not rendered by default.

16 changes: 16 additions & 0 deletions scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scripts.adapter import PlugableAdapter, Adapter, StyleAdapter, Adapter_light
from scripts.controlnet_lllite import PlugableControlLLLite
from scripts.cldm import PlugableControlModel
from scripts.controlnet_sparsectrl import PlugableSparseCtrlModel
from scripts.controlmodel_ipadapter import PlugableIPAdapter
from scripts.logging import logger
from scripts.controlnet_diffusers import convert_from_diffuser_state_dict
Expand Down Expand Up @@ -132,6 +133,21 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel:
network.to(devices.dtype_unet)
return ControlModel(network, ControlModelType.ControlLoRA)

if "down_blocks.0.motion_modules.0.temporal_transformer.norm.weight" in state_dict: # sparsectrl
config = copy.deepcopy(controlnet_default_config)
if "input_hint_block.0.weight" in state_dict: # rgb
config['use_simplified_condition_embedding'] = True
config['conditioning_channels'] = 5
else: # scribble
config['use_simplified_condition_embedding'] = False
config['conditioning_channels'] = 4

config['use_fp16'] = devices.dtype_unet == torch.float16

network = PlugableSparseCtrlModel(config, state_dict)
network.to(devices.dtype_unet)
return ControlModel(network, ControlModelType.SparseCtrl)

if "controlnet_cond_embedding.conv_in.weight" in state_dict: # diffusers
state_dict = convert_from_diffuser_state_dict(state_dict)

Expand Down
99 changes: 99 additions & 0 deletions scripts/controlnet_sparsectrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Tuple, List

import torch
from torch import nn
from torch.nn import functional as F

from scripts.cldm import PlugableControlModel, ControlNet, zero_module, conv_nd, TimestepEmbedSequential

class PlugableSparseCtrlModel(PlugableControlModel):
def __init__(self, config, state_dict=None):
nn.Module.__init__(self)
self.config = config
self.control_model = SparseCtrl(**self.config).cpu()
if state_dict is not None:
self.control_model.load_state_dict(state_dict, strict=False)
self.gpu_component = None


class CondEmbed(nn.Module):
def __init__(
self,
dims: int,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
):
super().__init__()

self.conv_in = conv_nd(dims, conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

self.blocks = nn.ModuleList([])

for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(conv_nd(dims, channel_in, channel_in, kernel_size=3, padding=1))
self.blocks.append(conv_nd(dims, channel_in, channel_out, kernel_size=3, padding=1, stride=2))

self.conv_out = zero_module(conv_nd(dims, block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1))

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)

for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)

embedding = self.conv_out(embedding)

return embedding


class SparseCtrl(ControlNet):
def __init__(self, use_simplified_condition_embedding=True, conditioning_channels=4, **kwargs):
super().__init__(hint_channels=1, **kwargs) # we don't need hint_channels, but we need to set it to 1 to avoid errors
self.use_simplified_condition_embedding = use_simplified_condition_embedding
if use_simplified_condition_embedding:
self.input_hint_block = TimestepEmbedSequential(
zero_module(conv_nd(self.dims, conditioning_channels, kwargs.get("model_channels", 320), kernel_size=3, padding=1)))
else:
self.input_hint_block = TimestepEmbedSequential(
CondEmbed(
self.dims, kwargs.get("model_channels", 320),
conditioning_channels=conditioning_channels,))


def load_state_dict(self, state_dict, strict=False):
mm_dict = {}
cn_dict = {}
for k, v in state_dict.items():
if "motion_modules" in k:
mm_dict[k] = v
else:
cn_dict[k] = v

super().load_state_dict(cn_dict, strict=True)

from scripts.animatediff_mm import MotionWrapper, MotionModuleType
sparsectrl_mm = MotionWrapper("", "", MotionModuleType.SparseCtrl)
sparsectrl_mm.load_state_dict(mm_dict, strict=True)

for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
mm_inject = getattr(sparsectrl_mm.down_blocks[mm_idx0], "motion_modules")[mm_idx1]
self.input_blocks[unet_idx].append(mm_inject)


@staticmethod
def create_cond_mask(control_image_index: List[int], control_image_latents: torch.Tensor, video_length: int):
hint_cond = torch.zeros((video_length, *control_image_latents.shape[1:]), device=control_image_latents.device, dtype=control_image_latents.dtype)
hint_cond[control_image_index] = control_image_latents[:len(control_image_index)]
hint_cond_mask = torch.zeros((hint_cond.shape[0], 1, *hint_cond.shape[2:]), device=control_image_latents.device, dtype=control_image_latents.dtype)
hint_cond_mask[control_image_index] = 1.0
return torch.cat([hint_cond, hint_cond_mask], dim=1)


def forward(self, x, hint, timesteps, context, y=None, **kwargs):
return super().forward(torch.zeros_like(x, device=x.device), hint, timesteps, context, y=y, **kwargs)
2 changes: 1 addition & 1 deletion scripts/controlnet_version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from scripts.logging import logger

version_flag = 'v1.1.440'
version_flag = 'v1.1.441'

logger.info(f"ControlNet {version_flag}")
# A smart trick to know if user has updated as well as if user has restarted terminal.
Expand Down
1 change: 1 addition & 0 deletions scripts/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ControlModelType(Enum):
IPAdapter = "IPAdapter, Hu Ye"
Controlllite = "Controlllite, Kohya"
InstantID = "InstantID, Qixun Wang"
SparseCtrl = "SparseCtrl, Yuwei Guo"

def is_controlnet(self) -> bool:
"""Returns whether the control model should be treated as ControlNet."""
Expand Down
4 changes: 4 additions & 0 deletions scripts/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,10 @@ def select_control_type(
filtered_preprocessor_list += [
x for x in preprocessor_list if "invert" in x.lower()
]
if pattern in ["sparsectrl"]:
filtered_preprocessor_list += [
x for x in preprocessor_list if "scribble" in x.lower()
]
filtered_model_list = [
model for model in all_models
if model.lower() == "none" or
Expand Down
19 changes: 14 additions & 5 deletions scripts/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scripts.logging import logger
from scripts.enums import ControlModelType, AutoMachine, HiResFixOption
from scripts.controlmodel_ipadapter import ImageEmbed
from scripts.controlnet_sparsectrl import SparseCtrl
from modules import devices, lowvram, shared, scripts

from ldm.modules.diffusionmodules.util import timestep_embedding, make_beta_schedule
Expand Down Expand Up @@ -384,17 +385,25 @@ def call_vae_using_process(p, x, batch_size=None, mask=None):
vae_output = vae_cache.get(x)
if vae_output is None:
with devices.autocast():
vae_output = p.sd_model.encode_first_stage(x)
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
vae_output = torch.stack([
p.sd_model.get_first_stage_encoding(
p.sd_model.encode_first_stage(torch.unsqueeze(img, 0).to(device=devices.device))
)[0].to(img.device)
for img in x
])
if torch.all(torch.isnan(vae_output)).item():
logger.info('ControlNet find Nans in the VAE encoding. \n '
'Now ControlNet will automatically retry.\n '
'To always start with 32-bit VAE, use --no-half-vae commandline flag.')
devices.dtype_vae = torch.float32
x = x.to(devices.dtype_vae)
p.sd_model.first_stage_model.to(devices.dtype_vae)
vae_output = p.sd_model.encode_first_stage(x)
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
vae_output = torch.stack([
p.sd_model.get_first_stage_encoding(
p.sd_model.encode_first_stage(torch.unsqueeze(img, 0).to(device=devices.device))
)[0].to(img.device)
for img in x
])
vae_cache.set(x, vae_output)
logger.info(f'ControlNet used {str(devices.dtype_vae)} VAE to encode {vae_output.shape}.')
latent = vae_output
Expand Down Expand Up @@ -571,7 +580,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
controlnet_context = context

# ControlNet inpaint protocol
if hint.shape[1] == 4:
if hint.shape[1] == 4 and not isinstance(control_model, SparseCtrl):
c = hint[:, 0:3, :, :]
m = hint[:, 3:4, :, :]
m = (m > 0.5).float()
Expand Down
1 change: 1 addition & 0 deletions scripts/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1316,6 +1316,7 @@ def run_model(self, img, res=512, **kwargs):
"T2I-Adapter": "none",
"IP-Adapter": "ip-adapter_clip_sd15",
"Instant_ID": "instant_id",
"SparseCtrl": "none",
}

preprocessor_filters_aliases = {
Expand Down
Loading