From d31bf2a1819060723f1fe220bda9f5c5ccbdf251 Mon Sep 17 00:00:00 2001 From: "xunsong.li" Date: Wed, 17 Jan 2024 09:21:32 +0000 Subject: [PATCH] add stage2 training codes --- .gitignore | 5 +- configs/train/stage2.yaml | 59 +++ src/dataset/dance_video.py | 137 +++++ src/pipelines/pipeline_pose2vid.py | 6 +- train_stage_2.py | 773 +++++++++++++++++++++++++++++ 5 files changed, 978 insertions(+), 2 deletions(-) create mode 100644 configs/train/stage2.yaml create mode 100644 src/dataset/dance_video.py create mode 100644 train_stage_2.py diff --git a/.gitignore b/.gitignore index ef98b73..569e4c1 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,7 @@ output/ mlruns/ data/ -*.pth \ No newline at end of file +*.pth +*.pt +*.pkl +*.bin \ No newline at end of file diff --git a/configs/train/stage2.yaml b/configs/train/stage2.yaml new file mode 100644 index 0000000..086fa1e --- /dev/null +++ b/configs/train/stage2.yaml @@ -0,0 +1,59 @@ +data: + train_bs: 1 + train_width: 512 + train_height: 512 + meta_paths: + - "./data/fashion_meta.json" + sample_rate: 4 + n_sample_frames: 24 + +solver: + gradient_accumulation_steps: 1 + mixed_precision: 'fp16' + enable_xformers_memory_efficient_attention: True + gradient_checkpointing: True + max_train_steps: 10000 + max_grad_norm: 1.0 + # lr + learning_rate: 1e-5 + scale_lr: False + lr_warmup_steps: 1 + lr_scheduler: 'constant' + + # optimizer + use_8bit_adam: True + adam_beta1: 0.9 + adam_beta2: 0.999 + adam_weight_decay: 1.0e-2 + adam_epsilon: 1.0e-8 + +val: + validation_steps: 20 + + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + +base_model_path: './pretrained_weights/stable-diffusion-v1-5' +vae_model_path: './pretrained_weights/sd-vae-ft-mse' +image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder' +mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt' + +weight_dtype: 'fp16' # [fp16, fp32] +uncond_ratio: 0.1 +noise_offset: 0.05 +snr_gamma: 5.0 +enable_zero_snr: True +stage1_ckpt_dir: './exp_output/stage1' +stage1_ckpt_step: 980 + +seed: 12580 +resume_from_checkpoint: '' +checkpointing_steps: 2000 +exp_name: 'stage2' +output_dir: './exp_output' \ No newline at end of file diff --git a/src/dataset/dance_video.py b/src/dataset/dance_video.py new file mode 100644 index 0000000..7f68bb0 --- /dev/null +++ b/src/dataset/dance_video.py @@ -0,0 +1,137 @@ +import json +import random +from typing import List + +import numpy as np +import pandas as pd +import torch +import torchvision.transforms as transforms +from decord import VideoReader +from PIL import Image +from torch.utils.data import Dataset +from transformers import CLIPImageProcessor + + +class HumanDanceVideoDataset(Dataset): + def __init__( + self, + sample_rate, + n_sample_frames, + width, + height, + img_scale=(1.0, 1.0), + img_ratio=(0.9, 1.0), + drop_ratio=0.1, + data_meta_paths=["./data/fashion_meta.json"], + ): + super().__init__() + self.sample_rate = sample_rate + self.n_sample_frames = n_sample_frames + self.width = width + self.height = height + self.img_scale = img_scale + self.img_ratio = img_ratio + + vid_meta = [] + for data_meta_path in data_meta_paths: + vid_meta.extend(json.load(open(data_meta_path, "r"))) + self.vid_meta = vid_meta + + self.clip_image_processor = CLIPImageProcessor() + + self.pixel_transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + (height, width), + scale=self.img_scale, + ratio=self.img_ratio, + interpolation=transforms.InterpolationMode.BILINEAR, + ), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + self.cond_transform = transforms.Compose( + [ + transforms.RandomResizedCrop( + (height, width), + scale=self.img_scale, + ratio=self.img_ratio, + interpolation=transforms.InterpolationMode.BILINEAR, + ), + transforms.ToTensor(), + ] + ) + + self.drop_ratio = drop_ratio + + def augmentation(self, images, transform, state=None): + if state is not None: + torch.set_rng_state(state) + if isinstance(images, List): + transformed_images = [transform(img) for img in images] + ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w) + else: + ret_tensor = transform(images) # (c, h, w) + return ret_tensor + + def __getitem__(self, index): + video_meta = self.vid_meta[index] + video_path = video_meta["video_path"] + kps_path = video_meta["kps_path"] + + video_reader = VideoReader(video_path) + kps_reader = VideoReader(kps_path) + + assert len(video_reader) == len( + kps_reader + ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}" + + video_length = len(video_reader) + + clip_length = min( + video_length, (self.n_sample_frames - 1) * self.sample_rate + 1 + ) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace( + start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int + ).tolist() + + # read frames and kps + vid_pil_image_list = [] + pose_pil_image_list = [] + for index in batch_index: + img = video_reader[index] + vid_pil_image_list.append(Image.fromarray(img.asnumpy())) + img = kps_reader[index] + pose_pil_image_list.append(Image.fromarray(img.asnumpy())) + + ref_img_idx = random.randint(0, video_length - 1) + ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy()) + + # transform + state = torch.get_rng_state() + pixel_values_vid = self.augmentation( + vid_pil_image_list, self.pixel_transform, state + ) + pixel_values_pose = self.augmentation( + pose_pil_image_list, self.cond_transform, state + ) + pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state) + clip_ref_img = self.clip_image_processor( + images=ref_img, return_tensors="pt" + ).pixel_values[0] + + sample = dict( + video_dir=video_path, + pixel_values_vid=pixel_values_vid, + pixel_values_pose=pixel_values_pose, + pixel_values_ref_img=pixel_values_ref_img, + clip_ref_img=clip_ref_img, + ) + + return sample + + def __len__(self): + return len(self.vid_meta) diff --git a/src/pipelines/pipeline_pose2vid.py b/src/pipelines/pipeline_pose2vid.py index 847e289..51356f5 100644 --- a/src/pipelines/pipeline_pose2vid.py +++ b/src/pipelines/pipeline_pose2vid.py @@ -22,7 +22,6 @@ @dataclass class Pose2VideoPipelineOutput(BaseOutput): videos: Union[torch.Tensor, np.ndarray] - middle_results: Union[torch.Tensor, np.ndarray] class Pose2VideoPipeline(DiffusionPipeline): @@ -429,6 +428,11 @@ def __call__( noise_pred_text - noise_pred_uncond ) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 diff --git a/train_stage_2.py b/train_stage_2.py new file mode 100644 index 0000000..d7d9498 --- /dev/null +++ b/train_stage_2.py @@ -0,0 +1,773 @@ +import argparse +import copy +import logging +import math +import os +import os.path as osp +import random +import time +import warnings +from collections import OrderedDict +from datetime import datetime +from pathlib import Path +from tempfile import TemporaryDirectory + +import diffusers +import mlflow +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from einops import rearrange +from omegaconf import OmegaConf +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPVisionModelWithProjection + +from src.dataset.dance_video import HumanDanceVideoDataset +from src.models.mutual_self_attention import ReferenceAttentionControl +from src.models.pose_guider import PoseGuider +from src.models.unet_2d_condition import UNet2DConditionModel +from src.models.unet_3d import UNet3DConditionModel +from src.pipelines.pipeline_pose2vid import Pose2VideoPipeline +from src.utils.util import ( + delete_additional_ckpt, + import_filename, + read_frames, + save_videos_grid, + seed_everything, +) + +warnings.filterwarnings("ignore") + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +class Net(nn.Module): + def __init__( + self, + reference_unet: UNet2DConditionModel, + denoising_unet: UNet3DConditionModel, + pose_guider: PoseGuider, + reference_control_writer, + reference_control_reader, + ): + super().__init__() + self.reference_unet = reference_unet + self.denoising_unet = denoising_unet + self.pose_guider = pose_guider + self.reference_control_writer = reference_control_writer + self.reference_control_reader = reference_control_reader + + def forward( + self, + noisy_latents, + timesteps, + ref_image_latents, + clip_image_embeds, + pose_img, + uncond_fwd: bool = False, + ): + pose_cond_tensor = pose_img.to(device="cuda") + pose_fea = self.pose_guider(pose_cond_tensor) + + if not uncond_fwd: + ref_timesteps = torch.zeros_like(timesteps) + self.reference_unet( + ref_image_latents, + ref_timesteps, + encoder_hidden_states=clip_image_embeds, + return_dict=False, + ) + self.reference_control_reader.update(self.reference_control_writer) + + model_pred = self.denoising_unet( + noisy_latents, + timesteps, + pose_cond_fea=pose_fea, + encoder_hidden_states=clip_image_embeds, + ).sample + + return model_pred + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +def log_validation( + vae, + image_enc, + net, + scheduler, + accelerator, + width, + height, + clip_length=24, + generator=None, +): + logger.info("Running validation... ") + + ori_net = accelerator.unwrap_model(net) + reference_unet = ori_net.reference_unet + denoising_unet = ori_net.denoising_unet + pose_guider = ori_net.pose_guider + + if generator is None: + generator = torch.manual_seed(42) + tmp_denoising_unet = copy.deepcopy(denoising_unet) + tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16) + + pipe = Pose2VideoPipeline( + vae=vae, + image_encoder=image_enc, + reference_unet=reference_unet, + denoising_unet=tmp_denoising_unet, + pose_guider=pose_guider, + scheduler=scheduler, + ) + pipe = pipe.to(accelerator.device) + + test_cases = [ + ( + "./configs/inference/ref_images/anyone-3.png", + "./configs/inference/pose_videos/anyone-video-1_kps.mp4", + ), + ( + "./configs/inference/ref_images/anyone-2.png", + "./configs/inference/pose_videos/anyone-video-2_kps.mp4", + ), + ] + + results = [] + for test_case in test_cases: + ref_image_path, pose_video_path = test_case + ref_name = Path(ref_image_path).stem + pose_name = Path(pose_video_path).stem + ref_image_pil = Image.open(ref_image_path).convert("RGB") + + pose_list = [] + pose_tensor_list = [] + pose_images = read_frames(pose_video_path) + pose_transform = transforms.Compose( + [transforms.Resize((height, width)), transforms.ToTensor()] + ) + for pose_image_pil in pose_images[:clip_length]: + pose_tensor_list.append(pose_transform(pose_image_pil)) + pose_list.append(pose_image_pil) + + pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w) + pose_tensor = pose_tensor.transpose(0, 1) + + pipeline_output = pipe( + ref_image_pil, + pose_list, + width, + height, + clip_length, + 20, + 3.5, + generator=generator, + ) + video = pipeline_output.videos + + # Concat it with pose tensor + pose_tensor = pose_tensor.unsqueeze(0) + video = torch.cat([video, pose_tensor], dim=0) + + results.append({"name": f"{ref_name}_{pose_name}", "vid": video}) + + del tmp_denoising_unet + del pipe + torch.cuda.empty_cache() + + return results + + +def main(cfg): + kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) + accelerator = Accelerator( + gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, + mixed_precision=cfg.solver.mixed_precision, + log_with="mlflow", + project_dir="./mlruns", + kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if cfg.seed is not None: + seed_everything(cfg.seed) + + exp_name = cfg.exp_name + save_dir = f"{cfg.output_dir}/{exp_name}" + if accelerator.is_main_process: + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + inference_config_path = "./configs/inference/inference_v2.yaml" + infer_config = OmegaConf.load(inference_config_path) + + if cfg.weight_dtype == "fp16": + weight_dtype = torch.float16 + elif cfg.weight_dtype == "fp32": + weight_dtype = torch.float32 + else: + raise ValueError( + f"Do not support weight dtype: {cfg.weight_dtype} during training" + ) + + sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) + if cfg.enable_zero_snr: + sched_kwargs.update( + rescale_betas_zero_snr=True, + timestep_spacing="trailing", + prediction_type="v_prediction", + ) + val_noise_scheduler = DDIMScheduler(**sched_kwargs) + sched_kwargs.update({"beta_schedule": "scaled_linear"}) + train_noise_scheduler = DDIMScheduler(**sched_kwargs) + + image_enc = CLIPVisionModelWithProjection.from_pretrained( + cfg.image_encoder_path, + ).to(dtype=weight_dtype, device="cuda") + vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( + "cuda", dtype=weight_dtype + ) + reference_unet = UNet2DConditionModel.from_pretrained( + cfg.base_model_path, + subfolder="unet", + ).to(device="cuda", dtype=weight_dtype) + + denoising_unet = UNet3DConditionModel.from_pretrained_2d( + cfg.base_model_path, + cfg.mm_path, + subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container( + infer_config.unet_additional_kwargs + ), + ).to(device="cuda") + + pose_guider = PoseGuider( + conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) + ).to(device="cuda", dtype=weight_dtype) + + stage1_ckpt_dir = cfg.stage1_ckpt_dir + stage1_ckpt_step = cfg.stage1_ckpt_step + denoising_unet.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"), + map_location="cpu", + ), + strict=False, + ) + reference_unet.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"), + map_location="cpu", + ), + strict=False, + ) + pose_guider.load_state_dict( + torch.load( + os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"), + map_location="cpu", + ), + strict=False, + ) + + # Freeze + vae.requires_grad_(False) + image_enc.requires_grad_(False) + reference_unet.requires_grad_(False) + denoising_unet.requires_grad_(False) + pose_guider.requires_grad_(False) + + # Set motion module learnable + for name, module in denoising_unet.named_modules(): + if "motion_modules" in name: + for params in module.parameters(): + params.requires_grad = True + + reference_control_writer = ReferenceAttentionControl( + reference_unet, + do_classifier_free_guidance=False, + mode="write", + fusion_blocks="full", + ) + reference_control_reader = ReferenceAttentionControl( + denoising_unet, + do_classifier_free_guidance=False, + mode="read", + fusion_blocks="full", + ) + + net = Net( + reference_unet, + denoising_unet, + pose_guider, + reference_control_writer, + reference_control_reader, + ) + + if cfg.solver.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + reference_unet.enable_xformers_memory_efficient_attention() + denoising_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError( + "xformers is not available. Make sure it is installed correctly" + ) + + if cfg.solver.gradient_checkpointing: + reference_unet.enable_gradient_checkpointing() + denoising_unet.enable_gradient_checkpointing() + + if cfg.solver.scale_lr: + learning_rate = ( + cfg.solver.learning_rate + * cfg.solver.gradient_accumulation_steps + * cfg.data.train_bs + * accelerator.num_processes + ) + else: + learning_rate = cfg.solver.learning_rate + + # Initialize the optimizer + if cfg.solver.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + trainable_params = list(filter(lambda p: p.requires_grad, net.parameters())) + logger.info(f"Total trainable params {len(trainable_params)}") + optimizer = optimizer_cls( + trainable_params, + lr=learning_rate, + betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), + weight_decay=cfg.solver.adam_weight_decay, + eps=cfg.solver.adam_epsilon, + ) + + # Scheduler + lr_scheduler = get_scheduler( + cfg.solver.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=cfg.solver.lr_warmup_steps + * cfg.solver.gradient_accumulation_steps, + num_training_steps=cfg.solver.max_train_steps + * cfg.solver.gradient_accumulation_steps, + ) + + train_dataset = HumanDanceVideoDataset( + width=cfg.data.train_width, + height=cfg.data.train_height, + n_sample_frames=cfg.data.n_sample_frames, + sample_rate=cfg.data.sample_rate, + img_scale=(1.0, 1.0), + data_meta_paths=cfg.data.meta_paths, + ) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4 + ) + + # Prepare everything with our `accelerator`. + ( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) = accelerator.prepare( + net, + optimizer, + train_dataloader, + lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil( + len(train_dataloader) / cfg.solver.gradient_accumulation_steps + ) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil( + cfg.solver.max_train_steps / num_update_steps_per_epoch + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + run_time = datetime.now().strftime("%Y%m%d-%H%M") + accelerator.init_trackers( + exp_name, + init_kwargs={"mlflow": {"run_name": run_time}}, + ) + # dump config file + mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml") + + # Train! + total_batch_size = ( + cfg.data.train_bs + * accelerator.num_processes + * cfg.solver.gradient_accumulation_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") + logger.info( + f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" + ) + logger.info( + f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" + ) + logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if cfg.resume_from_checkpoint: + if cfg.resume_from_checkpoint != "latest": + resume_dir = cfg.resume_from_checkpoint + else: + resume_dir = save_dir + # Get the most recent checkpoint + dirs = os.listdir(resume_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.load_state(os.path.join(resume_dir, path)) + accelerator.print(f"Resuming from checkpoint {path}") + global_step = int(path.split("-")[1]) + + first_epoch = global_step // num_update_steps_per_epoch + resume_step = global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm( + range(global_step, cfg.solver.max_train_steps), + disable=not accelerator.is_local_main_process, + ) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, num_train_epochs): + train_loss = 0.0 + t_data_start = time.time() + for step, batch in enumerate(train_dataloader): + t_data = time.time() - t_data_start + with accelerator.accumulate(net): + # Convert videos to latent space + pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype) + with torch.no_grad(): + video_length = pixel_values_vid.shape[1] + pixel_values_vid = rearrange( + pixel_values_vid, "b f c h w -> (b f) c h w" + ) + latents = vae.encode(pixel_values_vid).latent_dist.sample() + latents = rearrange( + latents, "(b f) c h w -> b c f h w", f=video_length + ) + latents = latents * 0.18215 + + noise = torch.randn_like(latents) + if cfg.noise_offset > 0: + noise += cfg.noise_offset * torch.randn( + (latents.shape[0], latents.shape[1], 1, 1, 1), + device=latents.device, + ) + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint( + 0, + train_noise_scheduler.num_train_timesteps, + (bsz,), + device=latents.device, + ) + timesteps = timesteps.long() + + pixel_values_pose = batch["pixel_values_pose"] # (bs, f, c, H, W) + pixel_values_pose = pixel_values_pose.transpose( + 1, 2 + ) # (bs, c, f, H, W) + + uncond_fwd = random.random() < cfg.uncond_ratio + clip_image_list = [] + ref_image_list = [] + for batch_idx, (ref_img, clip_img) in enumerate( + zip( + batch["pixel_values_ref_img"], + batch["clip_ref_img"], + ) + ): + if uncond_fwd: + clip_image_list.append(torch.zeros_like(clip_img)) + else: + clip_image_list.append(clip_img) + ref_image_list.append(ref_img) + + with torch.no_grad(): + ref_img = torch.stack(ref_image_list, dim=0).to( + dtype=vae.dtype, device=vae.device + ) + ref_image_latents = vae.encode( + ref_img + ).latent_dist.sample() # (bs, d, 64, 64) + ref_image_latents = ref_image_latents * 0.18215 + + clip_img = torch.stack(clip_image_list, dim=0).to( + dtype=image_enc.dtype, device=image_enc.device + ) + clip_img = clip_img.to(device="cuda", dtype=weight_dtype) + clip_image_embeds = image_enc( + clip_img.to("cuda", dtype=weight_dtype) + ).image_embeds + clip_image_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d) + + # add noise + noisy_latents = train_noise_scheduler.add_noise( + latents, noise, timesteps + ) + + # Get the target for loss depending on the prediction type + if train_noise_scheduler.prediction_type == "epsilon": + target = noise + elif train_noise_scheduler.prediction_type == "v_prediction": + target = train_noise_scheduler.get_velocity( + latents, noise, timesteps + ) + else: + raise ValueError( + f"Unknown prediction type {train_noise_scheduler.prediction_type}" + ) + + # ---- Forward!!! ----- + model_pred = net( + noisy_latents, + timesteps, + ref_image_latents, + clip_image_embeds, + pixel_values_pose, + uncond_fwd=uncond_fwd, + ) + + if cfg.snr_gamma == 0: + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="mean" + ) + else: + snr = compute_snr(train_noise_scheduler, timesteps) + if train_noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective requires that we add one to SNR values before we divide by them. + snr = snr + 1 + mse_loss_weights = ( + torch.stack( + [snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 + ).min(dim=1)[0] + / snr + ) + loss = F.mse_loss( + model_pred.float(), target.float(), reduction="none" + ) + loss = ( + loss.mean(dim=list(range(1, len(loss.shape)))) + * mse_loss_weights + ) + loss = loss.mean() + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean() + train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_( + trainable_params, + cfg.solver.max_grad_norm, + ) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if accelerator.sync_gradients: + reference_control_reader.clear() + reference_control_writer.clear() + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % cfg.val.validation_steps == 0: + if accelerator.is_main_process: + generator = torch.Generator(device=accelerator.device) + generator.manual_seed(cfg.seed) + + sample_dicts = log_validation( + vae=vae, + image_enc=image_enc, + net=net, + scheduler=val_noise_scheduler, + accelerator=accelerator, + width=cfg.data.train_width, + height=cfg.data.train_height, + clip_length=cfg.data.n_sample_frames, + generator=generator, + ) + + for sample_id, sample_dict in enumerate(sample_dicts): + sample_name = sample_dict["name"] + vid = sample_dict["vid"] + with TemporaryDirectory() as temp_dir: + out_file = Path( + f"{temp_dir}/{global_step:06d}-{sample_name}.gif" + ) + save_videos_grid(vid, out_file, n_rows=2) + mlflow.log_artifact(out_file) + + logs = { + "step_loss": loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + "td": f"{t_data:.2f}s", + } + t_data_start = time.time() + progress_bar.set_postfix(**logs) + + if global_step >= cfg.solver.max_train_steps: + break + # save model after each epoch + if accelerator.is_main_process: + save_path = os.path.join(save_dir, f"checkpoint-{global_step}") + delete_additional_ckpt(save_dir, 1) + accelerator.save_state(save_path) + # save motion module only + unwrap_net = accelerator.unwrap_model(net) + save_checkpoint( + unwrap_net.denoising_unet, + save_dir, + "motion_module", + global_step, + total_limit=3, + ) + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + accelerator.end_training() + + +def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None): + save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") + + if total_limit is not None: + checkpoints = os.listdir(save_dir) + checkpoints = [d for d in checkpoints if d.startswith(prefix)] + checkpoints = sorted( + checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) + ) + + if len(checkpoints) >= total_limit: + num_to_remove = len(checkpoints) - total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(save_dir, removing_checkpoint) + os.remove(removing_checkpoint) + + mm_state_dict = OrderedDict() + state_dict = model.state_dict() + for key in state_dict: + if "motion_module" in key: + mm_state_dict[key] = state_dict[key] + + torch.save(mm_state_dict, save_path) + + +def decode_latents(vae, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = rearrange(latents, "b c f h w -> (b f) c h w") + # video = self.vae.decode(latents).sample + video = [] + for frame_idx in tqdm(range(latents.shape[0])): + video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample) + video = torch.cat(video) + video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) + video = (video / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 + video = video.cpu().float().numpy() + return video + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml") + args = parser.parse_args() + + if args.config[-5:] == ".yaml": + config = OmegaConf.load(args.config) + elif args.config[-3:] == ".py": + config = import_filename(args.config).cfg + else: + raise ValueError("Do not support this format config file") + main(config)