# EIS Fire Visualizations Notebook 1 v1.1.1
### Version: 06.23.21
EIS – Fire Notebook 1: Visualizations Notebooks aims to create interactive visualizations of IMERG and other raster data. These are interactive visualizations of raster data in conjunction with vector layers with added functionality including (but not limited to) time-series averages, animations.

In [None]:
from datetime import datetime
from dask.distributed import Client
import io
import numpy as np
import pandas
import os
import requests
import random
import s3fs
import warnings
import xarray as xr
import zipfile

import cartopy.crs as ccrs
import colorcet as cc
import geopandas as gpd
import geoviews as gv
import holoviews as hv
import hvplot.pandas
import hvplot.xarray
from holoviews import opts, streams
from holoviews.plotting.links import DataLink
from ipyleaflet import Map, basemaps, basemap_to_tiles, DrawControl
import metpy.calc as mpc
import metpy
import panel as pn
import rioxarray as rxr
from shapely.geometry import mapping, Point

from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from sqlalchemy import create_engine, MetaData, Table, inspect, and_, between, distinct
from sqlalchemy import select, func, distinct, between, and_, or_, not_, Integer

xr.set_options(display_style="html")
warnings.filterwarnings('ignore')
pn.extension()
pn.param.ParamMethod.loading_indicator = True
pn.extension(sizing_mode="stretch_width")

### Datasets

Add, modify or remove the desired Zarr dataset file. Following the convention:
`"DATASET" : "path/to/zarr/file.zarr"`

In [None]:
raster_filepath = {
    'IMERG FWI': "dh-eis-fire-usw2-shared/imerg-fwi.zarr",
    'GEOS FP CONUS': "dh-eis-fire-usw2-shared/geos-fp-zarr/conus.zarr",
    'QFED': 'dh-eis-fire-usw2-shared/qfed.zarr/',
    'GEOS FWI': "dh-eis-fire-usw2-shared/geos-fwi/zarr",
}

In [None]:
# TROPOMI DATA
base_path = '/home/jovyan/efs/eis-fire-tropomi'
tropomi_filepath = {
    'TROPOMI AEROSOL-INDEX': "tropomi_aer_ai.zarr",
    'TROPOMI CO': 'tropomi_co.zarr',
}
for k, v in tropomi_filepath.items():
    tropomi_filepath[k] = os.path.join(base_path, v)

## User-defined variables

Add, modify, or remove variables below. Please follow conventions outlined in comments.

In [None]:
# ---
# Variables from raster datasets to use.
# Follow {'DATASET': ['var1', 'var2'],
#         'DATASET': ['var1', 'var2']
# ---
raster_variables = {
    'IMERG FWI': ['IMERG.FINAL.v6_FWI'],
    'GEOS FP CONUS': ['T2M'],
    'QFED': ['co.biomass'],
    'GEOS FWI': ['GEOS-5_FWI'],
}


raster_variable_to_show = 'IMERG.FINAL.v6_FWI' #'T2M'

# If desired, slice time to date_start and date_end.
date_start = '2020-06-01'
date_end = '2020-11-30'
d_s = datetime(2020, 6, 1)
d_e = datetime(2020, 11, 30)


# See ____ for more projection options.
projection = ccrs.PlateCarree()

In [None]:
# Common iterables
tbounds = slice(date_start, date_end)
line_color_iter = ['red', 'blue', 'green', 'black', 'indianred', 'grey', 'maroon', 'orange', 'gold', 'darkgreen', 'darkslategrey', 
                  'steelblue', 'purple', 'crimson']
cmaps  = {n: cc.palette[n] for n in ['kbc', 'fire', 'bgy', 'bgyw', 'bmy', 'gray', 'kbc']}
cmaps_str = ['kbc', 'fire', 'bgy', 'bgyw', 'bmy', 'gray']

# Valid data are masks 10-15, 20-25, 30-35
goesmasks = tuple(list(range(10, 16)) + list(range(20, 26)) + list(range(30, 36)))

# Read in shape files
psa_shp = gpd.read_file('./shape_files/NIFC_PSA/National_Predictive_Service_Areas_(PSA)_Boundaries.shp')
fires = gpd.read_file('./shape_files/YANG_CA_FIRES/yang-ca-fires.json')

print('Loading in and merging all datasets')
# Map to S3 location
s3 = s3fs.S3FileSystem(anon=False)
dataset_path = {}
for dataset, path in raster_filepath.items():
    print(dataset, path)
    dataset_path[dataset] = s3.get_mapper(path)

zarrDict = {}
for mission, s3_file in dataset_path.items():
    print(mission, s3_file)
    try:
        zarrDict[mission] = xr.open_zarr(s3_file, consolidated=True)
    except:
        zarrDict[mission] = xr.open_zarr(s3_file, consolidated=False)
datasetSubs = {}
rasterAttrs = {}
for mission, dataset in raster_variables.items():
    for subdataset in dataset:
        if mission not in datasetSubs.keys():
            datasetSubs[mission] = {}
        rasterAttrs[subdataset] = zarrDict[mission][subdataset].attrs
        rasterAttrs[subdataset]['mission'] = mission
        datasetSubs[mission].update({ subdataset : zarrDict[mission][subdataset]})
        
#---
# Dataset:subdataset resampling to daily temporal resolution
#---
datasetSubs['GEOS FP CONUS']['T2M'] = datasetSubs['GEOS FP CONUS']['T2M'].resample(time='1D').mean("time").sel(time=tbounds)
datasetSubs['GEOS FWI']['GEOS-5_FWI'] = datasetSubs['GEOS FWI']['GEOS-5_FWI'].resample(time='1D').mean('time').sel(time=tbounds).interp_like(datasetSubs['GEOS FP CONUS']['T2M']).drop('forecast_lag')
datasetSubs['IMERG FWI']['IMERG.FINAL.v6_FWI'] = datasetSubs['IMERG FWI']['IMERG.FINAL.v6_FWI'].sel(time=tbounds).interp_like(datasetSubs['GEOS FP CONUS']['T2M'])
datasetSubs['QFED']['co.biomass'] = datasetSubs['QFED']['co.biomass'].sel(time=tbounds).interp_like(datasetSubs['GEOS FP CONUS']['T2M'])

mergedDataset = None
for mission, dataDict in datasetSubs.items():
    for subd, array in dataDict.items():
        if mergedDataset is None:
            mergedDataset = array
        print(subd)
        mergedDataset = xr.merge([mergedDataset, array])
datasetSubs = None

# ---
# Handling TROPOMI data
# ---
print('Opening and merging all TROPOMI data.')
tropomiDict = {}
for mission, path in tropomi_filepath.items():
    print(mission, path)
    tropomiDict[mission] = xr.open_zarr(path)

tropomiVars = {}
for ds in tropomiDict.keys():
    l = list(tropomiDict[ds].variables)
    l.remove('lat')
    l.remove('lon')
    l.remove('time')
    tropomiVars[ds] = l

tropomiSubs = {}
for mission, dataset in tropomiVars.items():
    for subdataset in dataset:
        if mission not in tropomiSubs.keys():
            tropomiSubs[mission] = {}
        rasterAttrs[subdataset] = tropomiDict[mission][subdataset].attrs
        rasterAttrs[subdataset]['mission'] = mission
        tropomiSubs[mission].update({ subdataset : tropomiDict[mission][subdataset]})

for k in tropomiSubs.keys():
    for v in tropomiSubs[k]:
        tropomiSubs[k][v] = tropomiSubs[k][v].resample(time='1D').mean('time').sel(time=tbounds)

tropoDataset = None
for mission, dataDict in tropomiSubs.items():
    for subd, array in dataDict.items():
        if tropoDataset is None:
            tropoDataset = array
        tropoDataset = xr.merge([tropoDataset, array])
tropomiSubs = None

### Functions to query and read shape layers and data
General helper functions included.

In [None]:
def constructDataDict(mergedDataset=None, attrsDict=None):
    """Construct a meta-data dictionary to accompany all dataset:variable combinations"""
    dataDict = dict()
    variableList = list(mergedDataset.variables)[3:] if mergedDataset else attrsDict.keys()
    varListForSum = ['co', 'co.biomass']
    for var in variableList:
        if mergedDataset:
            metadata = mergedDataset[var].attrs
        else:
            metadata = attrsDict[var]
        try:
            longName = '{}:{}'.format(metadata['long_name'], metadata['mission'])
        except KeyError:
            longName = '{}:{}'.format(var, metadata['mission'])
        try:
            units = metadata['units']
        except KeyError:
            units = ''
        needsCollapse = True if var in varListForSum else False
        shortName = var
        tropomiSet = True if var in list(tropoDataset.variables)[3:] else False
        dataDict[longName] = {
            'metadata': metadata,
            'units': units,
            'collapse': needsCollapse,
            'shortName': shortName,
            'tropomiSet': tropomiSet,
            'derived': False
        }
    return dataDict

def clip_to_shape(raster, shape_file):
    """Given a raster and geometry, clip raster to the given geometry."""
    raster.rio.set_spatial_dims(x_dim="lon", y_dim="lat", inplace=True)
    raster.rio.write_crs("epsg:4326", inplace=True)
    raster_clipped = raster.rio.clip(shape_file.geometry.apply(mapping), shape_file.crs, drop=False)
    return raster_clipped

def collapseTo1D(datarray):
    """Collapse lat/lon via sum given an xarray."""
    dataCollapsed = datarray.sum(axis=1)
    dataCollapsed = dataCollapsed.sum(axis=1)
    return dataCollapsed

def collapseMean(datarray):
    """Load dat and collapse to averaged lat/lon"""
    datarray = datarray.load()
    dataCollapsed = datarray.mean(['lat'])
    dataCollapsed = dataCollapsed.mean(['lon'])
    return dataCollapsed

def plotTS(datarray, lat, lon):
    """Plot time-series given a xarray dataarray and lat/lon"""
    datarray = datarray.load()
    dataSelected = datarray.interactive.sel(lon=lon, 
                                            lat=lat, 
                                            method='nearest')
    return dataSelected

def updateTitle(lat, lon):
    latStr = str(round(lat, 4))
    lonStr = str(round(lon, 4))
    mkdStr = '## Point Selected: ({}, {})'.format(lonStr, latStr)
    return pn.pane.Markdown(mkdStr, width=800)

def timeAveragedCallBack(target, event):
    """Callback to enable time-averaged related widgets"""
    target.disabled = False if event.new == 'Time Averaged' else True

def sequentialDisable(target, event):
    """Callback to disable time-averaged related widgets"""
    target.disabled = True if event.new == 'Time Averaged' else False

def aeSitesCallback(target, event):
    """Callback to enable Aeronet-related widgets"""
    target.disabled = False if event.new == True else True

### Contruct metadata dictionary

In [None]:
%%time
dataDict = constructDataDict(mergedDataset=None, attrsDict=rasterAttrs)
totalVars = list(dataDict.keys())

### Rechunk and persist data to memory

With added derived datasets will take 10 minutes to persist data to memory. 2 minutes without derived datasets. 

In [None]:
%%time
mergedDataset = mergedDataset.chunk({'time':-1})
tropoDataset = tropoDataset.chunk({'time':-1})
print('Peristing merged data')
mergedDataset = mergedDataset.persist()
print('Persisting derived data')
# Comment below line out to not persist derived vars
latLonGrid = tropoDataset['aerosol_index_354_388'].isel(time=0)
latLonGrid = latLonGrid.rename('pl')
latLonGrid = latLonGrid * 0
latLonGrid = latLonGrid.persist()
print('Persisting TROPOMI data')
tropoDataset = tropoDataset.persist()

# Visualization 1 - Interactive time series

This is a basic interactive plot. The control panel on the left give the users many options, including choosing colormaps, which state to get data from, and more.

The controls on the right allow you to zoom, move around, and save the visual. Hover over the visual to see variables.

In [None]:
timeAveragedSequentialRadioGroup = pn.widgets.RadioButtonGroup(
 name='Visualization Type', options=['Daily', 'Time Averaged'], value='Daily', button_type='success')

mapLayers= pn.widgets.RadioButtonGroup(
    name='Base map layer',
    options=['Open Street Map', 'Satellite Imagery'],
    value='Satellite Imagery',
    button_type='success')

bufferInput = pn.widgets.FloatInput(name='Buffer (degrees)', value=0.1)
timeStepInput = pn.widgets.TextInput(name='Time Step', value='7D', disabled=True)
timeAveragedSequentialRadioGroup.link(timeStepInput, callbacks={'value': timeAveragedCallBack})
toggleCSVExport = pn.widgets.Button(name='Export to time series to CSV', button_type='success')

# List desired defualt variables here
defaultVars = ['IMERG.FINAL.v6 Fire Weather Index:IMERG FWI',
               'Vertically integrated CO column:TROPOMI CO', 
               'CO2 Biomass Emissions:QFED']
addTimeSeries = pn.widgets.MultiChoice(name='Time Series Variables',  value=defaultVars, options=totalVars)

# Edit title and subtitles here
titleText = '# EIS FIRE - Mult-variable time series visualization'
subtitleText = 'This interactive dashboard displays raster data in an interactive format.' + \
' This raster data may be clipped to individual states via user-controls. '+ \
'In addition this dashboard displays various time-series of user-given data products. '+ \
'These time-series are averaged over individual state polygons. All TROPOMI data is unfiltered L2 gridded data' + \
'\n <b>Make sure dashboard is done loading before making changed to left-side control bar</b>'
title = pn.pane.Markdown(titleText, width=600)
subtitle = pn.pane.Markdown(subtitleText)

In [None]:
@pn.depends(addTS=addTimeSeries.param.value, exportToCSV=toggleCSVExport.param.value, rasterType=timeAveragedSequentialRadioGroup.param.value,
            step=timeStepInput, mapLayer=mapLayers.param.value, buffer=bufferInput.param.value)
def visuaPanel(addTS, exportToCSV, rasterType, step, buffer, mapLayer):
    timeAve = True if rasterType=='Time Averaged' else False
    tile = 'OSM' if mapLayer=='Open Street Map' else 'ESRI' # Choose basemap style. 
    data = ([40], [-110]) # Starting point selection.
    points = hv.Points(data).redim.range(x=(-130, -70), y=(0, 60)) # Limit to U.S.
    
    # Captures user-input and plots onto the map. No functionality attached. 
    points = points.opts(active_tools=['point_draw'], 
                         color='red',
                         size=10, 
                         marker='^')
    point_stream = streams.PointDraw(data=points.columns(), num_objects=1, source=points) # Captures user input for marker.
    
    # Plotting basemap based off of empty xarray to give us a nice lat/lon grid with ESRI tiles.
    baseMap = latLonGrid.hvplot(alpha=0, 
                                xlim=(-130, -70), 
                                ylim=(0, 60), 
                                frame_width=800, 
                                tiles=tile, 
                                frame_height=600, 
                                coastline=True, 
                                colorbar=False).opts(opts.Image(xlabel='Longitude',
                                                                ylabel='Latitude'))
    overlay = (baseMap * points)
    
    #Captures user input -120.8762, 37.6069
    stream = hv.streams.Tap(source=baseMap.Image.I, x=-120.8762, y=37.6069) 
    
    # Plotting time series
    col = pn.Column(pn.Card(overlay))
    
    # Make a nice grid to populate with time series.
    rowDict= {}
    tsCard = pn.Card()
    tsCard.append(pn.bind(updateTitle, lat=stream.param.y, lon=stream.param.x))
    
    for i in range(20):
        row = pn.Row()
        rowDict[i+1] = row
        tsCard.append(row)
    
    for i, ts in enumerate(addTS):
        rowIdx = ((i+1) + 2 // 2) // 2 # Find which row to put the TS in. 
        color=line_color_iter[random.randint(0, len(line_color_iter)-1)]
        if (not 'aod' in ts): # Make sure we aren't plotting any Aeronet data yet. 
            varShortName = dataDict[ts]['shortName'] # Variable short name from which to index with.
            tropo = dataDict[ts]['tropomiSet'] # A bool to let us know if it's tropomi data.
            derived = dataDict[ts]['derived'] # If it's a derived dataset. 
            
            if tropo:
                rasterTS = tropoDataset[varShortName]
            elif(derived):
                rasterTS = derivedDataset[varShortName]
            else:
                rasterTS = mergedDataset[varShortName]
            
            rasterTS = rasterTS.resample(time=step).mean(dim='time') if timeAve else rasterTS # Resample to user-given time-step.
            ds = plotTS(rasterTS, lat=stream.param.y, lon=stream.param.x) # Dataset given a raster and lat, lon. 
            ylabel = '{} ({})'.format(varShortName, dataDict[ts]['units'])
            addTimeSeriesPlot = ds.hvplot('time', 
                                          color=color, 
                                          title='{} {}'.format(ts, dataDict[ts]['units']),
                                         ylabel=ylabel)
            
            if exportToCSV:
                ds.to_pandas().to_csv('{}.csv'.format(varShortName))
            
            rowDict[rowIdx].append(addTimeSeriesPlot.dmap())
    
    col.append(tsCard)
    return col

In [None]:
titleRow = pn.Row(pn.Column(title, subtitle))
accordion = pn.Accordion(
                    ('Layers', pn.Column(mapLayers)),
    header_background='#0059b3', header_color='white', active_header_background='#339966')
widgetBox = pn.WidgetBox(pn.Column(timeAveragedSequentialRadioGroup, 
                                accordion,
                                bufferInput,
                                timeStepInput, 
                                addTimeSeries, toggleCSVExport), height=800)
dashboard = pn.Column(titleRow, pn.Row(widgetBox, visuaPanel), background='WhiteSmoke')
print('Rendering')
dashboard