In [1]:
from __future__ import annotations

from pathlib import Path
import numpy as np
import pandas as pd

from flamekit.io_fronts import Case, load_fronts

# Dash
from dash import Dash, dcc, html, Input, Output, State, ctx
import plotly.express as px

# =========================
# USER SETTINGS
# =========================
TIME_STEP = 210
PHI = 0.40
LAT_SIZE = "100"
ISOLEVELS = [4.9]

BASE_DIR = Path("../data/isocontours")
POST = True

TARGET = "DW_FDS"
XCOL = "x"
YCOL = "y"
DIFFVEL = "FDS_diff_velocity_term"

# downsample for speed in browser
MAX_POINTS = 150_000
RANDOM_SEED = 0

# =========================
# LOAD DATA
# =========================
case = Case(
    base_dir=BASE_DIR,
    phi=PHI,
    lat_size=LAT_SIZE,
    time_step=TIME_STEP,
    post=POST,
)

fronts = load_fronts(case, ISOLEVELS)
c_val = float(ISOLEVELS[0])
front = fronts[c_val].copy()

need = [XCOL, YCOL, TARGET, DIFFVEL]
missing = [c for c in need if c not in front.columns]
if missing:
    raise ValueError(f"Missing columns: {missing}")

df = (
    front[need]
    .replace([np.inf, -np.inf], np.nan)
    .dropna(axis=0, how="any")
    .reset_index(drop=True)
)

if MAX_POINTS is not None and len(df) > MAX_POINTS:
    df = df.sample(n=MAX_POINTS, random_state=RANDOM_SEED).reset_index(drop=True)

# IMPORTANT: stable point id for linking
df["pid"] = np.arange(len(df), dtype=int)

# =========================
# FIGURE BUILDERS
# =========================
def make_fig_xy(selected_pids: set[int] | None):
    fig = px.scatter(
        df,
        x=XCOL, y=YCOL,
        color=TARGET,
        color_continuous_scale="Viridis",
        render_mode="webgl",  # fast
        hover_data={"pid": True, TARGET: ":.4e", XCOL: ":.4f", YCOL: ":.4f"},
        title=f"xy coloured by {TARGET} (t={TIME_STEP}, T={c_val:.2f})",
    )
    fig.update_layout(
        dragmode="lasso",
        margin=dict(l=50, r=20, t=60, b=45),
        height=560,
    )
    fig.update_yaxes(scaleanchor="x", scaleratio=1)

    if selected_pids:
        mask = df["pid"].isin(selected_pids)
        # Fade unselected points
        fig.update_traces(
            selector=dict(mode="markers"),
            opacity=np.where(mask.to_numpy(), 0.95, 0.10),
            marker=dict(size=np.where(mask.to_numpy(), 7, 4)),
        )
    else:
        fig.update_traces(marker=dict(size=4), opacity=0.85)

    return fig


def make_fig_dv(selected_pids: set[int] | None):
    fig = px.scatter(
        df,
        x=DIFFVEL, y=TARGET,
        color=TARGET,
        color_continuous_scale="Viridis",
        render_mode="webgl",
        hover_data={"pid": True, TARGET: ":.4e", DIFFVEL: ":.4e"},
        title=f"{DIFFVEL} vs {TARGET} (t={TIME_STEP}, T={c_val:.2f})",
    )
    fig.update_layout(
        dragmode="lasso",
        margin=dict(l=35, r=40, t=60, b=45),
        height=560,
        coloraxis_showscale=False,  # keep one colorbar on the left
    )

    if selected_pids:
        mask = df["pid"].isin(selected_pids)
        fig.update_traces(
            selector=dict(mode="markers"),
            opacity=np.where(mask.to_numpy(), 0.95, 0.10),
            marker=dict(size=np.where(mask.to_numpy(), 7, 4)),
        )
    else:
        fig.update_traces(marker=dict(size=4), opacity=0.85)

    return fig


def extract_selected_pids(selectedData) -> set[int]:
    if not selectedData or "points" not in selectedData:
        return set()
    # px.scatter puts point data in customdata sometimes; easiest is hoverdata includes pid as "pid"
    # Dash selection points include "pointIndex" which matches row order
    inds = {p["pointIndex"] for p in selectedData["points"] if "pointIndex" in p}
    return inds


# =========================
# DASH APP
# =========================
app = Dash(__name__)
app.layout = html.Div(
    style={"fontFamily": "Arial", "padding": "8px"},
    children=[
        html.Div(
            style={"display": "flex", "alignItems": "center", "gap": "12px"},
            children=[
                html.H3("Linked selection (PyCharm-safe)", style={"margin": "8px 0"}),
                html.Button("Clear selection", id="btn-clear", n_clicks=0),
                html.Div(id="sel-count", style={"marginLeft": "10px"}),
            ],
        ),
        dcc.Store(id="store-selected", data=[]),
        html.Div(
            style={"display": "flex", "gap": "10px"},
            children=[
                dcc.Graph(
                    id="g-xy",
                    figure=make_fig_xy(None),
                    clear_on_unhover=True,
                    style={"flex": "1"},
                    config={"displayModeBar": True},
                ),
                dcc.Graph(
                    id="g-dv",
                    figure=make_fig_dv(None),
                    clear_on_unhover=True,
                    style={"flex": "1"},
                    config={"displayModeBar": True},
                ),
            ],
        ),
    ],
)

@app.callback(
    Output("store-selected", "data"),
    Input("g-xy", "selectedData"),
    Input("g-dv", "selectedData"),
    Input("btn-clear", "n_clicks"),
    State("store-selected", "data"),
)
def update_selection(sel_xy, sel_dv, n_clear, current):
    trig = ctx.triggered_id

    if trig == "btn-clear":
        return []

    if trig == "g-xy":
        s = sorted(list(extract_selected_pids(sel_xy)))
        return s

    if trig == "g-dv":
        s = sorted(list(extract_selected_pids(sel_dv)))
        return s

    return current


@app.callback(
    Output("g-xy", "figure"),
    Output("g-dv", "figure"),
    Output("sel-count", "children"),
    Input("store-selected", "data"),
)
def redraw_figs(selected_list):
    selected_set = set(selected_list) if selected_list else set()
    fig1 = make_fig_xy(selected_set if selected_set else None)
    fig2 = make_fig_dv(selected_set if selected_set else None)
    msg = f"Selected points: {len(selected_set)}" if selected_set else "Selected points: 0"
    return fig1, fig2, msg


if __name__ == "__main__":
    # Run locally; open the printed URL in your browser.
    app.run(debug=False, host="127.0.0.1", port=8050)
