In [1]:
import dash
from dash import dcc, html, Input, Output
import plotly.graph_objects as go
import pandas as pd
import numpy as np
import json

In [2]:
# dataset_name = "Mouse Bladder"
dataset_name = "Human Pancreas"
base_dir = f"dash_data/{dataset_name}"

df = pd.read_csv(f"{base_dir}/umap_labels.csv")
xy = np.load(f"{base_dir}/umap_coords.npy")

with open(f"{base_dir}/mustlink.json", "r") as f:
    ml_draw = json.load(f)

with open(f"{base_dir}/cannotlink.json", "r") as f:
    cl_draw = json.load(f)

In [3]:
# ------------------ Dash App ------------------
app = dash.Dash(__name__)
server = app.server

methods = ["Ground Truth", "KMeans", "DEC", "scDCC"]
method_to_col = {
    "Ground Truth": "GT",
    "KMeans": "KMeans",
    "DEC": "DEC",
    "scDCC": "scDCC"
}

def colors_for_labels(labels):
    import plotly.express as px
    pal = px.colors.qualitative.Alphabet + px.colors.qualitative.Dark24 + px.colors.qualitative.Set3
    uniq = pd.unique(labels.astype(str))
    lut = {lab: pal[i % len(pal)] for i, lab in enumerate(uniq)}
    return [lut[str(v)] for v in labels.astype(str)]

def segs(pairs):
    xs, ys = [], []
    for i, j in pairs:
        xs += [xy[i, 0], xy[j, 0], None]
        ys += [xy[i, 1], xy[j, 1], None]
    return xs, ys

app.layout = html.Div([
    html.H2(f"UMAP Viewer: {dataset_name}", style={'textAlign': 'center'}),
    html.Div([
        html.Label("Select Method:"),
        dcc.Dropdown(
            id='method-dropdown',
            options=[{'label': m, 'value': m} for m in methods],
            value='Ground Truth',
            style={'width': '250px'}
        ),
        html.Label("Show Constraints:", style={'margin-left': '30px'}),
        dcc.Checklist(
            id='constraint-toggle',
            options=[
                {'label': 'Must-Link', 'value': 'ML'},
                {'label': 'Cannot-Link', 'value': 'CL'}
            ],
            value=[],
            inline=True
        )
    ], style={'display': 'flex', 'justify-content': 'center', 'align-items': 'center'}),

    dcc.Graph(id='umap-plot', config={'displayModeBar': True}, style={'height': '90vh'})
])

@app.callback(
    Output('umap-plot', 'figure'),
    Input('method-dropdown', 'value'),
    Input('constraint-toggle', 'value')
)
def update_plot(method, constraint_opts):
    col = method_to_col[method]
    fig = go.Figure()

    fig.add_trace(go.Scattergl(
        x=df["UMAP1"], y=df["UMAP2"],
        mode="markers",
        marker=dict(size=4, opacity=0.85, color=colors_for_labels(df[col])),
        text=df[col],
        hovertemplate=(
            f"Method: {method}<br>Label: %{{text}}<br>UMAP1: %{{x:.2f}}, UMAP2: %{{y:.2f}}<extra></extra>"
        ),
        name=method,
    ))

    if 'ML' in constraint_opts:
        xs_ml, ys_ml = segs(ml_draw)
        fig.add_trace(go.Scatter(
            x=xs_ml, y=ys_ml, mode="lines",
            line=dict(width=1.2, color="crimson"),
            name="Must-Link"
        ))
    if 'CL' in constraint_opts:
        xs_cl, ys_cl = segs(cl_draw)
        fig.add_trace(go.Scatter(
            x=xs_cl, y=ys_cl, mode="lines",
            line=dict(width=1.2, color="royalblue"),
            name="Cannot-Link"
        ))

    fig.update_layout(
        title=f"UMAP — {method}",
        xaxis_title="UMAP-1",
        yaxis_title="UMAP-2",
        template="plotly_white",
        autosize=True,
        margin=dict(l=10, r=10, t=50, b=10),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="left", x=0)
    )
    return fig

if __name__ == '__main__':
    app.run(debug=True, port=8051)
    # http://127.0.0.1:8051/