In [2]:
import os
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import cv2
import torch
import einops
import numpy as np
from PIL import Image
from pytorch_lightning import seed_everything
import torchvision.transforms as transformers

device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from enum import Enum
from ControlNet.cldm.model import create_model,load_state_dict
from ControlNet.annotator.canny import CannyDetector
from ControlNet.annotator.util import HWC3
from safetensors.torch import load_file

from src.controller import AttentionControl
from src.ddim_v_hacked import DDIMVSampler
from src.img_util import find_flat_region, numpy2tensor

eta = 0.0

In [4]:
model_dict = {
    'Stable Diffusion 1.5': '',
    'revAnimated_v11': 'models/revAnimated_v11.safetensors',
    'realisticVisionV20_v20': 'models/realisticVisionV20_v20.safetensors',
    'DGSpitzer/Cyberpunk-Anime-Diffusion': '~/YYF/all_models/Cyberpunk-Anime-Diffusion.safetensors',
    'wavymulder/Analog-Diffusion': 'analog-diffusion-1.0.safetensors',
    'Fictiverse/Stable_Diffusion_PaperCut_Model': '/home/yfyuan/YYF/all_models/papercut_v1.ckpt',
}
local_model = ['Fictiverse/Stable_Diffusion_PaperCut_Model', 'wavymulder/Analog-Diffusion', 'DGSpitzer/Cyberpunk-Anime-Diffusion']
class ProcessingState(Enum):
    NULL = 0
    FIRST_IMG = 1
    KEY_IMGS = 2
class GlobalState:

    def __init__(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        self.detector_type = None
        self.detector = None
        self.controller = None
        self.processing_state = ProcessingState.NULL

    def update_controller(self, inner_strength, mask_period, cross_period,
                          ada_period, warp_period):
        self.controller = AttentionControl(inner_strength, mask_period,
                                           cross_period, ada_period,
                                           warp_period)

    def update_sd_model(self, sd_model, control_type):
        if sd_model == self.sd_model:
            return
        self.sd_model = sd_model
        model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
        if control_type == 'HED':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', './models/control_sd15_hed.pth'),
                    location=device))
        elif control_type == 'canny':
            model.load_state_dict(
                load_state_dict('./models/control_sd15_canny.pth',
                                location=device))
        elif control_type == 'depth':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'),
                    location=device))

        model.to(device)
        sd_model_path = model_dict[sd_model]
        if len(sd_model_path) > 0:
            # check if sd_model is repo_id/name otherwise use global REPO_NAME
            if sd_model.count('/') == 1:
                repo_name = sd_model

            model_ext = os.path.splitext(sd_model_path)[1]
            if model_ext == '.safetensors':
                model.load_state_dict(load_file(sd_model_path), strict=False)
            elif model_ext == '.ckpt' or model_ext == '.pth':
                model.load_state_dict(torch.load(sd_model_path)['state_dict'],
                                      strict=False)

        try:
            model.first_stage_model.load_state_dict(torch.load('./models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'],strict=False)
        except Exception:
            print('Warning: We suggest you download the fine-tuned VAE',
                  'otherwise the generation quality will be degraded')

        self.ddim_v_sampler = DDIMVSampler(model)

    def clear_sd_model(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        if device == 'cuda':
            torch.cuda.empty_cache()

    def update_detector(self, control_type, canny_low=100, canny_high=200):
        if self.detector_type == control_type:
            return
        if control_type == 'HED':
            self.detector = HEDdetector()
        elif control_type == 'canny':
            canny_detector = CannyDetector()
            low_threshold = canny_low
            high_threshold = canny_high

            def apply_canny(x):
                return canny_detector(x, low_threshold, high_threshold)

            self.detector = apply_canny

        elif control_type == 'depth':
            midas = MidasDetector()

            def apply_midas(x):
                detected_map, _ = midas(x)
                return detected_map

            self.detector = apply_midas

In [5]:
global_state = GlobalState()

In [6]:
# sd_model = 'revAnimated_v11'
sd_model = 'realisticVisionV20_v20' #真实风格
# sd_model = 'DGSpitzer/Cyberpunk-Anime-Diffusion' #赛博朋克风格
# sd_model = 'wavymulder/Analog-Diffusion' # 人物传记风格
# sd_model = 'Fictiverse/Stable_Diffusion_PaperCut_Model' # 剪纸风格
control_type = 'canny'
low_threshold = 50
high_threshold = 100

In [7]:
global_state.update_sd_model(sd_model, control_type)
global_state.update_controller(0,0,0,0,0)
# global_state.update_controller(0,0,0,0,0)
global_state.update_detector(control_type, low_threshold, high_threshold)

ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.23.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.10.mlp.fc2.bias', 'vision_model.encoder.layers.14.layer_norm1.weight', 'vision_model.encoder.layers.16.layer_norm1.weight', 'vision_model.encoder.layers.21.self_attn.k_proj.bias', 'vision_model.encoder.layers.21.mlp.fc2.bias', 'vision_model.encoder.layers.18.self_attn.v_proj.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.weight', 'vision_model.encoder.layers.23.mlp.fc2.weight', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.k_proj.weight', 'vision_model.encoder.layers.12.self_attn.q_proj.bias', 'vision_model.encoder.layers.6.mlp.fc2.bias', 'vision_model.encoder.layers.1.layer_norm2.weight', 'vision_model.encoder.layers.0.mlp.fc2.weight', 'vision_model.encoder.layers.16.self_at

Loaded model config from [./ControlNet/models/cldm_v15.yaml]
Loaded state_dict from [./models/control_sd15_canny.pth]


In [8]:
def save_feature_maps(blocks, i, feature_type="input_block"):
    block_idx = 0
    for block in blocks:
        if feature_type == "input_block":
            if "Downsample" in str(type(block[0])) and block_idx == 6:
                save_feature_map(block[0].down_output_feature,
                                 f"down_output_{block_idx}_time_{0}")
        elif feature_type == "output_block":
            if "ResBlock" in str(type(block[0])) and block_idx == 5:
                save_feature_map(block[0].out_layers_features,
                                 f"{feature_type}_{block_idx}_out_layers_features_time_{0}")
                
        if feature_type == "output_block":
            if len(block) > 1 and "SpatialTransformer" in str(type(block[1])):
                save_feature_map(block[1].transformer_blocks[0].attn1.tmp_sim,
                                 f"attn_{block_idx}_frame_{0}")
        block_idx += 1

def save_feature_maps_callback(i, unet_model):
    save_feature_maps(unet_model.input_blocks, i, "input_block")
    save_feature_maps(unet_model.output_blocks, i, "output_block")


def save_feature_map(feature_map, filename):
    os.makedirs(feature_maps_path, exist_ok=True)
    print(f"saving feature map in {feature_maps_path}")
    save_path = os.path.join(feature_maps_path, f"{filename}.pt")
    torch.save(feature_map, save_path)

In [9]:
@torch.no_grad()
def single_inversion(x0, ddim_v_sampler, img_callback=None):
    # controller.cls = "attn"
    model = ddim_v_sampler.model

    prompt = f""
    cond = {
        'c_concat': None,
        'c_crossattn': [
            model.get_learned_conditioning(
                [prompt]
            )
        ]
    }
    un_cond = {
        'c_concat': None,
        'c_crossattn': [
            model.get_learned_conditioning(
                ['']
            )
        ]
    }
    ddim_v_sampler.encode_ddim(x0, 1000, cond, un_cond, controller=None, img_callback=img_callback)

In [10]:
%%capture
ddim_v_sampler = global_state.ddim_v_sampler
model = ddim_v_sampler.model
detector = global_state.detector
controller = global_state.controller
controller.set_task('')
model.control_scales = [0.9] * 13
model.to(device)

In [11]:
######### 批量出图  ##########
first_strength = 1 - 0.95

# processed = [dir_name[:-len("realisticVisionV20_v20_20")-1] for dir_name in os.listdir("/home/yfyuan/YYF/Rerender/exp/ImageNet/results")]

prompt = f"realistic photo"
# prompt = f"in cartoon style"
a_prompt = "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"

# prompt = "((Papercut)), shapes, delicate patterns, silhouette, kirigami, sharp outline, Spread the picture"
# a_prompt = "(masterpiece, top quality, best quality)"
# n_prompt = "frame,decorations"

# prompt = "in CG style"
# a_prompt = "extremely detailed"
# n_prompt = "extra digit, fewer digits, cropped, worst quality, low quality"

In [None]:
# 0 5 8 9
seed_everything(0)
from tqdm.notebook import tqdm
num_samples = 1
steps = 20
unet_model = model.model.diffusion_model
dir_ = "goldfish"
img_name = "painting_31.jpg"
img_dir = f"/home/yfyuan/YYF/Rerender/exp/ImageNet/imagenet-r/{dir_}"
feature_maps_path = "/home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn"
feature_maps_path_denoising = os.path.join(feature_maps_path, f"{dir_}")
feature_maps_path_inv = os.path.join(feature_maps_path, f"{dir_}_inv")

controller.set_task('')
img_path = os.path.join(img_dir, img_name)
with torch.no_grad():
    def ddim_sampler_callback(i):
        save_feature_maps_callback(i, unet_model)
    def generate_first_img(x0, img, strength):
        samples, _ = ddim_v_sampler.sample(
            steps,
            num_samples,
            shape,
            cond,
            verbose=False,
            eta=eta,
            unconditional_guidance_scale=7.5,
            unconditional_conditioning=un_cond,
            controller=controller,
            x0=x0,
            strength=strength,
            img_callback=None)
        x_samples = model.decode_first_stage(samples)
        x_samples_np = (
                einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
                127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
        return x_samples, x_samples_np

    frame = cv2.imread(img_path)
    shape = frame.shape
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame = np.array(transformers.CenterCrop((frame.shape[0] // 64 * 64, frame.shape[1] // 64 * 64))(Image.fromarray(frame)))

    img = HWC3(frame)
    H, W, C = img.shape
    shape = (4, H // 8, W // 8)
    img_ = numpy2tensor(img)
    
    encoder_posterior = model.encode_first_stage(img_.to(device))
    x0 = model.get_first_stage_encoding(encoder_posterior).detach()
    
    unet_model.unet_type = "denoising"
    feature_maps_path = feature_maps_path_inv
    # single_inversion(x0, ddim_v_sampler, ddim_sampler_callback)
    
    feature_maps_path = feature_maps_path_denoising
    controller.set_task(['initfirst'])
    controller.batch_frame_attn_feature_path = feature_maps_path
    controller.threshold_block_idx = [3,4]

    detected_map = detector(img)
    detected_map = HWC3(detected_map)
    control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
    control = torch.stack([control for _ in range(num_samples)], dim=0)
    control = einops.rearrange(control, 'b h w c -> b c h w').clone()
    cond = {
        'c_concat': [control],
        'c_crossattn': [
            model.get_learned_conditioning(
                [prompt + ', ' + a_prompt] * num_samples)
        ]
    }
    un_cond = {
        'c_concat': [control],
        'c_crossattn':
            [model.get_learned_conditioning([n_prompt] * num_samples)]
    }

    unet_model.unet_type = "spatial"

    x_samples, x_samples_np = generate_first_img(x0, img, first_strength)

saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
saving feature map in /home/yfyuan/YYF/Rerender/exp/attn_map/exp_attn/goldfish_inv
savi

In [None]:
display(Image.fromarray(x_samples_np[0]))

In [34]:
block_num = f"{controller.threshold_block_idx[0]}-{controller.threshold_block_idx[-1]}"
Image.fromarray(x_samples_np[0]).save(f"/home/yfyuan/YYF/Rerender/exp/exp_all_attn/{dir_}_{steps}_{block_num}.png")