In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from IPython.display import Audio
import matplotlib.pyplot as plt
import soundfile as sf

In [2]:
from models.tts.gpt_tts.gpt_tts import GPTTTS
from models.tts.gpt_tts.g2p_old_en import process, PHPONE2ID
from g2p_en import G2p
from models.codec.codec_latent.codec_latent import LatentCodecEncoder, LatentCodecDecoderWithTimbre
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from utils.util import load_config



In [3]:
cfg = load_config("egs/tts/LaCoTTS/exp_config_base.json")
print(cfg)

{'model_type': 'GPTTTS', 'dataset': ['Your Dataset Name'], 'preprocess': {'hop_size': 200, 'sample_rate': 16000, 'processed_dir': '', 'valid_file': 'valid.json', 'train_file': 'train.json'}, 'model': {'latent_codec': {'encoder': {'d_mel': 128, 'd_model': 96, 'num_blocks': 4, 'out_channels': 256, 'use_tanh': False}, 'decoder': {'in_channels': 256, 'num_quantizers': 1, 'codebook_size': 8192, 'codebook_dim': 8, 'quantizer_type': 'fvq', 'use_l2_normlize': True, 'vocos_dim': 512, 'vocos_intermediate_dim': 4096, 'vocos_num_layers': 16, 'ln_before_vq': True, 'use_pe': False}, 'pretrained_ckpt': '...'}, 'wav_codec': {'encoder': {'d_model': 96, 'out_channels': 128, 'up_ratios': [2, 4, 5, 5], 'use_tanh': False}, 'decoder': {'in_channels': 128, 'upsample_initial_channel': 1536, 'num_quantizers': 8, 'codebook_size': 1024, 'codebook_dim': 128, 'quantizer_type': 'fvq', 'quantizer_dropout': 0.5, 'use_l2_normlize': True, 'use_vocos': True, 'vocos_dim': 512, 'vocos_intermediate_dim': 4096, 'vocos_num_l

# Audio Tokenizer: Convert Speech to Token

In [4]:
wav_codec_enc = CodecEncoder(
    cfg=cfg.model.wav_codec.encoder
)
wav_codec_dec = CodecDecoder(
    cfg=cfg.model.wav_codec.decoder
)

In [5]:
latent_codec_enc = LatentCodecEncoder(
    cfg=cfg.model.latent_codec.encoder
)
latent_codec_dec = LatentCodecDecoderWithTimbre(
    cfg=cfg.model.latent_codec.decoder
)

In [7]:
wav_codec_enc.load_state_dict(torch.load("ckpts/wav_codec/wav_codec_enc.bin"))
wav_codec_dec.load_state_dict(torch.load("ckpts/wav_codec/wav_codec_dec.bin"))
latent_codec_enc.load_state_dict(torch.load("ckpts/latent_codec/latent_codec_enc.bin"))
latent_codec_dec.load_state_dict(torch.load("ckpts/latent_codec/latent_codec_dec.bin"))

wav_codec_enc.eval()
wav_codec_dec.eval()
latent_codec_enc.eval()
latent_codec_dec.eval()

# wav_codec_enc.cuda()
# wav_codec_dec.cuda()
# latent_codec_enc.cuda()
# latent_codec_dec.cuda()

# requires_grad false
wav_codec_enc.requires_grad_(False)
wav_codec_dec.requires_grad_(False)
latent_codec_enc.requires_grad_(False)
latent_codec_dec.requires_grad_(False)

LatentCodecDecoderWithTimbre(
  (quantizer): ResidualVQ(
    (quantizers): ModuleList(
      (0): FactorizedVectorQuantize(
        (in_project): Conv1d(256, 8, kernel_size=(1,), stride=(1,))
        (out_project): Conv1d(8, 256, kernel_size=(1,), stride=(1,))
        (codebook): Embedding(8192, 8)
      )
    )
  )
  (model): Sequential(
    (0): VocosBackbone(
      (embed): Conv1d(256, 512, kernel_size=(7,), stride=(1,), padding=(3,))
      (norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
      (convnext): ModuleList(
        (0-15): 16 x ConvNeXtBlock(
          (dwconv): Conv1d(512, 512, kernel_size=(7,), stride=(1,), padding=(3,), groups=512)
          (norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
          (pwconv1): Linear(in_features=512, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (pwconv2): Linear(in_features=4096, out_features=512, bias=True)
        )
      )
      (final_layer_norm): LayerNorm((512,), eps=

# Latent Codec Language Model TTS

In [9]:
gpt_tts = GPTTTS(cfg=cfg.model.gpt_tts)
gpt_tts.load_state_dict(torch.load("/mnt/petrelfs/hehaorui/jiaqi/gpt-tts/exps/latent_codec_gpt_tts/checkpoint/epoch-0001_step-0136000_loss-7.047267/pytorch_model.bin", map_location="cpu"))

<All keys matched successfully>

In [11]:
gpt_tts.eval()
# gpt_tts.cuda()
gpt_tts.requires_grad_(False)

GPTTTS(
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(8846, 1024, padding_idx=8838)
      (layers): ModuleList(
        (0-11): 12 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=1024, out_features=4096, bias=False)
            (down_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (up_proj): Linear(in_features=1024, out_features=4096, bias=False)
            (act_fn): SiLUActivation()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layerno

# Inference

In [12]:
source_text = "and keeping erernity before the eyes"
target_text = "Come, come returned Hawkeye, uncasing his honest countenance, the better to assure the wavering confidence of his companion. You may see a skin which, if it be not as white as one of the gentle ones, has no tinge of red to it that the winds of the heaven and the sun have not bestowed. Now, let us to business."
source_wav = librosa.load("examples/ref/1.wav", sr=16000)[0]
Audio(source_wav, rate=16000)

In [13]:
# text tokenize
text = source_text + " " + target_text
g2p = G2p()
txt_struct, txt = process(text, g2p)
phone_seq = [p for w in txt_struct for p in w[1]]
phone_id = [PHPONE2ID[p] for p in phone_seq]
phone_id = torch.LongTensor(phone_id).unsqueeze(0).cuda()
phone_id.shape

In [15]:
# from utils.g2p import PhonemeBpeTokenizer
# text_tokenizer = PhonemeBpeTokenizer()
# lang2token ={
#     'zh': "[ZH]", 
#     'ja':"[JA]", 
#     "en":"[EN]", 
#     "fr":"[FR]",
#     "kr": "[KR]",
#     "de": "[DE]",
# }
# def g2p(text, language):
#     text = text.replace("\n","").strip("")
#     lang_token = lang2token[language]
#     text = lang_token + text + lang_token
#     return text_tokenizer.tokenize(text=f"{text}".strip(),language=language)
# phone = g2p(text, 'en')[1]
# phone = torch.tensor(phone, dtype=torch.long)
# phone_id = phone.unsqueeze(0)
# phone_id

tensor([[107,  20,  16,  62,  27,  25,  32,  51,  45,  16,  21,  55,  21,  55,
          30,  25,  34,  39,   7,  16,  63,  19,  25,  62,  22,  47,  55,  16,
          44,  48,  16,  96,  40,  16,  27,  48,  29,   8,  16,  27,  48,  29,
          16,  55,  51,  62,  34,  48,  55,  30,  20,  16,  62,  24,  47,  63,
          27,  96,   8,  16,  35,  30, 617,  18,  33, 105,  23,   7,  16,  24,
          51,  40,  16,  62,  46,  30,  48,  33,  34,  16,  62,  27,  18,  58,
          30,  34, 107, 107,  33,   8,  16,  44,  48,  16,  62,  19,  49,  34,
          48,  55,  16,  34,  51,  16,  48,  62,  57,  58,  55,  16,  44,  48,
          16,  62,  37,  97,  36,  48,  55,  51,  45,  16,  62,  27,  46,  30,
          22,  48,  20,  49,  30,  33,  16,  48,  36,  16,  24,  51,  40,  16,
          27,  48,  29,  62,  32,  42,  30,  26, 107,  10,  16, 624,  16,  29,
          97,  16,  33,  25,  16,  48,  16,  33,  27,  51,  30,  16,  37,  51,
          78,   8,  16,  51,  22,  16,  51,  34,  16

In [17]:
# speech tokenize
prompt_wav = torch.FloatTensor(source_wav).unsqueeze(0)
# wav to latent
vq_emb = wav_codec_enc(prompt_wav.unsqueeze(1))
vq_emb = latent_codec_enc(vq_emb)
# latent to token
(
    _,
    vq_indices,
    _,
    _,
    _,
    speaker_embedding,
) = latent_codec_dec(vq_emb, vq=True, eval_vq=False, return_spk_embs=True)
prompt_id = vq_indices[0,:,:]
prompt_id.shape

torch.Size([1, 272])

In [18]:
code = prompt_id[0].cpu().numpy().tolist()
code = [int(c) for c in code]
code = np.array(code, dtype=np.int16)
# save code as small as possible
code.tofile("examples/ref/1.code")
# compute the size of the code
print(os.path.getsize("examples/ref/1.code"))
print(os.path.getsize("examples/ref/1.wav"))

544
108844


In [19]:
gen_tokens = gpt_tts.sample_hf(
    phone_id,
    prompt_id,
    max_length=3600,
    temperature=0.9,
    top_k=8192,
    top_p=0.85,
    repeat_penalty=1.0,
    classifer_free_guidance=1.25,   # i find speech speed will be faster if we set cfg > 1.0; if you don't want it, set cfg = 1.0
)

In [20]:
# gen token to latent
vq_post_emb = latent_codec_dec.vq2emb(gen_tokens.unsqueeze(0))
recovered_latent = latent_codec_dec(
    vq_post_emb, vq=False, speaker_embedding=speaker_embedding
)

In [21]:
# reconvered latent to wav
recovered_audio = wav_codec_dec(recovered_latent, vq=False)

In [22]:
sf.write("examples/recon/1.wav", recovered_audio.squeeze().cpu().numpy(), 16000)
Audio("examples/recon/1.wav", rate=16000)