In [1]:
import numpy as np

# data

In [2]:
from src.visualization import IndexType

INDEX = IndexType.HORIZONTAL_SNAKE
if INDEX != IndexType.HORIZONTAL_SNAKE:
    raise NotImplemented
DIM = (266, 500)

In [3]:
from src.preprocessing import match_wavelengths

"""
s1_wavelengths = np.load(open('data/s1_wavelengths.npy', 'rb'))
s2_wavelengths = np.load(open('data/s2_wavelengths.npy', 'rb'))
s1_spectra = np.load(open('data/s1.npy', 'rb'))
s2_spectra = np.load(open('data/s2.npy', 'rb'))
"""

N = DIM[0] * DIM[1]
X_labels = np.arange(1000)
y_labels = np.arange(10)
y_true = np.random.uniform(0., 1., (N, y_labels.shape[0]))
y_pred = np.random.uniform(0., 1., (N, y_labels.shape[0]))
X = np.random.beta(0.5, 0.5, (N, X_labels.shape[0]))

# layout

In [4]:
import dash_bootstrap_components as dbc
from dash import html
from dash import dcc
from jupyter_dash import JupyterDash
from src.rowwise_metrics import rowwise_cosine, rowwise_euclid, rowwise_kl_divergence
from src.custom_components import place_in_container, make_dropdown, make_plot, make_error_table, make_tooltip_title, make_toggle
from src.visualization import plot_spectra

METRICS = {
    'cosine distance'            : rowwise_cosine,
    'euclidean distance'         : rowwise_euclid,
    'kullback-leibler divergence': rowwise_kl_divergence,
}

# TODO what if calibration

dbc_css = "https://cdn.jsdelivr.net/gh/AnnMarieW/dash-bootstrap-templates/dbc.min.css"
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.SLATE, dbc_css])
app.title = 'LIBS Error Analysis'
app.layout = \
place_in_container([

    # intensity map & error map
    # error table
    dbc.Row([
        dbc.Col([
            make_plot(
                dbc.Row([
                    dbc.Col([
                        make_tooltip_title('Original map', 'y_map_title', 'TBD'),
                    ]),
                    dbc.Col([
                        make_toggle('', 'Suppress outliers', 'y_map_outlier_toggle'),
                    ]),
                    dbc.Col([
                        make_toggle('y_true', 'y_pred', 'y_map_toggle'),
                    ]),
                ]),
                'y_map',
            ),
        ], width=5),
        dbc.Col([
            make_plot(
                dbc.Row([
                    dbc.Col([
                        make_tooltip_title('Error map', 'error_map_title', 'TBD'),
                    ]),
                    dbc.Col([
                        make_dropdown('error_metric_dropdown', list(METRICS.keys())),
                    ]),
                ]),
                'error_map',
            ),
        ], width=5),
        dbc.Col([
            make_error_table('Summary table', 'error_table')
        ], width=2),
    ], align='center'),
    html.Br(),

    # input and output
    dbc.Row([
        dbc.Col([
            make_plot(
                make_tooltip_title('Model input', 'X_plot_title', 'TBD'),
                'X_plot',
            )
        ], width=7),
        dbc.Col([
            make_plot(
                make_tooltip_title('Model output and expected output', 'y_plot_title', 'TBD'),
                'y_plot',
            )
        ], width=5),
    ], align='center',),
    html.Br(),
])

# callbacks

In [5]:
import plotly.express as px
from src.visualization import plot_map, intensity_map
from dash import Input, Output

@app.callback(
    Output('y_map_title', 'children'),
    Input('y_map_toggle', 'on'),
)
def update_y_map_title(y_pred_selected):
    return 'Predicted map' if y_pred_selected else 'Expected map'


@app.callback(
    Output('y_map', 'figure'),
    Input('y_map_outlier_toggle', 'on'),
    Input('y_map_toggle', 'on'),
)
def update_y_map(suppress_outliers, y_pred_selected):
    fig = intensity_map(
        y_pred if y_pred_selected else y_true,
        dim=DIM,
        index_type=INDEX,
        suppress_outliers=suppress_outliers,
        colorscale=px.colors.sequential.haline,
    )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        yaxis=dict(scaleanchor='x'),
        margin=dict(l=0, r=0, b=0, t=0,),
    )
    return fig


@app.callback(
    Output('error_map', 'figure'),
    Input('error_metric_dropdown', 'value'),
)
def update_error_map(metric):
    """
    TODO KLD currently not working
    """
    fig = plot_map(
        METRICS[metric](y_true, y_pred),
        dim=DIM,
        index_type=INDEX,
        colorscale=px.colors.sequential.deep,
    )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        yaxis=dict(scaleanchor='x'),
        margin=dict(l=0, r=0, b=0, t=0,),
    )

    return fig

In [6]:
import plotly.graph_objects as go
from src.visualization import id_from_snake_index, plot_spectra
from dash import callback_context

@app.callback(
    Output('X_plot', 'figure'),
    Input('y_map', 'clickData'),
    Input('error_map', 'clickData'),
)
def update_X_plot(y_map_click, error_map_click):
    """
    TODO doesn't have to be for spectra... change axis labels
    """
    ctx = callback_context.triggered[0]['prop_id'].split('.')[0]
    if ctx == 'y_map':
        x, y = y_map_click['points'][0]['x'], y_map_click['points'][0]['y']
    elif ctx == 'error_map':
        x, y = error_map_click['points'][0]['x'], error_map_click['points'][0]['y']
    else:
        x, y, id = 0, 0, 0
    id = id_from_snake_index(x, y, DIM)

    fig = plot_spectra(
        [X[id]],
        calibration=X_labels,
        labels=[f'({x}, {y})'],
        colormap=px.colors.qualitative.Set1,
    )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=50,),
    )
    return fig


@app.callback(
    Output('y_plot', 'figure'),
    Input('y_map', 'clickData'),
    Input('error_map', 'clickData'),
)
def update_y_plot(y_map_click, error_map_click, cmap=px.colors.qualitative.Set1, opacity=.7):
    ctx = callback_context.triggered[0]['prop_id'].split('.')[0]
    if ctx == 'y_map':
        x, y = y_map_click['points'][0]['x'], y_map_click['points'][0]['y']
    elif ctx == 'error_map':
        x, y = error_map_click['points'][0]['x'], error_map_click['points'][0]['y']
    else:
        x, y, id = 0, 0, 0
    id = id_from_snake_index(x, y, DIM)

    fig = go.Figure()
    for i, (y_x, label) in enumerate(zip([y_true, y_pred], ['original', 'predicted'])):
        fig.add_trace(
            go.Scatter(
                x = y_labels,
                y = y_x[id],
                name = label + f' at ({x}, {y})',
                line = {'color': cmap[i % len(cmap)]},
                opacity=opacity,
            )
        )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=50,),
    )
    return fig

In [7]:
@app.callback(
    Output('error_table', 'children'),
    Input('error_map', 'figure'),
)
def update_error_table(fig):
    values = np.array(fig['data'][0]['z']).flatten()
    row1 = html.Tr([html.Td("mean error"), html.Td('%s' % float('%.4g' % np.mean(values)))])
    row2 = html.Tr([html.Td("median error"), html.Td('%s' % float('%.4g' % np.median(values)))])
    row3 = html.Tr([html.Td("min error"), html.Td('%s' % float('%.4g' % np.min(values)))])
    row4 = html.Tr([html.Td("max error"), html.Td('%s' % float('%.4g' % np.max(values)))])

    return [html.Tbody([row1, row2, row3, row4])]

# Run

In [8]:
if __name__ == "__main__":
    app.run_server(debug=True)

Dash app running on http://127.0.0.1:8050/
