# Generating bird sounds with a VAE

## Imports

In [1]:
import os
import io

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from tqdm.notebook import tqdm
from IPython.display import display, Audio
from PIL import Image

import soundfile as sf
import librosa
import librosa.display

import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
from lightning.pytorch import loggers as pl_loggers

## Dataset

In [2]:
class BirdClefDataset(Dataset):
    def __init__(
        self,
        root_dir,
        bird_name=None,
        transform=None,
        num_samples=65_500,
        min_db=-80,
        max_db=0,
        cache=True
    ):
        self.root_dir = root_dir
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db

        if cache:
            self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(self.root_dir)), 'cache')
            os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self.cache_dir = None
        self.bird_folders = sorted(os.listdir(root_dir))

        if bird_name is not None:
            self.bird_folders = [bird_name]

        self.audio_files = []

        for bird_folder in self.bird_folders:
            bird_path = os.path.join(root_dir, bird_folder)
            audio_files = [os.path.join(bird_path, file) for file in os.listdir(bird_path) if file.endswith('.ogg')]
            self.audio_files.extend(audio_files)

    def get_spec(self, audio_path):
        waveform, sample_rate = librosa.load(audio_path, sr=None, mono=True)

        if len(waveform) < self.num_samples:
            pad_amount = self.num_samples - len(waveform)
            waveform = np.pad(waveform, (0, pad_amount))
        else:
            waveform = waveform[:self.num_samples]

        mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sample_rate)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return mel_spec

    def normalize(self, x):
        return (x- self.min_db) / (self.max_db - self.min_db)

    def denormalize(self, x):
        if isinstance(x, torch.Tensor):
            x = x.cpu().detach().numpy()

        flattened_array = x.reshape((x.shape[0], -1))

        min_batch_values = flattened_array.min(axis=-1, keepdims=True)
        max_batch_values = flattened_array.max(axis=-1, keepdims=True)

        normalized_array = self.min_db + ((flattened_array - min_batch_values) / (max_batch_values - min_batch_values)) * (self.max_db - self.min_db)

        normalized_batch = normalized_array.reshape(x.shape)

        return normalized_batch

    def cache_all(self):
        self.cache = True
        for idx in range(len(self)):
            self.__getitem__(idx)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]

        if self.cache_dir is not None:
            cache_filename = f"{os.path.basename(audio_path)}_{self.num_samples}.npy"
            cache_path = os.path.join(self.cache_dir, cache_filename)

        if self.cache_dir is not None and os.path.isfile(cache_path):
            mel_spec = np.load(cache_path)
        else:
            mel_spec = self.get_spec(audio_path)

            # Normalize mel_spec
            mel_spec = self.normalize(mel_spec)
            mel_spec = np.expand_dims(mel_spec, axis=0)

            # Save mel spectrogram to cache
        if self.cache_dir is not None:
            np.save(cache_path, mel_spec)

        if self.transform is not None:
            mel_spec = self.transform(mel_spec)

        folder, filename = os.path.split(audio_path)
        basedir, bird = os.path.split(folder)

        return mel_spec, bird, filename

In [3]:
class BirdClefDataModule(pl.LightningDataModule):
    def __init__(self,
                 root_dir,
                 batch_size=64,
                 validation_split=0.2,
                 num_workers=10,
                 bird_name=None,
                 transform=None,
                 num_samples=65_500,
                 min_db=-80,
                 max_db=0,
                 cache=True,
                 seed=0
                ):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.validation_split = validation_split
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db
        self.cache = cache
        self.seed = seed

        self.save_hyperparameters()

    def setup(self, stage=None):
        dataset = BirdClefDataset(self.root_dir, bird_name=self.bird_name)
        self.normalize = dataset.normalize
        self.denormalize = dataset.denormalize
        if stage == 'fit' or stage is None:
            train_dataset, validation_dataset = torch.utils.data.random_split(dataset,
                                                                              (1 - self.validation_split, self.validation_split),
                                                                             torch.Generator().manual_seed(self.seed))
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
            self.validation_loader = DataLoader(validation_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def illustration_dataloader(self, batch_size):
        illustrative_dataset = self.validation_loader.dataset
        illustrative_loader = DataLoader(illustrative_dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers)
        return illustrative_loader

In [4]:
root_directory = 'train_audio'
data_module = BirdClefDataModule(root_directory, batch_size=128, bird_name="rerswa1")

## VAE model

In [5]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def __init__(self, size=16, channels=128):
        super(UnFlatten, self).__init__()
        self.size = size
        self.channels = channels
    def forward(self, input):
        return input.view(input.size(0), self.channels, self.size, self.size)

In [6]:
class VAELoss(torch.nn.modules.loss._Loss):
    __constants__ = ['reduction']

    def __init__(self, beta=1.0, size_average=None, reduce=None, reduction: str = 'mean', reconstruction_loss=F.mse_loss):
        super().__init__(size_average, reduce, reduction)
        self.beta = beta
        self.reconstruction_loss = reconstruction_loss

    def forward(self, recon_x, x, mean, logvar):
        reconstruction_loss = self.reconstruction_loss(recon_x, x, reduction=self.reduction)
        kl_loss = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=1)
        if self.reduction != 'none':
            kl_loss = torch.mean(kl_loss) if self.reduction == 'mean' else torch.sum(kl_loss)
        total_loss = reconstruction_loss + self.beta * kl_loss
        return total_loss, reconstruction_loss, kl_loss


In [7]:
class VariationalAutoEncoder(pl.LightningModule):
    def __init__(
        self,
        input_channels=1,
        img_size=256,
        hidden_size=4096,
        layers=[16, 32, 64],
        learning_rate=0.001,
        beta=1e-3,
        activation=nn.ReLU,
        optimizer=torch.optim.Adam,
        reconstruction_loss=F.mse_loss,
        generate_on_epoch=4,
        reconstruct_on_epoch=4,
        seed=0
    ):
        super(VariationalAutoEncoder, self).__init__()

        if img_size % (2**len(layers)) != 0:
            raise ValueError("An image of size {image_size} with {len(layers)} layers won't be reconstructed with the correct size")

        self.hidden_size = hidden_size
        self.encoder = self.build_encoder(input_channels, layers, activation)
        last_conv_size = (img_size // (2**(len(layers)+1)))**2 * layers[-1]
        self.mean_layer = nn.Linear(last_conv_size, self.hidden_size)
        self.logvar_layer = nn.Linear(last_conv_size, self.hidden_size)
        self.decoder = self.build_decoder(input_channels, layers, img_size, last_conv_size, activation)
        self.learning_rate = learning_rate
        self.optim = optimizer
        self.criterion = VAELoss(beta=beta, reconstruction_loss=reconstruction_loss)
        self.generate_on_epoch = generate_on_epoch
        self.reconstruct_on_epoch = reconstruct_on_epoch
        self.seed = seed
        self.save_hyperparameters()

    def build_encoder(self, input_channels, channels_list, activation):
        layers = []
        in_channels = input_channels
        for out_channels in channels_list:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1))
            layers.append(activation())
            #layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            #layers.append(activation())
            #layers.append(nn.MaxPool2d(2))
            in_channels = out_channels
        layers.append(Flatten())
        return nn.Sequential(*layers)

    def build_decoder(self, input_channels, channels_list, img_size, last_conv_size, activation):
        layers = [
            nn.Linear(self.hidden_size, last_conv_size),
            UnFlatten(img_size // (2**(len(channels_list)+1)), channels_list[-1])
        ]
        for in_channels, out_channels in zip(channels_list[::-1], channels_list[-2::-1]+[input_channels]):
            layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(activation())
            #layers.append(nn.Upsample(scale_factor=2, mode="bilinear"))
            #layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            #layers.append(activation())
        layers.pop()
        layers.append(nn.Sigmoid())
        return nn.Sequential(*layers)

    def encode(self, x):
        x = self.encoder(x)
        mean, logvar = self.mean_layer(x), self.logvar_layer(x)
        return mean, logvar

    def sample(self, mean, logvar, seed=None):
        std = logvar.exp() #torch.exp(0.5 * logvar)
        if seed is None:
            epsilon = torch.randn_like(std)
        else:
            gen = torch.Generator(device=self.device).manual_seed(seed)
            epsilon = torch.empty_like(std).normal_(generator=gen)
        z = mean + std*epsilon
        return z

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.sample(mean, logvar)
        x = self.decode(z)
        return x, mean, logvar, z

    def configure_optimizers(self):
        optimizer = self.optim(self.parameters(), lr=self.learning_rate)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=50, threshold=1e-6, cooldown=50, min_lr=1e-6),
                "monitor": "val_loss",
                "frequency": 1,
                "interval": "epoch"
            }
        }

    def training_step(self, batch, batch_idx):
        x, birdname, file = batch
        outputs, mean, logvar, z = self(x)
        loss, reconstruction_loss, kl_loss = self.criterion(outputs, x, mean, logvar)
        self.log('train_loss', loss, on_epoch=True, on_step=True, prog_bar=True, batch_size=x.shape[0])
        self.log('train_reconstruction', reconstruction_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train_kl', kl_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        tensorboard = self.logger.experiment
        tensorboard.add_histogram("train_z", z, self.global_step)
        tensorboard.add_histogram("train_mean", mean, self.global_step)
        tensorboard.add_histogram("train_logvar", logvar, self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        x, birdname, file = batch
        outputs, mean, logvar, z = self(x)
        loss, reconstruction_loss, kl_loss = self.criterion(outputs, x, mean, logvar)
        self.log('val_loss', loss, on_epoch=True, on_step=True, prog_bar=True, batch_size=x.shape[0])
        self.log('val_reconstruction', reconstruction_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('val_kl', kl_loss, on_epoch=True, on_step=True, batch_size=x.shape[0])
        tensorboard = self.logger.experiment
        tensorboard.add_histogram("val_z", z, self.global_step)
        tensorboard.add_histogram("val_mean", mean, self.global_step)
        tensorboard.add_histogram("val_logvar", logvar, self.global_step)
        return loss

    def generate_specs(self, n=None):
        if n is None:
            n = self.trainer.datamodule.batch_size
        mean = torch.zeros([n, self.hidden_size]).to(self.device)
        logvar = torch.zeros([n, self.hidden_size]).to(self.device)
        z = self.sample(mean, logvar, seed=0)
        specs = self.decode(z)

        return self.trainer.datamodule.denormalize(specs[:, 0])

    def spec_to_img(self, spec):
        fig, ax = plt.subplots()
        img = librosa.display.specshow(spec, x_axis='time', y_axis='mel', sr=32000, ax=ax)
        fig.colorbar(img, ax=ax, format='%+2.0f dB')
        ax.set(title='Mel-frequency spectrogram')

        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)

        plt.close(fig)
        image = Image.open(buffer)
        image = image.convert('RGB')
        image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)  # Convert to tensor and adjust dimensions

        return image_tensor

    def reconstruction_to_img(self, original_spec, reconstructed_spec):
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))  # Create subplots with 1 row and 2 columns
        img1 = librosa.display.specshow(original_spec, x_axis='time', y_axis='mel', sr=32000, ax=axs[0])
        fig.colorbar(img1, ax=axs[0], format='%+2.0f dB')
        axs[0].set(title='Mel-frequency spectrogram - original')

        img2 = librosa.display.specshow(reconstructed_spec, x_axis='time', y_axis='mel', sr=32000, ax=axs[1])
        fig.colorbar(img2, ax=axs[1], format='%+2.0f dB')
        axs[1].set(title='Mel-frequency spectrogram - reconstructed')

        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)

        plt.close(fig)
        image = Image.open(buffer)
        image = image.convert('RGB')
        image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)  # Convert to tensor and adjust dimensions

        return image_tensor

    def interpolate(self, batch1, batch2):
        pass

    def reconstruct(self, batch):
        output, _, _, _ = self(batch)
        return output

    def on_validation_end(self):
        tensorboard = self.logger.experiment

        specs = self.generate_specs(self.generate_on_epoch)
        for i, spec in enumerate(specs):
            tensorboard.add_image(f"generated_spectrogram_{i}", self.spec_to_img(spec), self.global_step)
            audio = librosa.feature.inverse.mel_to_audio(spec)
            tensorboard.add_audio(f"generated_audio_{i}", audio, self.global_step, 32000)

        reconstruction_dataloader = self.trainer.datamodule.illustration_dataloader(self.reconstruct_on_epoch)
        originals, birdnames, files = next(iter(reconstruction_dataloader))
        originals = originals.to(self.device)
        reconstructed = self.reconstruct(originals)
        originals = self.trainer.datamodule.denormalize(originals)[:, 0]
        reconstructed = self.trainer.datamodule.denormalize(reconstructed)[:, 0]
        for i, (original_spec, reconstructed_spec) in enumerate(zip(originals, reconstructed)):
            tensorboard.add_image(
                f"reconstructed_spectrogram_{i}",
                self.reconstruction_to_img(original_spec, reconstructed_spec),
                self.global_step
            )
            original_audio = librosa.feature.inverse.mel_to_audio(original_spec)
            reconstructed_audio = librosa.feature.inverse.mel_to_audio(reconstructed_spec)
            tensorboard.add_audio(f"original_audio_{i}", original_audio, self.global_step, 32000)
            tensorboard.add_audio(f"reconstructed_audio_{i}", reconstructed_audio, self.global_step, 32000)

    def on_train_start(self):
        self.logger.log_hyperparams(self.hparams, {"hp/lr": self.trainer.lr_scheduler_configs[0].scheduler.optimizer.param_groups[0]["lr"]})


In [8]:
seed = 0
pl.seed_everything(seed, workers=True)
model = VariationalAutoEncoder(
    input_channels=1,
    img_size=256,
    hidden_size=8192,
    layers=[32, 64, 128, 256, 512],
    learning_rate=0.001,
    beta=5e-5,
    #activation=nn.ReLU,
    #optimizer=torch.optim.AdamW,
    #reconstruction_loss=F.l1_loss,
    seed=seed
)

Seed set to 0


## Training

In [9]:
num_epochs = 1000


tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/", default_hp_metric=False)
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')
model_chkpt = pl.callbacks.ModelCheckpoint(save_top_k=1, monitor="val_loss")
trainer = pl.Trainer(
    max_epochs=num_epochs,
    accelerator='auto',
    log_every_n_steps=1,
    callbacks=[lr_monitor, model_chkpt],
    logger=tb_logger,
    #overfit_batches=10,
    #enable_checkpointing=False,
    deterministic=True
)
#tuner = pl.tuner.Tuner(trainer)
#tuner.scale_batch_size(model, mode="binsearch", datamodule=data_module, init_val=64)
trainer.fit(model, data_module)#, ckpt_path="logs/lightning_logs/version_0/checkpoints/epoch=294-step=62540.ckpt")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params
--------------------------------------------
0 | encoder      | Sequential | 1.6 M 
1 | mean_layer   | Linear     | 67.1 M
2 | logvar_layer | Linear     | 67.1 M
3 | decoder      | Sequential | 68.7 M
4 | criterion    | VAELoss    | 0     
--------------------------------------------
204 M     Trainable params
0         Non-trainable params
204 M     Total params
817.947   Total estimated model params size (MB)


Sanity Checking: |                                                     | 0/? [00:00<?, ?it/s]

  return F.conv2d(input, weight, bias, self.stride,


Training: |                                                            | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacty of 5.79 GiB of which 67.75 MiB is free. Process 1683681 has 3.12 GiB memory in use. Including non-PyTorch memory, this process has 1.99 GiB memory in use. Of the allocated memory 1.72 GiB is allocated by PyTorch, and 45.19 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF