In [None]:
import os
import shutil
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from scipy.io import wavfile

In [None]:
def load_model(model_path):
  return tf.keras.models.load_model(model_path)

In [None]:
class DrumAudio():
  def __init__(
        self,
        audio_path: str,
        **kwargs
        ):
    super(DrumAudio, self).__init__(**kwargs)

    self.freq_rate, self.signal = wavfile.read(audio_path)

    size = self.signal.shape[0]
    secs = size / self.freq_rate
    T = 1.0/self.freq_rate
    self.time_range = np.arange(0, secs, T).round(6)
  
  def detect_percussions(self):
    dataset = pd.DataFrame(np.array([self.signal,self.time_range]).T, columns=['signal','time'])
    dataset['index_position'] = dataset.index

    mean = 0
    step = 1.01
    thresholds = [mean - step, mean + step]
    dataset['label'] = 1
    dataset.loc[(dataset['signal'] <= thresholds[0]) | (dataset['signal']>= thresholds[1]), 'label'] = 0
    dataset['new_index'] = dataset['label'].cumsum()
    new_dataset = dataset[(dataset['label']==0)].groupby('new_index').first()
    new_dataset['size'] = dataset[(dataset['label']==0)].groupby('new_index').size()
    new_dataset = new_dataset[new_dataset['size']>=1024].groupby(new_dataset['time'].round(1)).first().reset_index(drop=True)
    new_dataset['new_index'] = (new_dataset['time']/0.1).round()

    percussions = new_dataset[['time','index_position']]

    return percussions
  
  def predict_kits(self, CY_model, HH_model, KD_model, SD_model):
    transcription = self.detect_percussions()

    transcription['CY'] = 0
    transcription['HH'] = 0
    transcription['KD'] = 0
    transcription['SD'] = 0

    os.makedirs('../data/predictions')
    for index, row in transcription.iterrows():
      fig, ax = plt.subplots(figsize=(8, 8))
      if row['index_position']<1024:
        spectrum, freqs, bins, im = ax.specgram(self.signal[0:int(row['index_position'])+4096+1024], Fs=self.freq_rate, scale_by_freq=True)
      else:
        spectrum, freqs, bins, im = ax.specgram(self.signal[int(row['index_position'])-1024:int(row['index_position'])+4096+1024], Fs=self.freq_rate, scale_by_freq=True)
      ax.axis('tight')
      ax.axis('off')
      fig.savefig(f'../data/predictions/{index}.png', bbox_inches='tight', pad_inches=0.0)
      plt.close(fig)
      img = tf.keras.preprocessing.image.load_img(
        f'../data/predictions/{index}.png', target_size=(100, 100)
      )
      img_array = tf.keras.preprocessing.image.img_to_array(img)
      img_array = tf.expand_dims(img_array, 0)
      transcription.loc[index, 'CY'] = int(CY_model.predict(img_array)[0] > 0.5)
      transcription.loc[index, 'HH'] = int(HH_model.predict(img_array)[0] > 0.5)
      transcription.loc[index, 'KD'] = int(KD_model.predict(img_array)[0] > 0.5)
      transcription.loc[index, 'SD'] = int(SD_model.predict(img_array)[0] > 0.5)

    shutil.rmtree('../data/predictions')

    transcription = transcription[(transcription['CY']!=0) | (transcription['HH']!=0) | (transcription['KD']!=0) | (transcription['SD']!=0)].reset_index(drop=True)

    return transcription

In [None]:
CY_model = load_model('../models/CY_60epochs')
HH_model = load_model('../models/HH_ResNet')
KD_model = load_model('../models/KD_ResNet')
SD_model = load_model('../models/SD_ResNet')

In [None]:
audio = DrumAudio('../data/drums_audio/MusicDelta_80sRock_Drum.wav')

In [None]:
transcription = audio.predict_kits(CY_model,HH_model,KD_model,SD_model)



In [None]:
transcription

Unnamed: 0,time,index_position,CY,HH,KD,SD
0,0.000000,0,0,0,1,0
1,0.554989,24475,0,0,1,1
2,1.103810,48678,0,0,1,0
3,1.657279,73086,0,0,1,1
4,2.228571,98280,0,0,1,0
...,...,...,...,...,...,...
70,34.413220,1517623,0,0,1,1
71,34.514308,1522081,0,0,0,1
72,34.561247,1524151,0,0,0,1
73,34.663923,1528679,0,0,0,1


In [None]:
transcription.head(50)

Unnamed: 0,time,index_position,CY,HH,KD,SD
0,0.0,0,0,0,1,0
1,0.554989,24475,0,0,1,1
2,1.10381,48678,0,0,1,0
3,1.657279,73086,0,0,1,1
4,2.228571,98280,0,0,1,0
5,2.777052,122468,0,0,1,1
6,3.315986,146235,0,0,1,0
7,3.831859,168985,0,0,1,1
8,4.373129,192855,1,0,1,0
9,4.880884,215247,0,0,1,1


In [None]:
transcription.tail(25)

Unnamed: 0,time,index_position,CY,HH,KD,SD
50,25.095488,1106711,0,0,1,0
51,25.60229,1129061,0,0,0,1
52,25.65678,1131464,0,0,1,1
53,26.204082,1155600,0,0,1,0
54,26.724399,1178546,0,0,1,0
55,26.758594,1180054,0,0,1,1
56,27.306576,1204220,0,0,1,0
57,27.827687,1227201,0,0,1,1
58,28.346939,1250100,0,0,1,0
59,28.388027,1251912,0,0,1,0
