In [1]:
import datetime
import mir_eval
import numpy as np
import os
import pandas as pd
import sys

sys.path.append("../src")
import localmodule

In [7]:
data_dir = "/Users/vl238/spl2017_data"
dataset_name = localmodule.get_dataset_name()
annotations_name = "_".join([dataset_name, "annotations"])
annotations_dir = os.path.join(data_dir, annotations_name)
predictions_name = "_".join([dataset_name, "baseline-predictions"])
predictions_dir = os.path.join(data_dir, predictions_name)
units = localmodule.get_units()
folds = localmodule.fold_units()
n_folds = len(folds)
n_thresholds = 50 # 100
negative_labels = localmodule.get_negative_labels()

# args = sys.argv[1:]
# fold_id = args[0]
# tolerance = int(args[1]) / 1000

fold_id = 0

fold = folds[fold_id]
test_units, training_units, val_units = fold[0], fold[1], fold[2]
training_val_units = training_units + val_units

unit_minima = []
unit_maxima = []
unit_peak_times = []
unit_peak_values = []

for unit_str in training_val_units:
    prediction_name = unit_str + ".npy"
    prediction_path = os.path.join(predictions_dir, prediction_name)
    prediction_matrix = np.load(prediction_path)
    timestamps = prediction_matrix[:, 0]
    odf = prediction_matrix[:, 1]
    peak_locations = localmodule.pick_peaks(odf)
    peak_times = timestamps[peak_locations]
    unit_peak_times.append(peak_times)
    peak_values = odf[peak_locations]
    unit_peak_values.append(peak_values)
    unit_min = np.min(peak_values)
    unit_minima.append(unit_min)
    unit_max = np.max(peak_values)
    unit_maxima.append(unit_max)


global_minimum = min(unit_minima)
global_maximum = max(unit_maxima)
global_delta = global_maximum - global_minimum
threshold_multipliers = 1-np.logspace(np.log10(0.5), np.log10(0.1), n_thresholds)
thresholds = global_minimum + threshold_multipliers * global_delta

tolerance = 0.5

for unit_id in [0]: # range(len(units))
    unit_str = training_val_units[unit_id]
    annotation_name = unit_str + ".txt"
    annotation_path = os.path.join(annotations_dir, annotation_name)
    df = pd.read_csv(annotation_path, "\t")
    relevant_rows = df.loc[~df["Calls"].isin(negative_labels)]
    begin_times = np.array(relevant_rows["Begin Time (s)"])
    end_times = np.array(relevant_rows["End Time (s)"])
    relevant = 0.5 * (begin_times+end_times)
    n_relevant = len(relevant)
    print("Relevant = {}".format(n_relevant))
    
    peak_times = unit_peak_times[unit_id]
    peak_values = unit_peak_values[unit_id]    
    for threshold in reversed(thresholds):
        selected = peak_times[np.where(peak_values>threshold)]
        selected_relevant = mir_eval.util.match_events(relevant, selected, tolerance)
        true_positives = len(selected_relevant)
        n_selected = len(selected)
        false_positives = n_selected - true_positives
        false_negatives = n_relevant - true_positives
        if n_selected == 0 or true_positives == 0:
            precision = 0.0
            recall = 0.0
            fmeasure = 0.0
        else:
            precision = true_positives / n_selected
            recall = true_positives / n_relevant
            fmeasure = 2*precision*recall / (precision+recall)
        print("")
        print(str(datetime.datetime.now()))
        print("Threshold = {}".format(threshold))
        print("Selected = {}".format(n_selected))
        print("True pos = {}".format(true_positives))
        print("False pos = {}".format(false_positives))
        print("False neg = {}".format(false_negatives))
        print("Precision = {}%".format(100*precision))
        print("Recall = {}%".format(100*recall))
        print("F-measure = {}%".format(100*fmeasure))

Relevant = 4730

2017-07-31 16:36:20.653315
Threshold = 0.8506
Selected = 6
True pos = 0
False pos = 6
False neg = 4730
Precision = 0.0%
Recall = 0.0%
F-measure = 0.0%

2017-07-31 16:36:20.654819
Threshold = 0.8487806904976041
Selected = 6
True pos = 0
False pos = 6
False neg = 4730
Precision = 0.0%
Recall = 0.0%
F-measure = 0.0%

2017-07-31 16:36:20.658080
Threshold = 0.8469091745746165
Selected = 7
True pos = 0
False pos = 7
False neg = 4730
Precision = 0.0%
Recall = 0.0%
F-measure = 0.0%

2017-07-31 16:36:20.659711
Threshold = 0.8449839541295062
Selected = 7
True pos = 0
False pos = 7
False neg = 4730
Precision = 0.0%
Recall = 0.0%
F-measure = 0.0%

2017-07-31 16:36:20.665219
Threshold = 0.84300348807162
Selected = 8
True pos = 0
False pos = 8
False neg = 4730
Precision = 0.0%
Recall = 0.0%
F-measure = 0.0%

2017-07-31 16:36:20.667485
Threshold = 0.8409661910875779
Selected = 9
True pos = 0
False pos = 9
False neg = 4730
Precision = 0.0%
Recall = 0.0%
F-measure = 0.0%

2017-07-31 16

KeyboardInterrupt: 