In [1]:
import os
import numpy as np
import pandas as pd

import warnings
import tensorflow as tf
import matplotlib.pyplot as plt
from tqdm import tqdm

import wavio
from scipy.io import wavfile
from librosa.core import resample, to_mono
from glob import glob
import sounddevice as sd

import kapre
from kapre.composed import get_melspectrogram_layer
from kapre.time_frequency import STFT, Magnitude, ApplyFilterbank, MagnitudeToDecibel
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import CSVLogger, ModelCheckpoint
from tensorflow.keras.layers import TimeDistributed, LayerNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model

In [2]:
src_root = 'raw_data'
dst_root = 'clean'
dt = 1.0
sample_f = 44100
dummy_file = 'finger_snaps_1_4'
threshold = 120

In [3]:
model = load_model("models/conv2d.h5", custom_objects = {'STFT':STFT,
                                          'Magnitude':Magnitude,
                                          'ApplyFilterbank':ApplyFilterbank,
                                          'MagnitudeToDecibel':MagnitudeToDecibel})

In [4]:
def downsample_mono(path, sf):
    obj = wavio.read(path)
    wav = obj.data.astype(np.float32,order='F')
    rate = obj.rate
    try:
        channel = wav.shape[1]
        if channel == 2:
            wav = to_mono(wav.T)
        elif channel == 1:
            wav = to_mono(wav.reshape(-1))
    except IndexError:
        wav = to_mono(wav.reshape(-1))
        pass
    except Exception as exc:
        raise exc
    wav = resample(wav, orig_sr = rate, target_sr = sf)
    wav = wav.astype(np.int16)
    return sf, wav

In [5]:
def envelope(y, rate, t = threshold):
    mask = []
    y = pd.Series(y).apply(np.abs)
    y_mean = y.rolling(window = int(rate/30), min_periods=1,center=True).mean()
    for mean in y_mean:
        if mean > t:
            mask.append(True)
        else:
            mask.append(False)
    return mask, y_mean

In [6]:
def m_pred(wav,model=model,sf=sample_f, dt=dt):
    step = int(sf*dt)
    batch = []
    
    sample = wav.reshape(-1, 1)
    if sample.shape[0] < step:
        tmp = np.zeros(shape=(step, 1), dtype=np.float32)
        tmp[:sample.shape[0],:] = sample.flatten().reshape(-1, 1)
        sample = tmp
    batch.append(sample)

    X_batch = np.array(batch, dtype=np.float32)
    y_pred = model.predict(X_batch)
    y_mean = np.mean(y_pred, axis=0)
    y_pred = np.argmax(y_mean)
    return y_pred, y_mean

In [35]:
pr = 'predict'
pr_file_name = 'p1'

In [36]:
def predict_wav(src = pr, fn = pr_file_name, sf = sample_f,t=threshold,dt=dt):
    wav_paths = glob('{}/**'.format(src), recursive=True)
    wav_path = [x for x in wav_paths if fn in x]
    rate, wav = downsample_mono(wav_path[0], sf)
    
    delta_sample = int(dt*rate)
    trunc = wav.shape[0] % delta_sample

    seconds = []
    
    for cnt, i in enumerate(np.arange(0,wav.shape[0]-trunc,delta_sample)):
        start = int(i)
        stop = int(i+delta_sample)
        raw_sample = wav[start:stop]
        mask, env = envelope(raw_sample, rate,t)
        noise_free_sample = raw_sample[mask]
        sample = np.zeros(shape = (delta_sample,),dtype = np.int16)
        sample[:noise_free_sample.shape[0]] = noise_free_sample
        y_pred, y_mean = m_pred(sample)
        if y_pred == 0:
            seconds.append(cnt+1)
        print(f"{cnt}:{cnt+1}  {y_pred}   {y_mean}")
    
    return seconds

In [37]:
pred = predict_wav()

0:1  1   [0.01952885 0.98047113]
1:2  1   [0.00236903 0.99763095]
2:3  1   [0.00209377 0.9979062 ]
3:4  1   [0.01742334 0.98257667]
4:5  1   [0.2336937 0.7663063]
5:6  1   [0.01120897 0.988791  ]
6:7  1   [0.00525028 0.9947497 ]
7:8  1   [0.03295229 0.9670477 ]
8:9  1   [0.00242132 0.9975787 ]
9:10  1   [0.00736307 0.9926369 ]
10:11  1   [0.00126906 0.99873096]
11:12  1   [0.01271176 0.98728824]
12:13  1   [0.02675934 0.9732406 ]
13:14  1   [0.01569761 0.9843024 ]
14:15  0   [0.9957163  0.00428369]
15:16  1   [0.13545096 0.86454904]
16:17  1   [0.15179084 0.84820914]
17:18  1   [0.0401395 0.9598605]
18:19  1   [0.24288289 0.75711715]
19:20  1   [0.0218384 0.9781616]
20:21  1   [0.01696103 0.983039  ]
21:22  1   [0.0350857 0.9649143]
22:23  1   [0.15179084 0.84820914]
23:24  1   [0.05619758 0.94380236]
24:25  1   [0.01117476 0.98882526]
25:26  0   [0.92171645 0.07828353]
26:27  0   [0.55029196 0.44970804]
27:28  1   [0.15179084 0.84820914]
28:29  1   [0.47675008 0.5232499 ]
29:30  1   [

In [38]:
import time

def ltbm(src = pr, fn = pr_file_name, sf = sample_f):
    wav_paths = glob('{}/**'.format(src), recursive=True)
    wav_path = [x for x in wav_paths if fn in x]
    rate, wav = downsample_mono(wav_path[0], sf)
    sd.play(wav, rate)
    print(pred)
    for sec in range(wav.shape[0] // rate):
        if sec in pred:
            print(f"Snap")
        time.sleep(1)
            

In [39]:
ltbm()



[15, 26, 27]
Snap
Snap
Snap
