In [39]:
from train_data_preparation import MfccPipeline

In [50]:
import numpy as np
import librosa
import tarfile
import soundfile as sf
import io
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split


class MfccPipeline():

    def __init__(self):
        self.PATH = "C:/Users/alexc/Downloads/nsynth-train.jsonwav.tar.gz"

        self.targets_list = [
            "bass", "brass", "flute", "guitar", "keyboard",
            "mallet", "organ", "reed", "string", "synth", "vocal"
        ]

    # recover wav files from tarball
    def get_files(self, PATH=None, num_files = 100):
        """Sequentially pull wav files from .tar.gz file
        :param PATH: path to compressed dataset file
        :param num_files: number of files to read
        :returns: 4 generators for data (wav file converted to list), sr (sample rate)
                target name (e.g. 'guitar', 'mallet', ect.), and target index
        """

        print("getting files")
        if PATH is None:
            PATH = self.PATH

        # open the tar file
        with tarfile.open(PATH, 'r:gz') as tar:

            # Initialize counter to count number of files pulled.
            index = 0
            while index < num_files:
                fname = tar.next()

                # Break if there are no more files.
                if fname is None:
                    break

                # Check that we're dealing with the proper format
                if fname.name.endswith(".wav"):

                    # Extract file
                    wav_file = tar.extractfile(fname).read()

                    # Convert bytes to a readable format
                    data, sr = sf.read(io.BytesIO(wav_file))

                    # Get target from filename
                    target = fname.name.split('/')[2].split('_')[0]

                    # yeild the 4 generators
                    yield data, sr, target, self.targets_list.index(target)
                    index += 1

    # collect raw data into an array
    def get_dataset(self, num_files=10):
        """Docstring pending
        """
        data = []
        data_generator = self.get_files(self.PATH, num_files)
        for i in range(num_files):
            data.append(next(data_generator))
        return data

    # calculate mfccs
    def get_mfccs(self, data_tuple):
        """Take a tuple of data (data, sr, target, target_index) and return the associated mfcc"""
        data = np.array(data_tuple[0])
        sr = data_tuple[1]
        mfcc = librosa.feature.mfcc(y=data, sr=sr)
        return mfcc

    # prepare the data for input to model
    def prepare_data(self, dataset):
        """Get the mfccs for each record in the dataset and the associated target values
        """

        X = np.array([self.get_mfccs(i) for i in dataset])
        t = to_categorical(np.array([i[3] for i in dataset]))

        return X, t

    def mfcc_pipeline(self, num_samples, validation_split=0.2):
        """Write this doc later"""
        # get the data from tar file
        data = self.get_dataset(num_samples)
        X, t = self.prepare_data(data)

        # split into train and validation set and return the data
        return train_test_split(X, t, test_size=validation_split)

In [51]:
pipe = MfccPipeline()

In [52]:
dir(pipe)

['PATH',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'get_dataset',
 'get_files',
 'get_mfccs',
 'mfcc_pipeline',
 'prepare_data',
 'targets_list']

In [54]:
X_train, X_test, y_train, y_test = pipe.mfcc_pipeline(num_samples = 1000)

getting files


In [55]:
X_train

array([[[-7.59742960e+01, -1.04606741e+02, -2.29081248e+02, ...,
         -5.66862667e+02, -5.66862667e+02, -5.66862667e+02],
        [ 1.36380829e+02,  1.44651938e+02,  1.68633632e+02, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 4.28400310e+00,  2.15398926e+00, -7.20601795e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        ...,
        [ 3.71962189e+00,  4.57976086e+00,  5.80338805e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [-1.37600947e+01, -1.34649780e+01, -7.51012292e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
        [ 5.59659274e+00,  5.11119904e+00,  1.10133784e+00, ...,
          0.00000000e+00,  0.00000000e+00,  0.00000000e+00]],

       [[-1.43798420e+02, -1.92296125e+02, -3.11512839e+02, ...,
         -5.13707716e+02, -5.15069914e+02, -5.16267922e+02],
        [ 1.39416394e+02,  1.45631181e+02,  1.64370976e+02, ...,
          5.82411216e+00,  3.95890164e