Skip to content

Commit

Permalink
add stage2 training codes
Browse files Browse the repository at this point in the history
  • Loading branch information
lixunsong authored and songtao-liu-mt committed Jan 17, 2024
1 parent f6066e8 commit d31bf2a
Show file tree
Hide file tree
Showing 5 changed files with 978 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ output/
mlruns/
data/

*.pth
*.pth
*.pt
*.pkl
*.bin
59 changes: 59 additions & 0 deletions configs/train/stage2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
data:
train_bs: 1
train_width: 512
train_height: 512
meta_paths:
- "./data/fashion_meta.json"
sample_rate: 4
n_sample_frames: 24

solver:
gradient_accumulation_steps: 1
mixed_precision: 'fp16'
enable_xformers_memory_efficient_attention: True
gradient_checkpointing: True
max_train_steps: 10000
max_grad_norm: 1.0
# lr
learning_rate: 1e-5
scale_lr: False
lr_warmup_steps: 1
lr_scheduler: 'constant'

# optimizer
use_8bit_adam: True
adam_beta1: 0.9
adam_beta2: 0.999
adam_weight_decay: 1.0e-2
adam_epsilon: 1.0e-8

val:
validation_steps: 20


noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false

base_model_path: './pretrained_weights/stable-diffusion-v1-5'
vae_model_path: './pretrained_weights/sd-vae-ft-mse'
image_encoder_path: './pretrained_weights/sd-image-variations-diffusers/image_encoder'
mm_path: './pretrained_weights/mm_sd_v15_v2.ckpt'

weight_dtype: 'fp16' # [fp16, fp32]
uncond_ratio: 0.1
noise_offset: 0.05
snr_gamma: 5.0
enable_zero_snr: True
stage1_ckpt_dir: './exp_output/stage1'
stage1_ckpt_step: 980

seed: 12580
resume_from_checkpoint: ''
checkpointing_steps: 2000
exp_name: 'stage2'
output_dir: './exp_output'
137 changes: 137 additions & 0 deletions src/dataset/dance_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json
import random
from typing import List

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor


class HumanDanceVideoDataset(Dataset):
def __init__(
self,
sample_rate,
n_sample_frames,
width,
height,
img_scale=(1.0, 1.0),
img_ratio=(0.9, 1.0),
drop_ratio=0.1,
data_meta_paths=["./data/fashion_meta.json"],
):
super().__init__()
self.sample_rate = sample_rate
self.n_sample_frames = n_sample_frames
self.width = width
self.height = height
self.img_scale = img_scale
self.img_ratio = img_ratio

vid_meta = []
for data_meta_path in data_meta_paths:
vid_meta.extend(json.load(open(data_meta_path, "r")))
self.vid_meta = vid_meta

self.clip_image_processor = CLIPImageProcessor()

self.pixel_transform = transforms.Compose(
[
transforms.RandomResizedCrop(
(height, width),
scale=self.img_scale,
ratio=self.img_ratio,
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

self.cond_transform = transforms.Compose(
[
transforms.RandomResizedCrop(
(height, width),
scale=self.img_scale,
ratio=self.img_ratio,
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
]
)

self.drop_ratio = drop_ratio

def augmentation(self, images, transform, state=None):
if state is not None:
torch.set_rng_state(state)
if isinstance(images, List):
transformed_images = [transform(img) for img in images]
ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
else:
ret_tensor = transform(images) # (c, h, w)
return ret_tensor

def __getitem__(self, index):
video_meta = self.vid_meta[index]
video_path = video_meta["video_path"]
kps_path = video_meta["kps_path"]

video_reader = VideoReader(video_path)
kps_reader = VideoReader(kps_path)

assert len(video_reader) == len(
kps_reader
), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"

video_length = len(video_reader)

clip_length = min(
video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(
start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
).tolist()

# read frames and kps
vid_pil_image_list = []
pose_pil_image_list = []
for index in batch_index:
img = video_reader[index]
vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
img = kps_reader[index]
pose_pil_image_list.append(Image.fromarray(img.asnumpy()))

ref_img_idx = random.randint(0, video_length - 1)
ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())

# transform
state = torch.get_rng_state()
pixel_values_vid = self.augmentation(
vid_pil_image_list, self.pixel_transform, state
)
pixel_values_pose = self.augmentation(
pose_pil_image_list, self.cond_transform, state
)
pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
clip_ref_img = self.clip_image_processor(
images=ref_img, return_tensors="pt"
).pixel_values[0]

sample = dict(
video_dir=video_path,
pixel_values_vid=pixel_values_vid,
pixel_values_pose=pixel_values_pose,
pixel_values_ref_img=pixel_values_ref_img,
clip_ref_img=clip_ref_img,
)

return sample

def __len__(self):
return len(self.vid_meta)
6 changes: 5 additions & 1 deletion src/pipelines/pipeline_pose2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
@dataclass
class Pose2VideoPipelineOutput(BaseOutput):
videos: Union[torch.Tensor, np.ndarray]
middle_results: Union[torch.Tensor, np.ndarray]


class Pose2VideoPipeline(DiffusionPipeline):
Expand Down Expand Up @@ -429,6 +428,11 @@ def __call__(
noise_pred_text - noise_pred_uncond
)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]

# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
Expand Down
Loading

0 comments on commit d31bf2a

Please sign in to comment.