## PAVICS Ouranos - Custom climate portraits
### To begin :
1. Hide all code (Optional) : Click View -> Collapse all code
2. Click Run -> Restart kernel and Run all Cells

In [2]:
%matplotlib inline
import sys
import copy
import glob
import ipywidgets as widgets
import ipython_blocking
from ipywidgets import FloatSlider, BoundedFloatText, BoundedIntText,HBox, VBox, AppLayout
from IPython.display import HTML, Javascript, display, display_markdown
import os
import ipyleaflet
import geopandas as gpd
import pandas as pd
import xclim.subset as sub
import xclim.atmos as atmos
import xclim.ensembles as ens
import xarray as xr
import threddsclient as tds
from siphon.catalog import TDSCatalog as TDS
import numpy as np
from dask.diagnostics import ProgressBar
import tempfile
import shutil
#from bqplot import LinearScale, Axis, Lines, Figure, Toolbar, PanZoom
from bokeh.plotting import figure, output_file, output_notebook, show
from bokeh.io import push_notebook
from dask.distributed import Client
import warnings
warnings.filterwarnings('ignore')

p_wid = 900
p_height = 400


# Subfunctions
def write_netcdf(dsSub,outfile):
    client = Client(n_workers=4)
    
    with ProgressBar():
        if os.path.exists(os.path.join('nc_data','tmp')):
            shutil.rmtree(os.path.join('nc_data','tmp'))
        os.makedirs(os.path.join('nc_data','tmp'))
        tmpdir = tempfile.mkdtemp(dir=os.path.join('nc_data','tmp'))
        # Set to NETCDF3 for now HDF Error????
        dsSub.to_netcdf(os.path.join(tmpdir,'out.nc'),)
        shutil.move(os.path.join(tmpdir,'out.nc'),outfile)
        shutil.rmtree(tmpdir)
    client.close()

def calc_tas(ds):
    ds['tas'] = 0.5 * (ds['tasmax'] + ds['tasmin'])
    ds['tas'].attrs = ds['tasmax'].attrs
    ds['tas'].attrs['long_name'] = 'Daily Mean Near-Surface Air Temperature'
    ds['tas'].attrs['cell_methods'] = 'time: mean'
    return ds

def load_data(outfolder):
    
    dsCC = {}
    # subset Obs data
    prog1.layout.visibility = 'visible'   
    outfile = os.path.join(outfolder,f'subset_Obs.nc')  
    dsCC['Obs'] = xr.open_dataset(outfile,chunks=dict(lat=10,lon=10))
    dsCC['Obs'] = calc_tas(dsCC['Obs'])


    prog1.value = 5
    prog1.description = f'{prog.value} % '
    rcps=['rcp45', 'rcp85']
    i = 0
    for r in rcps:
        dsCC[r] = {}
        for m in mods:

            outfile = os.path.join(outfolder,f'subset_{m}_{r}.nc')
            dsCC[r][m] =  xr.open_dataset(outfile)
            dsCC[r][m] = calc_tas(dsCC[r][m])

            prog1.value = int((i+1)/((len(rcps)+1)*len(mods))*100) 

            prog1.description = f'{prog1.value} % '
            i += 1
    
       
    prog1.value = 100
    prog1.description = f'{prog1.value} % '
    prog1.layout.visibility = 'hidden'
    return dsCC


def extract_data(outfolder,bounds):
    if not os.path.exists(outfolder):
        os.makedirs(outfolder)
    pd.DataFrame(bounds).to_csv(os.path.join(outfolder,f'{newregname.value}_bounds.csv'))
    lon_bnds = (float(bounds['minx']), float(bounds['maxx']))
    lat_bnds = (float(bounds['miny']), float(bounds['maxy']))
    print([lon_bnds, lat_bnds])
    dsCC = {}
    # subset Obs data
    
    outfile = os.path.join(outfolder,'subset_Obs.nc')      
    if not os.path.exists(outfile):
        print('writing obs to subsetted .nc file')
        with xr.open_dataset(url,chunks=dict(time=31, lat=102, lon=267), drop_variables=['ts','time_vectors'],) as dsObs:

            write_netcdf(sub.subset_bbox(dsObs, lon_bnds = lon_bnds, lat_bnds = lat_bnds), outfile)

    else:

        print('subsetted netcdf file already exists' )

    dsCC['Obs'] = xr.open_dataset(outfile)
    dsCC['Obs'] = calc_tas(dsCC['Obs'])



    rcps=['rcp45', 'rcp85']

    for r in rcps:
        dsCC[r] = {}
        for m in mods:
            datasets = [ds.opendap_url() for ds in tds.crawl(baseurl_CC.replace('catalog.html', f'{m}/catalog.html'), depth=10) if r in ds.name and '.ncml' in ds.name]
            for d in datasets:

                outfile = os.path.join(outfolder,f'subset_{m}_{r}.nc')
                if not os.path.exists(outfile):
                    print(f'subsetting {m} {r} to new .nc file')
                    with xr.open_dataset(d,chunks=dict(time=256, lat=16*3, lon=16*3),drop_variables=['ts','time_vectors']) as ds1:

                        write_netcdf(sub.subset_bbox(ds1, lon_bnds = lon_bnds, lat_bnds = lat_bnds), outfile)

                else:
                    print('subsetted netcdf file already exists' )

            dsCC[r][m] =  xr.open_dataset(outfile)
            dsCC[r][m] = calc_tas(dsCC[r][m])
    return dsCC

def get_rectangle():
    
    m = ipyleaflet.Map(
        center=canada_center,
        basemap=ipyleaflet.basemaps.CartoDB.Positron,
        zoom=8,
    )
    
    # Create a new draw control
    draw_control = ipyleaflet.DrawControl()

    # disable some drawing inputs
    draw_control.polyline = {}
    draw_control.circlemarker = {}
    draw_control.polygon = {}

    draw_control.rectangle = {
        "shapeOptions": {
            "fillColor": "#4ae",
            "color": "#4ae",
            "fillOpacity": 0.3,
        }
    }

    
    output = widgets.Output(layout={'border': '1px solid black'})
    
    rectangle = {}
    bounds = {}
    # set drawing callback
    def callback(control, action, geo_json):
        if action == "created":
            # note: we can't close the map or remove it from the output
            # from this callback. The map keeps the focus, and the 
            # jupyter keyboard input is messed up.
            # So we set it very thin to make it disappear :)
            #m.layout = {"max_height": "0"}
            with output:
                print("*User selected 1 rectangle*")
                rectangle.update(geo_json)
                bounds.update(gpd.GeoDataFrame.from_features([rectangle]).bounds)
                

    draw_control.on_draw(callback)    

    m.add_control(draw_control)
    
    with output:
        print("Select a rectangle:")
        display(m)
        
    display(output)
    
    return rectangle, bounds

def select_region(b):
    
    if RegSelect.value == 'New Region':
        
        newregname.layout.visibility = 'visible'
        newregname.layout.width = 'auto'
    else:
        newregname.layout.visibility = 'hidden'
        newregname.layout.width = '0'


def create_map():
    
    bounds_csv = glob.glob(os.path.join('.','nc_data',region_name,f'*_bounds.csv'))[0]
    bounds = pd.read_csv(bounds_csv)
    m1 = ipyleaflet.Map(
        center=canada_center,
        basemap=ipyleaflet.basemaps.CartoDB.Positron,
        zoom=8,
    )
    rect_bnds = ((float(bounds.miny.values), float(bounds.minx.values)),(float(bounds.maxy.values), float(bounds.maxx.values)))
    rect_bnds
    rectangle = ipyleaflet.Rectangle(bounds=rect_bnds)
    m1.add_layer(rectangle)
    return m1

def download_offline(b):
    if not os.path.exists('./offline_figures'):
        os.makedirs('./offline_figures')
    outhtml = os.path.join('offline_figures',f"{fig.title.text.replace(' : ','').replace('(','').replace(')','').replace('/',' per ')}.html")
    output_file(outhtml)
    show(fig)
    

def create_args(ds):
    argin = {}
    for v in index[TypeList.value][Index_list.value]['args']['vars']:
        argin[v] = ds[v]
    seas = seaslist.value
    if seas == 'Yearly':
        argin['freq']='YS'
    elif seas in ['Winter','Spring','Summer','Fall']:
        argin['freq'] = 'QS-DEC'
    elif seas == 'Apr-Sept':
        argin['freq'] = '6MS-APR'
    
    if 'thresh' in list(index[TypeList.value][Index_list.value]['args']):
        units = index[TypeList.value][Index_list.value]['args']['thresh']['units']
        argin['thresh'] = f'{threshlist.value} {units}'
    if 'threshTmax' in list(index[TypeList.value][Index_list.value]['args']):
        units = index[TypeList.value][Index_list.value]['args']['threshTmax']['units']
        argin['thresh_tasmax'] = f'{threshTmaxlist.value} {units}'
    if 'threshTmin' in list(index[TypeList.value][Index_list.value]['args']):
        units = index[TypeList.value][Index_list.value]['args']['threshTmin']['units']
        argin['thresh_tasmin'] = f'{threshTminlist.value} {units}'
    if 'window' in list(index[TypeList.value][Index_list.value]['args']):
        argin['window'] = windowlist.value
    
    return argin

def update_index_list(change):
    Index_list.options = list(index[TypeList.value])
    #Index_list.value = list(index[TypeList.value])[0]
    
    seaslist.options = index[TypeList.value][Index_list.value]['args']['seasons']
    #seaslist.value = index[TypeList.value][Index_list.value]['args']['seasons'][0]
    

def calc_index(ds, func_str ,out_str ,to_netcdf=False):
    argsin = create_args(ds)
    func = getattr(atmos, func_str) 
    var_str = func.identifier   
    keys = list(argsin.keys())
    keys.remove('freq')
    for a in sorted(keys):
        if not isinstance(argsin[a], xr.DataArray):
            var_str = f'{var_str}_{a}_{argsin[a]}'
    reg_str = RegSelect.value.replace(' ','_')
    outfile = f'{reg_str}_{var_str}_{out_str}.nc'
            
    outfile = os.path.join(outfolder,'ClimateIndicatorOutput',var_str,argsin['freq'], outfile)
    
    if os.path.exists(outfile):
        out = xr.open_dataset(outfile, decode_times=False)
        out['time'] = xr.decode_cf(out).time
    else:
        out = func(**argsin).to_dataset()
        out.attrs=ds.attrs
        if to_netcdf:
            if not os.path.exists(os.path.dirname(outfile)):
                os.makedirs(os.path.dirname(outfile))
            
            # set compression level 9=highest to save space on jupyter instance
            comp = dict(zlib=True, complevel=9)
            encoding = {var: comp for var in out.data_vars}
            encoding["time"] = {"dtype": "single"}  # Opendap wants floats in the time variable
            out.to_netcdf(outfile,  format="NETCDF4", encoding=encoding)
            
    return out


def update_graph(change):
    
    arglist = list(index[TypeList.value][Index_list.value]['args'])
    
    if 'window' in arglist:
        windowlist.layout.width = 'auto'
        windowlist.min = index[TypeList.value][Index_list.value]['args']['window']['min']
        windowlist.max = index[TypeList.value][Index_list.value]['args']['window']['max']
        if not (windowlist.value > windowlist.min and windowlist.value < windowlist.max):
            windowlist.value = index[TypeList.value][Index_list.value]['args']['window']['value']
        windowlist.layout.visibility = 'visible'
    else:
        windowlist.layout.width = '0'
        windowlist.min = -999999
        windowlist.max = 9999999
        #threshlist.step = 0
        windowlist.layout.visibility = 'hidden'
        
    if 'thresh' in arglist:
        threshlist.layout.width = 'auto'
        threshlist.min = index[TypeList.value][Index_list.value]['args']['thresh']['min']
        threshlist.max = index[TypeList.value][Index_list.value]['args']['thresh']['max']
        threshlist.step = index[TypeList.value][Index_list.value]['args']['thresh']['step']
        if not (threshlist.value > threshlist.min and threshlist.value < threshlist.max):
            threshlist.value = index[TypeList.value][Index_list.value]['args']['thresh']['value']
        threshlist.layout.visibility = 'visible'
     
    else:
        threshlist.layout.width = '0'
        threshlist.min = -999999
        threshlist.max = 9999999
        #threshlist.step = 0
        threshlist.layout.visibility = 'hidden'
        
        
    if 'threshTmax' in arglist:
        threshTmaxlist.layout.width = 'auto'
        threshTmaxlist.min = index[TypeList.value][Index_list.value]['args']['threshTmax']['min']
        threshTmaxlist.max = index[TypeList.value][Index_list.value]['args']['threshTmax']['max']
        threshTmaxlist.step = index[TypeList.value][Index_list.value]['args']['threshTmax']['step']
        if not (threshTmaxlist.value > threshTmaxlist.min and threshTmaxlist.value < threshTmaxlist.max):
            threshTmaxlist.value = index[TypeList.value][Index_list.value]['args']['threshTmax']['value']
        threshTmaxlist.layout.visibility = 'visible'
        threshlist.layout.visibility = 'hidden'
    else:
        threshTmaxlist.layout.width = '0'
        threshTmaxlist.min = -999999
        threshTmaxlist.max = 9999999
        #threshlist.step = 0
        threshTmaxlist.layout.visibility = 'hidden'   
    
    if 'threshTmin' in arglist:
        threshTminlist.layout.width = 'auto'
        threshTminlist.min = index[TypeList.value][Index_list.value]['args']['threshTmin']['min']
        threshTminlist.max = index[TypeList.value][Index_list.value]['args']['threshTmin']['max']
        threshTminlist.step = index[TypeList.value][Index_list.value]['args']['threshTmin']['step']
        if not (threshTminlist.value > threshTminlist.min and threshTminlist.value < threshTminlist.max):
            threshTminlist.value = index[TypeList.value][Index_list.value]['args']['threshTmin']['value']
        threshTminlist.layout.visibility = 'visible'
        threshlist.layout.visibility = 'hidden'
    else:
        threshTminlist.layout.width = '0'
        threshTminlist.min = -999999
        threshTminlist.max = 9999999
        #threshlist.step = 0
        threshTminlist.layout.visibility = 'hidden'   

def create_trace(out,trendflag):
    # create trace
    x = out.time.dt.year

    if seaslist.value != 'Yearly':
        if seaslist.value == 'Winter':
            seas1 = 'DJF'
        elif seaslist.value == 'Spring':
            seas1 = 'MAM'
        elif seaslist.value == 'Summer':
            seas1 = 'JJA'
        elif seaslist.value == 'Fall':
            seas1 = 'SON'
        y = out.sel(time=(out.time.dt.season==seas1)).mean(dim=['lon','lat'])
        x = out.sel(time=(out.time.dt.season==seas1)).time.dt.year
    else:
        y = out.mean(dim=['lon','lat'])
        x = out.time.dt.year
    if trendflag:
        poly = np.polyfit(x[np.where(~np.isnan(y))],y[np.where(~np.isnan(y))],3)
        y1  = xr.DataArray(np.polyval(poly, x))
        trace = dict(x=x,y=y1)
    else:
        trace = dict(x=x,y=y)
    return trace
        
def poly_trend(x,y):
    flag=False
    for n in np.arange(4,0,-1):
        #print(n)
        try:
            poly = np.polyfit(x[np.where(~np.isnan(y))],y[np.where(~np.isnan(y))],n)
            y1  = xr.DataArray(np.polyval(poly, x))
            flag=True
        except:
            continue
        if flag:
            return y1

def clear_graph(change):
    fig.title.text = 'Select options and click "Calculate" button ...'
    for t in traces.keys():
        traces[t].data_source.data['y'] = traces[t].data_source.data['y'] + np.nan
    push_notebook()
        
def recalculate_vals(change):
    func_str = index[TypeList.value][Index_list.value]['func']
    prog.value = 0
    prog.layout.visibility = 'visible'
    fig.title.text = 'Calculating climate index ...'
    for t in traces.keys():
        traces[t].data_source.data['y'] = traces[t].data_source.data['y'] + np.nan
    push_notebook()
    
    out = {}
    x = {}
    y = {}
    y_upper = {}
    y_lower = {}
    prog.value = 5
    prog.description = f'{prog.value} % '
    for i, r in enumerate(dsCC.keys()):
        
        #print(r)
        if r == 'Obs':
            #out[r] = calc_index(dsCC[r])
            out[r] = calc_index(dsCC[r],func_str,  r, to_netcdf=save_index.value)
           
        else:
            if r == 'rcp45':
                col1 = 'rgb(50, 50, 200)'
            else:
                col1 = 'rgb(200, 50, 50)'

            out[r] = []
            for m in dsCC[r].keys():
                out_str = f'{m}_{r}'
                out[r].append(calc_index(dsCC[r][m], func_str,out_str, to_netcdf=save_index.value))
            out[r] = ens.ensemble_mean_std_max_min(ens.create_ensemble(out[r]))
        vars = list(out[r].data_vars)
            #out[r] = out[r][vars[0]]
        # create trace
        seas1 = None

        if seaslist.value != 'Yearly':
            if seaslist.value == 'Apr-Sept':
                seas1 = seaslist.value

            else:
                if seaslist.value == 'Winter':
                    seas1 = 'DJF'
                elif seaslist.value == 'Spring':
                    seas1 = 'MAM'
                elif seaslist.value == 'Summer':
                    seas1 = 'JJA'
                elif seaslist.value == 'Fall':
                    seas1 = 'SON'
                x[r] = out[r][vars[0]].sel(time=(out[r].time.dt.season==seas1)).time.dt.year.values
                y[r] = out[r][vars[0]].sel(time=(out[r].time.dt.season==seas1)).mean(dim=['lon','lat']).values
                if r != 'Obs':
                    y_upper[r] = out[r][vars[2]].sel(time=(out[r].time.dt.season==seas1)).mean(dim=['lon','lat']).values
                    y_lower[r] = out[r][vars[3]].sel(time=(out[r].time.dt.season==seas1)).mean(dim=['lon','lat']).values
                
        else:
            #print(r, out[r])
            y[r] = out[r][vars[0]].mean(dim=['lon','lat']).values
            x[r] = out[r].time.dt.year.values
            if r != 'Obs':
                y_upper[r] = out[r][vars[2]].mean(dim=['lon','lat']).values
                y_lower[r] = out[r][vars[3]].mean(dim=['lon','lat']).values
        
        prog.value = int((i+1)/len(dsCC.keys())*100) 
        #print(prog.value)
        prog.description = f'{prog.value} % '
    
        
    #print(y,y_upper,y_lower)
    for t in list(traces.keys()):
        r = t.split('_')[0]
        #print(t)
        if 'Obs' in t:
            traces[t].data_source.data['y'] = y[r]
            traces[t].data_source.data['x'] = x[r]
    
        else:
            
            if 'bounds' in t:
                y1 = poly_trend(x[r],y_upper[r])
                y2 = poly_trend(x[r],y_lower[r])
                x_all =  np.append(x[r],np.flip(x[r]))
                y_all =  np.append(y1,np.flip(y2))
                #y_all =  np.append(y_upper[r],np.flip(y_lower[r]))
                
            else:
                y0 = poly_trend(x[r],y[r])
                y1 = poly_trend(x[r],y_upper[r])
                y2 = poly_trend(x[r],y_lower[r])
                x_all =np.tile(np.append(x[r],np.nan),[1,3]).flatten()
                y_all = np.asarray((np.append(y1,np.nan),np.append(y0,np.nan), np.append(y2,np.nan))).flatten()
       
            traces[t].data_source.data['y'] = y_all
            traces[t].data_source.data['x'] = x_all
    
    
    fig.title.text = f"{region_name} : {out['rcp45'][vars[0]].description.replace('Seasonal', f'{seaslist.value} ({seas1})').split(':')[0]}"
    fig.xaxis.axis_label = 'year'
    fig.yaxis.axis_label = out['rcp45'][vars[0]].units
    prog.layout.visibility = 'hidden'
    push_notebook()



canada_center = (45.4292, -73.2959)
if not os.path.exists('nc_data'):
        os.makedirs('nc_data')
dd = widgets.Button(description='Continue')
working_list = ['New Region']
options = sorted([o for o in os.listdir('./nc_data') if os.path.isdir(os.path.join('./nc_data',o)) and 'tmp' not in o])
if '.ipynb_checkpoints' in options:
    options.pop(options.index('.ipynb_checkpoints'))
working_list.extend(options)
RegSelect = widgets.Dropdown(
    description='Select region or create new',
    value= None,
    options=working_list
)
newregname = widgets.Text(
                    value='',
                    placeholder='Enter region name',
                    description='new region:',
                    disabled=False)

newregname.layout.visibility= 'hidden'
newregname.layout.width = '0'
RegSelect.observe(select_region)
box = widgets.HBox(children =[RegSelect, newregname, dd])

prog1 = widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    step=0.1,
    description='',
    bar_style='info',
    orientation='horizontal')

prog =widgets.FloatProgress(
    value=0,
    min=0,
    max=100,
    step=0.1,
    description='Loading : ',
    bar_style='info',
    orientation='horizontal')

display_markdown(
      '### Select an existing region or create a new one', raw=True)

display(box)

### Select an existing region or create a new one

HBox(children=(Dropdown(description='Select region or create new', options=('New Region', 'Montreal', 'TestReg…

In [None]:
%block dd

In [4]:
# cell 3
# Won't actually be executed until the user chooses an option in the dd widget
if RegSelect.value == 'New Region':
    region_name = newregname.value
    newreg_flag = True
    if 'tmp' == region_name:
        raise Exception('"tmp" is a reserved subfolder name please use a different name')
else:
    region_name = RegSelect.value
    newreg_flag = False

dbounds = widgets.Button(description='Load Data')
m1 = None

if not newreg_flag:
    m1 = create_map()
    display_markdown(
      '### Click "load data" to continue', raw=True)
    display(m1)
    dbounds.description=f'Load Data : {region_name}'
    m1.center = list(np.vstack(m1.layers[1].bounds).mean(axis=0))
else:
    display_markdown(
      '### Draw rectangle and click "extract data" to continue', raw=True)
    rectangle, bounds = get_rectangle()
    dbounds.description=f'Extract Data : {region_name}'


display(HBox([dbounds,prog1]))


### Click "load data" to continue

Map(basemap={'url': 'http://c.basemaps.cartocdn.com/light_all/{z}/{x}/{y}.png', 'max_zoom': 20, 'attribution':…

HBox(children=(Button(description='Load Data : Montreal', style=ButtonStyle()), FloatProgress(value=0.0, bar_s…

In [6]:
%blockrun dbounds

In [7]:
outfolder = os.path.join('.','nc_data',region_name)
url = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/dodsC/birdhouse/nrcan/nrcan_canada_daily/nrcan_canada_daily-allvars-agg.ncml'
baseurl_CC = 'https://pavics.ouranos.ca/twitcher/ows/proxy/thredds/catalog/birdhouse/testdata/ouranos/cb-oura-1.0_rechunk/catalog.html'
mods = TDS(baseurl_CC)
mods = list(mods.catalog_refs)
if newreg_flag:
    dsCC = extract_data(outfolder, bounds)
    if m1 is None:
        m1 = create_map()
else:
    dsCC = load_data(outfolder)

m1.center = list(np.vstack(m1.layers[1].bounds).mean(axis=0))
dbounds.layout.visibility = 'hidden'
#Create climate index function list
# Temperature indices
index = {}
index['Temperature']  = {}

args=dict(vars = {'tas':'tas'}, seasons= ['Yearly','Winter','Spring','Summer','Fall'])
index['Temperature']['Average Temperature'] = dict(func='tg_mean',args=args )

args=dict(vars={'tasmax':'tasmax'}, seasons= ['Yearly','Winter','Spring','Summer','Fall'], thresh=dict(value=25,min=-40,max=40,step=1, units='degC'))
index['Temperature']['Days w/Tmax > value'] = dict(func='tx_days_above', args=args)

args=dict(vars={'tas':'tas'}, seasons= ['Yearly'], thresh=dict(value=5,min=-10,max=20,step=1, units='degC'))
index['Temperature']['Growing degree days > value'] = dict(func='growing_degree_days', args=args)

args=dict(vars={'tasmax':'tasmax','tasmin':'tasmin'}, seasons= ['Yearly','Winter','Spring','Summer','Fall'],threshTmax=dict(value=0,min=0,max=10,step=1, units='degC'), threshTmin=dict(value=0,min=-10,max=0,step=1, units='degC'))
index['Temperature']['Freeze-Thaw events'] = dict(func='daily_freezethaw_cycles', args=args)


#Precip indices
index['Precipitation'] = {}

args = dict(vars = {'pr':'pr'}, seasons= ['Yearly','Winter','Spring','Summer','Fall'])
index['Precipitation']['Total Precipitation'] = dict(func='precip_accumulation',args=args)

args = dict(vars = {'pr':'pr'}, seasons= ['Yearly','Winter','Spring','Summer','Fall'], thresh=dict(value=1,min=1,max=25,step=1, units='mm/day'))
index['Precipitation']['Wetdays > value'] =  dict(func='wetdays',args=args)

args = dict(vars = {'pr':'pr'}, seasons= ['Yearly','Apr-Sept','Winter','Spring','Summer','Fall'], window=dict(value=1,min=1,max=20,step=1))
index['Precipitation']['Max n-day precipitation'] =  dict(func='max_n_day_precipitation_amount',args=args)


html_download = widgets.Button(
    description='Html',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Download Offline',
    icon='download'
)

TypeList = widgets.Dropdown(
    description='Category',
    value='Temperature',
    options=list(index)
)

TypeList.layout.width = 'auto'

Index_list = widgets.Dropdown(
    description='Climate Index',
    value=list(index[TypeList.value])[0],
    options=list(index[TypeList.value])
)

seaslist = widgets.Dropdown(
    description='Season',
    value= index[TypeList.value][Index_list.value]['args']['seasons'][0],
    options=index[TypeList.value][Index_list.value]['args']['seasons']
)


windowlist = BoundedIntText(
    value=1,
    min=-99999,
    max=99999,
    step=0,
    description='Days'
)

threshlist = BoundedFloatText(
    value=1,
    min=-99999,
    max=99999,
    step=0,
    description='Thresh'
)

threshTmaxlist = BoundedFloatText(
    value=0,
    min=-99999,
    max=99999,
    step=0,
    description='Thresh Tmax'
)


threshTminlist = BoundedFloatText(
    value=0,
    min=-99999,
    max=99999,
    step=0,
    description='Thresh Tmin'
)

calc_button = widgets.Button(
    description='Calculate',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Calculate climate index',
    icon='fa-refresh'
)

save_index = widgets.Checkbox(
    value=False,
    indent=False,
    description='save to netcdf',
    disabled=False
)

prog.layout.visibility =  'hidden'
threshlist.layout.width = '0'
threshlist.layout.visibility = 'hidden'
threshTmaxlist.layout.width = '0'
threshTmaxlist.layout.visibility = 'hidden'
threshTminlist.layout.width = '0'
threshTminlist.layout.visibility = 'hidden'
windowlist.layout.width = '0'
windowlist.layout.visibility = 'hidden'
TypeList.observe(update_index_list, 'value')
TypeList.observe(update_graph, 'value')
Index_list.observe(update_graph, 'value')
threshlist.observe(clear_graph,'value')
threshTmaxlist.observe(clear_graph,'value')
threshTminlist.observe(clear_graph,'value')
windowlist.observe(clear_graph,'value')
seaslist.observe(clear_graph,'value')
TypeList.observe(clear_graph, 'value')
Index_list.observe(clear_graph, 'value')
html_download.on_click(download_offline)
calc_button.on_click(recalculate_vals)



from ipywidgets import TwoByTwoLayout
# 

# Figure intitialize 
from bokeh.models import HoverTool, CrosshairTool, CustomJS 
from bokeh.layouts import row
from bokeh import events
fig = figure(plot_width=p_wid, plot_height=p_height, title='Test', toolbar_location="below")
output_notebook()
func_str = index[TypeList.value][Index_list.value]['func']
traces = {}
# intitialize traces
out = {}
traces = {}
for r in dsCC.keys():
    
    if r == 'Obs':
        out[r] = calc_index(dsCC[r], func_str, r)
        trace = create_trace(out[r], trendflag=False)

        traces[r]= fig.line(trace['x'].values,trace['x'].values,line_width=1.5, color="#54565B", legend_label=r, name=r)
    else:
        if r == 'rcp45':
            col1 = '#476EF3'
        else:
            col1 = '#DC300B'

        out[r] = []
        for m in dsCC[r].keys():
            out_str = f'{m}_{r}'
            out[r].append(calc_index(dsCC[r][m],func_str, out_str))
        out[r] = ens.ensemble_mean_std_max_min(ens.create_ensemble(out[r]))
        vars = list(out[r].data_vars)
        trace = create_trace(out[r][vars[0]], trendflag=True) 
        trace_upper = create_trace(out[r][vars[2]],  trendflag=True)
        trace_lower = create_trace(out[r][vars[3]],  trendflag=True)
        x_all =np.tile(np.append(trace_upper['x'].values,np.nan),[1,3]).flatten()
        y_all = np.asarray((np.append(trace_upper['y'].values,np.nan),np.append(trace['y'].values,np.nan), np.append(trace_lower['y'].values,np.nan))).flatten()
        x_all = np.append(np.append(np.append(trace_upper['x'].values,np.nan), np.append(trace['x'].values, np.nan)),trace_lower['x'].values)
        y_all =  np.append(np.append(np.append(trace_upper['y'].values,np.nan), np.append(trace['y'].values, np.nan)),trace_lower['y'].values)
        traces[r] = fig.line(trace['x'].values,trace['y'].values, line_width=1.5,color=col1,  legend_label=r, name=r)                  
        x_all =  np.append(trace_upper['x'].values,np.flip(trace_lower['x'].values))
        y_all =  np.append(trace_upper['y'].values,np.flip(trace_lower['y'].values))
        traces[f'{r}_bounds'] = fig.patch(x_all, y_all, alpha=0.5, color=col1, line_width=1.0,  legend_label=f'{r} bounds', name=f'{r} bounds')
        traces[f'{r}_bounds'].level = 'underlay'


fig.add_tools(HoverTool(tooltips = [('Name', '$name'),('Year', '@x'),('Value', '@y')]))
fig.add_tools(CrosshairTool(dimensions='height',line_alpha=0.3))
legend = fig.legend[0]
legend.click_policy = "hide"
legend.location = "top_left"
legend.label_text_font_size = '8pt'

def show_hide_legend(legend=fig.legend[0]):
    legend.visible = not legend.visible
fig.js_on_event(events.DoubleTap, CustomJS.from_py_func(show_hide_legend))
outw = widgets.Output()

# display widget!!!!
controls1 = widgets.HBox([ Index_list, seaslist, threshlist, threshTmaxlist, threshTminlist, windowlist, widgets.HBox([calc_button, save_index])])
app = AppLayout(header=None,
          left_sidebar=None,
          center=widgets.VBox([ TypeList,outw,prog]),
          right_sidebar=m1,
          footer=widgets.VBox([controls1, html_download]))
display(app)
recalculate_vals(1)
# plot into widget
with outw:
    h = show(fig, notebook_handle=True)
  

AppLayout(children=(VBox(children=(HBox(children=(Dropdown(description='Climate Index', options=('Average Temp…

In [8]:
#os.getcwd()

In [9]:
#shutil.rmtree('./nc_data/Montreal/ClimateIndicatorOutput')