In [9]:
import logging
import numpy  as np
import panel  as pn
import xarray as xr

from glob    import glob
from pathlib import Path

import hvplot.xarray

logging.basicConfig(level=logging.WARN)

import mlky
from mlky import Sect

from sudsaq.ml.explain import (
    Dataset, 
    Explanation
)
self = Sect()
self.run = '/data/MLIA_active_data/data_SUDSAQ/models/bias/gattaca.v4.bias-median'

pn.extension('plotly', loading_spinner='dots', loading_color='#00aa41', template='bootstrap')
pn.param.ParamMethod.loading_indicator = True


In [10]:
def getAvailable():
    files = glob(f'{self.run}/**/**/test.explanation.nc')
    order = ['jan', 'feb', 'mar', 'apr', 'may', 'jun', 'jul', 'aug', 'sep', 'oct', 'nov', 'dec']
    self.avail = {
        'months': set(), 
        'years' : set()
    }
    for file in files:
        month, year = file.split('/')[-3:-1]
        self.avail.months.update([month])
        self.avail.years.update([year])
    self.avail.months = sorted(self.avail.months, key=lambda month: order.index(month))
    self.avail.years = sorted(list(self.avail.years), key=lambda year: int(year))

getAvailable()

In [11]:
def loadByFiles():
    if not hasattr(loadByFiles, 'cache'):
        loadByFiles.cache = {}
        
    files  = []
    combos = list(product(self.months.values(), self.years.values()))

    key = str(combos)
    if key in loadByFiles.cache:
        return loadByFiles.cache[key]
    
    for month, year in combos:
        files += glob(f'{self.run}/{month.lower()}/{year}/test.explanation.nc')

    if not files:
        print(f'No files found for combinations: {combos}')
        return

    self.ds = xr.open_mfdataset(files, mode='r', lock=False)
    self.ds = Dataset(self.ds)

    loadByFiles.cache[key] = self.ds

def checkBounds():
    """
    """
    if self.lat.upper <= self.lat.lower:
        print('Upper bound latitude must be less than the lower bound')
    
    if self.lon.upper <= self.lon.lower:
        print('Upper bound longitude must be less than the lower bound')

def cache(key):
    """
    WIP
    """
    def decorator(func):
        def execute(*args, **kwargs):
            cache[key] = func(*args, **kwargs)
            return cache[key]
        return execute
    
    cache = {}
    key   = mlky.replace(key)

    if key in cache:
        return cache[key]
    return decorator


@cache(key=('((${.lat.lower}, ${.lat.upper}), (${.lon.lower}, ${.lon.upper}))'))
def loadRegion():
    """
    """
    checkBounds()
    loadByFiles()
    
    self.region = (self.ds
        .sel(lat=slice(self.lat.lower, self.lat.upper))
        .sel(lon=slice(self.lon.lower, self.lon.upper))
    ).load()

def plotRegion():
    """
    """
    checkBounds()
    loadByFiles()

    lats = (self.lat.lower <= self.ds.lat) & (self.ds.lat <= self.lat.upper)
    lons = (self.lon.lower <= self.ds.lon) & (self.ds.lon <= self.lon.upper)
    mask = lats * lons
    plot = mask.hvplot.quadmesh('lon', 'lat', global_extent=True, coastline=True)

    return pn.panel(plot)


RecursionError: maximum recursion depth exceeded

In [None]:
import time
import shap
import matplotlib.pyplot as plt
from functools import partial
from itertools import product


def btnList(func, name, options, lbl, lblw=100, **kwargs):
    def group_select(*_, group, values=None):
        group.set_param('value', group.values if values is None else values)
    
    def select(selected):
        self.months = selected

    group = pn.widgets.CheckButtonGroup(name=name, options=options, **kwargs)
    pn.bind(func, group, watch=True)

    all = pn.widgets.Button(
        width       = 16,
        name        = 'All', 
        button_type = 'primary'
    )
    pn.bind(partial(group_select, group=group), all, watch=True)
    none = pn.widgets.Button(
        width       = 32,
        name        = 'None', 
        button_type = 'danger'
    )
    pn.bind(partial(group_select, group=group, values=[]), none, watch=True)

    lbl = pn.pane.Markdown(lbl, width=lblw)

    return pn.Row(lbl, all, none, group, sizing_mode='stretch_width')


def plotWaterfall(key):
    """
    """
    loadRegion()
    plt.close('all')
    
    data = self[key].to_explanation(auto=True).mean(0)
    plot = shap.plots.waterfall(data, max_display=20, show=False)
    plt.tight_layout()
    pane = pn.pane.Matplotlib(plot)
    
    return pane

def generator():
    def click(event):
        if event:
            return self.generate()
        return self.default

    btn = pn.widgets.Button(
        name        = 'Generate', 
        button_type = 'primary'
    )
    btn.click = lambda: btn.param.set_param('value', True) # Simulates a click event
    bind = pn.bind(click, btn)
    plot = pn.param.ParamFunction(bind, loading_indicator=True)

    # column = pn.Column(btn, plot, sizing_mode='stretch_width')

    return btn, plot

def regionSel():
    """
    """
    def sel(value, dim, boundary):
        self[dim][boundary] = value

    # Populate defaults
    self['lat'] = {'upper': 90, 'lower': -90}
    self['lon'] = {'upper': 180, 'lower': -180}

    # Initialize the grid and label
    lbl  = pn.pane.Markdown('## Select region boundaries: ')
    
    grid = pn.GridSpec(sizing_mode='scale_width', height=100)
    grid = pn.GridSpec(width=320, height=200)

    # Upper Lat
    grid[0, 1] = widget = pn.widgets.FloatInput(name='Upper Latitude', value=self.lat.upper, step=1e-1, start=self.lat.lower, end=self.lat.upper)
    bind = pn.bind(sel, widget, dim='lat', boundary='upper', watch=True)

    # Lower Lon
    grid[1, 0] = widget = pn.widgets.FloatInput(name='Lower Longitude', value=self.lon.lower, step=1e-1, start=self.lon.lower, end=self.lon.upper)
    bind = pn.bind(sel, widget, dim='lon', boundary='lower', watch=True)

    # Lower Lat
    grid[2, 1] = widget = pn.widgets.FloatInput(name='Lower Latitude', value=self.lat.lower, step=1e-1, start=self.lat.lower, end=self.lat.upper)
    bind = pn.bind(sel, widget, dim='lat', boundary='lower', watch=True)

    # Upper Lon
    grid[1, 2] = widget = pn.widgets.FloatInput(name='Upper Longitude', value=self.lon.upper, step=1e-1, start=self.lon.lower, end=self.lon.upper)
    bind = pn.bind(sel, widget, dim='lon', boundary='upper', watch=True)

    # Extra invisible spacer helps with, well, spacing
    grid[3, :] = pn.Spacer()
    
    return pn.Row(lbl, grid)

months = btnList(
    name        = 'Select Months',
    options     = self.avail.years.values(),
    button_type = 'success',
    width       = 30,
    lbl  = '## Select years(s):',
    lblw = 200,
    func = lambda selected: setattr(self, 'years', selected)
)
years = btnList(
    name        = 'Select Years',
    options     = self.avail.months.values(),
    button_type = 'success',
    width       = 30,
    lbl  = '## Select month(s):',
    lblw = 200,
    func = lambda selected: setattr(self, 'months', selected)
)
self.default = pn.pane.PNG('/home/mambauser/suds-air-quality/notebooks/howtoSHAPwaterfalls.png', sizing_mode='scale_width')

# gbtn is the generator button, components 
gbtn, plot = generator()


def genButtons(generator):
    def genRegion(event):
        if event:
            self.generate = plotRegion
            generator.click()

    def genWaterfall(event):
        if event:
            self.generate = partial(plotWaterfall, key='region')
            generator.click()

    buttons = []

    btn = pn.widgets.Button(name='Region')
    pn.bind(genRegion, btn, watch=True)
    buttons.append(btn)

    btn = pn.widgets.Button(name='Waterfall')
    pn.bind(genWaterfall, btn, watch=True)
    buttons.append(btn)
    
    return pn.Row(*buttons)

genButts = genButtons(gbtn)

pn.Column(months, years, regionSel(), genButts, plot, sizing_mode='stretch_width')