In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import os
from IPython.display import Audio
import matplotlib.pyplot as plt
import soundfile as sf

In [2]:
from models.tts.gpt_tts.gpt_tts import GPTTTS
from models.tts.gpt_tts.g2p_old_en import process, PHPONE2ID
from g2p_en import G2p
from models.codec.codec_latent.codec_latent import LatentCodecEncoder, LatentCodecDecoderWithTimbre
from models.codec.amphion_codec.codec import CodecEncoder, CodecDecoder
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
from utils.util import load_config

In [3]:
cfg = load_config("egs/tts/NaturalSpeech2/model_config.json")
print(cfg)

{'model': {'reference_encoder': {'encoder_layer': 6, 'encoder_hidden': 512, 'encoder_head': 8, 'conv_filter_size': 2048, 'conv_kernel_size': 9, 'encoder_dropout': 0.2, 'use_skip_connection': False, 'use_new_ffn': True, 'ref_in_dim': 256, 'ref_out_dim': 512, 'use_query_emb': True, 'num_query_emb': 32}, 'diffusion': {'diffusion_type': 'diffusion', 'beta_min': 0.05, 'beta_max': 20, 'sigma': 1.0, 'noise_factor': 1.0, 'ode_solve_method': 'euler', 'diff_model_type': 'WaveNet', 'diff_wavenet': {'input_size': 256, 'hidden_size': 512, 'out_size': 256, 'num_layers': 40, 'cross_attn_per_layer': 3, 'dilation_cycle': 2, 'attn_head': 8, 'drop_out': 0.2}}, 'prior_encoder': {'encoder_layer': 6, 'encoder_hidden': 512, 'encoder_head': 8, 'conv_filter_size': 2048, 'conv_kernel_size': 9, 'encoder_dropout': 0.2, 'use_skip_connection': False, 'use_new_ffn': True, 'vocab_size': 256, 'cond_dim': 512, 'duration_predictor': {'input_size': 512, 'filter_size': 512, 'kernel_size': 3, 'conv_layers': 30, 'cross_attn

In [4]:
codec_enc = CodecEncoder(
    cfg=cfg.model.codec.encoder
)
codec_dec = CodecDecoder(
    cfg=cfg.model.codec.decoder
)

In [5]:
pretrained_model_path = "/blob/v-yuancwang/codec_ckpt/codec_yc/codec_16k_200hopsize_12layer_1024_vq_wo_l2norm_wo_codebook_loss_drop_0_0_commit_0_15_small/checkpoint-32000.pt"
checkpoint = torch.load(pretrained_model_path, map_location="cpu")
codec_enc.load_state_dict(checkpoint["model"]["CodecEnc"])
codec_dec.load_state_dict(checkpoint["model"]["generator"])
torch.save(codec_enc.state_dict(), "ckpts/ns2/codec_enc.bin")
torch.save(codec_dec.state_dict(), "ckpts/ns2/codec_dec.bin")

In [None]:
codec_enc.load_state_dict(torch.load("ckpts/ns2/codec_enc.bin"))
codec_dec.load_state_dict(torch.load("ckpts/ns2/codec_dec.bin"))

codec_enc.eval()
codec_dec.eval()

codec_enc.cuda()
codec_dec.cuda()

# requires_grad false
codec_enc.requires_grad_(False)
codec_dec.requires_grad_(False)

In [7]:
ns2_model = NaturalSpeech2(
    cfg=cfg.model
)

In [8]:
ns2_model.load_state_dict(torch.load("ckpts/ns2/ns2_model.bin"))
ns2_model.eval()
ns2_model.cuda()
ns2_model.requires_grad_(False)

NaturalSpeech2(
  (reference_encoder): ReferenceEncoder(
    (in_linear): Linear(in_features=256, out_features=512, bias=True)
    (transformer_encoder): TransformerEncoder(
      (position_emb): PositionalEncoding()
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (ffn): TransformerFFNLayer(
            (ffn_1): Conv1d(512, 2048, kernel_size=(9,), stride=(1,), padding=(4,))
            (ffn_2): Linear(in_features=2048, out_features=512, bias=True)
          )
        )
      )
      (last_ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (query_embs): Embedding(32, 512)
    (query_attn): MultiheadAttention(
      (out_proj): NonDynamicall

In [9]:
g2p = G2p()

In [10]:
target_text = "Come, come returned Hawkeye, uncasing his honest countenance, the better to assure the wavering confidence of his companion. You may see a skin which, if it be not as white as one of the gentle ones, has no tinge of red to it that the winds of the heaven and the sun have not bestowed. Now, let us to business."
txt_struct, txt = process(target_text, g2p)
phone_seq = [p for w in txt_struct for p in w[1]][1:-1]
phone_ids = [PHPONE2ID[p] for p in phone_seq]

In [11]:
ref_wa_path = "examples/ref/1.wav"
ref_wav, sr = librosa.load(ref_wa_path, sr=16000)
ref_wav = torch.from_numpy(ref_wav).float().cuda()
ref_wav = ref_wav.unsqueeze(0)
ref_latent = codec_enc(ref_wav.unsqueeze(1))

In [12]:
ref_latent = ref_latent.transpose(1, 2)
ref_mask = torch.ones(ref_latent.size(0), ref_latent.size(1)).cuda()
print(ref_latent.shape, ref_mask.shape)

torch.Size([1, 272, 256]) torch.Size([1, 272])


In [13]:
phone_ids = torch.tensor(phone_ids).unsqueeze(0).cuda()
print(phone_ids.shape)

torch.Size([1, 261])


In [None]:
x0, prior_out = ns2_model.inference(
    phone_id=phone_ids,
    x_ref=ref_latent,
    x_ref_mask=ref_mask,
    inference_steps=200,
    sigma=1.2,
)
print(x0.shape)

In [15]:
recon_wav = codec_dec(x0.transpose(1, 2), vq=False)
recon_ref_wav = codec_dec(ref_latent.transpose(1, 2), vq=False)

In [16]:
Audio(recon_wav.squeeze().cpu().numpy(), rate=16000)