# Generating bird sounds with a VAE + GAN loss

## Imports

In [1]:
import os
import io

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from tqdm.notebook import tqdm
from IPython.display import display, Audio
from PIL import Image

import soundfile as sf
import librosa
import librosa.display

import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
from lightning.pytorch import loggers as pl_loggers

## Dataset

In [2]:
class BirdClefDataset(Dataset):
    def __init__(
        self,
        root_dir,
        bird_name=None,
        transform=None,
        num_samples=65_500,
        min_db=-80,
        max_db=0,
        cache=True
    ):
        self.root_dir = root_dir
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db

        if cache:
            self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(self.root_dir)), 'cache')
            os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self.cache_dir = None
        self.bird_folders = sorted(os.listdir(root_dir))

        if bird_name is not None:
            self.bird_folders = [bird_name]

        self.audio_files = []

        for bird_folder in self.bird_folders:
            bird_path = os.path.join(root_dir, bird_folder)
            audio_files = [os.path.join(bird_path, file) for file in os.listdir(bird_path) if file.endswith('.ogg')]
            self.audio_files.extend(audio_files)

    def get_spec(self, audio_path):
        waveform, sample_rate = librosa.load(audio_path, sr=None, mono=True)

        if len(waveform) < self.num_samples:
            pad_amount = self.num_samples - len(waveform)
            waveform = np.pad(waveform, (0, pad_amount))
        else:
            waveform = waveform[:self.num_samples]

        mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sample_rate)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return mel_spec

    def normalize(self, x):
        return (x- self.min_db) / (self.max_db - self.min_db)

    def denormalize(self, x):
        if isinstance(x, torch.Tensor):
            x = x.cpu().detach().numpy()

        flattened_array = x.reshape((x.shape[0], -1))

        min_batch_values = flattened_array.min(axis=-1, keepdims=True)
        max_batch_values = flattened_array.max(axis=-1, keepdims=True)

        normalized_array = self.min_db + ((flattened_array - min_batch_values) / (max_batch_values - min_batch_values)) * (self.max_db - self.min_db)

        normalized_batch = normalized_array.reshape(x.shape)

        return normalized_batch

    def cache_all(self):
        self.cache = True
        for idx in range(len(self)):
            self.__getitem__(idx)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]

        if self.cache_dir is not None:
            cache_filename = f"{os.path.basename(audio_path)}_{self.num_samples}.npy"
            cache_path = os.path.join(self.cache_dir, cache_filename)

        if self.cache_dir is not None and os.path.isfile(cache_path):
            try:
                mel_spec = np.load(cache_path)
            except Exception as e:
                raise IOError(f"Failed to read file: {cache_path}. Error: {e}")
        else:
            mel_spec = self.get_spec(audio_path)

            # Normalize mel_spec
            mel_spec = self.normalize(mel_spec)
            mel_spec = np.expand_dims(mel_spec, axis=0)

            # Save mel spectrogram to cache
            if self.cache_dir is not None:
                np.save(cache_path, mel_spec)

        if self.transform is not None:
            mel_spec = self.transform(mel_spec)

        folder, filename = os.path.split(audio_path)
        basedir, bird = os.path.split(folder)

        return mel_spec, bird, filename

In [3]:
class BirdClefDataModule(pl.LightningDataModule):
    def __init__(self,
                 root_dir,
                 batch_size=64,
                 validation_split=0.2,
                 num_workers=10,
                 bird_name=None,
                 transform=None,
                 num_samples=65_500,
                 min_db=-80,
                 max_db=0,
                 cache=True,
                 seed=0
                ):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.validation_split = validation_split
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db
        self.cache = cache
        self.seed = seed

        self.save_hyperparameters()

    def setup(self, stage=None):
        dataset = BirdClefDataset(self.root_dir, bird_name=self.bird_name)
        self.normalize = dataset.normalize
        self.denormalize = dataset.denormalize
        if stage == 'fit' or stage is None:
            train_dataset, validation_dataset = torch.utils.data.random_split(dataset,
                                                                              (1 - self.validation_split, self.validation_split),
                                                                             torch.Generator().manual_seed(self.seed))
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
            self.validation_loader = DataLoader(validation_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def illustration_dataloader(self, batch_size):
        illustrative_dataset = self.validation_loader.dataset
        illustrative_loader = DataLoader(illustrative_dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers)
        return illustrative_loader

In [4]:
root_directory = 'train_audio'
data_module = BirdClefDataModule(root_directory, batch_size=128)#, bird_name="rerswa1")

## VAE-GAN model

In [5]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class UnFlatten(nn.Module):
    def __init__(self, size=16, channels=128):
        super(UnFlatten, self).__init__()
        self.size = size
        self.channels = channels
    def forward(self, input):
        return input.view(input.size(0), self.channels, self.size, self.size)

In [6]:
class GANLoss(torch.nn.modules.loss._Loss):
    __constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean'):
        super().__init__(size_average, reduce, reduction)

    def forward(self, sampled_classif, original_classif):
        loss_discriminator_samples = F.binary_cross_entropy_with_logits(
            sampled_classif,
            torch.ones_like(sampled_classif),
            reduction=self.reduction
        )
        loss_discriminator_original = F.binary_cross_entropy_with_logits(
            original_classif,
            torch.zeros_like(original_classif),
            reduction=self.reduction
        )
        loss_discriminator = 1/2 * (loss_discriminator_samples + loss_discriminator_original)

        loss_generator = F.binary_cross_entropy_with_logits(
            sampled_classif,
            torch.zeros_like(sampled_classif),
            reduction=self.reduction
        )

        return (
            loss_discriminator,
            loss_generator,
            loss_discriminator_samples,
            loss_discriminator_original
        )

In [7]:
class GAN(pl.LightningModule):
    def __init__(
        self,
        img_channels=1,
        img_size=256,
        input_size=4096,
        layers=[16, 32, 64],
        learning_rate=0.001,
        lr_decay=1,
        activation=nn.ReLU,
        optimizer=torch.optim.Adam,
        generate_on_epoch=4,
        generator_too_good=.6, # value of the discriminator's loss above which the generator shouldn't be trained
        discriminator_too_good=.3, # value of the discriminator's loss above which the discriminator shouldn't be trained
        seed=0
    ):
        super(GAN, self).__init__()

        if img_size % (2**len(layers)) != 0:
            raise ValueError("An image of size {image_size} with {len(layers)} layers won't work")

        self.input_size = input_size
        last_conv_size = (img_size // (2**(len(layers)+1)))**2 * layers[-1]
        self.generator = self.build_generator(img_channels, layers, img_size, last_conv_size, activation)
        self.discriminator = self.build_discriminator(img_channels, layers, activation, last_conv_size)

        self.generator_too_good = generator_too_good
        self.discriminator_too_good = discriminator_too_good

        self.learning_rate = learning_rate
        self.lr_decay = lr_decay
        self.optim = optimizer
        self.criterion = GANLoss()
        self.generate_on_epoch = generate_on_epoch
        self.seed = seed

        self.prev_disc_loss = None

        self.automatic_optimization = False

        self.save_hyperparameters()

    def build_discriminator(self, input_channels, channels_list, activation, last_conv_size):
        layers = []
        in_channels = input_channels
        for out_channels in channels_list:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1))
            layers.append(activation())
            in_channels = out_channels
        layers.append(Flatten())
        layers.append(nn.Linear(last_conv_size, 1))
        return nn.Sequential(*layers)


    def build_generator(self, input_channels, channels_list, img_size, last_conv_size, activation):
        layers = [
            nn.Linear(self.input_size, last_conv_size),
            UnFlatten(img_size // (2**(len(channels_list)+1)), channels_list[-1])
        ]
        for in_channels, out_channels in zip(channels_list[::-1], channels_list[-2::-1]+[input_channels]):
            layers.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1))
            layers.append(activation())
            #layers.append(nn.Upsample(scale_factor=2, mode="bilinear"))
            #layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1))
            #layers.append(activation())
        layers.pop()
        layers.append(nn.Sigmoid())
        return nn.Sequential(*layers)


    def sample(self, batch_size, seed=None):
        if seed is None:
            z = torch.randn((batch_size, self.input_size), device=self.device)
        else:
            gen = torch.Generator(device=self.device).manual_seed(seed)
            z = torch.empty((batch_size, self.input_size), device=self.device).normal_(generator=gen)
        return z

    def generate(self, x):
        return self.generator(x)

    def forward(self, originals):
        # generate
        z = self.sample(originals.shape[0])
        generated = self.generate(z)

        # discriminate
        logits_originals = self.discriminator(originals)
        logits_generated = self.discriminator(generated)

        return generated, logits_originals, logits_generated

    def configure_optimizers(self):
        if isinstance(self.learning_rate, tuple):
            lr_generator, lr_discriminator = self.learning_rate
        else:
            lr_generator = lr_discriminator = self.learning_rate

        if isinstance(self.lr_decay, tuple):
            lr_decay_generator, lr_decay_discriminator = self.lr_decay
        else:
            lr_decay_generator = lr_decay_discriminator = self.lr_decay

        optimizer_generator = self.optim(self.generator.parameters(), lr=lr_generator)
        optimizer_discriminator = self.optim(self.discriminator.parameters(), lr=lr_discriminator)

        scheduler_generator = {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer_generator, gamma=lr_decay_generator),
        }
        scheduler_discriminator = {
            'scheduler': torch.optim.lr_scheduler.ExponentialLR(optimizer_discriminator, gamma=lr_decay_discriminator),
        }

        return (
            [
                optimizer_generator,
                optimizer_discriminator
            ],
            [
                scheduler_discriminator,
                scheduler_discriminator
            ]
        )


    def training_step(self, batch, batch_idx):
        optimizer_generator, optimizer_discriminator = self.optimizers()
        #self.toggle_optimizer(optimizer_encoder)
        #self.toggle_optimizer(optimizer_decoder)
        #self.toggle_optimizer(optimizer_discriminator)
        x, birdname, file = batch
        generated, logits_originals, logits_generated = self(x)
        loss_discriminator, loss_generator, loss_discriminator_samples, loss_discriminator_original = self.criterion(logits_generated, logits_originals)

        optimizer_generator.zero_grad()
        optimizer_discriminator.zero_grad()

        # we apply the loss to the generator and then prevent it from being affected by other losses
        loss_generator.backward(retain_graph=True)
        for group in optimizer_generator.param_groups:
            for param in group['params']:
                param.requires_grad = False

        # we reset the discriminator that has been affected by the generator's loss
        optimizer_discriminator.zero_grad()

        loss_discriminator.backward()

        if self.prev_disc_loss is None or self.prev_disc_loss < self.generator_too_good:
            optimizer_generator.step()
            self.log('train/active/generator', 1, on_epoch=False, on_step=True, batch_size=x.shape[0])
        else:
            self.log('train/active/generator', 0, on_epoch=False, on_step=True, batch_size=x.shape[0])
        if self.prev_disc_loss is None or self.prev_disc_loss > self.discriminator_too_good:
            optimizer_discriminator.step()
            self.log('train/active/discriminator', 1, on_epoch=False, on_step=True, batch_size=x.shape[0])
        else:
            self.log('train/active/discriminator', 0, on_epoch=False, on_step=True, batch_size=x.shape[0])

        self.prev_disc_loss = loss_discriminator.detach()

        # resettings things to thier normal states
        for group in optimizer_generator.param_groups:
            for param in group['params']:
                param.requires_grad = True

        self.log('train/loss_generator', loss_generator, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_discriminator', loss_discriminator, on_epoch=True, on_step=True, batch_size=x.shape[0])

        self.log('train/loss_discriminator_samples', loss_discriminator_samples, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('train/loss_discriminator_original', loss_discriminator_original, on_epoch=True, on_step=True, batch_size=x.shape[0])

        tensorboard = self.logger.experiment

        tensorboard.add_histogram(
            "train/pred_originals",
            F.sigmoid(logits_originals),
            self.trainer.num_training_batches * self.current_epoch + batch_idx
        )

        tensorboard.add_histogram(
            "trian/pred_generated",
            F.sigmoid(logits_generated),
            self.trainer.num_training_batches* self.current_epoch + batch_idx
        )

    def validation_step(self, batch, batch_idx):
        x, birdname, file = batch
        generated, logits_originals, logits_generated = self(x)
        loss_discriminator, loss_generator, loss_discriminator_samples, loss_discriminator_original = self.criterion(logits_generated, logits_originals)

        self.log('validation/loss_generator', loss_generator, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_discriminator', loss_discriminator, on_epoch=True, on_step=False, batch_size=x.shape[0])

        self.log('validation/loss_discriminator_samples', loss_discriminator_samples, on_epoch=True, on_step=True, batch_size=x.shape[0])
        self.log('validation/loss_discriminator_original', loss_discriminator_original, on_epoch=True, on_step=True, batch_size=x.shape[0])


        tensorboard = self.logger.experiment

        tensorboard.add_histogram(
            "validation/pred_originals",
            F.sigmoid(logits_originals),
            self.trainer.num_val_batches[0] * self.current_epoch + batch_idx
        )
        tensorboard.add_histogram(
            "validation/pred_generated",
            F.sigmoid(logits_generated),
            self.trainer.num_val_batches[0] * self.current_epoch + batch_idx
        )

    def generate_specs(self, n=None):
        if n is None:
            n = self.trainer.datamodule.batch_size
        z = self.sample(n, seed=0)
        specs = self.generate(z)

        return self.trainer.datamodule.denormalize(specs[:, 0])

    def spec_to_img(self, spec):
        fig, ax = plt.subplots()
        img = librosa.display.specshow(spec, x_axis='time', y_axis='mel', sr=32000, ax=ax)
        fig.colorbar(img, ax=ax, format='%+2.0f dB')
        ax.set(title='Mel-frequency spectrogram')

        buffer = io.BytesIO()
        plt.savefig(buffer, format='png')
        buffer.seek(0)

        plt.close(fig)
        image = Image.open(buffer)
        image = image.convert('RGB')
        image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1)  # Convert to tensor and adjust dimensions

        return image_tensor

    def on_validation_end(self):
        tensorboard = self.logger.experiment

        specs = self.generate_specs(self.generate_on_epoch)
        for i, spec in enumerate(specs):
            tensorboard.add_image(f"generated_spectrogram_{i}", self.spec_to_img(spec), self.global_step)
            audio = librosa.feature.inverse.mel_to_audio(spec)
            tensorboard.add_audio(f"generated_audio_{i}", audio, self.global_step, 32000)

    def on_train_epoch_end(self):
        for scheduler in self.lr_schedulers():
            scheduler.step()

    def on_train_start(self):
        self.logger.log_hyperparams(self.hparams)


In [8]:
seed = 0
num_epochs = 100

pl.seed_everything(seed, workers=True)

model = GAN(
    img_channels=1,
    img_size=256,
    input_size=8192,
    layers=[32, 64, 128, 256, 256],
    learning_rate=1e-4,
    lr_decay=(1/100)**(1/num_epochs),
    seed=seed,
    generator_too_good=.65,
    discriminator_too_good=.5,
)

Seed set to 0


## Training

In [None]:
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval='epoch')

tb_logger = pl_loggers.TensorBoardLogger(save_dir="gan/", default_hp_metric=False)
trainer = pl.Trainer(
    max_epochs=num_epochs,
    accelerator='auto',
    log_every_n_steps=1,
    logger=tb_logger,
    deterministic=True,
    callbacks=[lr_monitor]
)
trainer.fit(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params
---------------------------------------------
0 | generator     | Sequential | 34.5 M
1 | discriminator | Sequential | 982 K 
2 | criterion     | GANLoss    | 0     
---------------------------------------------
35.5 M    Trainable params
0         Non-trainable params
35.5 M    Total params
142.073   Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

  return F.conv_transpose2d(


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]



Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]