In [1]:
import dash
from dash import dcc, html, Input, Output
import pandas as pd
import geopandas as gpd
import plotly.express as px
import plotly.graph_objects as go
from flask import Flask
import glob
import os

# === Setup ===
server = Flask(__name__)
app = dash.Dash(__name__, server=server, suppress_callback_exceptions=True)
app.title = "Wasting Prediction Dashboard"

# === Load datasets ===
# Load trend data
trend_df = pd.read_csv("data/clean_trend_wide_with_CI.csv")
trend_df["time_period"] = pd.to_datetime(trend_df["time_period"])
trend_df["Ward"] = trend_df["Ward"].str.strip()

# Geo data
gdf = gpd.read_file("data/Kenya_wards_with_counties.geojson")
gdf["Ward"] = gdf["Ward"].str.strip()
gdf["geometry_json"] = gdf["geometry"].apply(lambda g: g.__geo_interface__)
counties = gpd.read_file("data/ken_admbnda_adm1_iebc_20191031.shp").to_crs(gdf.crs)

# Dropdown options
available_months = sorted(trend_df["time_period"].dropna().dt.to_period("M").unique())
available_months = [str(m) for m in available_months]
county_list = sorted(trend_df["County"].dropna().unique())

# === Layout ===
app.layout = html.Div([
    html.H1("3-Month Wasting Prediction Alert Dashboard", style={"textAlign": "center"}),

    dcc.Tabs(id="main-tabs", value="map-tab", children=[
        dcc.Tab(label="Map & Single Ward Time Series", value="map-tab", children=[

            html.Div([
                html.Label("Reference Date (Month):"),
                dcc.Dropdown(
                    id="month-select",
                    options=[{"label": m, "value": m} for m in available_months],
                    value=available_months[-1],
                    style={"width": "300px"}
                ),
            ], style={"margin": "20px"}),

            html.Div([
                html.Div([
                    html.H4("Alert Map (3-Month Horizon)"),
                    html.Div([
                        html.P("Wards are flagged for alert based on the following criteria:", style={"fontSize": "12px", "marginBottom": "4px"}),
                        html.Ul([
                            html.Li("Observed prevalence is ≥ 10%, and both the 3-month predicted trend and the past 2-month observed trend are increasing or stable.", style={"fontSize": "12px"}),
                            html.Li("Or, observed prevalence is ≥ 15% regardless of trends.", style={"fontSize": "12px"})
                        ]),
                        html.P("Hover over a ward to see predicted prevalence and confidence intervals.", style={"fontSize": "12px", "marginTop": "6px"})
                    ]),
                    dcc.Graph(id="trend-map")
                ], style={"width": "49%", "display": "inline-block"}),

                html.Div([
                    html.H4("Observed Prevalence Map"),
                    html.Div("Observed prevalence at selected date.", style={"fontSize": "12px"}),
                    dcc.Graph(id="prevalence-map")
                ], style={"width": "49%", "display": "inline-block", "marginLeft": "1%"})
            ]),

            html.Hr(),

            html.Div([
                html.H2("Observed vs Predicted Time Series", style={"textAlign": "center"}),

                html.Label("Select County:"),
                dcc.Dropdown(id="county-select", options=[{"label": c, "value": c} for c in county_list],
                             value=county_list[0], style={"width": "300px"}),

                html.Label("Select Ward:"),
                dcc.Dropdown(id="ward-select", options=[], value=None, style={"width": "300px"}),

                dcc.Graph(id="ward-timeseries", config={"displayModeBar": True})
            ], style={"margin": "20px"}),

        ]),

        dcc.Tab(label="All Alert Wards Time Series", value="alerts-tab", children=[
            html.Div([
                html.Label("Reference Date (Month):"),
                dcc.Dropdown(
                    id="month-select-alerts",
                    options=[{"label": m, "value": m} for m in available_months],
                    value=available_months[-1],
                    style={"width": "300px", "margin": "20px"}
                ),
                html.Div(id="alert-timeseries-plots")
            ])
        ])
    ])
])


# === Callbacks ===

@app.callback(
    Output("ward-select", "options"),
    Output("ward-select", "value"),
    Input("county-select", "value")
)
def update_ward_dropdown(county):
    wards = trend_df[trend_df["County"] == county]["Ward"].dropna().unique()
    return [{"label": w, "value": w} for w in sorted(wards)], sorted(wards)[0] if len(wards) > 0 else None



@app.callback(
    Output("trend-map", "figure"),
    Input("month-select", "value")
)
def update_trend_map(month_str):
    month = pd.Period(month_str).to_timestamp()
    month_trends = trend_df[trend_df["time_period"] == month][
        ["Ward","County", "predicted_trend_CI_3mo", "predicted_value_3mo", 
         "lower_bound_3mo", "upper_bound_3mo", 
         "alert_flag","consecutive_alerts"]
    ]

    month_trends["alert_flag"] = month_trends["alert_flag"].fillna(False)

    merged = gdf[["Ward", "geometry_json"]].merge(month_trends, on="Ward", how="left")

    def format_hover(row):
        val = row["predicted_value_3mo"]
        lb = row["lower_bound_3mo"]
        ub = row["upper_bound_3mo"]
        val_str = "N/A" if pd.isna(val) else f"{val:.3f}"
        ci_str = (
            "N/A"
            if pd.isna(lb) or pd.isna(ub)
            else f"[{lb:.3f} – {ub:.3f}]"
        )

        alert = "⚠️ ALERT<br>" if row["alert_flag"] else ""
        streak = (
            f"Consecutive months in alert: {int(row['consecutive_alerts'])}"
            if pd.notna(row["consecutive_alerts"]) and row["consecutive_alerts"] > 1
            else ""
        )

        return (
            f"<b>County:</b> {row['County']}<br>"
            f"<b>Ward:</b> {row['Ward']}<br>"
            f"<b>Predicted Trend (3mo):</b> {row['predicted_trend_CI_3mo']}<br>"
            f"<b>Predicted (3mo):</b> {val_str}<br>"
            f"<b>95% CI:</b> {ci_str}<br>"
            f"{alert}"
            f"{streak}"
        )


    merged["hover_label"] = merged.apply(format_hover, axis=1)

    geojson = {
        "type": "FeatureCollection",
        "features": [
            {
                "type": "Feature",
                "geometry": geom,
                "properties": {
                    "Ward": ward,
                    "hover_label": label
                }
            }
            for ward, geom, label in zip(merged["Ward"], merged["geometry_json"], merged["hover_label"])
        ]
    }



    fig = px.choropleth_mapbox(
        merged,
        geojson=geojson,
        locations="Ward",
        featureidkey="properties.Ward",
        color="alert_flag",
        color_discrete_map={False: "lightblue", True: "brown"},
        custom_data=["hover_label"],
        mapbox_style="carto-positron",
        zoom=5.5,
        center={"lat": 0.5, "lon": 37}
    )

    fig.update_traces(
        hovertemplate="%{customdata[0]}<extra></extra>"
    )



    for _, row in counties.iterrows():
        for poly in getattr(row.geometry, "geoms", [row.geometry]):
            x, y = poly.exterior.xy
            fig.add_trace(go.Scattermapbox(
                lon=list(x), lat=list(y), mode="lines",
                line=dict(color="black"), hoverinfo="skip", showlegend=False
            ))

    fig.update_layout(margin=dict(r=0, t=0, l=0, b=0))
    return fig




@app.callback(
    Output("prevalence-map", "figure"),
    Input("month-select", "value")
)
def update_prevalence_map(month_str):
    month = pd.Period(month_str).to_timestamp()
    month_df = trend_df[trend_df["time_period"] == month][["Ward", "observed"]]
    merged = gdf[["Ward", "geometry"]].merge(month_df, on="Ward", how="left")

    def classify(val):
        if pd.isna(val): return "No Data"
        elif val <= 0.10: return "0–0.10"
        elif val <= 0.15: return "0.10–0.15"
        else: return ">0.15"

    merged["prevalence_group"] = merged["observed"].apply(classify)
    color_map = {
        "0–0.10": "lightblue",
        "0.10–0.15": "lightcoral", 
        ">0.15": "brown",
        "No Data": "lightgray"
    }

    fig = px.choropleth_mapbox(
        merged, geojson=merged.geometry, locations=merged.index,
        color="prevalence_group", color_discrete_map=color_map,
        hover_name="Ward", hover_data={"observed": ":.3f"},
        mapbox_style="carto-positron", zoom=5.5, center={"lat": 0.5, "lon": 37}
    )

    for _, row in counties.iterrows():
        geom = row.geometry
        for poly in getattr(geom, "geoms", [geom]):
            x, y = poly.exterior.xy
            fig.add_trace(go.Scattermapbox(lon=list(x), lat=list(y), mode="lines",
                                           line=dict(color="black"), hoverinfo="skip", showlegend=False))

    fig.update_layout(margin=dict(r=0, t=0, l=0, b=0))
    return fig

@app.callback(
    Output("ward-timeseries", "figure"),
    Input("ward-select", "value")
)
def update_timeseries(ward):
    df = trend_df[trend_df["Ward"] == ward].sort_values("time_period")
    if df.empty: return go.Figure()
    title = f"{ward} ({df['County'].iloc[0]})"
    fig = go.Figure()
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["observed"], mode="lines+markers", name="Observed", line=dict(color="black")))
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["predicted_value_1mo"], mode="lines", name="Predicted (1mo)", line=dict(color="green", dash="dash")))
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["lower_bound_1mo"], mode="lines", line=dict(color="lightgreen"), showlegend=False))
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["upper_bound_1mo"], mode="lines", fill="tonexty", line=dict(color="lightgreen"),
                             fillcolor="rgba(0,255,0,0.1)", name="1mo CI"))
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["predicted_value_3mo"], mode="lines", name="Predicted (3mo)", line=dict(color="orange", dash="dot")))
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["lower_bound_3mo"], mode="lines", line=dict(color="navajowhite"), showlegend=False))
    fig.add_trace(go.Scatter(x=df["time_period"], y=df["upper_bound_3mo"], mode="lines", fill="tonexty", line=dict(color="navajowhite"),
                             fillcolor="rgba(255,165,0,0.15)", name="3mo CI"))
    fig.update_layout(title=title, xaxis_title="Date", yaxis_title="Wasting Prevalence")
    return fig

html.Hr(),

html.Div([
    html.H2("Observed vs Predicted Time Series", style={"textAlign": "center"}),

    html.Label("Select County:"),
    dcc.Dropdown(id="county-select", options=[{"label": c, "value": c} for c in county_list],
                 value=county_list[0], style={"width": "300px"}),

    html.Label("Select Ward:"),
    dcc.Dropdown(id="ward-select", options=[], value=None, style={"width": "300px"}),

    dcc.Graph(id="ward-timeseries", config={"displayModeBar": True})
], style={"margin": "20px"})

@app.callback(
    Output("alert-timeseries-plots", "children"),
    Input("month-select-alerts", "value")
)
def display_alert_ward_timeseries(month_str):
    month = pd.Period(month_str).to_timestamp()
    alert_wards = trend_df[(trend_df["time_period"] == month) & (trend_df["alert_flag"] == True)]["Ward"].unique()

    plots = []
    for ward in sorted(alert_wards):
        df = trend_df[trend_df["Ward"] == ward].sort_values("time_period")
        if df.empty:
            continue
        county = df["County"].iloc[0]
        title = f"{ward} ({county})"
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=df["time_period"], y=df["observed"], mode="lines+markers", name="Observed", line=dict(color="black")))
        fig.add_trace(go.Scatter(x=df["time_period"], y=df["predicted_value_3mo"], mode="lines", name="Predicted (3mo)", line=dict(color="orange", dash="dot")))
        fig.add_trace(go.Scatter(x=df["time_period"], y=df["lower_bound_3mo"], mode="lines", line=dict(color="navajowhite"), showlegend=False))
        fig.add_trace(go.Scatter(x=df["time_period"], y=df["upper_bound_3mo"], mode="lines", fill="tonexty", line=dict(color="navajowhite"),
                                 fillcolor="rgba(255,165,0,0.15)", name="3mo CI"))
        fig.update_layout(title=title, height=300, margin=dict(t=40, b=20))
        plots.append(dcc.Graph(figure=fig, config={"displayModeBar": False}))

    if not plots:
        return html.P("No wards flagged for alert at this date.", style={"margin": "20px", "fontStyle": "italic"})

    return plots

# === Run app ===
if __name__ == "__main__":
    app.run(debug=True, host="0.0.0.0", port=8080)



*choropleth_mapbox* is deprecated! Use *choropleth_map* instead. Learn more at: https://plotly.com/python/mapbox-to-maplibre/


*choropleth_mapbox* is deprecated! Use *choropleth_map* instead. Learn more at: https://plotly.com/python/mapbox-to-maplibre/

