### Generate the Official cb-WoFS Explainability Graphics 

In [1]:
# Box and Whisker plot for the top 5 predictors 
# pink line for a given example. 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib
from display_names import to_display_name, to_units
import sys
sys.path.insert(0, '/home/monte.flora/python_packages/WoF_post')
sys.path.insert(0, '/home/monte.flora/python_packages/wofs_ml_severe')

from wofs_ml_severe.io.load_ml_models import load_ml_model
from wofs_ml_severe.io.io import load_ml_data
from wofs.post.utils import load_yaml
import matplotlib.ticker as ticker

from sklearn.impute import SimpleImputer

lookup_file: /home/monte.flora/python_packages/WoF_post/wofs/data/psadilookup.dat


In [2]:
import math
import json 

class cbWoFSExplainabilityGraphics:
    def __init__(self, X_train, y_train):
        
        # Only get where the examples are matched to reports.
        inds = np.where(y_train>0)[0]
    
        X_train_subset = X_train.iloc[inds, :]
        X_train_subset.reset_index(drop=True, inplace=True)
        
        #X_train_subset['mid_level_lapse_rate_ens_mean_spatial_mean'] /= -2.7
        #X_train_subset['low_level_lapse_rate_ens_mean_spatial_mean'] /= -3.0
        
        # Convert mid-level temps?  
        
        self.X_train = X_train_subset
        self.features = X_train_subset.columns 
        
        self.max_min_val_dict = { }
        
    def _round(self, value, mode):
        
        def round_to_nearest_fifth(x):    
            if x < 0:
                return -round_to_nearest_fifth(-x)
            elif x < 1:
                return x
            else:
                if mode=='upper': 
                    return math.ceil(x / 5,) * 5
                else:
                    return math.floor(x / 5,) * 5
        
        if value == 0.0:
            return 0.0, 0
    
        # Find the order of magnitude (oom)
        oom = int(math.log10(abs(value)))

        round_to_fifth = False
        if oom >= 0:
            # Positive Order of Mag. 
            if oom in [0,1]: 
                round_int = 1
            elif oom > 1:
                round_int = 0
                round_to_fifth = True
        else:
            # Negative Order of Mag.
            if oom == -1:
                round_int = 2
            else:
                round_int = 3
 
        # Round to the nearest 5 for >=10 
        if round_to_fifth:
            return round_to_nearest_fifth(round(value, round_int)), round_int
        else:
            return round(value, round_int), round_int

    def create_global(self, features, target):
        """Create the global explainability graphic"""
        f, axes = plt.subplots(dpi=192, nrows=5, 
                           figsize=(800/192, 800/192))
        for ax, feature in zip(axes, features):
            self.create_local(feature, ax=ax, f=f)
        
        title = f"""
        Training Set Distribution\n(All {target}-Producing Storms)\nfor the Top 5 Predictors (out of 113)"""
    
        f.suptitle(title, 
               fontsize=8, y=1.10)
        
        axes[0].set_title('Red numbers and vertical bars\nshow current values for this object', 
                      fontsize=6, pad=12, color='red')
        
        plt.subplots_adjust(hspace=1.4)
        
        return f, axes 
                
    def create_local(self, feature, f=None, ax=None):
        """Create box-and-whisker graphic for a single feature"""  
        units = to_units(feature)
        pretty_name = to_display_name(feature)
    
        if ax is None:
            f, ax = plt.subplots(dpi=192, nrows=1, 
                           figsize=(800/192, 100/192))
        
        # Despine and only leave the bottom side. 
        for side in ['top', 'right', 'left']: 
            ax.spines[side].set_visible(False)

        # Create the box-and-whiskers 
        whis=[0.1, 99.9]
        box_plot = ax.boxplot(x=self.X_train[feature], vert=False, 
                              whis=whis, patch_artist=True, 
                              widths=0.3, showfliers=False )
        
        # Create a title for the feature name. 
        ax.annotate(f'{pretty_name} ({units})', xy=(1.0, 1.15),
                    xycoords='axes fraction', fontsize=6, ha='right', color = 'k', fontweight='bold')
        
        # Remove y tick labels 
        ax.set_yticks([],)
        ax.tick_params(axis='x', labelsize=9, size=8)
        min_val, _ = self._round(np.nanpercentile(self.X_train[feature], whis[0]), 'lower')
        max_val, round_int = self._round(np.nanpercentile(self.X_train[feature], whis[-1]), 'upper')
        
        self.max_min_val_dict[feature] = {'max_val' : max_val,
                                          'min_val' : min_val, 
                                          'round_int' : round_int}
        
        # Identify pretty tick ranges 
        ax.set_xlim(min_val, max_val)
        
        # set the tick locator for the x-axis
        ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=7))
        
        rng = list(ax.get_xticks())
        if round_int == 0: 
            levels = [f"{v:.0f}" for v in rng]
        elif round_int == 1: 
            levels = [f"{v:.1f}" for v in rng]
        elif round_int == 2: 
            levels = [f"{v:.2f}" for v in rng]
        else:
            levels = [f"{v:.3f}" for v in rng]
        
        levels[0] = ''; levels[-1] = ''
        ax.set_xticks(rng)
        ax.set_xticklabels(labels=levels)

        # fill with colors
        color = 'xkcd:medium blue'
        for patch in box_plot['boxes']:
            patch.set_facecolor(color)
        for line in box_plot['medians']:
            line.set_color('k')

        return f, ax
    
    def save_local(self, fig, feature, target):
        plt.savefig(
            f"new_graphics/{feature.lower().replace(' ', '_')}_{target}_explainability_background.png", 
            format="png", dpi=192, bbox_inches="tight")
        plt.close(fig) 
    
    def save_global(self, fig, target):
        plt.savefig(
            f"new_graphics/{target}_global_explainability_background.png", 
            format="png", dpi=192, bbox_inches="tight")
        plt.close(fig) 
    
    
    
    def save_json(self, target): 
        with open(f"../json/min_max_vals_{target}.json", "w") as outfile:
            json.dump(self.max_min_val_dict, outfile)

### Create the Explainability Graphic Per Feature For Each Hazard

1. Only using the severe hail dataset 

In [3]:
def get_target_str(target):
    # Initialize the kwargs for the hyperparameter optimization.
    if isinstance(target, list):
        if 'sig_severe' in target[0]:
            target = 'all_sig_severe'
        else:
            target = 'all_severe'
   
    return target 

### Create Individual Panels 

In [4]:
top_preds = {'tornado_severe_0km' : 
            ['wz_0to2_instant__time_max__amp_ens_mean_spatial_perc_90', 
             'shear_v_0to6__ens_mean__spatial_mean',
             'buoyancy__time_min__amp_ens_mean_spatial_perc_10',
             '10-500m_bulkshear__time_max__amp_ens_mean_spatial_perc_90',
             'v_10__ens_mean__spatial_mean',
            ],
            'hail_severe_0km' : 
                    [
              'dbz_3to5__time_max__ens_mean__spatial_mean',
              'comp_dz__time_max__amp_ens_mean_spatial_perc_90',  
              'td_850__ens_mean__spatial_mean',
              '10-500m_bulkshear__time_max__ens_mean__spatial_mean',
              'w_down__time_min__amp_ens_mean_spatial_perc_10'
    
            ],
            'wind_severe_0km' : 
                    [
                  'v_10__ens_mean__spatial_mean',
                  'ws_80__time_max__amp_ens_mean_spatial_perc_90',
                  'comp_dz__time_max__amp_ens_mean_spatial_perc_90',
                  'div_10m__time_min__ens_mean__spatial_mean',
                  'buoyancy__time_min__amp_ens_mean_spatial_perc_10',  
                    ],
             
             'all_severe' : 
             [
                'comp_dz__time_max__amp_ens_mean_spatial_perc_90',
                'div_10m__time_min__ens_std__spatial_mean',
                'ctt__time_min__amp_ens_mean_spatial_perc_10',
                'hailcast__time_max__ens_mean__spatial_mean',
                '10-500m_bulkshear__time_max__ens_mean__spatial_mean',
                ],
             'all_sig_severe': [
                'low_level_lapse_rate__ens_mean__spatial_mean',
                'ctt__time_min__amp_ens_mean_spatial_perc_10',
                '10-500m_bulkshear__time_max__ens_mean__spatial_mean',
                'hailcast__time_max__ens_mean__spatial_mean',
                'comp_dz__time_max__amp_ens_mean_spatial_perc_90', 
                ],
             
            }

In [6]:
TOP_FEATURES = {'tornado' : 
            ['wz_0to2_instant__time_max__amp_ens_mean_spatial_perc_90', 
             'shear_v_0to6__ens_mean__spatial_mean',
             'buoyancy__time_min__amp_ens_mean_spatial_perc_10',
             '10-500m_bulkshear__time_max__amp_ens_mean_spatial_perc_90',
             'v_10__ens_mean__spatial_mean',
            ],
            'hail' : 
                    [
              'dbz_3to5__time_max__ens_mean__spatial_mean',
              'comp_dz__time_max__amp_ens_mean_spatial_perc_90',  
              'td_850__ens_mean__spatial_mean',
              '10-500m_bulkshear__time_max__ens_mean__spatial_mean',
              'w_down__time_min__amp_ens_mean_spatial_perc_10'
    
            ],
            'wind' : 
                    [
                  'v_10__ens_mean__spatial_mean',
                  'ws_80__time_max__amp_ens_mean_spatial_perc_90',
                  'comp_dz__time_max__amp_ens_mean_spatial_perc_90',
                  'div_10m__time_min__ens_mean__spatial_mean',
                  'buoyancy__time_min__amp_ens_mean_spatial_perc_10',  
                    ],
             
             'all_severe' : 
             [
                'comp_dz__time_max__amp_ens_mean_spatial_perc_90',
                'div_10m__time_min__ens_std__spatial_mean',
                'ctt__time_min__amp_ens_mean_spatial_perc_10',
                'hailcast__time_max__ens_mean__spatial_mean',
                '10-500m_bulkshear__time_max__ens_mean__spatial_mean',
                ],
             'all_sig_severe': [
                'low_level_lapse_rate__ens_mean__spatial_mean',
                'ctt__time_min__amp_ens_mean_spatial_perc_10',
                '10-500m_bulkshear__time_max__ens_mean__spatial_mean',
                'hailcast__time_max__ens_mean__spatial_mean',
                'comp_dz__time_max__amp_ens_mean_spatial_perc_90', 
                ],
             
            }

In [7]:
dataframe = load_ml_data('wind_severe_0km', 
                 lead_time = 'first_hour', 
                 mode = None, 
                 baseline=False,
                 return_only_df=True, 
                 load_reduced=True, 
                 base_path = '/work/mflora/ML_DATA/DATA',
                )

#['label', 'obj_centroid_x', 'obj_centroid_y']

data = {hazard : dataframe[preds] for hazard, preds in TOP_FEATURES.items()}


Only keeping warm season cases for the official training!


In [10]:
new_dataframe = pd.DataFrame(data, index=TOP_FEATURES.keys())

ValueError: Data must be 1-dimensional

### Create the top 5 predictor global panel 

In [5]:
%matplotlib inline

TITLES = {'wind_severe_0km' : 'Severe Wind', 
          'hail_severe_0km' : 'Severe Hail', 
          'tornado_severe_0km' : 'Tornado', 
          'all_severe' : 'Any Severe', 
          'all_sig_severe' : 'Any Sig. Severe', 
         }


targets = ['wind_severe_0km', 
           'hail_severe_0km', 
           'tornado_severe_0km', 
           ['wind_severe_0km', 'hail_severe_0km', 'tornado_severe_0km'],
           ['wind_sig_severe_0km', 'hail_sig_severe_0km', 'tornado_sig_severe_0km']
          ]

for target in targets: 
    target_str = get_target_str(target)
    X_train, y_train, metadata = load_ml_data(target, 
                 lead_time = 'first_hour', 
                 mode = None, 
                 baseline=False,
                 return_only_df=False, 
                 load_reduced=True, 
                 base_path = '/work/mflora/ML_DATA/DATA',
                )

    # Impute missing values. 
    X_train = pd.DataFrame(SimpleImputer().fit_transform(X_train), columns=X_train.columns)


    explainer = cbWoFSExplainabilityGraphics(X_train.astype(float), y_train)
    features = top_preds[target_str]
    fig, _ = explainer.create_global(features, target=TITLES[target_str])
    
    explainer.save_global(fig, target_str)

Only keeping warm season cases for the official training!
Only keeping warm season cases for the official training!
Only keeping warm season cases for the official training!
Only keeping warm season cases for the official training!
Only keeping warm season cases for the official training!
