In [1]:
import csv
import os
import numpy as np
import queue
import librosa
import time

import torch

import threading

import scipy.signal as scipy_signal

import datetime

In [17]:
class DataLoader:
    
    def __init__(self, metadata_path, class_name_to_class_idx_dict):
        
        self.queue = queue.Queue(maxsize=10)
        self.batch_size = 64

        self.fs = 44100 / 2

        self.nsc_in_ms = 40
        self.nov_in_ms = 0
        self.nsc_in_sample = int(self.nsc_in_ms / 1000 * self.fs)
        self.nov_in_sample = int(self.nov_in_ms / 1000 * self.fs)

        self.num_mels = 160

        self.mel_band = librosa.filters.mel(self.fs, self.nsc_in_sample, n_mels=self.num_mels)

        self.batch_flag = None
        
        with open(metadata_path) as file:
            csv_reader = csv.reader(file)

            metadata = np.array(list(csv_reader))
            self.file_names = metadata[:, 0]
            self.durations = np.asarray(list(map(float, metadata[:, 1])))
            
        self.class_name_to_class_idx_dict = class_name_to_class_idx_dict
          
    def shuffle_dataset(self):
        
        sorted_idx = np.argsort(self.durations)
        self.file_names = self.file_names[sorted_idx]
        self.durations = self.durations[sorted_idx]
        
#         random_idx = np.random.permutation(len(self.file_names))
#         self.file_names = self.file_names[random_idx]
#         self.durations = self.durations[random_idx]
        
        
        
        return
    
    def start_loading(self):
        
        t = threading.Thread(target=self.batching_thread)
        t.start()
#         t.join()
        
    def batch_generator(self):
        
        self.batch_flag = True
        self.start_loading()
        
        while self.batch_flag or not self.queue.empty():
            try:
                yield self.queue.get(timeout=1)
            except queue.Empty:
                print('Empty Queue Found')
                time.sleep(1)
        
    def batching_thread(self):
        
        mel_buffer_list = list()
        label_buffer_list = list()
        
        test_buffer_list = list()
        
        for file_name in self.file_names:
            mel = self.wav_path_to_mel(file_name)
            class_label = file_name.split('/')[1]
            class_idx = self.class_name_to_class_idx_dict[class_label]
#             test_buffer_list.append(file_name)
            mel_buffer_list.append(mel)
            label_buffer_list.append(class_idx)
    
            if len(mel_buffer_list) == self.batch_size:
                batch = self.pack_batch(mel_buffer_list, label_buffer_list)
                self.queue.put(batch)

                mel_buffer_list = list()
                label_buffer_list = list()

        if len(mel_buffer_list) > 0:
            batch = self.pack_batch(mel_buffer_list, label_buffer_list)
            self.queue.put(batch)

            mel_buffer_list = list()
            label_buffer_list = list()
                
        self.batch_flag = False
        
    def wav_path_to_mel(self, wav_path):

        data, fs = librosa.core.load(wav_path, sr=None)

        f, t, Zxx = scipy_signal.stft(data, fs=fs, 
                                      nperseg=self.nsc_in_sample,
                                      noverlap=self.nov_in_sample)

        Sxx = np.abs(Zxx)
        Sxx = np.matmul(self.mel_band, Sxx)
        normalized_spectrogram = (20 * np.log10(np.maximum(Sxx, 1e-8)) + 160) / 160

#         if self.is_train:
#             normalized_spectrogram += (np.random.random(normalized_spectrogram.shape) - 0.5)/10
#             normalized_spectrogram = np.clip(normalized_spectrogram, 0, None)
        
        return normalized_spectrogram
    
    def pack_batch(self, mels, labels):
        
        # TODO
        
        # mels (B, F, T)
        
        batch_size = len(mels) # B
        num_mels = self.num_mels # F
        max_time_step = np.max([mel.shape[1] for mel in mels]) # T
        
        x = np.zeros([batch_size, num_mels, max_time_step])
        
        for i, mel in enumerate(mels):
            x[i, :, -mel.shape[1]:] = mel
            
        x = torch.tensor(x)
        y = torch.tensor(labels)
        
        return (x, y)
        
        

In [18]:
classes = os.listdir('train')
classes.sort()

class_name_to_class_idx_dict = {class_name: i for i, class_name in enumerate(classes)}

In [19]:
train_dataloader = DataLoader('train_metadata.csv', class_name_to_class_idx_dict)

In [20]:
train_dataloader.shuffle_dataset()

In [21]:
# train_dataloader.start_loading()

In [22]:
batch_generator = train_dataloader.batch_generator()

In [23]:
for batch in batch_generator:
    print('[TIME: {}] [SHAPE: {}] '.format(datetime.datetime.now(), batch[0].shape))

[TIME: 2020-02-13 14:56:44.765798] [SHAPE: torch.Size([64, 160, 14])] 
[TIME: 2020-02-13 14:56:44.802988] [SHAPE: torch.Size([64, 160, 14])] 
[TIME: 2020-02-13 14:56:44.838573] [SHAPE: torch.Size([64, 160, 14])] 
[TIME: 2020-02-13 14:56:44.876701] [SHAPE: torch.Size([64, 160, 16])] 
[TIME: 2020-02-13 14:56:44.918466] [SHAPE: torch.Size([64, 160, 18])] 
[TIME: 2020-02-13 14:56:44.960463] [SHAPE: torch.Size([64, 160, 20])] 
[TIME: 2020-02-13 14:56:45.008338] [SHAPE: torch.Size([64, 160, 22])] 
[TIME: 2020-02-13 14:56:45.055209] [SHAPE: torch.Size([64, 160, 24])] 
[TIME: 2020-02-13 14:56:45.111652] [SHAPE: torch.Size([64, 160, 26])] 
[TIME: 2020-02-13 14:56:45.159095] [SHAPE: torch.Size([64, 160, 27])] 
[TIME: 2020-02-13 14:56:45.212599] [SHAPE: torch.Size([64, 160, 29])] 
[TIME: 2020-02-13 14:56:45.272070] [SHAPE: torch.Size([64, 160, 31])] 
[TIME: 2020-02-13 14:56:45.326519] [SHAPE: torch.Size([64, 160, 32])] 
[TIME: 2020-02-13 14:56:45.381926] [SHAPE: torch.Size([64, 160, 35])] 
[TIME: