In [1]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

# import os
# os.chdir('/content/drive/MyDrive/Colab Notebooks/swishnet')

# import sys
# sys.path.append(os.getcwd())

In [3]:
suffix = "_16K_32000cut"
suffix_random = "_16K_random_cut"

In [4]:
class train_data(Dataset):
    def __init__(self):
        super().__init__()
        self.train_car_data_dir = f'./train/car/train_car_data{suffix}.pt'
        self.train_music_data_dir = f'./train/music/train_music_data{suffix}.pt'
        self.train_noise_data_dir = f'./train/noise/train_noise_data{suffix}.pt'
        self.train_speech_data_dir = f'./train/speech/train_speech_data{suffix}.pt'

        self.car_data = torch.load(self.train_car_data_dir)
        self.music_data = torch.load(self.train_music_data_dir)
        self.noise_data = torch.load(self.train_noise_data_dir)
        self.speech_data = torch.load(self.train_speech_data_dir)
        
        for i in range(len(self.car_data)):
            self.car_data[i] = (self.car_data[i], 0)    
        for i in range(len(self.music_data)):
            self.music_data[i] = (self.music_data[i], 1)
        for i in range(len(self.speech_data)):
            self.speech_data[i] = (self.speech_data[i], 2)
        for i in range(len(self.noise_data)):
            self.noise_data[i] = (self.noise_data[i], 3)

        self.data = self.car_data + self.music_data + self.speech_data + self.noise_data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,x):
        return self.data[x][0], self.data[x][1]


In [None]:
class eval_data(Dataset):
    def __init__(self):
        super().__init__()

        self.eval_car_data_dir = f'./eval/car/eval_car_data{suffix}.pt'
        self.eval_music_data_dir = f'./eval/music/eval_music_data{suffix}.pt'
        self.eval_noise_data_dir = f'./eval/noise/eval_noise_data{suffix}.pt'
        self.eval_speech_data_dir = f'./eval/speech/eval_speech_data{suffix}.pt'

        self.car_data = torch.load(self.eval_car_data_dir)
        self.music_data = torch.load(self.eval_music_data_dir)
        self.noise_data = torch.load(self.eval_noise_data_dir)
        self.speech_data = torch.load(self.eval_speech_data_dir)

        for i in range(len(self.car_data)):
            self.car_data[i] = (self.car_data[i], 0)    
        for i in range(len(self.music_data)):
            self.music_data[i] = (self.music_data[i], 1)
        for i in range(len(self.speech_data)):
            self.speech_data[i] = (self.speech_data[i], 2)
        for i in range(len(self.noise_data)):
            self.noise_data[i] = (self.noise_data[i], 3)

        self.data = self.car_data + self.music_data + self.speech_data + self.noise_data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,x):
        return self.data[x][0], self.data[x][1]

In [6]:
class train_data_random(Dataset):
    def __init__(self):
        super().__init__()
        self.train_car_data_dir = f'./train/car/train_car_data{suffix_random}.pt'
        self.train_music_data_dir = f'./train/music/train_music_data{suffix_random}.pt'
        self.train_noise_data_dir = f'./train/noise/train_noise_data{suffix_random}.pt'
        self.train_speech_data_dir = f'./train/speech/train_speech_data{suffix_random}.pt'

        self.car_data = torch.load(self.train_car_data_dir)
        self.music_data = torch.load(self.train_music_data_dir)
        self.noise_data = torch.load(self.train_noise_data_dir)
        self.speech_data = torch.load(self.train_speech_data_dir)
        
        for i in range(len(self.car_data)):
            self.car_data[i] = (self.car_data[i], 0)    
        for i in range(len(self.music_data)):
            self.music_data[i] = (self.music_data[i], 1)
        for i in range(len(self.speech_data)):
            self.speech_data[i] = (self.speech_data[i], 2)
        for i in range(len(self.noise_data)):
            self.noise_data[i] = (self.noise_data[i], 3)

        self.data = self.car_data + self.music_data + self.speech_data + self.noise_data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, x):
        waveform, label = self.data[x]
        target_length = 48000  # 3초 @ 16kHz

        # padding 필요할 경우 뒤에 0-padding
        if waveform.shape[0] < target_length:
            pad_len = target_length - waveform.shape[0]
            waveform = F.pad(waveform, (0, pad_len))  # (left, right) padding
        elif waveform.shape[0] > target_length:
            waveform = waveform[:target_length]  # 혹시나 너무 길 경우 자름

        return waveform, label


In [None]:
class eval_data_random(Dataset):
    def __init__(self):
        super().__init__()

        self.eval_car_data_dir = f'./eval/car/eval_car_data{suffix_random}.pt'
        self.eval_music_data_dir = f'./eval/music/eval_music_data{suffix_random}.pt'
        self.eval_noise_data_dir = f'./eval/noise/eval_noise_data{suffix_random}.pt'
        self.eval_speech_data_dir = f'./eval/speech/eval_speech_data{suffix_random}.pt'

        self.car_data = torch.load(self.eval_car_data_dir)
        self.music_data = torch.load(self.eval_music_data_dir)
        self.noise_data = torch.load(self.eval_noise_data_dir)
        self.speech_data = torch.load(self.eval_speech_data_dir)

        for i in range(len(self.car_data)):
            self.car_data[i] = (self.car_data[i], 0)    
        for i in range(len(self.music_data)):
            self.music_data[i] = (self.music_data[i], 1)
        for i in range(len(self.speech_data)):
            self.speech_data[i] = (self.speech_data[i], 2)
        for i in range(len(self.noise_data)):
            self.noise_data[i] = (self.noise_data[i], 3)

        self.data = self.car_data + self.music_data + self.speech_data + self.noise_data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,x):
        waveform, label = self.data[x]
        target_length = 48000  # 3초 @ 16kHz

        # padding 필요할 경우 뒤에 0-padding
        if waveform.shape[0] < target_length:
            pad_len = target_length - waveform.shape[0]
            waveform = F.pad(waveform, (0, pad_len))  # (left, right) padding
        elif waveform.shape[0] > target_length:
            waveform = waveform[:target_length]  # 혹시나 너무 길 경우 자름

        return waveform, label