In [None]:
%%capture
!pip install spatialpandas easydev colormap colorcet duckdb dask_geopandas nb_black

In [None]:
%load_ext lab_black

In [None]:
import os
import sys

# DATA_DIR = "/Users/ctownsend/projects/hydro_data/data/studies"
# adding project dirs to path so code may be referenced from the notebook
sys.path.insert(0, "../../")
sys.path.insert(0, "../../evaluation/")
sys.path.insert(0, "../../evaluation/queries/")

from evaluation import utils, config
import queries  # need to fix path to use original queries
import dask_geopandas
import duckdb as ddb

In [None]:
import holoviews as hv, geoviews as gv, param, dask.dataframe as dd, cartopy.crs as crs
import panel as pn
from datetime import datetime as dt
from bokeh.models import HoverTool

# import datetime as dt
import datashader as ds
import spatialpandas as spd
from colormap import rgb2hex
import logging
from shapely.geometry import Point
import dask
import geopandas
from evaluation import utils, config
import queries  # need to fix path to use original queries
import dask_geopandas

from colorcet import cm
from holoviews.operation.datashader import rasterize, shade, regrid, inspect_points
from holoviews.operation.datashader import datashade, inspect_polygons
from holoviews.streams import RangeXY, Pipe, Tap, Selection1D
from holoviews.util.transform import easting_northing_to_lon_lat
import pandas as pd

In [None]:
%%time
hv.extension('bokeh', logo=False)
opts = dict(width=700,
            height=500,
            #xaxis=None,
            #yaxis=None,
            #bgcolor='black',
            show_grid=False)
cmaps = ['fire','bgy','bgyw','bmy','gray','kbc']

In [None]:
class HydroExplorer(param.Parameterized):
    renderer = hv.renderer("bokeh")
    _basins_gdf = utils.parquet_to_gdf(config.HUC10_PARQUET_FILEPATH).to_crs(
        "EPSG:3857"
    )
    # _basins_gdf["geometry"] = _basins_gdf["geometry"].simplify(tolerance=0.001)
    _rslt = geopandas.GeoDataFrame(
        columns=[
            "geometry",
            "huc10",
            "name",
            "reference_time",
            "catchment_id",
            "value_time",
            "forecast_value",
            "observed_value",
            "forecast_average",
            "observed_average",
        ],
        geometry="geometry",
    )
    pn.extension(
        loading_spinner="dots", loading_color="#00aa41", sizing_mode="stretch_width"
    )

    _measure_min = None
    _measure_max = None

    def _get_defaults():
        query = f"""select distinct(reference_time)as _time
            from '{config.MEDIUM_RANGE_FORCING_PARQUET}/*.parquet'
            order by reference_time asc"""
        time_df = ddb.query(query).to_df()

        return (
            time_df["_time"].iloc[0],
            time_df["_time"].iloc[-1],
            time_df._time.tolist(),
        )

    _min_time, _max_time, time_list = _get_defaults()

    measure = param.ObjectSelector(
        default="bias",
        objects=[
            "bias",
            "max_forecast_delta",
            "observed_variance",
            "forecast_variance",
            "observed_average",
            "forecast_average",
        ],
    )
    huc2 = param.ObjectSelector(
        default="01",
        objects=[
            "all",
            "01",
            "02",
            "03",
            "04",
            "05",
            "06",
            "07",
            "08",
            "09",
            "10",
            "11",
            "12",
            "13",
            "14",
            "15",
            "16",
            "17",
            "18",
        ],
    )
    time = param.ObjectSelector(default=_min_time, objects=list(time_list))

    def _set_measure_min_max(self):
        measure_min = self._rslt[self.measure].min()
        measure_max = self._rslt[self.measure].max()
        self._measure_min = measure_min
        self._measure_max = measure_max

    _tap_stream = Tap(transient=False)

    def get_catchment_details(self, catchment_id, _time=None):
        filters = []
        if len(catchment_id) == 2:
            filters.append(
                {
                    "column": "catchment_id",
                    "operator": "like",
                    "value": "" + catchment_id + "%",
                }
            )
        elif len(catchment_id) > 3:
            filters.append(
                {"column": "catchment_id", "operator": "=", "value": catchment_id}
            )

        if _time is not None:
            filters.append(
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": "" + str(_time) + "",
                }
            )

        query = queries.calculate_catchment_metrics(
            config.MEDIUM_RANGE_FORCING_PARQUET,
            config.FORCING_ANALYSIS_ASSIM_PARQUET,
            group_by=["reference_time, catchment_id"],
            order_by=["reference_time, catchment_id"],
            filters=filters,
        )
        df = ddb.query(query).to_df()
        gdf_map = self._basins_gdf.merge(df, left_on="huc10", right_on="catchment_id")
        # gdf_map = gdf_map[["geometry", "name", "catchment_id", "bias"]]
        gdf_map["name"] = gdf_map["name"].astype("category")
        gdf_map["catchment_id"] = gdf_map["catchment_id"].astype("category")
        spd_map = spd.GeoDataFrame(gdf_map)
        return spd_map

    # @param.depends('measure', 'huc2', 'time')
    def get_polygon(self):
        rslt_df = self.get_catchment_details(self.huc2, self.time)
        self._rslt = rslt_df
        self._set_measure_min_max()

        polygon = gv.Polygons(
            rslt_df,
            crs=crs.GOOGLE_MERCATOR,  # needed for tooltips to work
            vdims=[self.measure, "name", "catchment_id"],
        )

        if self._measure_min and self._measure_max:
            polygon = polygon.redim.range(
                **{f"{self.measure}": (self._measure_min, self._measure_max)}
            )

        return polygon

    @param.depends("huc2", "measure", "time")
    def map_plot(self):
        polygons = hv.DynamicMap(self.get_polygon)

        tooltips = [
            ("Name", "@name"),
            ("Catchment ID", "@catchment_id"),
            (self.measure, "@" + self.measure),
        ]
        # tooltips=[('x', '$x'), ('y', '$y'), ('name', '$name')]
        hover_tool = HoverTool(tooltips=tooltips)

        shaded = rasterize(polygons, aggregator=ds.max(self.measure))
        shaded.opts(alpha=0.75, colorbar=True)
        # shaded.opts(alpha=0.75, colorbar=True, tools=[hover_tool, "tap"])

        hover = (
            inspect_polygons(shaded)
            .opts(fill_color="yellow", tools=[hover_tool, "tap"])
            .opts(alpha=0.9)
        )

        self._tap_stream.source = shaded
        tiles = gv.tile_sources.OSM.opts(**opts)

        return (tiles * shaded * hover).opts(width=700, height=500)
        # return (tiles * shaded).opts(width=700, height=500)

    def get_joined_catchment_timeseries(self, catchment_id, _time):
        query = queries.get_joined_catchment_timeseries(
            config.MEDIUM_RANGE_FORCING_PARQUET,
            config.FORCING_ANALYSIS_ASSIM_PARQUET,
            filters=[
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": "" + str(_time) + "",
                },
                {"column": "catchment_id", "operator": "=", "value": catchment_id},
            ],
        )
        df = ddb.query(query).to_df()
        return df

    def plot_joined_catchment_timeseries(self, x, y):
        label = "Select catchment to see timeseries"
        if x is None:
            x, y = 0, 0
        pnt = Point(x, y)
        rslt = self._basins_gdf[(self._basins_gdf.contains(pnt) == True)]

        if len(rslt) > 0:
            label = f"{rslt['name'].iloc[0]} ({str(rslt['huc10'].iloc[0])}) | refrence_time: {self.time}"
            rslt = self.get_joined_catchment_timeseries(
                rslt["huc10"].iloc[0], self.time
            )

        forecast_val = hv.Curve(
            rslt, "value_time", "forecast_value", label="forecast_value"
        )
        forecast_val.opts(tools=["hover"], color="orange")
        observed_val = hv.Curve(
            rslt, "value_time", "observed_value", label="observed_value"
        )
        observed_val.opts(tools=["hover"], color="blue")

        plot = forecast_val * observed_val
        plot.opts(width=1200).relabel(label)
        return plot

    def joined_catchment_timeseries_table(self, x, y):
        label = ""
        if x is None:
            x, y = 0, 0
        # x, y = easting_northing_to_lon_lat(x, y)
        pnt = Point(x, y)
        rslt = self._basins_gdf[(self._basins_gdf.contains(pnt) == True)]

        if len(rslt) > 0:
            label = rslt["name"].iloc[0] + " (" + str(rslt["huc10"].iloc[0]) + ")"
            rslt = self.get_joined_catchment_timeseries(
                rslt["huc10"].iloc[0], self.time
            )

            target_fields = [
                "reference_time",
                "value_time",
                "catchment_id",
                "forecast_value",
                "configuration",
                "measurement_unit",
                "variable_name",
                "observed_value",
                "lead_time",
            ]
            rslt = rslt[target_fields]
        return hv.Table(rslt)

    def get_joined_catchment_timeseries_table_dmap(self):
        return hv.DynamicMap(
            self.joined_catchment_timeseries_table, streams=[self._tap_stream]
        )

    def plot_joined_catchment_timeseries_dmap(self):
        return hv.DynamicMap(
            self.plot_joined_catchment_timeseries, streams=[self._tap_stream]
        )

In [None]:
hydro = HydroExplorer(name="Map Explorer")

layout = pn.Column(
    pn.Row(
        pn.panel(hydro.map_plot, loading_indicator=True),
        pn.Param(hydro.param, widgets={"time": pn.widgets.DiscretePlayer}),
        sizing_mode="stretch_both",
    ),
    #          hydro.plot_forecast_diff,
    pn.Column(
        pn.panel(hydro.plot_joined_catchment_timeseries_dmap, loading_indicator=True)
    ),
    #         hydro.get_table_dmap().opts(width=1200)
    pn.panel(hydro.get_joined_catchment_timeseries_table_dmap, loading_indicator=True),
).servable()

# haven't found a way to do this internal to the class as it binds to the control type at run-time
# self.time.interval in the class didn't do anything
# layout[0][1].widgets['time'].interval = 4000
layout