In [1]:
import datetime
import h5py
import librosa
import mir_eval
import numpy as np
import os
import pandas as pd
import peakutils
import scipy.signal
import soundfile as sf
import sys
sys.path.append('../src')
import localmodule

from matplotlib import pyplot as plt

In [2]:
# Define constants.
data_dir = localmodule.get_data_dir()
dataset_name = localmodule.get_dataset_name()
models_dir = localmodule.get_models_dir()
units = localmodule.get_units()
n_units = len(units)
n_false_alarms = 20
test_unit_id = 2
aug_kind_str = "all"

threshold_ids =\
      [[103, 102,  79, 115,   0,  91,   0, 106,   0, 135],
       [205,   0, 207,   0, 189, 170, 206, 193, 191, 184],
       [143,   0, 137, 150,   0, 158,   0, 145, 154, 118],
       [130, 158,   0, 136, 162, 133, 174, 134, 158, 140],
       [134, 131, 203, 126, 140, 131, 123, 121, 125, 139],
       [159, 186, 130, 133, 140, 161, 170, 141, 166, 168]]
threshold_ids = np.array(threshold_ids)

trials =\
      [[4, 8, 9, 6, 3, 0, 2, 1, 5, 7],
       [1, 3, 0, 9, 2, 6, 4, 8, 5, 7],
       [1, 4, 6, 8, 0, 9, 7, 5, 3, 2],
       [2, 9, 1, 5, 6, 3, 4, 8, 7, 0],
       [2, 8, 1, 0, 4, 9, 5, 7, 3, 6],
       [4, 9, 8, 6, 1, 2, 5, 3, 7, 0]]
trials = np.array(trials)

thresholds = 1.0 - np.concatenate((
    np.logspace(-9, -2, 141), np.delete(np.logspace(-2, 0, 81), 0)
))
n_thresholds = len(thresholds)

tolerance = 0.5 # in seconds

In [4]:
# Define directory for annotations.
annotations_name = "_".join([dataset_name, "annotations"])
annotations_dir = os.path.join(data_dir, annotations_name)


# Load annotation.
test_unit_str = units[test_unit_id]
annotation_path = os.path.join(annotations_dir,
    test_unit_str + ".txt")
annotation = pd.read_csv(annotation_path, '\t')
begin_times = np.array(annotation["Begin Time (s)"])
end_times = np.array(annotation["End Time (s)"])
relevant = 0.5 * (begin_times + end_times)
relevant = np.sort(relevant)
high_freqs = np.array(annotation["High Freq (Hz)"])
low_freqs = np.array(annotation["Low Freq (Hz)"])
mid_freqs = 0.5 * (high_freqs + low_freqs)
n_relevant = len(relevant)


# Define directory for model.
aug_str = "all"
model_name = "icassp-convnet"
if not aug_kind_str == "none":
    model_name = "_".join([model_name, "aug-" + aug_kind_str])
model_dir = os.path.join(models_dir, model_name)


# Select trial maximizing validation accuracy.
trial_id = trials[test_unit_id, -1]


# Selet threshold maximizing validation accuracy.
threshold_id = threshold_ids[test_unit_id, trial_id]


# Load prediction.
test_unit_str = units[test_unit_id]
unit_dir = os.path.join(model_dir, test_unit_str)
trial_str = "trial-" + str(trial_id)
trial_dir = os.path.join(unit_dir, trial_str)
prediction_name = "_".join([
    dataset_name,
    model_name,
    "test-" + test_unit_str,
    trial_str,
    "predict-" + test_unit_str,
    "full-predictions.csv"])
prediction_path = os.path.join(trial_dir, prediction_name)
prediction_df = pd.read_csv(prediction_path)
odf = np.array(prediction_df["Predicted probability"])
timestamps = np.array(prediction_df["Timestamp"])


# Select peaks.
threshold = thresholds[threshold_id]
peak_locations = peakutils.indexes(odf, thres=threshold, min_dist=3)
selected = timestamps[peak_locations]


# Match events.
selected_relevant = mir_eval.util.match_events(relevant, selected, tolerance)
tp_relevant_ids = list(zip(*selected_relevant))[0]
tp_relevant_times = [relevant[i] for i in tp_relevant_ids]
tp_selected_ids = list(zip(*selected_relevant))[1]
tp_selected_times = [selected[i] for i in tp_selected_ids]


# Find false negatives
fn_times = [relevant[i] for i in range(len(relevant))
    if i not in tp_relevant_ids]


# Sample false alarms uniformly in relative time (*not* physical time)
downsampling = int(len(fn_times) / n_false_alarms)
fa_times = fn_times[::downsampling]
if len(fa_times) > n_false_alarms:
    fa_times = fa_times[:n_false_alarms]

0.992920542156 7898 9113 3112


[1207.4754380000002,
 6287.2804070000002,
 9088.526170000001,
 14003.949269999999,
 20273.97006,
 24354.39515,
 28633.012220000001,
 32685.898550000002,
 34836.239730000001,
 36661.494579999999,
 37677.98173,
 38217.127869999997,
 38345.342089999998,
 38459.4833,
 38531.461760000006,
 38598.28787,
 38696.585269999996,
 38777.065819999996,
 38852.247019999995,
 38954.255410000005]