In [14]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
import seaborn as sns
from keras.models import load_model
from collections import OrderedDict
import numpy as np
import plotly
from plotly.offline import iplot
import plotly.graph_objs as go
import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

def f1_score(p, r):
    if p == 0 and r == 0:
        return 0
    else:
        return 2 * p * r / (p + r)

def compute_metric(labels, predictions, mtype="micro"):
    """Comutes the precision, recall and F1 scores

    Args:
        labels      (np.ndarray): the real labels
        predictions (np.ndarray): the model predictions
        mtype           (string): the type of metric; one of 'micro
                                  or 'macro'

    Returns:
        p  (float): the precision
        r  (float): the recall
        f1 (float): the F1 score
    """
    if mtype == 'micro':
        p  = precision_score(labels.flatten(), predictions.flatten())
        r  = recall_score(labels.flatten(), predictions.flatten())
        f1 = f1_score(p, r)

    elif mtype == 'macro':
        ps, rs, f1s = [], [], []
        for i, (label, pred) in enumerate(zip(labels, predictions)):
            p = precision_score(label, pred)
            r = recall_score(label, pred)
            ps.append(p), rs.append(r), f1s.append(f1_score(p, r))

        p, r, f1 = np.mean(ps), np.mean(rs), np.mean(f1s)

    return p, r, f1

def compute_metric_per_instrument(labels, predictions):
    """
    """
    instruments = OrderedDict([("cel", 0.0), ("cla", 0.0), ("flu", 0.0), ("gac", 0.0),
        ("gel", 0.0), ("org", 0.0), ("pia", 0.0), ("sax", 0.0), ("tru", 0.0), 
        ("vio", 0.0), ("voi", 0.0)])

    a_score, ps, rs, f1s = instruments.copy(), instruments.copy(), instruments.copy(), instruments.copy()
    for label, pred, inst in zip(labels.T, predictions.T, instruments):
        a = accuracy_score(label, pred)
        p = precision_score(label, pred)
        r = recall_score(label, pred)
        a_score[inst] = a
        ps[inst] = p
        rs[inst] = r
        f1s[inst] = f1_score(p, r)

    return a_score, ps, rs, f1s

In [15]:
################################ Ground Truth #################################
f = open("y_test.pkl", "rb")
y_test = pickle.load(f)
f.close()
a = list(y_test.values())
y_test = np.zeros((len(a), 11))
for i in range(len(a)):
    y_test[i] = a[i][0]
###############################################################################


############################### Plotly figures ################################
fig_per_inst = plotly.tools.make_subplots(rows=1, cols=1, 
        subplot_titles=("Accuracy vs F-1 Score"))
###############################################################################

f = open("predictions", "rb")
predictions_all_thresholds = pickle.load(f)
f.close()
micro, macro = [], []
max_best_threshold, argmax_best_threshold = -1, -1
for predictions_threshold in predictions_all_thresholds:
    micro.append(compute_metric(y_test, predictions_threshold))
    macro.append(compute_metric(y_test, predictions_threshold, mtype='macro'))
    if macro[-1][2] > max_best_threshold:       #Useful for per_instrument
        argmax_best_threshold = len(macro) - 1
        max_best_threshold = macro[-1][2]

legend = True 
inst_a, inst_p, inst_r, inst_f1 = compute_metric_per_instrument(y_test, predictions_all_thresholds[argmax_best_threshold])

x = list(inst_p.keys()) 
per_inst = [
        go.Histogram(
            histfunc = "sum",
            y = list(inst_a.values()),
            x = x,
            marker=dict(
                color='blue'
                ),
            showlegend = legend,
            name = "accuracy"
            ),
        go.Histogram(
            histfunc = "sum",
            y = list(inst_f1.values()),
            x = x,
            marker=dict(
                color='orange'
                ),
            showlegend = legend,
            name = "f1"
            )
        ] 

fig_per_inst.append_trace(per_inst[0], 1, 1)
fig_per_inst.append_trace(per_inst[1], 1, 1)


FileNotFoundError: [Errno 2] No such file or directory: 'y_test.pkl'

In [None]:
iplot(fig_per_inst, filename = "per_inst.html")