In [1]:
import torch
from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel
import matplotlib.pyplot as plt
import math
# this notebook aim to inject attention map by hand
#NOTE: Last tested working diffusers version is diffusers==0.4.1, https://github.com/huggingface/diffusers/releases/tag/v0.4.1

In [2]:
#Init CLIP tokenizer and model
model_path_clip = "/mfs/xiangchendong19/stable-diffusion-ckpt/xiaoyang.ckpt"
clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip)
clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch.float16)
clip = clip_model.text_model

#Init diffusion model
auth_token = "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" #Replace this with huggingface auth token as a string if model is not already downloaded
model_path_diffusion = "/mfs/xiangchendong19/stable-diffusion-ckpt"
unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=None, revision="fp16", torch_dtype=torch.float16)
vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)

#Move to GPU
device = "cuda:3"
unet.to(device)
vae.to(device)
clip.to(device)
print("Loaded all models")

ValueError: Calling CLIPTokenizer.from_pretrained() with the path to a single file or url is not supported for this tokenizer. Use a model identifier or the path to a directory instead.

In [17]:
Debug = False


# time maps include a serious of record attn maps
time_maps = []
record_attn_maps = []
import os
def mkdir(path):
    folder = os.path.exists(path)
    if not folder:                   #判断是否存在文件夹如果不存在则创建为文件夹
        os.makedirs(path)


In [86]:
import numpy as np
import random
from PIL import Image
from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
from torch import autocast
from difflib import SequenceMatcher

def init_attention_weights(weight_tuples):
    tokens_length = clip_tokenizer.model_max_length
    weights = torch.ones(tokens_length)
    
    for i, w in weight_tuples:
        if i < tokens_length and i >= 0:
            weights[i] = w
    
    
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.last_attn_slice_weights = weights.to(device)
        if module_name == "CrossAttention" and "attn1" in name:
            module.last_attn_slice_weights = None
    


def save_map(attn_map):
    record_attn_maps.append(attn_map)
    
attn_layers = 0
attn2_layers = 0
def init_attention_func():
    attn_layers = 0
    attn2_layers = 0
    #ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
    def new_attention(self, query, key, value):
        # TODO: use baddbmm for better performance
        attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
        
        if self.inject:
            attn_slice = self.inject_func(attention_scores, self.inject_map_dicts, self.inject_scale)
            self.inject = False
        else:
            attn_slice = attention_scores.softmax(dim=-1)

        hidden_states = torch.matmul(attn_slice, value)
        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
        return hidden_states
    
    def new_sliced_attention(self, query, key, value, sequence_length, dim):
        print("hello!") # 这个函数貌似没有走进来过
        batch_size_attention = query.shape[0]
        hidden_states = torch.zeros(
            (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
        )
        slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
        for i in range(hidden_states.shape[0] // slice_size):
            start_idx = i * slice_size
            end_idx = (i + 1) * slice_size
            attn_slice = (
                torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
            )  # TODO: use baddbmm for better performance
            attn_slice = attn_slice.softmax(dim=-1)
            if self.use_last_attn_slice:
                if self.last_attn_slice_mask is not None:
                    new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
                    attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask
                else:
                    attn_slice = self.last_attn_slice
                self.use_last_attn_slice = False
                    
                
            if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
                attn_slice = attn_slice * self.last_attn_slice_weights
                self.use_last_attn_weights = False
            
            attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])

            hidden_states[start_idx:end_idx] = attn_slice
        # reshape hidden_states
        hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
        return hidden_states

    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention":
            if "attn2" in name:
                attn2_layers += 1
                attn_type = "attn2"
            else:
                attn_type = "attn1"
            attn_layers += 1
            module.attn_order = attn_layers
            module.attn_type = attn_type
            module.inject = False
            module.inject_scale = 0.0
            module.inject_map_dicts = None
            module.inject_func = None
            module._sliced_attention = new_sliced_attention.__get__(module, type(module))
            module._attention = new_attention.__get__(module, type(module))

    
def inject_map_by_hand(inject=True):
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.inject = inject
            
def set_inject_map_dicts(map_dicts):
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.inject_map_dicts = map_dicts


def set_inject_func(func):
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.inject_func = func
            
def set_inject_scale(scale):
    for name, module in unet.named_modules():
        module_name = type(module).__name__
        if module_name == "CrossAttention" and "attn2" in name:
            module.inject_scale = scale



In [87]:
@torch.no_grad()
def stablediffusion(prompt, prompt_edit_token_weights=[], 
                    prompt_edit_tokens_start=0.0, prompt_edit_tokens_end=1.0, 
                    inject_map_dicts = [], inject_scale=0.0, inject_func=None,
                    guidance_scale=7.5, steps=50, seed=None, width=512, height=512, 
                    init_image=None, init_image_strength=0.5):
    #Change size to multiple of 64 to prevent size mismatches inside model
    record_attn_maps.clear()
    width = width - width % 64
    height = height - height % 64
    
    #If seed is None, randomly select seed from 0 to 2^32-1
    if seed is None: seed = random.randrange(2**32 - 1)
    generator = torch.cuda.manual_seed_all(seed)
    
    #Set inference timesteps to scheduler
    scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
    scheduler.set_timesteps(steps)
    
    #Preprocess image if it exists (img2img)
    if init_image is not None:
        #Resize and transpose for numpy b h w c -> torch b c h w
        init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS)
        init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0
        init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2))
        
        #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel
        if init_image.shape[1] > 3:
            init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:])
            
        #Move image to GPU
        init_image = init_image.to(device)
        
        #Encode image
        with autocast(device):
            init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215
            
        t_start = steps - int(steps * init_image_strength)
            
    else:
        init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device)
        t_start = 0
    
    #Generate random normal noise
    noise = torch.randn(init_latent.shape, generator=generator, device=device)
    #latent = noise * scheduler.init_noise_sigma
    latent = scheduler.add_noise(init_latent, noise, torch.tensor([scheduler.timesteps[t_start]], device=device)).to(device)
    #Process clip
    with autocast("cuda"):
        tokens_unconditional = clip_tokenizer("", padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
        embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state

        tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True)
        embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state


        
        init_attention_func()
        init_attention_weights(prompt_edit_token_weights)
        set_inject_func(inject_func)
        set_inject_scale(inject_scale)
        set_inject_map_dicts(inject_map_dicts)
        
            
        timesteps = scheduler.timesteps[t_start:]

        
        for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
            t_index = t_start + i
            t_scale = t / scheduler.num_train_timesteps
            #sigma = scheduler.sigmas[t_index]
            latent_model_input = latent
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            
            
            #Predict the unconditional noise residual
            noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional).sample
            
            if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end:
                inject_map_by_hand()
            
            noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional).sample
                
            #Perform guidance
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            

            latent = scheduler.step(noise_pred, t_index, latent).prev_sample

        #scale and decode the image latents with vae
        latent = latent / 0.18215
        image = vae.decode(latent.to(vae.dtype)).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    image = (image[0] * 255).round().astype("uint8")
    
    return Image.fromarray(image)



In [88]:
def prompt_token(prompt, index):
    tokens = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True).input_ids[0]
    print(len(tokens))
    return clip_tokenizer.decode(tokens[index:index+1])

In [89]:
prompt_word_index = 6
print(prompt_token("a cat sitting on a car", prompt_word_index))
print(prompt_token("a smiling dog sitting on a car",prompt_word_index))

77
car
77
a


In [90]:
# prepare inject map dicts
animal_dict = {
    "word_index": 2,
    "inject_map": inject_map_animal_64,
    "inject_map_32": inject_map_animal_32
}
car_dict = {
    "word_index":6,
    "inject_map": inject_map_car_64,
    "inject_map_32": inject_map_car_32
}
only_animal = [animal_dict]
only_car = [car_dict]
animal_car = [animal_dict, car_dict]

map_dict_list = [only_animal, only_car, animal_car]



In [91]:
seed = 24839640267
mkdir("./hand_inject")

In [None]:
for i in range(len(inject_functions)):
    func = inject_functions[i]
    inject_map_dicts = only_animal
    for scale in range(5):
        res = stablediffusion("a hamster sitting on a car",
                inject_map_dicts = inject_map_dicts, inject_scale=scale, inject_func=func,
                seed=seed, steps=50, prompt_edit_tokens_start=0.8)
        mkdir("./hand_inject/func-scale/")
        res.save(f"./hand_inject/func-scale/on-ani-cs0.8-ste50-scal{scale}-func{i+1}.png")

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]