In [None]:
%%capture
!pip install spatialpandas easydev colormap colorcet duckdb dask_geopandas nb_black

In [None]:
import geopandas as gpd
import pandas as pd
import holoviews as hv
import geoviews as gv
import hvplot.pandas
from holoviews import opts
from bokeh.models import HoverTool
from bokeh.plotting import figure

In [None]:
import os
import sys

# adding project dirs to path so code may be referenced from the notebook
sys.path.insert(0, '../../')
sys.path.insert(0, '../../evaluation/')
sys.path.insert(0, '../../evaluation/queries/')

from evaluation import utils, config
import queries # need to fix path to use original queries
import dask_geopandas
import duckdb as ddb
import spatialpandas as spd
from holoviews.operation.datashader import (
    rasterize, shade, regrid, inspect_points
)
from holoviews.operation.datashader import (
    datashade, inspect_polygons
)
import datashader as ds
import cartopy.crs as ccrs
from dask.distributed import Client, LocalCluster
import dask.dataframe as dd
import panel as pn
from holoviews import streams
import numpy as np
from shapely.geometry import Point
from panel.interact import interact, fixed
from collections import OrderedDict as odict
import colorcet

In [None]:
gdf = utils.parquet_to_gdf(config.HUC10_PARQUET_FILEPATH).to_crs("EPSG:3857")
gdf["geometry"] = gdf["geometry"].simplify(tolerance=100)

In [None]:
def get_catchment_details(catchment_id: str, reference_time:pd.Timestamp=None):
        filters = []
        if catchment_id != "all":
            filters.append(
                {
                    "column": "catchment_id",
                    "operator": "like",
                    "value": f"{catchment_id}%"
                }
            )
        else:
            filters.append(
                {
                    "column": "catchment_id",
                    "operator": "<>",
                    "value": ""
                }
            )
        if reference_time is not None:
            filters.append(
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": f"{reference_time}"
                }  
            )

        query = queries.calculate_catchment_metrics(
            config.MEDIUM_RANGE_FORCING_PARQUET,
            config.FORCING_ANALYSIS_ASSIM_PARQUET,
            group_by=["reference_time, catchment_id"],
            order_by=["reference_time, catchment_id"],
            filters=filters
        )
        df = ddb.query(query).to_df()
        gdf_map = gdf.merge(df, left_on="huc10", right_on="catchment_id")
        gdf_map["name"] = gdf_map["name"].astype("category")
        gdf_map["catchment_id"] = gdf_map["catchment_id"].astype("category")
        spd_map = spd.GeoDataFrame(gdf_map)
        return spd_map

In [None]:
def get_joined_catchment_timeseries(catchment_id: str, reference_time: pd.Timestamp):
        query = queries.get_joined_catchment_timeseries(
            config.MEDIUM_RANGE_FORCING_PARQUET,
            config.FORCING_ANALYSIS_ASSIM_PARQUET,
            filters=[
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": f"{reference_time}"
                },
                {"column": "catchment_id", "operator": "=", "value": catchment_id},
            ],
        )
        df = ddb.query(query).to_df()
        return df

In [None]:
hv.extension('bokeh', logo=False)
opts = dict(
    width=700,
    height=500,
    show_grid=False
)

In [None]:
# cluster = LocalCluster()
# client = Client(cluster)
# cluster

In [None]:
%%time

stream = streams.Tap()


def get_basemap():
    tiles = gv.tile_sources.OSM.opts(**opts)
    return tiles

    
def get_catchments(catchment_id, reference_time, measure):
    spd_map = get_catchment_details(catchment_id, reference_time)

    catchments = gv.Polygons(
        spd_map,
        crs=ccrs.GOOGLE_MERCATOR, 
        vdims=[measure, 'name', 'catchment_id']
    )
    
    # bounds = spd_map["geometry"].total_bounds    
    # catchments.opts(xlim=(bounds[0], bounds[1]), ylim=(bounds[2], bounds[3]))
    
    measure_min = spd_map[measure].min()
    measure_max = spd_map[measure].max()
    catchments = catchments.redim.range(**{f"{measure}": (measure_min, measure_max)})
    
    return catchments


def get_table(x, y, reference_time):
    if x is not None and y is not None:
        pnt = Point(x, y)
        catchment = gdf[(gdf.contains(pnt) == True)]
        catchment_id = catchment["huc10"].iloc[0]
        catchment_name = catchment["name"].iloc[0]
        
        df = get_joined_catchment_timeseries(catchment_id, reference_time)[["value_time", "forecast_value", "observed_value"]]
        return pn.panel(hv.Table(df, width=700).relabel(f"{catchment_name} | {catchment_id} | {reference_time}"), width=700)
    

def get_plot(x, y, reference_time):
    if x is not None and y is not None:
        pnt = Point(x, y)
        catchment = gdf[(gdf.contains(pnt) == True)]
        catchment_id = catchment["huc10"].iloc[0]
        catchment_name = catchment["name"].iloc[0]

        df = get_joined_catchment_timeseries(catchment_id, reference_time)[["value_time", "forecast_value", "observed_value"]]
        
        forecast_val = hv.Curve(
            df, "value_time", "forecast_value", label="forecast_value"
        )
        forecast_val.opts(tools=["hover"], color="orange")
        observed_val = hv.Curve(
            df, "value_time", "observed_value", label="observed_value"
        )
        observed_val.opts(tools=["hover"], color="blue")

        viz = (forecast_val * observed_val).relabel(f"{catchment_name} | {catchment_id} | {reference_time}")
        return pn.panel(viz, width=700)
    

def get_reference_times():
    query = f"""select distinct(reference_time)as time
        from '{config.MEDIUM_RANGE_FORCING_PARQUET}/*.parquet'
        order by reference_time asc"""
    df = ddb.query(query).to_df()
    times = df.time.tolist()
    return times


def get_reference_time_slider():
    reference_times = get_reference_times()
    reference_time_slider = pn.widgets.DiscretePlayer(name='Discrete Player', options=list(reference_times), value=reference_times[0], interval=5000)   
    return reference_time_slider


def get_reference_time(reference_time):
    return pn.pane.HTML(f"{reference_time}", width_policy="fit")


def get_measure_selector():
    measures = [
            "bias",
            "max_forecast_delta",
            "observed_variance",
            "forecast_variance",
            "observed_average",
            "forecast_average",
        ]  
    measure_selector = pn.widgets.Select(name='Measure', options=measures, value=measures[0], width_policy="fit") 
    return measure_selector


def get_cmap_selector():
    cmaps = odict([(n,colorcet.palette[n]) for n in ['fire', 'bgy', 'bgyw', 'bmy', 'gray', 'kbc']])
    cmap = pn.widgets.Select(name='Colormap', options=cmaps)
    return cmap


def get_aggregator(measure):
    return ds.mean(measure)


def get_huc_selector():
    hucs=[
        "all",
        "01",
        "02",
        "03",
        "04",
        "05",
        "06",
        "07",
        "08",
        "09",
        "10",
        "11",
        "12",
        "13",
        "14",
        "15",
        "16",
        "17",
        "18",
    ]
    huc_selector = pn.widgets.Select(name='HUC2', options=hucs, value="01") 
    return huc_selector

# def get_map_overlay(catchment_id, reference_time, measure):
#     catchment_polygons = get_catchments(
#         catchment_id=catchment_id, 
#         reference_time=reference_time, 
#         measure=measure
#     )
#     aggregator = get_measure(measure=measure)
#     raster_catchments = rasterize(catchment_polygons, aggregator=aggregator, precompute=True).opts(**opts, colorbar=True, cmap="fire")

#     return raster_catchments


reference_time_slider = get_reference_time_slider()
measure_selector = get_measure_selector()
cmap_selector = get_cmap_selector()
huc_selector = get_huc_selector()


catchment_polygons = pn.bind(
    get_catchments, 
    catchment_id=huc_selector.param.value, 
    reference_time=reference_time_slider.param.value, 
    measure=measure_selector.param.value
)
aggregator = pn.bind(get_aggregator, measure=measure_selector.param.value)
raster_catchments = rasterize(hv.DynamicMap(catchment_polygons), aggregator=aggregator, precompute=True).opts(**opts, colorbar=True, cmap="fire")

# raster_catchments = pn.bind(
#     get_map_overlay, 
#     catchment_id=huc_selector.param.value, 
#     reference_time=reference_time_slider.param.value, 
#     measure=measure_selector.param.value
# )

hover = inspect_polygons(raster_catchments).opts(fill_color='yellow', tools=["hover", 'tap']).opts(alpha=0.5)

stream.source = raster_catchments

pn.param.ParamMethod.loading_indicator = True

layout = pn.Column(
    pn.Row(
        pn.pane.PNG('https://ciroh.ua.edu/wp-content/uploads/2022/08/CIROHLogo_200x200.png', width=100),
        pn.pane.Markdown(
            """
            # CIROH Integrated Evaluation and Exploration System
            ## After-action Analysis
            """,
            width=800
        )
    ),
    pn.Row(
        pn.Column(
            reference_time_slider,
            measure_selector,
            huc_selector,
            # cmap_selector,
        ),
        pn.Column(
            pn.bind(get_reference_time, reference_time=reference_time_slider.param.value),
            pn.Row(
                hv.DynamicMap(get_basemap) * raster_catchments * hover
            ),
            pn.Row(
                pn.Tabs(
                    ("Plot", pn.bind(get_plot, x=stream.param.x, y=stream.param.y, reference_time=reference_time_slider.param.value)),
                    ("Table", pn.bind(get_table, x=stream.param.x, y=stream.param.y, reference_time=reference_time_slider.param.value))
                )
            )
        )
    )
)
layout.servable()