In [1]:
import numpy as np
import torch
import torch.nn as nn
from IPython.display import Audio
import librosa
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from utils.util import load_config

In [2]:
cfg = load_config("egs/codec/codec_dac_vocos_24K/exp_config_base.json")
print(cfg)

{'model_type': 'Codec', 'dataset': ['libritts-train-clean-100'], 'preprocess': {'hop_size': 480, 'sample_rate': 24000, 'max_length': 36000, 'processed_dir': '/home/t-zeqianju/yuancwang/temp_test_dataset', 'valid_file': 'valid.json', 'train_file': 'train.json'}, 'model': {'encoder': {'d_model': 96, 'up_ratios': [3, 4, 5, 8], 'out_channels': 256, 'use_tanh': False}, 'decoder': {'in_channel': 256, 'upsample_initial_channel': 1536, 'up_ratios': [8, 5, 4, 3], 'num_quantizers': 12, 'codebook_size': 1024, 'codebook_dim': 8, 'quantizer_type': 'fvq', 'quantizer_dropout': 0.5, 'commitment': 0.25, 'codebook_loss_weight': 1.0, 'use_l2_normlize': True, 'codebook_type': 'euclidean', 'kmeans_init': False, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 0.5, 'threshold_ema_dead_code': 2, 'weight_init': False, 'use_vocos': True, 'vocos_dim': 512, 'vocos_intermediate_dim': 4096, 'vocos_num_layers': 30, 'n_fft': 1920, 'hop_size': 480, 'padding': 'same'}, 'period_gan': {'max_downsample_channels': 512, 'channels'

In [4]:
encoder = CodecEncoder(cfg=cfg.model.encoder)
decoder = CodecDecoder(cfg=cfg.model.decoder)

In [8]:
model_path = "/blob/v-yuancwang/codec_ckpt/codec_amphion/codec_dac_vocos_24K_320hopsize_12vq_30vocos_layers/checkpoint/epoch-0014_step-0016000_loss-27.834438/pytorch_model.bin"
checkpoint = torch.load(model_path, map_location="cpu")
encoder.load_state_dict(checkpoint)
model_path = "/blob/v-yuancwang/codec_ckpt/codec_amphion/codec_dac_vocos_24K_320hopsize_12vq_30vocos_layers/checkpoint/epoch-0014_step-0016000_loss-27.834438/pytorch_model_1.bin"
checkpoint = torch.load(model_path, map_location="cpu")
decoder.load_state_dict(checkpoint)

<All keys matched successfully>

In [15]:
test_wav_path = "/home/t-zeqianju/yuancwang/temp_test_dataset/libritts-train-clean-100/train-clean-100/39/121916/39_121916_000002_000002.wav"
wav, sr = librosa.load(test_wav_path, sr=24000)
Audio(wav, rate=sr)

In [16]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

torch.Size([1, 355440])


torch.Size([1, 256, 740])
torch.Size([1, 256, 740])
torch.Size([12, 1, 740])


In [18]:
recovered_audio = decoder(vq_post_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)

torch.Size([1, 1, 355200])


# Model Example
DAC codec with latent size 256, code size 8, codebook number 8, codebook size 1024, with Vocos decoder

In [2]:
encoder = CodecEncoder(
    d_model=96,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
)

decoder = CodecDecoder(
    in_channels=256,
    up_ratios=[5, 5, 4, 2],
    num_quantizers=8,
    codebook_size=1024,
    codebook_dim=8,
    quantizer_type="fvq",
    use_l2_normlize=True,
    use_vocos=True,
    vocos_dim=512,
    vocos_intermediate_dim=4096,
    vocos_num_layers=24,
)

In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print the number of parameters of encoder and decoder
print("number of parameters of the encoder: {}M".format(count_parameters(encoder)/1e6))
print("number of parameters of the decoder: {}M".format(count_parameters(decoder)/1e6))

number of parameters of the encoder: 35.425472M
number of parameters of the decoder: 102.343074M


In [4]:
encoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_8d_w_l2norm_vocos/encoder.bin"
decoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_8d_w_l2norm_vocos/decoder.bin"

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

<All keys matched successfully>

In [5]:
test_wav_path = "/mnt/data2/wangyuancheng/tts_data/kss/2/2_0000.wav"
wav, sr = librosa.load(test_wav_path, sr=16000)
Audio(wav, rate=sr)

In [6]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

torch.Size([1, 50156])
torch.Size([1, 256, 251])
torch.Size([1, 256, 251])
torch.Size([8, 1, 251])


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
recovered_audio = decoder(vq_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)

torch.Size([1, 1, 50200])


# Model Example
Codec with latent size 128, code size 128, codebook number 8, codebook size 1024, with Vocos decoder

In [8]:
encoder = CodecEncoder(
    d_model=96,
    up_ratios=[2, 4, 5, 5],
    out_channels=128,
)

decoder = CodecDecoder(
    in_channels=128,
    up_ratios=[5, 5, 4, 2],
    num_quantizers=8,
    codebook_size=1024,
    codebook_dim=128,
    quantizer_type="fvq",
    use_l2_normlize=True,
    use_vocos=True,
    vocos_dim=512,
    vocos_intermediate_dim=4096,
    vocos_num_layers=24,
)

In [9]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print the number of parameters of encoder and decoder
print("number of parameters of the encoder: {}M".format(count_parameters(encoder)/1e6))
print("number of parameters of the decoder: {}M".format(count_parameters(decoder)/1e6))

number of parameters of the encoder: 34.835392M
number of parameters of the decoder: 102.83037M


In [10]:
encoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_128d_w_l2norm_vocos/encoder.bin"
decoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_8layer_1024_fvq_128d_w_l2norm_vocos/decoder.bin"

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

<All keys matched successfully>

In [11]:
test_wav_path = "/mnt/data2/wangyuancheng/tts_data/kss/2/2_0000.wav"
wav, sr = librosa.load(test_wav_path, sr=16000)
Audio(wav, rate=sr)

In [12]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

torch.Size([1, 50156])
torch.Size([1, 128, 251])
torch.Size([1, 128, 251])
torch.Size([8, 1, 251])


In [13]:
recovered_audio = decoder(vq_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)

torch.Size([1, 1, 50200])


# Model Example
A small codec model similar to encodec/soundstream with latent size 256, code size 256, codebook number 12, codebook size 1024

In [14]:
encoder = CodecEncoder(
    d_model=32,
    up_ratios=[2, 4, 5, 5],
    out_channels=256,
)

decoder = CodecDecoder(
    in_channels=256,
    up_ratios=[5, 5, 4, 2],
    upsample_initial_channel=512,
    num_quantizers=12,
    codebook_size=1024,
    codebook_dim=256,
    quantizer_type="vq",
    use_l2_normlize=False,
    use_vocos=False,
    commitment=0.15
)

In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
# print the number of parameters of encoder and decoder
print("number of parameters of the encoder: {}M".format(count_parameters(encoder)/1e6))
print("number of parameters of the decoder: {}M".format(count_parameters(decoder)/1e6))

number of parameters of the encoder: 4.206656M
number of parameters of the decoder: 4.730914M


In [16]:
encoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_12layer_1024_vq_wo_l2norm_wo_codebook_loss_drop_0_0_commit_0_15_small/encoder.bin"
decoder_path = "/mnt/data2/wangyuancheng/model_ckpts/codec/codec_16k_200hopsize_12layer_1024_vq_wo_l2norm_wo_codebook_loss_drop_0_0_commit_0_15_small/decoder.bin"

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

<All keys matched successfully>

In [17]:
test_wav_path = "/mnt/data2/wangyuancheng/tts_data/kss/2/2_0000.wav"
wav, sr = librosa.load(test_wav_path, sr=16000)
Audio(wav, rate=sr)

In [18]:
audio = torch.from_numpy(wav).unsqueeze(0)
print(audio.shape)

# encode the audio to latent
vq_emb = encoder(audio.unsqueeze(0))
print(vq_emb.shape)

vq_post_emb, vq_id, _, _, _ = decoder(
    vq_emb, vq=True, eval_vq=True
)
# latent after vq
print(vq_post_emb.shape)
# vq id
print(vq_id.shape)

torch.Size([1, 50156])
torch.Size([1, 256, 251])
torch.Size([1, 256, 251])
torch.Size([12, 1, 251])


In [19]:
recovered_audio = decoder(vq_emb, vq=False)
print(recovered_audio.shape)
recovered_audio = recovered_audio.squeeze(0).squeeze(0).detach().numpy()
Audio(recovered_audio, rate=sr)

torch.Size([1, 1, 50200])
