In [1]:
from tensorflow.keras.models import load_model
from kapre.time_frequency import STFT, Magnitude, ApplyFilterbank, MagnitudeToDecibel
from sklearn.preprocessing import LabelEncoder
import numpy as np
from clean import envelope,downsample_mono
from glob import glob
import argparse
import wavio
import os
import pandas as pd
from tqdm import tqdm



In [2]:
print(os.getcwd())

/home/andy/NN/Licenta


In [3]:
def make_prediction(args):

    model = load_model(args.model_fn,
        custom_objects={'STFT':STFT,
                        'Magnitude':Magnitude,
                        'ApplyFilterbank':ApplyFilterbank,
                        'MagnitudeToDecibel':MagnitudeToDecibel})
    wav_paths = glob('{}/**'.format(args.src_dir), recursive=True)
    wav_paths = sorted([x.replace(os.sep, '/') for x in wav_paths if '.wav' in x])
    classes = ['Other','Piano']
    labels = [os.path.split(x)[0].split('/')[-1] for x in wav_paths]
    le = LabelEncoder()
    y_true = le.fit_transform(labels)
    results = []

    for z, wav_fn in tqdm(enumerate(wav_paths), total=len(wav_paths)):
        rate, wav = downsample_mono(wav_fn, args.sr)
        mask, env = envelope(wav, rate, threshold=args.threshold)
        clean_wav = wav[mask]
        step = int(args.sr*args.dt)
        batch = []

        for i in range(0, clean_wav.shape[0], step):
            sample = clean_wav[i:i+step]
            sample = sample.reshape(-1, 1)
            if sample.shape[0] < step:
                tmp = np.zeros(shape=(step, 1), dtype=np.float32)
                tmp[:sample.shape[0],:] = sample.flatten().reshape(-1, 1)
                sample = tmp
            batch.append(sample)
        X_batch = np.array(batch, dtype=np.float32)
        y_pred = model.predict(X_batch)
        y_mean = np.mean(y_pred, axis=0)
        y_pred = np.argmax(y_mean)
        real_class = os.path.dirname(wav_fn).split('/')[-1]
        print('Actual class: {}, Predicted class: {}'.format(real_class, classes[y_pred]))
        results.append(y_mean)

    np.save(os.path.join('logs', args.pred_fn), np.array(results))

In [4]:
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Audio Classification Training')
    parser.add_argument('--model_fn', type=str, default='models/cnn.h5',
                        help='model file to make predictions')
    parser.add_argument('--pred_fn', type=str, default='y_pred',
                        help='fn to write predictions in logs dir')
    parser.add_argument('--src_dir', type=str, default='./test/Piano',
                        help='directory containing wavfiles to predict')
    parser.add_argument('--dt', type=float, default=1.0,
                        help='time in seconds to sample audio')
    parser.add_argument('--sr', type=int, default=16000,
                        help='sample rate of clean audio')
    parser.add_argument('--threshold', type=str, default=50,
                        help='threshold magnitude for np.int16 dtype')
    args, _ = parser.parse_known_args()
    print(args)
    make_prediction(args)

Namespace(dt=1.0, model_fn='models/cnn.h5', pred_fn='y_pred', sr=16000, src_dir='./test/Piano', threshold=50)


 33%|███▎      | 1/3 [00:05<00:11,  5.91s/it]

Actual class: Piano, Predicted class: Piano


 67%|██████▋   | 2/3 [00:24<00:13, 13.37s/it]

Actual class: Piano, Predicted class: Piano


100%|██████████| 3/3 [00:34<00:00, 11.49s/it]

Actual class: Piano, Predicted class: Other





In [5]:
print(os.listdir(args.src_dir))

['Edith Piaf - Non, Je ne regrette rien - EASY Piano Tutorial by PlutaX.wav', "GUNS N' ROSES - NOVEMBER RAIN - Piano Tutorial.wav", 'Great White - The Angel Song.wav']
