In [1]:
from data.dataloader_v1 import AudioDataset

In [2]:
! pwd

/home/prs392/codes/incubator/non-invertible-audio-feature-generation/development/param/openl3_librispeech/data_loader


In [3]:
import os
import random

import torch
from torch.utils.data import Dataset

import numpy as np


class AudioDataset(Dataset):

    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.transform = transform
        
        self.list_of_embedding_file_names = []
        self.embeddings_dir = os.path.join(self.root_dir, 'embeddings_6144')
        
        for root, dirs, files in os.walk(self.embeddings_dir):
            for file in files:
                if file.endswith(".npy"):
                     self.list_of_embedding_file_names.append(file)
        
        list_of_spectrogram_file_names = []
        self.spectrograms_dir = os.path.join(self.root_dir, 'spectrograms')
        
        for root, dirs, files in os.walk(self.spectrograms_dir):
            for file in files:
                if file.endswith(".npy"):
                     list_of_spectrogram_file_names.append(file)
                        
        assert set(list_of_spectrogram_file_names) == set(self.list_of_embedding_file_names)
        
        del list_of_spectrogram_file_names
            

    def __len__(self):
        return len(self.list_of_embedding_file_names)

    def __getitem__(self, idx):
        
        file_name = self.list_of_embedding_file_names[idx]
        
        emb_path = os.path.join(self.embeddings_dir, file_name)
        
        spec_path = os.path.join(self.spectrograms_dir, file_name)
        
        with open(emb_path, 'rb') as f:
            emb = np.load(f)
            
        with open(spec_path, 'rb') as f:
            spec = np.load(f)
            
        n = emb.shape[0]
        i = random.randrange(n)
        
        emb_tensor = torch.from_numpy(emb[i])
        spec_tensor = torch.from_numpy(spec[i]).permute(2, 0, 1)
        
        return emb_tensor, spec_tensor, torch.tensor(i)


In [None]:
audio_dataset = AudioDataset(root_dir='/scratch/prs392/incubator/data/LibriSpeech/train-clean-360')

for i in range(len(audio_dataset)):
    sample, spec, j = audio_dataset[i]
    print(sample.shape, spec.shape, j)
    
    if i == 100:
        break

torch.Size([6144]) torch.Size([1, 128, 199]) tensor(139)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(58)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(22)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(110)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(65)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(49)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(29)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(12)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(80)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(74)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(55)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(78)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(0)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(48)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(36)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(124)
torch.Size([6144]) torch.Size([1, 128, 199]) tensor(19)
torch.Size([6144]) torch.Size([1, 128, 199]) t