In [None]:
from CookieTTS.utils.audio.stft import TacotronSTFT, STFT
from CookieTTS.utils.dataset.utils import load_wav_to_torch

import sys
sys.path.append('../_4_mtw/waveglow') # add WaveGlow to System path for easier importing
import os
from glob import glob
import shutil

import numpy as np
from scipy import signal
from scipy.io.wavfile import write
import torch

import matplotlib
%matplotlib inline
import matplotlib.pylab as plt
import IPython.display as ipd

# 1 - Initialize WaveGlow and Load Checkpoint/Weights

In [None]:
# Load WaveGlow
import json

waveglow_path = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_6_Flow_512C_ssvae2\best_val_model"
config_fpath = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_6_Flow_512C_ssvae2\config.json"

def is_ax(config):
    """Quickly check if a model uses the Ax WaveGlow core by what's available in the config file."""
    return True if 'upsample_first' in config.keys() else False

# Load config file
with open(config_fpath) as f:
    data = f.read()
config = json.loads(data)
train_config = config["train_config"]
data_config = config["data_config"]
if 'preempthasis' not in data_config.keys():
    data_config['preempthasis'] = 0.0
dist_config = config["dist_config"]
data_config['n_mel_channels'] = config["waveglow_config"]['n_mel_channels'] if 'n_mel_channels' in config["waveglow_config"].keys() else 160
waveglow_config = {
    **config["waveglow_config"], 
    'win_length': data_config['win_length'],
    'hop_length': data_config['hop_length'],
    'preempthasis': data_config['preempthasis']
}
print(waveglow_config)
print(f"Config File from '{config_fpath}' successfully loaded.")

# import the correct model core
if is_ax(waveglow_config):
    from CookieTTS._4_mtw.waveglow.efficient_model_ax import WaveGlow
else:
    if waveglow_config["yoyo"]:
        from CookieTTS._4_mtw.waveglow.efficient_model import WaveGlow
    else:
        from CookieTTS._4_mtw.waveglow.glow import WaveGlow
from CookieTTS._4_mtw.waveglow.denoiser import Denoiser

# initialize model
print(f"intializing WaveGlow model... ", end="")
waveglow = WaveGlow(**waveglow_config).cuda()
print(f"Done!")

# load checkpoint from file
print(f"loading WaveGlow checkpoint... ", end="")
checkpoint = torch.load(waveglow_path)
waveglow.load_state_dict(checkpoint['model']) # and overwrite initialized weights with checkpointed weights
waveglow.cuda().eval().half() # move to GPU and convert to half precision
waveglow.remove_weightnorm()
print(f"Done!")

print(f"initializing Denoiser... ", end="")
denoiser = Denoiser(waveglow, mu=0., var=2.0, stft_device='cpu', speaker_dependant=False)
print(f"Done!")
waveglow_iters = checkpoint['iteration']
print(f"WaveGlow trained for {waveglow_iters} iterations")
speaker_lookup = checkpoint['speaker_lookup'] # ids lookup
training_sigma = train_config['sigma']

# 2 - Setup STFT to generate wavs from audio files

In [None]:
# Setup for generating Spectrograms from Audio files
print('Initializing STFT...')
stft = TacotronSTFT(data_config['filter_length'], data_config['hop_length'], data_config['win_length'],
                    data_config['n_mel_channels'], data_config['sampling_rate'], data_config['mel_fmin'],
                    data_config['mel_fmax'])
print('Done!')
def load_mel(path):
    if path.endswith('.wav') or path.endswith('.flac'):
        audio, sampling_rate, max_audio_value = load_wav_to_torch(path)
        if sampling_rate != stft.sampling_rate:
            raise ValueError("{} {} SR doesn't match target {} SR".format(
                sampling_rate, stft.sampling_rate))
        audio_norm = audio / max_audio_value
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = stft.mel_spectrogram(audio_norm)
    elif path.endswith('.npy'):
        melspec = torch.from_numpy(np.load(path))
    else:
        pass
    return melspec

# 3 - Reconstruct Audio from Audio Spectrogram using WaveGlow/Flow

In [None]:
folder_path = r"H:\ClipperDatasetV2\SlicedDialogue\FiM\S1\s1e26"
ext = '.mel.npy'
sigmas = [0.95,]
denoise_strengths = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 3.0]

speaker_ids = [167,]
speaker_ids = [speaker_lookup[x] for x in speaker_ids] # map speaker ids to internel
speaker_ids = torch.tensor(speaker_ids).cuda().long()

display_audio = False
display_denoised_audio = False

save_outputs = True
output_folder = r'D:\Downloads\infer\WaveFlow\AR_6_Flow_512C_ssvae2'

for audio_path in glob(os.path.join(folder_path, '**', f'*{ext}'), recursive=True):
    print(f"Audio Path:\n'{audio_path}'\n")
    mel_outputs_postnet = load_mel(audio_path).cuda()
    output_path = os.path.join(output_folder,os.path.splitext(os.path.split(audio_path)[-1])[0])
    os.makedirs(output_path, exist_ok=True)
    audios = []
    save_path = os.path.join(output_path, 'Ground Truth.wav')
    shutil.copy(audio_path.replace('.mel.npy','.wav'), save_path)
    with torch.no_grad():
        for i, sigma in enumerate(sigmas):
            with torch.random.fork_rng(devices=[0,]):
                torch.random.manual_seed(0)# use same Z / random seed during validation so results are more consistent and comparable.
                audio = waveglow.infer(mel_outputs_postnet, sigma=sigma, speaker_ids=speaker_ids, return_CPU=True).float()
                if (torch.isnan(audio) | torch.isinf(audio)).any():
                    print('inf or nan found in audio')
                audio[torch.isnan(audio) | torch.isinf(audio)] = 0.0
                #audio[:,-1]=1.0
                audios.append(audio)
                print(f"sigma = {sigma}")
                if display_audio:
                    ipd.display(ipd.Audio(audio[0].data.cpu().numpy(), rate=data_config['sampling_rate']))
                if save_outputs:
                    save_path = os.path.join(output_path, f'denoise_{0.00:0.2f}_sigma_{sigma:0.2f}.wav')
                    write(save_path, data_config['sampling_rate'], (audio[0]* 2**15).data.cpu().numpy().astype('int16'))
        
        for i, (audio, sigma) in enumerate(zip(audios, sigmas)):
            for denoise_strength in denoise_strengths:
                audio_denoised = denoiser(audio, speaker_ids=speaker_ids, strength=denoise_strength)[:, 0]
                if (torch.isnan(audio) | torch.isinf(audio)).any():
                    print('inf or nan found in audio')
                assert (not torch.isinf(audio_denoised).any()) or (not torch.isnan(audio_denoised).any())
                print(f"[Denoised Strength {denoise_strength}] [sigma {sigma}]")
                if display_denoised_audio:
                    ipd.display(ipd.Audio(audio_denoised.cpu().numpy(), rate=data_config['sampling_rate']))
                if save_outputs:
                    save_path = os.path.join(output_path, f'denoise_{denoise_strength:0.2f}_sigma_{sigma:0.2f}.wav')
                    write(save_path, data_config['sampling_rate'], (audio_denoised[0]* 2**15).data.cpu().numpy().astype('int16'))
        print('\n')

# (Testing) Dynamic Time Warping for GTA Alignment

In [None]:
target = torch.rand(1, 2, 700)
pred = torch.rand(1, 2, 700)

@torch.jit.script
def DTW(batch_pred, batch_target, scale_factor: int, range_: int):
    """
    Calcuates ideal time-warp for each frame to minimize L1 Error from target.
    Params:
        scale_factor: Scale factor for linear interpolation.
                      Values greater than 1 allows blends neighbouring frames to be used.
        range_: Range around the target frame that predicted frames should be tested as possible candidates to output.
                If range is set to 1, then predicted frames with more than 0.5 distance cannot be used. (where 0.5 distance means blending the 2 frames together).
    """
    assert range_ % 2 == 1, 'range_ must be an odd integer.'
    assert batch_pred.shape == batch_target.shape, 'pred and target shapes do not match.'
    
    batch_pred_dtw = batch_pred * 0.
    for i, (pred, target) in enumerate(zip(batch_pred, batch_target)):
        pred = pred.unsqueeze(0)
        target = target.unsqueeze(0)
        
        # shift pred into all aligned forms that might produce improved L1
        pred_pad = torch.nn.functional.pad(pred, (range_//2, range_//2))
        pred_expanded = torch.nn.functional.interpolate(pred_pad, scale_factor=float(scale_factor), mode='linear', align_corners=False)# [B, C, T] -> [B, C, T*s]
        
        p_shape = pred.shape
        pred_list = []
        for j in range(scale_factor*range_):
            pred_list.append(pred_expanded[:,:,j::scale_factor][:,:,:p_shape[2]])
        
        pred_dtw = pred.clone()
        for pred_interpolated in pred_list:
            new_l1 = torch.nn.functional.l1_loss(pred_interpolated, target, reduction='none').sum(dim=1, keepdim=True)
            old_l1 = torch.nn.functional.l1_loss(pred_dtw, target, reduction='none').sum(dim=1, keepdim=True)
            pred_dtw = torch.where(new_l1 < old_l1, pred_interpolated, pred_dtw)
        batch_pred_dtw[i:i+1] = pred_dtw
    return batch_pred_dtw

pred_dtw = DTW(pred, target, 4, 3)
print(torch.nn.functional.l1_loss(pred, target))
print(torch.nn.functional.l1_loss(pred_dtw, target))

In [None]:
import matplotlib
%matplotlib inline
import matplotlib.pylab as plt
import IPython.display as ipd
def plot_data(data, title=None, figsize=(20, 5)):
    %matplotlib inline
    fig = plt.figure(figsize=figsize)

    ax = fig.add_subplot(111)
    plt.imshow(data, cmap='inferno', origin='lower',
                   interpolation='none')
    ax.set_aspect('equal')

    cax = fig.add_axes([0.12, 0.1, 0.78, 0.8])
    cax.get_xaxis().set_visible(False)
    cax.get_yaxis().set_visible(False)
    cax.patch.set_alpha(0)
    cax.set_frame_on(False)
    plt.colorbar(orientation='vertical')
    plt.show()

In [None]:
import random
filetext = open(r"G:\TwiBot\CookiePPPTTS\CookieTTS\_2_ttm\tacotron2\GTA_flist\map_train.txt", "r").read().split("\n")
filter_str = [".mel100",".mel200",".mel300",".mel400",".mel500"]
filetext = [x for x in filetext if not any(str_ in x for str_ in filter_str)]

rand_start = int(random.random()*len(filetext))
file_count = 10
for line in filetext[rand_start:rand_start+file_count]:
    pred_mel_path = line.split("|")[1].replace("\n","").replace("/media/cookie/Samsung 860 QVO/", "H:\\")
    
    mel_pred = torch.from_numpy(np.load(pred_mel_path)).float().unsqueeze(0)
    mel_target = torch.from_numpy(np.load(pred_mel_path.replace('.mel.npy','.npy'))).float().unsqueeze(0)
    mel_pred_dtw = DTW(mel_pred, mel_target, scale_factor = 10, range_= 9)
    print(mel_pred.shape)
    print(
        torch.nn.functional.mse_loss(mel_pred, mel_target),
        torch.nn.functional.mse_loss(mel_pred_dtw, mel_target),
        sep='\n')
    start_frame = 0
    end_frame = 999
    plot_data(mel_pred[0][:,start_frame:end_frame].numpy())
    plot_data(mel_target[0][:,start_frame:end_frame].numpy())
    plot_data(mel_pred_dtw[0][:,start_frame:end_frame].numpy())
    print("\n\n\n")

# (Testing) Timestamps for Model inputs

In [None]:
alignments = torch.rand(1, 80, 12)
sequence = torch.rand(1, 12)
dur_frames = torch.histc(torch.argmax(alignments[0], dim=1).float(), min=0, max=sequence.shape[1]-1, bins=sequence.shape[1])# number of frames each letter taken the maximum focus of the model.
dur_seconds = dur_frames * (275.625/22050)# convert from frames to seconds
end_times = dur_seconds * 0.0# new empty list
for i, dur_second in enumerate(dur_seconds): # calculate the end times for each letter.
    end_times[i] = end_times[i-1] + dur_second# by adding up the durations of all the letters that go before it
start_times = torch.nn.functional.pad(end_times, (1,0))[:-1]# calculate the start times by assuming the next letter starts the moment the last one ends.
for i, (dur, start, end) in enumerate(zip(dur_seconds, start_times, end_times)):
    print(f"[Letter {i:02}]\nDuration:\t{dur:.3f}\nStart Time:\t{start:.3f}\nEnd Time:\t{end:.3f}\n")