In [1]:
import os
# WORKING_DIR = './' # change the working directory to the project's absolute path
# os.chdir(WORKING_DIR)

import librosa
import numpy as np
import torch
import torch.optim
import torch.utils.data
from transformers import Wav2Vec2FeatureExtractor
import pickle

from dataset.dataset_config import dataset_config
from loss.loss import UniTalkerLoss
from models.unitalker import UniTalker
from utils.utils import get_parser, get_template_verts, get_audio_encoder_dim
from utils import config

def get_all_audios(audio_root:str):
    wav_f_names = []
    for r, _, f in os.walk(audio_root):
        for wav in f:
            if wav.endswith('.wav') or wav.endswith('.mp3'):
                relative_path = os.path.join(r, wav)
                relative_path = os.path.relpath(relative_path,
                                                audio_root)
                wav_f_names.append(relative_path)
    wav_f_names = sorted(wav_f_names)
    return wav_f_names

def split_long_audio(
    audio: np.ndarray,
    processor:Wav2Vec2FeatureExtractor
):
    # audio = audio.squeeze(0)
    a, b = 25, 5
    sr = 16000 
    total_length = len(audio) /sr
    reps = max(0, int(np.ceil((total_length - a) / (a - b)))) + 1
    in_audio_split_list = []
    start, end = 0, int(a * sr)
    step = int((a - b) * sr)
    for i in range(reps):
        audio_split = audio[start:end]
        audio_split = np.squeeze(
            processor(audio_split, sampling_rate=sr).input_values)
        in_audio_split_list.append(audio_split)
        start += step
        end += step
    return in_audio_split_list

def merge_out_list(out_list: list, fps:int):
    if len(out_list) == 1:
        return out_list[0]
    a, b = 25, 5
    left_weight = np.linspace(1, 0, b * fps)[:, np.newaxis]
    right_weight = 1 - left_weight
    a = a * fps 
    b = b * fps 
    offset = a - b

    out_length = len(out_list[-1]) + offset * (len(out_list) - 1)
    merged_out = np.empty((out_length, out_list[-1].shape[-1]),
                            dtype=out_list[-1].dtype)
    merged_out[:a] = out_list[0]
    for out_piece in out_list[1:]:
        merged_out[a - b:a] = left_weight * merged_out[
            a - b:a] + right_weight * out_piece[:b]
        merged_out[a:a + offset] = out_piece[b:]
        a += offset
    return merged_out

condition_id_config = {
    'D0': 3,
    'D1': 3,
    'D2': 0,
    'D3': 0,
    'D4': 0,
    'D5': 0,
    'D6': 4,
    'D7': 0,
}
template_id_config = {
    'D0': 3,
    'D1': 3,
    'D2': 0,
    'D3': 0,
    'D4': 0,
    'D5': 0,
    'D6': 4,
    'D7': 0,
}



cfg = config.load_cfg_from_cfg_file('./config/unitalker.yaml')
# if args.opts is not None:
#     cfg = config.merge_cfg_from_list(cfg, args.opts)
cfg.condition_id = 'common'
cfg.dataset = cfg.dataset.split(',')
cfg.demo_dataset = cfg.demo_dataset.split(',')

print('Weight path: ', cfg.weight_path)

checkpoint = torch.load(cfg.weight_path, map_location='cpu')
cfg.identity_num = len(checkpoint['decoder.learnable_style_emb.weight'])

dataset_name = 'D1' # FLAME

start_idx = 0
annot_type = dataset_config[dataset_name]['annot_type']
id_num = dataset_config[dataset_name]['subjects']
end_idx = start_idx + id_num
local_condition_idx = condition_id_config[dataset_name]
template_idx = template_id_config[dataset_name]
template = get_template_verts(cfg.data_root, dataset_name, template_idx)
template_id_config[dataset_name] = torch.Tensor(template.reshape(1, -1))
if cfg.condition_id == 'each':
    condition_id_config[dataset_name] = torch.tensor(
        start_idx + local_condition_idx).reshape(1)
elif cfg.condition_id == 'common':
    condition_id_config[dataset_name] = torch.tensor(cfg.identity_num - 1).reshape(1)
else:
    try:
        condition_id = int(cfg.condition_id)
        condition_id_config[dataset_name] = torch.tensor(
            condition_id).reshape(1)
    except ValueError:
        assert cfg.condition_id in dataset_config.keys()
        condition_id_config[dataset_name] = torch.tensor(
            start_idx + local_condition_idx).reshape(1)
start_idx = end_idx


cfg.audio_encoder_feature_dim = get_audio_encoder_dim(cfg.audio_encoder_repo)

model = UniTalker(cfg)
model.load_state_dict(checkpoint, strict=False)
model.eval()
model.cuda()

processor = Wav2Vec2FeatureExtractor.from_pretrained(cfg.audio_encoder_repo,)
loss_module = UniTalkerLoss(cfg).cuda()


  from .autonotebook import tqdm as notebook_tqdm


Weight path:  ./pretrained_models/UniTalker-B-D0-D7.pt


Some weights of the model checkpoint at microsoft/wavlm-base-plus were not used when initializing WavLMModel: ['encoder.pos_conv_embed.conv.weight_g', 'encoder.pos_conv_embed.conv.weight_v']
- This IS expected if you are initializing WavLMModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing WavLMModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of WavLMModel were not initialized from the model checkpoint at microsoft/wavlm-base-plus and are newly initialized: ['encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictio

In [2]:
# settings
wav_path = './test_audios/can_you_feel_the_love_tonight_clip.wav'
save_path = './can_you_feel_the_love_tonight_clip.npy'
fps = 30


# run unitalker model
with torch.no_grad():
    annot_type = 'flame_params_from_dadhead'
    audio_data, sr = librosa.load(wav_path, sr=16000)
    audio_data = np.squeeze(processor(audio_data, sampling_rate=sr).input_values)

    audio_data_splits = split_long_audio(audio_data, processor)

    template = template_id_config[dataset_name].cuda()
    scale = dataset_config[dataset_name]['scale']
    template = scale * template
    condition_id = condition_id_config[dataset_name].cuda()

    out_list = []
    for audio_data in audio_data_splits:
        audio_data = torch.Tensor(audio_data[None]).cuda()

        frame_num = round(audio_data.shape[-1] / 16000 * fps)
        hidden_states = model.audio_encoder(
            audio_data, frame_num=frame_num, interpolate_pos=model.interpolate_pos)
        hidden_states = hidden_states.last_hidden_state
        decoder_out = model.decoder(hidden_states, condition_id, frame_num)
        out_motion = model.out_head_dict[annot_type](decoder_out)[0]
        out_list.append(out_motion.detach().cpu().numpy())

    #out = np.concatenate(out_list, axis=0)
    out = merge_out_list(out_list, fps=fps)


out_dict = {
    'exp': out[:, 300:400], # extract expression coefficients
    'jaw': out[:, 400:403], # extract jaw pose
    'fps': fps
}

# save the sequence of driving signals (later used in Gaussian Dejavu)
with open(save_path, 'wb') as f:
    pickle.dump(out_dict, f)