In [None]:
## For visualisation
!pip install denku==0.1.3

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import sys
sys.path.append('..')

import torch
from denku import show_images
from diffusers import (
    AutoencoderKLWan,
    FlowMatchEulerDiscreteScheduler,
    UniPCMultistepScheduler
)
from diffusers.utils import export_to_video, load_video
from controlnet_aux import HEDdetector, CannyDetector
from transformers import UMT5EncoderModel, T5TokenizerFast


from wan_controlnet import WanControlnet
from wan_transformer import CustomWanTransformer3DModel
from wan_controlnet_pipeline import WanControlnetPipeline

%load_ext autoreload
%autoreload 2

In [None]:
base_model_path = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
controlnet_model_path = "TheDenk/wan2.1-t2v-1.3b-controlnet-hed-v1"

In [None]:
tokenizer = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer")
text_encoder = UMT5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder", torch_dtype=torch.bfloat16)
vae = AutoencoderKLWan.from_pretrained(base_model_path, subfolder="vae", torch_dtype=torch.float32)
transformer = CustomWanTransformer3DModel.from_pretrained(base_model_path, subfolder="transformer", torch_dtype=torch.bfloat16)
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(base_model_path, subfolder="scheduler")
# flow_shift = 3.0 # 5.0 for 720P, 3.0 for 480P
# scheduler = UniPCMultistepScheduler(prediction_type='flow_prediction', use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift)

In [None]:
controlnet = WanControlnet.from_pretrained(controlnet_model_path, torch_dtype=torch.bfloat16)

In [None]:
pipe = WanControlnetPipeline(
    tokenizer=tokenizer, 
    text_encoder=text_encoder,
    transformer=transformer,
    vae=vae, 
    controlnet=controlnet,
    scheduler=scheduler,
)
pipe = pipe.to(device="cuda")
pipe.enable_model_cpu_offload()

In [None]:
def init_controlnet_processor(controlnet_type):
    if controlnet_type in ['canny']:
        return controlnet_mapping[controlnet_type]()
    return controlnet_mapping[controlnet_type].from_pretrained('lllyasviel/Annotators').to(device='cuda')

controlnet_mapping = {
    'hed': HEDdetector,
    'canny': CannyDetector,
}

controlnet_processor = init_controlnet_processor("hed")

In [None]:
video_path = "../resources/physical-1.mp4"
num_frames = 81

video_frames = load_video(video_path)[:num_frames]
controlnet_frames = [controlnet_processor(x) for x in video_frames]

show_images(video_frames[::20], figsize=(16, 8))
show_images(controlnet_frames[::20], figsize=(16, 8))

In [None]:
prompt = "In a cozy kitchen, a golden retriever wearing a white chef's hat and a blue apron stands at the table, holding a sharp kitchen knife and skillfully slicing fresh tomatoes. Its tail sways gently, and its gaze is focused and gentle. There are already several neatly arranged tomatoes on the wooden chopping board in front of me. The kitchen has soft lighting, with various kitchen utensils hanging on the walls and several pots of green plants placed on the windowsill."
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=480,
    width=832,
    num_frames=num_frames,
    guidance_scale=5.0,
    num_inference_steps=50,
    generator=torch.Generator(device="cuda").manual_seed(42),
    output_type="pil",

    controlnet_frames=controlnet_frames,
    controlnet_guidance_start=0.0,
    controlnet_guidance_end=0.8,
    controlnet_weight=0.8,
    controlnet_stride=3,
).frames[0]

show_images(output[::20], figsize=(16, 8))

In [None]:
output_path = "output.mp4"
export_to_video(output, output_path, fps=16)