In [12]:
from diffusers import StableDiffusionPipeline
import torch
from transformers import CLIPTokenizer, CLIPTextModel
from torchvision import transforms
from transformers import CLIPImageProcessor
from PIL import Image
from diffusers.models.attention import Attention as CrossAttention
from datasets import load_dataset
import os
from tqdm import tqdm
from PIL import Image
from safetensors.torch import safe_open
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor

In [13]:
class ImageProjModel(torch.nn.Module):
    """Projection Model"""

    def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
        super().__init__()

        self.generator = None
        self.cross_attention_dim = cross_attention_dim
        self.clip_extra_context_tokens = clip_extra_context_tokens
        self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
        self.norm = torch.nn.LayerNorm(cross_attention_dim)

    def forward(self, image_embeds):
        embeds = image_embeds
        clip_extra_context_tokens = self.proj(embeds).reshape(
            -1, self.clip_extra_context_tokens, self.cross_attention_dim
        )
        clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
        return clip_extra_context_tokens


In [14]:
class IPAdapterModule:
    def __init__(self, image_encoder_path, ip_ckpt, device, num_tokens=4, cross_attention_dim=768):
        self.device = device
        self.image_encoder_path = image_encoder_path
        self.ip_ckpt = ip_ckpt
        self.num_tokens = num_tokens
        self.cross_attention_dim = cross_attention_dim

        self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
            self.device, dtype=torch.float16
        )
        self.clip_image_processor = CLIPImageProcessor()
        self.image_proj_model = self.init_proj()
        self.load_ip_adapter_weights()

    def init_proj(self):# или путь к вашей реализации
        image_proj_model = ImageProjModel(
            cross_attention_dim=self.cross_attention_dim,
            clip_embeddings_dim=self.image_encoder.config.projection_dim,
            clip_extra_context_tokens=self.num_tokens,
        ).to(self.device, dtype=torch.float16)
        return image_proj_model

    def load_ip_adapter_weights(self):
        if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
            state_dict = {"image_proj": {}, "ip_adapter": {}}
            with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
                for key in f.keys():
                    if key.startswith("image_proj."):
                        state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
                    elif key.startswith("ip_adapter."):
                        state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
        else:
            state_dict = torch.load(self.ip_ckpt, map_location="cpu")
        self.image_proj_model.load_state_dict(state_dict["image_proj"])
        self._attn_weights = state_dict["ip_adapter"]  # храним отдельно, применим в set_ip_adapter

    @torch.no_grad()
    def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
        if pil_image is not None:
            if isinstance(pil_image, Image.Image):
                pil_image = [pil_image]
            clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
            clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
        else:
            clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)

        image_prompt_embeds = self.image_proj_model(clip_image_embeds)
        uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
        return image_prompt_embeds, uncond_image_prompt_embeds

    def set_ip_adapter(self, unet):
        from attention_processor import IPAttnProcessor, AttnProcessor  
        attn_procs = {}
        for name in unet.attn_processors.keys():
            if name.endswith("attn1.processor"):
                attn_procs[name] = AttnProcessor()
                continue

            cross_attention_dim = self.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]

            ip_proc = IPAttnProcessor(
                hidden_size=hidden_size,
                cross_attention_dim=cross_attention_dim,
                scale=1.0,
                num_tokens=self.num_tokens,
            ).to(self.device, dtype=torch.float16)
            attn_procs[name] = ip_proc

        unet.set_attn_processor(attn_procs)

        # загружаем веса IP-адаптера (после set_attn_processor)
        ip_layers = torch.nn.ModuleList(unet.attn_processors.values())
        ip_layers.load_state_dict(self._attn_weights)

    def set_scale(self, unet, scale):
        for attn_proc in unet.attn_processors.values():
            if hasattr(attn_proc, "scale"):
                attn_proc.scale = scale

In [26]:
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
device = "cuda"
# Создаём pipeline и получаем UNet
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None,).to("cuda")
ckpt_path = "C:/comfy/ComfyUI_windows_portable/ComfyUI/models/checkpoints/dreamshaper_8.safetensors"   # или путь к SDXL
unet = pipe.unet
state_dict = load_file(ckpt_path)  # возвращает словарь {param_name: tensor}

# Фильтруем и загружаем только веса UNet
unet_keys = [k for k in state_dict.keys() if k.startswith("model.diffusion_model.")]
mapped_state_dict = {
    k.replace("model.diffusion_model.", ""): v for k, v in state_dict.items() if k.startswith("model.diffusion_model.")
}
missing, unexpected = unet.load_state_dict(mapped_state_dict, strict=False)
print(f"Missing: {missing}, Unexpected: {unexpected}")

# Инициализируем модуль
ip_adapter = IPAdapterModule(
    image_encoder_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
    ip_ckpt="C:/comfy/ComfyUI_windows_portable/ComfyUI/models/ipadapter/ip-adapter_sd15.safetensors",
    device="cuda",
    num_tokens=4,
    cross_attention_dim=768
)



Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


Missing: ['conv_in.weight', 'conv_in.bias', 'time_embedding.linear_1.weight', 'time_embedding.linear_1.bias', 'time_embedding.linear_2.weight', 'time_embedding.linear_2.bias', 'down_blocks.0.attentions.0.norm.weight', 'down_blocks.0.attentions.0.norm.bias', 'down_blocks.0.attentions.0.proj_in.weight', 'down_blocks.0.attentions.0.proj_in.bias', 'down_blocks.0.attentions.0.transformer_blocks.0.norm1.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.norm1.bias', 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_k.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_v.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_out.0.bias', 'down_blocks.0.attentions.0.transformer_blocks.0.norm2.weight', 'down_blocks.0.attentions.0.transformer_blocks.0.norm2.bias', 'down_blocks.0.attentions.0.transformer_blocks.0.at

In [27]:
# 3. Применяем патч к UNet
ip_adapter.set_ip_adapter(pipe.unet)

# 4. Загрузка изображения для conditioning
image_path = "example.png"  # путь к conditioning изображению
image = Image.open(image_path).convert("RGB")

# 5. Получаем image prompt embeddings
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(pil_image=image)

# 6. Подготовка текстового промпта
prompt = "a majestic castle in the mountains at sunset,  "
negative_prompt = "blurry, low quality"

# 7. Получаем текстовые эмбеддинги (вне @torch.inference_mode, если в train loop)
with torch.no_grad():
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=prompt,
        negative_prompt=negative_prompt,
        device=device,
        do_classifier_free_guidance=True,
        num_images_per_prompt=1
    )

    # # 8. Объединяем текст и визуальные эмбеддинги
    # prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
    # negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)

# 9. Генерация изображения
generator = torch.manual_seed(42)
output = pipe(
    prompt_embeds=prompt_embeds,
    negative_prompt_embeds=negative_prompt_embeds,
    guidance_scale=7.5,
    num_inference_steps=30,
    generator=generator
)

# 10. Сохраняем изображение
output.images[0].save("generated_with_ip_adapter.png")

  0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
ip_adapter.set_ip_adapter(unet)


In [3]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to(device)
pipe.enable_attention_slicing()

Couldn't connect to the Hub: (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/models/runwayml/stable-diffusion-v1-5 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001DF8008C050>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: ee3c77e8-2274-4225-bfe7-8a54e51be369)').
Will try to load from local cache.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder

prompt = "a majestic cat in a futuristic city"
tokens = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77)
text_embed = text_encoder(tokens.input_ids.to(device))[0]  # shape: [1, 77, 768]

In [5]:
def build_patch(block_to_conditioning: dict, weight: float = 1.0):
    def attn_patch(attn_module, encoder_hidden_states, **kwargs):
        # определяем, нужен ли патч
        block_name = getattr(attn_module, "block_name", None)
        if block_name not in block_to_conditioning:
            return encoder_hidden_states

        cond = block_to_conditioning[block_name]  # твой внешний conditioning
        if cond.shape != encoder_hidden_states.shape:
            raise ValueError("Shape mismatch in attention patch")
        
        # смешиваем conditioning и оригинальный контекст
        return (1 - weight) * encoder_hidden_states + weight * cond

    return attn_patch

In [6]:
from types import MethodType

def patch_unet_cross_attention(unet, conditioning, target_blocks=None, weight=1.0):
    """
    Подменяет encoder_hidden_states в CrossAttention модулях UNet.
    target_blocks — список имён модулей, которые мы хотим патчить, 
                    например ["mid_block.attentions.0", ...]. Если None — все.
    """
    for name, module in unet.named_modules():
        # патчим только CrossAttention
        if not isinstance(module, CrossAttention):
            continue
        # если указан список блоков — фильтруем по нему
        if target_blocks is not None:
            if not any(name.startswith(tb) for tb in target_blocks):
                continue

        # сохраняем старый forward
        old_forward = module.forward

        def make_new_forward(old_fwd):
            def new_forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
                # если есть encoder_hidden_states — смешиваем
                if encoder_hidden_states is not None:
                    # conditioning: [B, seq_len, D], encoder_hidden_states: same shape
                    encoder_hidden_states = (
                        (1 - weight) * encoder_hidden_states + weight * conditioning
                    )
                # вызываем оригинальный forward
                return old_fwd(hidden_states, encoder_hidden_states=encoder_hidden_states, **kwargs)
            return new_forward

        # связываем новый метод
        module.forward = MethodType(make_new_forward(old_forward), module)
        print(f"Patched CrossAttention at {name}")

In [8]:
# Пример использования:
device = torch.device("cuda")
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(device)
pipe.enable_attention_slicing()

# Получаем текстовый эмбеддинг
tokenizer = pipe.tokenizer
text_encoder = pipe.text_encoder
prompt = "an astronaut riding a horse, oil painting style"
tokens = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=77)
text_embed = text_encoder(tokens.input_ids.to(device))[0]  # [1, 77, 768]

# Патчим только средний блок (можно указать несколько)
patch_unet_cross_attention(
    pipe.unet,
    conditioning=text_embed,
    target_blocks=["mid_block.attentions.0"],
    weight=1.0
)

# Генерация
generator = torch.manual_seed(42)
image = pipe(
    prompt=prompt,
    num_inference_steps=30,
    guidance_scale=7.5,
    generator=generator
).images[0]

image.save("patched_output.png")

Couldn't connect to the Hub: (MaxRetryError('HTTPSConnectionPool(host=\'huggingface.co\', port=443): Max retries exceeded with url: /api/models/runwayml/stable-diffusion-v1-5 (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001E0596D6210>: Failed to resolve \'huggingface.co\' ([Errno 11001] getaddrinfo failed)"))'), '(Request ID: b931a02d-85b3-4f9d-88df-8ec056ce6774)').
Will try to load from local cache.


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Patched CrossAttention at mid_block.attentions.0.transformer_blocks.0.attn1
Patched CrossAttention at mid_block.attentions.0.transformer_blocks.0.attn2


  0%|          | 0/30 [00:00<?, ?it/s]