In [1]:
import pickle
import math
import ipywidgets as widgets
from ipywidgets import interactive, fixed, IntProgress
from plotly_colors import colors
import random
import numpy as np
from IPython.display import display
from load_data import load_em_history, LoadedData, get_measures_from_history, labels_dict, \
    get_measures_mean_across_iterations, Measures
import itertools
import plotly.graph_objs as go
from plotly.subplots import make_subplots

In [2]:
def plot_em_results(arg: LoadedData, fig, row, col):
    if arg is None:
        return 
    with open(arg.file_name, 'rb') as f:
        history = pickle.load(f)
        
    measures = get_measures_from_history(history)
    
    for name, measure in measures:
        random.seed(a=name)
        fig.add_trace(go.Scatter(
            y=measure,
            x=np.array(range(len(measure))),
            mode='lines',
            name=name,
            visible="legendonly",
            legendgroup=name,
            showlegend=row == 1 and col ==1,
            line={'color': random.choice(colors)},
        ), row=row, col=col)
    

<a id='plot_one'></a>
## Compare classifiers on one iteration
Plot one or more classifier results. Left click to select one, hold `ctrl` key + left click to select more than one.
You can select only one iteration at a time. If instead you wish to see the results mean across all iterations, please
go to [Compare classifiers for all iterations (one or more measures)](#plot_mean)

In [None]:
def plot_multiple_em_results(args):
    if args is None:
        return 
    if type(args) != list:
        args = [args]
    global fig
    if fig is not None and type(fig) is go.FigureWidget:
        fig.close()
        
    progress = IntProgress(min=0, max=100, description="Loading: 0%")
    display(progress)
    fig = make_subplots(rows=math.ceil(len(args) / 2) if len(args) > 1 else 1, cols=2, 
                        subplot_titles=[arg.clf_name for arg in args])
    
    for i, data in enumerate(args):
        plot_em_results(data, fig, row=math.ceil((i+1) / 2), col=1 if (i+1) % 2 != 0 else 2)
        progress.value += 100 / len(args)
        progress.description = f"Loading: {progress.value}%"
        
    fig.update_layout(autosize=False, width=900, 
                      height=400*(math.log2(len(args)) if len(args) > 1 else 1))
    fig.update_yaxes(automargin=True)
    fig.update_xaxes(automargin=True)
    progress.close()
    fig = go.FigureWidget(fig)
    display(fig)
    
    
def select_random_execution(clf_names, data):
    if not clf_names:
        return 
    filtered_list = sorted(filter(lambda el: el.clf_name in clf_names, data), key=lambda k: k.it)
    it_list = list()
    for it, group in itertools.groupby(filtered_list, lambda el: el.it):
        it_list.append((f"Iteration: {it}", list(group)))
        
    wid = widgets.Select(options=it_list, description="Random run")
    w_disp = interactive(plot_multiple_em_results, args=wid)
    display(w_disp)
        
    
fig = None  # storing last figure so that we can later close it to free resources
em_data = list(load_em_history())
classifiers = {f"{el.clf_name}{'_' + el.dataset if el.dataset else ''}" for el in em_data}
w = widgets.SelectMultiple(options=sorted(list(classifiers)), description="Classifiers")

w_disp = interactive(select_random_execution, clf_names=w, data=fixed(em_data))
display(w_disp)


<a id='plot_mean'></a>
## Compare classifiers for all iterations (one or more metrics)
In this section you can select one or more metrics and visualize its/their means across all the experimental iterations,
for all classifiers. If instead you wish to compare classifiers on one iteration, go to [Compare classifiers on one iteration](#plot_one)

In [3]:
def plot_metrics(metrics, data):
    if not metrics:
        return 
    global fig
    if fig is not None and type(fig) is go.FigureWidget:
        fig.close()
        
    fig = make_subplots(rows=math.ceil(len(metrics) / 2), cols=2, subplot_titles=[labels_dict[m] for m in metrics], shared_yaxes=True)
    progress = IntProgress(min=0, max=100, description="Loading: 0%")
    display(progress)
    metrics_clf = [(clf_name, get_measures_mean_across_iterations(measures, metrics)) for clf_name, measures in data]
        
    i = 0
    for metric in metrics:
        measures = list(map(lambda el: (el[0], el[1].__getattribute__(metric)), metrics_clf))
        row = math.ceil((i+1) / 2)
        col = 1 if (i+1) % 2 != 0 else 2
        picked_colors = set()
        max_len = max(len(m[1]) for _, m in measures)
        for clf_name, measure in measures:
            measure = measure[1]
            random.seed(a=hash(clf_name))
            color = random.choice(colors)
            
            # Keep selecting a color until we find one we haven't already used
            while color in picked_colors:
                color = random.choice(colors)
            picked_colors.add(color)
            fig.add_trace(go.Scatter(
                y=np.pad(measure, (0, max_len - len(measures)), mode='edge'),
                x=np.array(range(max_len)),
                mode='lines',
                name=clf_name,
                visible="legendonly",
                legendgroup=clf_name,
                showlegend=row == 1 and col ==1,
                line={'color': color},
            ), row=row, col=col)
            progress.value += 100 / len(metrics)
            progress.description = f"Loading: {progress.value}%"
        i += 1
    
    fig.update_layout(autosize=False, width=900, 
                      height=400*(math.log2(len(metrics)) if len(metrics) > 1 else 1))
    fig.update_yaxes(automargin=True)
    fig.update_xaxes(automargin=True)
    progress.close()
    fig = go.FigureWidget(fig)
    display(fig)
                    
                
fig = None
with open('./pickles/measures_all_em_rcv1_30-01-20.pkl', 'rb') as f:
    data = pickle.load(f)
    
w = widgets.SelectMultiple(options=sorted(map(lambda el: (el[1], el[0]), labels_dict.items())), description="Metrics")
w_disp = interactive(plot_metrics, metrics=w, data=fixed(data))
display(w_disp)

interactive(children=(SelectMultiple(description='Metrics', options=(('Abs. err.', 'abs_errors'), ('Brier', 'b…