In [None]:
%%capture
!pip install spatialpandas easydev colormap colorcet duckdb dask_geopandas nb_black

In [None]:
import geopandas as gpd
import holoviews as hv
import geoviews as gv
import hvplot.pandas
from holoviews import opts
from bokeh.models import HoverTool
from bokeh.plotting import figure

In [None]:
import os
import sys

#DATA_DIR = "/Users/ctownsend/projects/hydro_data/data/studies"
# adding project dirs to path so code may be referenced from the notebook
sys.path.insert(0, '../../')
sys.path.insert(0, '../../evaluation/')
sys.path.insert(0, '../../evaluation/queries/')

from evaluation import utils, config
import queries # need to fix path to use original queries
import dask_geopandas
import duckdb as ddb
import spatialpandas as spd
from holoviews.operation.datashader import (
    rasterize, shade, regrid, inspect_points
)
from holoviews.operation.datashader import (
    datashade, inspect_polygons
)
import datashader as ds
import cartopy.crs as ccrs
from dask.distributed import Client, LocalCluster
import dask.dataframe as dd
import panel as pn
from holoviews import streams
import numpy as np
from shapely.geometry import Point
from panel.interact import interact, fixed

In [None]:
gdf = utils.parquet_to_gdf(config.HUC10_PARQUET_FILEPATH).to_crs("EPSG:3857")
# gdf["geometry"] = gdf["geometry"].simplify(tolerance=100)

In [None]:
def get_catchment_details(catchment_id, time=None):
        filters = []
        if len(catchment_id) == 2:
            filters.append(
                {
                    "column": "catchment_id",
                    "operator": "like",
                    "value": "" + catchment_id + "%"
                }
            )
        elif len(catchment_id) > 3:
            filters.append(
                {
                    "column": "catchment_id",
                    "operator": "=",
                    "value": catchment_id
                }
            )
        
        if time is not None:
            filters.append(
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": f"{time}"
                }  
            )

        query = queries.calculate_catchment_metrics(
            config.MEDIUM_RANGE_FORCING_PARQUET,
            config.FORCING_ANALYSIS_ASSIM_PARQUET,
            group_by=["reference_time, catchment_id"],
            order_by=["reference_time, catchment_id"],
            filters=filters
        )
        df = ddb.query(query).to_df()
        gdf_map = gdf.merge(df, left_on="huc10", right_on="catchment_id")
        gdf_map = gdf_map[["geometry","name","catchment_id","bias"]]
        gdf_map["name"] = gdf_map["name"].astype("category")
        gdf_map["catchment_id"] = gdf_map["catchment_id"].astype("category")
        spd_map = spd.GeoDataFrame(gdf_map)
        return spd_map

In [None]:
def get_joined_catchment_timeseries(catchment_id, time):
        query = queries.get_joined_catchment_timeseries(
            config.MEDIUM_RANGE_FORCING_PARQUET,
            config.FORCING_ANALYSIS_ASSIM_PARQUET,
            filters=[
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": time,
                },
                {"column": "catchment_id", "operator": "=", "value": catchment_id},
            ],
        )
        df = ddb.query(query).to_df()
        return df

In [None]:
hv.extension('bokeh', logo=False)
opts = dict(
    width=700,
    height=500,
    show_grid=False
)

In [None]:
# cluster = LocalCluster()
# client = Client(cluster)
# cluster

In [None]:
%%time
reference_time = "2022-12-25 00:00:00"

tiles = gv.tile_sources.OSM.opts(**opts)

stream = streams.Tap()

def map_overlay(reference_time):
    spd_map = get_catchment_details("01", reference_time)
    # ddf = dd.from_pandas(
    #     spd_map, 
    #     npartitions=16
    # ).pack_partitions(npartitions=32).persist()

    polygons = gv.Polygons(
        # ddf, 
        spd_map,
        crs=ccrs.GOOGLE_MERCATOR, 
        vdims=['bias', 'name', 'catchment_id']
    )

    rast_polygons = shade(rasterize(
        polygons,
        aggregator=ds.mean("bias"),
        precompute=True,
    )).opts(
        **opts
    )

    tooltips=[('Name', '@name'), ('Catchment ID', '@catchment_id'), ("bias", f"@bias")]
    hover_tool = HoverTool(tooltips=tooltips)
    hover = inspect_polygons(
        rast_polygons
    ).opts(
        fill_color='yellow', 
        tools=[hover_tool,'tap']
    ).opts(alpha=0.5)

    
    dashboard = (
        tiles
        * rast_polygons
        # * polygons
        * hover
    )

    stream.source = rast_polygons
    
    return rast_polygons * hover


def table(x,y):
    if x is not None and y is not None:
        pnt = Point(x, y)
        rslt = gdf[(gdf.contains(pnt) == True)]
        catchment_id = rslt["huc10"].iloc[0]
        catchment_name = rslt["name"].iloc[0]
        
        df = get_joined_catchment_timeseries(catchment_id, reference_time)[["value_time", "forecast_value", "observed_value"]]
        return pn.panel(hv.Table(df).relabel(f"{catchment_name} | {catchment_id} | {reference_time}"), width=700)
    

def plot(x,y):
    if x is not None and y is not None:
        pnt = Point(x, y)
        rslt = gdf[(gdf.contains(pnt) == True)]
        catchment_id = rslt["huc10"].iloc[0]
        catchment_name = rslt["name"].iloc[0]

        df = get_joined_catchment_timeseries(catchment_id, reference_time)[["value_time", "forecast_value", "observed_value"]]
        
        forecast_val = hv.Curve(
            df, "value_time", "forecast_value", label="forecast_value"
        )
        forecast_val.opts(tools=["hover"], color="orange")
        observed_val = hv.Curve(
            df, "value_time", "observed_value", label="observed_value"
        )
        observed_val.opts(tools=["hover"], color="blue")

        viz = forecast_val * observed_val
        return pn.panel(viz, width=700)
    

def get_reference_times():
    query = f"""select distinct(reference_time)as time
        from '{config.MEDIUM_RANGE_FORCING_PARQUET}/*.parquet'
        order by reference_time asc"""
    df = ddb.query(query).to_df()
    times = df.time.tolist()
    return times

reference_times = get_reference_times()
reference_time_slider = pn.widgets.DiscretePlayer(name='Discrete Player', options=list(reference_times), value=reference_times[0])   
reference_time_slider.link(map_overlay, value='object')

layout = pn.Column(
    pn.Row(reference_time_slider),
    pn.Row(tiles * map_overlay(reference_time_slider.value)),
    pn.Row(
        pn.Tabs(
            ("Plot", pn.bind(plot, x=stream.param.x, y=stream.param.y)),
            ("Table", pn.bind(table, x=stream.param.x, y=stream.param.y))
        )
    )
)
layout.servable()