In [41]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import torchaudio
import os

In [45]:
class ShipEarDataset(Dataset):
    def __init__(self, annotations_file, audio_dir, transformation, target_get_sample_rate, num_samples, device):
        self.annotations = pd.read_csv(annotations_file)
        self.audio_dir = audio_dir
        self.device = device
        self.transformation = transformation.to(self.device)
        self.target_get_sample_rate = target_get_sample_rate
        self.num_samples = num_samples

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        signal, sr = torchaudio.load(audio_sample_path)
        signal = signal.to(self.device)
        # signal -> (num_channels, samples)
        label = self._get_audio_sample_label(index)
        signal = self._resample_if_necessary(signal, sr)
        signal = self._mix_down_if_necessary(signal)
        signal = self._cut_if_necessary(signal)
        signal = self._right_pad_if_necessary(signal)
        signal = self.transformation(signal)
        return signal, label

    def _cut_if_necessary(self, signal):
        # signal -> Tensor -> (num_channels, samples)
        if signal.shape[1] > self.num_samples:
            signal = signal[:, :self.num_samples]
        return signal

    def _right_pad_if_necessary(self, signal):
        length_signal = signal.shape[1]
        if length_signal < self.num_samples:
            num_missing_samples = self.num_samples - length_signal
            last_dim_padding = (0, num_missing_samples)
            signal = torch.nn.functional.pad(signal, last_dim_padding)
        return signal

    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_get_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_get_sample_rate).to(self.device)
            signal = resampler(signal)
        return signal

    def _mix_down_if_necessary(self, signal):
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0,  keepdim=True)
        return signal

    def _get_audio_sample_path(self, index):
        filename = self.annotations.iloc[index, 0]
        path = os.path.join(self.audio_dir, filename)
        return path
    
    def _get_audio_sample_label(self, index):
        return self.annotations.iloc[index, 1]

In [49]:
annotations_file = "../label_process/label.csv"
audio_dir = r"E:\数据集\ShipEar\shipsEar_AUDIOS"

SAMPLE_RATE = 22050
NUM_SAMPLES = 44100

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using device: {device}.")

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=2048,
    hop_length=512,
    n_mels=128
)

ship_ear_dataset = ShipEarDataset(annotations_file, audio_dir, mel_spectrogram, SAMPLE_RATE, NUM_SAMPLES, device)
print(f"There are {ship_ear_dataset.__len__()} sample in the dataset.")

audio, label = ship_ear_dataset.__getitem__(0)
print(f"audio.shape: {audio.shape}")
print(f"label: {label}")

Using device: cuda.
There are 90 sample in the dataset.
audio.shape: torch.Size([1, 128, 87])
label: 0
