In [None]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    from google.colab import drive, output
    drive.mount('/content/drive')
    import sys
    %cd '/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project'
    #sys.path.append('/content/drive/MyDrive/Deep Learning/AutoEncoders/Project/VQVAE_Working/data')
    #sys.path.append('/content/drive/MyDrive/Deep Learning/AutoEncoders/Project/VQVAE_Working/models')
    sys.path.append('/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project/Dataloader')
    sys.path.append('/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project/Models')
    sys.path.append('/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project/Modules')
    %load_ext autoreload
    %autoreload 1
    !pip install torchaudio
    !pip install umap
    !pip install wandb --upgrade
    !wandb login
    output.clear()

else:
    print('Not running on CoLab')


In [None]:
# from __future__ import print_function
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
# import IPython

from six.moves import xrange

import umap
import datetime
import gc

from scipy import signal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torchaudio.transforms as audio_transform

#from ResidualStack import ResidualStack
#from Residual import Residual

from Jaguas_DataLoader import SoundscapeData
from Models import Model
from Models import Encoder
from Models import Decoder
from Models import VectorQuantizer
from Models import VectorQuantizerEMA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = xm.xla_device()

from datetime import timedelta
import wandb
from wandb import AlertLevel

torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()


## Train

We use the hyperparameters from the author's code.

In [None]:
from scipy.io.wavfile import write

class TestModel:

    def __init__(self, model, iterator, num_views):
        self._model = model
        self._iterator = iterator
        self.num_views = num_views

    def plot_waveform(self, waveform, n_rows=2, directory=None):
        fig, axs = plt.subplots(n_rows, figsize=(10, 6))
        for i in range(len(waveform)):
            axs[i].plot(waveform[i,0])
            if directory != None:
                scaled = np.int16(waveform[i,0]/np.max(np.abs(waveform[i,0])) * 32767)
                write(directory + str(i) + '.wav', 22050, scaled)
        plt.show()
        
        
    def waveform_generator(self, spec, n_fft=1028, win_length=1028, base_win=256, plot=False):
        spec = spec.cdouble()
        spec = spec.to("cpu")
        hop_length = int(np.round(base_win/win_length * 172.3))
        transformation = audio_transform.InverseSpectrogram(n_fft=n_fft, win_length=win_length)
        waveform = transformation(spec)
        waveform = waveform.cpu().detach().numpy()
        return waveform
    
    def plot_psd(self, waveform):
        for wave in waveform:
            plt.psd(wave)

    def plot_reconstructions(self, imgs_original, imgs_reconstruction, num_views:int = 8):
        output = torch.cat((imgs_original[0:self.num_views], imgs_reconstruction[0:self.num_views]), 0)
        img_grid = make_grid(output, nrow=self.num_views, pad_value=20)
        fig, ax = plt.subplots(figsize=(20,5))
        ax.imshow(img_grid[1,:,:].cpu(), vmin=0, vmax=1)
        ax.axis("off")
        plt.show()
        return fig

    def reconstruct(self):
        self._model.eval()
        (valid_originals, _,_) = next(self._iterator)
        valid_originals = torch.reshape(valid_originals, (valid_originals.shape[0] * valid_originals.shape[1], 
                                                    valid_originals.shape[2], valid_originals.shape[3]))
        valid_originals = torch.unsqueeze(valid_originals,1)

        valid_originals = valid_originals.to(device)

        vq_output_eval = self._model._pre_vq_conv(self._model._encoder(valid_originals))
        _, valid_quantize, _, _ = self._model._vq_vae(vq_output_eval)
        valid_reconstructions = self._model._decoder(valid_quantize)

        recon_error = F.mse_loss(valid_originals, valid_reconstructions)

        return valid_originals, valid_reconstructions, recon_error

    def run(self, plot=True, wave_return=True, wave_plot=True, directory=None):
        wave_original = []
        wave_reconstructions = []
        originals, reconstructions, error = self.reconstruct() 
        if plot:
            self.plot_reconstructions(originals, reconstructions)
        if wave_return:
            wave_original = self.waveform_generator(originals)
            wave_reconstructions = self.waveform_generator(reconstructions)
            if wave_plot:
                self.plot_waveform(wave_original, len(wave_original), directory="originals")
                self.plot_waveform(wave_reconstructions, len(wave_reconstructions), directory="reconstructions")

        return originals, reconstructions, wave_original, wave_reconstructions, error


class TrainModel:

    def __init__(self, model):
        self._model = model

    def wandb_init(self, config, keys=["audio_length", "win_length", "batch_size"]):
        try:
            run_name = "VQ_"
            for key in keys:
                if key in config.keys():
                    run_name = run_name + key + ":" + str(config[key]) + "_"
                else:
                    run_name = run_name + str(key)

            wandb.login()
            wandb.finish()
            wandb.init(project="VQ-VAE-Jaguas", config=config)
            wandb.run.name = run_name
            wandb.run.save()
            wandb.watch(self._model, F.mse_loss, log="all", log_freq=1)
            is_wandb_enable = True         
        except Exception as e:
            print(e)
            is_wandb_enable = False

        return is_wandb_enable, run_name

    def wandb_logging(self, dict, step=0):
        for keys in dict:
            wandb.log({keys: dict[keys]}, step=step)


    def fordward(self, training_loader, test_loader, config):
        iterator = iter(test_loader)
        wandb_enable, run_name = self.wandb_init(config)
        optimizer = config["optimizer"]
        scheduler = config["scheduler"]

        train_res_recon_error = []
        train_res_perplexity = []
        logs = []
        best_loss = 10000

        for epoch in range(config["num_epochs"]):
            iterator_train = iter(training_loader)
            for i in xrange(config["num_training_updates"]):
                self._model.train()
                try:
                    (data, _,_) = next(iterator_train)
                except Exception as e:
                    print("error")
                    print(e)
                    logs.append(e)
                    continue

                data = torch.reshape(data, (data.shape[0] * data.shape[1], data.shape[2], data.shape[3]))
                data = torch.unsqueeze(data,1)
                data = data.to(device)
                # print(data.shape)

                optimizer.zero_grad()
                vq_loss, data_recon, perplexity = self._model(data)
                # print(data_recon.shape)
                
                recon_error = F.mse_loss(data_recon, data) #/ data_variance
                loss = recon_error + vq_loss
                loss.backward()

                optimizer.step()
                print(f'epoch: {epoch+1} of {config["num_epochs"]} \t iteration: {(i+1)} of {config["num_training_updates"]} \t loss: {np.round(loss.item(),4)} \t recon_error: {np.round(recon_error.item(),4)} \t vq_loss: {np.round(vq_loss.item(),4)}')
                dict = {"loss":loss.item(),
                        "perplexity":perplexity.item(),
                        "recon_error": recon_error,
                        "vq_loss": vq_loss}
                step = epoch*config["num_training_updates"] + i
                self.wandb_logging(dict, step=step)
                                   
                if (i+1) % 20 == 0:
                    try:
                        test_ = TestModel(self._model, iterator, 8)
                        #torch.save(model.state_dict(),f'model_{epoch}_{i}.pkl')
                        originals, reconstructions, test_error = test_.reconstruct()
                        fig = test_.plot_reconstructions(originals, reconstructions, 8)
                        images = wandb.Image(fig, caption= f"recon_error: {np.round(test_error.item(),4)}")
                        self.wandb_logging({"examples": images}, step=i)
                        
                    except Exception as e:
                        print("error")
                        logs.append(e)
                        continue
                else:
                    pass

                if recon_error < 20:
                    wandb.alert(
                    title='High accuracy',
                    text=f'Recon error {recon_error} is lower than 0.5',
                    level=AlertLevel.WARN,
                    wait_duration=timedelta(minutes=5)
                                )
                    _time = datetime.datetime.now()       
                    torch.save(self._model.state_dict(),f'{run_name}_low_error.pkl')
                else:
                    pass

            scheduler.step()
            torch.cuda.empty_cache()
            time = datetime.datetime.now()
            torch.save(self._model.state_dict(),f'{time.hour}_{time.minute}_{run_name}_{epoch}.pkl')
            output.clear()
            print(optimizer.state_dict()["param_groups"][0]["lr"])

        wandb.finish()
        return self._model, logs, run_name

                




In [None]:
root_path = '/content/drive/Shareddrives/ConservacionBiologicaIA/Datos/Jaguas_2018'

config = {
    "project" : "VQ-VAE-Jaguas",
    "batch_size" : 14,
    "num_epochs": 6,
    "num_hiddens" : 128,
    "embedding_dim" : 8,
    "num_embeddings" : 512,
    "commitment_cost" : 0.25,
    "decay" : 0.99,
    "learning_rate" : 1e-2,
    "dataset": "Audios Jaguas",
    "architecture": "VQ-VAE",
}

model = Model(config["num_hiddens"],
              config["num_embeddings"], config["embedding_dim"], 
              config["commitment_cost"], config["decay"]).to(device)

dataset = SoundscapeData(root_path, audio_length=12, ext="wav", win_length=1028)
dataset_train, dataset_test = random_split(dataset,
                                           [round(len(dataset)*0.2), len(dataset) - round(len(dataset)*0.2)], 
                                           generator=torch.Generator().manual_seed(1024))

training_loader = DataLoader(dataset_train, batch_size=config["batch_size"], shuffle = True)
test_loader = DataLoader(dataset_test, batch_size=config["batch_size"])

optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], amsgrad=False)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 2, gamma = 0.1 )

config["optimizer"] = optimizer
config["scheduler"] = scheduler
config["audio_length"] = dataset.audio_length
config["num_training_updates"] = len(training_loader)
config["win_length"] = dataset.win_length


In [None]:
Training = TrainModel(model)
time = datetime.datetime.now()
model, logs, run_name = Training.fordward(training_loader, test_loader, config)
torch.save(model.state_dict(),f'{time.hour}_{time.minute}_{run_name}.pkl')
#np.savetxt("corrupted_files.csv", logs, delimiter=",")

torch.cuda.memory_summary(device=None, abbreviated=False)


In [None]:
# wandb.finish()
model.load_state_dict(torch.load(f'Models/Best_Model_Embedding_256_VQ_audio_length 12_win_length 1028_batch_size 8__5.pkl', map_location=torch.device('cpu')))

In [None]:
root_path = '/content/drive/Shareddrives/ConservacionBiologicaIA/Datos/Porce_2019'


dataset = SoundscapeData(root_path=root_path, audio_length=12, ext='WAV', win_length=config["win_length"])
dataset_train, dataset_test = random_split(dataset,
                                          [round(len(dataset)*0.10), len(dataset) - round(len(dataset)*0.10)], 
                                           generator=torch.Generator().manual_seed(1024))

training_loader = DataLoader(dataset_train, batch_size=config["batch_size"], shuffle = False)
test_loader = DataLoader(dataset_test, batch_size=config["batch_size"])
iterator = iter(test_loader)

In [None]:
spec, record, _ = next(iter(test_loader))

In [None]:
spec_2 = audio_transform.Spectrogram(n_fft=1028, win_length=1028, window_fn=torch.hamming_window,power=2)(record)

In [None]:
a = reconstruction.permute(1,0,2,3)
b = a.squeeze(dim=0)
b.shape

In [None]:
b = b.type(torch.complex64).to("cpu")
wav = audio_transform.InverseSpectrogram(n_fft=1028, win_length=1028, hop_length=514)(b)

In [None]:
Test = TestModel(model, iterator, 8)
originals, reconstruction, wav_ori, wav_recons, error = Test.run(wave_plot=True)
a = Test.waveform_generator(originals)
b = Test.waveform_generator(reconstruction)
waves = [a,b]
Test.plot_psd(waves)

In [None]:
import torchaudio.transforms as audio_transform
reconstruction = reconstruction.cdouble()
reconstruction = reconstruction.to("cpu")
originals = originals.cdouble()
originals = originals.to("cpu")
transformation = audio_transform.InverseSpectrogram(n_fft=1025, win_length = 1025)
waveform = transformation(reconstruction)
waveform_original = transformation(originals)
waveform = waveform.cpu().detach().numpy()
waveform_original = waveform_original.cpu().detach().numpy()

In [None]:
diff = waveform[0] - waveform_original[0]
diff = diff**2
plt.psd(waveform_original)
plt.psd(waveform)
#plt.psd(diff[0])
plt.show()