In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from IPython.display import Audio
import matplotlib.pyplot as plt
import soundfile as sf
import pickle

In [None]:
from models.codec.kmeans.kmeans_model import KMeans, KMeansEMA
from models.tts.soundstorm.soundstorm_model import SoundStorm
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from models.tts.text2semantic.t2s_model import T2SLlama
from transformers import Wav2Vec2BertModel
import safetensors
from utils.util import load_config

In [None]:
from utils.g2p_new.g2p import phonemizer_g2p
LANG2CODE = {
    'zh': 349,
    'en': 350,
    'ja': 351,
    'ko': 352,
    'fr': 353,
    'de': 354,
}
def g2p(text, language):
    return phonemizer_g2p(text, language)

In [None]:
cfg = load_config("/opt/tiger/SpeechGeneration/egs/tts/SoundStorm/exp_config_16k_emilia_llama.json")
t2s_cfg = load_config("/opt/tiger/SpeechGeneration/egs/tts/Text2Semantic/exp_config_16k_emilia.json")

In [None]:
def build_soundstorm(cfg, pretrained_path, device):
    soundstorm_model = SoundStorm(cfg=cfg.model.soundstorm)
    if ".bin" in pretrained_path:
        soundstorm_model .load_state_dict(torch.load(pretrained_path))
    elif ".safetensors" in pretrained_path:
        safetensors.torch.load_model(soundstorm_model, pretrained_path)
    soundstorm_model.eval()
    soundstorm_model.to(device)
    return soundstorm_model

def build_kmeans_model(cfg, device):
    if cfg.model.kmeans.type == "kmeans":
        kmeans_model = KMeans(cfg=cfg.model.kmeans.kmeans)
    elif cfg.model.kmeans.type == "kmeans_ema":
        kmeans_model = KMeansEMA(cfg=cfg.model.kmeans.kmeans)
    kmeans_model.eval()
    pretrained_path =cfg.model.kmeans.pretrained_path
    if ".bin" in pretrained_path:
        kmeans_model.load_state_dict(torch.load(pretrained_path))
    elif ".safetensors" in pretrained_path:
        safetensors.torch.load_model(kmeans_model, pretrained_path)
    kmeans_model.to(device)
    return kmeans_model

def build_semantic_model(cfg, device):
    semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
    semantic_model.eval()
    semantic_model.to(device)
    # layer_idx = 15
    # if layer_idx == 23:
    #     output_idx = 0
    # else:
    #     output_idx = layer_idx + 2
    layer_idx = 15
    output_idx = 17
    stat_mean_var = torch.load(cfg.model.kmeans.stat_mean_var_path)
    semantic_mean = stat_mean_var["mean"]
    semantic_std = torch.sqrt(stat_mean_var["var"])
    semantic_mean = semantic_mean.to(device)
    semantic_std = semantic_std.to(device)
    # print(
    #     "semantic mean: ", semantic_mean, "semantic std: ", semantic_std
    # )
    return semantic_model, semantic_mean, semantic_std

def build_codec_model(cfg, device):
    codec_encoder = CodecEncoder(cfg=cfg.model.codec.encoder)
    codec_decoder = CodecDecoder(cfg=cfg.model.codec.decoder)
    codec_encoder.load_state_dict(
        torch.load(cfg.model.codec.encoder.pretrained_path)
    )
    codec_decoder.load_state_dict(
        torch.load(cfg.model.codec.decoder.pretrained_path)
    )
    # codec_decoder = codec_decoder.quantizer  # we only need the quantizer
    codec_encoder.eval()
    codec_decoder.eval()
    codec_encoder.to(device)
    codec_decoder.to(device)
    return codec_encoder, codec_decoder

def build_t2s_model(cfg, device):
    t2s_model = T2SLlama(cfg=cfg.model.t2sllama)
    t2s_model.eval()
    t2s_model.to(device)
    return t2s_model

In [None]:
device = torch.device("cuda:1")
soundstorm_pretrained_path = "/mnt/bn/yuacnwang-speech/ckpt/soundstorm/soundstorm_16k_kmeans_2048_emilia_50k_llama/checkpoint/epoch-0011_step-0123000_loss-4.518502/model.safetensors"
soundstorm_model = build_soundstorm(cfg, soundstorm_pretrained_path, device)
semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg, device)
kmeans_model = build_kmeans_model(cfg, device)
codec_encoder, codec_decoder = build_codec_model(cfg, device)
t2s_model = build_t2s_model(t2s_cfg, device)

In [None]:
semantic_mean = semantic_mean.to(device)
semantic_std = semantic_std.to(device)

In [None]:
safetensors.torch.load_model(soundstorm_model, "/mnt/bn/yuacnwang-speech/ckpt/soundstorm/soundstorm_16k_kmeans_2048_emilia_50k_llama/checkpoint/epoch-0011_step-0127000_loss-5.334249/model.safetensors")
safetensors.torch.load_model(t2s_model, "/mnt/bn/yuacnwang-speech/ckpt/text2semantic/t2s_16k_kmeans_2048_emilia_50k/checkpoint/epoch-0001_step-0076000_loss-1.535404/model.safetensors")

In [None]:
from transformers import SeamlessM4TFeatureExtractor
processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

In [None]:
@torch.no_grad()
def extract_acoustic_code(speech):
    vq_emb = codec_encoder(speech.unsqueeze(1))
    _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)
    acoustic_code = vq.permute(
        1, 2, 0
    )  # (num_quantizer, T, C) -> (T, C, num_quantizer)
    return acoustic_code

@torch.no_grad()
def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):
    vq_emb = semantic_model(
        input_features=input_features,
        attention_mask=attention_mask,
        output_hidden_states=True,
    )
    feat = vq_emb.hidden_states[17]  # (B, T, C)
    feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)

    semantic_code, _ = kmeans_model.quantize(feat)  # (B, T)
    return semantic_code

@torch.no_grad()
def extract_features(speech, processor):
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt")
    input_features = inputs["input_features"][0]
    attention_mask = inputs["attention_mask"][0]
    return input_features, attention_mask

In [None]:
dataset_path = "/mnt/bn/yuacnwang-speech/dataset/Emilia/emilia/emilia_50k/wav"
dataset_wav_info = []
for subset in os.listdir(dataset_path):
    for wav in os.listdir(os.path.join(dataset_path, subset)):
        wav_path = os.path.join(dataset_path, subset, wav)
        dataset_wav_info.append(wav_path)
print(len(dataset_wav_info))

In [None]:
random_idx = 888
wav_path = dataset_wav_info[random_idx]
meta_path = wav_path.replace("/wav/", "/meta/").replace(".wav", ".pkl")
with open(meta_path, 'rb') as f:
    meta_info = pickle.load(f)
uid = wav_path.split("/")[-1].split(".")[0]
print(uid)
speech, sr = librosa.load(wav_path, sr=16000)
print(meta_info)
Audio(speech, rate=sr)

In [None]:
phone_id = g2p(meta_info['text'], meta_info['language'])[1]
phone_id = torch.tensor(phone_id, dtype=torch.long)
phone_id = torch.cat([torch.tensor(LANG2CODE[meta_info['language']], dtype=torch.long).reshape(1), phone_id]).to(device) # add language token

In [None]:
text = meta_info["text"]
with open("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/coninue/txt/{}.txt".format(uid), "w", encoding="utf-8") as f:
    f.writelines(text)

#### test data

In [None]:
wav_path = "/mnt/bn/yuacnwang-speech/dataset/temp_test/ns2_10.wav"
speech, sr = librosa.load(wav_path, sr=16000)
uid = wav_path.split("/")[-1].split(".")[0]
text = "For a few miles, she followed the line hitherto presumably occupied by the coast of Algeria, but no land appeared to the south."
Audio(speech, rate=16000)

In [None]:
phone_id = g2p(text, 'en')[1]
phone_id = torch.tensor(phone_id, dtype=torch.long)
phone_id = torch.cat([torch.tensor(LANG2CODE[meta_info['language']], dtype=torch.long).reshape(1), phone_id]).to(device) # add language token

In [None]:
input_fetures, attention_mask = extract_features(speech, processor)
input_fetures = input_fetures.unsqueeze(0).to(device)
attention_mask = attention_mask.unsqueeze(0).to(device)
semantic_code = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask)

## SoundStorm Reconstruction

In [None]:
acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))
seq_len = min(semantic_code.shape[1], acoustic_code.shape[1])
semantic_code = semantic_code[:, :seq_len]
acoustic_code = acoustic_code[:, :seq_len, :]
cond = soundstorm_model.cond_emb(semantic_code.to(device))

In [None]:
# cond = soundstorm_model.cond_emb(semantic_code.to(device))
# prompt = acoustic_code[:,:50*3,:]
# predict = soundstorm_model.reverse_diffusion(cond=cond.to(device), prompt=prompt.to(device), temp=1.5, filter_thres=0.98, n_timesteps=[50, 10, 1, 1, 1, 1, 1, 1], cfg=1.0, rescale_cfg=1.0, phone_id=phone_id.unsqueeze(0).to(device))

In [None]:
# vq_emb = codec_decoder.vq2emb(predict.permute(2,0,1))
# recovered_audio = codec_decoder(vq_emb)
# recovered_audio = recovered_audio[0][0].cpu().detach().numpy()
# Audio(recovered_audio, rate=16000)

In [None]:
# prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1))
# recovered_prompt_audio = codec_decoder(prompt_vq_emb)
# recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().detach().numpy()
# Audio(recovered_prompt_audio, rate=16000)

In [None]:
# combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
# Audio(combine_audio, rate=16000)

## SoundStorm Continue TTS

In [None]:
prompt_len = 50*3

In [None]:
predict_semantic = t2s_model.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=semantic_code[:, :prompt_len], temperature=1.0, top_k=100, top_p=0.8)

In [None]:
combine_semantic_code = torch.cat([semantic_code[:,:prompt_len], predict_semantic], dim=-1)

In [None]:
acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))

In [None]:
cond = soundstorm_model.cond_emb(combine_semantic_code)
print(cond.shape)

In [None]:
prompt = acoustic_code[:,:prompt_len,:]
predict = soundstorm_model.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[50, 10, 1, 1, 1, 1, 1, 1], cfg=1.0, rescale_cfg=1.0)
print(predict.shape)

In [None]:
vq_emb = codec_decoder.vq2emb(predict.permute(2,0,1))
recovered_audio = codec_decoder(vq_emb)
recovered_audio = recovered_audio[0][0].cpu().detach().numpy()
sf.write("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/coninue/target/{}.wav".format(uid), recovered_audio, 16000)
Audio(recovered_audio, rate=16000)

In [None]:
prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1))
recovered_prompt_audio = codec_decoder(prompt_vq_emb)
recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().detach().numpy()
sf.write("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/coninue/prompt/{}.wav".format(uid), recovered_prompt_audio, 16000)
Audio(recovered_prompt_audio, rate=16000)

In [None]:
combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
sf.write("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/coninue/combine/{}.wav".format(uid), combine_audio, 16000)
Audio(combine_audio, rate=16000)

## SoundStorm Cross TTS

In [None]:
prompt_wav_path = "/mnt/bn/yuacnwang-speech/dataset/temp_test/biden_2.wav"
prompt_speech, sr = librosa.load(prompt_wav_path, sr=16000)
uid = prompt_wav_path.split("/")[-1].split(".")[0]
prompt_text = "We do not break, we never give in, we never backdown."
prompt_phone_id = g2p(prompt_text, 'en')[1]
prompt_phone_id = torch.tensor(prompt_phone_id, dtype=torch.long)
prompt_phone_id = torch.cat([torch.tensor(LANG2CODE['en'], dtype=torch.long).reshape(1), prompt_phone_id])
target_text = "我的名字叫做拜登" 
target_phone_id = g2p(target_text, 'zh')[1]
target_phone_id = torch.tensor(target_phone_id, dtype=torch.long)
# target_phone_id = torch.cat([torch.tensor(LANG2CODE['en'], dtype=torch.long).reshape(1), target_phone_id])
phone_id = torch.cat([prompt_phone_id, target_phone_id])
device = torch.device("cuda:1")
phone_id = phone_id.to(device)
text = prompt_text + target_text
with open("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/cross/txt/{}.txt".format(uid), "w", encoding="utf-8") as f:
    f.writelines(target_text)
Audio(prompt_speech, rate=16000)

In [None]:
input_fetures, attention_mask = extract_features(prompt_speech, processor)
input_fetures = input_fetures.unsqueeze(0).to(device)
attention_mask = attention_mask.unsqueeze(0).to(device)
prompt_semantic_code = extract_semantic_code(semantic_mean, semantic_std, input_fetures, attention_mask)

In [None]:
predict_semantic = t2s_model.sample_hf(phone_ids=phone_id.unsqueeze(0), prompt_ids=prompt_semantic_code, temperature=0.95, top_k=512, top_p=0.9)
semantic_code = torch.cat([prompt_semantic_code, predict_semantic], dim=-1)

In [None]:
acoustic_code = extract_acoustic_code(torch.tensor(prompt_speech).unsqueeze(0).to(device))
print(acoustic_code.shape)
cond = soundstorm_model.cond_emb(semantic_code.to(device))
print(cond.shape)

In [None]:
prompt = acoustic_code
predict = soundstorm_model.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[50, 10, 1, 1, 1, 1, 1, 1], cfg=1.0, rescale_cfg=1.0)
print(predict.shape)

In [None]:
vq_emb = codec_decoder.vq2emb(predict.permute(2,0,1))
recovered_audio = codec_decoder(vq_emb)
recovered_audio = recovered_audio[0][0].cpu().detach().numpy()
sf.write("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/cross/target/{}.wav".format(uid), recovered_audio, 16000)
Audio(recovered_audio, rate=16000)

In [None]:
prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1))
recovered_prompt_audio = codec_decoder(prompt_vq_emb)
recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().detach().numpy()
sf.write("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/cross/prompt/{}.wav".format(uid), recovered_prompt_audio, 16000)
Audio(recovered_prompt_audio, rate=16000)

In [None]:
combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
sf.write("/mnt/bn/yuacnwang-speech/test_result/soundstorm_kmeans_2048_ar_76k_nar_127k/cross/combine/{}.wav".format(uid), combine_audio, 16000)
Audio(combine_audio, rate=16000)