In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from keras.layers import GRU, Bidirectional, Dropout, Input, TimeDistributed, BatchNormalization, Dense, Conv1D, Activation, UpSampling1D
from keras.models import Model

In [None]:
class CRNN_segmenter:
    def __init__(self, weights_path):
        self.model = self.build_model((1000, 50))
        self.model.load_weights(weights_path)
        
    def build_model(self, input_shape):
        X_input = Input(shape=input_shape)
        X = Conv1D(196, kernel_size=1, strides=1)(X_input)
        X = BatchNormalization()(X)
        X = Activation("relu")(X)
        X = Bidirectional(GRU(units = 256, return_sequences = True))(X)
        X = BatchNormalization()(X)
        X = UpSampling1D(5)(X)
        X_output = TimeDistributed(Dense(1, activation = "sigmoid"))(X)

        return Model(inputs=X_input, outputs=X_output)
    
    def pad(self, seq, max_length):
        if seq.shape[0] < max_length:
            seq = np.append(seq, [0] * (max_length - seq.shape[0]))
        else:
            seq = seq[:max_length]
        return seq
    
    def to_spectral(self, x, samples):
        xf = np.zeros((x.shape[0] // samples, samples))
        for i in range(0, x.shape[0] - samples, samples):
            w = abs(np.fft.fft(x[i:i+samples], n=samples*2))
            freqs = np.fft.fftfreq(len(w))
            xf[i//50, :] = w[freqs >= 0]
        return xf
    
    def predict(self, wave, plot_result=False):
        wave_xf = self.to_spectral(self.pad(wave.y, 50000), 50)
        wave_xf = wave_xf.reshape(1, 1000, 50)
        pred = self.model.predict(wave_xf)
        diff = pred[0][1:]-pred[0][:-1]
        onsets = np.where(diff>0.5)[0] * 10
        
        if plot_result:
            plt.figure(figsize=(18,6))
            plt.plot(3*self.pad(wave.y, 50000))
            for onset in onsets:
                plt.axvline(onset, color="r")
                
        return onsets