# Demo Plots

## Setup

In [None]:
import xarray as xr
import holoviews as hv
import geopandas as gpd

from clearwater_riverine.variables import (
    NUMBER_OF_REAL_CELLS,
    CONCENTRATION,
)
import numpy as np
import geopandas as gpd
import pandas as pd
from shapely.geometry import Polygon 
import geoviews as gv
import matplotlib.pyplot as plt
from IPython.display import display

hv.extension('bokeh')

In [None]:
ds = xr.open_zarr('W:/2ERDC12 - Clearwater/Clearwater_testing_TSM/plan28_testTSM_pwrPlnt_May2022/full_test_output/mesh_output_full_2023_12_20.zarr')

In [None]:
ds = ds.compute()

In [None]:
gdf = gpd.read_parquet('W:/2ERDC12 - Clearwater/Clearwater_testing_TSM/plan28_testTSM_pwrPlnt_May2022/full_test_output/mesh_output_full_gdf.parquet')

## Create GDF

In [None]:
from clearwater_riverine.variables import (
    NUMBER_OF_REAL_CELLS,
    CONCENTRATION,
)
import numpy as np
import geopandas as gpd
import pandas as pd
from shapely.geometry import Polygon 
import geoviews as gv
import matplotlib.pyplot as plt
from IPython.display import display


In [None]:
crs = 'EPSG:26916'

nreal_index = ds.attrs[NUMBER_OF_REAL_CELLS] + 1
real_face_node_connectivity = ds.face_nodes[0:nreal_index]

# Turn real mesh cells into polygons
polygon_list = []
for cell in real_face_node_connectivity:
    indices = cell[np.where(np.isnan(cell) == False)].astype(int)
    xs = ds.node_x[indices]
    ys = ds.node_y[indices]
    p1 = Polygon(list(zip(xs.values, ys.values)))
    polygon_list.append(p1)

poly_gdf = gpd.GeoDataFrame({
    'nface': ds.nface[0:nreal_index],
    'geometry': polygon_list},
    crs = crs)
poly_gdf = poly_gdf.to_crs('EPSG:4326')

In [None]:
df_from_array = ds[['concentration', 'volume']].isel(
    nface=slice(0, nreal_index)
).to_dataframe()

df_from_array.reset_index(inplace=True)

df_merged = gpd.GeoDataFrame(
    pd.merge(
        df_from_array,
        poly_gdf,
        on='nface',
        how='left'
)
        )

df_merged.rename(
    columns={
        'nface':'cell',
        'time': 'datetime'
    },
    inplace=True
)
gdf = df_merged

In [None]:
gdf.to_parquet('../plan28_testTSM_pwrPlnt_May2022/full_test_output/mesh_output_full_gdf.parquet')

## Basic Plot

In [None]:
def plot(
    ds: xr.Dataset,
    gdf: gpd.geodataframe.GeoDataFrame,
    clim: tuple = (None, None),
    time_index_range: tuple = (0, -1)):

    mval = clim[1]
    mn_val = clim[0]

    def map_generator(datetime, mval=mval):
        """This function generates plots for the DynamicMap"""
        ras_sub_df = gdf[gdf.datetime == datetime]
        units = ds[CONCENTRATION].Units
        ras_map = gv.Polygons(
            ras_sub_df,
            vdims=['concentration', 'cell']).opts(
                height = 500,
                width = 500,
                color='concentration',
                colorbar = True,
                cmap = 'OrRd',
                clim = (mn_val, mval),
                line_width = 0.1,
                tools = ['hover'],
                clabel = f"Concentration ({units})"
        )
        return (ras_map * gv.tile_sources.CartoLight())

    dmap = hv.DynamicMap(map_generator, kdims=['datetime'])
    return dmap.redim.values(datetime=gdf.datetime.unique()[time_index_range[0]: time_index_range[1]])

In [None]:
curve = plot(ds, gdf, clim=(13,25))

In [None]:
curve

In [None]:
minx = gdf.geometry.bounds['minx'].min()
maxx = gdf.geometry.bounds['maxx'].max()
miny = gdf.geometry.bounds['miny'].min()
maxy = gdf.geometry.bounds['maxy'].max()

In [None]:
print(minx, maxx, miny, maxy)

In [None]:
gdf.head()

## Basic Map

In [None]:
date_value = ds.time.isel(time=18000).values
gdf[(gdf.datetime == date_value) & (gdf.volume > 0)].plot(
    facecolor = 'lightskyblue',
    edgecolor = 'white',
    linewidth = 0.3,
    ax=plt.gca()
)
plt.xticks([])
plt.yticks([])
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['bottom'].set_visible(False)
plt.gca().spines['left'].set_visible(False)

plt.show()


## Timeseries Plot

In [None]:
def make_timeseries_plot(ls_of_cells):
    # matplotlib plot
    date_value = ds.time.isel(time=18000).values
    gdf[(gdf.datetime == date_value) & (gdf.volume > 0)].plot(
        facecolor = 'lightskyblue',
        edgecolor = 'white',
        linewidth = 0.3,
        ax=plt.gca()
    )
    for cell in ls_of_cells.keys():
        gdf[(gdf.datetime == date_value) & (gdf.cell == cell)].plot(
            facecolor = 'none',
            edgecolor = ls_of_cells[cell],
            linewidth = 1,
            ax=plt.gca()
        )

    plt.xticks([])
    plt.yticks([])
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)

    plt.show()

    curve_ls = []
    for cell in ls_of_cells.keys():
        temp_curve = hv.Curve(
            ds.concentration.isel(
                nface=cell,
                time = slice(18000,-60)
                ),
        ).opts(
            height=400,
            width=800,
            ylabel='Temperature (C)',
            color = ls_of_cells[cell],
        )
        curve_ls.append(temp_curve)

    overlayed_curves = hv.Overlay(curve_ls)
    return overlayed_curves

In [None]:
cell_list = {
    273: '#FF5733',
    299: '#33FF57',
    311: '#3366FF',
    148: '#FFFF33',
    169: '#33FFFF',
    325: '#FF33A1',
    263: '#A133FF',
    # 137: '#FF3333',
}

In [None]:
make_timeseries_plot(cell_list)

### Series to step through

In [None]:
temp_dict = {}
for cell in cell_list.keys():
    print(cell)
    temp_dict[cell] = cell_list[cell]
    c = make_timeseries_plot(temp_dict)
    display(c)

## Map

In [None]:
from datetime import datetime

In [None]:
def map_plot(time_index):
    date_value = ds.time.isel(time=time_index).values
    c = gdf[(gdf.datetime == date_value) & (gdf.volume > 0)].plot(
        column='concentration',
        cmap='OrRd',
        vmin = 13,
        vmax = 25)
    plt.xticks([])
    plt.yticks([])
    formatted_datetime = np.datetime_as_string(date_value, unit='s').split('T')
    
    plt.title(f'{formatted_datetime[0]} {formatted_datetime[1]}')
    plt.show()

In [None]:
map_plot(32400)

In [None]:
def make_conc_plot(time_index, ls_of_cells):
    # matplotlib plot
    date_value = ds.time.isel(time=time_index).values
    formatted_datetime = np.datetime_as_string(date_value, unit='s').split('T')
    c = gdf[(gdf.datetime == date_value) & (gdf.volume > 0)].plot(
        column='concentration',
        cmap='OrRd',
        vmin = 13,
        vmax = 25)
    
    for cell in ls_of_cells.keys():
        gdf[(gdf.datetime == date_value) & (gdf.cell == cell)].plot(
            facecolor = 'none',
            edgecolor = ls_of_cells[cell],
            linewidth = 1,
            ax=plt.gca()
        )

    plt.xticks([])
    plt.yticks([])
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.title(f'{formatted_datetime[0]} {formatted_datetime[1]}')
    plt.show()

    curve_ls = []
    for cell in ls_of_cells.keys():
        temp_curve = hv.Curve(
            ds.concentration.isel(
                nface=cell,
                time = slice(18000,-60)
                ),
        ).opts(
            height=400,
            width=800,
            ylabel='Temperature (C)',
            color = ls_of_cells[cell],
        )
        curve_ls.append(temp_curve)

    # xs = np.linspace(mn, mx, 100)
    # ys = xs * 0 + date_value
    vline = hv.VLine(date_value).opts(
        line_width=6,
        line_dash = 'dashed',
        
    )

    
    

    overlayed_curves = hv.Overlay(curve_ls) * vline
    return overlayed_curves

In [None]:
make_conc_plot(32400, {273:'red'})

In [None]:
def make_conc_plot_alternative(time_index, ls_of_cells):
    # matplotlib plot
    date_value = ds.time.isel(time=time_index).values
    formatted_datetime = np.datetime_as_string(date_value, unit='s').split('T')
    c = gdf[(gdf.datetime == date_value) & (gdf.volume > 0)].plot(
        column='concentration',
        cmap='OrRd',
        vmin = 13,
        vmax = 25)
    
    for cell in ls_of_cells.keys():
        gdf[(gdf.datetime == date_value) & (gdf.cell == cell)].plot(
            facecolor = 'none',
            edgecolor = ls_of_cells[cell],
            linewidth = 1,
            ax=plt.gca()
        )

    plt.xticks([])
    plt.yticks([])
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['bottom'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.title(f'{formatted_datetime[0]} {formatted_datetime[1]}')
    plt.show()

    curve_ls = []
    for cell in ls_of_cells.keys():
        temp_curve = hv.Curve(
            ds.concentration.isel(
                nface=cell,
                time = slice(18000, time_index)
                ),
        ).opts(
            height=400,
            width=800,
            ylabel='Temperature (C)',
            color = ls_of_cells[cell],
            xlim = (
                pd.to_datetime(ds.time[18000].values), 
                pd.to_datetime(ds.time[-100].values))
        )
        curve_ls.append(temp_curve)    

    overlayed_curves = hv.Overlay(curve_ls) 
    return overlayed_curves

In [None]:
make_conc_plot_alternative(32400, {273:'red'})

## Matplotlib

In [None]:
import matplotlib.dates as mdates

In [None]:
def conc_plot_plt(time_index, ls_of_cells):
    fig, axs = plt.subplots(1, 2, width_ratios=[1,2],)



    date_value = ds.time.isel(time=time_index).values
    formatted_datetime = np.datetime_as_string(date_value, unit='s').split('T')
    
    c = gdf[(gdf.datetime == date_value) & (gdf.volume > 0)].plot(
        column='concentration',
        cmap='OrRd',
        vmin = 13,
        vmax = 25,
        ax = axs[0]
    )

    for cell in ls_of_cells.keys():
        gdf[(gdf.datetime == date_value) & (gdf.cell == cell)].plot(
            facecolor = 'none',
            edgecolor = ls_of_cells[cell],
            linewidth = 1,
            ax=axs[0]
        )
        
    axs[0].set_xticks([])
    axs[0].set_yticks([])
    axs[0].spines['top'].set_visible(False)
    axs[0].spines['right'].set_visible(False)
    axs[0].spines['bottom'].set_visible(False)
    axs[0].spines['left'].set_visible(False)

    for cell in ls_of_cells.keys():
        axs[1].plot(
            ds.time.isel(time=slice(18000, -100)),
            ds.concentration.isel(
                nface = 273,
                time=slice(18000,-100)
            ),
            color = ls_of_cells[cell]
            
        )
        axs[1].axvline(
            date_value,
            linewidth=3, 
            linestyle='dotted'
        )


    fig.set_size_inches(20,6)
    
    # plt.suptitle(f'{formatted_datetime[0]} {formatted_datetime[1]}')
    plt.show()
    plt.show()

In [None]:
conc_plot_plt(32400, {273:'red'})

## Reactivity

https://holoviews.org/reference/streams/bokeh/Tap.html

Start with basic demo:

In [None]:
import panel as pn
import holoviews as hv

pn.extension()

points = hv.Points([])
stream = hv.streams.Tap(source=points, x=np.nan, y=np.nan)

@pn.depends(stream.param.x, stream.param.y)
def location(x, y):
    return pn.pane.Str(f'Click at {x:.2f}, {y:.2f}', width=200)

pn.Row(points, location)

Adapt for these purposes

In [None]:
time_index = 32400
mn_val = 13
mval = 25
date_value = ds.time.isel(time=time_index).values

ras_sub_df = gdf[gdf.datetime == date_value]
units = ds[CONCENTRATION].Units
ras_map = gv.Polygons(
    ras_sub_df,
    vdims=['concentration', 'cell']).opts(
        height = 800,
        width = 800,
        color='concentration',
        colorbar = True,
        cmap = 'OrRd',
        clim = (mn_val, mval),
        line_width = 0.1,
        tools = ['hover', 'tap'],
        clabel = f"Concentration ({units})"
)

tap_stream = hv.streams.Tap(source=ras_map, x=-86.99906, y=0.00191)

def tap_plot(x, y):
    clicked_data = ras_sub_df.cx[x:x, y:y]
    cell = clicked_data['cell'].values[0]
    print(cell)
    cs = ds.concentration.isel(
        nface=cell,
        time = slice(18000, -100)
        )
    mn = float(cs.min().values)
    mx = float(cs.max().values)

    curve =  hv.Curve(cs
    ).opts(
        ylim=(mn,mx),
        title=f'Time series for cell {cell}',
        height=800,
        width=800,
    )
    return curve

        
tap_dmap = hv.DynamicMap(tap_plot, streams=[tap_stream])

(ras_map + tap_dmap).opts(
    opts.Curve(framewise=True, yaxis='right', line_width=3 )
)

# layout = pn.Row(ras_map, tap_dmap)
# layout.servable()