In [None]:
import numpy as np

# data

In [None]:
from src.visualization import IndexType

INDEX = IndexType.HORIZONTAL_SNAKE
if INDEX != IndexType.HORIZONTAL_SNAKE:
    raise NotImplemented
DIM = (266, 500)

In [None]:
from src.preprocessing import match_wavelengths

"""
s1_wavelengths = np.load(open('data/s1_wavelengths.npy', 'rb'))
s2_wavelengths = np.load(open('data/s2_wavelengths.npy', 'rb'))
s1_spectra = np.load(open('data/s1.npy', 'rb'))
s2_spectra = np.load(open('data/s2.npy', 'rb'))
"""

N = DIM[0] * DIM[1]
X_labels = np.arange(1000)
y_labels = np.arange(10)
y_true = np.random.uniform(0., 1., (N, y_labels.shape[0]))
y_pred = np.random.uniform(0., 1., (N, y_labels.shape[0]))
X = np.random.beta(0.5, 0.5, (N, X_labels.shape[0]))

# layout

In [None]:
import dash_bootstrap_components as dbc
from dash import html
from dash import dcc
from jupyter_dash import JupyterDash
from src.rowwise_metrics import rowwise_cosine, rowwise_euclid, rowwise_kl_divergence
from src.custom_components import *
from src.visualization import plot_spectra

METRICS = {
    'cosine distance'            : rowwise_cosine,
    'euclidean distance'         : rowwise_euclid,
    'kullback-leibler divergence': rowwise_kl_divergence,
}

dbc_css = "https://cdn.jsdelivr.net/gh/AnnMarieW/dash-bootstrap-templates/dbc.min.css"
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.SLATE, dbc_css])
app.title = 'LIBS Error Analysis'
app.layout = \
place_in_container([

    # intensity map & error map
    # error table
    dbc.Row([
        dbc.Col([
            make_plot(
                dbc.Row([
                    dbc.Col([
                        make_tooltip_title('', 'y_map_title', ''),
                    ]),
                    dbc.Col([
                        make_toggle('Suppress Outliers', 'y_map_outlier_toggle'),
                    ]),
                    dbc.Col([
                        make_toggle('Predicted / Ground Truth', 'y_map_toggle'),
                    ]),
                ]),
                'y_map',
            ),
        ], width=5),
        dbc.Col([
            make_plot(
                dbc.Row([
                    dbc.Col([
                        make_tooltip_title('Error Map', 'error_map_title', 'Results of the error function for each true-predicted pair.'),
                    ]),
                    dbc.Col([
                        make_dropdown('error_metric_dropdown', list(METRICS.keys())),
                    ]),
                ]),
                'error_map',
            ),
        ], width=5),
        dbc.Col([
            dbc.Row([
                make_error_table('Summary Table', 'Summary statistics of the error function results.', 'error_table')
            ]),
            dbc.Row([
                make_button('Reset Selection', 'reset_button')
            ]),
        ], width=2),
    ], align='center'),
    html.Br(),

    # input and output
    dbc.Row([
        dbc.Col([
            make_plot(
                make_tooltip_title(
                    'Model Input',
                    'X_plot_title', 
                    'Spectra fed to the model to gain the predictions of the selected spectra. Only the last spectrum \
                    is displayed along with the average of all selected spectra.',
                ),
                'X_plot',
            )
        ], width=7),
        dbc.Col([
            make_plot(
                make_tooltip_title(
                    'Model Output and the Ground Truth',
                    'y_plot_title',
                    'Output of the model compared to the ground truth. Only the label of the last spectrum is displayed\
                    along with the average of all selected spectra.'
                ),
                'y_plot',
            )
        ], width=5),
    ], align='center',),
    html.Br(),

    html.Div(id = 'selected_points', style = {'display': 'none'}),
])

# callbacks

In [None]:
import plotly.express as px
import plotly.graph_objects as go
import json
from src.visualization import plot_map, intensity_map
from dash import Input, Output, State
from src.visualization import id_from_snake_index, plot_spectra
from dash import callback_context

# stateless workaround

def load_points(memory):
    result = []
    if memory:
        memory = json.loads(memory)
        if memory is not None:
            result = memory
    return result
    

@app.callback(
    Output('selected_points', 'children'),
    Input('y_map', 'clickData'),
    Input('error_map', 'clickData'),
    Input('reset_button', 'n_clicks'),
    State('selected_points', 'children'),
    prevent_initial_call=True)
def update_selected_points(y_map_click, error_map_click, reset, memory):
    # load previously clicked points from the <selected_points> component
    points = load_points(memory)

    # get the newly clicked point
    ctx = callback_context.triggered[0]['prop_id'].split('.')[0]
    if ctx == 'reset_button':
        return json.dumps([])
    elif ctx == 'y_map':
        coordinates = [y_map_click['points'][0]['x'], y_map_click['points'][0]['y']]
    elif ctx == 'error_map':
        coordinates = [error_map_click['points'][0]['x'], error_map_click['points'][0]['y']]

    # add or remove the newly clicked point from memory
    if coordinates in points:
        points.remove(coordinates)
    else:
        points.append(coordinates)

    return json.dumps(points)

In [None]:
# hyperspectral images

@app.callback(
    Output('y_map_title', 'children'),
    Output('y_map_title_tooltip', 'children'),
    Input('y_map_toggle', 'on'),
)
def update_y_map_title(y_pred_selected):
    if y_pred_selected:
        return 'Predicted Map', 'The model predictions plotted for each measured coordinate. Click on the map to see \
            more information about a given point. Use the toggle to see the ground truth.'
    else:
        return 'Ground Truth Map', 'True labels plotted for each measured coordinate. Click on the map to see more \
            information about a given point. Use the toggle to see the model predictions.'


@app.callback(
    Output('y_map', 'figure'),
    Input('y_map_outlier_toggle', 'on'),
    Input('y_map_toggle', 'on'),
    Input('selected_points', 'children'),
)
def update_y_map(suppress_outliers, y_pred_selected, memory):
    fig = intensity_map(
        y_pred if y_pred_selected else y_true,
        dim=DIM,
        index_type=INDEX,
        suppress_outliers=suppress_outliers,
        colorscale=px.colors.sequential.haline,
    )
    points = list(zip(*load_points(memory)))
    if points:
        fig.add_trace(
            go.Scatter(
                x = points[0],
                y = points[1],
                mode='markers',
                marker_color=['red'] * (len(points[0]) - 1) + ['blue'],
                marker_symbol='square-open',
                marker_line_width=2,
                marker_size=20,
                showlegend=False,
            )
        )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        yaxis=dict(scaleanchor='x'),
        margin=dict(l=0, r=0, b=0, t=0,),
        uirevision='None',
    )
    return fig


@app.callback(
    Output('error_map', 'figure'),
    Input('error_metric_dropdown', 'value'),
    Input('selected_points', 'children'),
)
def update_error_map(metric, memory):
    fig = plot_map(
        METRICS[metric](y_true, y_pred),
        dim=DIM,
        index_type=INDEX,
        colorscale=px.colors.sequential.deep,
    )
    points = list(zip(*load_points(memory)))
    if points:
        fig.add_trace(
            go.Scatter(
                x = points[0],
                y = points[1],
                mode='markers',
                marker_color=['red'] * (len(points[0]) - 1) + ['blue'],
                marker_symbol='square-open',
                marker_line_width=2,
                marker_size=20,
            )
        )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        yaxis=dict(scaleanchor='x'),
        margin=dict(l=0, r=0, b=0, t=0,),
        uirevision='None',
    )

    return fig

In [None]:
# individual input and output plot

@app.callback(
    Output('X_plot', 'figure'),
    Input('selected_points', 'children'),
)
def update_X_plot(memory):
    points = load_points(memory)
    if not points:
        points = [[0, 0]]
    ids = [id_from_snake_index(x, y, DIM) for x, y in points]
    x, y = points[-1]

    fig = plot_spectra(
        [X[ids[-1]], X[ids].mean(axis=0)] if len(points) > 1 else [X[ids[-1]]],
        calibration=X_labels,
        labels=[f'({x}, {y})', 'selected mean'],
        colormap=px.colors.qualitative.Set1,
    )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=50,),
    )
    return fig


@app.callback(
    Output('y_plot', 'figure'),
    Input('selected_points', 'children'),
)
def update_y_plot(memory, cmap=px.colors.qualitative.Set1, opacity=.7):
    points = load_points(memory)
    if not points:
        points = [[0, 0]]
    ids = [id_from_snake_index(x, y, DIM) for x, y in points]
    x, y = points[-1]

    fig = go.Figure()
    i = 0
    for y_values, label_start in zip([y_true, y_pred], ['true', 'predicted']):
        for sps, label_end in zip([y_values[ids].mean(axis=0), y_values[ids[-1]]], ['mean', f'at ({x}, {y})']):
            if len(points) <= 1 and label_end == 'mean':
                continue
            i += 1
            fig.add_trace(
                go.Scatter(
                    x = y_labels,
                    y = sps,
                    name = label_start + ' ' + label_end,
                    line = {'color': cmap[i % len(cmap)]},
                    opacity=opacity,
                )
            )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=50,),
        xaxis_title = 'labels',
        yaxis_title = 'relative confidence',
    )
    return fig

In [None]:
# summary table

@app.callback(
    Output('error_table', 'children'),
    Input('error_map', 'figure'),
)
def update_error_table(fig):
    values = np.array(fig['data'][0]['z']).flatten()
    row1 = html.Tr([html.Td("mean error"), html.Td('%s' % float('%.4g' % np.mean(values)))])
    row2 = html.Tr([html.Td("median error"), html.Td('%s' % float('%.4g' % np.median(values)))])
    row3 = html.Tr([html.Td("min error"), html.Td('%s' % float('%.4g' % np.min(values)))])
    row4 = html.Tr([html.Td("max error"), html.Td('%s' % float('%.4g' % np.max(values)))])

    return [html.Tbody([row1, row2, row3, row4])]

# Run

In [None]:
if __name__ == "__main__":
    app.run_server(debug=True)