In [1]:
"""
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
%cd drive/Shareddrives/LIBS/LIBS_library
"""

import numpy as np

# data

In [2]:
from src.visualization import IndexType

INDEX = IndexType.HORIZONTAL_SNAKE
DIM = (266, 500)

In [3]:
from src.preprocessing import match_wavelengths

s1_wavelengths = np.load(open('data/s1_wavelengths.npy', 'rb'))
s2_wavelengths = np.load(open('data/s2_wavelengths.npy', 'rb'))
s1_spectra = np.load(open('data/s1.npy', 'rb'))
s2_spectra = np.load(open('data/s2.npy', 'rb'))

left, right, calibration = match_wavelengths(s1_spectra, s2_spectra, s1_wavelengths, s2_wavelengths)
del s1_spectra, s2_spectra, s1_wavelengths, s2_wavelengths

# layout

In [4]:
import dash
import dash_bootstrap_components as dbc
from dash import html
from dash import dcc
from jupyter_dash import JupyterDash
from src.rowwise_metrics import rowwise_cosine, rowwise_euclid, rowwise_kl_divergence

METRICS = {
    'cosine distance'            : rowwise_cosine,
    'euclidean distance'         : rowwise_euclid,
    'kullback-leibler divergence': rowwise_kl_divergence,
}


def make_map(name, title):
    return  html.Div([
        dbc.Card(
            dbc.CardBody([
                html.H4(title, className="card-title"),
                dcc.Graph(
                    id=name,
                    config={
                        'displayModeBar': False
                    }
                ) 
            ])
        ),  
    ])


# TODO more ranges
def make_slider():
    return  html.Div([
        dbc.Card(
            dbc.CardBody([
                dcc.RangeSlider(
                    min=calibration[0],
                    max=calibration[-1],
                    value=[calibration[0], calibration[-1]],
                    id='wavelength_slider',
                    tooltip={"placement": "bottom", "always_visible": True},
                    step=.2,
                    marks=None,
                )
            ])
        ),  
    ])


def make_spectra():
    return  html.Div([
        dbc.Card(
            dbc.CardBody([
                dcc.Graph(
                    id='spectra',
                ) 
            ])
        ),  
    ])


dbc_css = "https://cdn.jsdelivr.net/gh/AnnMarieW/dash-bootstrap-templates/dbc.min.css"
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.SLATE, dbc_css])
app.title = 'LIBS Transfer library'
app.layout = html.Div([
    dbc.Card(
        dbc.CardBody([
            dbc.Row([
                dbc.Col([
                    make_map('left_map', 'Left/Right map') 
                ], width=6),
                dbc.Col([
                    make_map('difference_map', 'Point-wise difference') 
                ], width=6),
            ], align='center'),
            html.Br(),

            dbc.Row([
                dbc.Col([
                    dbc.Card(
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    html.Div(
                                        [
                                            dbc.RadioItems(
                                                options=[
                                                    {"label": "left map", "value": 1},
                                                    {"label": "right map", "value": 2},
                                                ],
                                                value=1,
                                                id="map_toggle",
                                                inline=True,
                                            ),
                                        ]
                                    )
                                ])
                            ])
                        ])
                    )
                ], width=2),
                dbc.Col([
                    make_slider()
                ], width=8),
                dbc.Col([
                    dbc.Card(
                        dbc.CardBody([
                            dbc.Row([
                                dbc.Col([
                                    dcc.Dropdown(
                                        options=list(METRICS.keys()),
                                        value=list(METRICS.keys())[0],
                                        id='metric_dropdown',
                                        className="dbc",
                                    )
                                ])
                            ])
                        ])
                    )
                ], width=2),
            ], align='center'),     
            html.Br(),

            dbc.Row([
                dbc.Col([
                    make_spectra()
                ], width=12),
            ], align='center', className="h-40",),    
        ]), color = 'dark',
    )
])

# callbacks

In [5]:
import plotly.express as px
from src.visualization import plot_map, intensity_map
from src.preprocessing import LabelCropp
from dash import Input, Output


@app.callback(
    Output('left_map', 'figure'),
    Input("wavelength_slider", "value"),
    Input("map_toggle", "value"),
)
def update_left_map(slider, map_num):
    if map_num == 1:
        display_map = left
    else:
        display_map = right
    left_map = intensity_map(
        display_map,
        dim=DIM,
        index_type=INDEX,
        start=slider[0],
        end=slider[1],
        calibration=calibration,
        suppress_outliers=True,
        colorscale=px.colors.sequential.haline,
    )
    left_map.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        yaxis=dict(scaleanchor='x'),
        margin=dict(l=0, r=0, b=0, t=0,),
    )

    return left_map


@app.callback(
    Output('difference_map', 'figure'),
    Input("wavelength_slider", "value"),
    Input("metric_dropdown", "value"),
)
def update_right_map(slider, metric_key):
    cropp = LabelCropp(label_from=slider[0], label_to=slider[1], labels=calibration)
    new_left = cropp.fit_transform(left)
    new_right = cropp.fit_transform(right)
    values = METRICS[metric_key](new_left, new_right)
    right_map = plot_map(
        values,
        dim=DIM,
        index_type=INDEX,
        colorscale=px.colors.sequential.deep,
    )
    right_map.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        yaxis=dict(scaleanchor='x'),
        margin=dict(l=0, r=0, b=0, t=0,),
    )

    return right_map

In [6]:
from src.visualization import id_from_snake_index, plot_spectra
from dash import callback_context

@app.callback(
    Output('spectra', 'figure'),
    Input('left_map', 'clickData'),
    Input('difference_map', 'clickData'),
)
def update_sample(left_click, right_click):
    ctx = callback_context.triggered[0]['prop_id'].split('.')[0]
    if left_click is None and right_click is None:
        x, y, id = 0, 0, 0
    elif ctx == 'difference_map':
        x, y = right_click['points'][0]['x'], right_click['points'][0]['y']
    elif ctx == 'left_map':
        x, y = left_click['points'][0]['x'], left_click['points'][0]['y']
    id = id_from_snake_index(x, y, DIM)

    fig = plot_spectra(
        np.vstack((left[id], right[id])),
        calibration=calibration,
        labels=['left spectrum', 'right spectrum'],
        colormap=px.colors.qualitative.Set1,
        title=f"Sampled spectra for id={id} (x={x}, y={y})",
    )
    fig.update_layout(
        template='plotly_dark',
        plot_bgcolor= 'rgba(0, 0, 0, 0)',
        paper_bgcolor= 'rgba(0, 0, 0, 0)',
        margin=dict(l=0, r=0, b=0, t=50,),
    )
    return fig

In [7]:
if __name__ == "__main__":
    app.run_server(debug=True)

Dash app running on http://127.0.0.1:8050/
