Skip to content

Commit

Permalink
Cache IPAdapter instances to avoid expensive KV extraction on every g…
Browse files Browse the repository at this point in the history
…eneration

lllyasviel#335
  • Loading branch information
Panchovix committed Jul 13, 2024
1 parent 2ee4f7f commit f902bc6
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 5 deletions.
9 changes: 9 additions & 0 deletions extensions-builtin/sd_forge_controlnet/scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modules.ui_components import InputAccordion
from modules.api.api import decode_base64_to_image
import gradio as gr
import time

from lib_controlnet import global_state, external_code
from lib_controlnet.external_code import ControlNetUnit, InputMode
Expand Down Expand Up @@ -506,11 +507,17 @@ def process_unit_before_every_sampling(self,
params.model.positive_advanced_weighting = soft_weighting.copy()
params.model.negative_advanced_weighting = soft_weighting.copy()

model_process_start_time = time.perf_counter()
cond, mask = params.preprocessor.process_before_every_sampling(p, cond, mask, *args, **kwargs)
model_process_end_time = time.perf_counter() - model_process_start_time
logger.debug(f"CN Preprocessor {params.preprocessor.name}: {model_process_end_time:.2f}s.")

params.model.advanced_mask_weighting = mask
model_process_start_time = time.perf_counter()

params.model.process_before_every_sampling(p, cond, mask, *args, **kwargs)
model_process_end_time = time.perf_counter() - model_process_start_time
logger.debug(f"CN Model {type(params.model).__name__}: {model_process_end_time:.2f}s.")

logger.info(f"ControlNet Method {params.preprocessor.name} patched.")
return
Expand Down Expand Up @@ -593,6 +600,8 @@ def on_ui_settings():
{"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_model_cache_size", shared.OptionInfo(
5, "Model cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_ipadapter_cache_size", shared.OptionInfo(
5, "IPAdapter cache size (requires restart)", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}, section=section))
shared.opts.add_option("control_net_no_detectmap", shared.OptionInfo(
False, "Do not append detectmap to output", gr.Checkbox, {"interactive": True}, section=section))
shared.opts.add_option("control_net_detectmap_autosaving", shared.OptionInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import contextlib
import os
import math
import time
from cachetools import LRUCache

import ldm_patched.modules.utils
import ldm_patched.modules.model_management
Expand All @@ -17,6 +19,9 @@
import torchvision.transforms as TT

from lib_ipadapter.resampler import PerceiverAttention, FeedForward, Resampler
from modules import shared

from lib_controlnet.logging import logger

# set the models directory backward compatible
GLOBAL_MODELS_DIR = os.path.join(folder_paths.models_dir, "ipadapter")
Expand Down Expand Up @@ -259,6 +264,28 @@ def NPToTensor(image):
return out

class IPAdapter(nn.Module):
_cache = LRUCache(maxsize=shared.opts.data.get("control_net_ipadapter_cache_size", 5))

# Factory method that caches off of the model filename
@classmethod
def create(cls, model_filename, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024,
clip_embeddings_dim=1024, clip_extra_context_tokens=4,
is_sdxl=False, is_plus=False, is_full=False,
is_faceid=False, is_instant_id=False):
if model_filename in cls._cache:
logger.info(f"IPAdapter: Using cached layers for {model_filename}.")
return cls._cache[model_filename]
else:
logger.info(f"IPAdapter: Creating new layer instance for {model_filename}.")
instance = cls(ipadapter_model, cross_attention_dim, output_cross_attention_dim,
clip_embeddings_dim, clip_extra_context_tokens,
is_sdxl, is_plus, is_full, is_faceid, is_instant_id)

if ldm_patched.modules.model_management.enable_ipadapter_layer_cache():
cls._cache[model_filename] = instance

return instance

def __init__(self, ipadapter_model, cross_attention_dim=1024, output_cross_attention_dim=1024,
clip_embeddings_dim=1024, clip_extra_context_tokens=4,
is_sdxl=False, is_plus=False, is_full=False,
Expand Down Expand Up @@ -612,9 +639,10 @@ def INPUT_TYPES(s):
FUNCTION = "apply_ipadapter"
CATEGORY = "ipadapter"

def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None, weight_type="original",
def apply_ipadapter(self, ipadapter, model_filename, model, weight, clip_vision=None, image=None, weight_type="original",
noise=None, embeds=None, attn_mask=None, start_at=0.0, end_at=1.0, unfold_batch=False,
insightface=None, faceid_v2=False, weight_v2=False, instant_id=False):
apply_ipadapter_start = time.perf_counter()

self.dtype = torch.float16 if ldm_patched.modules.model_management.should_use_fp16() else torch.float32
self.device = ldm_patched.modules.model_management.get_torch_device()
Expand Down Expand Up @@ -720,7 +748,8 @@ def apply_ipadapter(self, ipadapter, model, weight, clip_vision=None, image=None

clip_embeddings_dim = clip_embed.shape[-1]

self.ipadapter = IPAdapter(
self.ipadapter = IPAdapter.create(
model_filename,
ipadapter,
cross_attention_dim=cross_attention_dim,
output_cross_attention_dim=output_cross_attention_dim,
Expand Down Expand Up @@ -799,6 +828,9 @@ def modifier(cnet, x_noisy, t, cond, batched_number):
set_model_patch_replace(work_model, patch_kwargs, ("middle", 0, index))
patch_kwargs["number"] += 1

apply_ipadapter_time = time.perf_counter() - apply_ipadapter_start
logger.debug(f"IPAdapter apply_ipadapter time: {apply_ipadapter_time:.2f}s")

return (work_model, )

class IPAdapterApplyFaceID(IPAdapterApply):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,18 @@ def try_build_from_state_dict(state_dict, ckpt_path):
if "ip_adapter" not in model.keys() or len(model["ip_adapter"]) == 0:
return None

o = IPAdapterPatcher(model)

model_filename = Path(ckpt_path).name.lower()
o = IPAdapterPatcher(model, model_filename)
if 'v2' in model_filename:
o.faceid_v2 = True
o.weight_v2 = True

return o

def __init__(self, state_dict):
def __init__(self, state_dict, model_filename):
super().__init__()
self.ip_adapter = state_dict
self.model_filename = model_filename
self.faceid_v2 = False
self.weight_v2 = False
return
Expand All @@ -146,6 +146,7 @@ def process_before_every_sampling(self, process, cond, mask, *args, **kwargs):

unet = opIPAdapterApply(
ipadapter=self.ip_adapter,
model_filename=self.model_filename,
model=unet,
weight=self.strength,
start_at=self.start_percent,
Expand Down
3 changes: 3 additions & 0 deletions ldm_patched/modules/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ def free_memory(memory_required, device, keep_loaded=[]):
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()

def enable_ipadapter_layer_cache():
return vram_state == VRAMState.HIGH_VRAM

def load_models_gpu(models, memory_required=0):
global vram_state

Expand Down
1 change: 1 addition & 0 deletions requirements_versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ pillow-avif-plugin==1.4.3
albumentations==1.4.3
pydantic==1.10.15
diffusers==0.25.0
cachetools==5.3.2

0 comments on commit f902bc6

Please sign in to comment.