In [None]:
%matplotlib inline

def draw(data, ax=None, figsize=(13, 7), title=None, coastlines=True, gridlines=True, **kwargs):
    """
    Portable geospatial plotting function
    """
    if ax is None:
        if 'plt' not in globals():
            global plt
            import matplotlib.pyplot as plt
        if 'ccrs' not in globals():
            global ccrs
            import cartopy.crs as ccrs

        fig = plt.figure(figsize=figsize)
        ax  = plt.subplot(111, projection=ccrs.PlateCarree())

    plot = data.plot.pcolormesh(x='lon', y='lat', ax=ax, **kwargs)

    if title:
        ax.set_title(title)
    if coastlines:
        ax.coastlines()
    if gridlines:
        ax.gridlines(draw_labels=False, color='dimgray', linewidth=0.5)

    return ax

In [None]:
self.file = '/Users/jamesmo/projects/suds-air-quality/.local/data/shap/v4/jul/2015/test.explanation.nc'
self.outp = '/Users/jamesmo/projects/suds-air-quality/research/james/local'
self.log = logging.getLogger('UI')

self

In [None]:
def load(file):
    """
    """
    self.log.info(f'Loading file: {file}')
    try:
        self.ds = Dataset(xr.open_dataset(file).load())
        ns = self.ds.stack(loc=['lat', 'lon', 'time']).transpose()
        self.ex = ns.to_explanation()
        self.file = file
    except:
        self.log.exception(f'Failed to load file: {file}')
    self.log.debug('Finished load()')

load(self.file)
self.ex

In [None]:
import cartopy.crs       as ccrs
import matplotlib.pyplot as plt
import shap

%matplotlib inline

from tqdm import tqdm


def geospatial(variable, save=None, **params):
    """
    """
    fig = plt.figure(figsize=(20, 10))
    ax = plt.subplot(111, projection=ccrs.PlateCarree())

    # Mean along time axis
    self.mds['values'].sel(variable=variable).plot.pcolormesh(x='lon', y='lat', ax=ax, **params)
    
    ax.set_title(f'Mean SHAP Values for {variable}')
    ax.coastlines()
    ax.gridlines(draw_labels=['bottom', 'left'], color='dimgray', linewidth=0.5)

    # Save and flush
    plt.tight_layout()
    if save:
        plt.savefig(save)
    plt.close('all') # flush

def bar(save=None, **params):
    """
    """
    # SHAP built-in plotter
    shap.plots.bar(self.ex, show=False, **params)

    # Show=False enables plot editting
    plt.title('Mean Absolute SHAP Value')

    # Save and flush
    plt.tight_layout()
    if save:
        plt.savefig(save)
    plt.close('all') # flush

def summary(save=None, **params):
    """
    """
    # SHAP built-in plotter
    shap.summary_plot(self.ex, show=False, **params)

    # Show=False enables plot editting
    
    # Save and flush
    plt.tight_layout()
    if save:
        plt.savefig(save)
    plt.close('all') # flush


plotters = {
    'bar': bar,
    'summary': summary,
    'geospatial': geospatial
}

def preload():
    for key, params in config.preload.items():
        params = list(parameterize(params))
        for parms in tqdm(params, desc=f'Generating {key} plots'):
            file = self.outp/f'{key}.{create_name(parms)}.png'
            if not file.exists():
                plotters[key](save=file, **parms)

preload()

In [None]:
config = Config("""
preload:
    bar:
        max_display: [null, 10, 20]
    summary:
        max_display: [null, 10]
    geospatial:
        variable: [momo.co, momo.ps, momo.t, momo.no2]
        cmap: viridis
        levels: 20
""")

In [None]:
from mlky import Section as S

def parameterize(params, skip=[]):
    """
    Generates the total combination of parameters from a dict of 
        {param1: [opt1, opt2, ...], param2: [opt1, ...], ...}
    to 
        [{param1: option[i], param2: option[i], ...}, ...]
    Parameters that are not lists will be carried forward as-is to all combinations
    """
    # Find an unprocessed key
    for key in params:
        if key not in skip:
            break
    else:
        # All keys processed
        return {}

    # Gather the other keys
    others = list(parse(params, skip + [key]))
    value  = params[key]

    # Yield a combination of parameters
    if isinstance(value, list):
        for val in value:
            if others:
                for other in others:
                    yield S({key: val}) + other
            else:
                yield S({key: val})
    else:
        if others:
            for other in others:
                yield S({key: value}) + other
        else:
            yield S({key: value})

# Simple function to create a name from the parameters
create_name = lambda params: '.'.join([f'{k}={v}' for k, v in params.items()])

In [None]:
params = list(parameterize(c.bar))[0]

In [None]:
def create_name(params):
    for key, val in params.items():