In [25]:
import os

from torch.utils.data import Dataset
import pandas as pd
import torchaudio

In [28]:
class UrbanSoundDataset(Dataset):
    # Constructor
    def __init__(self, annotations_file, audio_dir, transformation, target_sample_rate, num_samples):
        self.annotations = pd.read_csv(annotations_file)
        self.audio_dir = audio_dir
        self.transformation = transformation
        self.target_sample_rate = target_sample_rate
        self.num_samples = num_samples
    

    def __len__(self):
        # return the number of samples in dataset
        return len(self.annotations)

    # (under hood using get item) ex: a_list[1] -> a_list.__getitem__(1) 
    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        label = self._get_audio_sample_label(index)
        # Load in the audio file as signal:
        signal, sr = torchaudio.load(audio_sample_path)
        # Resample the signal if needed:
        signal = self._resample_if_necessary(signal, sr)
        # Transform the signal to mono for Spectrogram if needed:
        signal = self._mono_if_necessary(signal)
        # Cut the signal length if necessary (only handling signals with length >= num_samples):
        signal = self._cut_if_necessary(signal)
        # add padding if necessary:
        signal = self._add_padding_if_necessary(signal)
        # Pass the signal to the transformation (MelSpectrogram):
        signal = self.transformation(signal)
        return signal, label

    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
            signal = resampler(signal)
        return signal
    
    def _mono_if_necessary(self, signal):
        if signal.shape[0] > 1:
            signal = signal.mean(dim=0, keepdim=True)
        return signal

    # indexes below refer to the columns in the annotations file (.csv)
    def _get_audio_sample_path(self, index):
        fold = f"fold{self.annotations.iloc[index, 5]}"
        path = os.path.join(self.audio_dir, fold, self.annotations.iloc[index, 0])
        return path

    def _get_audio_sample_label(self, index):
        return self.annotations.iloc[index, 6]


In [30]:
if __name__ == "__main__":
    ANNOTATIONS_FILE = "UrbanSound8K/metadata/UrbanSound8K.csv"
    AUDIO_DIR = "UrbanSound8K/audio/"
    SAMPLE_RATE = 22050
    # Number of Samples we want to process
    NUM_SAMPLES = 22050

    # calling for mel spectrogram from PyTorch Transforms
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(sample_rate=SAMPLE_RATE, n_fft=1024, hop_length=512, n_mels=64)

    usd = UrbanSoundDataset(ANNOTATIONS_FILE, AUDIO_DIR, mel_spectrogram, SAMPLE_RATE, NUM_SAMPLES)

    print(f"there are {len(usd)} samples in the dataset")
    signal, label = usd[0]


there are 8732 samples in the dataset
