In [None]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, data_transform=None, sample_duration=5):
        self.root_dir = root_dir
        self.data_transform = data_transform
        self.classes = sorted(os.listdir(root_dir))
        self.sample_duration = sample_duration

        self.data = []
        self.labels = []
        for i, class_name in enumerate(self.classes):
            class_path = os.path.join(root_dir, class_name)
            for filename in os.listdir(class_path):
                filepath = os.path.join(class_path, filename)
                self.data.append(filepath)
                self.labels.append(i)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        waveform, _ = torchaudio.load(self.data[idx], normalize=True)

        # Resample to 16 kHz
        resample_transform = T.Resample(orig_freq=waveform.size(1), new_freq=16000)
        waveform = resample_transform(waveform)

        # Trim to the first 5 seconds
        if waveform.size(1) > 16000 * self.sample_duration:
            waveform = waveform[:, :16000 * self.sample_duration]

        # Convert PyTorch tensor to NumPy array
        waveform_np = waveform.numpy()

        # Apply data transformations if provided
        if self.data_transform:
            waveform_np = self.data_transform(waveform_np)

        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return waveform_np, label
