In [38]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import sys
sys.path.append('..')

import cv2
import torch
import imageio
import numpy as np
from PIL import Image
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPImageProcessor
from diffusers import (
    DDIMScheduler, 
    AutoencoderKL, 
    ControlNetModel,
    StableDiffusionPipeline,
)
from denku import show_images
from controlnet_aux import HEDdetector

from iattention import IAttentionSDCPipeline
from iattention.attention_processors import register_stablediffuion_attention_control
from iattention.utils import correct_colors_hist
from unet.unet_2d_condition import UNet2DConditionModel

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def register_attention_colntrol(pipe, config):
    pipe.set_storage_params(config['pipe_config'])
    register_stablediffuion_attention_control(
        pipe.unet,
        **config['unet_config']
    )
    register_stablediffuion_attention_control(
        pipe.controlnet,
        **config['controlnet_config']
    )

def register_coefs(pipe, coefs):
    for up_block in pipe.unet.up_blocks: 
        up_block.coefs = coefs
    
def get_capture_info(cap):
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    return height, width, fps, frame_count

In [3]:
pretrained_model_path = 'runwayml/stable-diffusion-v1-5'
pretrained_controlnet = 'lllyasviel/control_v11p_sd15_softedge'

interpolation_config = {
    'pipe_config': {
        'interpolation_scheduler': 'ema',
        'ema': 0.45,
        'eta': 0.75,
        'start_step': 0,
        'end_step': 25,
        'const_steps': 0,
        'total_steps': 25,
    },
    'unet_config': {
        'interpolation_scheduler': 'ema',
        'ema': 0.62,
        'eta': 0.87,
        'start_step': 0,
        'end_step': 25,
        'const_steps': 0,
        'total_steps': 25,
        'use_interpolation':{
            'key': False,
            'query': False,
            'value': False,
            'attention_probs': True,
            'out_linear': True
        },
        'attention_res': 32,
        'allow_names': ['down', 'mid', 'up'],
    },
    'controlnet_config': {
        'interpolation_scheduler': 'ema',
        'ema': 0.625,
        'eta': 0.875,
        'start_step': 0,
        'end_step': 25,
        'const_steps': 0,
        'total_steps': 25,
        'use_interpolation': {
            'key': False,
            'query': False,
            'value': False,
            'attention_probs': True,
            'out_linear': True,
        },
        'attention_res': 32,
        'allow_names': ['down', 'mid'],
    },
}

In [4]:
tokenizer = CLIPTokenizer.from_pretrained(
    pretrained_model_path, 
    subfolder="tokenizer", 
    torch_dtype=torch.float16
)

In [5]:
text_encoder = CLIPTextModel.from_pretrained(
    pretrained_model_path, 
    subfolder="text_encoder", 
    torch_dtype=torch.float16
)

In [6]:
feature_extractor = CLIPImageProcessor.from_pretrained(
    pretrained_model_path, 
    subfolder="feature_extractor", 
    torch_dtype=torch.float16
)

In [7]:
vae = AutoencoderKL.from_pretrained(
    pretrained_model_path, 
    subfolder="vae", 
    torch_dtype=torch.float16
)            

In [8]:
unet = UNet2DConditionModel.from_pretrained(
    pretrained_model_path, 
    subfolder="unet", 
    torch_dtype=torch.float16
)

In [9]:
scheduler = DDIMScheduler.from_pretrained(
    pretrained_model_path, 
    subfolder="scheduler", 
)

In [10]:
controlnet = ControlNetModel.from_pretrained(
    pretrained_controlnet, 
    torch_dtype=torch.float16
)
processor = HEDdetector.from_pretrained('lllyasviel/Annotators')

In [11]:
pipe = IAttentionSDCPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler,  
    feature_extractor=feature_extractor,
    controlnet=controlnet,
    safety_checker=None,
).to('cuda')

You have disabled the safety checker for <class 'iattention.stablediffusion_controlnet_pipeline.IAttentionSDCPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [61]:
coefs = {
    1280: {
        'backbone_coef': 1.4,
        'skip_coef': 0.9,
    },
    640: {
        'backbone_coef': 1.2,
        'skip_coef': 0.1,
    }
}

In [62]:
register_attention_colntrol(pipe, interpolation_config)
register_coefs(pipe, coefs)

gen_config = {
    'prompt': 'Professional high-quality wide-angle digital art of an iron man in helmet. photorealistic, epic fantasy, dramatic lighting, cinematic, extremely high detail, cinematic lighting, trending on artstation, cgsociety, realistic rendering of Unreal Engine 5, 8k, 4k, HQ, wallpaper',
    'negative_prompt': 'lowres, worst quality, low quality',
    'seed': 17,
    'img_h': 512,
    'img_w': 512,
    'num_inference_steps': 25,
    'guess_mode': True,
    'original_output_size': True,
    'controlnet_processor': 'softedge',  # softedge, pose, norm, canny, depth
    'hist_normalize': 'lab',  # rgb, hsv
    'input_video_path': '../video/input/man.mp4',
    'output_video_path': '../video/output/man_ironman.mp4',
}

cap = cv2.VideoCapture(gen_config['input_video_path'])
height, width, fps, frame_count = get_capture_info(cap)

images = []
first_image = None
process_frames_count = 1

for frame_num in range(frame_count):
    print(f'PROCESSED FRAME: {frame_num} | TOTAL FRAMES: {min(process_frames_count, frame_count)}', end='\r')
    ret, image = cap.read()
    if not ret or image is None:
        continue
    
    if frame_num >= process_frames_count:
        break
        
    img_size = (gen_config['img_w'], gen_config['img_h'])
    image = cv2.resize(image, img_size)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    condition_image = processor(Image.fromarray(image))

    result_img = pipe(
        image=condition_image,
        prompt=gen_config['prompt'],
        negative_prompt=gen_config['negative_prompt'],
        num_inference_steps=gen_config['num_inference_steps'],
        generator=torch.manual_seed(gen_config['seed']),
        guess_mode=gen_config['guess_mode'],
    ).images[0]
    result = np.array(result_img)
    
    if first_image is None:
        first_image = result.copy()
    else:
        result = correct_colors_hist(first_image, result, gen_config['hist_normalize'])
            
    images.append(result)
#     break

PROCESSED FRAME: 1 | TOTAL FRAMES: 1

In [1]:
# show_images(images, n_rows=1, figsize=(15, 15))

In [51]:
imageio.mimwrite(gen_config['output_video_path'], images, fps=fps)

In [19]:
# images = pipe(
#     prompt='black cat, night, moon. high quality, extrimely high detail, cinematic, 8k, 4k',
#     num_inference_steps=50,
#     guidance_scale=7.5,
#     return_dict=True,
# ).images