In [None]:
import numpy as np
from src.constants import *

NUM_CLASSES = 4

# data

In [None]:
calibration = np.load(open('data/calibration.npy', 'rb'))  # wavelengths
X = np.load(open('data/X.npy', 'rb'))  # measured data, dimensions are (index of measurement, wavelength)

# make hyperspectral map
X.resize(DIM + (calibration.shape[0],))
X[::2, :] = X[::2, ::-1]  # input data has snake index

# cache the actual input to the models
X_in = X.reshape((-1, calibration.shape[0]))

# Wrappers

In [None]:
class Wrapper:
    def __init__(self, model, bool_function) -> None:
        self.model = model
        self.predicate = bool_function
        
    def fit(self, X, y):
        xy_data = zip(X, y)
        X_, y_ = zip(*filter(self.predicate, xy_data))

        (self.model).fit(X_, y_)

        return self

    def predict(self, X):
        return self.model.predict(X)

# models

In [None]:
from sklearn.svm import SVC
from sklearn.semi_supervised import SelfTrainingClassifier

svc = SVC(probability=True, gamma="auto")
self_training_model = SelfTrainingClassifier(svc)

In [None]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
# TODO knn


In [None]:
filtering_predicate = lambda x: x[1] != -1

In [None]:
from sklearn.cluster import KMeans

# TODO - add MLP
models = [KMeans(n_clusters=NUM_CLASSES, n_init='auto'),
          Wrapper(SVC(random_state=0), filtering_predicate),
          Wrapper(KNeighborsClassifier(n_jobs=-1), filtering_predicate),
          Wrapper(RandomForestClassifier(max_depth=3), filtering_predicate),
          Wrapper(GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, max_depth=1, random_state=0), filtering_predicate),
          Wrapper(MLPClassifier(hidden_layer_sizes=(150, 100, 50), max_iter=300, activation='relu', solver='adam', random_state=1), filtering_predicate),
          Wrapper(MLPClassifier(hidden_layer_sizes=(256, 128, 64, 32), max_iter=300, activation='relu', solver='adam', random_state=1), filtering_predicate)
          ]

# maps model names to indices over <models> array
model_names = {name: i for i, name in enumerate(
    ['Naive KMeans', 'SVC', 'KNN', 'Random Forest', 'Gradient Boosting', 'MLP', 'MLP_bigger'])}


# layout

In [None]:
import dash_bootstrap_components as dbc
import plotly.express as px
import json
from dash import html, dcc, no_update, ctx
from jupyter_dash import JupyterDash
from dash import Input, Output
from dash.exceptions import PreventUpdate

# our modules you can modify
import libs_tools.dash.custom_components as cc
from libs_tools.visualization import plot_spectra, plot_map

In [None]:
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.FLATLY])
app.title = 'LIBS Segmentation'

# short text-based exaplanation of the app
introduction = dbc.Card([
    dbc.CardHeader('Introduction'),
    dbc.CardBody('Introduction goes here'),
])

# hyperspectral image, along with the drawing panel
image_panel = dbc.Card([
    dbc.CardHeader('Image panel'),
    dbc.CardBody([
        dbc.Row([
            dbc.Col([
                dbc.Card([
                    dbc.CardBody(dbc.RadioItems(
                        id="mode_button",
                        className="btn-group",
                        inputClassName="btn-check",
                        labelClassName="btn btn-outline-primary",
                        labelCheckedClassName="active",
                        options=[
                            {"label": "Reset", "value": -4},  # TODO reset should be seperate button (not be a mode)
                            {"label": "Zoom", "value": -3},
                            {"label": "Clear", "value": -1},
                             {"label": "Ignore", "value": -2}, ] + [
                            {'label': f'Class {i}', 'value': i} for i in range(NUM_CLASSES)
                        ],
                        value=0
                    )),
                ]),
            ]),

            dbc.Col([
                dbc.Card([
                    dbc.CardBody(dcc.Input(
                        id='width',
                        type='number',
                        placeholder='Brush width (2)'
                    )),
                ]),
            ]),
        ]),
        dbc.Row([
            dbc.Card(dbc.CardBody(dcc.Graph(
                id='x_map',
                config={
                    'displayModeBar': False
                },
            ))),
        ])
    ])
])

# saving and loading past work, themes?, colorscales?
application_panel = dbc.Card([
    dbc.CardHeader('Application panel'),
        dbc.Row([
            dbc.Col([dbc.Button('Download Manual Labels', id='save_labels')]),
            dbc.Col(dcc.Upload(dbc.Button('Upload Manual Labels'),id='load_labels')),
            dbc.Col(dbc.Button('Download Segmentation', id='save_output')),
        ]),
])

options = [{'label': name, 'value': val} for name, val in model_names.items()]
# controls the segmentation model and the output display
model_panel = dbc.Card([
    dbc.CardHeader('Model panel'),
    dbc.Row([
        dbc.Col([dbc.Button('Show Segmentation', id='show_output_btn', disabled=True)], width=4),
        dbc.Col([dbc.Button('Train Model', id='retrain_btn')], width=4),
        dbc.Col([dbc.Select(
            id='model_identifier',
            placeholder=options[0]['label'],
            options=options,
        )], width=4)
    ])
])

# currently hovered on spectrum, TODO add support for clicking
selected_spectra = dbc.Card([
    dbc.CardHeader('Currently selected spectrum'),
    dbc.CardBody([
        dcc.Graph(id='point_plot'),
    ])
])

fig = plot_spectra([X.mean(axis=(0, 1))], calibration=calibration, colormap=RANGE_SLIDER_COLORS)
fig.update_layout(
    template='plotly_white',
    yaxis=dict(fixedrange=True,),
    plot_bgcolor= 'rgba(0, 0, 0, 0)',
    paper_bgcolor= 'rgba(0, 0, 0, 0)',
    margin=dict(l=0, r=0, b=0, t=0,),
)

range_slider = dbc.Card([
    dbc.CardHeader('Mean spectrum (resize to change how the total intensity is calculated)'),
    dbc.CardBody([
        dcc.Graph(id='range_slider', figure=fig),
    ])
])

meta = html.Div(
    [
        dcc.Store(id='manual_labels', data=np.zeros((DIM)) - 1), # TODO storage type? currently loses data on reload
        dcc.Store(id='model_output', data=None), # TODO storage type? currently loses data on reload
        html.Div(id='test'),
        dcc.Location(id='url'),
        html.Div(id='screen_resolution', style={'display': 'none'}),
        dcc.Download(id='download'),
    ],
    # TODO style = no-display
)

app.layout = html.Div([
    dbc.Container([
        dbc.Row([
            dbc.Col(introduction),
        ], justify='evenly'),
        html.Br(),
        dbc.Row([
            dbc.Col([
                dbc.Row([
                    dbc.Col(image_panel)
                ]),
                html.Br(),
                dbc.Row([
                    application_panel
                ])
            ], width=7),
            dbc.Col([
                dbc.Row([
                    model_panel
                ]),
                html.Br(),
                dbc.Row([
                    selected_spectra
                ]),
                html.Br(),
                dbc.Row([
                    range_slider
                ]),
            ], width=4)
        ], justify="evenly",),
        dbc.Row([
            dbc.Col([meta])
        ])
    ], fluid=True)
])

# callbacks

In [None]:
def mouse_path_to_indices(path):
    indices_str = [
        el.replace("M", "").replace("Z", "").split(",") for el in path.split("L")
    ]
    return list(map(tuple, np.rint(np.array(indices_str, dtype=float)).astype(int).tolist()))

In [None]:
from PIL import Image, ImageDraw
from matplotlib import cm
from base64 import b64decode

# get screen resolution (to manually resize the hyperspectral image)
app.clientside_callback(
    """
    function(href) {
        var w = window.innerWidth;
        var h = window.innerHeight;
        return JSON.stringify({'height': h, 'width': w});
    }
    """,
    Output('screen_resolution', 'children'),
    Input('url', 'href')
)


@app.callback(
    Output('manual_labels', 'data'),
    Input('manual_labels', 'data'),
    Input('mode_button', 'value'),
    Input('width', 'value'),
    Input('x_map', 'relayoutData'),
    prevent_initial_call=True,
)
def update_manual_labels(memory, mode, width, relayout):
    if mode == -4:
        return np.zeros(DIM) - 1
    if ctx.triggered_id != 'x_map' or 'shapes' not in relayout or mode < -2:
        raise PreventUpdate
    img = Image.fromarray(np.array(memory))
    draw = ImageDraw.Draw(img)
    node_coords = mouse_path_to_indices(relayout['shapes'][-1]['path'])
    # TODO bug leaves little holes
    draw.line(node_coords, fill=mode, width=int(width) if width else 2, joint='curve')
    return np.asarray(img)


# TODO delete
@app.callback(
    Output('test', 'children'),
    Input('load_labels', 'contents'),
)
def idk(inp):
    return ''


@app.callback(
    Output('retrain_btn', 'outline'),
    Input('retrain_btn', 'n_clicks'),
    Input('manual_labels', 'data'),
    Input('model_identifier', 'value'),
    prevent_initial_call=True,
)
def highlight_retrain_btn(*args, **kwargs):
    if ctx.triggered_id == 'retrain_btn':
        return True
    return False


@app.callback(
    Output('download', 'data'),
    Input('save_labels', 'n_clicks'),
    Input('save_output', 'n_clicks'),
    Input('manual_labels', 'data'),
    Input('model_output', 'data'),
    prevent_initial_call=True,
)
def download_files(l_click, s_click, manual_labels, model_out):
    if ctx.triggered_id == 'save_labels':
        return {'content': json.dumps(manual_labels), 'filename':'manual_labels.json'}
    elif ctx.triggered_id == 'save_output':
        return {'content': json.dumps(model_out), 'filename':'segmentation_mask.json'}
    raise PreventUpdate


@app.callback(
    Output('manual_labels', 'data', allow_duplicate=True),
    Input('load_labels', 'contents'),
    prevent_initial_call=True,
)
def upload_labels(upload):
    # TODO find more robust solution
    decoded = b64decode(upload).decode('ISO-8859-1')
    return json.loads(decoded[decoded.find('['):])


@app.callback(
    Output('show_output_btn', 'disabled'),
    Input('retrain_btn', 'n_clicks'),
    prevent_initial_call=True,
)
def disable_show_segmentation(click):
    if click is not None:
        return False
    return True


@app.callback(
    Output('model_output', 'data'),
    Input('retrain_btn', 'n_clicks'),
    Input('manual_labels', 'data'),
    Input('model_identifier', 'value'),
)
def calculate_model_output(_, labels, model_identifier):
    if ctx.triggered_id != 'retrain_btn':
        raise PreventUpdate
    
    model_identifier = int(model_identifier) if model_identifier else 0

    y_in = np.array(labels).flatten()

    # TODO wrapper here
    # X_in 4900, 3700
    # y_in labels, -1 unknown
    # check pairs corresponding to each other
    return models[int(model_identifier)].fit(X_in, y_in).predict(X_in).reshape(DIM)


@app.callback(
    Output('x_map', 'figure'),
    Input('range_slider', 'relayoutData'),
    Input('manual_labels', 'data'),
    Input('screen_resolution', 'children'),
    Input('mode_button', 'value'),
    Input('show_output_btn', 'n_clicks'),
    Input('model_output', 'data'),
)
def update_X_map(wave_range, manual_labels, screen_resolution, mode, show_segment_btn, y):
    # unpack input values
    manual_labels = np.array(manual_labels)
    screen_resolution = json.loads(screen_resolution)
    y = np.array(y)

    # broadcast manual labels to multi-channel image
    mask = np.repeat(manual_labels[:,:, np.newaxis], 4, axis=2)

    # choose one of two main modes
    if show_segment_btn is None or show_segment_btn % 2 == 0:
        # show image
        if wave_range is None or "xaxis.autorange" in wave_range or 'autosize' in wave_range:
            values = X.sum(axis=2)
        else:
            values = X[:, :, (calibration >= float(wave_range["xaxis.range[0]"])) & (calibration <= float(wave_range["xaxis.range[1]"]))].sum(axis=2)
        img = np.where(mask >= 0, cm.Set1(manual_labels / (NUM_CLASSES), alpha=1.) * 255, cm.Reds(values / values.max(), alpha=1.) * 255)
    else:
        # show segmentation
        img = np.where(mask >= 0, cm.Set1(manual_labels / (NUM_CLASSES), alpha=1.) * 255, cm.Set1(y / (NUM_CLASSES), alpha=.8) * 255)

    img = np.where(mask == -2, 128, img)

    # generate plot
    fig = px.imshow(img=img, labels={})
    fig.update_traces(
        hovertemplate='<',
        hoverinfo='skip',
    )
    fig.update_layout(
        template='plotly_white',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=0, pad=0),
        dragmode='zoom' if mode < -2 else 'drawopenpath',
        newshape=dict(opacity=0),  # TODO shapes are currently just hidden but not deleted
        xaxis=dict(visible=False, range=fig['layout']['xaxis']['range'] if fig else None),
        yaxis=dict(visible=False, range=fig['layout']['yaxis']['range'] if fig else None),
        width=int(min(screen_resolution['height'] * .9, screen_resolution['width'] * .7)),
        height=int(min(screen_resolution['height'] * .9, screen_resolution['width'] * .7)),
        uirevision='None',
        shapes=[],  # TODO this does not remove the shapes!
    )
    fig.update_shapes(editable=False)

    return fig


@app.callback(
    Output('point_plot', 'figure'),
    Input('x_map', 'hoverData'),
)
def update_point_plot(hover):
    if hover is not None:
        x, y = hover['points'][0]['x'], hover['points'][0]['y']
    else:
        x, y = 0, 0
    fig = plot_spectra([X[x, y, :]], calibration=calibration)
    fig.update_layout(
        template='plotly_white',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=0,),
    )
    return fig

# run

In [None]:
if __name__ == "__main__":
    app.run_server(debug=True)