In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from IPython.display import Audio
import matplotlib.pyplot as plt
import soundfile as sf

In [2]:
from models.tts.gpt_tts.gpt_tts import GPTTTS
from models.tts.gpt_tts.g2p_old_en import process, PHPONE2ID
from g2p_en import G2p
from models.codec.codec_latent.codec_latent import LatentCodecEncoder, LatentCodecDecoderWithTimbre
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from utils.util import load_config

In [3]:
cfg = load_config("egs/tts/LaCoTTS/exp_config_base.json")
print(cfg)

{'model_type': 'LaCoTTS', 'dataset': ['Your Dataset Name'], 'preprocess': {'hop_size': 200, 'sample_rate': 16000, 'processed_dir': '', 'valid_file': 'valid.json', 'train_file': 'train.json'}, 'model': {'latent_codec': {'encoder': {'d_mel': 128, 'd_model': 96, 'num_blocks': 4, 'out_channels': 256, 'use_tanh': False}, 'decoder': {'in_channels': 256, 'num_quantizers': 1, 'codebook_size': 8192, 'codebook_dim': 8, 'quantizer_type': 'fvq', 'use_l2_normlize': True, 'vocos_dim': 512, 'vocos_intermediate_dim': 4096, 'vocos_num_layers': 16, 'ln_before_vq': True, 'use_pe': False}, 'pretrained_ckpt': '...'}, 'wav_codec': {'encoder': {'d_model': 96, 'out_channels': 128, 'up_ratios': [2, 4, 5, 5], 'use_tanh': False}, 'decoder': {'in_channels': 128, 'upsample_initial_channel': 1536, 'num_quantizers': 8, 'codebook_size': 1024, 'codebook_dim': 128, 'quantizer_type': 'fvq', 'quantizer_dropout': 0.5, 'use_l2_normlize': True, 'use_vocos': True, 'vocos_dim': 512, 'vocos_intermediate_dim': 4096, 'vocos_num_

In [4]:
wav_codec_enc = CodecEncoder(
    cfg=cfg.model.wav_codec.encoder
)
wav_codec_dec = CodecDecoder(
    cfg=cfg.model.wav_codec.decoder
)

In [5]:
latent_codec_enc = LatentCodecEncoder(
    cfg=cfg.model.latent_codec.encoder
)
latent_codec_dec = LatentCodecDecoderWithTimbre(
    cfg=cfg.model.latent_codec.decoder
)

In [6]:
wav_codec_enc.load_state_dict(torch.load("ckpts/wav_codec/wav_codec_enc.bin"))
wav_codec_dec.load_state_dict(torch.load("ckpts/wav_codec/wav_codec_dec.bin"))
latent_codec_enc.load_state_dict(torch.load("ckpts/latent_codec/latent_codec_enc.bin"))
latent_codec_dec.load_state_dict(torch.load("ckpts/latent_codec/latent_codec_dec.bin"))

wav_codec_enc.eval()
wav_codec_dec.eval()
latent_codec_enc.eval()
latent_codec_dec.eval()

wav_codec_enc.cuda()
wav_codec_dec.cuda()
latent_codec_enc.cuda()
latent_codec_dec.cuda()

# requires_grad false
wav_codec_enc.requires_grad_(False)
wav_codec_dec.requires_grad_(False)
latent_codec_enc.requires_grad_(False)
latent_codec_dec.requires_grad_(False)

LatentCodecDecoderWithTimbre(
  (quantizer): ResidualVQ(
    (quantizers): ModuleList(
      (0): FactorizedVectorQuantize(
        (in_project): Conv1d(256, 8, kernel_size=(1,), stride=(1,))
        (out_project): Conv1d(8, 256, kernel_size=(1,), stride=(1,))
        (codebook): Embedding(8192, 8)
      )
    )
  )
  (model): Sequential(
    (0): VocosBackbone(
      (embed): Conv1d(256, 512, kernel_size=(7,), stride=(1,), padding=(3,))
      (norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
      (convnext): ModuleList(
        (0-15): 16 x ConvNeXtBlock(
          (dwconv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), padding=(3,), groups=512)
          (norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
          (pwconv1): Linear(in_features=512, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=4096, out_features=512, bias=True)
        )
      )
      (final_layer_norm): LayerNorm((512,), eps=

In [7]:
source_wav = librosa.load("examples/ref/1.wav", sr=16000)[0]

In [8]:
# speech tokenize
prompt_wav = torch.FloatTensor(source_wav).cuda().unsqueeze(0)
# wav to latent
vq_emb = wav_codec_enc(prompt_wav.unsqueeze(1))
vq_emb = latent_codec_enc(vq_emb)
# latent to token
(
    _,
    vq_indices,
    _,
    _,
    _,
    _,
) = latent_codec_dec(vq_emb, vq=True, eval_vq=False, return_spk_embs=False)
prompt_id = vq_indices[0,:,:]
prompt_id.shape

torch.Size([1, 272])

In [9]:
code = prompt_id[0].cpu().numpy().tolist()
code = [int(c) for c in code]
code = np.array(code, dtype=np.int16)
# save code as small as possible
code.tofile("examples/ref/1.code")
# compute the size of the code
print(os.path.getsize("examples/ref/1.code"))
print(os.path.getsize("examples/ref/1.wav"))

544
108844
