In [6]:
import os
import numpy as np
import torch
from torch.utils import data
from ipynb.fs.full.data_preprocess import serialized_test_folder, serialized_train_folder, serialized_val_folder


In [9]:
class SignalDataset(data.Dataset):
    """
    Audio sample reader.
    """

    def __init__(self, data_type):

        if data_type == 'train':
            data_path = serialized_train_folder
        elif data_type == 'val':
            data_path = serialized_val_folder
        else:
            data_path = serialized_test_folder
        if not os.path.exists(data_path):
            raise FileNotFoundError('The {} data folder does not exist!'.format(data_type))

        self.data_type = data_type
        self.file_names = [os.path.join(data_path, filename) for filename in os.listdir(data_path)]

    def reference_batch(self, batch_size):
        """
        Randomly selects a reference batch from dataset.
        Reference batch is used for calculating statistics for virtual batch normalization operation.

        Args:
            batch_size(int): batch size

        Returns:
            ref_batch: reference batch
        """
        ref_file_names = np.random.choice(self.file_names, batch_size)
        ref_batch = np.stack([np.load(f) for f in ref_file_names])
        ref_batch = np.squeeze(ref_batch, axis=-1)
        # ref_batch = emphasis(ref_batch, emph_coeff=0.95)
        return torch.from_numpy(ref_batch).type(torch.FloatTensor)

    def __getitem__(self, idx):
        pair = np.load(self.file_names[idx])
        noisy = pair[1].reshape(1,-1)
        acc = pair[2].reshape(1,-1)
        # if self.data_type == 'train':
        clean = pair[0].reshape(1,-1)
        return torch.from_numpy(pair).type(torch.FloatTensor), torch.from_numpy(clean).type(
            torch.FloatTensor), torch.from_numpy(noisy).type(torch.FloatTensor), torch.from_numpy(acc).type(torch.FloatTensor)
    # else:
        #     return os.path.basename(self.file_names[idx]), torch.from_numpy(noisy).type(torch.FloatTensor), torch.from_numpy(acc).type(torch.FloatTensor)

    def __len__(self):
        return len(self.file_names)
