In [35]:
import os
import pickle
import numpy as np
import tensorflow as tf
import madmom
import mir_eval

from modules.labels import get_label_vector
from modules.madmom_cnn_prep import cnn_preprocessor
from datasets import Dataset
from modules.analysis_funcs import get_idx_to_fold, get_segmented_data, get_test_peaks
from analyze_detection import evaluate

%load_ext autoreload
%autoreload 2

FPS = 100
CONTEXT = 7

# Load Madmom normalization
def cnn_normalize(frames):
    inv_std = np.load("models/bock2013pret_inv_std.npy")
    mean = np.load("models/bock2013pret_mean.npy")
    frames_normalized = (frames - np.reshape(mean, (1,80,3)))*np.reshape(inv_std, (1,80,3))
    return frames_normalized

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
ds0 = Dataset("initslurtest")
ds1 = Dataset("slurtest_add_1")

audio_fnames = ds0.get_audio_paths() + ds1.get_audio_paths()
label_fnames = ds0.get_annotation_paths() + ds1.get_annotation_paths()

audios = [madmom.audio.signal.load_wave_file(filename)[0] for filename in audio_fnames]
sample_rates = [madmom.audio.signal.load_wave_file(filename)[1] for filename in audio_fnames]
onset_schedules = [np.loadtxt(label_fname, usecols=0) for label_fname in label_fnames]

  file_sample_rate, signal = wavfile.read(filename, mmap=True)


In [38]:
base_path = "results/cnn-training-220409/"
folds_path = base_path + "folds.pkl"

model_name = "added-sample-gen-nostandard"

with open(folds_path, "rb") as f:
    folds = pickle.load(f)

itf = get_idx_to_fold(folds)

TOL = 0.025


In [39]:
CD_list = []
FN_list = []
FP_list = []
for r in range(len(itf.keys())):
    fold = itf[r]
    rec_name = os.path.basename(audio_fnames[r])
    x = get_segmented_data(audio_fnames[r])
    training_name = "finetune-dense-dropout"
    model = tf.keras.models.load_model(base_path + "fold_" + str(fold) + "_" + model_name + "_model")
    out = model.predict(x)
    peaks = get_test_peaks(out, 1./FPS)
    [CD,FN,FP,doubles,merged] = evaluate(onset_schedules[r], peaks, tol_sec=TOL)
    CD_list.append(CD)
    FN_list.append(FN)
    FP_list.append(FP)

    scores = mir_eval.onset.evaluate(onset_schedules[r], peaks, window=TOL)
    print(scores.keys())
    print(rec_name + "\t" + "F-score: {:.2f}".format(100*scores["F-measure"]))



odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest01.wav	F-score: 88.06
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest02.wav	F-score: 91.73
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest03.wav	F-score: 92.19
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest04.wav	F-score: 91.72
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest05.wav	F-score: 83.12
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest06.wav	F-score: 96.49
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest07.wav	F-score: 94.42
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest08.wav	F-score: 94.18
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest09.wav	F-score: 84.97
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest10.wav	F-score: 49.18
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest11.wav	F-score: 64.00
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest12.wav	F-score: 67.42


  file_sample_rate, signal = wavfile.read(filename, mmap=True)


odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest13.wav	F-score: 81.19
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest14.wav	F-score: 82.35
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest15.wav	F-score: 86.61
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest16.wav	F-score: 81.78
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest17.wav	F-score: 60.71
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest18.wav	F-score: 51.28
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest19.wav	F-score: 90.95
odict_keys(['F-measure', 'Precision', 'Recall'])
stormhatten_IR2.wav	F-score: 81.94
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest02_IR1.wav	F-score: 94.49
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest01_IR2.wav	F-score: 93.02
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest03_FK1.wav	F-score: 94.49
odict_keys(['F-measure', 'Precision', 'Recall'])
6xtpsg_220319.wav	F-score: 78.72
odict_keys(['F-measure', 'Precis

  file_sample_rate, signal = wavfile.read(filename, mmap=True)


odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest04_FK1.wav	F-score: 93.51
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest01_IR1.wav	F-score: 92.06
odict_keys(['F-measure', 'Precision', 'Recall'])
63an_start_220306.wav	F-score: 81.82
odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest08_FK1.wav	F-score: 96.91


  file_sample_rate, signal = wavfile.read(filename, mmap=True)


odict_keys(['F-measure', 'Precision', 'Recall'])
slurtest03_IR1.wav	F-score: 90.91
odict_keys(['F-measure', 'Precision', 'Recall'])
stormhatten_IR1.wav	F-score: 87.79


In [40]:
np.sum(CD_list)/(np.sum(CD_list)+.5*(np.sum(FP_list) + np.sum(FN_list)))

0.8505654281098546