In [1]:
%env HF_HUB_CACHE=/ist-nas/ist-share/vision/huggingface_hub/
%env CUDA_VISIBLE_DEVICES=1
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

env: HF_HUB_CACHE=/ist-nas/ist-share/vision/huggingface_hub/
env: CUDA_VISIBLE_DEVICES=1
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
import os
import gc
import time
import yaml
import psutil
import random
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanTransformer3DModel
from diffusers.utils import load_video, export_to_video
from diffusers.video_processor import VideoProcessor
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
from transformers import UMT5EncoderModel, AutoTokenizer
from accelerate import cpu_offload
from tqdm.notebook import tqdm
from dataclasses import dataclass, asdict, field
from contextlib import contextmanager
from typing import Dict

def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True)
    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False
seed_everything(0)

MODEL_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16

vae, scheduler, tokenizer, text_encoder, transformer = None, None, None, None, None
video_processor, mask_processor = None, None

print(f"using device: {DEVICE}, dtype: {DTYPE}")

using device: cuda, dtype: torch.bfloat16


In [3]:
if vae is None:
    vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=DTYPE, local_files_only=True)
    # vae.enable_tiling()
    vae_scale_factor_temporal = 2 ** sum(vae.temperal_downsample)
    vae_scale_factor_spatial = 2 ** len(vae.temperal_downsample)
if scheduler is None:
    scheduler = UniPCMultistepScheduler.from_pretrained(MODEL_ID, subfolder="scheduler", local_files_only=True)
if tokenizer is None:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer", model_max_length=512, local_files_only=True)
if text_encoder is None:
    text_encoder = UMT5EncoderModel.from_pretrained(MODEL_ID, subfolder="text_encoder", torch_dtype=DTYPE, local_files_only=True)
if transformer is None:
    transformer = WanTransformer3DModel.from_pretrained(MODEL_ID, subfolder="transformer", torch_dtype=DTYPE, local_files_only=True)
    transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=DTYPE)

models = [text_encoder, transformer, vae]
for m in models:
    cpu_offload(m, DEVICE) # automatically move unused models to cpu

video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
mask_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial, do_binarize=True, do_normalize=False, do_convert_grayscale=True)

print(f'vae scale factor temporal: {vae_scale_factor_temporal}, spatial: {vae_scale_factor_spatial}')

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

vae scale factor temporal: 4, spatial: 8


In [4]:
@contextmanager
def track_memory_usage():
    peak_memory = {}
    torch.cuda.reset_peak_memory_stats(DEVICE)
    start_gpu = torch.cuda.memory_allocated(DEVICE)/(1024**3)
    start_cpu = psutil.Process(os.getpid()).memory_info().rss/(1024**3)

    try:
        yield peak_memory
    finally:
        peak_memory["gpu_gb"] = round(torch.cuda.max_memory_allocated(DEVICE)/(1024**3), 2)
        current_gpu = torch.cuda.memory_allocated(DEVICE)/(1024**3)
        peak_memory["cpu_gb"] = round(psutil.Process(os.getpid()).memory_info().rss/(1024**3), 2)
        print(f"gpu peak: {peak_memory['gpu_gb']} gb, current: {current_gpu:.2f} gb, start: {start_gpu:.2f} gb")
        print(f'cpu rss: {peak_memory["cpu_gb"]} gb, start: {start_cpu:.2f} gb')


In [5]:
@torch.no_grad()
def sdedit_video_inpainting_pipeline(
    input_video_frames,   # List[PIL.Image.Image]
    input_mask_frames,    # b f c h w
    height,
    width,
    prompt,
    neg_prompt,
    strength,
    num_inference_steps,
    guidance_scale,
    dir,
    debug,
):
    print('1. preparing timesteps...')
    scheduler.set_timesteps(num_inference_steps, DEVICE)
    original_timesteps = scheduler.timesteps

    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
    start_step = max(num_inference_steps - init_timestep, 0) 
    timesteps = original_timesteps[start_step * scheduler.order :]
    num_inference_steps = num_inference_steps - start_step
    latent_timestep = timesteps[:1] # shape: [1]
    print(f'scheduler config: {scheduler.config}')
    print(f'active timesteps (strength={strength}, {len(timesteps)}/{len(original_timesteps)} steps): {timesteps.cpu().numpy()}')

    print(f'\n2. encoding video frames...')
    video_tensor = video_processor.preprocess_video(input_video_frames, height, width).to(DEVICE, DTYPE) # [B, C, F, H, W], range [0, 1]

    num_channels_latents = transformer.config.in_channels # 16
    shape = (
        1,                                                              # b
        num_channels_latents,                                           # latent c
        (video_tensor.size(2) - 1) // vae_scale_factor_temporal + 1,    # latent f
        height // vae_scale_factor_spatial,                             # latent h
        width // vae_scale_factor_spatial,                              # latent w
    )
    initial_latents = [vae.encode(vid.unsqueeze(0)).latent_dist.sample() for vid in video_tensor]
    initial_latents = torch.cat(initial_latents, dim=0).to(DTYPE)

    latents_mean = (torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(DEVICE, DTYPE))
    latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(DEVICE, DTYPE)
    init_latents = (initial_latents - latents_mean) * latents_std 

    noise = torch.randn(shape).to(DEVICE)
    latents = scheduler.add_noise(init_latents, noise, latent_timestep)
    print('latents', latents.shape)

    print(f'\n3. preparing masks...')
    mask_tensor = mask_processor.preprocess_video(input_mask_frames, height, width).to(DEVICE, DTYPE) # [B, 1, F, H, W], (0, 1)
    
    mask = torch.nn.functional.interpolate(
        mask_tensor, 
        size=shape[2:],
        mode='trilinear'
    ).to(DEVICE, DTYPE)
    print('mask', mask.shape)

    print(f'\n4. encoding prompts...')
    pos_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    pos_input_ids = pos_inputs.input_ids
    pos_attention_mask = pos_inputs.attention_mask # [1, 512]
    pos_seq_lens = pos_attention_mask.gt(0).sum(dim=1).long()
    pos_text_embeds = text_encoder(pos_input_ids.to(DEVICE), pos_attention_mask.to(DEVICE)).last_hidden_state 
    pos_text_embeds = pos_text_embeds.to(DEVICE, DTYPE) # [1, 512, 4096]
    pos_trimmed_embedds = [u[:v] for u, v in zip(pos_text_embeds, pos_seq_lens)] # [1, seq_len, 4096]
    prompt_embeds = torch.stack([
        torch.cat([u, u.new_zeros(tokenizer.model_max_length - u.size(0), u.size(1))]) for u in pos_trimmed_embedds
    ], dim=0) # [1, 512, 4096]

    neg_inputs = tokenizer(neg_prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    neg_input_ids = neg_inputs.input_ids
    neg_attention_mask = neg_inputs.attention_mask
    neg_seq_lens = neg_attention_mask.gt(0).sum(dim=1).long()
    neg_text_embeds = text_encoder(neg_input_ids.to(DEVICE), neg_attention_mask.to(DEVICE)).last_hidden_state
    neg_text_embeds = neg_text_embeds.to(DEVICE, DTYPE)
    neg_trimmed_embeds = [u[:v] for u, v in zip(neg_text_embeds, neg_seq_lens)]
    negative_prompt_embeds = torch.stack([torch.cat([
        u, u.new_zeros(tokenizer.model_max_length - u.size(0), u.size(1))]) 
        for u in neg_trimmed_embeds
    ], dim=0)

    prompt_embeds.to(DTYPE)
    negative_prompt_embeds.to(DTYPE)
    print('prompt', prompt_embeds.shape)
    print('negative prompt', negative_prompt_embeds.shape)

    print(f'\n5. denoising...')
    if debug: torch.save(latents, f'{dir}/latents_before_denoising.pt')
    progress_bar = tqdm(timesteps, total=num_inference_steps)
    for i, t in enumerate(progress_bar):
        latent_model_input = latents.to(DTYPE)

        noise_pred = transformer(
            hidden_states=latent_model_input,
            timestep=t.expand(1), # convert scalar to [t]
            encoder_hidden_states=prompt_embeds,
            return_dict=False,
        )[0]

        noise_pred_uncond = transformer(
            hidden_states=latent_model_input,
            timestep=t.expand(1),
            encoder_hidden_states=negative_prompt_embeds,
            return_dict=False,
        )[0]
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)

        # x_t -> x_t-1
        latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

        # sdedit
        init_latents_proper = init_latents
        init_mask = mask
        if i < len(timesteps) - 1:
            noise_timestep = timesteps[i+1]
            init_latents_proper = scheduler.add_noise(init_latents_proper, noise, torch.tensor([noise_timestep]))
        latents = (1-init_mask) * init_latents_proper + init_mask * latents
    if debug: torch.save(latents, f'{dir}/latents_after_denoising.pt')

    print(f'\n6. decoding...')
    latents = latents.to(vae.dtype)
    latents_mean = (torch.tensor(vae.config.latents_mean).view(1, vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype))
    latents_std = 1.0 / torch.tensor(vae.config.latents_std).view(1, vae.config.z_dim, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents/latents_std + latents_mean

    if debug: torch.save(latents, f'{dir}/latents_before_decode.pt')
    output = vae.decode(latents, return_dict=False)[0] # [1, 3, F, H, W]
    output = video_processor.postprocess_video(output) # (F, 3, H, W)
    return output

In [6]:
@dataclass
class Config:
    dir: str = f'output/sdedit/{time.strftime("%m%d")}/{time.strftime("%H%M")}'
    input_video: str = "input/landmark/process/16fps_720x1280_centercrop_41/man_video.mp4"
    height: int = 1280
    width: int = 720
    strength: float = 0.7
    num_inference_steps: int = 10
    guidance_scale: float = 7.0
    # prompt: str = "A man is speaking straight to the camera. He is bald, has beard, and is wearing a white shirt. His mouth opens and closes, naturally revealing his teeth as he gives his speech. He is eloquently pronouncing each word, moving his head and changing his facial expression as he talks."
    prompt: str = "A man with beautiful teeth speaking"
    # prompt: str = "A professional man speaking with a flawless, radiant smile—symmetrical white teeth, no gaps or imperfections, and a natural-looking dental appearance."
    neg_prompt: str = "色调艳丽，过曝，静态，细节模糊不清，字幕，风格，作品，画作，画面，静止，整体发灰，最差质量，低质量，JPEG压缩残留，丑陋的，残缺的，多余的手指，画得不好的手部，画得不好的脸部，畸形的，毁容的，形态畸形的肢体，手指融合，静止不动的画面，杂乱的背景，三条腿，背景人很多，倒着走"
    fps: int = 16

    timing_stats: Dict[str, float] = field(default_factory=dict)
    memory_stats: Dict[str, float] = field(default_factory=dict)

    @property
    def pipeline_kwargs(self):
        return {k: v for k, v in asdict(self).items() if k in ['dir', 'height', 'width', 'strength', 'num_inference_steps', 'guidance_scale', 'prompt', 'neg_prompt']}

config = Config()
os.makedirs(config.dir, exist_ok=True)
config_file = f'{config.dir}/config.yml'
with open(config_file, "w") as f:
    yaml.safe_dump(asdict(config), f, sort_keys=False)

start_inference = time.time()
with track_memory_usage() as peak_memory:
    inpainted_video = sdedit_video_inpainting_pipeline(
        input_video_frames=load_video(config.input_video),
        input_mask_frames=load_video(config.input_video.replace('video', 'mask')),
        debug=False,
        **config.pipeline_kwargs,
    )[0]
infer_time = time.time() - start_inference

config.timing_stats.update({"infer_seconds": round(infer_time, 2)})
config.memory_stats.update(peak_memory)
with open(config_file, "w") as f:
    yaml.safe_dump(asdict(config), f, sort_keys=False)

print(f'\nexporting to video...')
export_to_video(inpainted_video, f'{config.dir}/out_nb.mp4', fps=config.fps, quality=10)

1. preparing timesteps...
scheduler config: FrozenDict([('num_train_timesteps', 1000), ('beta_start', 0.0001), ('beta_end', 0.02), ('beta_schedule', 'linear'), ('trained_betas', None), ('solver_order', 2), ('prediction_type', 'flow_prediction'), ('thresholding', False), ('dynamic_thresholding_ratio', 0.995), ('sample_max_value', 1.0), ('predict_x0', True), ('solver_type', 'bh2'), ('lower_order_final', True), ('disable_corrector', []), ('solver_p', None), ('use_karras_sigmas', False), ('use_exponential_sigmas', False), ('use_beta_sigmas', False), ('use_flow_sigmas', True), ('flow_shift', 3.0), ('timestep_spacing', 'linspace'), ('steps_offset', 0), ('final_sigmas_type', 'zero'), ('rescale_betas_zero_snr', False), ('_class_name', 'UniPCMultistepScheduler'), ('_diffusers_version', '0.33.0.dev0')])
active timesteps (strength=0.7, 7/10 steps): [874 817 749 666 562 428 249]

2. encoding video frames...
latents torch.Size([1, 16, 11, 160, 90])

3. preparing masks...
mask torch.Size([1, 1, 11, 

  0%|          | 0/7 [00:00<?, ?it/s]


6. decoding...
gpu peak: 10.29 gb, current: 0.06 gb, start: 0.00 gb
cpu rss: 39.49 gb, start: 38.08 gb

exporting to video...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


'output/sdedit/0519/1236/out_nb.mp4'

In [None]:
latents = torch.load('output/sdedit/0514/2157/latents_before_decode.pt')
latents.shape



torch.Size([1, 16, 6, 240, 136])