In [1]:
from CookieTTS.utils.audio.stft import TacotronSTFT, STFT
from CookieTTS.utils.dataset.utils import load_wav_to_torch, DTW

import sys
sys.path.append('../_4_mtw/waveglow') # add WaveGlow to System path for easier importing
import os
from glob import glob
import shutil

import numpy as np
from scipy import signal
from scipy.io.wavfile import write
import torch

import matplotlib
%matplotlib inline
import matplotlib.pylab as plt
import IPython.display as ipd

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit


# 1 - Initialize WaveGlow and Load Checkpoint/Weights

In [2]:
# Load WaveGlow
def load_waveglow(
    waveglow_path = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_HDN\best_val_model",
    config_fpath = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_HDN\config.json",):
    
    import json
    
    def is_ax(config):
        """Quickly check if a model uses the Ax WaveGlow core by what's available in the config file."""
        return True if 'upsample_first' in config.keys() else False
    
    # Load config file
    with open(config_fpath) as f:
        data = f.read()
    config = json.loads(data)
    train_config = config["train_config"]
    data_config = config["data_config"]
    if 'preempthasis' not in data_config.keys():
        data_config['preempthasis'] = 0.0
    if 'use_logvar_channels' not in data_config.keys():
        data_config['use_logvar_channels'] = False
    if 'load_hidden_from_disk' not in data_config.keys():
        data_config['load_hidden_from_disk'] = False
    if not 'iso226_empthasis' in data_config.keys():
        data_config["iso226_empthasis"] = False
    dist_config = config["dist_config"]
    data_config['n_mel_channels'] = config["waveglow_config"]['n_mel_channels'] if 'n_mel_channels' in config["waveglow_config"].keys() else 160
    waveglow_config = {
        **config["waveglow_config"],
        'win_length': data_config['win_length'],
        'hop_length': data_config['hop_length'],
        'preempthasis': data_config['preempthasis'],
        'n_mel_channels': data_config["n_mel_channels"],
        'use_logvar_channels': data_config["use_logvar_channels"],
        'load_hidden_from_disk': data_config["load_hidden_from_disk"],
        'iso226_empthasis': data_config["iso226_empthasis"]
    }
    print(waveglow_config)
    print(f"Config File from '{config_fpath}' successfully loaded.")
    
    # import the correct model core
    if is_ax(waveglow_config):
        from CookieTTS._4_mtw.waveglow.efficient_model_ax import WaveGlow
    else:
        if waveglow_config["yoyo"]:
            from CookieTTS._4_mtw.waveglow.efficient_model import WaveGlow
        else:
            from CookieTTS._4_mtw.waveglow.glow import WaveGlow
    from CookieTTS._4_mtw.waveglow.denoiser import Denoiser
    
    # initialize model
    print(f"intializing WaveGlow model... ", end="")
    waveglow = WaveGlow(**waveglow_config).cuda()
    print(f"Done!")
    
    # load checkpoint from file
    print(f"loading WaveGlow checkpoint... ", end="")
    checkpoint = torch.load(waveglow_path)
    waveglow.load_state_dict(checkpoint['model']) # and overwrite initialized weights with checkpointed weights
    waveglow.cuda().eval() # move to GPU and convert to half precision
    #waveglow.half()
    #waveglow.remove_weightnorm()
    print(f"Done!")
    
    print(f"initializing Denoiser... ", end="")
    cond_channels = waveglow_config['n_mel_channels']*(waveglow_config['use_logvar_channels']+1)
    denoiser = Denoiser(waveglow, n_mel_channels=cond_channels, mu=0.0, var=1.0, stft_device='cpu', speaker_dependant=False)
    print(f"Done!")
    waveglow_iters = checkpoint['iteration']
    print(f"WaveGlow trained for {waveglow_iters} iterations")
    speaker_lookup = checkpoint['speaker_lookup'] # ids lookup
    training_sigma = train_config['sigma']
    return waveglow, denoiser, speaker_lookup, training_sigma, waveglow_iters, waveglow_config, data_config

waveglow, denoiser, speaker_lookup, training_sigma, waveglow_iters, waveglow_config, data_config = load_waveglow(
                        waveglow_path = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_DTW2\best_val_model",
                        config_fpath = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_DTW2\config.json")

{'shift_spect': 0.0, 'scale_spect': 1.0, 'preceived_vol_scaling': False, 'waveflow': True, 'channel_mixing': 'permute', 'mix_first': False, 'n_flows': 8, 'n_group': 20, 'n_early_every': 16, 'n_early_size': 2, 'memory_efficient': 0.0, 'spect_scaling': False, 'upsample_mode': 'normal', 'WN_config': {'gated_unit': 'GTU', 'n_layers': 8, 'n_channels': 128, 'kernel_size_w': 7, 'kernel_size_h': 7, 'n_layers_dilations_w': None, 'n_layers_dilations_h': 1, 'speaker_embed_dim': 96, 'rezero': False, 'cond_layers': 3, 'cond_activation_func': 'lrelu', 'cond_out_activation_func': True, 'negative_slope': 0.5, 'cond_hidden_channels': 256, 'cond_kernel_size': 1, 'cond_padding_mode': 'zeros', 'seperable_conv': True, 'res_skip': True, 'merge_res_skip': False, 'upsample_mode': 'linear'}, 'n_mel_channels': 160, 'speaker_embed': 96, 'cond_layers': 5, 'cond_activation_func': 'lrelu', 'negative_slope': 0.25, 'cond_hidden_channels': 512, 'cond_output_channels': 256, 'cond_residual': True, 'cond_res_rezero': Tru

# 2 - Setup STFT to generate wavs from audio files

In [3]:
# Setup for generating Spectrograms from Audio files
def load_mel(path):
    if path.endswith('.wav') or path.endswith('.flac'):
        audio, sampling_rate, max_audio_value = load_wav_to_torch(path)
        if sampling_rate != stft.sampling_rate:
            raise ValueError("{} {} SR doesn't match target {} SR".format(
                sampling_rate, stft.sampling_rate))
        audio_norm = audio / max_audio_value
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = stft.mel_spectrogram(audio_norm)
    elif path.endswith('.npy'):
        melspec = torch.from_numpy(np.load(path)).float()
    else:
        pass
    return melspec

print('Initializing STFT...')
stft = TacotronSTFT(data_config['filter_length'], data_config['hop_length'], data_config['win_length'],
                    data_config['n_mel_channels'], data_config['sampling_rate'], data_config['mel_fmin'],
                    data_config['mel_fmax'])
print('Done!')

Initializing STFT...
Done!


# 3 - Reconstruct Audio from Audio Spectrogram using WaveGlow/Flow

In [4]:
waveglow_paths = [
#    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF4.2_iso226\best_model",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF4.1\best_model",
#    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF\best_val_model",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_DTW2\best_val_model",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_6_Flow_512C_ssvae2_2\best_val_model",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_6_Flow_512C_ssvae2\best_model",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_DTW\best_model",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow\best_val_model_gt3",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_HDN\best_val_model",
]
config_fpaths = [
#    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF4.2_iso226\config.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF4.1\config.json",
#    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF\config.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_DTW2\config.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_6_Flow_512C_ssvae2_2\config.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_6_Flow_512C_ssvae2\config.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_DTW\config.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow\config_original.json",
    r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_HDN\config.json",
]
output_dirnames = [
#    "AR_8_Flow_AEF4.2_iso226",
    "AR_8_Flow_AEF4.1_gt",
#    "AR_8_Flow_AEF_gt",
    "AR_8_Flow_DTW2",
    "AR_6_Flow_512C_ssvae2_2",
    "AR_6_Flow_512C_ssvae2",
    "AR_8_Flow_DTW",
    "AR_8_Flow_gt3",
    "AR_8_Flow_HDN",
]
exts = [
#    '__*.npy',
    '__*.npy',
#    '__*.npy',
    '__*.mel.npy',
    '__*.mel.npy',
    '__*.mel.npy',
    '__*.mel.npy',
    '__*.mel.npy',
    '__*.hdn.npy',
]

folder_paths = [
    r"H:\ClipperDatasetV2\SlicedDialogue\FiM\S2\s2e1",
    r"H:\ClipperDatasetV2\SlicedDialogue\FiM\S2\s2e2",
    r"H:\ClipperDatasetV2\SlicedDialogue\FiM\S4\s4e12",
    r"H:\ClipperDatasetV2\SlicedDialogue\FiM\S5\s5e18",
    r"H:\ClipperDatasetV2\SlicedDialogue\FiM\S9\s9e8",
]
gt_ext = '.npy'
sigmas = [0.0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,]
denoise_strengths = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0]

speaker_ids = [0,]
speaker_ids = [speaker_lookup[x] for x in speaker_ids] # map speaker ids to internel
speaker_ids = torch.tensor(speaker_ids).cuda().long()

use_DTW = False

display_audio = False
display_denoised_audio = False

save_outputs = True

for waveglow_path, config_fpath, output_dirname, ext in list(zip(waveglow_paths, config_fpaths, output_dirnames, exts))[:1]:
    waveglow, denoiser, speaker_lookup, training_sigma, waveglow_iters, waveglow_config, data_config = load_waveglow(
                                         waveglow_path=waveglow_path, config_fpath=config_fpath)
    
    output_folder = r"D:\Downloads\infer\WaveFlow\\" + f"{output_dirname}" + f"_{waveglow_iters}"
    
    audio_paths = [glob(os.path.join(folder_path, '**', f'*{ext}'), recursive=True) for folder_path in folder_paths]
    audio_paths = [item for sublist in audio_paths for item in sublist]
    if ext is '.npy' or '__*.npy':
        audio_paths = [x for x in audio_paths if not (x.endswith('.hdn.npy') or x.endswith('.mel.npy') or x.endswith('0.npy') or x.endswith('gdur.npy') or x.endswith('genc_out.npy') or x.endswith('pdur.npy') or x.endswith('penc_out.npy'))]
    print(f'Generating Audio from {len(audio_paths)} Files...')
    for audio_path in audio_paths:
        print(f"Audio Path:\n'{audio_path}'")
        mel_outputs_postnet = load_mel(audio_path).cuda()
        
        if not waveglow_config['use_logvar_channels'] and (mel_outputs_postnet.shape[0] == waveglow.n_mel_channels*2):
            mel_outputs_postnet = mel_outputs_postnet.chunk(2, dim=0)[0]
            mel_logvars_postnet = None
        elif not waveglow_config['use_logvar_channels'] and (mel_outputs_postnet.shape[0] == waveglow.n_mel_channels):
            mel_logvars_postnet = None
        elif waveglow_config['use_logvar_channels'] and (mel_outputs_postnet.shape[0] == waveglow.n_mel_channels*2):
            mel_outputs_postnet, mel_logvars_postnet = mel_outputs_postnet.chunk(2, dim=0)
        elif waveglow_config['use_logvar_channels'] and (mel_outputs_postnet.shape[0] == waveglow.n_mel_channels*2):
            mel_logvars_postnet = mel_outputs_postnet.new_ones(mel_outputs_postnet.shape) * -4.9
        else:
            print(f"Saved file has Wrong Shape!\nPath: '{audio_path}'")
            continue
        if use_DTW and not waveglow_config['load_hidden_from_disk']:
            gt_mel_outputs_postnet = load_mel(audio_path.replace(ext, gt_ext)).cuda()
            mel_outputs_postnet = DTW(mel_outputs_postnet.unsqueeze(0), gt_mel_outputs_postnet.unsqueeze(0), 8, 7).squeeze(0)
        if mel_logvars_postnet is not None:
            mel_outputs_postnet = torch.cat((mel_outputs_postnet, mel_logvars_postnet), dim=0)
        
        output_path = os.path.join(output_folder, os.path.splitext(os.path.split(audio_path)[-1])[0])
        os.makedirs(output_path, exist_ok=True)
        audios = []
        save_path = os.path.join(output_path, 'Ground Truth.wav')
        wav_path = audio_path.replace('.hdn.npy','.wav').replace('.mel.npy','.wav').replace('.npy','.wav')
        shutil.copy(wav_path, save_path)
        with torch.no_grad():
            for i, sigma in enumerate(sigmas):
                with torch.random.fork_rng(devices=[0,]):
                    torch.random.manual_seed(0)# use same Z / random seed during validation so results are more consistent and comparable.
                    
                    audio = waveglow.infer(mel_outputs_postnet.unsqueeze(0), sigma=sigma, speaker_ids=speaker_ids, return_CPU=True).float().clamp(min=-0.999, max=0.999)
                
                if (torch.isnan(audio) | torch.isinf(audio)).any():
                    print('inf or nan found in audio')
                audio[torch.isnan(audio) | torch.isinf(audio)] = 0.0
                #audio[:,-1]=1.0
                audios.append(audio)
                if display_audio:
                    ipd.display(ipd.Audio(audio[0].data.cpu().numpy(), rate=data_config['sampling_rate']))
                if save_outputs:
                    save_path = os.path.join(output_path, f'denoise_{0.00:0.2f}_sigma_{sigma:0.2f}.wav')
                    write(save_path, data_config['sampling_rate'], (audio[0]* 2**15).data.cpu().numpy().astype('int16'))
            
            for i, (audio, sigma) in enumerate(zip(audios, sigmas)):
                for denoise_strength in denoise_strengths:
                    audio_denoised = denoiser(audio, speaker_ids=speaker_ids, strength=denoise_strength)[:, 0]
                    if (torch.isnan(audio) | torch.isinf(audio)).any():
                        print('inf or nan found in audio')
                    assert (not torch.isinf(audio_denoised).any()) or (not torch.isnan(audio_denoised).any())
                    #print(f"[Denoised Strength {denoise_strength}] [sigma {sigma}]")
                    if display_denoised_audio:
                        ipd.display(ipd.Audio(audio_denoised.cpu().numpy(), rate=data_config['sampling_rate']))
                    if save_outputs:
                        save_path = os.path.join(output_path, f'denoise_{denoise_strength:0.2f}_sigma_{sigma:0.2f}.wav')
                        write(save_path, data_config['sampling_rate'], (audio_denoised[0]* 2**15).data.cpu().numpy().astype('int16'))
            print('')

{'shift_spect': 0.0, 'scale_spect': 1.0, 'preceived_vol_scaling': False, 'waveflow': True, 'channel_mixing': 'permute', 'mix_first': False, 'n_flows': 8, 'n_group': 20, 'n_early_every': 16, 'n_early_size': 2, 'memory_efficient': 0.0, 'spect_scaling': False, 'upsample_mode': 'normal', 'WN_config': {'gated_unit': 'GTU', 'n_layers': 8, 'n_channels': 256, 'kernel_size_w': 7, 'kernel_size_h': 7, 'n_layers_dilations_w': None, 'n_layers_dilations_h': 1, 'speaker_embed_dim': 0, 'rezero': False, 'transposed_conv_hidden_dim': 256, 'transposed_conv_kernel_size': [2, 3, 5], 'transposed_conv_scales': None, 'cond_layers': 0, 'cond_activation_func': 'lrelu', 'cond_out_activation_func': False, 'negative_slope': 0.5, 'cond_hidden_channels': 256, 'cond_kernel_size': 1, 'cond_padding_mode': 'zeros', 'seperable_conv': True, 'res_skip': True, 'merge_res_skip': False, 'upsample_mode': 'linear'}, 'n_mel_channels': 160, 'speaker_embed': 32, 'cond_layers': 4, 'cond_activation_func': 'lrelu', 'negative_slope': 

RuntimeError: CUDA error: out of memory

# (Testing) Blending GT and Pred Spectrograms

In [None]:
import torch
min_ = 110
max_ = 120
n_mel_channels = 160
gt_perc = ((torch.arange(1, n_mel_channels+1).float()-min_).clamp(0)/(max_-min_)).clamp(max=1.0)
print(gt_perc)

# (Testing) Dynamic Time Warping for GTA Alignment

In [None]:
import torch
import numpy as np

target = torch.rand(1, 2, 700)
pred = torch.rand(1, 2, 700)

@torch.jit.script
def DTW(batch_pred, batch_target, scale_factor: int, range_: int):
    """
    Calcuates ideal time-warp for each frame to minimize L1 Error from target.
    Params:
        scale_factor: Scale factor for linear interpolation.
                      Values greater than 1 allows blends neighbouring frames to be used.
        range_: Range around the target frame that predicted frames should be tested as possible candidates to output.
                If range is set to 1, then predicted frames with more than 0.5 distance cannot be used. (where 0.5 distance means blending the 2 frames together).
    """
    assert range_ % 2 == 1, 'range_ must be an odd integer.'
    assert batch_pred.shape == batch_target.shape, 'pred and target shapes do not match.'
    
    batch_pred_dtw = batch_pred * 0.
    for i, (pred, target) in enumerate(zip(batch_pred, batch_target)):
        pred = pred.unsqueeze(0)
        target = target.unsqueeze(0)
        
        # shift pred into all aligned forms that might produce improved L1
        pred_pad = torch.nn.functional.pad(pred, (range_//2, range_//2))
        pred_expanded = torch.nn.functional.interpolate(pred_pad, scale_factor=float(scale_factor), mode='linear', align_corners=False)# [B, C, T] -> [B, C, T*s]
        
        p_shape = pred.shape
        pred_list = []
        for j in range(scale_factor*range_):
            pred_list.append(pred_expanded[:,:,j::scale_factor][:,:,:p_shape[2]])
        
        pred_dtw = pred.clone()
        for pred_interpolated in pred_list:
            new_l1 = torch.nn.functional.l1_loss(pred_interpolated, target, reduction='none').sum(dim=1, keepdim=True)
            old_l1 = torch.nn.functional.l1_loss(pred_dtw, target, reduction='none').sum(dim=1, keepdim=True)
            pred_dtw = torch.where(new_l1 < old_l1, pred_interpolated, pred_dtw)
        batch_pred_dtw[i:i+1] = pred_dtw
    return batch_pred_dtw

pred_dtw = DTW(pred, target, 4, 3)
print(torch.nn.functional.l1_loss(pred, target))
print(torch.nn.functional.l1_loss(pred_dtw, target))

In [None]:
import matplotlib
%matplotlib inline
import matplotlib.pylab as plt
import IPython.display as ipd
def plot_data(data, title=None, figsize=(20, 5)):
    %matplotlib inline
    fig = plt.figure(figsize=figsize)

    ax = fig.add_subplot(111)
    plt.imshow(data, cmap='inferno', origin='lower',
                   interpolation='none')
    ax.set_aspect('equal')
    
    cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
    cax.get_xaxis().set_visible(False)
    cax.get_yaxis().set_visible(False)
    cax.patch.set_alpha(0)
    cax.set_frame_on(False)
    plt.colorbar(orientation='vertical')
    plt.show()

In [2]:
import torch
torch.tensor(-0.4).exp()

tensor(0.6703)

In [None]:
import random
filetext = open(r"G:\TwiBot\CookiePPPTTS\CookieTTS\_2_ttm\tacotron2\GTA_flist2\map_val.txt", "r").read().split("\n")
filter_str = [".mel100",".mel200",".mel300",".mel400",".mel500"]
filetext = [x for x in filetext if not any(str_ in x for str_ in filter_str)]

rand_start = int(random.random()*len(filetext))-file_count
rand_start = 10
file_count = 20
for line in filetext[rand_start:rand_start+file_count]:
    pred_mel_path = line.split("|")[1].replace("\n","").replace("/media/cookie/Samsung 860 QVO/", "H:\\")
    
    mel_pred = torch.from_numpy(np.load(pred_mel_path)).float().unsqueeze(0)
    mel_pred[:, 120:, :] = 0.0
    mel_target = torch.from_numpy(np.load(pred_mel_path.replace('.mel.npy','.npy'))).float().unsqueeze(0)
    mel_target[:, 120:, :] = 0.0
    mel_pred_dtw = DTW(mel_pred, mel_target, scale_factor = 8, range_= 7)
    print(mel_pred.shape)
    print(
        torch.nn.functional.mse_loss(mel_pred, mel_target),
        torch.nn.functional.mse_loss(mel_pred_dtw, mel_target),
        sep='\n')
    start_frame = 0
    end_frame = 999
    plot_data(mel_pred[0][:,start_frame:end_frame].numpy())
    plot_data(mel_target[0][:,start_frame:end_frame].numpy())
    plot_data(mel_pred_dtw[0][:,start_frame:end_frame].numpy())
    print("\n\n\n")

# (Testing) Timestamps for Model inputs

In [None]:
alignments = torch.rand(1, 80, 12)
sequence = torch.rand(1, 12)
dur_frames = torch.histc(torch.argmax(alignments[0], dim=1).float(), min=0, max=sequence.shape[1]-1, bins=sequence.shape[1])# number of frames each letter taken the maximum focus of the model.
dur_seconds = dur_frames * (275.625/22050)# convert from frames to seconds
end_times = dur_seconds * 0.0# new empty list
for i, dur_second in enumerate(dur_seconds): # calculate the end times for each letter.
    end_times[i] = end_times[i-1] + dur_second# by adding up the durations of all the letters that go before it
start_times = torch.nn.functional.pad(end_times, (1,0))[:-1]# calculate the start times by assuming the next letter starts the moment the last one ends.
for i, (dur, start, end) in enumerate(zip(dur_seconds, start_times, end_times)):
    print(f"[Letter {i:02}]\nDuration:\t{dur:.3f}\nStart Time:\t{start:.3f}\nEnd Time:\t{end:.3f}\n")