# Multi-Band Auto-Encoder
### This Network should learn to compress the information from multiple 2400 channel spectrograms of different window lengths into a low dimensional latent space.

### - This latent space can then be used to train a `Text -> Feature -> Vocoder` setup in place of typical Mel-Spectrograms.

### - The latent space should contain multiple windows worth of information, and potentially encode the types of noise occuring in the frame.

---

- ### Updated to add iso226 volume scaling for features. (noticed far too much model focus on 11Khz+ channels)

---

# -1 - Install ISO226

In [None]:
try:
    import iso226
except:
    !git clone https://github.com/jacobbaylesssmc/iso226
    !cd iso226; python3 -m pip install ./

# 0 - Import Dependancies/Modules


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
os.environ["LRU_CACHE_CAPACITY"] = "3"
import random

from CookieTTS.utils.dataset.utils import load_wav_to_torch, load_filepaths_and_text

---

# 1 - Load Dataset

In [None]:
import os
import random
from glob import glob
from CookieTTS.utils.dataset.utils import load_wav_to_torch
from CookieTTS.utils.audio.stft import STFT
from CookieTTS.utils.audio.audio_processing import window_sumsquare, dynamic_range_compression, dynamic_range_decompression

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, model_config, gpu_dataloading, directory, wildcard_filter = "*.wav"):
        self.gpu_dataloading = gpu_dataloading
        print("Finding Audio Files...")
        self.audio_files = glob( os.path.join(directory, "**", wildcard_filter), recursive=True)
        print("Done")
        
        random.seed(1234)
        random.shuffle(self.audio_files)
        self.len = len(self.audio_files)
        
        self.max_len_s = model_config['max_len_s']
        self.win_lens = model_config['window_lengths']
        self.hop_len = model_config['hop_length']
        self.fil_len = model_config['filter_length']
        
        self.stfts = []
        for win_len in self.win_lens:
            stft = STFT(filter_length=self.fil_len,
                           hop_length=self.hop_len,
                           win_length=win_len,)
            stft = stft.cuda() if self.gpu_dataloading else stft
            self.stfts.append(stft)
        
        self.directory = directory
        self.wildcard_filter = wildcard_filter
    
    def get_mel(self, audio):
        """Take audio and convert to multi-res spectrogram"""
        melspec = []
        for stft in self.stfts:
            spect = stft.transform(audio.unsqueeze(0), return_phase=False)[0].squeeze(0)# -> [n_mel, dec_T]
            #print(spect.shape)
            melspec.append(spect)
        return torch.cat(melspec, dim=0)# [[n_mel, dec_T], ...] -> [n_stft*n_mel, dec_T]
    
    def __getitem__(self, index):
        audio_path = self.audio_files[index]
        audio, sampling_rate, max_mag = load_wav_to_torch(audio_path) # load mono audio from file
        audio = audio / max_mag # normalize to range [-1, 1]
        
        #noisy_audio = audio.clone()
        #noisy_audio += torch.randn(*audio.shape) * random.uniform(self.min_noise_str, self.max_noise_str)
        #noisy_audio = noisy_audio.clamp(min=-0.999, max=0.999)
        #noisy_spect = dynamic_range_compression(self.get_mel(noisy_audio))
        
        if audio.shape[0] > int(self.max_len_s*sampling_rate):
            max_start = audio.shape[0] - int(self.max_len_s*sampling_rate)
            start = (torch.rand(1)*max_start).int()
            audio = audio[start:start+int(self.max_len_s*sampling_rate)]
        
        audio = audio.cuda() if self.gpu_dataloading else audio
        
        spect = self.get_mel(audio)
        spect = dynamic_range_compression(spect)
        
        spect_length = spect.shape[1]
        return (spect, spect_length)
    
    def __len__(self):
        return self.len

In [None]:
class MelCollate():
    def __init__(self):
        pass
    
    def __call__(self, batch):
        B = len(batch)
        lengths = [x[0].shape[1] for x in batch]
        n_mel = batch[0][0].shape[0]
        max_length = max(*lengths)
        b_spect = [x[0] for x in batch]
        b_spect = torch.cat(b_spect, dim=1).unsqueeze(0)# [1, n_stft*n_mel, sum(dec_T)]
        
        #for i in range(B):
        #    spect = batch[i][0]
        #    b_spect[i, :, :spect.shape[1]] = spect
        
        spect_lengths = torch.tensor([sum(lengths),])
        model_inputs = (b_spect, spect_lengths)
        return model_inputs

---

# 2 - Init Model

In [None]:
class TemporalBlock(nn.Module):
    def __init__(self, input_dim, output_dim, n_layers, n_dim, kernel_w, bias=True, act_func=nn.LeakyReLU(negative_slope=0.1, inplace=True), dropout=0.0, res=False):
        super(TemporalBlock, self).__init__()
        self.layers = nn.ModuleList()
        for i in range(n_layers):
            in_dim = input_dim if i == 0 else n_dim
            out_dim = output_dim if i+1 == n_layers else n_dim
            pad = (kernel_w - 1)//2
            conv = nn.Conv1d(in_dim, out_dim, kernel_w, padding=pad, bias=bias)
            self.layers.append(conv)
        self.act_func = act_func
        self.dropout = dropout
        self.res = res
        if self.res:
            assert input_dim == output_dim, 'residual connection requires input_dim and output_dim to match.'
    
    def forward(self, x): # [B, in_dim, T]
        skip = x
        
        for i, layer in enumerate(self.layers):
            is_last_layer = bool( i+1 == len(self.layers) )
            x = layer(x)
            if not is_last_layer:
                x = self.act_func(x)
            if self.dropout > 0.0 and self.training:
                x = F.dropout(x, p=self.dropout, training=self.training, inplace=True)
        if self.res:
            x += skip
        return x # [B, out_dim, T]


class Coder(nn.Module):
    def __init__(self, model_config, input_dim=None, output_dim=None, output_batchnorm=False):
        super(Coder, self).__init__()
        self.input_dim = ((model_config['filter_length']//2) + 1) * len(model_config['window_lengths']) if input_dim is None else input_dim
        self.output_dim = model_config['latent_dim'] if output_dim is None else output_dim
        self.device = "cuda"
        
        self.temporalblocks = nn.ModuleList()
        for i in range(model_config['n_blocks']):
            b_first_block = bool(i == 0)
            b_last_block = bool(i+1 == model_config['n_blocks'])
            in_dim  = self.input_dim if b_first_block else model_config['n_dim']
            out_dim = self.output_dim if b_last_block else model_config['n_dim']
            res = True if (model_config['residual_connections'] and in_dim == out_dim) else False
            n_layers = model_config['bottleneck_n_layers'] if b_first_block or b_last_block else model_config['n_layers']
            temp_block = TemporalBlock(in_dim, out_dim, n_layers, model_config['n_dim'],
                                       model_config['kernel_width'], bias = model_config['bias'],
                                       dropout = model_config['dropout'], res=res)
            self.temporalblocks.append(temp_block)
        
        if output_batchnorm:
            self.bn_out = nn.BatchNorm1d(self.output_dim, momentum=0.05, affine=False)
    
    def forward(self, spect):
        assert spect.shape[1] == self.input_dim, f'input Tensor is wrong shape ({spect.shape}). Expected {self.input_dim} channels.'
        spect = spect.to(self.device)
        
        for block in self.temporalblocks:
            spect = block(spect)
        
        spect = spect.clone()
        
        if hasattr(self, 'bn_out'):
            spect = self.bn_out(spect)
        return spect


class AutoEncoder(nn.Module):
    def __init__(self, model_config):
        super(AutoEncoder, self).__init__()
        self.in_out_dim = ((model_config['filter_length']//2) + 1) * len(model_config['window_lengths'])
        self.latent_dim = model_config['latent_dim']
        
        self.encoder = Coder(model_config, self.in_out_dim, self.latent_dim, output_batchnorm=True)
        self.decoder = Coder(model_config, self.latent_dim, self.in_out_dim)
    
    def get_specs(self, audio):
        spects = []
        for stft in self.stfts:
            spect = stft.transform(audio.unsqueeze(0), return_phase=False)[0].squeeze(0)# -> [n_mel, dec_T]
            spects.append(spect)
        spect = torch.cat(spects, dim=0)# [[n_mel, dec_T], ...] -> [n_stft*n_mel, T//hop_len]
        spect = dynamic_range_compression(spect)# change to clamped log-scale magnitudes
        return spect.unsqueeze(0)# -> [1, n_stft*n_mel, T//hop_len]
    
    def encode_audiopath(self, audio_path):
        audio, sampling_rate, max_mag = load_wav_to_torch(audio_path) # load mono audio from file
        audio /= max_mag # normalize to range [-1, 1]
        return self.encode_audio(audio)
    
    def encode_audio(self, audio):
        """Encoder [T] Tensor into Z learned latent representation."""
        spect = self.get_specs(audio.cuda())# -> [B, n_stfts*n_fft, T]
        z = self.encoder(spect)# -> [B, z_dim, T]
        return z
    
    def forward(self, inputs):
        spect, spect_lengths = inputs
        
        z = self.encoder(spect)
        
        rec_spect = self.decoder(z)
        
        return rec_spect.clone()

In [None]:
import torch
import numpy as np
import iso226
import math
from CookieTTS.utils.model.utils import get_mask_from_lengths

# https://www.desmos.com/calculator/4nac7kvt7p
# Squash smaller values together so that mse loss is lower on quieter parts of the spectrogram.
def vol_rescale_loss(mel, power=0.5, min=-11.55):
    mel = mel + (power/(-min*2))*(mel**2)
    return mel


class LossFunction(nn.Module):
    def __init__(self, model_config):
        super(LossFunction, self).__init__()
        iso226_spl_from_freq = iso226.iso226_spl_itpl(L_N=60, hfe=True)# get InterpolatedUnivariateSpline for Perc Sound Pressure Level at Difference Frequencies with 60DB ref.
        self.freq_weights = torch.tensor([(2**(60./10))/(2**(iso226_spl_from_freq(freq)/10)) for freq in np.linspace(0, model_config['sampling_rate']//2, (model_config['filter_length']//2)+1)])
        self.freq_weights = self.freq_weights.cuda().repeat(len(model_config['window_lengths']))[None, :, None]# [B, n_mel, T]
        
        self.loud_loss_priority_str = model_config['loud_loss_priority']
    
    def forward(self, y, x):
        gt_spect, lengths = x
        gt_spect = gt_spect.cuda()
        pred_spect = y
        
        mask = get_mask_from_lengths(lengths.cuda())
        mask = mask.expand(gt_spect.size(1), mask.size(0), mask.size(1))
        mask = mask.permute(1, 0, 2)
        #gt_spect.detach()[~mask] = 0.0
        #pred_spect.detach()[~mask] = 0.0
        
        if self.loud_loss_priority_str > 0:
            pred_spect = vol_rescale_loss(pred_spect, power=self.loud_loss_priority_str)
            gt_spect = vol_rescale_loss(gt_spect, power=self.loud_loss_priority_str)
        
        MAE = F.mse_loss(pred_spect, gt_spect, reduction='none')
        MAE = MAE * self.freq_weights# [B, n_mel, T] * [1, n_mel, 1]
        MAE = torch.masked_select(MAE, mask)# [B, n_mel, T] -> [n_mel*sum(n_frames)]
        
        return MAE.mean()

# 2.9 - Plot Data

In [None]:
import matplotlib
%matplotlib inline
import matplotlib.pylab as plt
import IPython.display as ipd

def plot_data(data, title=None, figsize=(20, 7.5), range_=[-11.6, 2.0]):
    """
    data: list([height, width], [height, width], ...)
    """
    #for i in range(len(data)):
    #    data[i][0,0] = range_[0]
    #    data[i][0,1] = range_[1]
    fig, axes = plt.subplots(1, len(data), figsize=figsize)
    for i in range(len(data)):
        if title:
            axes[i].set_title(title[i])
        axes[i].imshow(data[i], aspect='auto', origin='bottom', 
                       interpolation='none')
    plt.show()
    %matplotlib inline

---

# 3 - Train and Eval

Config

---
```
----- Previous Models -----
AEF1 - 160 Channels with 12*5 Coder Layers, Learned Mean/STD
AEF4 - 160 Channels with  3*1 Coder Layers, Learned Mean/STD
AEF5 - 512 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
AEF6 - 256 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
AEF7 - 128 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
AEF8 - 192 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
```
---

In [None]:
model_config = {
    "model_name": "AEF8",
    "audio_directory": "/media/cookie/Samsung 860 QVO/ClipperDatasetV2",#r"H:\ClipperDatasetV2\SlicedDialogue",
    "wildcard_filter": "*.wav",
    "batch_size": 8,
    'max_len_s': 6.0,
    "learning_rate": 5e-5,
    "latent_dim": 192,
    "loud_loss_priority": 0.1,# squash smaller values so model will prioritise louder parts of the spectrogram. # 0.0 = Off, 1.0 = Nearly parts have 0.0 loss.
    "sampling_rate": 48000,
    "window_lengths": [600, 1200, 2400],
    "hop_length": 600,
    "filter_length": 2400,
    "n_blocks": 1,#3,
    "n_layers": 3,
    "bottleneck_n_layers": 1,
    "n_dim": 256,
    "kernel_width": 1,
    "residual_connections": True,
    "bias": True,
    "dropout": 0.00,
}

gpu_dataloading = True

---

The rest

In [None]:
from torch.utils.data import DataLoader

n_epochs = 200

dataset = AudioDataset(model_config, gpu_dataloading, model_config['audio_directory'], model_config['wildcard_filter'])

In [None]:
# Initialize Training
collate_fn = MelCollate()
train_loader = DataLoader(dataset, num_workers=0 if gpu_dataloading else 12, shuffle=True,
                              batch_size=model_config['batch_size'],
                              pin_memory=False, drop_last=True,
                              collate_fn=collate_fn)

criterion = LossFunction(model_config)

In [None]:
# Initialize/Reset Model
model = AutoEncoder(model_config).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=model_config['learning_rate'])

iteration = 0
avg_improvement = 0.0
avg_training_loss = 2.0

In [None]:
#checkpoint_path = "MelAutoEncoder_50000_AEF7.pt"
#checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
#model.load_state_dict(checkpoint_dict['model'])
#optimizer.load_state_dict(checkpoint_dict['optimizer'])
#iteration = checkpoint_dict['iteration']
#model_config = checkpoint_dict['model_config']

In [None]:
model.train()

for epoch in range(n_epochs):
    print(f"Epoch: {epoch}")
    for i, batch in enumerate(train_loader):
        learning_rate = 2e-4
        if iteration > 50000:
            learning_rate = 1e-4
        if iteration > 75000:
            learning_rate = 0.5e-4
        if iteration > 100000:
            learning_rate = 0.25e-4
        if iteration > 110000:
            learning_rate = 0.125e-4
        if iteration > 120000:
            learning_rate = 0.0625e-4
        
        learning_rate *= 0.74
        
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
        
        model.zero_grad()
        outputs = model(batch)
        loss = criterion(outputs, batch)
        reduced_loss = loss.item()
        loss.backward()
        optimizer.step()
        
        avg_training_loss = avg_training_loss*0.99 + reduced_loss*(1-0.99)
        if iteration%100 == 0:
            print(f"\n[Iter {iteration:<6}] [Training Loss {reduced_loss:5.3f} Avg {avg_training_loss:5.3f}]", end='')
        else:
            print(".", end='')
        if True and iteration%10000 == 0:
            plot_data([x[0][:,:400].float().cpu().detach().numpy() for x in batch[0:1]]+\
                      [outputs[0][:,:400].float().cpu().detach().numpy()],
                      figsize=(24, 48))
        if iteration%25000==0:
            filepath = "/media/cookie/Samsung PM961/TwiBot/CookiePPPTTS/CookieTTS/scripts/MelAutoEncoder"+f"_{iteration}_{model_config['model_name']}.pt"
            saving_dict = {
                'model': model.state_dict(),
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate,
                'model_config': model_config}
            torch.save(saving_dict, filepath)
        iteration+=1

In [None]:
# save model
#filepath = r"G:\TwiBot\CookiePPPTTS\CookieTTS\scripts\MelAutoEncoder"+f"_{iteration}.pt"
filepath = "/media/cookie/Samsung PM961/TwiBot/CookiePPPTTS/CookieTTS/scripts/MelAutoEncoder"+f"_{iteration}_{model_config['model_name']}.pt"

print(f"Saving checkpoint to '{filepath}'")
saving_dict = {
    'model': model.state_dict(),
    'iteration': iteration,
    'optimizer': optimizer.state_dict(),
    'learning_rate': learning_rate,
    'model_config': model_config,
    }
torch.save(saving_dict, filepath)
print("Done")

---

# 4 - Convert Dataset to new Latent features.

#### `.wav` -> `.npy`

```
AEF1 - 160 Channels with 12*5 Coder Layers, Learned Mean/STD
AEF4 - 160 Channels with  3*1 Coder Layers, Learned Mean/STD
AEF5 - 512 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
AEF6 - 256 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
AEF7 - 128 Channels with  3*3 Coder Layers, Zero Mean Unit Variance
AEF8 - 192 Channels with  3*3 Coder Layers, Zero Mean Unit Variance

(edited)
```

In [None]:
checkpoint_path = "MelAutoEncoder_150000_AEF6.pt"
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
iteration = checkpoint_dict['iteration']
model_config = checkpoint_dict['model_config']

# load model
model = AutoEncoder(model_config).cuda()
model.load_state_dict(checkpoint_dict['model'])
_ = model.eval()

In [None]:
hop_len = 600
# update as needed!

model.stfts = []
for win_len in model_config['window_lengths']:
    stft = STFT(filter_length=model_config['filter_length'],
                hop_length=hop_len, win_length=win_len,).cuda()
    model.stfts.append(stft)

In [None]:
from glob import glob
directory = "/media/cookie/Samsung 860 QVO/ClipperDatasetV2"

with torch.no_grad():
    audiopaths = glob( os.path.join(directory, "**", "*.wav"), recursive=True)
    len_audiopaths = len(audiopaths)
    for i, audiopath in enumerate(audiopaths):
        latent_z = model.encode_audiopath(audiopath).squeeze(0)
        print(f'{i:6}/{len_audiopaths:<6} {latent_z.shape}', end='\r')
        new_save_path = audiopath.replace('.wav','.npy')
        #print(torch.from_numpy(np.load(new_save_path)).shape)
        np.save(new_save_path, latent_z.data.squeeze().float().cpu().numpy())
    print("\nDone!")