In [1]:
import json
import itertools
import random
import glob
from IPython.display import Audio

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DistributedSampler, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [3]:
from utils.utils import pad_or_trim, AttrDict, get_logger, save_checkpoint,load_checkpoint
from utils.hifigan_utils import plot_spectrogram

In [4]:
from models.encoder import LF0Encoder, PPGEncoder, PhonemeEncoder
from models.decoder import MelDecoder
from models.vocoder.diffwave_models import DiffWave
from models.vocoder.hifigan_models import Generator
from models.vocoder.cfm_models import CFMWave

from models.svc import GANSVC, DIFFSVC, CFMSVC

from datasets.dataset_diff import SVCDIFFDataset
from datasets.dataset_gan import SVCGANDataset

In [5]:
import logging
import torchaudio
import os
import time
import pickle
import numpy as np
from math import sqrt

from IPython.display import Audio
%matplotlib inline

import matplotlib.pyplot as plt
from config import (
    data_path, TARGET_SAMPLE_RATE, root_path, 
    WHISPER_DIM, HUBERT_DIM, MCEP_DIM,
    MAX_MEL_LENGTH, MAX_AUDIO_LENGTH
)

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [48]:
experiment_path = os.path.join(root_path,"experiment")
model_name = "cfm"
device = torch.device(f'cuda:1' if torch.cuda.is_available() else 'cpu')
# device = "cpu"
lf0_out_dims = 8
ppg_out_dims = 128
pho_out_dims = 128

In [49]:
lf0_encoder = LF0Encoder(in_dims=1,out_dims=lf0_out_dims).to(device)
ppg_encoder = PPGEncoder(output_dims=ppg_out_dims, d_model=WHISPER_DIM).to(device)
pho_encoder = PhonemeEncoder(output_dims=pho_out_dims, d_model=HUBERT_DIM).to(device)


mel_decoder = MelDecoder(
                lf0_dims=lf0_out_dims,ppg_dims=ppg_out_dims,
                phoneme_dims=pho_out_dims,output_dims=MCEP_DIM).to(device)

In [50]:
if model_name.lower() == "gan":
    model_path = os.path.join(experiment_path, "HIFI_GAN_20230416_234705/models")
    model_path = glob.glob(os.path.join(model_path, "g*"))[0]
    with open("configs/hifigan_config.json") as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)
        
    generator = Generator(h).to(device)
    
    svc = GANSVC(lf0_encoder,ppg_encoder,pho_encoder,decoder=mel_decoder,vocoder=generator).to(device)
    
    testset = SVCGANDataset(
                    dataset="M4Singer",
                    dataset_type="test", args=None,
                n_fft=h.n_fft, num_mels=h.num_mels,hop_size=h.hop_size, 
                win_size=h.win_size, sampling_rate=TARGET_SAMPLE_RATE,  
        fmin=h.fmin, fmax=h.fmax, fmax_loss=h.fmax_loss, mel_crop_length=MAX_MEL_LENGTH)
elif model_name.lower() == "diff":
    model_path = os.path.join(experiment_path,  "DIFFWAVE_20230423_214609/models")
    model_path = glob.glob(os.path.join(model_path, "g*"))[0]
    h = AttrDict(
    # Training params
    batch_size=16,
    learning_rate=2e-4,
    max_grad_norm=None,

    # Data params
    sample_rate=22050,
    n_mels=80,
    n_fft=1024,
    hop_samples=256,
    crop_mel_frames=62,  # Probably an error in paper.

    # Model params
    residual_layers=30,
    residual_channels=64,
    dilation_cycle_length=10,
    unconditional = False,
    noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
    inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],

    # unconditional sample len
    audio_len = 22050*5, # unconditional_synthesis_samples
    )
    
    diffwave = DiffWave(h).to(device)
    
    svc = DIFFSVC(lf0_encoder,ppg_encoder,pho_encoder,decoder=mel_decoder,vocoder=diffwave,
             noise_schedule=h.noise_schedule, 
              inference_noise_schedule=h.inference_noise_schedule,
             audio_len=MAX_AUDIO_LENGTH).to(device)
    
    testset = SVCDIFFDataset(
                    dataset="M4Singer",
                    dataset_type="test", args=None,
                n_fft=h.n_fft, num_mels=h.n_mels,hop_size=h.hop_samples, 
                win_size=h.hop_samples*4, sampling_rate=TARGET_SAMPLE_RATE,  
        fmin=20, fmax=TARGET_SAMPLE_RATE / 2.0, mel_crop_length=MAX_MEL_LENGTH,
            audio_crop_length=MAX_AUDIO_LENGTH)
elif model_name.lower() == "cfm":
    model_path = os.path.join(experiment_path,  "CFMWAVE_20230506_165603/models")
    model_path = glob.glob(os.path.join(model_path, "g*"))[0]
    h = AttrDict(
    # Training params
    batch_size=16,
    learning_rate=2e-4,
    max_grad_norm=None,

    # Data params
    sample_rate=22050,
    n_mels=80,
    n_fft=1024,
    hop_samples=256,
    crop_mel_frames=62,  # Probably an error in paper.

    # Model params
    residual_layers=30,
    residual_channels=64,
    dilation_cycle_length=10,
    unconditional = False,
    noise_schedule=np.linspace(1e-4, 0.05, 50).tolist(),
    inference_noise_schedule=[0.0001, 0.001, 0.01, 0.05, 0.2, 0.5],

    # unconditional sample len
    audio_len = 22050*5, # unconditional_synthesis_samples
    )
    
    cfmwave = CFMWave(h).to(device)
    
    svc = CFMSVC(lf0_encoder,ppg_encoder,pho_encoder,decoder=mel_decoder,vocoder=cfmwave,
             audio_len=MAX_AUDIO_LENGTH).to(device)
    
    testset = SVCDIFFDataset(
                    dataset="M4Singer",
                    dataset_type="test", args=None,
                n_fft=h.n_fft, num_mels=h.n_mels,hop_size=h.hop_samples, 
                win_size=h.hop_samples*4, sampling_rate=TARGET_SAMPLE_RATE,  
        fmin=20, fmax=TARGET_SAMPLE_RATE / 2.0, mel_crop_length=MAX_MEL_LENGTH,
            audio_crop_length=MAX_AUDIO_LENGTH)

In [51]:
ckpt = load_checkpoint(model_path,device)
svc.load_state_dict(ckpt['svc'])

<All keys matched successfully>

In [43]:
count_parameters(svc)

36739706

In [12]:
# source_idx = random.choice(range(len(testset)))
# target_idx = random.choice(range(len(testset)))

source_idx = 70
target_idx = 160

In [52]:
if model_name.lower() == 'gan':
    mel_s, wav_source, _, mel_mask_source, wav_mask_source, lf0_source, ppg_source, pho_source = testset[source_idx]

    mel_t, wav_target, _, mel_mask_target, wav_mask_target, lf0_target, ppg_target, pho_target = testset[target_idx]
    
else:
    mel_s, wav_source, mel_mask_source, wav_mask_source, lf0_source, ppg_source, pho_source = testset[source_idx]

    mel_t, wav_target, mel_mask_target, wav_mask_target, lf0_target, ppg_target, pho_target = testset[target_idx]

In [14]:
%%timeit
if model_name.lower() == 'gan':
    wav_conversion, mel_conversion =  svc(lf0_source.unsqueeze(0).unsqueeze(0).to(device),
        ppg_target.unsqueeze(0).to(device),
        pho_source.unsqueeze(0).to(device)) 

53.2 ns ± 0.561 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)


In [38]:
%%timeit
if model_name.lower() == 'diff':
    wav_conversion =  svc.inference(lf0_source.unsqueeze(0).unsqueeze(0).to(device),
        ppg_target.unsqueeze(0).to(device),
        pho_source.unsqueeze(0).to(device),fast_sampling=False) 

7min 12s ± 4.9 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [56]:
%%timeit
if model_name.lower() == 'cfm':
    wav_conversion =  svc.inference(lf0_source.unsqueeze(0).unsqueeze(0).to(device),
        ppg_target.unsqueeze(0).to(device),
        pho_source.unsqueeze(0).to(device),steps=100)

2.51 s ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
wav_conversion.shape

torch.Size([1, 256000])

In [23]:
torchaudio.save("source.wav",wav_source.unsqueeze(0), sample_rate=TARGET_SAMPLE_RATE)
torchaudio.save("target.wav",wav_target.unsqueeze(0), sample_rate=TARGET_SAMPLE_RATE)
torchaudio.save("convert.wav",wav_conversion.cpu(), sample_rate=TARGET_SAMPLE_RATE)

In [27]:
Audio("convert.wav")