In [1]:
from itertools import product
# copy of plot method with added vmin and vmax parameters
# https://github.com/scikit-learn/scikit-learn/blob/cb7271339e56631fe47a22e259c98716f14f6894/sklearn/metrics/_plot/confusion_matrix.py#L82
def ConfusionMatrixDisplay_plot(
    self,
    *,
    include_values=True,
    cmap="viridis",
    vmin=None, vmax=None,
    xticks_rotation="horizontal",
    values_format=None,
    ax=None,
    colorbar=True,
):
    """Plot visualization.
    Parameters
    ----------
    include_values : bool, default=True
        Includes values in confusion matrix.
    cmap : str or matplotlib Colormap, default='viridis'
        Colormap recognized by matplotlib.
    xticks_rotation : {'vertical', 'horizontal'} or float, \
                     default='horizontal'
        Rotation of xtick labels.
    values_format : str, default=None
        Format specification for values in confusion matrix. If `None`,
        the format specification is 'd' or '.2g' whichever is shorter.
    ax : matplotlib axes, default=None
        Axes object to plot on. If `None`, a new figure and axes is
        created.
    colorbar : bool, default=True
        Whether or not to add a colorbar to the plot.
    Returns
    -------
    display : :class:`~sklearn.metrics.ConfusionMatrixDisplay`
    """
    import matplotlib.pyplot as plt

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = ax.figure

    cm = self.confusion_matrix
    n_classes = cm.shape[0]
    self.im_ = ax.imshow(cm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
    self.text_ = None
    cmap_min, cmap_max = self.im_.cmap(0), self.im_.cmap(1.0)

    if include_values:
        self.text_ = np.empty_like(cm, dtype=object)

        # print text with appropriate color depending on background
        thresh = (cm.max() + cm.min()) / 2.0

        for i, j in product(range(n_classes), range(n_classes)):
            color = cmap_max if cm[i, j] < thresh else cmap_min

            if values_format is None:
                text_cm = format(cm[i, j], ".2g")
                if cm.dtype.kind != "f":
                    text_d = format(cm[i, j], "d")
                    if len(text_d) < len(text_cm):
                        text_cm = text_d
            else:
                text_cm = format(cm[i, j], values_format)

            self.text_[i, j] = ax.text(
                j, i, text_cm, ha="center", va="center", color=color
            )

    if self.display_labels is None:
        display_labels = np.arange(n_classes)
    else:
        display_labels = self.display_labels
    if colorbar:
        fig.colorbar(self.im_, ax=ax)
    ax.set(
        xticks=np.arange(n_classes),
        yticks=np.arange(n_classes),
        xticklabels=display_labels,
        yticklabels=display_labels,
        ylabel="True label",
        xlabel="Predicted label",
    )

    ax.set_ylim((n_classes - 0.5, -0.5))
    plt.setp(ax.get_xticklabels(), rotation=xticks_rotation)

    self.figure_ = fig
    self.ax_ = ax
    return self

In [3]:
%matplotlib inline
import os
import numpy as np
from datetime import datetime
from IPython.display import display
from ipywidgets import interact, interact_manual
import ipywidgets as widgets
from study_common import *
from study_notebook_helper import *
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.metrics._plot.confusion_matrix import confusion_matrix
import itertools
import random

print(plt.rcParams['figure.dpi'])
plt.rcParams['figure.dpi']= 150

d = len(TAXONOMY)
study_datasets = [name[len("dataset-"):] for name in os.listdir(STUDY_RESULTS_PATH) if name.startswith("dataset-")]
@interact(dataset=study_datasets)
def main(dataset=study_datasets[0]):
    dataset_name = dataset[:-len(".pickle")]
    
    print()

    result_vote_names = [None] + [name[len("votes-"):-(len(dataset) + 1)] for name in os.listdir(STUDY_RESULTS_PATH) if name.startswith("votes-") and dataset_name in name]
    if len(result_vote_names) <= 1:
        print("Oh no, there is no data yet! Please go and do the study!")
        return
    print("Chose votes to anaylze for " + dataset_name)
    def chose_votes(names):
        @interact(result_name=result_vote_names)
        def analyze(result_name=result_vote_names[0]):
            if result_name is not None:
                chose_votes(names + [result_name])
                return
            if len(names) == 0:
                print("Select a name!")
                return
            print(names)
            
            all_correct_labels = []
            all_votes = []
            
            for name in names:
                study: STUDY_TYPE = None
                with open(STUDY_RESULTS_PATH + "dataset-" + dataset, "rb") as data_file:
                    study = pickle.load(data_file)
                # print("This is a study on " + study[0])
                random.Random(name + dataset).shuffle(study[1])  # TODO NO!
                correct_labels = [entry[0] for entry in study[1]]
                votes = None
                with open(STUDY_RESULTS_PATH + "votes-" + name + "-" + dataset, "rb") as votes_file:
                    votes = pickle.load(votes_file)
                # for i in range(len(votes)):
                #     votes[i] = [4, 2, 3, 5, 0, 6, 1][votes[i]]  # convert results from old ordering to new if needed

                # print(correct_labels)
                # print(votes)
                all_correct_labels += correct_labels
                all_votes += votes
            # print(all_correct_labels)
            # print(all_votes)
            
            @interact(reorder_exponent=(1.0,5.0))
            def plot_results(reorder_exponent=2):
                display_labels = [t[1] for t in TAXONOMY]
                if True:  # reorder the matrix columns to have everything as close as possible to the middle!
                    confusion = [[0] * d for t in range(d)]
                    for correct, pred in zip(all_correct_labels, all_votes):
                        confusion[correct][pred] += 1
                    print(np.array(confusion))
                    key=lambda order: sum((confusion[order[i1]][order[i2]] + confusion[order[i2]][order[i1]]) * (abs(i1 - i2) ** reorder_exponent) for i1, i2 in itertools.combinations(range(d), 2))
                    best_ordering = min(
                        itertools.permutations(range(d), d),
                        key=key
                    )
                    print("Amount of best options: "  + str(key(best_ordering)) + " : " + str(len([x for x in itertools.permutations(range(d), d) if abs(key(x) - key(best_ordering)) < 0.001])))
                    # best_ordering = random.choice(list(itertools.permutations(range(d), d)))  # use a random one!
                    # best_ordering = (2, 1, 5, 0, 3, 6, 4)
                    best_ordering = (1, 6, 3, 4, 2, 5, 0)
                    print(best_ordering)
                    best_ordering_inv = [0 for i in range(d)]
                    for i, val in enumerate(best_ordering):
                        best_ordering_inv[val] = i
                    print(best_ordering_inv)

                    display_labels = [TAXONOMY[best_ordering[ti]][1] for ti in range(len(TAXONOMY))]
                    for i in range(len(all_votes)):
                        all_votes[i] = best_ordering_inv[all_votes[i]]
                        all_correct_labels[i] = best_ordering_inv[all_correct_labels[i]]

                # https://github.com/scikit-learn/scikit-learn/blob/cb7271339e56631fe47a22e259c98716f14f6894/sklearn/metrics/_plot/confusion_matrix.py#L566
                cm = confusion_matrix(all_correct_labels, all_votes, normalize=None)  # normalize="true"
                disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=display_labels)
                matrix_plot = ConfusionMatrixDisplay_plot(disp,
                    include_values=True,
                    cmap='viridis',
                    vmin=0, vmax=len(all_votes)/7,
                    ax=None,
                    xticks_rotation='vertical',# 'horizontal',
                    values_format=None,
                    colorbar=True,
                )
                # https://stackoverflow.com/a/43153984/4354423
                plt.setp(matrix_plot.ax_.get_xticklabels(), rotation=30, ha="right")
                print(f"Amount of correct answers: {sum(confusion[i][i] for i in range(len(TAXONOMY)))} of {len(all_votes)} (expected random: {len(all_votes) / len(TAXONOMY)})")
                print_html(ADD_LINKS_TO_RENDERED_IMAGES)
    chose_votes([])
        

72.0


interactive(children=(Dropdown(description='dataset', options=('vanzin_jEdit.pickle', 'junit-team_junit4.pickl…