In [None]:
import numpy as np
from matplotlib import cm
DIM = (70, 70) # hyperspectral map dimensions
MAX_NUM_CLASSES = 9
CLASS_COLORS = cm.Set1
INTENSITY_COLORS = cm.Reds
RANGE_SLIDER_COLORS = ['red']

GRAPH_STYLE = dict(
    template='plotly_dark',
    plot_bgcolor= 'rgba(0, 0, 0, 0)',
    paper_bgcolor= 'rgba(0, 0, 0, 0)',
    margin=dict(l=0, r=0, b=0, t=0,),
)

# PROBLEMS: colorful buttons, model output

# data

In [None]:
calibration = np.load(open('data/X_labels.npy', 'rb'))  # wavelengths
X = np.load(open('data/X.npy', 'rb'))  # measured data, dimensions are (index of measurement, wavelength)

# make hyperspectral map
X.resize(DIM + (X.shape[1],))
X[::2, :] = X[::2, ::-1]  # input data has snake index

# define class labels and colors
num_classes = 4
if num_classes > MAX_NUM_CLASSES:
    raise RuntimeError('Classes supported up to 9!')

In [None]:
class DummyModel:
    def __init__(self) -> None:
        pass

    def fit(self, X, y, *args, **kwargs):
        self.y = np.zeros(DIM)
        self.y[:,::2] = 1
        self.y[:,1::2] = 2
        self.y[:30,:30] = 3
        self.y[50:,50:] = 4
        return self
    
    def predict(self, X):
        return self.y

# models

In [None]:
models = [DummyModel(), ]

# maps model names to indices over <models> array
model_names = {name: i for i, name in enumerate(['Dummy model'])}

# layout

In [None]:
import dash_bootstrap_components as dbc
import plotly.express as px
import json
from dash import html, dcc, no_update, ctx
from jupyter_dash import JupyterDash
from dash import Input, Output, State
from dash.exceptions import PreventUpdate

# our modules you can modify
import libs_tools.dash.custom_components as cc
from libs_tools.visualization import plot_spectra, plot_map

In [None]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.DARKLY])
app.title = 'LIBS Segmentation'

# short text-based exaplanation of the app
introduction = dbc.Card([
    dbc.CardHeader('Introduction'),
    dbc.CardBody('Introduction goes here'),
])

# hyperspectral image, along with the drawing panel
image_panel = dbc.Card([
    dbc.CardHeader('Image panel'),
    dbc.CardBody([
        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody(dbc.RadioItems(
                        id="class_button",
                        className="btn-group",
                        inputClassName="btn-check",
                        labelClassName="btn btn-outline-primary",
                        labelCheckedClassName="active",
                        options=[
                            {"label": "Reset", "value": -2},  # TODO reset should have one button (not be a mode)
                            {"label": "Zoom", "value": -1},
                            {"label": "Clear", "value": 0}, ] + [
                            {'label': f'Class {i}', 'value': i} for i in range(1, num_classes + 1)
                        ],
                        value=1
                    )),
                ]),
            ]),

            dbc.Col([
                dbc.Card([
                    dbc.CardBody(dcc.Input(
                        id='width',
                        type='number',
                        placeholder='Brush width (2)'
                    )),
                ]),
            ]),
        ]),
        dbc.Row([
            dbc.Card(dbc.CardBody(dcc.Graph(
                id='x_map',
                config={
                    'displayModeBar': False
                },
            ))),
        ])
    ])
])

# saving and loading past work, themes?, colorscales?
application_panel = dbc.Card([
    dbc.CardHeader('Application panel'),
])

options = [{'label': name, 'value': val} for name, val in model_names.items()]
# controls the segmentation model and the output display
model_panel = dbc.Card([
    dbc.CardHeader('Model panel'),
    dbc.Row([
        dbc.Col([dbc.Button('Show segmentation', id='model_btn')], width=6),
        dbc.Col([dbc.Select(
            id='model_identifier',
            placeholder=options[0]['label'],
            options=options,
        )], width=6)
    ])
])

# currently hovered on spectrum, TODO add support for clicking
selected_spectra = dbc.Card([
    dbc.CardHeader('Currently selected spectrum'),
    dbc.CardBody([
        dcc.Graph(id='point_plot'),
    ])
])

fig = plot_spectra([X.mean(axis=(0, 1))], calibration=calibration, colormap=RANGE_SLIDER_COLORS)
fig.update_layout(
    yaxis=dict(fixedrange=True,),
    **GRAPH_STYLE,
)

range_slider = dbc.Card([
    dbc.CardHeader('Mean spectrum (resize to change how the total intensity is calculated)'),
    dbc.CardBody([
        dcc.Graph(id='range_slider', figure=fig),
    ])
])

meta = html.Div(
    [
        dcc.Store(id='manual_labels', data=np.zeros((DIM))), # TODO storage type? currently loses data on reload
        html.Div(id='test'),
        html.Div(id='test2', style={'display': 'none'}),
        dcc.Location(id='url'),
        html.Div(id='screen_resolution', style={'display': 'none'})
    ],
    # TODO style = no-display
)

app.layout = html.Div([
    dbc.Container([
        dbc.Row([
            dbc.Col(introduction),
        ], justify='evenly'),
        html.Br(),
        dbc.Row([
            dbc.Col([
                dbc.Row([
                    dbc.Col(image_panel)
                ]),
                html.Br(),
                dbc.Row([
                    application_panel
                ])
            ], width=7),
            dbc.Col([
                dbc.Row([
                    model_panel
                ]),
                html.Br(),
                dbc.Row([
                    range_slider
                ]),
                html.Br(),
                dbc.Row([
                    selected_spectra
                ])
            ], width=4)
        ], justify="evenly",),
        dbc.Row([
            dbc.Col([meta])
        ])
    ], fluid=True)
])

# callbacks

In [None]:
def mouse_path_to_indices(path):
    indices_str = [
        el.replace("M", "").replace("Z", "").split(",") for el in path.split("L")
    ]
    return list(map(tuple, np.rint(np.array(indices_str, dtype=float)).astype(int).tolist()))

In [None]:
from PIL import Image, ImageDraw
from matplotlib import cm

# get screen resolution (to manually resize the hyperspectral image)
app.clientside_callback(
    """
    function(href) {
        var w = window.innerWidth;
        var h = window.innerHeight;
        return JSON.stringify({'height': h, 'width': w});
    }
    """,
    Output('screen_resolution', 'children'),
    Input('url', 'href')
)


@app.callback(
    Output('manual_labels', 'data'),
    Input('manual_labels', 'data'),
    Input('class_button', 'value'),
    Input('width', 'value'),
    Input('x_map', 'relayoutData'),
    prevent_initial_call=True,
)
def update_manual_labels(memory, mode, width, relayout):
    if mode == -2:
        return np.zeros(DIM)
    if ctx.triggered_id != 'x_map' or 'shapes' not in relayout or mode < 0:
        raise PreventUpdate
    img = Image.fromarray(np.array(memory))
    draw = ImageDraw.Draw(img)
    node_coords = mouse_path_to_indices(relayout['shapes'][-1]['path'])
    # TODO bug leaves little holes
    draw.line(node_coords, fill=mode, width=int(width) if width else 2, joint='curve')
    return np.asarray(img)


@app.callback(
    Output('test', 'children'),
    Input('x_map', 'figure'),
)
def update_manual_labels(inp):
    return ''


@app.callback(
    Output('x_map', 'figure'),
    Input('range_slider', 'relayoutData'),
    Input('manual_labels', 'data'),
    Input('screen_resolution', 'children'),
    Input('class_button', 'value'),
    Input('model_btn', 'n_clicks'),
    Input('model_identifier', 'value'),
)
def update_X_map(wave_range, manual_labels, screen_resolution, mode, show_segment_btn, model_identifier):
    # unpack input values
    screen_resolution = json.loads(screen_resolution)
    manual_labels = np.array(manual_labels)
    mask = np.repeat(manual_labels[:,:, np.newaxis], 4, axis=2)  # broadcast manual labels to multi-channel image
    model_identifier = int(model_identifier) if model_identifier else 0
    y = models[model_identifier].fit(X, manual_labels).predict(X)

    # choose one of two main modes
    if show_segment_btn is None or show_segment_btn % 2 == 0:
        # show image
        if wave_range is None or "xaxis.autorange" in wave_range or 'autosize' in wave_range:
            values = X.sum(axis=2)
        else:
            values = X[:, :, (calibration >= float(wave_range["xaxis.range[0]"])) & (calibration <= float(wave_range["xaxis.range[1]"]))].sum(axis=2)
        img = np.where(mask > 0, cm.Set1(manual_labels / (num_classes + 1), alpha=1.) * 255, cm.Reds(values / values.max(), alpha=1.) * 255)
    else:
        # show segmentation
        img = np.where(mask > 0, cm.Set1(manual_labels / (num_classes + 1), alpha=1.) * 255, cm.Set1(y / (num_classes + 1), alpha=.8) * 255)

    # generate plot
    fig = px.imshow(img=img, labels={})
    fig.update_traces(
        hovertemplate='<',
        hoverinfo='skip',
    )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=0, pad=0),
        dragmode='zoom' if mode == 0 else 'drawopenpath',
        newshape=dict(opacity=0),  # TODO shapes are currently just hidden but not deleted
        xaxis=dict(visible=False, range=fig['layout']['xaxis']['range'] if fig else None),
        yaxis=dict(visible=False, range=fig['layout']['yaxis']['range'] if fig else None),
        width=int(min(screen_resolution['height'] * .9, screen_resolution['width'] * .7)),
        height=int(min(screen_resolution['height'] * .9, screen_resolution['width'] * .7)),
        uirevision='None',
        shapes=[],  # TODO this does not remove the shapes!
    )
    fig.update_shapes(editable=False)

    return fig


@app.callback(
    Output('point_plot', 'figure'),
    Input('x_map', 'hoverData'),
)
def update_point_plot(hover):
    if hover is not None:
        x, y = hover['points'][0]['x'], hover['points'][0]['y']
    else:
        x, y = 0, 0
    fig = plot_spectra([X[x, y, :]], calibration=calibration)
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=0,),
    )
    return fig

# run

In [None]:
if __name__ == "__main__":
    app.run_server(debug=True)