In [1]:
import os
import io

import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from tqdm.notebook import tqdm
from IPython.display import display, Audio
from PIL import Image

import soundfile as sf
import librosa
import librosa.display

import torchaudio
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import pytorch_lightning as pl
from lightning.pytorch import loggers as pl_loggers



from aquatk.embedding_extractors import VGGish

2024-03-14 00:19:53.162187: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-14 00:19:53.162232: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-14 00:19:53.164508: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-14 00:19:53.179327: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class BirdClefDataset(Dataset):
    def __init__(
        self,
        root_dir,
        bird_name=None,
        transform=None,
        num_samples=65_500,
        min_db=-80,
        max_db=0,
        cache=True
    ):
        self.root_dir = root_dir
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db

        if cache:
            self.cache_dir = os.path.join(os.path.dirname(os.path.abspath(self.root_dir)), 'cache')
            os.makedirs(self.cache_dir, exist_ok=True)
        else:
            self.cache_dir = None
        self.bird_folders = sorted(os.listdir(root_dir))

        if bird_name is not None:
            self.bird_folders = [bird_name]

        self.audio_files = []

        for bird_folder in self.bird_folders:
            bird_path = os.path.join(root_dir, bird_folder)
            audio_files = [os.path.join(bird_path, file) for file in os.listdir(bird_path) if file.endswith('.ogg')]
            self.audio_files.extend(audio_files)

    def get_spec(self, audio_path):
        waveform, sample_rate = librosa.load(audio_path, sr=None, mono=True)

        if len(waveform) < self.num_samples:
            pad_amount = self.num_samples - len(waveform)
            waveform = np.pad(waveform, (0, pad_amount))
        else:
            waveform = waveform[:self.num_samples]

        mel_spec = librosa.feature.melspectrogram(y=waveform, sr=sample_rate)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)

        return mel_spec

    def normalize(self, x):
        return (x- self.min_db) / (self.max_db - self.min_db)

    def denormalize(self, x):
        if isinstance(x, torch.Tensor):
            x = x.cpu().detach().numpy()

        flattened_array = x.reshape((x.shape[0], -1))

        min_batch_values = flattened_array.min(axis=-1, keepdims=True)
        max_batch_values = flattened_array.max(axis=-1, keepdims=True)

        normalized_array = self.min_db + ((flattened_array - min_batch_values) / (max_batch_values - min_batch_values)) * (self.max_db - self.min_db)

        normalized_batch = normalized_array.reshape(x.shape)

        return normalized_batch

    def cache_all(self):
        self.cache = True
        for idx in range(len(self)):
            self.__getitem__(idx)

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]

        if self.cache_dir is not None:
            cache_filename = f"{os.path.basename(audio_path)}_{self.num_samples}.npy"
            cache_path = os.path.join(self.cache_dir, cache_filename)

        if self.cache_dir is not None and os.path.isfile(cache_path):
            try:
                mel_spec = np.load(cache_path)
            except Exception as e:
                raise IOError(f"Failed to read file: {cache_path}. Error: {e}")
        else:
            mel_spec = self.get_spec(audio_path)

            # Normalize mel_spec
            mel_spec = self.normalize(mel_spec)
            mel_spec = np.expand_dims(mel_spec, axis=0)

            # Save mel spectrogram to cache
            if self.cache_dir is not None:
                np.save(cache_path, mel_spec)

        if self.transform is not None:
            mel_spec = self.transform(mel_spec)

        folder, filename = os.path.split(audio_path)
        basedir, bird = os.path.split(folder)

        return mel_spec, bird, filename

In [3]:
class BirdClefDataModule(pl.LightningDataModule):
    def __init__(self,
                 root_dir,
                 batch_size=64,
                 validation_split=0.2,
                 num_workers=10,
                 bird_name=None,
                 transform=None,
                 num_samples=65_500,
                 min_db=-80,
                 max_db=0,
                 cache=True,
                 seed=0
                ):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.validation_split = validation_split
        self.bird_name = bird_name
        self.transform = transform
        self.num_samples = num_samples
        self.min_db = min_db
        self.max_db = max_db
        self.cache = cache
        self.seed = seed

        self.save_hyperparameters()

    def setup(self, stage=None):
        dataset = BirdClefDataset(self.root_dir, bird_name=self.bird_name)
        self.normalize = dataset.normalize
        self.denormalize = dataset.denormalize
        if stage == 'fit' or stage is None:
            train_dataset, validation_dataset = torch.utils.data.random_split(dataset,
                                                                              (1 - self.validation_split, self.validation_split),
                                                                             torch.Generator().manual_seed(self.seed))
            self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
            self.validation_loader = DataLoader(validation_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.validation_loader

    def illustration_dataloader(self, batch_size):
        illustrative_dataset = self.validation_loader.dataset
        illustrative_loader = DataLoader(illustrative_dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers)
        return illustrative_loader

In [4]:
root_directory = 'train_audio'
data_module = BirdClefDataModule(root_directory, batch_size=128)
data_module.setup()

In [None]:
val_dataset = data_module.validation_loader.dataset
ds = BirdClefDataset(root_dir='train_audio')
sample_rate = 32000
for mel_spec, bird, filename in tqdm(val_dataset):
    filename = filename[:-3] + "wav"
    if os.path.isfile(f"validation_audio/{filename}"):
        continue
    mel_spec = ds.denormalize(mel_spec)
    audio = librosa.feature.inverse.mel_to_audio(mel_spec)
    audio = librosa.resample(audio, sample_rate, 16000)
    sf.write(f"validation_audio/{filename}", audio.T, 16000, format='wav')

  0%|          | 0/3388 [00:00<?, ?it/s]

  audio = librosa.resample(audio, sample_rate, 16000)


In [None]:
if not os.path.isfile("validation_embeddings.npy"):
    vggish_extractor = VGGish(checkpoint_path="vggish_model.ckpt", pca_params_path="vggish_pca_params.npz")
    validation_embeddings = vggish_extractor.get_embeddings("validation_audio")
    with open('validation_embeddings.npy', 'wb') as f:
        np.save(f, validation_embeddings)
    vggish_extractor.cleanup()

In [None]:
train_dataset = data_module.train_loader.dataset
ds = BirdClefDataset(root_dir='train_audio')
sample_rate = 32000
for mel_spec, bird, filename in tqdm(val_dataset):
    filename = filename[:-3] + "wav"
    if os.path.isfile(f"train_audio_processed/{filename}"):
        continue
    mel_spec = ds.denormalize(mel_spec)
    audio = librosa.feature.inverse.mel_to_audio(mel_spec)
    audio = librosa.resample(audio, sample_rate, 16000)
    sf.write(f"train_audio_processed/{filename}", audio.T, 16000, format='wav')

In [None]:
if not os.path.isfile("train_embeddings.npy"):
    vggish_extractor = VGGish(checkpoint_path="vggish_model.ckpt", pca_params_path="vggish_pca_params.npz")
    train_embeddings = vggish_extractor.get_embeddings("train_audio_processed")
    with open('train_embeddings.npy', 'wb') as f:
        np.save(f, validation_embeddings)
    vggish_extractor.cleanup()

In [None]:
from aquatk.metrics.frechet_distance import frechet_audio_distance

In [None]:
frechet_audio_distance(train_embeddings.reshape((-1, 128), validation_embeddings.reshape((-1, 128))