## This notebook will search for WaveGlow/WaveFlow sample folders and calcuate metrics relating to audio quality for each folder.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from glob import glob
from CookieTTS.utils.dataset.utils import load_wav_to_torch
from CookieTTS.utils.audio.stft import STFT
from CookieTTS.utils.audio.audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression

Import requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit
Import of 'jit' requested from: 'numba.decorators', please update to use 'numba.core.decorators' or pin to Numba version 0.48.0. This alias will not be present in Numba version 0.50.0.[0m
  from numba.decorators import jit as optional_jit


In [2]:
directory = r"H:\TTCheckpoints\wave_logdir_symlink"
list_of_samples_folders = [x[0] for x in os.walk(directory) if x[0].endswith('samples')]

In [3]:
model_names = [x.split('H:\\TTCheckpoints\\wave_logdir_symlink\\')[1].split('\\samples')[0].replace('\\','.') for x in list_of_samples_folders]

In [4]:
win_lens = [300]
hop_len = 65
fil_len = 2400
sampling_rate = 48000

stfts = []
for win_len in win_lens:
    stft = STFT(filter_length=fil_len,
                   hop_length=hop_len,
                   win_length=win_len,).cuda()
    stfts.append(stft)

In [5]:
def get_spect(audio):
    """Take audio and convert to multi-res spectrogram"""
    melspec = []
    for stft in stfts:
        spect = stft.transform(audio.unsqueeze(0), return_phase=False)[0].squeeze(0)# -> [n_mel, dec_T]
        #print(spect.shape)
        melspec.append(spect)
    return torch.cat(melspec, dim=0)# [[n_mel, dec_T], ...] -> [n_stft*n_mel, dec_T]

In [6]:
import numpy as np
import iso226
import math
iso226_spl_from_freq = iso226.iso226_spl_itpl(L_N=60, hfe=True)# get InterpolatedUnivariateSpline for Perc Sound Pressure Level at Difference Frequencies with 60DB ref.
freq_weights = torch.tensor([(2**(60./10))/(2**(iso226_spl_from_freq(freq)/10)) for freq in np.linspace(0, sampling_rate//2, (fil_len//2)+1)])
freq_weights = freq_weights.cuda().repeat(len(win_lens))[:, None]# [B, n_mel, T]

In [12]:
total_models = len(model_names)
mean_best_vals = []
for i, (model_name, sample_folder) in enumerate(zip(model_names, list_of_samples_folders)):
    print(f'{i:3}/{total_models:<} - {model_name}')
    gt_folder = os.path.join(sample_folder, 'Ground Truth')
    assert os.path.exists(gt_folder)
    pred_folders = [x[0] for x in os.walk(sample_folder) if not x[0] == gt_folder][1:]
    pred_folders = [x for x in pred_folders if '000-' in x]# allow only samples where the iter is a multiple of 1000
    
    best_vals = []
    gt_audio_files = glob(os.path.join(gt_folder, '*.wav'))
    len_gt_audio_files = len(gt_audio_files)
    gt_spects = {}
    for i, gt_audio_file in enumerate(gt_audio_files):
        #print(os.path.split(gt_audio_file)[-1], end='... ')
        best_vals.append(1e5)
        gt_audio, sr, max_mag = load_wav_to_torch(gt_audio_file)
        gt_audio = gt_audio.cuda()
        gt_audio /= max_mag
        gt_spect = get_spect(gt_audio)
        gt_spect = dynamic_range_compression(gt_spect, clip_val=1e-5)
        gt_spects[os.path.split(gt_audio_file)[-1]] = gt_spect
    
    best_mean_val = 1e9
    best_pred_folder = ''
    len_pred_folders = len(pred_folders)
    for i, pred_folder in enumerate(pred_folders):
        print(f'{i+1}/{len_pred_folders}', end='\r')
        vals = []
        for pred_audio_file in glob(os.path.join(pred_folder, '*.wav')):
            if os.path.split(pred_audio_file)[-1] in gt_spects.keys():
                gt_spect = gt_spects[os.path.split(pred_audio_file)[-1]]
                pred_audio, sr, max_mag = load_wav_to_torch(pred_audio_file)
                pred_audio = pred_audio.cuda()
                pred_audio = pred_audio[:gt_audio.shape[0]]
                pred_audio /= max_mag
                pred_spect = get_spect(pred_audio)
                pred_spect = dynamic_range_compression(pred_spect, clip_val=1e-5)
                min_length = min(gt_spect.shape[1], pred_spect.shape[1])
                pred_MAE = torch.nn.functional.l1_loss(pred_spect[:, :min_length], gt_spect[:, :min_length], reduction='none')
                pred_MAE *= freq_weights
                pred_MAE = pred_MAE.mean().item()
                vals.append(pred_MAE)
        if len(vals):
            mean_val = sum(vals)/len(vals)
            if mean_val < best_mean_val:
                best_mean_val = mean_val
                best_pred_folder = pred_folder
    print(f'{best_mean_val} | {os.path.split(best_pred_folder)[-1]}\n')
    mean_best_vals.append(best_mean_val)

  0/100 - waveflow.WaveGlow2.WaveGlow6
0.43790902694066364 | 253000-2020_07_28-01_02_36

  1/100 - waveflow.WaveGlow2.WaveGlow5
0.4314320385456085 | 263000-2020_07_28-10_36_34

  2/100 - waveflow.WaveGlow2
0.45375850134425694 | 206000-2020_07_24-19_47_44

  3/100 - waveflow.WaveGlow2.WaveGlow4
0.5140067769421471 | 213000-2020_07_25-02_38_20

  4/100 - waveflow.WaveGlow2.WaveGlow3
0.46576651599672103 | 213000-2020_07_25-00_49_04

  5/100 - waveflow.WaveGlow.50G12F256C
0.7245897650718689 | 896000-2020_06_28-22_21_26

  6/100 - waveflow.WaveGlow
0.43754787955965313 | 368000-2020_07_16-14_42_31

  7/100 - waveflow.6thARSmall.with_res_skip
0.42919816289629253 | 810000-2020_06_22-17_55_11

  8/100 - waveflow.6thARSmall.with_res_skip_and_separable
0.4378726141793387 | 886000-2020_06_23-00_45_01

  9/100 - waveflow.6thARSmall
0.4263772964477539 | 992000-2020_06_24-11_44_29

 10/100 - waveflow.6thARSmall.without_res_skip
0.4458443011556353 | 840000-2020_06_22-08_54_50

 11/100 - waveflow.3rdLar

In [13]:
mean_best_vals = [float(x) if not 'N/A' in str(x) else 1e5 for x in mean_best_vals]
model_perfs = {k: v for k, v in zip(model_names, mean_best_vals)}
sorted_model_perfs = {k: v for k, v in sorted(model_perfs.items(), key=lambda item: item[1])}

In [14]:
# val MAE with [300] windows
print('\n'.join([f'{x[1]:.4f} | {x[0]:40}' for x in list(sorted_model_perfs.items())]))

0.3550 | waveflow.4thLargeKernels.AR_8_Flow_AEF4.1
0.3592 | waveflow.4thLargeKernels.AR_8_Flow_AEF  
0.3634 | waveflow.4thLargeKernels.AR_8_Flow_AEF4.1.1_nonsep
0.3737 | waveflow.4thLargeKernels.WG_12_Flow_AEF4.1
0.3743 | waveflow.4thLargeKernels.AR_8_Flow_AEF2 
0.3761 | waveflow.4thLargeKernels.AR_8_Flow_AEF4 
0.3796 | waveflow.4thLargeKernels.WG_24_Flow_AEF4.1
0.3825 | waveglow.outdir_EfficientSmallGlobalSpeakerEmbeddings.Testing2.1
0.3853 | waveflow.4thLargeKernels.AR_8_Flow_24Khz_Fmax
0.3859 | waveflow.4thLargeKernels.empthasis_AR_6_Flow_512C
0.3881 | waveglow.outdir_EfficientSmallGlobalSpeakerEmbeddings
0.3887 | waveflow.7thHeightDilated.32C12F        
0.3894 | waveflow.4thLargeKernels.AR_8_Flow.gt3  
0.3929 | waveflow.4thLargeKernels.GSIRU_384C_4Flow
0.3997 | waveglow.outdir_EfficientSmallGlobalSpeakerEmbeddings.Testing
0.4007 | waveflow.4thLargeKernels.AR_6_Flow_512C 
0.4007 | waveglow.outdir_EfficientLarge4         
0.4012 | waveglow.outdir_EfficientSmallSpeakerEmbeddings
0.405

In [23]:
# val MAE with [600, 1200] windows
print('\n'.join([f'{x[1]:.4f} | {x[0]:40}' for x in list(sorted_model_perfs.items())]))

0.3470 | waveflow.4thLargeKernels.AR_8_Flow_AEF4.1
0.3521 | waveflow.4thLargeKernels.AR_8_Flow_AEF  
0.3570 | waveflow.4thLargeKernels.AR_8_Flow_AEF4.1.1_nonsep
0.3700 | waveflow.4thLargeKernels.AR_8_Flow_AEF2 
0.3714 | waveflow.4thLargeKernels.AR_8_Flow.gt3  
0.3717 | waveglow.outdir_EfficientSmallGlobalSpeakerEmbeddings.Testing2.1
0.3742 | waveflow.4thLargeKernels.empthasis_AR_6_Flow_512C
0.3753 | waveflow.4thLargeKernels.AR_8_Flow_24Khz_Fmax
0.3770 | waveflow.4thLargeKernels.WG_12_Flow_AEF4.1
0.3798 | waveflow.4thLargeKernels.AR_8_Flow_AEF4 
0.3816 | waveflow.7thHeightDilated.32C12F        
0.3850 | waveglow.outdir_EfficientSmallGlobalSpeakerEmbeddings
0.3896 | waveflow.4thLargeKernels.GSIRU_384C_4Flow
0.3896 | waveflow.4thLargeKernels.AR_6_Flow_512C 
0.3904 | waveglow.outdir_EfficientLarge4         
0.3977 | waveglow.outdir_EfficientSmallGlobalSpeakerEmbeddings.Testing
0.3978 | waveglow.outdir_EfficientLarge          
0.3990 | waveflow.4thLargeKernels.AR_6_Flow      
0.4018 | waveg

In [None]:
# val MAE with [300, 600, 1200, 2400] windows
print('\n'.join([f'{x[1]:.4f} | {x[0]:40}' for x in list(sorted_model_perfs.items())]))