In [1]:
import os
import io

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from tqdm.notebook import tqdm
from IPython.display import display, Audio
from PIL import Image

import soundfile as sf
import librosa
import librosa.display

import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
from lightning.pytorch import loggers as pl_loggers



from aquatk.embedding_extractors import VGGish
from aquatk.metrics.frechet_distance import frechet_audio_distance

2024-03-14 10:24:17.688263: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-14 10:24:17.688289: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-14 10:24:17.689298: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-14 10:24:17.699024: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Dataset embeddings

In [2]:
class BirdClefDataset(Dataset):
    def __init__(
        self,
        root_dir,
        bird_name=None,
        transform=None,
        num_samples=65_500,
        min_db=-80,
        max_db=0,
        cache=True
    ):
        self.root_dir = root_dir
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db

        if cache:
            self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(self.root_dir)), 'cache')
            os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self.cache_dir = None
        self.bird_folders = sorted(os.listdir(root_dir))

        if bird_name is not None:
            self.bird_folders = [bird_name]

        self.audio_files = []

        for bird_folder in self.bird_folders:
            bird_path = os.path.join(root_dir, bird_folder)
            audio_files = [os.path.join(bird_path, file) for file in os.listdir(bird_path) if file.endswith('.ogg')]
            self.audio_files.extend(audio_files)

    def get_spec(self, audio_path):
        waveform, sample_rate = librosa.load(audio_path, sr=None, mono=True)

        if len(waveform) < self.num_samples:
            pad_amount = self.num_samples - len(waveform)
            waveform = np.pad(waveform, (0, pad_amount))
        else:
            waveform = waveform[:self.num_samples]

        mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sample_rate)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return mel_spec

    def normalize(self, x):
        return (x- self.min_db) / (self.max_db - self.min_db)

    def denormalize(self, x):
        if isinstance(x, torch.Tensor):
            x = x.cpu().detach().numpy()

        flattened_array = x.reshape((x.shape[0], -1))

        min_batch_values = flattened_array.min(axis=-1, keepdims=True)
        max_batch_values = flattened_array.max(axis=-1, keepdims=True)

        normalized_array = self.min_db + ((flattened_array - min_batch_values) / (max_batch_values - min_batch_values)) * (self.max_db - self.min_db)

        normalized_batch = normalized_array.reshape(x.shape)

        return normalized_batch

    def cache_all(self):
        self.cache = True
        for idx in range(len(self)):
            self.__getitem__(idx)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]

        if self.cache_dir is not None:
            cache_filename = f"{os.path.basename(audio_path)}_{self.num_samples}.npy"
            cache_path = os.path.join(self.cache_dir, cache_filename)

        if self.cache_dir is not None and os.path.isfile(cache_path):
            try:
                mel_spec = np.load(cache_path)
            except Exception as e:
                raise IOError(f"Failed to read file: {cache_path}. Error: {e}")
        else:
            mel_spec = self.get_spec(audio_path)

            # Normalize mel_spec
            mel_spec = self.normalize(mel_spec)
            mel_spec = np.expand_dims(mel_spec, axis=0)

            # Save mel spectrogram to cache
            if self.cache_dir is not None:
                np.save(cache_path, mel_spec)

        if self.transform is not None:
            mel_spec = self.transform(mel_spec)

        folder, filename = os.path.split(audio_path)
        basedir, bird = os.path.split(folder)

        return mel_spec, bird, filename

In [3]:
class BirdClefDataModule(pl.LightningDataModule):
    def __init__(self,
                 root_dir,
                 batch_size=64,
                 validation_split=0.2,
                 num_workers=10,
                 bird_name=None,
                 transform=None,
                 num_samples=65_500,
                 min_db=-80,
                 max_db=0,
                 cache=True,
                 seed=0
                ):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.validation_split = validation_split
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db
        self.cache = cache
        self.seed = seed

        self.save_hyperparameters()

    def setup(self, stage=None):
        dataset = BirdClefDataset(self.root_dir, bird_name=self.bird_name)
        self.normalize = dataset.normalize
        self.denormalize = dataset.denormalize
        if stage == 'fit' or stage is None:
            train_dataset, validation_dataset = torch.utils.data.random_split(dataset,
                                                                              (1 - self.validation_split, self.validation_split),
                                                                             torch.Generator().manual_seed(self.seed))
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
            self.validation_loader = DataLoader(validation_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def illustration_dataloader(self, batch_size):
        illustrative_dataset = self.validation_loader.dataset
        illustrative_loader = DataLoader(illustrative_dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers)
        return illustrative_loader

In [4]:
root_directory = 'train_audio'
data_module = BirdClefDataModule(root_directory, batch_size=128)
data_module.setup()

## VAE inference

In [5]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def __init__(self, size=16, channels=128):
        super(UnFlatten, self).__init__()
        self.size = size
        self.channels = channels
    def forward(self, input):
        return input.view(input.size(0), self.channels, self.size, self.size)

In [6]:
class VAEGANLoss(torch.nn.modules.loss._Loss):
    __constants__ = ['reduction']

    def __init__(self, beta=1.0, size_average=None, reduce=None, reduction: str = 'mean', reconstruction_loss=F.mse_loss, reconstruction_decoder_weight=100):
        super().__init__(size_average, reduce, reduction)
        self.beta = beta
        self.reconstruction_loss = reconstruction_loss
        self.reconstruction_decoder_weight = reconstruction_decoder_weight

    def forward(self, reconstructed_x, x, mean, logvar, sampled_classif, reconstructed_classif, original_classif):
        reconstruction_loss = self.reconstruction_loss(reconstructed_x, x, reduction=self.reduction)
        kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=1)
        if self.reduction != 'none':
            kl_loss = torch.mean(kl_loss) if self.reduction == 'mean' else torch.sum(kl_loss)

        loss_encoder = reconstruction_loss + self.beta * kl_loss

        loss_discriminator_samples = F.binary_cross_entropy_with_logits(
            sampled_classif,
            torch.ones_like(sampled_classif),
            reduction=self.reduction
        )
        loss_discriminator_original = F.binary_cross_entropy_with_logits(
            original_classif,
            torch.zeros_like(original_classif),
            reduction=self.reduction
        )
        loss_discriminator_reconstructed = F.binary_cross_entropy_with_logits(
            reconstructed_classif,
            torch.ones_like(reconstructed_classif),
            reduction=self.reduction
        )
        loss_discriminator = 1/4 * (loss_discriminator_samples + loss_discriminator_original * 2 + loss_discriminator_reconstructed)


        loss_decoder_reconstructed = F.binary_cross_entropy_with_logits(
            reconstructed_classif,
            torch.zeros_like(reconstructed_classif),
            reduction=self.reduction
        )

        loss_decoder_samples = F.binary_cross_entropy_with_logits(
            sampled_classif,
            torch.zeros_like(sampled_classif),
            reduction=self.reduction
        )

        loss_decoder = (loss_decoder_samples +
                        loss_decoder_reconstructed +
                        reconstruction_loss * self.reconstruction_decoder_weight)/(self.reconstruction_decoder_weight+2)

        return (
            loss_encoder,
            loss_decoder,
            loss_discriminator,
            reconstruction_loss,
            kl_loss,
            loss_decoder_reconstructed,
            loss_decoder_samples,
            loss_discriminator_samples,
            loss_discriminator_original,
            loss_discriminator_reconstructed
        )

In [7]:
class VAEGAN(pl.LightningModule):
    def __init__(
        self,
        input_channels=1,
        img_size=256,
        hidden_size=4096,
        layers=[16, 32, 64],
        learning_rate=0.001,
        lr_decay=1,
        beta=1e-3,
        activation=nn.ReLU,
        optimizer=torch.optim.Adam,
        reconstruction_loss=F.mse_loss,
        reconstruction_decoder_weight=100,
        generate_on_epoch=4,
        reconstruct_on_epoch=4,
        generator_too_good=.6, # value of the discriminator's loss above which the generator shouldn't be trained
        discriminator_too_good=.3, # value of the discriminator's loss above which the discriminator shouldn't be trained
        seed=0
    ):
        super(VAEGAN, self).__init__()

        if img_size % (2**len(layers)) != 0:
            raise ValueError("An image of size {image_size} with {len(layers)} layers won't be reconstructed with the correct size")

        self.hidden_size = hidden_size
        self.encoder = self.build_encoder(input_channels, layers, activation)
        last_conv_size = (img_size // (2**(len(layers)+1)))**2 * layers[-1]
        self.mean_layer = nn.Linear(last_conv_size, self.hidden_size)
        self.logvar_layer = nn.Linear(last_conv_size, self.hidden_size)
        self.decoder = self.build_decoder(input_channels, layers, img_size, last_conv_size, activation)
        self.discriminator = self.build_discriminator(input_channels, layers, activation, last_conv_size)

        self.generator_too_good = generator_too_good
        self.discriminator_too_good = discriminator_too_good

        self.learning_rate = learning_rate
        self.lr_decay = lr_decay
        self.optim = optimizer
        self.criterion = VAEGANLoss(
            beta=beta,
            reconstruction_loss=reconstruction_loss,
            reconstruction_decoder_weight=reconstruction_decoder_weight
        )
        self.generate_on_epoch = generate_on_epoch
        self.reconstruct_on_epoch = reconstruct_on_epoch
        self.seed = seed

        self.prev_disc_loss = None

        self.automatic_optimization = False

        self.save_hyperparameters()

    def build_discriminator(self, input_channels, channels_list, activation, last_conv_size):
        layers = []
        in_channels = input_channels
        for out_channels in channels_list:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1))
            layers.append(activation())
            in_channels = out_channels
        layers.append(Flatten())
        layers.append(nn.Linear(last_conv_size, 1))
        return nn.Sequential(*layers)

    def build_encoder(self, input_channels, channels_list, activation):
        layers = []
        in_channels = input_channels
        for out_channels in channels_list:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1))
            layers.append(activation())
            #layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            #layers.append(activation())
            #layers.append(nn.MaxPool2d(2))
            in_channels = out_channels
        layers.append(Flatten())
        return nn.Sequential(*layers)

    def build_decoder(self, input_channels, channels_list, img_size, last_conv_size, activation):
        layers = [
            nn.Linear(self.hidden_size, last_conv_size),
            UnFlatten(img_size // (2**(len(channels_list)+1)), channels_list[-1])
        ]
        for in_channels, out_channels in zip(channels_list[::-1], channels_list[-2::-1]+[input_channels]):
            layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(activation())
            #layers.append(nn.Upsample(scale_factor=2, mode="bilinear"))
            #layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            #layers.append(activation())
        layers.pop()
        layers.append(nn.Sigmoid())
        return nn.Sequential(*layers)

    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def sample(self, mean, logvar, seed=None):
        std = logvar.exp()
        if seed is None:
            epsilon = torch.randn_like(std)
        else:
            gen = torch.Generator(device=self.device).manual_seed(seed)
            epsilon = torch.empty_like(std).normal_(generator=gen)
        z = mean + std*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, originals):
        # reconstruct from inputs
        mean, logvar = self.encode(originals)
        z = self.sample(mean, logvar)
        reconstructed = self.decode(z)

        # generate from scratch
        mean_ = torch.zeros([originals.shape[0], self.hidden_size]).to(self.device)
        logvar_ = torch.zeros([originals.shape[0], self.hidden_size]).to(self.device)
        z_ = self.sample(mean_, logvar_)
        sampled = self.decode(z_)

        # discriminate
        logits_originals = self.discriminator(originals)
        logits_reconstructed = self.discriminator(reconstructed)
        logits_sampled = self.discriminator(sampled)

        return reconstructed, sampled, mean, logvar, logits_originals, logits_reconstructed, logits_sampled

    def configure_optimizers(self):
        if isinstance(self.learning_rate, tuple):
            lr_encoder, lr_decoder, lr_discriminator = self.learning_rate
        else:
            lr_encoder = lr_decoder = lr_discriminator = self.learning_rate

        if isinstance(self.lr_decay, tuple):
            lr_decay_encoder, lr_decay_decoder, lr_decay_discriminator = self.lr_decay
        else:
            lr_decay_encoder = lr_decay_decoder = lr_decay_discriminator = self.lr_decay

        optimizer_encoder = self.optim(
            list(self.encoder.parameters()) + list(self.mean_layer.parameters()) + list(self.logvar_layer.parameters()),
            lr=lr_encoder
        )
        optimizer_decoder = self.optim(self.decoder.parameters(), lr=lr_decoder)
        optimizer_discriminator = self.optim(self.discriminator.parameters(), lr=lr_discriminator)

        scheduler_encoder = {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer_encoder, gamma=lr_decay_encoder),
        }

        scheduler_decoder = {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer_decoder, gamma=lr_decay_decoder),
        }

        scheduler_discriminator = {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer_discriminator, gamma=lr_decay_discriminator),
        }

        return (
            [
                optimizer_encoder,
                optimizer_decoder,
                optimizer_discriminator
            ],
            [
                scheduler_encoder,
                scheduler_decoder,
                scheduler_discriminator
            ]
        )


    def training_step(self, batch, batch_idx):
        optimizer_encoder, optimizer_decoder, optimizer_discriminator = self.optimizers()
        #self.toggle_optimizer(optimizer_encoder)
        #self.toggle_optimizer(optimizer_decoder)
        #self.toggle_optimizer(optimizer_discriminator)
        x, birdname, file = batch
        reconstructed, sampled, mean, logvar, logits_originals, logits_reconstructed, logits_sampled = self(x)
        (loss_encoder,
         loss_decoder,
         loss_discriminator,
         reconstruction_loss,
         kl_loss,
         loss_decoder_reconstructed,
         loss_decoder_samples,
         loss_discriminator_samples,
         loss_discriminator_original,
         loss_discriminator_reconstructed
        ) = self.criterion(
            reconstructed,
            x,
            mean,
            logvar,
                logits_sampled,
            logits_reconstructed,
            logits_originals
        )


        optimizer_encoder.zero_grad()
        optimizer_decoder.zero_grad()
        optimizer_discriminator.zero_grad()

        # we apply the loss to the encoder and then prevent it from being affected by other losses
        loss_encoder.backward(retain_graph=True)
        for group in optimizer_encoder.param_groups:
            for param in group['params']:
                param.requires_grad = False

        # we reset the decoder and discriminator that have been affected by the encoder's loss
        optimizer_decoder.zero_grad()
        optimizer_discriminator.zero_grad()

        # we apply the loss to the decoder and then prevent it from being affected by other losses
        loss_decoder.backward(retain_graph=True)
        for group in optimizer_decoder.param_groups:
            for param in group['params']:
                param.requires_grad = False

        # we reset the discriminator that has been affected by the decoder's loss
        optimizer_discriminator.zero_grad()
        loss_discriminator.backward()

        optimizer_encoder.step()
        if self.prev_disc_loss is None or self.prev_disc_loss < self.generator_too_good:
            optimizer_decoder.step()
            self.log('train/active/generator', 1, on_epoch=False, on_step=True, batch_size=x.shape[0])
        else:
            self.log('train/active/generator', 0, on_epoch=False, on_step=True, batch_size=x.shape[0])
        if self.prev_disc_loss is None or self.prev_disc_loss > self.discriminator_too_good:
            optimizer_discriminator.step()
            self.log('train/active/discriminator', 1, on_epoch=False, on_step=True, batch_size=x.shape[0])
        else:
            self.log('train/active/discriminator', 0, on_epoch=False, on_step=True, batch_size=x.shape[0])

        self.prev_disc_loss = loss_discriminator.detach()

        # resettings things to thier normal states
        for group in optimizer_encoder.param_groups + optimizer_decoder.param_groups:
            for param in group['params']:
                param.requires_grad = True

        self.log('train/loss_encoder', loss_encoder, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_decoder', loss_decoder, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_discriminator', loss_discriminator, on_epoch=True, on_step=True, batch_size=x.shape[0])

        self.log('train/loss_reconstruction', reconstruction_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_reconstruction', kl_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_decoder_reconstructed', loss_decoder_reconstructed, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_decoder_samples', loss_decoder_samples, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_discriminator_samples', loss_discriminator_samples, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_discriminator_original', loss_discriminator_original, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_discriminator_reconstructed', loss_discriminator_reconstructed, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_reconstruction', reconstruction_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])

        tensorboard = self.logger.experiment
        tensorboard.add_histogram("train_mean", mean, self.trainer.num_training_batches * self.current_epoch + batch_idx)
        tensorboard.add_histogram("train_logvar", logvar, self.trainer.num_training_batches * self.current_epoch + batch_idx)

        tensorboard.add_histogram(
            "pred_originals",
            F.sigmoid(logits_originals),
            self.trainer.num_training_batches * self.current_epoch + batch_idx
        )
        tensorboard.add_histogram(
            "pred_reconstructed",
            F.sigmoid(logits_reconstructed),
            self.trainer.num_training_batches * self.current_epoch + batch_idx
        )
        tensorboard.add_histogram(
            "pred_sampled",
            F.sigmoid(logits_sampled),
            self.trainer.num_training_batches* self.current_epoch + batch_idx
        )

    def validation_step(self, batch, batch_idx):
        x, birdname, file = batch
        reconstructed, sampled, mean, logvar, logits_originals, logits_reconstructed, logits_sampled = self(x)

        (loss_encoder,
         loss_decoder,
         loss_discriminator,
         reconstruction_loss,
         kl_loss,
         loss_decoder_reconstructed,
         loss_decoder_samples,
         loss_discriminator_samples,
         loss_discriminator_original,
         loss_discriminator_reconstructed
        ) = self.criterion(
            reconstructed,
            x,
            mean,
            logvar,
                logits_sampled,
            logits_reconstructed,
            logits_originals
        )

        self.log('validation/loss_encoder', loss_encoder, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_decoder', loss_decoder, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_discriminator', loss_discriminator, on_epoch=True, on_step=False, batch_size=x.shape[0])

        self.log('validation/loss_reconstruction', reconstruction_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_reconstruction', kl_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_decoder_reconstructed', loss_decoder_reconstructed, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_decoder_samples', loss_decoder_samples, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_discriminator_samples', loss_discriminator_samples, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_discriminator_original', loss_discriminator_original, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_discriminator_reconstructed', loss_discriminator_reconstructed, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_reconstruction', reconstruction_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])

        tensorboard = self.logger.experiment
        tensorboard.add_histogram("val_mean", mean, self.trainer.num_val_batches[0] * self.current_epoch + batch_idx)
        tensorboard.add_histogram("val_logvar", logvar, self.trainer.num_val_batches[0] * self.current_epoch + batch_idx)

        tensorboard.add_histogram(
            "validation/pred_originals",
            F.sigmoid(logits_originals),
            self.trainer.num_val_batches[0] * self.current_epoch + batch_idx
        )
        tensorboard.add_histogram(
            "validation/pred_reconstructed",
            F.sigmoid(logits_reconstructed),
            self.trainer.num_val_batches[0] * self.current_epoch + batch_idx
        )
        tensorboard.add_histogram(
            "validation/pred_sampled",
            F.sigmoid(logits_sampled),
            self.trainer.num_val_batches[0] * self.current_epoch + batch_idx
        )

    def generate_specs(self, n=None):
        if n is None:
            n = self.trainer.datamodule.batch_size
        mean = torch.zeros([n, self.hidden_size]).to(self.device)
        logvar = torch.zeros([n, self.hidden_size]).to(self.device)
        z = self.sample(mean, logvar, seed=0)
        specs = self.decode(z)

        return self.trainer.datamodule.denormalize(specs[:, 0])

    def spec_to_img(self, spec):
        fig, ax = plt.subplots()
        img = librosa.display.specshow(spec, x_axis='time', y_axis='mel', sr=32000, ax=ax)
        fig.colorbar(img, ax=ax, format='%+2.0f dB')
        ax.set(title='Mel-frequency spectrogram')

        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)

        plt.close(fig)
        image = Image.open(buffer)
        image = image.convert('RGB')
        image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)  # Convert to tensor and adjust dimensions

        return image_tensor

    def reconstruction_to_img(self, original_spec, reconstructed_spec):
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))  # Create subplots with 1 row and 2 columns
        img1 = librosa.display.specshow(original_spec, x_axis='time', y_axis='mel', sr=32000, ax=axs[0])
        fig.colorbar(img1, ax=axs[0], format='%+2.0f dB')
        axs[0].set(title='Mel-frequency spectrogram - original')

        img2 = librosa.display.specshow(reconstructed_spec, x_axis='time', y_axis='mel', sr=32000, ax=axs[1])
        fig.colorbar(img2, ax=axs[1], format='%+2.0f dB')
        axs[1].set(title='Mel-frequency spectrogram - reconstructed')

        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)

        plt.close(fig)
        image = Image.open(buffer)
        image = image.convert('RGB')
        image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)  # Convert to tensor and adjust dimensions

        return image_tensor

    def reconstruct(self, batch):
        output, _, _, _, _, _, _ = self(batch)
        return output

    def on_validation_end(self):
        tensorboard = self.logger.experiment

        specs = self.generate_specs(self.generate_on_epoch)
        for i, spec in enumerate(specs):
            tensorboard.add_image(f"generated_spectrogram_{i}", self.spec_to_img(spec), self.global_step)
            audio = librosa.feature.inverse.mel_to_audio(spec)
            tensorboard.add_audio(f"generated_audio_{i}", audio, self.global_step, 32000)

        reconstruction_dataloader = self.trainer.datamodule.illustration_dataloader(self.reconstruct_on_epoch)
        originals, birdnames, files = next(iter(reconstruction_dataloader))
        originals = originals.to(self.device)
        reconstructed = self.reconstruct(originals)
        originals = self.trainer.datamodule.denormalize(originals)[:, 0]
        reconstructed = self.trainer.datamodule.denormalize(reconstructed)[:, 0]
        for i, (original_spec, reconstructed_spec) in enumerate(zip(originals, reconstructed)):
            tensorboard.add_image(
                f"reconstructed_spectrogram_{i}",
                self.reconstruction_to_img(original_spec, reconstructed_spec),
                self.global_step
            )
            original_audio = librosa.feature.inverse.mel_to_audio(original_spec)
            reconstructed_audio = librosa.feature.inverse.mel_to_audio(reconstructed_spec)
            tensorboard.add_audio(f"original_audio_{i}", original_audio, self.global_step, 32000)
            tensorboard.add_audio(f"reconstructed_audio_{i}", reconstructed_audio, self.global_step, 32000)

    def on_train_epoch_end(self):
        for scheduler in self.lr_schedulers():
            scheduler.step()

    def on_train_start(self):
        self.logger.log_hyperparams(self.hparams)


In [8]:
seed = 0
pl.seed_everything(seed, workers=True)

model = VAEGAN.load_from_checkpoint("vaegan/lightning_logs/version_21/checkpoints/epoch=63-step=16479.ckpt")
mean = torch.zeros([1024, model.hidden_size]).to(model.device)
logvar = torch.zeros([1024, model.hidden_size]).to(model.device)
z = model.sample(mean, logvar, seed=0)
specs = model.decode(z)
ds = BirdClefDataset(root_dir='train_audio')
sample_rate = 32000
for i, mel_spec in enumerate(tqdm(specs)):
    filename = f"{i:}.wav"
    if os.path.isfile(f"vaegan_inference/{filename}"):
        continue
    mel_spec = ds.denormalize(mel_spec)
    audio = librosa.feature.inverse.mel_to_audio(mel_spec)
    audio = librosa.resample(audio, sample_rate, 16000)
    sf.write(f"vaegan_inference/{filename}", audio.T, 16000, format='wav')

Seed set to 0


  0%|          | 0/1024 [00:00<?, ?it/s]

In [9]:
if not os.path.isfile("vaegan_embeddings.npy"):
    vggish_extractor = VGGish(checkpoint_path="vggish_model.ckpt", pca_params_path="vggish_pca_params.npz")
    vaegan_embeddings = vggish_extractor.get_embeddings("vaegan_inference")
    with open('vaegan_embeddings.npy', 'wb') as f:
        np.save(f, vaegan_embeddings)
    vggish_extractor.cleanup()

  0%|▏                                         | 4/1024 [00:00<01:11, 14.18it/s]2024-03-14 10:24:28.788714: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-14 10:24:28.788739: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-14 10:24:28.789686: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-14 10:24:28.796366: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild Ten

INFO:tensorflow:Restoring parameters from vggish_model.ckpt


2024-03-14 10:24:33.975434: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
  0%|                                                  | 0/1024 [00:00<?, ?it/s]2024-03-14 10:24:35.057647: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8902
2024-03-14 10:24:35.272181: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2024-03-14 10:24:38.520748: W external/local_tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 566.38MiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-03-14 10:24:39.407183: W external/local_tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.09GiB with freed_by_count=0. The caller indicates that this is