In [1]:
import os
import json
import warnings
import itertools
import numpy as np
import pandas as pd
import polar_diagrams
from sklearn.datasets import load_wine
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
from sklearn.model_selection import GridSearchCV

from dash import Dash, dcc, html, Input, Output, callback, State, ctx, Patch
import dash_bootstrap_components as dbc
from dash.exceptions import PreventUpdate
import plotly.graph_objects as go


_INT_CHART_WIDTH = 1400
_INT_CHART_HEIGHT = 500

In [2]:
df_wine_data = load_wine(return_X_y=False, as_frame=True)['data']
df_wine_data['od_diluted'] = df_wine_data['od280/od315_of_diluted_wines']
df_wine_data.drop(['od280/od315_of_diluted_wines', 'proline'], axis=1,
                  inplace=True)
df_wine_data

Unnamed: 0,alcohol,malic_acid,ash,alcalinity_of_ash,magnesium,total_phenols,flavanoids,nonflavanoid_phenols,proanthocyanins,color_intensity,hue,od_diluted
0,14.23,1.71,2.43,15.6,127.0,2.80,3.06,0.28,2.29,5.64,1.04,3.92
1,13.20,1.78,2.14,11.2,100.0,2.65,2.76,0.26,1.28,4.38,1.05,3.40
2,13.16,2.36,2.67,18.6,101.0,2.80,3.24,0.30,2.81,5.68,1.03,3.17
3,14.37,1.95,2.50,16.8,113.0,3.85,3.49,0.24,2.18,7.80,0.86,3.45
4,13.24,2.59,2.87,21.0,118.0,2.80,2.69,0.39,1.82,4.32,1.04,2.93
...,...,...,...,...,...,...,...,...,...,...,...,...
173,13.71,5.65,2.45,20.5,95.0,1.68,0.61,0.52,1.06,7.70,0.64,1.74
174,13.40,3.91,2.48,23.0,102.0,1.80,0.75,0.43,1.41,7.30,0.70,1.56
175,13.27,4.28,2.26,20.0,120.0,1.59,0.69,0.43,1.35,10.20,0.59,1.56
176,13.17,2.59,2.37,20.0,120.0,1.65,0.68,0.53,1.46,9.30,0.60,1.62


In [10]:
def _grid_search(df_left_input, string_reference_model, list_measures):

    # We save the row with the reference model
    # TODO: Implement check for duplicate reference models in the library
    df_reference_row = df_left_input.loc[
        df_left_input['Model'] == string_reference_model]
    # We remove the reference row from the dataframe
    df_input_no_reference = df_left_input.drop(
        df_reference_row.index)[list_measures]

    list_min_samples = np.arange(2, 15, step=2)
    list_epsilons = np.linspace(0.01, 5, num=50)
    list_hyperparam = list(itertools.product(list_epsilons, list_min_samples))

    list_scores = []
    list_labels_over_runs = []

    for i, (float_eps, int_min_samples) in enumerate(list_hyperparam):
        constructor_DBSCAN = DBSCAN(
            eps=float_eps, min_samples=int_min_samples, n_jobs=-1)
        constructor_DBSCAN.fit_predict(df_input_no_reference)
        list_labels = constructor_DBSCAN.labels_

        # We check if we have all outliers or all elements in seperate clusters
        # These are the edge cases which we do not want
        if len(set(list_labels)) == 1 or (
                len(set(list_labels)) == len(list_labels)):
            continue

        list_scores.append(
            silhouette_score(df_input_no_reference, list_labels))
        list_labels_over_runs.append(list_labels)

    int_best_score_index = np.argmax(list_scores)
    np_array_best_labels = list(list_labels_over_runs[int_best_score_index])
    # We add the label for the reference model at the same place that model
    # was before we removed the entire row it was contained in
    # We add a value of df_input.shape[0] because the model must not be a part
    # of any cluster
    np_array_best_labels.insert(
        df_reference_row.index.values[0], df_left_input.shape[0])

    tuple_best_hyperparam = list_hyperparam[int_best_score_index]

    return tuple_best_hyperparam, np_array_best_labels


def _tuple_group_left_dataframe(df_left_input, string_reference_model):
    dict_aggregate_rules = {'Model': '; '.join}
    for i in df_left_input.columns.to_list():
        if i not in ['Model', 'Label']:
            dict_aggregate_rules[i] = 'mean'

    df_grouped_rows = df_left_input.groupby(
        'Label', as_index=False, sort=False).agg(
        dict_aggregate_rules).reset_index(drop=True)
    df_grouped_rows['Cluster Count'] = [
        int_i.count('; ') + 1 for int_i in list(df_grouped_rows['Model'])]

    dict_model_cluster_correspondence = {}
    list_new_model_names = []
    for int_i, str_model in enumerate(df_grouped_rows['Model'].to_list()):

        # The code below names only Clusters those traces that have multiple
        # model names. If there is only one model name, then it is left as is
        # and not named 'Cluster number'
        #if '; ' in str_model:
        #    list_new_model_names.append('Cluster ' + str(int_i + 1))
        #else:
        #    list_new_model_names.append(str_model)

        # The code below names only Clusters those traces that are different
        # than the reference model
        if string_reference_model == str_model:
            list_new_model_names.append(str_model)
        else:
            list_new_model_names.append('Cluster ' + str(int_i + 1))

        for str_one_model in str_model.split('; '):
            if str_one_model == string_reference_model:
                dict_model_cluster_correspondence[
                    str_one_model] = string_reference_model
            else:
                dict_model_cluster_correspondence[
                    str_one_model] = 'Cluster ' + str(int_i + 1)

    df_grouped_rows['Model'] = list_new_model_names

    return df_grouped_rows, dict_model_cluster_correspondence


def _chart_create_left_chart(df_grouped_data, string_reference_model,
                            string_diagram_type, string_mid_type):
    chart_left = polar_diagrams.polar_diagrams._chart_create_diagram(
        [df_grouped_data],
        string_reference_model=string_reference_model,
        string_diagram_type=string_diagram_type,
        string_mid_type=string_mid_type,
        bool_normalized_measures=False)

    dict_left = chart_left.to_dict()
    for int_i in range(len(dict_left['data'])):
        #dict_left['data'][int_i]['showlegend'] = False

        if dict_left['data'][int_i]['name'].split(
                '. ')[1] == string_reference_model:
            continue

        dict_left['data'][int_i]['mode'] = 'markers+text'
        dict_left['data'][int_i]['text'] = '<b>' + str(
            df_grouped_data['Cluster Count'][int_i]) + '</b>'
        dict_left['data'][int_i]['marker']['color'] = 'rgba(100,100,100,0)'
        dict_left['data'][int_i]['marker']['size'] = 20 + int(
            10 * int(df_grouped_data['Cluster Count'][int_i])/10)
        dict_left['data'][int_i]['marker']['line']['color'] = 'rgba(0,0,0,0.5)'

    return go.Figure(dict_left)


def _tuple_create_initial_left_diagram(df_input, string_reference_model,
                                       string_diagram_type, string_mid_type):
    # Here we create a DataFrame for the left chart with the clustered models
    if string_diagram_type == 'taylor':
        df_left_input = polar_diagrams.df_calculate_td_properties(
            df_input, string_reference_model)
        list_relevant_measures = ['Standard Deviation', 'Correlation', 'CRMSE']
    else:
        df_left_input = polar_diagrams.df_calculate_mid_properties(
            df_input, string_reference_model)
        if string_mid_type == 'scaled':
            list_relevant_measures = ['Entropy', 'Scaled MI', 'VI']
        else:
            list_relevant_measures = ['Root Entropy', 'Normalized MI', 'RVI']

    tuple_hyperparam, np_array_labels = _grid_search(
        df_left_input,
        string_reference_model=string_reference_model,
        list_measures=list_relevant_measures)
    df_left_input['Label'] = np_array_labels

    df_left_grouped, dict_model_cluster = _tuple_group_left_dataframe(
        df_left_input, string_reference_model)

    chart_left = _chart_create_left_chart(
        df_left_grouped, string_reference_model, string_diagram_type,
        string_mid_type).update_layout(
        dragmode='zoom', clickmode='event+select', hovermode=False,
        width=round(_INT_CHART_WIDTH/3),
        height=_INT_CHART_HEIGHT-100,
        margin={'l':40, 'r':40})

    return chart_left, dict_model_cluster


def _tuple_create_initial_right_diagram(df_input, string_reference_model,
                                        string_diagram_type, string_mid_type):

    list_warning_caught = None
    # We monkey patch the function that prints the warnings so that it doesn't
    # require some inputs and only returns the warning message that we need
    warnings.formatwarning = lambda msg, *args, **kwargs: str(msg)

    if string_diagram_type == 'mid':
        with warnings.catch_warnings(record=True) as warning_tmp:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("default")
            chart_right = polar_diagrams.chart_create_mi_diagram(
                df_input, string_reference_model=string_reference_model,
                string_mid_type=string_mid_type).update_layout(
                dragmode='select', clickmode='event+select',
                width=int(_INT_CHART_WIDTH),
                height=_INT_CHART_HEIGHT*1.4,
                margin={'l':150, 'r':40})

            list_warning_caught = warning_tmp
    else:
        with warnings.catch_warnings(record=True) as warning_tmp:
            # Cause all warnings to always be triggered.
            warnings.simplefilter("default")
            chart_right = polar_diagrams.chart_create_taylor_diagram(
                df_input,
                string_reference_model=string_reference_model).update_layout(
                dragmode='select', clickmode='event+select',
                width=int(_INT_CHART_WIDTH),
                height=_INT_CHART_HEIGHT*1.4,
                margin={'l':150, 'r':40})

            list_warning_caught = warning_tmp

    string_warnings = ''
    int_i = 1
    for warning_tmp in list_warning_caught:
        if 'RuntimeWarning' in warnings.formatwarning(warning_tmp):
            string_one_warning = warnings.formatwarning(
                warning_tmp)[11:208].replace('\\n', ' ')
            if string_one_warning in string_warnings:
                continue
            else:
                string_warnings += str(int_i) + '. ' + string_one_warning
                int_i += 1

    return chart_right, string_warnings


def _tuple_style_both_diagrams(chart_left, chart_right):

    # We use the same radial and angular axis range for both diagrams. This
    # fixes the edge cases where we have different axis ranges because of the
    # left overview diagram. This diagram can have for example the angular axis
    # 0-90 and not 0-180 as the right diagram because of the aggregation of
    # some models during clustering (thus aggregating their coordinates)
    chart_left.update_layout(
        title=None,
        polar_radialaxis_range=chart_right[
            'layout']['polar']["radialaxis"]["range"],
        polar_radialaxis_ticklen=0,
        polar_radialaxis_showticklabels=False,
        polar_radialaxis_linewidth=0.5,
        polar_radialaxis_layer='below traces',
        polar_radialaxis_autorange=False,
        polar_radialaxis_rangemode='normal',
        polar_radialaxis_title=None,
        polar_angularaxis=chart_right['layout']['polar']["angularaxis"],
        polar_angularaxis_layer='below traces',
        polar_angularaxis_ticklen=0,
        polar_angularaxis_showticklabels=False,
        polar_angularaxis_linewidth=0.5,
        polar_sector=[
            0, chart_right['layout']['polar']["angularaxis"]['tickvals'][0]],
    )

    # We vertically orient the legend of the right diagram
    chart_right.update_layout(
        legend_orientation='v',
        legend_x=-0.3,
        legend_y=1)

    # We disable a legend for the second diagram by traversing traces
    #dict_right = chart_right.to_dict()
    #for int_i in range(len(dict_right['data'])):
    #    dict_right['data'][int_i]['showlegend'] = False
    #chart_right = go.Figure(dict_right)

    return chart_left, chart_right


def app_create_dashboard(df_input, string_reference_model,
                         string_diagram_type='taylor',
                         string_mid_type='normalized'):
    dash_app = Dash("Polar Diagrams Dashboard",
                    external_stylesheets=[dbc.themes.BOOTSTRAP],
                    meta_tags=[{"name": "viewport",
                                "content": "width=device-width"}],)
    dash_app.title = "Polar Diagrams Dashboard"
    dash_app.css.config.serve_locally = True
    dash_app.scripts.config.serve_locally = True

    # ====================================================
    # TODO: Raise an exception if a list of dataframes is provided where the
    # TODO: second data set is not with scalar values
    # TODO: We don't want to support two-version model functionality
    # ====================================================
    list_valid_diagram_types = ['taylor', 'mid']
    list_valid_mid_types = ['scaled', 'normalized']

    if string_diagram_type not in list_valid_diagram_types:
        raise ValueError('string_diagram_type not in ' +
                         str(list_valid_diagram_types))

    if string_diagram_type == 'mid' and (
            string_mid_type not in list_valid_mid_types):
        raise ValueError('string_mid_type not in ' +
                         str(list_valid_mid_types))

    chart_left, dict_model_cluster = _tuple_create_initial_left_diagram(
        df_input, string_reference_model, string_diagram_type, string_mid_type)

    chart_right, string_warnings = _tuple_create_initial_right_diagram(
        df_input, string_reference_model, string_diagram_type, string_mid_type)

    chart_left, chart_right = _tuple_style_both_diagrams(
        chart_left, chart_right)

    global _FLOAT_MAX_R
    _FLOAT_MAX_R = chart_left['layout']['polar']['radialaxis']['range'][1]
    global _INT_MAX_THETA
    _INT_MAX_THETA = chart_left['layout']['polar']['angularaxis'][
        'tickvals'][0]

    dash_app.layout = dbc.Container(
        [
            dbc.Row(
                    html.Div(
                        html.H1("Polar Diagrams Dashboard"),
                        style={"font-family": 'open sans',
                               'margin-top': 50,
                               'margin-bottom': 50})),
            dbc.Row(
                [
                    dbc.Col([
                        dcc.Graph(
                            id="chart-left",
                            figure=chart_left,
                            config={
                                'modeBarButtonsToRemove': [
                                    'zoom', 'select', 'pan', 'lasso', 'zoomIn',
                                    'zoomOut', 'autoScale', 'resetScale'],
                                'staticPlot': False,
                                'displaylogo': False,
                                'showAxisDragHandles': False}),
                        html.Div(
                            dbc.Alert(
                                string_warnings,
                                color="warning",
                                is_open=True if string_warnings != '' else False))],
                        width=3,
                        align='start',
                        style={'border': '1px solid', 'margin-left': 0, 'margin-right':0}),
                    dbc.Col(
                        dcc.Graph(
                            id="chart-right",
                            figure=chart_right,
                            config={
                                'modeBarButtonsToRemove': [
                                    'zoom', 'pan', 'lasso', 'zoomIn',
                                    'zoomOut', 'select', 'autoScale',
                                    'resetScale'],
                                'displaylogo': False,
                                'showAxisDragHandles': False}),
                        width=True,
                        align='start',
                        style={'border': '1px solid'})
                ],
                className="g-0",
                justify="center",
            ),
        ],
        fluid=True)

    dash_app.run(debug=True, jupyter_mode='external')

    return None

'''
@callback(
    Output(component_id="chart-left", component_property="figure",
           allow_duplicate=True),
    Output(component_id="chart-right", component_property="figure",
           allow_duplicate=True),
    Input(component_id="chart-left", component_property="restyleData"),
    State('chart-left', 'figure'),
    State('chart-right', 'figure'),
    prevent_initial_call=True,
)
def _list_update_legends(list_legend_points, dict_left, dict_right):
    # ====================================================
    # TODO: Combine the two callbacks by using the following context property
    # TODO: list(ctx.triggered_prop_ids.keys())[0].split('.')[1]
    # TODO: This will either return restyleData or relayoutData
    # ====================================================
    chart_left_updated = Patch()
    chart_right_updated = Patch()
    if list_legend_points:
        # Legend click gives the following output
        # [{"visible": ["legendonly"]}, [10]]
        # [{"visible": [true]}, [1]]

        for int_i, int_legend_point in enumerate(list_legend_points[1]):
            for int_j, dict_one_trace in enumerate(dict_right['data']):
                if dict_one_trace['name'].startswith(
                        str(int_legend_point) + '.'):
                    if isinstance(
                        list_legend_points[0]['visible'][int_i], bool) and (
                        list_legend_points[0]['visible'][int_i], bool == True):
                        chart_right_updated['data'][int_j][
                            'visible'] = True
                    else:
                        chart_right_updated['data'][int_j][
                            'visible'] = False
    else:
        raise PreventUpdate

    return chart_left_updated, chart_right_updated
'''

@callback(
    Output(component_id="chart-left", component_property="figure"),
    Output(component_id="chart-right", component_property="figure"),
    Input(component_id="chart-left", component_property="relayoutData"),
    State('chart-left', 'figure'),
    State('chart-right', 'figure'),
    prevent_initial_call=True
)
def _list_update_zooms(dict_selected_range, dict_left, dict_right):

    # ==============================================
    # TODO: Update radial and angular axis of the right chart with the gray
    # TODO: color to indicate the success of zooming
    # ==============================================

    chart_left_updated = Patch()
    chart_right_updated = Patch()

    if dict_selected_range and (
            'polar.radialaxis.range' in dict_selected_range):

        dict_radial_range = dict_selected_range['polar.radialaxis.range']

        for int_i, trace in enumerate(dict_left['data']):
            if 'name' in trace and trace['name'] == 'Selection':
                del chart_left_updated['data'][int_i]

        # Here we check if double click was not detected. If it was detected
        # we just had to remove the Selection trace, which we did above.
        # If it was not detected, that means we have to create a new Selection
        # {  'polar.angularaxis.rotation': 0,
        #    'polar.radialaxis.angle': 0,
        #    'polar.radialaxis.range': [0, 16.353330541878254]
        # }
        if 'polar.angularaxis.rotation' not in dict_selected_range and (
                'polar.radialaxis.angle' not in dict_selected_range):

            # We create a circular rectangle of 60 points by creating them and
            # connecting them with a line
            np_alpha = np.linspace(0, _INT_MAX_THETA, 60).tolist()
            np_selection_theta = np_alpha + np_alpha[::-1] + [np_alpha[0]]

            chart_left_updated['data'].append(
                go.Scatterpolar(r=[dict_radial_range[0]]*60 +\
                                  [dict_radial_range[1]]*60 +\
                                  [dict_radial_range[0]],
                                theta =np_selection_theta,
                                name='Selection',
                                fill='toself',
                                mode='lines',
                                showlegend=False,
                                line=dict(
                                    color='lightgrey',
                                    dash='dot',
                                    width=2)))

        chart_left_updated['layout']['polar']["radialaxis"][
            "autorange"] = False
        chart_left_updated['layout']['polar']["radialaxis"][
            'rangemode'] = 'normal'
        chart_right_updated['layout']['polar']["radialaxis"][
            "autorange"] = False
        chart_right_updated['layout']['polar']["radialaxis"][
            'rangemode'] = 'normal'

        chart_left_updated['layout']['polar']["radialaxis"][
            "range"] = [0, _FLOAT_MAX_R]
        chart_right_updated['layout']['polar']["radialaxis"][
            "range"] = [dict_radial_range[0], dict_radial_range[1]]

    else:
        raise PreventUpdate

    return chart_left_updated, chart_right_updated


app_create_dashboard(df_wine_data, 'alcohol', 'mid', 'scaled')

{'alcohol': 'alcohol', 'malic_acid': 'Cluster 2', 'flavanoids': 'Cluster 2', 'color_intensity': 'Cluster 2', 'od_diluted': 'Cluster 2', 'ash': 'Cluster 3', 'hue': 'Cluster 3', 'alcalinity_of_ash': 'Cluster 4', 'magnesium': 'Cluster 4', 'total_phenols': 'Cluster 5', 'proanthocyanins': 'Cluster 5', 'nonflavanoid_phenols': 'Cluster 6'}


ValueError: too many values to unpack (expected 2)