In [1]:
import os
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import cv2
import torch
import einops
import numpy as np
from PIL import Image
from pytorch_lightning import seed_everything
import torchvision.transforms as transformers

device = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from enum import Enum
from ControlNet.cldm.model import create_model,load_state_dict
from ControlNet.annotator.canny import CannyDetector
from ControlNet.annotator.util import HWC3
from safetensors.torch import load_file

from src.controller import AttentionControl
from src.ddim_v_hacked import DDIMVSampler
from src.img_util import find_flat_region, numpy2tensor

eta = 0.0

In [3]:
model_dict = {
    'Stable Diffusion 1.5': '',
    'revAnimated_v11': 'models/revAnimated_v11.safetensors',
    'realisticVisionV20_v20': 'models/realisticVisionV20_v20.safetensors',
    'DGSpitzer/Cyberpunk-Anime-Diffusion': '~/YYF/all_models/Cyberpunk-Anime-Diffusion.safetensors',
    'wavymulder/Analog-Diffusion': 'analog-diffusion-1.0.safetensors',
    'Fictiverse/Stable_Diffusion_PaperCut_Model': '~/YYF/all_models/papercut_v1.ckpt',
}
local_model = ['Fictiverse/Stable_Diffusion_PaperCut_Model', 'wavymulder/Analog-Diffusion', 'DGSpitzer/Cyberpunk-Anime-Diffusion']
class ProcessingState(Enum):
    NULL = 0
    FIRST_IMG = 1
    KEY_IMGS = 2
class GlobalState:

    def __init__(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        self.detector_type = None
        self.detector = None
        self.controller = None
        self.processing_state = ProcessingState.NULL

    def update_controller(self, inner_strength, mask_period, cross_period,
                          ada_period, warp_period):
        self.controller = AttentionControl(inner_strength, mask_period,
                                           cross_period, ada_period,
                                           warp_period)

    def update_sd_model(self, sd_model, control_type):
        if sd_model == self.sd_model:
            return
        self.sd_model = sd_model
        model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
        if control_type == 'HED':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', './models/control_sd15_hed.pth'),
                    location=device))
        elif control_type == 'canny':
            model.load_state_dict(
                load_state_dict('./models/control_sd15_canny.pth',
                                location=device))
        elif control_type == 'depth':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'),
                    location=device))

        model.to(device)
        sd_model_path = model_dict[sd_model]
        if len(sd_model_path) > 0:
            # check if sd_model is repo_id/name otherwise use global REPO_NAME
            if sd_model.count('/') == 1:
                repo_name = sd_model

            model_ext = os.path.splitext(sd_model_path)[1]
            if model_ext == '.safetensors':
                model.load_state_dict(load_file(sd_model_path), strict=False)
            elif model_ext == '.ckpt' or model_ext == '.pth':
                model.load_state_dict(torch.load(sd_model_path)['state_dict'],
                                      strict=False)

        try:
            model.first_stage_model.load_state_dict(torch.load('./models/vae-ft-mse-840000-ema-pruned.ckpt')['state_dict'],strict=False)
        except Exception:
            print('Warning: We suggest you download the fine-tuned VAE',
                  'otherwise the generation quality will be degraded')

        self.ddim_v_sampler = DDIMVSampler(model)

    def clear_sd_model(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        if device == 'cuda':
            torch.cuda.empty_cache()

    def update_detector(self, control_type, canny_low=100, canny_high=200):
        if self.detector_type == control_type:
            return
        if control_type == 'HED':
            self.detector = HEDdetector()
        elif control_type == 'canny':
            canny_detector = CannyDetector()
            low_threshold = canny_low
            high_threshold = canny_high

            def apply_canny(x):
                return canny_detector(x, low_threshold, high_threshold)

            self.detector = apply_canny

        elif control_type == 'depth':
            midas = MidasDetector()

            def apply_midas(x):
                detected_map, _ = midas(x)
                return detected_map

            self.detector = apply_midas

In [4]:
global_state = GlobalState()

In [5]:
# sd_model = 'revAnimated_v11'
sd_model = 'realisticVisionV20_v20' #真实风格
# sd_model = 'DGSpitzer/Cyberpunk-Anime-Diffusion' #赛博朋克风格
# sd_model = 'wavymulder/Analog-Diffusion' # 人物传记风格
# sd_model = 'Fictiverse/Stable_Diffusion_PaperCut_Model' # 剪纸风格
control_type = 'canny'
low_threshold = 50
high_threshold = 100

In [None]:
global_state.update_sd_model(sd_model, control_type)
global_state.update_controller(0,0,0,0,0)
# global_state.update_controller(0,0,0,0,0)
global_state.update_detector(control_type, low_threshold, high_threshold)

ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla-xformers' with 512 in_channels
building MemoryEfficientAttnBlock with 512 in_channels...


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.7.self_attn.q_proj.weight', 'vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.8.layer_norm2.weight', 'vision_model.encoder.layers.13.layer_norm2.weight', 'vision_model.encoder.layers.1.self_attn.q_proj.weight', 'vision_model.embeddings.position_ids', 'vision_model.encoder.layers.11.layer_norm2.bias', 'vision_model.encoder.layers.11.mlp.fc1.bias', 'vision_model.encoder.layers.1.mlp.fc2.weight', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.encoder.layers.5.mlp.fc2.bias', 'vision_model.encoder.layers.4.self_attn.q_proj.weight', 'vision_model.encoder.layers.2.self_attn.q_proj.bias', 'vision_model.encoder.layers.14.mlp.fc1.weight', 'vision_model.encoder.layers.14.mlp.fc1.bias', 'vision_model.encoder.layers.13.self_attn.v_proj.weight', 'vision_model.encoder.layers.18.self_attn.v_proj.weight', 'v

Loaded model config from [./ControlNet/models/cldm_v15.yaml]
Loaded state_dict from [./models/control_sd15_canny.pth]


In [7]:
def save_feature_maps(blocks, i, feature_type="input_block"):
    block_idx = 0
    for block in blocks:
        if feature_type == "input_block":
            if "Downsample" in str(type(block[0])) and block_idx == 6:
                save_feature_map(block[0].down_output_feature,
                                 f"down_output_{block_idx}_time_{0}")
        elif feature_type == "output_block":
            if "ResBlock" in str(type(block[0])) and block_idx == 5:
                save_feature_map(block[0].out_layers_features,
                                 f"{feature_type}_{block_idx}_out_layers_features_time_{0}")
                
            if len(block) > 1 and "SpatialTransformer" in str(type(block[1])) and block_idx in [3,4,5]: # block:[resblock, spatial]
                save_feature_map(block[1].transformer_blocks[0].attn1.tmp_sim, f"attn_{block_idx}_frame_{0}")
        block_idx += 1
        
def save_feature_maps_callback(i, unet_model):
    save_feature_maps(unet_model.input_blocks, i, "input_block")
    save_feature_maps(unet_model.output_blocks, i, "output_block")
feature_maps_path = "/home/yfyuan/YYF/Rerender/exp/attn_map/batch_frame_attn_feature"

def save_feature_map(feature_map, filename):
    os.makedirs(feature_maps_path, exist_ok=True)
    save_path = os.path.join(feature_maps_path, f"{filename}.pt")
    torch.save(feature_map, save_path)

In [8]:
@torch.no_grad()
def single_inversion(x0, ddim_v_sampler, img_callback=None):
    steps = 1000
    model = ddim_v_sampler.model

    prompt = f""
    cond = {
        'c_concat': None,
        'c_crossattn': [
            model.get_learned_conditioning(
                [prompt]
            )
        ]
    }
    un_cond = {
        'c_concat': None,
        'c_crossattn': [
            model.get_learned_conditioning(
                ['']
            )
        ]
    }
    ddim_v_sampler.encode_ddim(x0, steps, cond, un_cond, img_callback=img_callback)

In [9]:
ddim_v_sampler = global_state.ddim_v_sampler
model = ddim_v_sampler.model
model.tokenizer

AttributeError: 'ControlLDM' object has no attribute 'tokenizer'

In [9]:
%%capture
ddim_v_sampler = global_state.ddim_v_sampler
model = ddim_v_sampler.model
detector = global_state.detector
controller = global_state.controller
model.control_scales = [0.9] * 13
model.to(device)

In [10]:
######### 批量出图  ##########
dir_list = os.listdir("/home/yfyuan/YYF/Rerender/exp/ImageNet/imagenet-r")
seed_everything(0)
first_strength = 1 - 0.95

processed = [dir_name[:-len("realisticVisionV20_v20_20")-1] for dir_name in os.listdir("/home/yfyuan/YYF/Rerender/exp/ImageNet/results")]

prompt = f"realistic photo"
a_prompt = "RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, high quality, film grain"
n_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"

# prompt = "((Papercut)), shapes, delicate patterns, silhouette, kirigami, sharp outline, Spread the picture"
# a_prompt = "(masterpiece, top quality, best quality)"
# n_prompt = "frame,decorations"

# prompt = "Cyberpunk style"
# a_prompt = "(masterpiece, best quality:1.2), ultra-high resolution, intricately detailed, futuristic cyberpunk car, dynamic lighting, sleek design, ((visually captivating)), lora:CyberPunkAI:.8 CyberPunkAI car on a neon-lit city street, (((cyberpunk atmosphere))) or speeding through a rain-soaked metropolis, (((cyberpunk aesthetics))), amidst a dense cybernetic cityscape, (((cyberpunk allure)))"
# n_prompt = "UnwantedElements, rough-hands-5, (grainy:1.4), (worst quality, low quality:1.4), blurry, unrealistic physique, inappropriate themes, badv4"

Global seed set to 0


In [11]:
print(processed)

['toy_poodle', 'pug', 'lemon', 'chihuahua', 'standard_poodle', 'mantis', 'afghan_hound', 'pineapple', 'cobra', 'junco', 'espresso', 'polar_bear', 'accordion', 'red_fox', 'beagle', 'cheeseburger', 'rugby_ball', 'wood_rabbit', 'guinea_pig', 'ant', 'ambulance', 'pig', 'beer_glass', 'lawn_mower', 'basketball', 'castle', 'acorn', 'beaver', 'hen', 'candle', 'grand_piano', 'guillotine', 'skunk', 'harp', 'badger', 'gibbon', 'mobile_phone', 'clown_fish', 'cowboy_hat', 'lab_coat', 'missile', 'mitten', 'revolver', 'sandal', 'soccer_ball', 'submarine', 'goldfish', 'tarantula', 'centipede', 'duck', 'iguana', 'husky', 'saint_bernard', 'scarf', 'vulture', 'tank', 'gorilla', 'chow_chow', 'birdhouse', 'koala', 'ostrich', 'bucket', 'king_penguin', 'hammerhead', 'cauldron', 'joystick', 'yorkshire_terrier', 'toucan', '', 'pickup_truck', 'bee', 'collie', 'flute', 'weimaraner', 'bagel', 'zebra', 'harmonica', 'porcupine', 'baseball_player', 'hippopotamus', 'lipstick', 'scuba_diver', 'tree_frog', 'school_bus'

In [None]:
from tqdm.notebook import tqdm
num_samples = 1
steps = 20
unet_model = model.model.diffusion_model
for i,dir_path in enumerate(tqdm(dir_list)):
    if dir_path in processed:
        continue
    print(f"{i} ",dir_path," gogogogogo!")
    img_dir = f"/home/yfyuan/YYF/Rerender/exp/ImageNet/imagenet-r/{dir_path}"
    results = f"/home/yfyuan/YYF/Rerender/exp/ImageNet/results/{dir_path}_{sd_model}_{steps}"
    img_names = os.listdir(img_dir)
    imgs_path = [os.path.join(img_dir, img) for img in img_names]
    os.makedirs(results, exist_ok=True)
    with torch.no_grad():
        def generate_first_img(x0, img, strength):
            samples, _ = ddim_v_sampler.sample(
                steps,
                num_samples,
                shape,
                cond,
                verbose=False,
                eta=eta,
                unconditional_guidance_scale=7.5,
                unconditional_conditioning=un_cond,
                controller=controller,
                x0=x0,
                strength=strength)
            x_samples = model.decode_first_stage(samples)
            x_samples_np = (
                    einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
                    127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
            return x_samples, x_samples_np
        
        for img_name, img_path in zip(img_names, imgs_path):
            # print(f"processing {img_name}")
            print(img_name)
            frame = cv2.imread(img_path)
            shape = frame.shape
            if "videogame" in img_name:
                continue
            if shape[0] > 650 or shape[1] > 650 or shape [2] > 650:
                continue
            if shape[0] < 300 or shape[1] < 300:
                continue
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = np.array(transformers.CenterCrop((frame.shape[0] // 64 * 64, frame.shape[1] // 64 * 64))(Image.fromarray(frame)))

            img = HWC3(frame)
            H, W, C = img.shape
            shape = (4, H // 8, W // 8)
            img_ = numpy2tensor(img)

            encoder_posterior = model.encode_first_stage(img_.to(device))
            x0 = model.get_first_stage_encoding(encoder_posterior).detach()

            detected_map = detector(img)
            detected_map = HWC3(detected_map)
            control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
            control = torch.stack([control for _ in range(num_samples)], dim=0)
            control = einops.rearrange(control, 'b h w c -> b c h w').clone()
            cond = {
                'c_concat': [control],
                'c_crossattn': [
                    model.get_learned_conditioning(
                        [prompt + ', ' + a_prompt] * num_samples)
                ]
            }
            un_cond = {
                'c_concat': [control],
                'c_crossattn':
                    [model.get_learned_conditioning([n_prompt] * num_samples)]
            }

            unet_model.unet_type = "denoising"
            controller.set_task('')

            def ddim_sampler_callback(i):
                save_feature_maps_callback(i, unet_model)
            single_inversion(x0, ddim_v_sampler, ddim_sampler_callback)

            Image.open(img_path).save(os.path.join(results, img_name))

            x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
            Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_sd.jpg"))

            controller.set_task('initfirst')
            controller.threshold_block_idx = [3,4,5]
            x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
            Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_attn345.jpg"))
            controller.set_task('')

            unet_model.unet_type = "spatial"
            x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
            Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_onlyspatial.jpg"))
            unet_model.unet_type = "denoising"

            unet_model.unet_type = "spatial"
            controller.set_task('initfirst')
            controller.threshold_block_idx = [3,4]
            x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
            Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_all.jpg"))

  0%|          | 0/200 [00:00<?, ?it/s]

48  ant  gogogogogo!
sketch_19.jpg
sculpture_8.jpg
embroidery_10.jpg
painting_3.jpg
cartoon_5.jpg
graffiti_11.jpg
graffiti_7.jpg
sculpture_11.jpg
sketch_8.jpg
origami_5.jpg
sketch_14.jpg
sculpture_3.jpg
cartoon_0.jpg
misc_6.jpg
tattoo_17.jpg
graffiti_2.jpg
sketch_3.jpg
toy_3.jpg
origami_0.jpg
sticker_5.jpg
art_8.jpg
tattoo_7.jpg
tattoo_22.jpg
embroidery_6.jpg
misc_1.jpg
tattoo_12.jpg
painting_7.jpg
graffiti_15.jpg
sticker_0.jpg
art_3.jpg
sculpture_15.jpg
graphic_0.jpg
tattoo_2.jpg
embroidery_1.jpg
origami_9.jpg
sketch_18.jpg
sculpture_7.jpg
misc_13.jpg
painting_2.jpg
cartoon_4.jpg
graffiti_10.jpg
sculpture_20.jpg
graffiti_6.jpg
sculpture_10.jpg
sketch_7.jpg
sketch_23.jpg
origami_4.jpg
sketch_13.jpg
sculpture_2.jpg
misc_5.jpg
tattoo_16.jpg
graffiti_1.jpg
sketch_2.jpg
toy_2.jpg
sticker_4.jpg
art_7.jpg
sculpture_19.jpg
graphic_4.jpg
tattoo_6.jpg
tattoo_21.jpg
embroidery_5.jpg
misc_0.jpg
tattoo_11.jpg
painting_6.jpg
graffiti_14.jpg
art_2.jpg
sculpture_14.jpg
tattoo_1.jpg
embroidery_0.jpg
o

In [None]:

# with torch.no_grad():
#     def generate_first_img(x0, img, strength):
#         samples, _ = ddim_v_sampler.sample(
#             steps,
#             num_samples,
#             shape,
#             cond,
#             verbose=False,
#             eta=eta,
#             unconditional_guidance_scale=7.5,
#             unconditional_conditioning=un_cond,
#             controller=controller,
#             x0=x0,
#             strength=strength)
#         x_samples = model.decode_first_stage(samples)
#         x_samples_np = (
#                 einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 +
#                 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
#         return x_samples, x_samples_np
    
#     for img_name, img_path in zip(img_names, imgs_path):
#         # print(f"processing {img_name}")
#         print(img_name)
#         frame = cv2.imread(img_path)
#         shape = frame.shape
#         if "videogame" in img_name:
#             continue
#         if shape[0] > 650 or shape[1] > 650 or shape [2] > 650:
#             continue
#         if shape[0] < 300 or shape[1] < 300:
#             continue
#         frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#         frame = np.array(transformers.CenterCrop((frame.shape[0] // 64 * 64, frame.shape[1] // 64 * 64))(Image.fromarray(frame)))

#         img = HWC3(frame)
#         H, W, C = img.shape
#         shape = (4, H // 8, W // 8)
#         img_ = numpy2tensor(img)
        
#         encoder_posterior = model.encode_first_stage(img_.to(device))
#         x0 = model.get_first_stage_encoding(encoder_posterior).detach()
        
#         detected_map = detector(img)
#         detected_map = HWC3(detected_map)
#         control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
#         control = torch.stack([control for _ in range(num_samples)], dim=0)
#         control = einops.rearrange(control, 'b h w c -> b c h w').clone()
#         cond = {
#             'c_concat': [control],
#             'c_crossattn': [
#                 model.get_learned_conditioning(
#                     [prompt + ', ' + a_prompt] * num_samples)
#             ]
#         }
#         un_cond = {
#             'c_concat': [control],
#             'c_crossattn':
#                 [model.get_learned_conditioning([n_prompt] * num_samples)]
#         }
        
#         unet_model.unet_type = "denoising"
#         controller.set_task('')
        
#         def ddim_sampler_callback(i):
#             save_feature_maps_callback(i, unet_model)
#         single_inversion(x0, ddim_v_sampler, ddim_sampler_callback)
        
#         Image.open(img_path).save(os.path.join(results, img_name))
        
#         x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
#         Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_sd.jpg"))
        
#         controller.set_task('initfirst')
#         controller.threshold_block_idx = [3,4,5]
#         x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
#         Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_attn345.jpg"))
#         controller.set_task('')
    
#         unet_model.unet_type = "spatial"
#         x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
#         Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_onlyspatial.jpg"))
#         unet_model.unet_type = "denoising"
        
#         unet_model.unet_type = "spatial"
#         controller.set_task('initfirst')
#         controller.threshold_block_idx = [3,4]
#         x_samples, x_samples_np = generate_first_img(x0, img, first_strength)
#         Image.fromarray(x_samples_np[0]).save(os.path.join(results, img_name[:-4]+"_all.jpg"))

In [None]:
print(1)