In [1]:
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

attn_maps = {}
def hook_fn(name):
    def forward_hook(module, input, output):
        if hasattr(module.processor, "attn_map"):
            attn_maps[name] = module.processor.attn_map
            del module.processor.attn_map

    return forward_hook

def register_cross_attention_hook(unet):
    for name, module in unet.named_modules():
        if name.split('.')[-1].startswith('attn2'):
            module.register_forward_hook(hook_fn(name))

    return unet

def upscale(attn_map, target_size):
    attn_map = torch.mean(attn_map, dim=0)
    attn_map = attn_map.permute(1,0)
    temp_size = None

    for i in range(0,5):
        scale = 2 ** i
        if ( target_size[0] // scale ) * ( target_size[1] // scale) == attn_map.shape[1]*64:
            temp_size = (target_size[0]//(scale*8), target_size[1]//(scale*8))
            break

    assert temp_size is not None, "temp_size cannot is None"

    attn_map = attn_map.view(attn_map.shape[0], *temp_size)

    attn_map = F.interpolate(
        attn_map.unsqueeze(0).to(dtype=torch.float32),
        size=target_size,
        mode='bilinear',
        align_corners=False
    )[0]

    attn_map = torch.softmax(attn_map, dim=0)
    return attn_map

def get_net_attn_map(image_size, batch_size=2, instance_or_negative=False, detach=True):

    idx = 0 if instance_or_negative else 1
    net_attn_maps = []

    for name, attn_map in attn_maps.items():
        attn_map = attn_map.cpu() if detach else attn_map
        attn_map = torch.chunk(attn_map, batch_size)[idx].squeeze()
        attn_map = upscale(attn_map, image_size) 
        net_attn_maps.append(attn_map) 

    net_attn_maps = torch.mean(torch.stack(net_attn_maps,dim=0),dim=0)

    return net_attn_maps

def attnmaps2images(net_attn_maps):

    #total_attn_scores = 0
    images = []

    for attn_map in net_attn_maps:
        attn_map = attn_map.cpu().numpy()
        #total_attn_scores += attn_map.mean().item()

        normalized_attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map)) * 255
        normalized_attn_map = normalized_attn_map.astype(np.uint8)
        #print("norm: ", normalized_attn_map.shape)
        image = Image.fromarray(normalized_attn_map)

        #image = fix_save_attn_map(attn_map)
        images.append(image)

    #print(total_attn_scores)
    return images

def is_torch2_available():
    return hasattr(F, "scaled_dot_product_attention")

def get_generator(seed, device):

    if seed is not None:
        if isinstance(seed, list):
            generator = [torch.Generator(device).manual_seed(seed_item) for seed_item in seed]
        else:
            generator = torch.Generator(device).manual_seed(seed)
    else:
        generator = None

    return generator

In [2]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

def cal_l1_loss(y_true, y_pred):
    return torch.mean(torch.abs(y_true - y_pred))

class AttnProcessor(torch.nn.Module):
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
    ):
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        # print(f"hidden_states:{hidden_states.shape}")
        # print(f"query:{query.shape}")
        # print(f"key:{key.shape}")
        # print(f"value:{value.shape}")

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

class KVAttnProcessor(torch.nn.Module):
    def __init__(
        self,
        hidden_size=None,
        cross_attention_dim=None,
        mode="Texture"
    ):
        self.mode = mode
        self.attnLoss = None
        self.queryLoss = None
        super().__init__()
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")


    def __call__(
        self,
        attn,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        *args,
        **kwargs,
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape)

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)

        if self.mode == "Texture":
            key_style = key[0:1]
            value_style = value[0:1]
            query_noise = query[1:2]
            attn_style = F.scaled_dot_product_attention(query_noise, key_style, value_style, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
            attn_noise = hidden_states[1:2]
            self.attnLoss = cal_l1_loss(attn_noise, attn_style)
        elif self.mode == "Style":
            key_style = key[0:1]
            value_style = value[0:1]
            query_noise = query[1:2]
            query_content = query[2:3]
            attn_style = F.scaled_dot_product_attention(query_noise, key_style, value_style, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
            attn_noise = hidden_states[1:2]
            self.attnLoss = cal_l1_loss(attn_noise, attn_style)
            self.queryLoss = cal_l1_loss(query_noise, query_content)


        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

In [None]:
import torch
import PIL.Image
from PIL import Image
from typing import Any, Callable, Dict, List, Optional, Union
import inspect
from torch import autocast
from tqdm import tqdm
from tqdm.auto import tqdm
import numpy as np
from torchvision import transforms as tfms
from diffusers import StableDiffusionPipeline, DDIMScheduler
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import deprecate, logging, BaseOutput
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, DDIMScheduler, AutoencoderKL

def GetLoss(unet,mode):
    loss = torch.tensor(0.0, dtype=torch.float32).cuda()
    for name, module in unet.attn_processors.items():
        if name.endswith("attn1.processor") & ("up" in name) & ("up_blocks.1" not in name):
            loss += module.attnLoss
            if mode == "Style":
                loss += module.queryLoss * 0.2
    return loss

def GetLoss_EDIT(unet,mode):
    loss = torch.tensor(0.0, dtype=torch.float32).cuda()
    for name, module in unet.attn_processors.items():
        if name.endswith("attn2.processor"):
            loss += module.attnLoss
            if mode == "Style":
                loss += module.queryLoss * 0.2
    return loss

class MyPipeline:
    
    def __init__(self):
        self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)
        self.vae = self.pipe.vae
        self.tokenizer = self.pipe.tokenizer
        self.text_encoder = self.pipe.text_encoder
        self.unet = self.pipe.unet
        self.scheduler = self.pipe.scheduler
        # self.scheduler = DDIMScheduler(num_train_timesteps=1000,beta_start=0.00085,beta_end=0.012,beta_schedule="scaled_linear",clip_sample=False,set_alpha_to_one=False,steps_offset=1,)
        
        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        self.vae = self.vae.to(self.device).requires_grad_(False)
        self.text_encoder = self.text_encoder.to(self.device).requires_grad_(False)
        self.unet = self.unet.to(self.device).requires_grad_(False)

    def set_attention(self, mode):
        attn_procs = {}
        for name in self.unet.attn_processors.keys():
            if name.endswith("attn1.processor") & ("up" in name) & ("up_blocks.1" not in name):
                attn_procs[name] = KVAttnProcessor(mode=mode)
            else:
                attn_procs[name] = AttnProcessor()
        self.unet.set_attn_processor(attn_procs)
    
    def sample(self, style_path, content_path=None, num_inference_steps=200, size=512):
        mode = "Texture"
        if content_path is not None:
            mode = "Style"
        self.set_attention(mode)

        # latents
        latent_size = int(size/8)
        noisy_latents = randn_tensor((1, 4, latent_size, latent_size), device=self.device).to(torch.float32)
        # noisy_latents = noisy_latents * self.scheduler.init_noise_sigma
        # style
        style_image = Image.open(style_path).convert('RGB').resize((size,size))
        style_latents = self.vae.encode(tfms.ToTensor()(style_image).unsqueeze(0).cuda() * 2 - 1)
        style_latents = 0.18215 * style_latents.latent_dist.sample()
        # content
        if mode == "Style":
            content_image = Image.open(content_path).convert('RGB').resize((size,size))
            content_latents = self.vae.encode(tfms.ToTensor()(content_image).unsqueeze(0).cuda() * 2 - 1)
            content_latents = 0.18215 * content_latents.latent_dist.sample()

        # text
        prompt = ""
        text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
        if mode == "Texture":
            text_embeddings = torch.cat([text_embeddings, text_embeddings])
        elif mode == "Style":
            text_embeddings = torch.cat([text_embeddings, text_embeddings, text_embeddings])

        # scheduler
        self.scheduler.set_timesteps(num_inference_steps)

        # optimizer
        # noisy_latents=noisy_latents.detach()
        optimizer = torch.optim.Adam([noisy_latents.requires_grad_(True)], lr=0.05)

        with autocast("cuda"):
            for i,t in tqdm(enumerate(self.scheduler.timesteps), total=len(self.scheduler.timesteps)):

                optimizer.zero_grad()

                if mode == "Texture":
                    latent_model_input = torch.cat([style_latents.detach(), noisy_latents])
                elif mode == "Style":
                    latent_model_input = torch.cat([style_latents.detach(), noisy_latents, content_latents.detach()])
                # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)

                loss = GetLoss(self.unet, mode)
                loss.backward()
                optimizer.step()

                # save
                # temp_latent = 1 / 0.18215 * noisy_latents.clone()
                # temp_image = self.vae.decode(temp_latent).sample
                # temp_image = (temp_image / 2 + 0.5).clamp(0, 1)
                # temp_image = temp_image.detach().cpu().permute(0, 2, 3, 1).numpy()
                # temp_image = (temp_image * 255).round().astype("uint8")
                # pil_images = [Image.fromarray(temp_image) for temp_image in temp_image]
                # pil_images[0].save(f"results/{i}.png")

        noisy_latents = 1 / 0.18215 * noisy_latents
        image = self.vae.decode(noisy_latents).sample
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
        images = (image * 255).round().astype("uint8")
        pil_images = [Image.fromarray(image) for image in images]
        
        return pil_images

In [None]:
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, DDIMScheduler
pipe = MyPipeline()

In [None]:
# 纹理生成
style_image_path = "纹理/208.jpg"
images = pipe.sample(style_path=style_image_path, num_inference_steps=150, size=512)
images[0].save("outputs/ADLoss.png")

In [6]:
# 以下为用AD Loss进行风格迁移的代码
# style_image_path = ""
# content_image_path = ""
# images = pipe.sample(style_path=style_image_path, content_path=content_image_path, num_inference_steps=500, size=512)
# images[0].save("outputs/ADLoss.png")