In [40]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
os.environ["LRU_CACHE_CAPACITY"] = "3"
import random

from CookieTTS.utils.dataset.utils import load_wav_to_torch, load_filepaths_and_text

In [41]:
import os
import random
from glob import glob
from CookieTTS.utils.dataset.utils import load_wav_to_torch
from CookieTTS.utils.audio.stft import STFT
from CookieTTS.utils.audio.audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, model_config):
        
        self.win_lens = model_config['window_lengths']
        self.hop_len = model_config['hop_length']
        self.fil_len = model_config['filter_length']
        
        self.stfts = []
        for win_len in self.win_lens:
            stft = STFT(filter_length=self.fil_len,
                           hop_length=self.hop_len,
                           win_length=win_len,)
            self.stfts.append(stft)
    
    def get_mel(self, audio):
        """Take audio and convert to multi-res spectrogram"""
        melspec = []
        for stft in self.stfts:
            spect = stft.transform(audio.unsqueeze(0), return_phase=False)[0].squeeze(0)# -> [n_mel, dec_T]
            #print(spect.shape)
            melspec.append(spect)
        return torch.cat(melspec, dim=0)# [[n_mel, dec_T], ...] -> [n_stft*n_mel, dec_T]
    
    def __getitem__(self, audio_path):
        audio, sampling_rate, max_mag = load_wav_to_torch(audio_path) # load mono audio from file
        audio = audio / max_mag # normalize to range [-1, 1]
         
        spect = self.get_mel(audio)
        spect = dynamic_range_compression(spect)
        
        return spect

In [42]:
import iso226
import math
from CookieTTS.utils.model.utils import get_mask_from_lengths

# https://www.desmos.com/calculator/4nac7kvt7p
# Squash smaller values together so that mse loss is lower on quieter parts of the spectrogram.
def vol_rescale_loss(mel, power=0.5, min=-11.55):
    mel = mel + (power/(-min*2))*(mel**2)
    return mel


class LossFunction(nn.Module):
    def __init__(self, model_config):
        super(LossFunction, self).__init__()
        iso226_spl_from_freq = iso226.iso226_spl_itpl(L_N=60, hfe=True)# get InterpolatedUnivariateSpline for Perc Sound Pressure Level at Difference Frequencies with 60DB ref.
        self.freq_weights = torch.tensor([(2**(60./10))/(2**(iso226_spl_from_freq(freq)/10)) for freq in np.linspace(0, model_config['sampling_rate']//2, (model_config['filter_length']//2)+1)])
        self.freq_weights = self.freq_weights.repeat(len(model_config['window_lengths']))[None, :, None]# [B, n_mel, T]
        
        self.loud_loss_priority_str = model_config['loud_loss_priority']
    
    def forward(self, pred_spect, gt_spect):
        if self.loud_loss_priority_str > 0:
            pred_spect = vol_rescale_loss(pred_spect, power=self.loud_loss_priority_str)
            gt_spect = vol_rescale_loss(gt_spect, power=self.loud_loss_priority_str)
        
        MAE = F.mse_loss(pred_spect, gt_spect, reduction='none')
        MAE = MAE * self.freq_weights# [B, n_mel, T] * [1, n_mel, 1]
        
        return MAE.mean()

In [47]:
model_config = {
    'sampling_rate': 48000,
    'filter_length': 2400,
    'hop_length': 600,
    'window_lengths': [
        2400,
        1200,
    ],
    'loud_loss_priority': 0.0,
}

In [48]:
data_loader = AudioDataset(model_config)
criterion = LossFunction(model_config)

In [49]:
gt_path = r"G:\TwiBot\wavegrad\src\00_04_12_Twilight_Neutral__The first performance was so full of energy, so highly charged, That magical lightning showered down on the crowd.wav"
pred_paths = glob(r"G:\TwiBot\wavegrad\src\**\*.wav", recursive=True)

mean_best_vals = []
gt_spec = data_loader.__getitem__(gt_path)
for pred_path in pred_paths:
    pred_spec = data_loader.__getitem__(pred_path)
    min_length = min(pred_spec.shape[1], gt_spec.shape[1])
    pred_spec = pred_spec[:, :min_length]
    gt_spec = gt_spec[:, :min_length]
    MAE = criterion(pred_spec, gt_spec).item()
    print(f'{MAE:.3f} : {pred_path}')
    mean_best_vals.append(MAE)

model_perfs = {k: v for k, v in zip(pred_paths, mean_best_vals)}
sorted_model_perfs = {k: v for k, v in sorted(model_perfs.items(), key=lambda item: item[1])}

0.000 : G:\TwiBot\wavegrad\src\00_04_12_Twilight_Neutral__The first performance was so full of energy, so highly charged, That magical lightning showered down on the crowd.wav
1.952 : G:\TwiBot\wavegrad\src\output-68736.wav
1.616 : G:\TwiBot\wavegrad\src\output-74464.wav
1.863 : G:\TwiBot\wavegrad\src\output-80192.wav
1.708 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_103086.wav
1.625 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_106904.wav
2.616 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_110722.wav
3.205 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_11454.wav
1.814 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_114540.wav
1.140 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_343620.wav
1.734 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_345529.wav
1.763 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_347438.wav
1.427 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_349347.wav
1.181 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_351256.wav
1.187 : G:\TwiBot\wavegrad\src\output_dist_fp16_3_353165.wav
1.448 : G:

In [50]:
print('\n'.join([f'{x[1]:.4f} | {x[0]:40}' for x in list(sorted_model_perfs.items())]))

0.0000 | G:\TwiBot\wavegrad\src\00_04_12_Twilight_Neutral__The first performance was so full of energy, so highly charged, That magical lightning showered down on the crowd.wav
1.0236 | G:\TwiBot\wavegrad\src\output_dist_fp16_5_450524.wav
1.0815 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_423798.wav
1.0880 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_383709.wav
1.1022 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_372255.wav
1.1089 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_412344.wav
1.1165 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_377982.wav
1.1178 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_335984.wav
1.1219 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_404708.wav
1.1231 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_402799.wav
1.1241 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_370346.wav
1.1395 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_343620.wav
1.1411 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_334075.wav
1.1460 | G:\TwiBot\wavegrad\src\output_dist_fp16_3_318803.wav
1.1493 | G:\TwiBo