In [15]:
import pickle
from sklearn.metrics import roc_curve, precision_recall_curve, roc_auc_score, average_precision_score
import matplotlib.pyplot as plt
from evaluator import binarize_Y, interpolate_precision, get_interpolated_avg_precision
import itertools

In [16]:
outdir = 'nonrandom_plots/'
best_predfile = 'best_results/best_nonrandom/Y_test_pred.pickle'
base_predfile = 'results/nb_2019_03_11_23_41/Y_test_pred.pickle'
actual_Yfile = 'best_results/best_nonrandom/Y_test.pickle'
best_mdl = 'Regularized SVM'
base_mdl = 'NB'

In [17]:
best_pred = pickle.load(open(best_predfile, 'rb'))
base_pred = pickle.load(open(base_predfile, 'rb'))
actual_Y = pickle.load(open(actual_Yfile, 'rb'))

In [18]:
Y_predicted_lst = [best_pred, base_pred]
mdlname_lst = [best_mdl, base_mdl]
ovr_labels = [0, 1, 2, 3, 4]
GEN_LABELS = [(0,1,2), 0, 1, 2, 3, 4, 5]
GEN_LAB_STR = {(0,1,2): 'brain', 0: 'forebrain', 1: 'midbrain', 2: 'hindbrain', 3: 'heart', 4: 'limb', 5: 'others'}
FINE_LABELS = [0, 1, 2, 3, 4]
FINE_LAB_STR = {0: 'forebrain', 1: 'midbrain', 2: 'hindbrain', 3: 'heart', 4: 'limb'}

In [19]:
def plot_ovr(Y_actual, Y_predicted_lst, mdlname_lst, pos_labels_lst, label_to_str, outdir, metric = 'auc'):
    outdir += 'ovr/'
    if metric == 'auc':
        title = 'ROC curve'
    elif metric == 'prc':
        title = 'Precision-recall curve'

    for pos_labels in pos_labels_lst:
        labelstr = label_to_str[pos_labels]
        plot_title = title + '(%s vs. others)' % labelstr

        if type(pos_labels) == tuple:
            pos_labels = list(pos_labels)
        else:
            pos_labels = [pos_labels]

        plt.clf()
        plt.title(plot_title)
        for i in range(len(Y_predicted_lst)):
            bin_Y_actual, bin_Y_pred = binarize_Y(Y_actual, Y_predicted_lst[i], pos_labels)
            if metric == 'auc':
                metric_estimate = roc_auc_score(bin_Y_actual, bin_Y_pred)
                x, y, threshold = roc_curve(bin_Y_actual, bin_Y_pred)
                plt.plot(x, y, label='%s (%.2f)' % (mdlname_lst[i], metric_estimate))
            elif metric == 'prc':
                metric_estimate = get_interpolated_avg_precision(bin_Y_actual, bin_Y_pred)
                x, y, threshold = precision_recall_curve(bin_Y_actual, bin_Y_pred)
                x = interpolate_precision(x)
                plt.plot(x, y, label='%s (%.2f)' % (mdlname_lst[i], metric_estimate))
        if metric == 'auc':
            plt.plot([0,1], [0,1], color='black')
            plt.ylabel('True positive rate')
            plt.xlabel('False positive rate')
            plt.legend(loc='lower right')
        elif metric == 'prc':
            plt.ylabel('Average precision')
            plt.xlabel('Recall')
            plt.legend(loc='upper right')
        plt.savefig('%s%s_%s' % (outdir, metric, labelstr))
        
def plot_ovo(Y_actual, Y_predicted_lst, mdlname_lst, labels, label_to_str, outdir, metric = 'auc'):
    outdir += 'ovo/'
    if metric == 'auc':
        title = 'ROC curve'
    elif metric == 'prc':
        title = 'Precision-recall curve'

    for pair in list(itertools.combinations(labels, 2)):
        pos_label = pair[0]
        neg_label = pair[1]
        pos_str = label_to_str[pos_label]
        neg_str = label_to_str[neg_label]
        
        plot_title = title + '(%s vs. %s)' % (pos_str, neg_str)

        plt.clf()
        plt.title(plot_title)
        for i in range(len(Y_predicted_lst)):
            bin_Y_actual, bin_Y_pred = binarize_Y(Y_actual, Y_predicted_lst[i], [pos_label], [neg_label])
            if metric == 'auc':
                metric_estimate = roc_auc_score(bin_Y_actual, bin_Y_pred)
                x, y, threshold = roc_curve(bin_Y_actual, bin_Y_pred)
                plt.plot(x, y, label='%s (%.2f)' % (mdlname_lst[i], metric_estimate))
            elif metric == 'prc':
                metric_estimate = get_interpolated_avg_precision(bin_Y_actual, bin_Y_pred)
                x, y, threshold = precision_recall_curve(bin_Y_actual, bin_Y_pred)
                x = interpolate_precision(x)
                plt.plot(x, y, label='%s (%.2f)' % (mdlname_lst[i], metric_estimate))
        if metric == 'auc':
            plt.plot([0,1], [0,1], color='black')
            plt.ylabel('True positive rate')
            plt.xlabel('False positive rate')
            plt.legend(loc='lower right')
        elif metric == 'prc':
            plt.ylabel('Average precision')
            plt.xlabel('Recall')
            plt.legend(loc='upper right')
        plt.savefig('%s%s_%s_vs_%s' % (outdir, metric, pos_str, neg_str))

In [20]:
plot_ovr(actual_Y, Y_predicted_lst, mdlname_lst, GEN_LABELS, GEN_LAB_STR, outdir, 'auc')
plot_ovr(actual_Y, Y_predicted_lst, mdlname_lst, GEN_LABELS, GEN_LAB_STR, outdir, 'prc')

In [21]:
plot_ovo(actual_Y, Y_predicted_lst, mdlname_lst, FINE_LABELS, FINE_LAB_STR, outdir, 'auc')
plot_ovo(actual_Y, Y_predicted_lst, mdlname_lst, FINE_LABELS, FINE_LAB_STR, outdir, 'prc')