# Preparation

In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = "5"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import cv2
import io
import gc
import yaml
import argparse
import torch
import torchvision
import diffusers
import json
from typing import List
from diffusers import StableDiffusionPipeline, AutoencoderKL, DDPMScheduler, ControlNetModel, DDIMScheduler

from src.utils import *
from src.keyframe_selection import get_keyframe_ind
from src.diffusion_hacked import apply_FRESCO_attn, apply_FRESCO_opt, disable_FRESCO_opt
from src.diffusion_hacked import get_flow_and_interframe_paras, get_intraframe_paras, get_flow_and_interframe_paras_warped
from src.pipe_FRESCO import inference, inference_extended
from src.tokenflow_utils import *
from src.pipe_FRESCO import inference_extended

# Model Utils

In [None]:
def get_models(config):
    print('\n' + '=' * 100)
    print('creating models...')
    import sys
    sys.path.append("./src/ebsynth/deps/gmflow/")
    sys.path.append("./src/EGNet/")
    sys.path.append("./src/ControlNet/")
    
    from gmflow.gmflow import GMFlow
    from model import build_model
    from annotator.hed import HEDdetector
    from annotator.canny import CannyDetector
    from annotator.midas import MidasDetector

    # optical flow
    flow_model = GMFlow(feature_channels=128,
                   num_scales=1,
                   upsample_factor=8,
                   num_head=1,
                   attention_type='swin',
                   ffn_dim_expansion=4,
                   num_transformer_layers=6,
                   ).to('cuda')
    
    checkpoint = torch.load(config['gmflow_path'], map_location=lambda storage, loc: storage)
    weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
    flow_model.load_state_dict(weights, strict=False)
    flow_model.eval() 
    print('create optical flow estimation model successfully!')
    
    # saliency detection
    sod_model = build_model('resnet')
    sod_model.load_state_dict(torch.load(config['sod_path']))
    sod_model.to("cuda").eval()
    print('create saliency detection model successfully!')
    
    # controlnet
    if config['controlnet_type'] not in ['hed', 'depth', 'canny']:
        print('unsupported control type, set to hed')
        config['controlnet_type'] = 'hed'
    controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-"+config['controlnet_type'], 
                                                 torch_dtype=torch.float16)
    controlnet.to("cuda") 
    if config['controlnet_type'] == 'depth':
        detector = MidasDetector()
    elif config['controlnet_type'] == 'canny':
        detector = CannyDetector()
    else:
        detector = HEDdetector()
    print('create controlnet model-' + config['controlnet_type'] + ' successfully!')
    
    # diffusion model
    if config['sd_path'] == 'stabilityai/stable-diffusion-2-1-base':
        pipe = StableDiffusionPipeline.from_pretrained(config['sd_path'], torch_dtype=torch.float16)
    else:
        vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
        pipe = StableDiffusionPipeline.from_pretrained(config['sd_path'], vae=vae, torch_dtype=torch.float16)
    pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
    # pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    #noise_scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
    pipe.to("cuda")
    pipe.scheduler.set_timesteps(config['num_inference_steps'], device=pipe._execution_device)
    
    if config['use_freeu']:
        from src.free_lunch_utils import apply_freeu
        apply_freeu(pipe, b1=1.2, b2=1.5, s1=1.0, s2=1.0)

    if config['use_tokenflow']:
        set_tokenflow(pipe.unet, 'SDEdit')

    frescoProc = apply_FRESCO_attn(pipe)
    frescoProc.controller.disable_controller()
    apply_FRESCO_opt(pipe)
    print('create diffusion model ' + config['sd_path'] + ' successfully!')
    
    for param in flow_model.parameters():
        param.requires_grad = False    
    for param in sod_model.parameters():
        param.requires_grad = False
    for param in controlnet.parameters():
        param.requires_grad = False
    for param in pipe.unet.parameters():
        param.requires_grad = False
    
    return pipe, frescoProc, controlnet, detector, flow_model, sod_model

def apply_control(x, detector, config):
    if config['controlnet_type'] == 'depth':
        detected_map, _ = detector(x)
    elif config['controlnet_type'] == 'canny':
        detected_map = detector(x, 50, 100)
    else:
        detected_map = detector(x)
    return detected_map

# Config Utils

In [None]:
def make_configs(inputs:List[str] = None, edit_method = 'SDEdit', synth_method = 'Tokenflow', 
                 use_warp_noise = False, run_ebsynth = False, use_inv_prompts = False,
                 use_saliency = True, use_controlnet = True,
                 keyframe_select_mode = 'loop', keyframe_select_radix = 6, 
                 pnp_attn_t = 0.5, pnp_f_t = 0.8, 
                 inv_step = 500, inv_batch_size = 20, batch_size = 4) -> List[str]:
    save_path = '/mnt/netdisk/linjunxin/fresco/'

    prompts = '/home/linjx/fresco/data/videos/prompts.json'
    inv_prompts = '/home/linjx/TokenFlow/data/tokenflow_supp_videos/inv_prompts.json'
    key_intervs = '/home/linjx/fresco/data/videos/key_interv.json'
    video_dir = '/home/linjx/fresco/data/videos'
    base_model_ref = '/home/linjx/fresco/data/videos/base_model.json'

    ref_yaml = '/mnt/netdisk/linjunxin/fresco/ref_config.yaml'
    refs = '/home/linjx/fresco/cfg.json'

    with open(prompts, 'r') as f:
        prompt_dict = json.load(f)

    with open(key_intervs, 'r') as f:
        key_dict = json.load(f)

    if use_inv_prompts == True:
        with open(inv_prompts, 'r') as f:
            inv_pronpts_dict = json.load(f)

    with open(base_model_ref, 'r') as f:
        model_dict = json.load(f)

    with open(refs, 'r') as f:
        cfg_dict = json.load(f)

    config_paths = []

    for name, prompts in prompt_dict.items():
        file_name, ext = os.path.splitext(name)
        if inputs is not None and file_name not in inputs:
            continue

        with open(ref_yaml,'r') as f:
            config_yaml = yaml.load(f, Loader=yaml.FullLoader)

        config_yaml['file_path'] = os.path.join(video_dir, name)
        save_path_video = os.path.join(save_path, f'test-{synth_method}-{edit_method}-{keyframe_select_mode}{'-warp' if use_warp_noise else ''}', file_name)
        if not os.path.exists(save_path_video):
            os.makedirs(save_path_video)
        config_yaml['mininterv'] = key_dict[name][0]
        config_yaml['maxinterv'] = key_dict[name][1]
        if use_inv_prompts == True:
            config_yaml['inv_prompt'] = inv_pronpts_dict[name]
        else:
            config_yaml['inv_prompt'] = ''
        config_yaml['sd_path'] = model_dict[name]
        config_yaml['inv_batch_size'] = inv_batch_size
        config_yaml['run_ebsynth'] = run_ebsynth
        config_yaml['batch_size'] = batch_size
        config_yaml['use_tokenflow'] = synth_method == 'Tokenflow'
        config_yaml['use_controlnet'] = use_controlnet
        config_yaml['edit_mode'] = edit_method
        config_yaml['warp_noise'] = use_warp_noise
        config_yaml['keyframe_select_mode'] = keyframe_select_mode
        config_yaml['keyframe_select_radix'] = keyframe_select_radix
        config_yaml['use_saliency'] = use_saliency
        config_yaml['num_inference_steps'] = 20
        if file_name in cfg_dict:
            config_yaml['cond_scale'] = cfg_dict[file_name]['control_scales']
            config_yaml['controlnet_type'] = cfg_dict[file_name]['control']
            config_yaml['num_warmup_steps'] = int(config_yaml['num_inference_steps'] * cfg_dict[file_name]['strength'])
            config_yaml['a_prompt'] = cfg_dict[file_name]['a_prompt']
            config_yaml['n_prompt'] = cfg_dict[file_name]['n_prompt']
        if use_inv_prompts == True:
            inv_path_name = 'latents'
        else:
            inv_path_name = 'latents-null'

        inv_latent_save_path = os.path.join(save_path, inv_path_name, file_name, f'inv_step_{inv_step}')
        if not os.path.exists(inv_latent_save_path):
            os.makedirs(inv_latent_save_path)
        config_yaml['inv_save_path'] = inv_latent_save_path
        # inv_latent_load_path = os.path.join(inv_latent_save_path, file_name)
        config_yaml['inv_latent_path'] = os.path.join(inv_latent_save_path, 'latents')
        
        print('=' * 100)
        
        with open(os.path.join(inv_latent_save_path, 'config.yaml'),'w') as f:
            yaml.dump(config_yaml, f, default_flow_style=False)
            print(os.path.join(inv_latent_save_path, 'config.yaml'))

        with open(os.path.join(save_path_video, 'config.yaml'),'w') as f:
            yaml.dump(config_yaml, f, default_flow_style=False)
            print(os.path.join(save_path_video, 'config.yaml'))
        
        for prompt in prompts:
            save_video_with_prompts = os.path.join(save_path_video, f'inv_step_{inv_step}', prompt.replace(' ', '_'), 
                                                f'radix_{keyframe_select_radix}')
            if not os.path.exists(save_video_with_prompts):
                os.makedirs(save_video_with_prompts)

            config_yaml['save_path'] = save_video_with_prompts +'/'
            config_yaml['prompt'] = prompt

            file_path = os.path.join(save_video_with_prompts, 'config.yaml')
            with open(file_path,'w') as f:
                yaml.dump(config_yaml,f,default_flow_style=False)
                print(file_path)
                config_paths.append(file_path)

        print('=' * 100)
    return config_paths

# Make Configs

In [None]:
inputs = ['music_input', ]
configs = make_configs(inputs=inputs, edit_method='SDEdit', synth_method='Tokenflow',
                       use_warp_noise=False, run_ebsynth=False, use_saliency=True,
                       use_controlnet=True, keyframe_select_mode='loop', keyframe_select_radix=6)
print(configs)

# Run Keyframe Translation

In [None]:
config_path = configs[0]
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

pipe, frescoProc, controlnet, detector, flow_model, sod_model = get_models(config)
device = pipe._execution_device
guidance_scale = 7.5
do_classifier_free_guidance = guidance_scale > 1
assert(do_classifier_free_guidance)
timesteps = pipe.scheduler.timesteps
cond_scale = [config['cond_scale']] * config['num_inference_steps']
dilate = Dilate(device=device)

base_prompt = config['prompt']
if 'Realistic' in config['sd_path'] or 'realistic' in config['sd_path']:
    a_prompt = ', RAW photo, subject, (high detailed skin:1.2), 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3, '
    n_prompt = '(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers:1.4), (deformed, distorted, disfigured:1.3), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation'
else:
    a_prompt = ', best quality, extremely detailed, '
    n_prompt = 'longbody, lowres, bad anatomy, bad hands, missing finger, extra digit, fewer digits, cropped, worst quality, low quality'    

print('\n' + '=' * 100)
print('key frame selection for \"%s\"...'%(config['file_path']))

video_cap = cv2.VideoCapture(config['file_path'])
frame_num = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))

# you can set extra_prompts for individual keyframe
# for example, extra_prompts[38] = ', closed eyes' to specify the person frame38 closes the eyes
extra_prompts = [''] * frame_num

keys = get_keyframe_ind(config['file_path'], frame_num, config['mininterv'], config['maxinterv'],
                        mode=config['keyframe_select_mode'], radix=config['keyframe_select_radix'],
                        extended=True)
sublists_all = []
for ind in range(len(keys)):
    sublists = [keys[ind][i:i+config['batch_size']-2] for i in range(2, len(keys[ind]), config['batch_size']-2)]
    sublists[0].insert(0, keys[ind][0])
    sublists[0].insert(1, keys[ind][1])
    while len(sublists_all) < len(sublists):
        sublists_all.append([])
    for batch_ind, keys_batch in enumerate(sublists):
        sublists_all[batch_ind].append(keys_batch)
print(f"split keyframes into batches {sublists_all}")

keylists = []
for keys_group in sublists_all:
    keys_all = []
    for key in keys_group:
        keys_all += key
    keylists.append(keys_all)
print(f"split keyframes into groups {keylists}")

os.makedirs(config['save_path'], exist_ok=True)
if os.path.exists(config['save_path']+'keys'):
    os.system(f'rm -r {config['save_path']+'keys'}')
os.makedirs(config['save_path']+'keys', exist_ok=True)
os.makedirs(config['save_path']+'video', exist_ok=True)

gc.collect()
torch.cuda.empty_cache()

In [None]:
dpi = 320

batch_ind = 0
imgs = []
img_idx = []
record_latent = []
video_cap = cv2.VideoCapture(config['file_path'])
for i in range(frame_num):
    success, frame = video_cap.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    img = resize_image(frame, 512)
    Image.fromarray(img).save(os.path.join(config['save_path'], 'video/%04d.png'%(i)))
    
    if i not in keylists[batch_ind] and not config['use_tokenflow']:
        continue

    imgs += [img]
    img_idx += [i]
    
    if batch_ind < len(keylists) - 1 and i != keylists[batch_ind][-1]:
        continue
    
    if batch_ind == len(keylists) - 1 and config['use_tokenflow'] and i != frame_num - 1:
        continue

    print(f'processing batch [{batch_ind + 1}/{len(keylists)}] with images {img_idx}')

    propagation_mode = batch_ind > 0

    prompts = [base_prompt + a_prompt + extra_prompts[ind] for ind in img_idx]
    if propagation_mode:
        if not config['use_tokenflow']:
            assert len(imgs) == len(keylists[batch_ind]) + 2
        else:
            assert len(imgs) ==  img_idx[-1] - keylists[batch_ind - 1][-1] + 2
        prompts = ref_prompts + prompts

    print('input of current batch:')
    imgs_torch = torch.cat([numpy2tensor(img) for img in imgs], dim=0)
    viz = torchvision.utils.make_grid(imgs_torch, len(imgs_torch), 1)
    visualize(viz.cpu(), dpi)

    edges = torch.cat([numpy2tensor(apply_control(img, detector, config)[:, :, None]) for img in imgs], dim=0)
    edges = edges.repeat(1,3,1,1).cuda() * 0.5 + 0.5
    if do_classifier_free_guidance:
        edges = torch.cat([edges.to(pipe.unet.dtype)] * 2)

    pos_map = {img_idx[i]:i for i in range(len(img_idx))}
    prefix = [0, 1] if propagation_mode else []
    keylists_pos = [prefix + [pos_map[key] for key in keygroup] + ([pos_map[img_idx[-1]]] if img_idx[-1] not in keygroup else [])
                    for keygroup in sublists_all[batch_ind]]

    print(f'keyframe indexes of images {sublists_all[batch_ind]}')
    print(f'keyframe indexes of position {keylists_pos}')

    gc.collect()
    torch.cuda.empty_cache()

    latents = inference_extended(pipe, controlnet, frescoProc, imgs, edges, timesteps, keylists_pos, n_prompt,
                                 prompts, config['end_opt_step'], config['batch_size'], propagation_mode, False, 
                                 config['warp_noise'], do_classifier_free_guidance, config['use_tokenflow'], False, 
                                 config['use_controlnet'], config['use_saliency'], cond_scale, config['num_inference_steps'], 
                                 config['num_warmup_steps'], config['seed'], guidance_scale, record_latent,
                                 flow_model=flow_model, sod_model=sod_model, dilate=dilate)

    with torch.no_grad(), torch.autocast(dtype=torch.float16, device_type='cuda'):
        start = 2 if propagation_mode else 0
        size = len(latents)
        image = []
        for i in range(start, size, config['batch_size']):
            end = min(size, i + config['batch_size'])
            image_batch = pipe.vae.decode(latents[i:end] / pipe.vae.config.scaling_factor, return_dict=False)[0]
            image.append(image_batch)
        image = torch.cat(image)
        image = torch.clamp(image, -1, 1)
        save_imgs = tensor2numpy(image)
        for ind, num in enumerate(img_idx[start:]):
            Image.fromarray(save_imgs[ind]).save(os.path.join(config['save_path'], 'keys/%04d.png'%(num)))
        print('results of current batch:')
        viz = torchvision.utils.make_grid(image, len(image), 1)
        visualize(viz.cpu(), dpi)

    batch_ind += 1
    imgs = [imgs[0], imgs[-1]]
    img_idx = [img_idx[0], img_idx[-1]]
    ref_prompts = [prompts[0], prompts[-1]]
    if batch_ind == len(keylists):
        gc.collect()
        torch.cuda.empty_cache()
        break

# Video Utils

In [None]:

def to_video_multi(roots:List[str], output:str, fps:int = 24, name:str = 'video') -> str:
    nfile = len(roots)
    img_roots = [os.path.join(root, 'keys') for root in roots]

    fourcc = cv2.VideoWriter_fourcc(*'MJPG')
    video_root = os.path.join(output, name + '.mp4')

    file_lists = []
    for img_root in img_roots:
        file_list = os.listdir(img_root)
        file_list.sort()
        file_lists.append(file_list)
    video_len = min([len(file_list) for file_list in file_lists])

    ref_frame = cv2.imread(os.path.join(img_roots[0], file_lists[0][0]))
    (H, W, C) = ref_frame.shape
    if nfile & 1:
        size = (nfile * W, H)
    else:
        size = (2 * W, nfile // 2 * H)

    videoWriter = cv2.VideoWriter(video_root, fourcc, fps, size, True)

    for i in range(video_len):
        frames = []
        for j in range(nfile):
            frames.append(cv2.imread(os.path.join(img_roots[j], file_lists[j][i])))
        if nfile & 1:
            frame = cv2.hconcat(frames)
        else:
            vframes = [cv2.hconcat([frames[k], frames[k+1]]) for k in range(0, nfile, 2)]
            frame = cv2.vconcat(vframes)
        videoWriter.write(frame)
    print('done!\n')
        
    videoWriter.release()
    return video_root

# Run Full Video Translation

In [None]:
if config['use_tokenflow']:
    keys = [i for i in range(frame_num)]
else:
    keys = keys[config['num_inference_steps'] % len(keys)]

if config['use_tokenflow']:
    video_name = config['file_path'].split('/')[-1]
    video_name = video_name.split('_')[0]
    video_name += f'_{config['edit_mode']}_Tokenflow_{config['keyframe_select_mode']}'
    if config['warp_noise']:
        video_name += '_warp'
    if config['keyframe_select_radix'] == 1:
        video_name += '_key'
    video_root = to_video_multi(roots=[config['save_path']], output=config['save_path'], 
                                fps=24, name=video_name)
elif not config['run_ebsynth']:
    print('to translate full video with ebsynth, install ebsynth and run:')
else:
    print('translating full video with:')

    video_cap = cv2.VideoCapture(config['file_path'])    
    fps = int(video_cap.get(cv2.CAP_PROP_FPS))
    o_video = os.path.join(config['save_path'], 'blend.mp4')
    max_process = config['max_process']
    save_path = config['save_path']
    key_ind = io.StringIO()
    for k in keys:
        print('%d'%(k), end=' ', file=key_ind)
    cmd = (
        f'python video_blend.py {save_path} --key keys '
        f'--key_ind {key_ind.getvalue()} --output {o_video} --fps {fps} '
        f'--n_proc {max_process} -ps')

    print('\n```')
    print(cmd)
    print('```')

    if config['run_ebsynth']:
        os.system(cmd)

    print('\n' + '=' * 100)
    print('Done') 