In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from IPython.display import Audio
import matplotlib.pyplot as plt
import soundfile as sf
import pickle

In [2]:
from models.codec.kmeans.kmeans_model import KMeans, KMeansEMA
from models.tts.soundstorm.soundstorm_model import SoundStorm
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from transformers import Wav2Vec2BertModel
import safetensors
from utils.util import load_config



In [3]:
from utils.g2p_new.g2p import phonemizer_g2p
LANG2CODE = {
    'zh': 349,
    'en': 350,
    'ja': 351,
    'ko': 352,
    'fr': 353,
    'de': 354,
}
def g2p(text, language):
    return phonemizer_g2p(text, language)

In [4]:
cfg = load_config("/opt/tiger/SpeechGeneration/egs/tts/SoundStorm/exp_config_16k_emilia_llama_add_phone.json")
print(cfg)

{'model_type': 'SoundStorm', 'dataset': ['emilia'], 'preprocess': {'hop_size': 320, 'sample_rate': 16000, 'processed_dir': '', 'valid_file': 'valid.json', 'train_file': 'train.json', 'use_phone_cond': True}, 'model': {'soundstorm': {'num_quantizer': 8, 'hidden_size': 1024, 'num_layers': 16, 'num_heads': 16, 'codebook_size': 1024, 'cfg_scale': 0.15, 'mask_layer_schedule': 'linear', 'use_cond_code': True, 'cond_codebook_size': 2048, 'cond_dim': 1024, 'use_llama_style': True, 'use_phone_cond': True, 'use_pretrained_model': True, 'pretrained_path': '/mnt/bn/yuacnwang-speech/ckpt/soundstorm/soundstorm_16k_kmeans_2048_emilia_50k_llama_add_phone/pretrained_ckpt/epoch-0009_step-0085000_loss-4.031860/model.safetensors', 'zero_init_cross_attn': True}, 'kmeans': {'type': 'kmeans_ema', 'stat_mean_var_path': '/mnt/bn/yuacnwang-speech/ckpt/semantic_kmeans/mls_wav2vec2bert_stats.pt', 'kmeans': {'codebook_size': 2048, 'codebook_dim': 1024, 'kmeans_init': True, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 

In [5]:
def build_soundstorm(cfg, pretrained_path, device):
    soundstorm_model = SoundStorm(cfg=cfg.model.soundstorm)
    if ".bin" in pretrained_path:
        soundstorm_model .load_state_dict(torch.load(pretrained_path))
    elif ".safetensors" in pretrained_path:
        safetensors.torch.load_model(soundstorm_model, pretrained_path)
    soundstorm_model.eval()
    soundstorm_model.to(device)
    return soundstorm_model

def build_kmeans_model(cfg, device):
    if cfg.model.kmeans.type == "kmeans":
        kmeans_model = KMeans(cfg=cfg.model.kmeans.kmeans)
    elif cfg.model.kmeans.type == "kmeans_ema":
        kmeans_model = KMeansEMA(cfg=cfg.model.kmeans.kmeans)
    kmeans_model.eval()
    pretrained_path =cfg.model.kmeans.pretrained_path
    if ".bin" in pretrained_path:
        kmeans_model.load_state_dict(torch.load(pretrained_path))
    elif ".safetensors" in pretrained_path:
        safetensors.torch.load_model(kmeans_model, pretrained_path)
    kmeans_model.to(device)
    return kmeans_model

def build_semantic_model(cfg, device):
    semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
    semantic_model.eval()
    semantic_model.to(device)
    # layer_idx = 15
    # if layer_idx == 23:
    #     output_idx = 0
    # else:
    #     output_idx = layer_idx + 2
    layer_idx = 15
    output_idx = 17
    stat_mean_var = torch.load(cfg.model.kmeans.stat_mean_var_path)
    semantic_mean = stat_mean_var["mean"]
    semantic_std = torch.sqrt(stat_mean_var["var"])
    semantic_mean = semantic_mean.to(device)
    semantic_std = semantic_std.to(device)
    # print(
    #     "semantic mean: ", semantic_mean, "semantic std: ", semantic_std
    # )
    return semantic_model, semantic_mean, semantic_std

def build_codec_model(cfg, device):
    codec_encoder = CodecEncoder(cfg=cfg.model.codec.encoder)
    codec_decoder = CodecDecoder(cfg=cfg.model.codec.decoder)
    codec_encoder.load_state_dict(
        torch.load(cfg.model.codec.encoder.pretrained_path)
    )
    codec_decoder.load_state_dict(
        torch.load(cfg.model.codec.decoder.pretrained_path)
    )
    # codec_decoder = codec_decoder.quantizer  # we only need the quantizer
    codec_encoder.eval()
    codec_decoder.eval()
    codec_encoder.to(device)
    codec_decoder.to(device)
    return codec_encoder, codec_decoder

In [7]:
device = torch.device("cuda:0")
soundstorm_pretrained_path = "/mnt/bn/yuacnwang-speech/ckpt/soundstorm/soundstorm_16k_kmeans_2048_emilia_50k_llama_add_phone/checkpoint/epoch-0005_step-0033000_loss-5.789692/model.safetensors"
soundstorm_model = build_soundstorm(cfg, soundstorm_pretrained_path, device)
semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg, device)
kmeans_model = build_kmeans_model(cfg, device)
codec_encoder, codec_decoder = build_codec_model(cfg, device)

use diffllama nar cross


In [9]:
safetensors.torch.load_model(soundstorm_model, "/mnt/bn/yuacnwang-speech/ckpt/soundstorm/soundstorm_16k_kmeans_2048_emilia_50k_llama_add_phone/checkpoint/epoch-0005_step-0033000_loss-5.789692/model.safetensors")

(set(), [])

In [10]:
from transformers import SeamlessM4TFeatureExtractor
processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

In [11]:
@torch.no_grad()
def extract_acoustic_code(speech):
    vq_emb = codec_encoder(speech.unsqueeze(1))
    _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)
    acoustic_code = vq.permute(
        1, 2, 0
    )  # (num_quantizer, T, C) -> (T, C, num_quantizer)
    return acoustic_code

@torch.no_grad()
def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):
    vq_emb = semantic_model(
        input_features=input_features,
        attention_mask=attention_mask,
        output_hidden_states=True,
    )
    feat = vq_emb.hidden_states[17]  # (B, T, C)
    feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)

    semantic_code, _ = kmeans_model.quantize(feat)  # (B, T)
    return semantic_code

@torch.no_grad()
def extract_features(speech, processor):
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt")
    input_features = inputs["input_features"][0]
    attention_mask = inputs["attention_mask"][0]
    return input_features, attention_mask

In [12]:
dataset_path = "/mnt/bn/yuacnwang-speech/dataset/Emilia/emilia/emilia_50k/wav"
dataset_wav_info = []
for subset in os.listdir(dataset_path):
    for wav in os.listdir(os.path.join(dataset_path, subset)):
        wav_path = os.path.join(dataset_path, subset, wav)
        dataset_wav_info.append(wav_path)
print(len(dataset_wav_info))

1632077


In [30]:
random_idx = 2223
wav_path = dataset_wav_info[random_idx]
meta_path = wav_path.replace("/wav/", "/meta/").replace(".wav", ".pkl")
with open(meta_path, 'rb') as f:
    meta_info = pickle.load(f)
uid = wav_path.split("/")[-1].split(".")[0]
print(uid)
speech, sr = librosa.load(wav_path, sr=16000)
print(meta_info)
Audio(speech, rate=sr)

Bilibili_105022844_BV18Q4y1W7Ea_0_58
{'text': '第一局,還是要付出一些代價,你才能去制衡,呃,制衡一下科技演員。我覺得首先,你雜技演員這個角色,我覺得該拿還是得拿,就不要,不要審。然後如果你真的想審的話,我覺得上周那個魔術師出現,我覺得後面的隊伍也真香模仿嘛。我覺得魔術師這個角色,第一局拋出來也不虧。', 'start': 623.203441426146, 'end': 642.591441426146, 'speaker': 'SPEAKER_02', 'language': 'zh', 'mos': {'wvmos': 0, 'dnsmos': 3.353840951067398, 'avg': 3.353840951067398}, 'uid': 'Bilibili_105022844_BV18Q4y1W7Ea_0_58'}


In [31]:
phone_id = g2p(meta_info['text'], meta_info['language'])[1]
phone_id = torch.tensor(phone_id, dtype=torch.long)
phone_id = torch.cat([torch.tensor(LANG2CODE[meta_info['language']], dtype=torch.long).reshape(1), phone_id]) # add language token
phone_mask = np.ones(len(phone_id))

In [32]:
# wav_path = "/mnt/bn/yuacnwang-speech/dataset/temp_test/1_douluo_yueyu_(Vocals).wav"
# uid = wav_path.split("/")[-1].split(".")[0]
# print(uid)
# speech, sr = librosa.load(wav_path, sr=16000)

In [33]:
input_fetures, attention_mask = extract_features(speech, processor)
input_fetures = input_fetures.unsqueeze(0)
attention_mask = attention_mask.unsqueeze(0)
semantic_code = extract_semantic_code(semantic_mean.to(device), semantic_std.to(device), input_fetures.to(device), attention_mask.to(device))

In [34]:
print(semantic_code.shape)

torch.Size([1, 969])


In [35]:
acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))

In [36]:
print(acoustic_code.shape)

torch.Size([1, 969, 8])


In [37]:
seq_len = min(semantic_code.shape[1], acoustic_code.shape[1])
semantic_code = semantic_code[:, :seq_len]
acoustic_code = acoustic_code[:, :seq_len, :]

In [38]:
cond = soundstorm_model.cond_emb(semantic_code.to(device))
print(cond.shape)

torch.Size([1, 969, 1024])


In [45]:
prompt = acoustic_code[:,:50*3,:]
predict = soundstorm_model.reverse_diffusion(cond=cond.to(device), prompt=prompt.to(device), temp=1.5, filter_thres=0.98, n_timesteps=[100, 4, 1, 1, 1, 1, 1, 1], cfg=1.0, rescale_cfg=1.0, phone_id=phone_id.unsqueeze(0).to(device))
print(predict.shape)

torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819])
torch.Size([1, 819, 8])
torch.Size([1, 819, 8])


In [49]:
vq_emb = codec_decoder.vq2emb(predict.permute(2,0,1))
recovered_audio = codec_decoder(vq_emb)
recovered_audio = recovered_audio[0][0].cpu().detach().numpy()
# sf.write("/opt/tiger/SpeechGeneration/temp_test_wav/target/{}.wav".format(uid), recovered_audio, samplerate=16000)
Audio(recovered_audio, rate=16000)

In [50]:
prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1))
recovered_prompt_audio = codec_decoder(prompt_vq_emb)
recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().detach().numpy()
# sf.write("/opt/tiger/SpeechGeneration/temp_test_wav/prompt/{}.wav".format(uid), recovered_prompt_audio, samplerate=16000)
Audio(recovered_prompt_audio, rate=16000)

In [51]:
combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
# sf.write("/opt/tiger/SpeechGeneration/temp_test_wav/combine/{}.wav".format(uid), combine_audio, samplerate=16000)
Audio(combine_audio, rate=16000)