In [None]:
import random
import os
import numpy as np
import sklearn
import torch
from torch.cuda import manual_seed_all
from torch.backends import cudnn
import matplotlib as mpl
from matplotlib import pyplot as plt
import torchaudio
import torchaudio.transforms as T
from torchdata.datapipes.iter import FileLister, FileOpener

PROJECT_ROOT_DIR = "."
CHAPTER_ID = "data"
AUDIO_PATH = os.path.join(PROJECT_ROOT_DIR, "audio", CHAPTER_ID)
os.makedirs(AUDIO_PATH, exist_ok=True)

In [None]:
# pre spectrogram augmentations
# these are examples and can be changed based on domain knowledge

time_stretch = T.TimeStretch()
def stretch_waveform(waveform, rate=1.2):
    # `rate > 1.0` speeds up, `rate < 1.0` slows down
    return time_stretch(waveform, rate)

pitch_shift = T.PitchShift(sample_rate=44100, n_steps=2)  # Shift up by 2 semitones
def shift_pitch(waveform, sample_rate):
    return pitch_shift(waveform)

def scale_volume(waveform, factor=1.5):
    return waveform * factor  # Amplifies waveform by factor

def crop_waveform(waveform, crop_size):
    start = torch.randint(0, max(1, waveform.size(-1) - crop_size), (1,)).item()
    return waveform[:, start:start + crop_size]

def apply_reverb(waveform):
    reverb = T.Reverberate()
    return reverb(waveform)

def time_shift(waveform, shift):
    return torch.roll(waveform, shifts=shift, dims=-1)

def add_noise(waveform, noise_level=0.005):
    noise = torch.randn_like(waveform) * noise_level
    return waveform + noise

# Augment on-the-fly stochastically
# again these are just examples and do not necessarily utilize the methods above
def augment_waveform(data):
    waveform, sample_rate = data
    if torch.rand(1).item() > 0.5:
        waveform += torch.randn_like(waveform) * 0.005
    if torch.rand(1).item() > 0.5:
        waveform = torch.roll(waveform, shifts=torch.randint(-5000, 5000, (1,)).item(), dims=-1)
    if torch.rand(1).item() > 0.5:
        waveform *= torch.FloatTensor(1).uniform_(0.8, 1.5).item()
    return waveform, sample_rate


In [None]:
# Create a MelSpectrogram transformation
mel_spectrogram_transform = T.MelSpectrogram(
    sample_rate=44100,         # Default sample rate, change if needed
    n_fft=1024,                # Number of FFT bins
    hop_length=512,            # Hop length between windows
    n_mels=64                  # Number of Mel bands
)

def waveform_to_spectrogram(data):
    waveform, sample_rate = data
    spectrogram = mel_spectrogram_transform(waveform)  # Apply the spectrogram transformation
    return spectrogram

In [None]:
# post spectrogram augmentations

# Example augmentations, could add more
time_mask = T.TimeMasking(time_mask_param=10)

freq_mask = T.FrequencyMasking(freq_mask_param=8)

# hybridizes two sounds
def mixup(spectrogram1, spectrogram2, alpha=0.2):
    lam = torch.FloatTensor(1).uniform_(0, alpha).item()
    return lam * spectrogram1 + (1 - lam) * spectrogram2

# should probably implement a randomization process like above
def augment_spectrogram(spectrogram):
    augmented = time_mask(spectrogram)  # Apply time masking
    augmented = freq_mask(augmented)   # Apply frequency masking
    return augmented

In [None]:
# Decode audio files
def decode_audio(file_tuple):
    file_path, file = file_tuple
    waveform, sample_rate = torchaudio.load(file_path)
    return waveform, sample_rate

In [None]:
# List audio files
file_list_dp = FileLister(root_path="./UrbanSound8K/audio", masks="*.wav")

# Open audio files lazily
file_opener_dp = FileOpener(file_list_dp, mode="rb")

# Decode audio
audio_dp = file_opener_dp.map(decode_audio)

# Augment raw waveform (pre-spectrogram)
augmented_waveform = audio_dp.map(augment_waveform)

# Transform waveforms to spectrograms
audio_spectrogram_dp = augmented_waveform.map(waveform_to_spectrogram)

# Apply data augmentation (post-spectrogram)
augmented_dp = audio_spectrogram_dp.map(augment_spectrogram)

# Batch and process chunks of data
chunked_dp = augmented_dp.batch(batch_size=32)

for batch in chunked_dp:
    for spectrogram in batch:
        print(spectrogram.size())  # Process each spectrogram