In [None]:
!pip install spatialpandas easydev colormap colorcet duckdb dask_geopandas

In [None]:
import sys
sys.path.insert(0, '../../')
sys.path.insert(0, '../../evaluation/')
sys.path.insert(0, '../../evaluation/queries/')

from evaluation import utils, config
import queries 
import dask_geopandas
import duckdb as ddb
import pandas as pd
import geopandas as gpd
import config
import utils

In [None]:
gdf = utils.get_usgs_gages()
gdf.head(3)

In [None]:
query = queries.calculate_nwm_feature_metrics(
    config.MEDIUM_RANGE_1_PARQUET,
    config.USGS_PARQUET,
    group_by=["nwm_feature_id", "reference_time"],
    order_by=["observed_average"],
    filters=[
        {
            "column": "1",
            "operator": "=",
            "value": 1
        },
    ]
)
#print(query)
df = ddb.query(query).to_df()
gdf_map = dask_geopandas.from_geopandas(gdf.merge(df, left_on="nwm_feature_id", right_on="nwm_feature_id"), npartitions=16)
gdf_map.sort_values('reference_time', inplace=True)
del df
del gdf
gdf_map.head(3)

In [None]:
%%time
import holoviews as hv, geoviews as gv, param, dask.dataframe as dd, cartopy.crs as crs
import panel as pn
from datetime import datetime as dt
from bokeh.models import HoverTool
import datashader as ds
from spatialpandas import GeoSeries, GeoDataFrame
from colormap import rgb2hex
import logging
from shapely.geometry import Point
import dask
import geopandas 
from evaluation import utils, config
import queries # need to fix path to use original queries
import dask_geopandas
from shapely.ops import nearest_points

from colorcet import cm
from holoviews.operation.datashader import rasterize, shade, regrid, inspect_points, dynspread, spread
from holoviews.operation.datashader import (
    datashade, inspect_polygons
)
from holoviews.operation import decimate
from holoviews.streams import RangeXY, Pipe, Tap, Selection1D
from holoviews.util.transform import easting_northing_to_lon_lat
import pandas as pd

hv.extension('bokeh', logo=False)
opts = dict(width=800,
            height=500,
            show_grid=False)


class NWMExplorer(param.Parameterized):
    _min_time, \
    _max_time, \
    _min_bias, \
    _max_bias,  \
    _min_max_forecast_delta, \
    _max_max_forecast_delta, \
    _min_observed_variance, \
    _max_observed_variance, \
    _min_forecast_variance, \
    _max_forecast_variance, \
    _min_observed_average, \
    _max_observed_average, \
    _min_forecast_average, \
    _max_forecast_average, \
    time_list = dask.compute(gdf_map.reference_time.min(), 
                             gdf_map.reference_time.max(),
                             gdf_map.bias.min(),
                             gdf_map.bias.max(),
                             gdf_map.max_forecast_delta.min(),
                             gdf_map.max_forecast_delta.max(),
                             gdf_map.observed_variance.min(),
                             gdf_map.observed_variance.max(),
                             gdf_map.forecast_variance.min(),
                             gdf_map.forecast_variance.max(),
                             gdf_map.observed_average.min(),
                             gdf_map.observed_average.max(),
                             gdf_map.forecast_average.min(),
                             gdf_map.forecast_average.max(),
                             gdf_map.reference_time.unique()
                             )
    map_view   = param.ObjectSelector(default='regular', objects=['rasterize','regular'])
    measure    = param.ObjectSelector(default='bias', objects=['bias','max_forecast_delta', 'observed_variance', 'forecast_variance', 'observed_average', 'forecast_average'])
    time       = param.ObjectSelector(default=_min_time, objects=list(time_list))
    _rslt = geopandas.GeoDataFrame(columns=list(gdf_map.columns), geometry='geometry')
    _tap_stream = Tap(transient=False)
    pn.extension(loading_spinner='dots', loading_color='#00aa41', sizing_mode="stretch_width")



    @param.depends('measure','time')
    def get_points(self):
        rslt_df = gdf_map[(gdf_map['reference_time']==self.time)] #(gdf_map["observed_average"] > 0) & 
        self._rslt = rslt_df
        rslt_df = rslt_df.to_crs("EPSG:3857")
        
        
        points = gv.Points(GeoDataFrame(rslt_df.compute()), #hover functionality needs spatialpandas dataframe to work
                              crs=crs.GOOGLE_MERCATOR, #needed for tooltips to work
                              vdims=[self.measure, 'nwm_feature_id'])
        points.opts(size=3, 
                    colorbar=True, 
                    cmap='viridis',
                    color=self.measure, 
                    #jitter=True,
                   tools=['hover', 'tap']) 
        if self.measure == 'max_forecast_delta':
            points = points.redim.range(max_forecast_delta=(self._min_max_forecast_delta,self._max_max_forecast_delta))
        elif self.measure == 'bias':
            points = points.redim.range(bias=(self._min_bias,self._max_bias))
        elif self.measure == 'observed_variance':
            points = points.redim.range(observed_variance=(self._min_observed_variance,self._max_observed_variance))
        elif self.measure == 'forecast_variance':
            points = points.redim.range(forecast_variance=(self._min_forecast_variance,self._max_forecast_variance))
        elif self.measure == 'observed_average':
            points = points.redim.range(observed_average=(self._min_observed_average,self._max_observed_average))
        elif self.measure == 'forecast_average':
            points = points.redim.range(forecast_average=(self._min_forecast_average,self._max_forecast_average))
        return points
    
    @param.depends('measure','time')
    def view(self):
        points = self.get_points()
        points = decimate(points, dynamic=False, max_samples=3000)
        tiles = gv.tile_sources.StamenTerrain().apply.opts(alpha=0.15, **opts)
        self._tap_stream.source = points
        return tiles * points

    def view2(self):
        points = hv.DynamicMap(self.get_points)
        tiles = gv.tile_sources.StamenTerrain().apply.opts(alpha=0.15, **opts)
        agg = rasterize(points, 
                        x_sampling=1, 
                        y_sampling=1, 
                        width=800, 
                        height=500, 
                        aggregator=ds.min(self.measure))
        agg.opts(colorbar=True, alpha=0.7)
        agg = spread(agg, px=3)
        tooltips=[('nwm_feature_id', '@nwm_feature_id'), ('gage_id', '@gage_id'), (self.measure, '@' + self.measure)]
        hover_tool = HoverTool(tooltips=tooltips)
        hover = inspect_points(agg).opts(fill_color='yellow', tools=[hover_tool,'tap']).opts(alpha=0.9)
        
        self._tap_stream.source = agg
        return tiles * agg * hover
    
    @param.depends('map_view','measure','time')
    def select_view(self):
        if self.map_view == 'rasterize':
            return self.view2()
        else:
            return self.view()
        

    def get_nwm_feature_timeseries(self, nwm_feature_id):
        query = queries.get_joined_nwm_feature_timeseries(
            config.MEDIUM_RANGE_1_PARQUET,
            config.USGS_PARQUET,
            filters=[
                {
                    "column": "reference_time",
                    "operator": "=",
                    "value": "" + str(self.time) + ""
                },
                {
                    "column": "nwm_feature_id",
                    "operator": "=",
                    "value": nwm_feature_id
                },
            ]
        )

        return ddb.query(query).to_df()

    @pn.depends(_tap_stream.param.x,_tap_stream.param.y)
    def plot_nwm_feature_timeseries(self,x,y):
        nwm_feature_id = 0
        if x is None:
            x,y = 0,0
        x,y = easting_northing_to_lon_lat(x, y)
        if x != 0:
            point = Point(x, y)
            pts3 = self._rslt.geometry.unary_union.compute()
            nearest = nearest_points(point, pts3)
            rslt = self._rslt[(self._rslt.geometry == nearest[1])].compute()
            #x,y,nwm_feature_id = rslt['longitude'].iloc[0], rslt['latitude'].iloc[0], rslt['nwm_feature_id'].iloc[0]
            catchment_df = self.get_nwm_feature_timeseries(rslt['nwm_feature_id'].iloc[0])
            catchment_df.sort_values("value_time", inplace=True)
            tbl = hv.Table(catchment_df)
            label = "nwm_feature_id={nwm_feature_id} | reference_time = {reference_time}".format(nwm_feature_id=rslt['nwm_feature_id'].iloc[0], reference_time=self.time)
            curves = hv.Curve(catchment_df, "value_time", "forecast_value").opts(tools=['hover']) * hv.Curve(catchment_df, "value_time", "observed_value").opts(tools=['hover'])
            curves = curves.relabel(label)
        else:
            d = {'reference_time': [x], 'value_time': [y], 'nwm_feature_id': [nwm_feature_id]}
            tbl = hv.Table(pd.DataFrame(data=d))
            curves = hv.Curve(pd.DataFrame(columns=['value_time', 'forecast_value'])) 
        tbl.opts(width=1200)
        curves.opts(width=1200)
        return hv.Layout(curves + tbl).cols(1)#tbl + curves


nwm = NWMExplorer(name="NWM Explorer")
pn.Column(pn.Row(nwm.select_view, pn.Param(nwm.param, 
                   widgets={'time': pn.widgets.DiscretePlayer})),
         pn.panel(nwm.plot_nwm_feature_timeseries, loading_indicator=True)).servable()

