# Drum Sample Classifier

This notebook will be used to develop a pipeline to process user-provided drum samples and apply the model's classification prediction.

# 0. Import packages

In [1]:
import os
import librosa
import tensorflow as tf
import numpy as np

# 1. Define test filenames & constants

In [2]:
KICK_1_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/KICKS/ELECTRONIC/Antidote Audio', 'Antidote - Kick 2.wav')
KICK_2_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/KICKS/ELECTRONIC/Nitti Gritti', '808 Top Kick.wav')

SNARE_1_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/SNARES/ACOUSTIC/Apashe', 'Apashe_Acoustic_Snare.wav')
SNARE_2_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/SNARES/TRAP/Kompany', 'Kompany - Snare 4.wav')

PERC_1_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/PERCS/BLOCKS, RIMS, ETC/Cymatics', 'Cymatics - 100k Perc 2.wav')
PERC_2_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/TOMS/ELECTRONIC/PhaseOne', 'PhaseOne_Tom1.wav')

CYMBAL_1_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/CYMBALS/CRASHES/ELECTRONIC/Au5', 'Au5_cymbal_crash_acoustic.wav')
CYMBAL_2_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/CYMBALS/CRASHES/ELECTRONIC/UpSound', 'Crash 13.wav')
CYMBAL_3_FILEPATH: str = os.path.join('/Users/tyler/2TB SSD/Samples/ALL/ONE SHOTS/DRUMS/CYMBALS/RIDES/ACOUSTIC/Cymatics', 'Cymatics - Ride 1.wav')

In [3]:
SAMPLE_RATE: int = 44100
SAMPLE_LENGTH: int = 132300

FRAME_LENGTH: int = 2**10
N_BINS: int = int(FRAME_LENGTH / 2 + 1)
FRAME_STEP: int = int(FRAME_LENGTH / 8)
N_FRAMES: int = int(((SAMPLE_LENGTH - FRAME_LENGTH) / FRAME_STEP) + 1)

# 2. Load & preprocess sample

## 2.1 Load sample

When building & training the model, I had to be careful about only using TensorFlow modules that would map to the Dataset properly. The tf.audio.decode_wav() module only supports 16bit audio, which required an extra step of writing new 16bit files from the 24bit originals. Since I am not mapping these transormations here to an entire dataset, I should be fine using 24bit audio samples with the librosa library. Ultimately, I am converting the audio to STFT spectrograms before passing to the classification model, so specific bit depth shouldn't matter.

In [4]:
def load_sample(filepath: str) -> tf.Tensor:
    audio, sample_rate = librosa.load(filepath, mono=True, sr=SAMPLE_RATE)
    return tf.convert_to_tensor(audio, dtype=tf.float32)

## 2.2 Pad & trim sample

In [5]:
def pad_sample(audio: tf.Tensor) -> tf.Tensor:
    audio = audio[:SAMPLE_LENGTH]

    zero_padding = tf.zeros([SAMPLE_LENGTH] - tf.shape(audio), dtype=tf.float32)
    return tf.concat([audio, zero_padding], axis=0)

## 2.3 Normalize sample volume

In [14]:
def normalize(audio: tf.Tensor) -> tf.Tensor:
    audio_max = tf.reduce_max(tf.abs(audio))
    scale_factor = 1 / audio_max
    return audio * scale_factor

In [15]:
kick_1_norm = normalize(pad_sample(load_sample(KICK_1_FILEPATH)))

In [16]:
kick_1_norm

<tf.Tensor: shape=(132300,), dtype=float32, numpy=
array([ 0.09780365, -0.06092693, -0.43853027, ...,  0.        ,
        0.        ,  0.        ], dtype=float32)>

In [17]:
tf.norm(kick_1_norm)

<tf.Tensor: shape=(), dtype=float32, numpy=42.16108>

In [18]:
tf.reduce_max(tf.abs(kick_1_norm))

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

## 2.4 Apply STFT

In [19]:
def apply_stft(audio: tf.Tensor) -> tf.Tensor:
    spectrogram = tf.signal.stft(audio, frame_length=FRAME_LENGTH, frame_step=FRAME_STEP) 
    spectrogram = tf.abs(spectrogram)
    spectrogram = tf.expand_dims(spectrogram, axis=2) #convolution neural net expects channels dimension
    return tf.expand_dims(spectrogram, axis=0) #model also expects batch size dimension

## 2.5 Wrap processing functions

In [20]:
def load_and_process(filepath: str) -> tf.Tensor:
    return apply_stft(normalize(pad_sample(load_sample(filepath))))

## 2.6 Test

In [21]:
kick_1_test_stft = load_and_process(KICK_1_FILEPATH)
kick_1_test_stft

<tf.Tensor: shape=(1, 1026, 513, 1), dtype=float32, numpy=
array([[[[1.0555505e+00],
         [9.8869829e+00],
         [4.0773315e+01],
         ...,
         [2.9656422e-01],
         [3.0780557e-01],
         [3.1764221e-01]],

        [[6.9636750e+00],
         [1.7453072e+01],
         [6.6924355e+01],
         ...,
         [1.0898353e-01],
         [1.7584307e-01],
         [2.6164484e-01]],

        [[1.3303730e+00],
         [2.4632553e+01],
         [1.0118411e+02],
         ...,
         [3.8304087e-02],
         [7.3528975e-02],
         [1.3738275e-01]],

        ...,

        [[0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         ...,
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00]],

        [[0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00],
         ...,
         [0.0000000e+00],
         [0.0000000e+00],
         [0.0000000e+00]],

        [[0.0000000e+00],
         [0.0000000e+00],
         [0.000

# 3. Load saved model and make prediction

## 3.1 Load trained model

In [22]:
model = tf.keras.models.load_model('model_training/trained_hypermodel')

## 3.2 Make prediction

In [23]:
kick_1_pred: list[list[float]] = model.predict(load_and_process(KICK_1_FILEPATH))
kick_1_pred



2023-02-25 12:43:02.984227: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2023-02-25 12:43:03.040764: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


array([[3.8817014e-18, 1.0000000e+00, 1.2962743e-10, 5.6914194e-21]],
      dtype=float32)

## 3.3 Translate prediction

In [24]:
def translate_prediction(prediction_array: list[float]) -> str:
    if np.argmax(prediction_array) == 0:
        return 'cymbal'
    elif np.argmax(prediction_array) == 1:
        return 'kick'
    elif np.argmax(prediction_array) == 2:
        return 'perc'
    elif np.argmax(prediction_array) == 3:
        return 'snare'
    else:
        return None

# 4. Test predictions on all test samples

As we can see from the tests below, the model is not perfect. We likely could have achieved greater accuracy with a larger and more diverse training dataset.

However even though the model isn't perfect, out of the 20 tests below, the model accurately predicted the label 17 times, perfectly reflecting the training val_accuracy of 85%.

In [25]:
translate_prediction(model.predict(load_and_process(KICK_1_FILEPATH)))



'kick'

In [26]:
translate_prediction(model.predict(load_and_process(KICK_2_FILEPATH)))



'kick'

In [27]:
translate_prediction(model.predict(load_and_process(SNARE_1_FILEPATH)))



'snare'

In [28]:
translate_prediction(model.predict(load_and_process(SNARE_2_FILEPATH)))



'snare'

In [29]:
translate_prediction(model.predict(load_and_process(PERC_1_FILEPATH)))



'snare'

In [30]:
translate_prediction(model.predict(load_and_process(PERC_2_FILEPATH)))



'perc'

In [31]:
translate_prediction(model.predict(load_and_process(CYMBAL_1_FILEPATH)))



'cymbal'

In [32]:
translate_prediction(model.predict(load_and_process(CYMBAL_2_FILEPATH)))



'perc'

In [33]:
translate_prediction(model.predict(load_and_process(CYMBAL_3_FILEPATH)))



'cymbal'

In [34]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/cymbal_001111.wav')))



'cymbal'

In [35]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/cymbal_001986.wav')))



'cymbal'

In [36]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/cymbal_001992.wav')))



'cymbal'

In [37]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/cymbal_002741.wav')))



'cymbal'

In [38]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/cymbal_009212.wav')))



'cymbal'

In [39]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/kick_015418.wav')))



'kick'

In [40]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/kick_016314.wav')))



'kick'

In [41]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/perc_019687.wav')))



'cymbal'

In [42]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/perc_019681.wav')))



'perc'

In [43]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/snare_020045.wav')))



'snare'

In [44]:
translate_prediction(model.predict(load_and_process('model_training/preprocessed_samples_16bit/snare_020044.wav')))



'snare'