In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import torchaudio
import plotly.express as px
from IPython.display import Audio
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

### Loading an audio file  
Loading and audio file from UrbanSound8K dataset, or provide a short `.wav` file 

In [None]:
# load sound
audio_sample_path = Path('../data/audio/fold1/106905-8-0-1.wav')
signal, sr = torchaudio.load(str(audio_sample_path))
# reduce to one channel
signal = signal.mean(axis=0)  
# resample to 22050 Hz
resampler = torchaudio.transforms.Resample(sr, 22050)
signal, sr = resampler(signal), 22050
#
Audio(signal, rate=sr)

### Data Augmentation
With librosa, we use two functions to slighly modify the sample :  
- Pitch shifting 
- Time streching

In [None]:
import librosa
pitch_shift = np.random.randint(-2, 3)  # shifting pitch from -2 to 2 semi tone
time_stretch = np.random.random() * (1.2 - 0.9) + 0.9  # speed up if > 1, slow down if < 1
print(pitch_shift, time_stretch)

In [None]:
signal_augmented = torch.tensor(librosa.effects.time_stretch(librosa.effects.pitch_shift(signal.numpy(), sr=sr, n_steps=pitch_shift), rate=time_stretch))
# Then the signal is cut or padded to get the same length between the original and the augmented signal
delta_ln = len(signal) - len(signal_augmented)
if delta_ln < 0:
    signal_augmented = signal_augmented[:len(signal)]
else:
    signal_augmented = torch.hstack([signal_augmented, torch.zeros(delta_ln)])
#
Audio(signal_augmented, rate=sr)

### Mel Spectrogram 

In [None]:
from torchaudio.transforms import MelSpectrogram
N_FFT, HOP_LENGTH, N_MELS = 1024, 256, 128
transform = MelSpectrogram(
    sample_rate=sr,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_mels=N_MELS, 
    center=False
)

In [None]:
mel_spec = transform(signal)
mel_spec_augmented = transform(signal_augmented)
# normalize Mel Spectrogram
mel_spec /= torch.max(mel_spec)
mel_spec_augmented /=torch.max(mel_spec_augmented)

In [None]:
# visiualize with plotly
mel_specs = torch.dstack([mel_spec, mel_spec_augmented])
fig = px.imshow(mel_specs, animation_frame=2, aspect="auto")
fig