In [1]:
# import torch
# from musetalk.models.unet import UNet,PositionalEncoding

# device="cuda"

# model_path = "checkpoints/unet_checkpoint_step000380000.pth"

# model_bin_path = "./cunet.bin"
# unet = UNet(unet_config="./models/musetalk/musetalk.json", model_path=None, device=device, use_float16=True)
# # gpu
# weights = torch.load(model_path)["state_dict"]
# unet.model.load_state_dict(weights)

# # cpu version
# # weights = torch.load(model_path, map_location="cpu")["state_dict"]
# # unet.model.load_state_dict(weights)

# torch.save(unet.model.state_dict(), model_bin_path)

In [8]:
import torch
from musetalk.whisper.audio2feature import Audio2Feature
from musetalk.models.vae import VAE
from musetalk.models.unet import UNet,PositionalEncoding

device="cuda"
use_float16=True

audio_processor = Audio2Feature(model_path="./models/whisper/tiny.pt", device=device)
vae = VAE(model_path = "./models/sd-vae-ft-mse/", device=device, use_float16=True)

model_bin_path = "./cunet.bin"
unet = UNet(unet_config="./models/musetalk/musetalk.json", model_path=model_bin_path, device=device, use_float16=True)

pe = PositionalEncoding(d_model=384)
timesteps = torch.tensor([0], device=device)

In [9]:
from musetalk.utils.preprocessing import get_landmark_and_bbox_from_frames, coord_placeholder
import cv2

# video_file = "data/video/sun.mp4"
video_file = "/data/apps/MuseTalk/data/256/driver.mp4"

video_stream = cv2.VideoCapture(video_file)
fps = video_stream.get(cv2.CAP_PROP_FPS)
frames = []
while 1:
    still_reading, frame = video_stream.read()
    if not still_reading:
        video_stream.release()
        break
    frames.append(frame.copy())

bbox_shift = 0
bbox_list, frame_list = get_landmark_and_bbox_from_frames(frames, bbox_shift)

print(f"frame_list: {len(frame_list)}, bbox_list: {len(bbox_list)}")

100%|██████████| 300/300 [01:08<00:00,  4.38it/s]

frame_list: 300, bbox_list: 300





In [10]:
import cv2
from tqdm import tqdm

def bbox_process(cords):
    new_cords = []
    for cord in cords:
        x1, y1, x2, y2 = cord
        x1 = max(x1, 0)
        y1 = max(y1, 0)
        new_cords.append((x1, y1, x2, y2))
    return new_cords

bbox_list = bbox_process(bbox_list)

i = 0
input_latent_list = []
input_face_list = []
for bbox, frame in tqdm(zip(bbox_list, frame_list)):
    if bbox == coord_placeholder:
        continue
    x1, y1, x2, y2 = bbox
    crop_frame = frame[y1:y2, x1:x2]
    crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
    input_face_list.append(crop_frame.copy())
    latents = vae.get_latents_for_unet(crop_frame)
    input_latent_list.append(latents)
    

300it [00:06, 45.16it/s]


In [11]:
audio_path = "data/audio/sun.wav"
whisper_feature = audio_processor.audio2feat(audio_path)
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)

video in 30.0 FPS, audio idx in 50FPS


In [12]:
import copy
import numpy as np
from PIL import Image
from musetalk.utils.blending import get_image, get_image_prepare_material,get_image_blending

batch_size = 1
whisper_batch, latent_batch = [], []

whisper_length = len(whisper_chunks)
results = []
for i in tqdm(range(0, whisper_length, batch_size)):
    
    audio_feature_batch = whisper_chunks[i:i+batch_size]
    audio_feature_batch = np.stack(audio_feature_batch)
    audio_feature_batch = torch.from_numpy(audio_feature_batch)

    face_batch = [input_face_list[idx%(len(input_latent_list))] for idx in range(i, min(len(whisper_chunks), i+batch_size))]
    face_batch = np.stack(face_batch)
    
    latent_batch = [input_latent_list[idx%(len(input_latent_list))] for idx in range(i, min(len(whisper_chunks), i+batch_size))]
    latent_batch = torch.cat(latent_batch, dim=0)

    audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype) # torch, B, 5*N,384
    audio_feature_batch = pe(audio_feature_batch)
    latent_batch = latent_batch.to(dtype=unet.model.dtype).to(unet.device)
    
    pred_latents = unet.model(latent_batch.half(), timesteps.half(), encoder_hidden_states=audio_feature_batch.half()).sample
    
    recon = vae.decode_latents(pred_latents)
        
    res_frame_list = []
    for res_frame in recon:
        res_frame_list.append(res_frame)

    for offset, res_frame in enumerate(res_frame_list):
        idx = i + offset
        bbox = bbox_list[idx%(len(bbox_list))]
        ori_frame = copy.deepcopy(frame_list[i%(len(frame_list))])
        x1, y1, x2, y2 = bbox
        res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
        combine_frame = get_image(ori_frame, res_frame, bbox)
        results.append(copy.deepcopy(combine_frame))


100%|██████████| 676/676 [02:12<00:00,  5.10it/s]


In [13]:
vfile = "gen_test.mp4"

frame_h, frame_w = results[0].shape[:-1]
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(vfile, fourcc, fps, (frame_w, frame_h))
for f in results:
    out.write(f.astype(np.uint8))
out.release()