In [None]:
gpu_id = 1

In [None]:
from glob import glob
from os import path
from functools import partial
import sys

from diffusers import (
    StableDiffusionControlNetPixelSharingPipeline,
    ControlNetModel,
    DDIMScheduler,
    UniPCMultistepScheduler,
)
import numpy as np
import torch, torchvision
import torchvision.transforms.functional as F
from torchvision.utils import flow_to_image
import tensorflow_docs.vis.embed as embed
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

sys.path.append('scripts')
from datasets import get_sintel_data, get_vkitti2_data

def to_h264(video):
    '''
    enforce even dimension, input BHWC
    '''
    if video.shape[1] % 2 != 0:
        video = torch.cat((video, torch.zeros(*video.shape[:1], 1, *video.shape[2:])), dim=1)
        
    if video.shape[2] % 2 != 0:
        video = torch.cat((video, torch.zeros(*video.shape[:2], 1, *video.shape[3:])), dim=2)
        
    return video

In [None]:
start_frame = 0
num_frames = 10
fps = 5

data = get_sintel_data('./sintel', 'bandage_2', start_frame, num_frames)
# data = get_vkitti2_data('./vkitti2', 'Scene01', start_frame, num_frames)

rgbs = data['rgbs']
flows = data['flows']
occlusions = data['occlusions']
conds = data['conds']
backward_coding = not data['is_backward_flow']

In [None]:
H, W = conds.shape[-2:]
if H < W:
    H_512, W_512 = 512, int(512 * W / H)
else:
    H_512, W_512 = int(512 * H / W), 512
# H_512, W_512 = H, W
H_q8 = (H_512 // 8) * 8
W_q8 = (W_512 // 8) * 8
rgbs = torch.nn.functional.interpolate(rgbs, (H_q8, W_q8), mode='bilinear')
conds = torch.nn.functional.interpolate(conds, (H_q8, W_q8), mode='bilinear')

In [None]:
viz = torch.cat((
    flow_to_image(flows.permute(0, 3, 1, 2)).permute(0, 2, 3, 1),
    occlusions.unsqueeze(-1).repeat(1, 1, 1, 3) * 255,
), dim=2)
viz = torch.cat((
    viz,
    torch.nn.functional.interpolate(conds[1:], (H*2, W*2), mode='bilinear', antialias=True).permute(0, 2, 3, 1) * 255,
    torch.nn.functional.interpolate(rgbs[1:], (H*2, W*2), mode='bilinear', antialias=True).permute(0, 2, 3, 1) * 255,
), dim=1)
viz.shape

In [None]:
torchvision.io.write_video('media/combined.mp4', to_h264(viz), fps, options={'crf': '18'})
embed.embed_file('media/combined.mp4')

In [None]:
from controlnet_aux import LineartDetector
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
processor_partial = partial(processor, coarse=False)

processor.to(f'cuda:{gpu_id}')
pred_conds = []
for cond in conds.permute(0, 2, 3, 1) * 255:
    pred_conds.append(processor_partial(cond, output_type='np'))
pred_conds = np.stack(pred_conds)
pred_conds = torch.from_numpy(pred_conds).permute(0, 3, 1, 2) / 255
pred_conds = torch.nn.functional.interpolate(pred_conds, conds.shape[-2:], mode='bilinear')
processor.to('cpu')

conds = pred_conds

In [None]:
torchvision.io.write_video('media/lineart.mp4', conds.permute(0, 2, 3, 1) * 255, fps, options={'crf': '18'})
embed.embed_file('media/lineart.mp4')

In [None]:
print(flows.shape)
print(occlusions.shape)
print(conds.shape)

In [None]:
controlnet = ControlNetModel.from_pretrained(
    # "./hf-models/control_v11p_sd15_normalbae",
    # "./hf-models/control_v11p_sd15_openpose",
    "./hf-models/control_v11p_sd15_lineart",
    # "./hf-models/control_v11f1p_sd15_depth",
    torch_dtype=torch.float16
)


# ckpt = './hf-models/stable-diffusion-v1-5'
ckpt = './hf-models/majicmixRealistic_betterV2V25'
# ckpt = './hf-models/majicmixRealistic_v6/'
# ckpt = './hf-models/xxmix9realistic'

scheduler = DDIMScheduler.from_pretrained(ckpt, subfolder="scheduler")
# scheduler = UniPCMultistepScheduler.from_pretrained(ckpt, subfolder="scheduler")

pipe = StableDiffusionControlNetPixelSharingPipeline.from_pretrained(
    ckpt, controlnet=controlnet,
    scheduler=scheduler, safety_checker=None, requires_safety_checker=False,
    torch_dtype=torch.float16
)
pipe.enable_xformers_memory_efficient_attention()
_ = pipe.to(f'cuda:{gpu_id}')
# pipe.load_lora_weights("./civitai", weight_name="./bronze_statue.safetensors")

In [None]:
kernel_width=10

In [None]:
std=0.5
kernel = torch.signal.windows.gaussian(kernel_width, std=std*kernel_width)
kernel /= kernel.sum()
plt.plot(kernel)

In [None]:
output = pipe(
    # original_image=rgbs[None],
    # strength=0.9,
    image=conds[None],
    prompt='',
    negative_prompt='pink, spring, flower, chinese, baby',
    generator=torch.manual_seed(1234567),
    num_inference_steps=20,
    output_type='pt',
    cpu_offload_text_encoder=True,
    unet_batch_size=10,
    vae_batch_size=10,
    harmonization_scale=0.8,
    kernel=kernel,
    guess_mode=True,
    
    # mixer=mixer,
    flows=flows[None],
    occlusions=occlusions[None],
    mix_mode='global_average', # global_average | convolution
    kernel_width=kernel_width,
    average_small_bucket=True,
    backward_coding=backward_coding,
)

video = output.images

video = torch.nn.functional.interpolate(video, (H, W), mode='bilinear')
video *= 255
torchvision.io.write_video('media/vcn.mp4', to_h264(video.permute(0, 2, 3, 1).cpu()), fps, options={'crf': '18'})
embed.embed_file('media/vcn.mp4')

In [None]:
output.timer