In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from IPython.display import Audio
import matplotlib.pyplot as plt
import soundfile as sf
import pickle

In [2]:
from models.codec.kmeans.kmeans_model import KMeans, KMeansEMA
from models.tts.soundstorm.soundstorm_model import SoundStorm
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from transformers import Wav2Vec2BertModel
import safetensors
from utils.util import load_config



In [3]:
cfg = load_config("/opt/tiger/SpeechGeneration/egs/tts/SoundStorm/exp_config_16k_emilia_llama_new_semantic.json")
print(cfg)

{'model_type': 'SoundStorm', 'dataset': ['emilia'], 'preprocess': {'hop_size': 320, 'sample_rate': 16000, 'processed_dir': '', 'valid_file': 'valid.json', 'train_file': 'train.json', 'use_phone_cond': False}, 'model': {'soundstorm': {'num_quantizer': 8, 'hidden_size': 1024, 'num_layers': 16, 'num_heads': 16, 'codebook_size': 1024, 'cfg_scale': 0.15, 'mask_layer_schedule': 'linear', 'use_cond_code': True, 'cond_codebook_size': 2048, 'cond_dim': 1024, 'use_llama_style': True, 'use_phone_cond': False, 'use_pretrained_model': False}, 'kmeans': {'type': 'kmeans_ema', 'stat_mean_var_path': '/mnt/bn/yuacnwang-speech/ckpt/semantic_kmeans/emilia_wav2vec2bert_stats_10k.pt', 'kmeans': {'codebook_size': 2048, 'codebook_dim': 1024, 'kmeans_init': True, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 1e-05}, 'pretrained_path': '/mnt/bn/yuacnwang-speech/ckpt/semantic_kmeans/semantic_kmeans_emilia_50k_2048_stable/model.safetensors'}, 'codec': {'encoder': {'d_model': 96, 'up_ratios': [4, 4, 4, 5], 'out_channe

In [4]:
def build_soundstorm(cfg, pretrained_path, device):
    soundstorm_model = SoundStorm(cfg=cfg.model.soundstorm)
    if ".bin" in pretrained_path:
        soundstorm_model .load_state_dict(torch.load(pretrained_path))
    elif ".safetensors" in pretrained_path:
        safetensors.torch.load_model(soundstorm_model, pretrained_path)
    soundstorm_model.eval()
    soundstorm_model.to(device)
    return soundstorm_model

def build_kmeans_model(cfg, device):
    if cfg.model.kmeans.type == "kmeans":
        kmeans_model = KMeans(cfg=cfg.model.kmeans.kmeans)
    elif cfg.model.kmeans.type == "kmeans_ema":
        kmeans_model = KMeansEMA(cfg=cfg.model.kmeans.kmeans)
    kmeans_model.eval()
    pretrained_path =cfg.model.kmeans.pretrained_path
    if ".bin" in pretrained_path:
        kmeans_model.load_state_dict(torch.load(pretrained_path))
    elif ".safetensors" in pretrained_path:
        safetensors.torch.load_model(kmeans_model, pretrained_path)
    kmeans_model.to(device)
    return kmeans_model

def build_semantic_model(cfg, device):
    semantic_model = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0")
    semantic_model.eval()
    semantic_model.to(device)
    # layer_idx = 15
    # if layer_idx == 23:
    #     output_idx = 0
    # else:
    #     output_idx = layer_idx + 2
    layer_idx = 15
    output_idx = 17
    stat_mean_var = torch.load(cfg.model.kmeans.stat_mean_var_path)
    semantic_mean = stat_mean_var["mean"]
    semantic_std = torch.sqrt(stat_mean_var["var"])
    semantic_mean = semantic_mean.to(device)
    semantic_std = semantic_std.to(device)
    # print(
    #     "semantic mean: ", semantic_mean, "semantic std: ", semantic_std
    # )
    return semantic_model, semantic_mean, semantic_std

def build_codec_model(cfg, device):
    codec_encoder = CodecEncoder(cfg=cfg.model.codec.encoder)
    codec_decoder = CodecDecoder(cfg=cfg.model.codec.decoder)
    codec_encoder.load_state_dict(
        torch.load(cfg.model.codec.encoder.pretrained_path)
    )
    codec_decoder.load_state_dict(
        torch.load(cfg.model.codec.decoder.pretrained_path)
    )
    # codec_decoder = codec_decoder.quantizer  # we only need the quantizer
    codec_encoder.eval()
    codec_decoder.eval()
    codec_encoder.to(device)
    codec_decoder.to(device)
    return codec_encoder, codec_decoder

In [6]:
device = torch.device("cuda:0")
soundstorm_pretrained_path = "/mnt/bn/yuacnwang-speech/ckpt/soundstorm/soundstorm_16k_kmeans_2048_emilia_50k_llama_new_semantic/checkpoint/epoch-0002_step-0028000_loss-6.033800/model.safetensors"
soundstorm_model = build_soundstorm(cfg, soundstorm_pretrained_path, device)
semantic_model, semantic_mean, semantic_std = build_semantic_model(cfg, device)
kmeans_model = build_kmeans_model(cfg, device)
codec_encoder, codec_decoder = build_codec_model(cfg, device)

In [7]:
from transformers import SeamlessM4TFeatureExtractor
processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")

In [8]:
@torch.no_grad()
def extract_acoustic_code(speech):
    vq_emb = codec_encoder(speech.unsqueeze(1))
    _, vq, _, _, _ = codec_decoder.quantizer(vq_emb)
    acoustic_code = vq.permute(
        1, 2, 0
    )  # (num_quantizer, T, C) -> (T, C, num_quantizer)
    return acoustic_code

@torch.no_grad()
def extract_semantic_code(semantic_mean, semantic_std, input_features, attention_mask):
    vq_emb = semantic_model(
        input_features=input_features,
        attention_mask=attention_mask,
        output_hidden_states=True,
    )
    feat = vq_emb.hidden_states[17]  # (B, T, C)
    feat = (feat - semantic_mean.to(feat)) / semantic_std.to(feat)

    semantic_code, _ = kmeans_model.quantize(feat)  # (B, T)
    return semantic_code

@torch.no_grad()
def extract_features(speech, processor):
    inputs = processor(speech, sampling_rate=16000, return_tensors="pt")
    input_features = inputs["input_features"][0]
    attention_mask = inputs["attention_mask"][0]
    return input_features, attention_mask

In [9]:
dataset_path = "/mnt/bn/yuacnwang-speech/dataset/Emilia/emilia/emilia_50k/wav"
dataset_wav_info = []
for subset in os.listdir(dataset_path):
    for wav in os.listdir(os.path.join(dataset_path, subset)):
        wav_path = os.path.join(dataset_path, subset, wav)
        dataset_wav_info.append(wav_path)
print(len(dataset_wav_info))

1632077


In [39]:
random_idx = 777
wav_path = dataset_wav_info[random_idx]
meta_path = wav_path.replace("/wav/", "/meta/").replace(".wav", ".pkl")
with open(meta_path, 'rb') as f:
    meta_info = pickle.load(f)
uid = wav_path.split("/")[-1].split(".")[0]
print(uid)
speech, sr = librosa.load(wav_path, sr=16000)
print(meta_info)
Audio(speech, rate=sr)

BiliBili_BV18K4y1q7kN_984
{'text': '一氣四倍處於一個良好的環境,所以我們整個艙內其實是有加溫的這個空氣進行催促的,包括我們艙內都有空調的這個送風,可以保證整個艙段的這個環境條件能滿足一氣四倍的這個工作環境條件要求,所以目前不管外面的這個環境怎麼樣,其實火箭裡面我們是有保障的這個條件。', 'start': 12283.934865874362, 'end': 12303.386865874363, 'speaker': 'SPEAKER_35', 'language': 'zh', 'mos': {'wvmos': 0, 'dnsmos': 3.298829834195812, 'avg': 3.298829834195812}, 'uid': 'BiliBili_BV18K4y1q7kN_984'}


In [40]:
# wav_path = "/mnt/bn/yuacnwang-speech/dataset/temp_test/1_douluo_yueyu_(Vocals).wav"
# uid = wav_path.split("/")[-1].split(".")[0]
# print(uid)
# speech, sr = librosa.load(wav_path, sr=16000)

In [41]:
input_fetures, attention_mask = extract_features(speech, processor)
input_fetures = input_fetures.unsqueeze(0).to(device)
attention_mask = attention_mask.unsqueeze(0).to(device)
semantic_code = extract_semantic_code(semantic_mean.to(device), semantic_std.to(device), input_fetures, attention_mask)

In [42]:
print(semantic_code.shape)

torch.Size([1, 972])


In [43]:
acoustic_code = extract_acoustic_code(torch.tensor(speech).unsqueeze(0).to(device))

In [44]:
print(acoustic_code.shape)

torch.Size([1, 972, 8])


In [45]:
seq_len = min(semantic_code.shape[1], acoustic_code.shape[1])
semantic_code = semantic_code[:, :seq_len]
acoustic_code = acoustic_code[:, :seq_len, :]

In [46]:
cond = soundstorm_model.cond_emb(semantic_code)
print(cond.shape)

torch.Size([1, 972, 1024])


In [51]:
prompt = acoustic_code[:,:50*8,:]
predict = soundstorm_model.reverse_diffusion(cond=cond, prompt=prompt, temp=1.5, filter_thres=0.98, n_timesteps=[20, 4, 1, 1, 1, 1, 1, 1], cfg=1.0, rescale_cfg=1.0)
print(predict.shape)

torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572])
torch.Size([1, 572, 8])
torch.Size([1, 572, 8])


In [52]:
vq_emb = codec_decoder.vq2emb(predict.permute(2,0,1))
recovered_audio = codec_decoder(vq_emb)
recovered_audio = recovered_audio[0][0].cpu().detach().numpy()
# sf.write("/opt/tiger/SpeechGeneration/temp_test_wav/target/{}.wav".format(uid), recovered_audio, samplerate=16000)
Audio(recovered_audio, rate=16000)

In [53]:
prompt_vq_emb = codec_decoder.vq2emb(prompt.permute(2,0,1))
recovered_prompt_audio = codec_decoder(prompt_vq_emb)
recovered_prompt_audio = recovered_prompt_audio[0][0].cpu().detach().numpy()
# sf.write("/opt/tiger/SpeechGeneration/temp_test_wav/prompt/{}.wav".format(uid), recovered_prompt_audio, samplerate=16000)
Audio(recovered_prompt_audio, rate=16000)

In [54]:
combine_audio = np.concatenate([recovered_prompt_audio, recovered_audio])
# sf.write("/opt/tiger/SpeechGeneration/temp_test_wav/combine/{}.wav".format(uid), combine_audio, samplerate=16000)
Audio(combine_audio, rate=16000)