In [None]:
# Box and Whisker plot for the top 5 predictors 
# pink line for a given example. 
import seaborn as sns
sns.set_theme({'axes.edgecolor':'black'})
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
#plt.rcParams['xtick.bottom'] = True

import sys
sys.path.append('/Users/monte.flora/Desktop/road_surface/')
sys.path.append('/Users/monte.flora/ml_workflow/')
sys.path.append('../python_scripts')
from display_names import to_readable_names, to_color, get_units
from probsr_config import PREDICTOR_COLUMNS, TARGET_COLUMN


sys.path.append('/Users/monte.flora/scikit-explain/')
sys.path.append('../python_scripts')
import skexplain
import shap

import seaborn as sns
sns.set_theme(style="whitegrid")

In [None]:
top_preds = {'Tornado' : 
                    ['wz_0to2_time_max_ens_mean_of_90th',
                     'hailcast_time_max_ens_mean_of_90th',
                     'uh_2to5_time_max_ens_mean_of_90th',
                     'comp_dz_time_max_ens_mean_of_90th',
                     'cape_ml_ens_mean_spatial_mean'
                    ],
            'Severe Hail' : 
                    ['dbz_3to5km_max_time_max_ens_mean_of_90th',
                     'w_up_time_max_ens_mean_of_90th',
                     'major_axis_length',
                     'comp_dz_time_max_ens_mean_of_90th',
                     'low_level_lapse_rate_ens_mean_spatial_mean'
                    ],
            'Severe Wind' : 
                    ['ws_80_time_max_ens_mean_of_90th',
                     'comp_dz_time_max_ens_mean_of_90th',
                     'minor_axis_length',
                     'wz_0to2_time_max_ens_mean_of_90th',
                     'mid_level_lapse_rate_ens_mean_spatial_mean',
                    ]
            }


models = {'Tornado': '../models/LogisticRegression_first_hour_tornado_under_standard_.pkl',
          'Severe Hail' : '../models/LogisticRegression_first_hour_severe_hail_under_standard_.pkl',
          'Severe Wind' : '../models/LogisticRegression_first_hour_severe_wind_under_standard_.pkl',
         }

In [None]:
def rounding(v):
    """Rounding for pretty plots"""
    if v > 100:
        return int(round(v))
    elif v >= 0 and v < 0.1:
        return round(v,3)
    elif v >= 0.1 and v < 1:
        return round(v,1)
    elif v > 0 and v < 100:
        return round(v,1)

def create_interpret_graphic(X_train, top_preds, target, add_red_lines=False ):
    """Create interpretability graphic"""
    units = {t : get_units(t) for t in top_preds}
    pretty_names = {t : to_readable_names(t) for t in top_preds}
    
    f, axes = plt.subplots(dpi=192, nrows=5, 
                           figsize=(800/192, 800/192))
    sns.despine(fig=f, ax=axes, top=True, right=True, left=True, bottom=False, offset=None, trim=False)

    axes[0].annotate('LOCAL')
    
    
    box_plots=[]
    for ax, c in zip(axes, top_preds):
        box_plot = ax.boxplot(x=X_train[c], vert=False, 
                              whis=[0,100], patch_artist=True, widths=0.3, showfliers=False )
        box_plots.append(box_plot)
        
        if 'axis' in c:
            ax.annotate(pretty_names[c].split('(')[0] +'('+units[c]+')', xy=(1.0, 1.15),
                    xycoords='axes fraction', fontsize=6, ha='right', color = 'k', fontweight='bold')
        else:
            ax.annotate('Ens. Mean ' + pretty_names[c].split('(')[0] +'('+units[c]+')', xy=(1.0, 1.15),
                    xycoords='axes fraction', fontsize=6, ha='right', color = 'k', fontweight='bold')
            
        # Remove y tick labels 
        ax.set_yticks([],)
        ax.tick_params(axis='x',  labelsize=8)
        min_val = 0 #rounding(np.percentile(X_train[c], 2.5))
        max_val = rounding(np.percentile(X_train[c], 99.9))
        
        try:
            ax.set_xticks(rngs[c])
        
            levels = rngs[c].copy()
            levels[0] = ''; levels[-1] = ''
            levels = [str(l) for l in levels]
            ax.set_xticklabels(levels)
        except:
            continue
        
        ax.set_xlim([min(rngs[c]), max(rngs[c])])
        ax.grid(False)

    # fill with colors
    color = 'xkcd:medium blue'
    for bplot in box_plots:
        for patch in bplot['boxes']:
            patch.set_facecolor(color)
        for line in bplot['medians']:
            line.set_color('k')

    plt.subplots_adjust(hspace=1.2)
    
    title = f'Training Set Distribution\n(All {target}-Producing Storms)\nfor the Top 5 Predictors (out of 113)'
    
    f.suptitle(title, 
               fontsize=8, y=1.05)
    axes[0].set_title('Red numbers and vertical bars\nshow current values for this object', 
                      fontsize=6, pad=12, color='red')
    #f.tight_layout()
    
    #axes[0].annotate('*Based on Shapley Values and LR Coefficients', 
    #                 xy=(0.175, 2.0), xycoords='axes fraction', fontsize=5, color='xkcd:dark blue')
    
    #axes[-1].annotate(r'$\mu_e$ = Ensemble Mean, S = Spatial, A=Amplitude'+'\n(see Flora et al. 2021 for details)', 
    #                 xy=(0.175, -2.5), xycoords='axes fraction', fontsize=5, color='xkcd:dark blue')
    
    ax.grid(False)
    
    if add_red_lines:
        box_plots=[]
        for ax, c, v in zip(axes, top_preds, current_object):
            ax.annotate(rounding(v), xy=(0.9, 0.7), xycoords='axes fraction', fontsize=10, color='red')
            # plot vertical lines 
            ax.axvline(x=v, color='red', zorder=5)
        
            ax.grid(False)
        
    return f, axes

def append_interpret_graphic(axes, current_object, data_vars):
    """Append to a interpretability graphic"""
    box_plots=[]
    for ax, c, v in zip(axes, data_vars, current_object):
        val = rounding(current_object[c])
        ax.annotate(val, xy=(0.9, 0.7), xycoords='axes fraction', fontsize=10, color='red')
        # plot vertical lines 
        ax.axvline(x=val, color='red', zorder=5)
    
    return axes

time = 'first_hour'
targets = ['severe_wind']
local = False

for target in targets:
    key = target.replace('_', ' ').title()

    # I AM LOADING THE TESTING DATASETS!!!!!
    X_train = pd.read_pickle(f'../datasets/{time}_testing_matched_to_{target}_0km_dataset').astype(float)
    y = X_train[f'matched_to_{target}_0km'].values
    #idx = np.where(y == 1)[0][3]
    idx = 13838 
    if local:
        data = joblib.load(f'../models/LogisticRegression_first_hour_{target}_under_standard_.pkl')
        model = data['model']
        features = data['features']
        X_train = X_train[features]
        
        pred = model.predict_proba(X_train)[:,1]
        #print(np.max(pred))
        idx = np.where(pred>0.8)[0][5]
        print(idx)

        single_example = X_train.iloc[[idx]]
        current_obj = X_train.iloc[idx, :]
        explainer = skexplain.ExplainToolkit(estimators=('LR', model), X=single_example,)

        shap_kwargs={'masker' : 
              shap.maskers.Partition(X, max_samples=500, clustering="correlation"), 
              'algorithm' : 'auto'}

        results = explainer.local_contributions(method='shap', shap_kwargs=shap_kwargs)
        contrib_names = [c for c in results.columns if 'contrib' in c and 'Bias' not in c]
        df=results[contrib_names]
        data = df.abs().values[0]
        names = df.columns
        inds = np.argsort(data)[::-1]
        names_sorted = names[inds]
        top_features = [n.replace('_contrib', '') for n in list(names_sorted[:5])]
        
    else:
        top_features = top_preds[key]
        current_obj = X_train.iloc[idx,:]
    
    X_train['mid_level_lapse_rate_ens_mean_spatial_mean'] /= -2.7
    X_train['low_level_lapse_rate_ens_mean_spatial_mean'] /= -3.0
    
    inds = np.where(y>0)[0]
    
    X_train = X_train.iloc[inds, :]
    X_train.reset_index(drop=True, inplace=True)
    
    f, axes = create_interpret_graphic(X_train, top_features, target=key, add_red_lines=False )
    append_interpret_graphic(axes, current_obj, data_vars=top_features)
    for ax in axes:
        ax.grid(False)
    
    #plt.savefig(f"{key.lower().replace(' ', '_')}_explainability_background.png", format="png", dpi=200) 