In [53]:
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
import numpy as np
import torch
import torch.nn as nn
from IPython.display import Audio
import librosa

# Model Example
DAC codec with latent size 256, code size 8, codebook number 8, codebook size 1024, with Vocos decoder

In [38]:
encoder = CodecEncoder(
    d_model=96,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
)

decoder = CodecDecoder(
    in_channels=256,
    up_ratios=[5, 5, 4, 2],
    num_quantizers=8,
    codebook_size=1024,
    codebook_dim=8,
    quantizer_type="fvq",
    use_l2_normlize=True,
    use_vocos=True,
    vocos_dim=512,
    vocos_intermediate_dim=4096,
    vocos_num_layers=24,
)

In [39]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print the number of parameters of encoder and decoder
print("number of parameters of the encoder: {}M".format(count_parameters(encoder)/1e6))
print("number of parameters of the decoder: {}M".format(count_parameters(decoder)/1e6))

number of parameters of the encoder: 35.425472M
number of parameters of the decoder: 102.343074M


In [40]:
encoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_8d_w_l2norm_vocos/encoder.bin"
decoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_8d_w_l2norm_vocos/decoder.bin"

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

<All keys matched successfully>

In [41]:
test_wav_path = "/mnt/data2/wangyuancheng/tts_data/kss/2/2_0000.wav"
wav, sr = librosa.load(test_wav_path, sr=16000)
Audio(wav, rate=sr)

In [42]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

torch.Size([1, 50156])
torch.Size([1, 256, 251])
torch.Size([1, 256, 251])
torch.Size([8, 1, 251])


In [43]:
recovered_audio = decoder(vq_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)

torch.Size([1, 1, 50200])


# Model Example
Codec with latent size 128, code size 128, codebook number 8, codebook size 1024, with Vocos decoder

In [44]:
encoder = CodecEncoder(
    d_model=96,
    up_ratios=[2, 4, 5, 5],
    out_channels=128,
)

decoder = CodecDecoder(
    in_channels=128,
    up_ratios=[5, 5, 4, 2],
    num_quantizers=8,
    codebook_size=1024,
    codebook_dim=128,
    quantizer_type="fvq",
    use_l2_normlize=True,
    use_vocos=True,
    vocos_dim=512,
    vocos_intermediate_dim=4096,
    vocos_num_layers=24,
)

In [45]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print the number of parameters of encoder and decoder
print("number of parameters of the encoder: {}M".format(count_parameters(encoder)/1e6))
print("number of parameters of the decoder: {}M".format(count_parameters(decoder)/1e6))

number of parameters of the encoder: 34.835392M
number of parameters of the decoder: 102.83037M


In [46]:
encoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_128d_w_l2norm_vocos/encoder.bin"
decoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_128d_w_l2norm_vocos/decoder.bin"

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

<All keys matched successfully>

In [47]:
test_wav_path = "/mnt/data2/wangyuancheng/tts_data/kss/2/2_0000.wav"
wav, sr = librosa.load(test_wav_path, sr=16000)
Audio(wav, rate=sr)

In [48]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

torch.Size([1, 50156])
torch.Size([1, 128, 251])
torch.Size([1, 128, 251])
torch.Size([8, 1, 251])


In [49]:
recovered_audio = decoder(vq_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)

torch.Size([1, 1, 50200])


# Model Example
A small codec model similar to encodec/soundstream with latent size 256, code size 256, codebook number 12, codebook size 1024

In [None]:
encoder = CodecEncoder(
    d_model=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
)

decoder = CodecDecoder(
    in_channels=256,
    up_ratios=[5, 5, 4, 2],
    upsample_initial_channel=512,
    num_quantizers=12,
    codebook_size=1024,
    codebook_dim=256,
    quantizer_type="vq",
    use_l2_normlize=False,
    use_vocos=False,
    commitment=0.15
)

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print the number of parameters of encoder and decoder
print("number of parameters of the encoder: {}M".format(count_parameters(encoder)/1e6))
print("number of parameters of the decoder: {}M".format(count_parameters(decoder)/1e6))

In [None]:
encoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_12layer_1024_vq_wo_l2norm_wo_codebook_loss_drop_0_0_commit_0_15_small/encoder.bin"
decoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_12layer_1024_vq_wo_l2norm_wo_codebook_loss_drop_0_0_commit_0_15_small/decoder.bin"

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

In [None]:
test_wav_path = "/mnt/data2/wangyuancheng/tts_data/kss/2/2_0000.wav"
wav, sr = librosa.load(test_wav_path, sr=16000)
Audio(wav, rate=sr)

In [None]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

In [None]:
recovered_audio = decoder(vq_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)