In [None]:
import os
import re
import tempfile
from TTS.api import TTS
from argparse import Namespace
import torch
from omegaconf import OmegaConf

import imageio
import glob
import pickle
import cv2
import copy
from tqdm import tqdm
import numpy as np
from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from moviepy.editor import *

from musetalk.utils.utils import get_file_type,get_video_fps,datagen
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder,get_bbox_range
from musetalk.utils.blending import get_image
from musetalk.utils.utils import load_all_model


In [None]:
ffmpeg_path = "./ffmpeg-6.1-amd64-stati"
if ffmpeg_path is None:
    print("please download ffmpeg-static and export to FFMPEG_PATH. \nFor example: export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static")
elif ffmpeg_path not in os.getenv('PATH'):
    print("add ffmpeg to path")
    os.environ["PATH"] = f"{ffmpeg_path}:{os.environ['PATH']}"

In [None]:
class TTSTalker():

    def __init__(self) -> None:
        # Get device
        device = "cuda" if torch.cuda.is_available() else "cpu"

        # Init TTS
        self.tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)

    def test(self, text, audio, language='en'):

        tempf  = tempfile.NamedTemporaryFile(
                delete = False,
                suffix = ('.'+'wav'),
            )

        self.tts.tts_to_file(text, speaker_wav=audio, language=language, file_path=tempf.name)

        return tempf.name

In [None]:
# load model weights
audio_processor,vae,unet,pe  = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
tts_talker = TTSTalker()
gfgan_model_path = './gfpgan/weights/GFPGANv1.4.pth'
realesrgan_path = './realesrgan/weights/RealESRGAN_x4plus.pth'

In [None]:
def is_valid_image(file):
    pattern = re.compile(r'\d{8}\.png')
    return pattern.match(file)

In [None]:
@torch.no_grad()
def inference(audio_path, video_path, bbox_shift, enhance_face=False, enhance_background=False):

    args_dict={"result_dir":'/tmp', "fps":24, "batch_size":8, "output_vid_name":'', "use_saved_coord":True}#same with inferenece script
    args = Namespace(**args_dict)

    input_basename = os.path.abspath(video_path).split('/')[-2]
    audio_basename  = os.path.basename(audio_path).split('.')[0]
    output_basename = f"{input_basename}_{audio_basename}"

    result_img_save_path = os.path.join(args.result_dir, output_basename) # related to video & audio inputs

    landmarks_save_path = os.path.join(video_path, "landmarks.pkl") # only related to video input

    os.makedirs(result_img_save_path,exist_ok =True)

    output_temp_vid_name = os.path.join(args.result_dir, output_basename+"_temp.mp4")
    if args.output_vid_name=="":
        output_vid_name = os.path.join(args.result_dir, output_basename+".mp4")
    else:
        output_vid_name = os.path.join(args.result_dir, args.output_vid_name)

    ############################################## extract frames from source video ##############################################
    if get_file_type(video_path)=="video":
        save_dir_full = os.path.join(args.video_path, "images")
        os.makedirs(save_dir_full,exist_ok = True)

        reader = imageio.get_reader(video_path)

        for i, im in enumerate(reader):
            imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
        input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
        fps = get_video_fps(video_path)
    else: # input img folder
        input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
        input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
        fps = args.fps

    ############################################## extract audio feature ##############################################
    whisper_feature = audio_processor.audio2feat(audio_path)
    whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)

    ############################################## preprocess input image  ##############################################

    if os.path.exists(landmarks_save_path) and args.use_saved_coord:
        print("using extracted coordinates")
        with open(landmarks_save_path,'rb') as f:
            dict = pickle.load(f)
            frame_list_cycle = dict["frame_list_cycle"]
            coord_list_cycle = dict["coord_list_cycle"]
            input_latent_list_cycle = dict["input_latent_list_cycle"]
        frame_list = read_imgs(input_img_list)
    else:
        print("extracting landmarks...time consuming")
        coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)

        i = 0
        input_latent_list = []
        for bbox, frame in zip(coord_list, frame_list):
            if bbox == coord_placeholder:
                continue
            x1, y1, x2, y2 = bbox
            crop_frame = frame[y1:y2, x1:x2]
            crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
            latents = vae.get_latents_for_unet(crop_frame)
            input_latent_list.append(latents)

        # to smooth the first and the last frame
        frame_list_cycle = frame_list + frame_list[::-1]
        coord_list_cycle = coord_list + coord_list[::-1]
        input_latent_list_cycle = input_latent_list + input_latent_list[::-1]

        with open(landmarks_save_path, 'wb') as f:
            dict = {'frame_list_cycle': frame_list_cycle, 'coord_list_cycle': coord_list_cycle, 'input_latent_list_cycle': input_latent_list_cycle}
            pickle.dump(dict, f)

    ############################################## inference batch by batch ##############################################
    print("inferencing talking images...")
    video_num = len(whisper_chunks)
    batch_size = args.batch_size
    gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
    res_frame_list = []
    for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
        
        tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
        audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
        audio_feature_batch = pe(audio_feature_batch)
        
        pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
        recon = vae.decode_latents(pred_latents)
        for res_frame in recon:
            res_frame_list.append(res_frame)

    ############################################## pad to full image ##############################################

    if enhance_face:
        print("enhancing talking images...")
        bg_upsampler = None
        if enhance_background:
            if not torch.cuda.is_available():  # CPU
                import warnings
                warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
                                'If you really want to use it, please modify the corresponding codes.')
                bg_upsampler = None
            else:

                model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
                bg_upsampler = RealESRGANer(
                    scale=2,
                    model_path=realesrgan_path,
                    model=model,
                    tile=400,
                    tile_pad=10,
                    pre_pad=0,
                    half=True)  # need to set False in CPU mode

        restorer = GFPGANer(
            model_path=gfgan_model_path,
            upscale=1,
            arch='clean',
            channel_multiplier=2,
            bg_upsampler=bg_upsampler)

    print("padding talking image to original video...")
    for i, res_frame in enumerate(tqdm(res_frame_list)):
        bbox = coord_list_cycle[i%(len(coord_list_cycle))]
        ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
        x1, y1, x2, y2 = bbox
        try:
            res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
        except:
    #                 print(bbox)
            continue
        
        combine_frame = get_image(ori_frame,res_frame,bbox)

        if enhance_face:
            # gfgan
            img = cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR)

            # restore faces and background if necessary
            cropped_faces, restored_faces, r_img = restorer.enhance(
                img,
                has_aligned=False,
                only_center_face=False,
                paste_back=True)
            
            combine_frame = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)

        cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)

    images = []
    files = [file for file in os.listdir(result_img_save_path) if is_valid_image(file)]
    files.sort(key=lambda x: int(x.split('.')[0]))

    for file in files:
        filename = os.path.join(result_img_save_path, file)
        images.append(imageio.imread(filename))

    imageio.mimwrite(output_temp_vid_name, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')

    # Check if the input_video and audio_path exist
    if not os.path.exists(output_temp_vid_name):
        raise FileNotFoundError(f"Input video file not found: {output_temp_vid_name}")
    if not os.path.exists(audio_path):
        raise FileNotFoundError(f"Audio file not found: {audio_path}")

    # Load the video
    video_clip = VideoFileClip(output_temp_vid_name)

    # Load the audio
    audio_clip = AudioFileClip(audio_path)

    # Set the audio to the video
    video_clip = video_clip.set_audio(audio_clip)

    # Write the output video
    video_clip.write_videofile(output_vid_name)

In [None]:
input_text = "It will be good to get back to the Sleeping Lion. After a fortnight going up and down the Still River, chasing a bad lead on a missing blacksmith, you can almost feel the warmth of the inn’s hearth when Gloomhaven’s walls come into view. You are almost home."
print(os.getcwd)
persona_path = './persona/'

persona = 'Melanie'
reference_video = os.path.join(persona_path, persona, 'images')
reference_audio = os.path.join(persona_path, persona, persona + '.mp3')
configs = os.path.join(persona_path, persona, 'config.yaml')
bbox_shift = OmegaConf.load(configs)['bbox_shift']

# Create audio based on text
processed_audio = tts_talker.test(input_text, reference_audio)



In [None]:
inference(processed_audio, reference_video, bbox_shift)