Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Loading local model to change reference model #13

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions scripts/audio2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from configs.prompts.test_cases import TestCasesDict
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.model_util import load_models, torch_gc, get_torch_device
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import get_fps, read_frames, save_videos_grid
Expand Down Expand Up @@ -56,21 +56,24 @@ def main():
weight_dtype = torch.float16
else:
weight_dtype = torch.float32


device = get_torch_device()

audio_infer_config = OmegaConf.load(config.audio_inference_config)
# prepare model
a2m_model = Audio2MeshModel(audio_infer_config['a2m_model'])
a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False)
a2m_model.cuda().eval()

vae = AutoencoderKL.from_pretrained(
config.pretrained_vae_path,
).to("cuda", dtype=weight_dtype)

reference_unet = UNet2DConditionModel.from_pretrained(
config.pretrained_base_model_path,
subfolder="unet",
).to(dtype=weight_dtype, device="cuda")
a2m_model.to(device).eval()

(_,_,unet,_,vae,) = load_models(
config.pretrained_base_model_path,
scheduler_name="",
v2=False,
v_pred=False,
weight_dtype=weight_dtype,
)
vae = vae.to(device, dtype=weight_dtype)
reference_unet = unet.to(dtype=weight_dtype, device=device)

inference_config_path = config.inference_config
infer_config = OmegaConf.load(inference_config_path)
Expand All @@ -79,14 +82,14 @@ def main():
config.motion_module_path,
subfolder="unet",
unet_additional_kwargs=infer_config.unet_additional_kwargs,
).to(dtype=weight_dtype, device="cuda")
).to(dtype=weight_dtype, device=device)


pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention

image_enc = CLIPVisionModelWithProjection.from_pretrained(
config.image_encoder_path
).to(dtype=weight_dtype, device="cuda")
).to(dtype=weight_dtype, device=device)

sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
scheduler = DDIMScheduler(**sched_kwargs)
Expand Down Expand Up @@ -115,7 +118,7 @@ def main():
pose_guider=pose_guider,
scheduler=scheduler,
)
pipe = pipe.to("cuda", dtype=weight_dtype)
pipe = pipe.to(device, dtype=weight_dtype)

date_str = datetime.now().strftime("%Y%m%d")
time_str = datetime.now().strftime("%H%M")
Expand Down Expand Up @@ -145,7 +148,7 @@ def main():
ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True)

sample = prepare_audio_feature(audio_path, wav2vec_model_path=audio_infer_config['a2m_model']['model_path'])
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda()
sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().to(device)
sample['audio_feature'] = sample['audio_feature'].unsqueeze(0)

# inference
Expand Down Expand Up @@ -218,6 +221,7 @@ def main():
stream = ffmpeg.input(save_path)
audio = ffmpeg.input(audio_path)
ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()
torch_gc()
os.remove(save_path)

if __name__ == "__main__":
Expand Down
29 changes: 16 additions & 13 deletions scripts/pose2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from configs.prompts.test_cases import TestCasesDict
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.model_util import load_models, torch_gc, get_torch_device
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import get_fps, read_frames, save_videos_grid
Expand Down Expand Up @@ -53,14 +53,17 @@ def main():
else:
weight_dtype = torch.float32

vae = AutoencoderKL.from_pretrained(
config.pretrained_vae_path,
).to("cuda", dtype=weight_dtype)
device = get_torch_device()

reference_unet = UNet2DConditionModel.from_pretrained(
config.pretrained_base_model_path,
subfolder="unet",
).to(dtype=weight_dtype, device="cuda")
(_,_,unet,_,vae,) = load_models(
config.pretrained_base_model_path,
scheduler_name="",
v2=False,
v_pred=False,
weight_dtype=weight_dtype,
)
vae = vae.to(device, dtype=weight_dtype)
reference_unet = unet.to(dtype=weight_dtype, device=device)

inference_config_path = config.inference_config
infer_config = OmegaConf.load(inference_config_path)
Expand All @@ -69,13 +72,13 @@ def main():
config.motion_module_path,
subfolder="unet",
unet_additional_kwargs=infer_config.unet_additional_kwargs,
).to(dtype=weight_dtype, device="cuda")
).to(dtype=weight_dtype, device=device)

pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention

image_enc = CLIPVisionModelWithProjection.from_pretrained(
config.image_encoder_path
).to(dtype=weight_dtype, device="cuda")
).to(dtype=weight_dtype, device=device)

sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
scheduler = DDIMScheduler(**sched_kwargs)
Expand Down Expand Up @@ -104,7 +107,7 @@ def main():
pose_guider=pose_guider,
scheduler=scheduler,
)
pipe = pipe.to("cuda", dtype=weight_dtype)
pipe = pipe.to(device, dtype=weight_dtype)

date_str = datetime.now().strftime("%Y%m%d")
time_str = datetime.now().strftime("%H%M")
Expand Down Expand Up @@ -191,7 +194,7 @@ def main():
stream = ffmpeg.input(save_path)
audio = ffmpeg.input(audio_output)
ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()

torch_gc()
os.remove(save_path)
os.remove(audio_output)

Expand Down
29 changes: 16 additions & 13 deletions scripts/vid2vid.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from configs.prompts.test_cases import TestCasesDict
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.model_util import load_models, torch_gc, get_torch_device
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import get_fps, read_frames, save_videos_grid
Expand Down Expand Up @@ -54,14 +54,17 @@ def main():
else:
weight_dtype = torch.float32

vae = AutoencoderKL.from_pretrained(
config.pretrained_vae_path,
).to("cuda", dtype=weight_dtype)
device = get_torch_device()

reference_unet = UNet2DConditionModel.from_pretrained(
config.pretrained_base_model_path,
subfolder="unet",
).to(dtype=weight_dtype, device="cuda")
(_,_,unet,_,vae,) = load_models(
config.pretrained_base_model_path,
scheduler_name="",
v2=False,
v_pred=False,
weight_dtype=weight_dtype,
)
vae = vae.to(device, dtype=weight_dtype)
reference_unet = unet.to(dtype=weight_dtype, device=device)

inference_config_path = config.inference_config
infer_config = OmegaConf.load(inference_config_path)
Expand All @@ -70,13 +73,13 @@ def main():
config.motion_module_path,
subfolder="unet",
unet_additional_kwargs=infer_config.unet_additional_kwargs,
).to(dtype=weight_dtype, device="cuda")
).to(dtype=weight_dtype, device=device)

pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention
pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention

image_enc = CLIPVisionModelWithProjection.from_pretrained(
config.image_encoder_path
).to(dtype=weight_dtype, device="cuda")
).to(dtype=weight_dtype, device=device)

sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
scheduler = DDIMScheduler(**sched_kwargs)
Expand Down Expand Up @@ -105,7 +108,7 @@ def main():
pose_guider=pose_guider,
scheduler=scheduler,
)
pipe = pipe.to("cuda", dtype=weight_dtype)
pipe = pipe.to(device, dtype=weight_dtype)

date_str = datetime.now().strftime("%Y%m%d")
time_str = datetime.now().strftime("%H%M")
Expand Down Expand Up @@ -224,7 +227,7 @@ def main():
stream = ffmpeg.input(save_path)
audio = ffmpeg.input(audio_output)
ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run()

torch_gc()
os.remove(save_path)
os.remove(audio_output)

Expand Down
Loading