In [None]:
from google.colab import drive, output
drive.mount('/content/drive')
import sys
%cd '/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project'
#sys.path.append('/content/drive/MyDrive/Deep Learning/AutoEncoders/Project/VQVAE_Working/data')
#sys.path.append('/content/drive/MyDrive/Deep Learning/AutoEncoders/Project/VQVAE_Working/models')
sys.path.append('/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project/Dataloader')
sys.path.append('/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project/Models')
sys.path.append('/content/drive/MyDrive/PhD_Thesis_Experiments/DeepLearning/AutoEncoders/Project/Modules')
%load_ext autoreload
%autoreload 1
!pip install torchaudio
!pip install umap
!pip install wandb --upgrade
!wandb login
output.clear()

In [None]:
from __future__ import print_function
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import IPython

from six.moves import xrange
from torch.optim import lr_scheduler

import umap
import wandb
import torch
torch.cuda.empty_cache()
import gc
gc.collect()
torch.cuda.empty_cache()

from scipy import signal
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

from torch.utils.data import random_split
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import torch.nn.functional as F
import torchaudio.transforms as audio_transform

#from ResidualStack import ResidualStack
#from Residual import Residual

from Jaguas_DataLoader import SoundscapeData
from Models import Model
from Models import Encoder
from Models import Decoder
from Models import VectorQuantizer
from Models import VectorQuantizerEMA

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = xm.xla_device()
print(device)

from datetime import timedelta
import wandb
from wandb import AlertLevel



## Train

We use the hyperparameters from the author's code.

In [None]:
class TestModel:

    def __init__(self, model, iterator, num_views):
        self._model = model
        self._iterator = iterator
        self.num_views = num_views


    def waveform_generator(self, spec, n_fft=1025, win_length=1025, audio_length=59, plot=False):
        spec= originals.to("cpu")
        hop_length = int(np.round(base_win/win_length * 172.3 * audio_length))
        transformation = audio_transform.InverseSpectrogram(n_fft=n_fft, win_length=win_length, hop_length=hop_length)
        waveform = transformation(spec)
        waveform = waveform.cpu().detach().numpy()
        return waveform
    
    def plot_psd(self, waveform):
        plt.psd(waveform)


    def plot_reconstructions(self, imgs_original, imgs_reconstruction, num_views:int = 8):
        output = torch.cat((imgs_original[0:self.num_views], imgs_reconstruction[0:self.num_views]), 0)
        img_grid = make_grid(output, nrow=self.num_views, pad_value=20)
        fig, ax = plt.subplots(figsize=(20,5))
        ax.imshow(img_grid[1,:,:].cpu(), vmin=0, vmax=1)
        ax.axis("off")
        plt.show()
        return fig

    def reconstruct(self):
        self._model.eval()
        (valid_originals, _,_) = next(self._iterator)
        valid_originals = torch.reshape(valid_originals, (valid_originals.shape[0] * valid_originals.shape[1], 
                                                    valid_originals.shape[2], valid_originals.shape[3]))
        valid_originals = torch.unsqueeze(valid_originals,1)

        valid_originals = valid_originals.to(device)

        vq_output_eval = self._model._pre_vq_conv(self._model._encoder(valid_originals))
        _, valid_quantize, _, _ = self._model._vq_vae(vq_output_eval)
        valid_reconstructions = self._model._decoder(valid_quantize)

        recon_error = F.mse_loss(valid_originals, valid_reconstructions)

        return valid_originals, valid_reconstructions, recon_error

    def run(self):
        originals, reconstructions, error = self.reconstruct() 
        self.plot_reconstructions(originals, reconstructions)


class TrainModel:

    def __init__(self, model):
        self._model = model

    def wandb_init(self, config):
        try:
            wandb.login()
            wandb.finish()
            wandb.init(project="VQ-VAE-Jaguas", config=config)
            wandb.run.name = run_name
            wandb.run.save()
            wandb.watch(self._model, F.mse_loss, log="all", log_freq=1)
            is_wandb_enable = True         
        except:
            is_wandb_enable = False

        return is_wandb_enable

    def wandb_logging(self, dict):
        for keys in dict:
            wandb.log({keys: dict[keys]})


    def fordward(self, training_loader, test_loader, config):
        iterator = iter(test_loader)
        wandb_enable = self.wandb_init(config)
        optimizer = config["optimizer"]
        scheduler = config["scheduler"]

        train_res_recon_error = []
        train_res_perplexity = []
        logs = []
        best_loss = 10000

        for epoch in range(config["num_epochs"]):
            for i in xrange(config["num_training_updates"]):
                self._model.train()
                try:
                    (data, _,_) = next(iter(training_loader))
                except Exception as e:
                    logs.append(e)
                    continue

                data = torch.reshape(data, (data.shape[0] * data.shape[1], data.shape[2], data.shape[3]))
                data = torch.unsqueeze(data,1)
                data = data.to(device)
                # print(data.shape)

                optimizer.zero_grad()
                vq_loss, data_recon, perplexity = self._model(data)
                # print(data_recon.shape)
                
                recon_error = F.mse_loss(data_recon, data) #/ data_variance
                loss = recon_error + vq_loss
                loss.backward()

                optimizer.step()
                print(f'epoch: {epoch} of {config["num_epochs"]} \t iteration: {(i+1)} of {config["num_training_updates"]} \t loss: {np.round(loss.item(),4)} \t recon_error: {np.round(recon_error.item(),4)} \t vq_loss: {np.round(vq_loss.item(),4)}')
                dict = {"loss":loss.item(),
                        "perplexity":perplexity.item(),
                        "recon_error": recon_error,
                        "vq_loss": vq_loss}
                self.wandb_logging(dict)
                                
                
                if (i+1) % 500 == 0:
                    try:
                        test_ = TestModel(self._model, iterator, 8)
                        #torch.save(model.state_dict(),f'model_{epoch}_{i}.pkl')
                        originals, reconstructions, test_error = test_.reconstruct()
                        fig = test_.plot_reconstructions(originals, reconstructions, 8)
                        images = wandb.Image(fig, caption= f"recon_error: {np.round(test_error.item(),4)}")
                        self.wandb_logging({"examples": images})
                        
                    except Exception as e:
                        print(e)
                        logs.append(e)
                        continue
                else:
                    pass

                if recon_error < 0.5:
                    wandb.alert(
                    title='High accuracy',
                    text=f'Recon error {recon_error} is lower than 0.5',
                    level=AlertLevel.WARN,
                    wait_duration=timedelta(minutes=5)
                                )        
                    torch.save(model.state_dict(),f'{run_name}_low_error.pkl')
                else:
                    pass
            
            scheduler.step()
            torch.cuda.empty_cache()

        wandb.finish()
        return self._model, logs

                




In [None]:
root_path = '/content/drive/Shareddrives/ConservacionBiologicaIA/Datos/Jaguas_2018'


dataset = SoundscapeData(root_path, audio_length=59, ext="wav", win_length=257)
dataset_train, dataset_test = random_split(dataset,
                                           [round(len(dataset)*0.10), len(dataset) - round(len(dataset)*0.10)], 
                                           generator=torch.Generator().manual_seed(1024))

config = {
    "project" : "VQ-VAE-Jaguas",
    "batch_size" : 8,
    "num_epochs": 4,
    "num_training_updates" : len(dataset_train),
    "num_hiddens" : 64,
    "embedding_dim" : 128,
    "num_embeddings" : 64,
    "commitment_cost" : 0.25,
    "decay" : 0.99,
    "learning_rate" : 1e-2,
    "dataset": "Audios Jaguas",
    "architecture": "VQ-VAE",
    "win_length":dataset.win_length
}

run_name = "Audio_Length_"+str(dataset.audio_length)+"_secs_"+"win_length_"+str(dataset.win_length)+"_Batch_size: "+ str(config["batch_size"]) +" num_hiddens: " + str(config["num_hiddens"]) +" num_embeddings: " + str(config["num_embeddings"]) +" embedding dim: "+  str(config["embedding_dim"])

training_loader = DataLoader(dataset_train, batch_size=config["batch_size"], shuffle = False)
test_loader = DataLoader(dataset_test, batch_size=config["batch_size"])


model = Model(config["num_hiddens"],
              config["num_embeddings"], config["embedding_dim"], 
              config["commitment_cost"], config["decay"]).to(device)

optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], amsgrad=False)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.1 )

config["optimizer"] = optimizer
config["scheduler"] = scheduler




In [None]:
Training = TrainModel(model)
model, logs = Training.fordward(training_loader, test_loader, config)
torch.save(model.state_dict(),f'{run_name}.pkl')
np.savetxt("corrupted_files.csv", logs, delimiter=",")

wandb.finish()
torch.cuda.memory_summary(device=None, abbreviated=False)


In [None]:
iterator = iter(test_loader)
tmod = TestModel(model, iterator)
tmod.run()

In [None]:
(data, _,_) = next(iter(training_loader))
print(data.shape)
data = torch.reshape(data, (data.shape[0] * data.shape[1], data.shape[2], data.shape[3]))
print(data.shape)
data = torch.unsqueeze(data,1)
data = data.to(device)
print(data.shape)

In [None]:
# wandb.finish()
model.load_state_dict(torch.load('Models/Few_layers_59_secs__Batch_size_8_num_hiddens_64 num_embeddings_64_embedding_dim_128.pkl', map_location=torch.device('cpu')))

In [None]:
root_path = '/content/drive/Shareddrives/ConservacionBiologicaIA/Datos/Porce_2019'


dataset = SoundscapeData(root_path=root_path, audio_length=59, ext='WAV', win_length=config["win_length"])
dataset_train, dataset_test = random_split(dataset,
                                          [round(len(dataset)*0.10), len(dataset) - round(len(dataset)*0.10)], 
                                           generator=torch.Generator().manual_seed(1024))

training_loader = DataLoader(dataset_train, batch_size=config["batch_size"], shuffle = False)
test_loader = DataLoader(dataset_test, batch_size=config["batch_size"])
iterator = iter(test_loader)

In [None]:
originals, record, sr = iterator.next()

In [None]:
iterator.next()
Test = TestModel(model, iterator, 1)
originals, reconstruction, error = Test.reconstruct()
Test.run()

In [None]:
import torchaudio.transforms as audio_transform
reconstruction = reconstruction.cdouble()
reconstruction = reconstruction.to("cpu")
originals = originals.cdouble()
originals = originals.to("cpu")
transformation = audio_transform.InverseSpectrogram(n_fft=1025, win_length = 1025)
waveform = transformation(reconstruction)
waveform_original = transformation(originals)
waveform = waveform.cpu().detach().numpy()
waveform_original = waveform_original.cpu().detach().numpy()

In [None]:
import matplotlib.pyplot as plt

plt.plot(waveform_original[0,0])
plt.show()
plt.figure()
plt.plot(waveform[0,0])


In [None]:
diff = waveform[0] - waveform_original[0]
diff = diff**2
plt.psd(waveform_original)
plt.psd(waveform)
#plt.psd(diff[0])
plt.show()

In [None]:
from scipy.io.wavfile import write

scaled = np.int16(waveform_original[0, 0]/np.max(np.abs(waveform_original[0 ,0])) * 32767)
write('waveform_original.wav', 22050, scaled)

scaled2 = np.int16(waveform[0, 0]/np.max(np.abs(waveform[0 ,0])) * 32767)
write('waveform.wav', 22050, scaled2)

