In [23]:
try:
    import iso226
except:
    !git clone https://github.com/jacobbaylesssmc/iso226
    !cd iso226; python3 -m pip install ./

In [6]:
print(f'{1e-4:.2e}')

1.00e-04


In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from CookieTTS.utils.audio.stft import STFT

class ISO_226(torch.nn.Module):
    def __init__(self, sampling_rate=48000, filter_length=2400, hop_length=600, win_length=2400, stft_device='cpu'):
        super(ISO_226, self).__init__()
        self.stft_device = stft_device
        self.stft = STFT(filter_length=filter_length,
                         hop_length=hop_length,
                         win_length=win_length).to(device=self.stft_device)
         
        iso226_spl_from_freq = iso226.iso226_spl_itpl(L_N=60, hfe=True)# get InterpolatedUnivariateSpline for Perc Sound Pressure Level at Difference Frequencies with 60DB ref.
        self.freq_weights = torch.tensor([(10**(60./10))/(10**(iso226_spl_from_freq(freq)/10)) for freq in np.linspace(0, sampling_rate//2, (filter_length//2)+1)])
        self.freq_weights = self.freq_weights.to(self.stft_device)[None, :, None]# [B, n_mel, T]
        freq_weights = self.freq_weights.clone()
        freq_weights[freq_weights<0.008] = 1e5
        self.inv_freq_weights = 1/freq_weights
    
    def forward(self, in_audio):
        with torch.no_grad():
            in_audio_device = in_audio.device
            audio = in_audio.to(self.stft_device).float()
            audio_spec, audio_angles = self.stft.transform(audio)
            audio_spec *= self.freq_weights
            audio = self.stft.inverse(audio_spec, audio_angles).squeeze(1)
            if in_audio_device != self.stft_device:
                audio.to(in_audio)
            return audio
    
    def inverse(self, in_audio):
        with torch.no_grad():
            in_audio_device = in_audio.device
            audio = in_audio.to(self.stft_device).float()
            audio_spec, audio_angles = self.stft.transform(audio)
            audio_spec *= self.inv_freq_weights
            audio = self.stft.inverse(audio_spec, audio_angles).squeeze(1)
            if in_audio_device != self.stft_device:
                audio.to(in_audio)
            return audio

In [25]:
from CookieTTS.utils.dataset.utils import load_wav_to_torch
import IPython.display as ipd

In [26]:
sampling_rate = 48000
iso_226 = ISO_226(sampling_rate=48000, filter_length=1200, hop_length=24, win_length=1200, stft_device='cuda')

In [27]:
audio_path = r"H:\TTCheckpoints\waveflow\4thLargeKernels\AR_8_Flow_AEF4.1\samples\Ground Truth\NAN_751.wav"

audio, sr, max_mag = load_wav_to_torch(audio_path)
audio = audio.unsqueeze(0).repeat(1, 1)
ipd.display(ipd.Audio(audio[0].cpu().numpy(), rate=sampling_rate))

audio = iso_226(audio)
ipd.display(ipd.Audio(audio[0].cpu().numpy(), rate=sampling_rate))

audio = iso_226.inverse(audio)
ipd.display(ipd.Audio(audio[0].cpu().numpy(), rate=sampling_rate))

In [28]:
import time
start = time.time()
for i in range(100):
    audio = iso_226.inverse(audio)
end = time.time()
print((end-start)/100)

0.04198033094406128


In [29]:
torch.set_printoptions(sci_mode=False)

filter_length = 1200
sampling_rate = 48000
iso226_spl_from_freq = iso226.iso226_spl_itpl(L_N=60, hfe=True)# get InterpolatedUnivariateSpline for Perc Sound Pressure Level at Difference Frequencies with 60DB ref.
print(
    torch.tensor([iso226_spl_from_freq(freq)/10 for freq in np.linspace(0, sampling_rate//2, (filter_length//2)+1)]),
    torch.tensor([(10**(60./10))/(10**(iso226_spl_from_freq(freq)/10)) for freq in np.linspace(0, sampling_rate//2, (filter_length//2)+1)]),
    sep='\n\n'
)


torch.set_printoptions(sci_mode=True)

tensor([14.9858,  9.4177,  8.2053,  7.6105,  7.2474,  6.9864,  6.7938,  6.6454,
         6.5254,  6.4260,  6.3451,  6.2807,  6.2283,  6.1831,  6.1424,  6.1060,
         6.0739,  6.0460,  6.0224,  6.0033,  5.9887,  5.9787,  5.9742,  5.9759,
         5.9846,  6.0012,  6.0259,  6.0572,  6.0931,  6.1313,  6.1699,  6.2068,
         6.2400,  6.2686,  6.2925,  6.3112,  6.3246,  6.3323,  6.3342,  6.3298,
         6.3189,  6.3015,  6.2783,  6.2503,  6.2182,  6.1832,  6.1462,  6.1080,
         6.0697,  6.0321,  5.9962,  5.9627,  5.9318,  5.9032,  5.8769,  5.8527,
         5.8305,  5.8102,  5.7915,  5.7745,  5.7589,  5.7447,  5.7316,  5.7197,
         5.7087,  5.6986,  5.6895,  5.6814,  5.6740,  5.6675,  5.6619,  5.6570,
         5.6529,  5.6494,  5.6467,  5.6447,  5.6433,  5.6424,  5.6422,  5.6425,
         5.6433,  5.6447,  5.6465,  5.6489,  5.6517,  5.6551,  5.6588,  5.6631,
         5.6678,  5.6730,  5.6786,  5.6846,  5.6911,  5.6979,  5.7052,  5.7129,
         5.7210,  5.7294,  5.7383,  5.74