In [1]:
import os
import shutil
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
import cv2
import torch
import einops
import numpy as np
from PIL import Image
from pytorch_lightning import seed_everything
import torchvision.transforms as transformers

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from enum import Enum
from ControlNet.cldm.model import create_model,load_state_dict
from ControlNet.annotator.canny import CannyDetector
from ControlNet.annotator.util import HWC3
from safetensors.torch import load_file

from src.controller import AttentionControl
from src.ddim_v_hacked import DDIMVSampler
from src.img_util import find_flat_region, numpy2tensor

import huggingface_hub
REPO_NAME = 'Anonymous-sub/Rerender'
eta = 0.0

In [3]:
model_dict = {
    'Stable Diffusion 1.5': '',
    'revAnimated_v11': 'models/revAnimated_v11.safetensors',
    'realisticVisionV20_v20': 'models/realisticVisionV20_v20.safetensors',
    'DGSpitzer/Cyberpunk-Anime-Diffusion': 'Cyberpunk-Anime-Diffusion.safetensors',
    'wavymulder/Analog-Diffusion': 'analog-diffusion-1.0.safetensors',
    'Fictiverse/Stable_Diffusion_PaperCut_Model': 'papercut_v1.ckpt',
}
local_model = ['Fictiverse/Stable_Diffusion_PaperCut_Model', 'wavymulder/Analog-Diffusion', 'DGSpitzer/Cyberpunk-Anime-Diffusion']
class ProcessingState(Enum):
    NULL = 0
    FIRST_IMG = 1
    KEY_IMGS = 2
class GlobalState:

    def __init__(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        self.detector_type = None
        self.detector = None
        self.controller = None
        self.processing_state = ProcessingState.NULL

    def update_controller(self, inner_strength, mask_period, cross_period,
                          ada_period, warp_period):
        self.controller = AttentionControl(inner_strength, mask_period,
                                           cross_period, ada_period,
                                           warp_period)

    def update_sd_model(self, sd_model, control_type):
        if sd_model == self.sd_model:
            return
        self.sd_model = sd_model
        model = create_model('./ControlNet/models/cldm_v15.yaml').cpu()
        if control_type == 'HED':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', './models/control_sd15_hed.pth'),
                    location=device))
        elif control_type == 'canny':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', 'models/control_sd15_canny.pth'),
                    location=device))
        elif control_type == 'depth':
            model.load_state_dict(
                load_state_dict(huggingface_hub.hf_hub_download(
                    'lllyasviel/ControlNet', 'models/control_sd15_depth.pth'),
                    location=device))

        model.to(device)
        sd_model_path = model_dict[sd_model]
        if len(sd_model_path) > 0:
            repo_name = REPO_NAME
            # check if sd_model is repo_id/name otherwise use global REPO_NAME
            if sd_model.count('/') == 1:
                repo_name = sd_model

            model_ext = os.path.splitext(sd_model_path)[1]
            if sd_model in local_model:
                downloaded_model = os.path.join("/home/yfyuan/YYF/all_models", model_dict[sd_model]) 
            else:
                downloaded_model = huggingface_hub.hf_hub_download(
                    repo_name, sd_model_path)
            if model_ext == '.safetensors':
                model.load_state_dict(load_file(downloaded_model),
                                      strict=False)
            elif model_ext == '.ckpt' or model_ext == '.pth':
                model.load_state_dict(
                    torch.load(downloaded_model)['state_dict'], strict=False)

        try:
            model.first_stage_model.load_state_dict(torch.load(
                huggingface_hub.hf_hub_download(
                    'stabilityai/sd-vae-ft-mse-original',
                    'vae-ft-mse-840000-ema-pruned.ckpt'))['state_dict'],
                strict=False)
        except Exception:
            print('Warning: We suggest you download the fine-tuned VAE',
                  'otherwise the generation quality will be degraded')

        self.ddim_v_sampler = DDIMVSampler(model)

    def clear_sd_model(self):
        self.sd_model = None
        self.ddim_v_sampler = None
        if device == 'cuda':
            torch.cuda.empty_cache()

    def update_detector(self, control_type, canny_low=100, canny_high=200):
        if self.detector_type == control_type:
            return
        if control_type == 'HED':
            self.detector = HEDdetector()
        elif control_type == 'canny':
            canny_detector = CannyDetector()
            low_threshold = canny_low
            high_threshold = canny_high

            def apply_canny(x):
                return canny_detector(x, low_threshold, high_threshold)

            self.detector = apply_canny

        elif control_type == 'depth':
            midas = MidasDetector()

            def apply_midas(x):
                detected_map, _ = midas(x)
                return detected_map

            self.detector = apply_midas

In [4]:
global_state = GlobalState()

In [5]:
sd_model = 'revAnimated_v11'
# sd_model = 'realisticVisionV20_v20' #真实风格
# sd_model = 'DGSpitzer/Cyberpunk-Anime-Diffusion' #赛博朋克风格
# sd_model = 'wavymulder/Analog-Diffusion' # 人物传记风格
# sd_model = 'Fictiverse/Stable_Diffusion_PaperCut_Model' # 剪纸风格
control_type = 'canny'
low_threshold = 50
high_threshold = 100

In [6]:
global_state.update_sd_model(sd_model, control_type)
global_state.update_controller(0,0,0,0,0)
global_state.update_detector(control_type, low_threshold, high_threshold)

No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.11.mlp.fc2.bias', 'vision_model.encoder.layers.4.mlp.fc1.bias', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.12.mlp.fc2.bias', 'vision_model.encoder.layers.9.self_attn.k_proj.bias', 'vision_model.encoder.layers.11.layer_norm1.weight', 'vision_model.encoder.layers.18.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.mlp.fc2.weight', 'vision_model.encoder.layers.2.mlp.fc1.weight', 'vision_model.encoder.layers.3.self_attn.k_proj.bias', 'vision_model.encoder.layers.14.mlp.fc1.bias', 'vision_model.encoder.layers.19.mlp.fc1.bias', 'vision_model.encoder.layers.5.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.layer_norm2.weight', 'vision_model.encoder.layers.23.mlp.fc2.weight', 'vision_model.encoder.layers.7.layer_norm1.bias', 'vision_model.encoder.layers.10.layer_norm2.bias', 'vision_model.en

Loaded model config from [./ControlNet/models/cldm_v15.yaml]
Loaded state_dict from [/home/yfyuan/.cache/huggingface/hub/models--lllyasviel--ControlNet/snapshots/e78a8c4a5052a238198043ee5c0cb44e22abb9f7/models/control_sd15_canny.pth]


In [7]:
%%capture
ddim_v_sampler = global_state.ddim_v_sampler
model = ddim_v_sampler.model
detector = global_state.detector
controller = global_state.controller
model.control_scales = [0.9] * 13
model.to(device)

seed_everything(0)

Global seed set to 0


In [8]:
unet_model = model.model.diffusion_model

In [9]:
feature_maps_path = "./exp/attn_map/all_frames_inv_features"
os.makedirs(feature_maps_path, exist_ok=True)
def save_feature_map(feature_map, filename):
    save_path = os.path.join(feature_maps_path, f"{filename}.pt")
    torch.save(feature_map, save_path)

In [13]:
frames_path = "./result/pexels-koolshooters-7322716/video"
frames = sorted(os.listdir(frames_path))
for idx, frame in enumerate(frames):
    img_path = os.path.join(frames_path, frame)
    print(img_path)
    frame = cv2.imread(img_path)
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = HWC3(frame)
    H, W, C = img.shape

    img_ = numpy2tensor(img)
    encoder_posterior = model.encode_first_stage(img_.to(device))
    x0 = model.get_first_stage_encoding(encoder_posterior).detach()
    prompt = f""
    cond = {
        'c_concat': None,
        'c_crossattn': [
            model.get_learned_conditioning(
                [prompt]
            )
        ]
    }
    un_cond = {
        'c_concat': None,
        'c_crossattn': [
            model.get_learned_conditioning(
                ['']
            )
        ]
    }
    
    def feature_callback(a,b,c):
        block_idx = 0
        # feature_type = "output_block"
        feature_type = "input_block"
        if feature_type == "input_block":
            blocks = unet_model.input_blocks
        elif feature_type == "output_block":
            blocks = unet_model.output_blocks
        for block in blocks:
            if "Downsample" in str(type(block[0])):
                save_feature_map(block[0].down_output_feature,
                                 f"down_output_{block_idx}_time_{0}")
            
            # if "ResBlock" in str(type(block[0])):
            #     save_feature_map(block[0].out_layers_features,
            #                      f"{feature_type}_{block_idx}_out_layers_features_time_{0}")

            # if len(block) > 1 and "SpatialTransformer" in str(type(block[1])):
            #     save_feature_map(block[1].transformer_blocks[0].attn1.tmp_sim, f"attn_{block_idx}_frame_{idx}")
            
            # if len(block) > 1 and "SpatialTransformer" in str(type(block[1])):
            #     save_feature_map(block[0].out_layers_features,
            #                      f"{feature_type}_{block_idx}_out_layers_features_time_{0}")
            
            # if len(block) > 1 and "SpatialTransformer" in str(type(block[1])):
            #     save_feature_map(block[1].attention_output, f"attn_output_{block_idx}_frame_{idx}")
 
            block_idx += 1
    
    ddim_v_sampler.encode_ddim(x0, 1000, cond, un_cond, controller=None, img_callback=feature_callback)
    break

./result/pexels-koolshooters-7322716/video/0000.png


DDIM Inversion:   0%|                                                                                                                                                    | 0/1000 [00:00<?, ?it/s]


In [11]:
a = torch.load("/home/yfyuan/YYF/Rerender/exp/attn_map/all_frames_inv_features/attn_output_7_frame_0.pt")

In [12]:
a.shape

torch.Size([1, 1280, 16, 18])