### Cumulative contributions to clc_32

**For Figure 10**

In [1]:
import os
import shap
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline, BSpline

import importlib
importlib.reload(matplotlib)
importlib.reload(plt)

<module 'matplotlib.pyplot' from '/pf/b/b309170/work/b309170/conda/envs/clouds113/lib/python3.7/site-packages/matplotlib/pyplot.py'>

*Load SHAP files*

In [3]:
r2b4_shap_files = []
for file in os.listdir('../shap_values'):
    if file.startswith('r2b4_shap'):
        r2b4_shap_files.append('../shap_values/'+file)
        
r2b4_file_count = len(r2b4_shap_files)

In [4]:
r2b5_shap_files = []
for file in os.listdir('../shap_values'):
    if file.startswith('r2b5_shap'):
        r2b5_shap_files.append('../shap_values/'+file)
        
r2b5_file_count = len(r2b5_shap_files)

In [5]:
# Append the averages of all shap value files we have for r2b4 and r2b5 each
r2b4_shap_values = [np.load(r2b4_shap_files[i]) for i in range(r2b4_file_count)]
r2b4_shap_means = [np.mean(r2b4_shap_values[i], axis=0) for i in range(r2b4_file_count)]

r2b5_shap_values = [np.load(r2b5_shap_files[i]) for i in range(r2b5_file_count)]
r2b5_shap_means = [np.mean(r2b5_shap_values[i], axis=0) for i in range(r2b5_file_count)]

*Feature names*

In [6]:
# R2B4
r2b4_feature_names = []
feat_names = ['qv', 'qc', 'qi', 'temp', 'pres', 'rho', 'zg']
for s in feat_names:
    for i in range(21, 48):
        r2b4_feature_names.append('%s_%d'%(s, i))
        
r2b4_feature_names.append('fr_lake')
r2b4_feature_names = np.array(r2b4_feature_names)

remove_fields = [27, 162, 163, 164]
r2b4_feature_names = np.delete(r2b4_feature_names, remove_fields)

In [7]:
# R2B5
r2b5_feature_names = []
feat_names = ['qv', 'qc', 'qi', 'temp', 'pres', 'zg']
for s in feat_names:
    for i in range(21, 48):
        r2b5_feature_names.append('%s_%d'%(s, i))
        
r2b5_feature_names.append('fr_land')
r2b5_feature_names = np.array(r2b5_feature_names)

remove_fields = [27, 28, 29, 30, 31, 32, 135, 136, 137]
r2b5_feature_names = np.delete(r2b5_feature_names, remove_fields)

In [8]:
# Intersecting and unique features
features_intersect = np.intersect1d(r2b4_feature_names, r2b5_feature_names)

only_in_r2b5 = set(r2b5_feature_names).difference(set(features_intersect))
only_in_r2b4 = set(r2b4_feature_names).difference(set(features_intersect))

In [9]:
# For every feature in features_intersect, we extract the means 
r2b4_shap_means_intersect = []
r2b5_shap_means_intersect = []

for s in features_intersect:
    feature_ind_r2b4 = np.where(r2b4_feature_names==s)[0][0]
    feature_ind_r2b5 = np.where(r2b5_feature_names==s)[0][0]
    
    r2b4_shap_means_intersect.append([r2b4_shap_means[i][feature_ind_r2b4] for i in range(r2b4_file_count)])
    r2b5_shap_means_intersect.append([r2b5_shap_means[i][feature_ind_r2b5] for i in range(r2b5_file_count)])

# List with as many entries as there are intersecting features. Each entry has as many entries as there are shap value files.
assert len(r2b4_shap_means_intersect) == len(r2b5_shap_means_intersect) == len(features_intersect)
assert len(r2b4_shap_means_intersect[0]) == r2b4_file_count
assert len(r2b5_shap_means_intersect[0]) == r2b5_file_count

r2b4_shap_means_intersect = np.array(r2b4_shap_means_intersect)
r2b5_shap_means_intersect = np.array(r2b5_shap_means_intersect)

assert r2b4_shap_means_intersect.shape == (len(features_intersect), r2b4_file_count)
assert r2b5_shap_means_intersect.shape == (len(features_intersect), r2b5_file_count)

In [10]:
np.sum(r2b5_shap_means_intersect[:2], axis=0)

array([-0.01372258,  0.00787697, -0.00492056,  0.00322776, -0.01075318,
       -0.01436394, -0.0176911 ,  0.00218605,  0.000322  , -0.01064747])

*Cumulative plot*

In [11]:
# Sum over feature contributions from the entire column
r2b4_var_type = []
r2b5_var_type = []

for var in ['qv', 'qi', 'qc', 'zg', 'pres', 'temp']:
    inds = [ind for ind in np.arange(len(features_intersect)) if features_intersect[ind].startswith(var)]
    r2b4_var_type.append(np.sum(r2b4_shap_means_intersect[inds], axis=0))
    r2b5_var_type.append(np.sum(r2b5_shap_means_intersect[inds], axis=0))
    
r2b4_var_type = np.array(r2b4_var_type)
r2b5_var_type = np.array(r2b5_var_type)

assert r2b4_var_type.shape == (6, r2b4_file_count)
assert r2b5_var_type.shape == (6, r2b5_file_count)

In [12]:
# Errorbars: Maximum possible deviation
# R2B4
r2b4_err_lower = r2b4_var_type[:, 0] - np.min(r2b4_var_type, axis=1)
r2b4_err_upper = np.max(r2b4_var_type, axis=1) - r2b4_var_type[:, 0]

# R2B5
r2b5_err_lower = r2b5_var_type[:, 0] - np.min(r2b5_var_type, axis=1)
r2b5_err_upper = np.max(r2b5_var_type, axis=1) - r2b5_var_type[:, 0]

*SHAP vertical profiles for qi and qv*

In [13]:
inds = [ind for ind in np.arange(len(features_intersect)) if features_intersect[ind].startswith('qv')]

i = 21
qv_r2b4 = []
qv_r2b5 = []

for ind in inds:
    qv_r2b4.append(r2b4_shap_means_intersect[ind])
    qv_r2b5.append(r2b5_shap_means_intersect[ind])
    i += 1

In [14]:
inds = [ind for ind in np.arange(len(features_intersect)) if features_intersect[ind].startswith('qi')]

i = 21
qi_r2b4 = []
qi_r2b5 = []

for ind in inds:
    qi_r2b4.append(r2b4_shap_means_intersect[ind])
    qi_r2b5.append(r2b5_shap_means_intersect[ind])
    i += 1

*SHAP Dependence Plots*

Conditional expectation line is computed over all seeds. <br>
Point cloud is shown only for one seed.

In [15]:
# To provide the Dependence Plot the corresponding 10000 NARVAL samples
r2b4_narval_r2b5_samples = np.load('../shap_values/r2b4_narval_r2b5_samples_layer_32_seed_100_train_samples_10000_narval_samples_10000.npy')
r2b5_narval_r2b5_samples = np.load('../shap_values/r2b5_narval_r2b5_samples_layer_32_seed_100_train_samples_7931_narval_samples_10000-constructed_base_value.npy')

# The 10000 NARVAL samples should be seed-independent
r2b4_narval_samples = ['../shap_values/'+file for file in os.listdir('../shap_values') if file.startswith('r2b4_narval_r2b5_samples')]
r2b5_narval_samples = ['../shap_values/'+file for file in os.listdir('../shap_values') if file.startswith('r2b5_narval_r2b5_samples')]

for i in range(len(r2b4_narval_samples)):
    assert np.all(np.load(r2b4_narval_samples[0]) == np.load(r2b4_narval_samples[i]))
for i in range(len(r2b5_narval_samples)):
    assert np.all(np.load(r2b5_narval_samples[0]) == np.load(r2b5_narval_samples[i]))

In [16]:
# Needed to scale back to native units 

# R2B4 Column-based model
r2b4_feature_means = np.array([2.62572183e-06,2.72625252e-06,2.74939600e-06,3.30840599e-06,6.62808605e-06,1.75788934e-05,4.56026919e-05,1.05190041e-04,2.05702805e-04,3.57870694e-04,5.71860616e-04,8.86342854e-04,1.40607454e-03,2.11394275e-03,2.96908898e-03,3.83956666e-03,4.85640761e-03,6.05059066e-03,7.37936039e-03,8.88779732e-03,1.05374548e-02,1.20163575e-02,1.32316365e-02,1.40249843e-02,1.44862015e-02,1.47169496e-02,1.49353026e-02,4.10339294e-14,1.09916165e-10,5.08967307e-11,9.79269311e-14,7.81782591e-13,1.59702138e-12,1.06302286e-08,1.03287141e-07,2.32342195e-07,4.52571159e-07,9.59800950e-07,2.75292262e-06,5.47922031e-06,6.96345062e-06,7.10544829e-06,8.49121303e-06,1.14876828e-05,1.62598283e-05,2.54900781e-05,3.60999973e-05,3.30096121e-05,1.50384025e-05,3.37482390e-06,9.94423396e-07,3.95924469e-07,2.27436437e-07,1.47661800e-14,4.78581565e-11,6.02759292e-09,7.85422277e-08,3.42838766e-07,1.03181587e-06,2.10645844e-06,2.66487045e-06,2.04870326e-06,1.01504965e-06,4.92335725e-07,2.89430485e-07,1.73665966e-07,6.58006285e-08,1.47246476e-08,2.46884148e-09,2.97776000e-10,2.23559883e-11,1.53999974e-12,9.41478240e-13,7.94546431e-13,5.45907918e-13,2.42190024e-13,1.03934147e-13,3.65539123e-14,1.55304439e-14,1.05358904e-14,2.08525301e+02,2.02078330e+02,1.96095922e+02,1.96231880e+02,2.02855933e+02,2.11649673e+02,2.21128411e+02,2.30533497e+02,2.39352824e+02,2.47584803e+02,2.54661191e+02,2.60677478e+02,2.65862196e+02,2.70101506e+02,2.74180293e+02,2.77942434e+02,2.81490486e+02,2.84592019e+02,2.87187378e+02,2.89183826e+02,2.90680284e+02,2.92060146e+02,2.93733091e+02,2.95405966e+02,2.96851675e+02,2.97902688e+02,2.98445713e+02,4.89214262e+03,6.41523961e+03,8.33266520e+03,1.08240088e+04,1.37001631e+04,1.70600136e+04,2.07632553e+04,2.48067011e+04,2.90956820e+04,3.37499929e+04,3.85120640e+04,4.34081851e+04,4.86049928e+04,5.36237056e+04,5.88910085e+04,6.39346849e+04,6.89441862e+04,7.37692055e+04,7.83390226e+04,8.26346116e+04,8.66219041e+04,9.01095619e+04,9.33174822e+04,9.59079770e+04,9.79586311e+04,9.94225005e+04,1.00212068e+05,8.18133756e-02,1.10740562e-01,1.48196795e-01,1.92120442e-01,2.35155442e-01,2.80640720e-01,3.26932112e-01,3.74686426e-01,4.23310942e-01,4.74706421e-01,5.26574353e-01,5.79757246e-01,6.36319453e-01,6.90718806e-01,7.46907098e-01,7.99472206e-01,8.50742657e-01,8.99725050e-01,9.46097691e-01,9.90226000e-01,1.03168333e+00,1.06722030e+00,1.09810057e+00,1.12163306e+00,1.13970514e+00,1.15248459e+00,1.15929123e+00,1.61339519e+04,1.47406270e+04,1.34213890e+04,1.21742147e+04,1.09970932e+04,9.88816591e+03,8.84571032e+03,7.86812123e+03,6.95389299e+03,6.10160326e+03,5.30990022e+03,4.57749545e+03,3.90316351e+03,3.28574924e+03,2.72418251e+03,2.21749985e+03,1.76487258e+03,1.36564177e+03,1.01936294e+03,7.25868116e+02,4.85363829e+02,2.98613214e+02,1.67518124e+02,9.62005988e+01,2.51278402e-03])
r2b4_feature_stds = np.array([1.47224807e-07,2.51491754e-07,2.82187441e-07,5.93724945e-07,1.51594298e-06,6.30296895e-06,2.22274936e-05,6.22211930e-05,1.43226310e-04,2.86338200e-04,4.85240662e-04,7.58182385e-04,1.13414912e-03,1.54548518e-03,1.89771471e-03,2.12619203e-03,2.33918584e-03,2.55623874e-03,2.79223044e-03,2.96787488e-03,2.98239542e-03,2.88134285e-03,2.89833854e-03,3.00368460e-03,3.09352534e-03,3.13736360e-03,3.16548101e-03,9.42811508e-12,3.44228546e-09,1.59084922e-09,5.93648639e-12,3.42805209e-11,8.49979059e-11,1.42841729e-07,8.71018616e-07,1.61722240e-06,2.73813894e-06,4.98860589e-06,1.09156013e-05,1.92056078e-05,2.58635856e-05,2.67312436e-05,2.88542767e-05,3.27361856e-05,3.80907041e-05,4.95573286e-05,6.64222088e-05,6.35284725e-05,4.02976359e-05,1.60662423e-05,8.31400034e-06,4.98668318e-06,4.18798497e-06,4.72870218e-12,9.65155904e-09,1.53906014e-07,1.30802262e-06,3.95656270e-06,7.01268576e-06,9.21713269e-06,9.56682609e-06,7.50242259e-06,4.54526978e-06,2.74400532e-06,1.68254503e-06,9.94899145e-07,4.90686500e-07,1.74423527e-07,6.26854136e-08,1.35744721e-08,2.72159858e-09,2.39259137e-10,6.37191578e-11,5.72720348e-11,4.81701846e-11,2.14967902e-11,1.01020924e-11,4.65900766e-12,3.62080895e-12,3.08853822e-12,2.14124189e+00,2.04636219e+00,2.28842093e+00,2.18477413e+00,1.60619664e+00,1.23447911e+00,1.36168420e+00,1.59257820e+00,1.77424088e+00,1.85501643e+00,1.85519231e+00,1.74569793e+00,1.50736838e+00,1.37280694e+00,1.30119982e+00,1.30107249e+00,1.42479076e+00,1.62621509e+00,1.88375876e+00,2.13026250e+00,2.32127917e+00,2.48017927e+00,2.58637194e+00,2.59273129e+00,2.57883413e+00,2.58848551e+00,2.63916745e+00,5.25025104e+01,6.13379273e+01,6.86279405e+01,8.39005677e+01,1.08798431e+02,1.33773907e+02,1.50246438e+02,1.57891625e+02,1.59368430e+02,1.59393957e+02,1.63436704e+02,1.78941090e+02,2.18262201e+02,2.74998258e+02,3.62779581e+02,4.62465472e+02,5.85167856e+02,7.29868327e+02,8.80813572e+02,1.05716175e+03,1.23096824e+03,1.39147497e+03,1.57507360e+03,1.69639911e+03,1.79506877e+03,1.86809404e+03,1.89542576e+03,9.42444728e-04,1.10053780e-03,1.74702325e-03,2.61821961e-03,2.72889747e-03,2.27853670e-03,1.76545251e-03,1.59704425e-03,1.94238096e-03,2.47019694e-03,3.01007451e-03,3.43331975e-03,3.85832590e-03,4.47209741e-03,5.18871128e-03,5.98561751e-03,7.30225606e-03,9.21774294e-03,1.16181468e-02,1.47943226e-02,1.80870903e-02,2.07868612e-02,2.28726988e-02,2.36245337e-02,2.39326331e-02,2.39750201e-02,2.37309135e-02,4.37962105e-01,1.18129232e+00,1.95876778e+00,3.13352980e+00,4.84460444e+00,7.25070491e+00,1.05220551e+01,1.48284694e+01,2.03242942e+01,2.71311569e+01,3.53196384e+01,4.48911206e+01,5.57614000e+01,6.77483165e+01,8.05664007e+01,9.38318753e+01,1.07080624e+02,1.19799754e+02,1.31470485e+02,1.41617295e+02,1.49856634e+02,1.55939200e+02,1.60153490e+02,1.61897063e+02,1.10524044e-02])

# R2B5 Column-based (fold 2) model
r2b5_feature_means = np.array([2.57681365e-06,2.60161901e-06,2.86229890e-06,3.49524686e-06,6.32444387e-06,1.62852938e-05,4.26197236e-05,1.00492283e-04,2.10850387e-04,3.96992495e-04,6.62768743e-04,1.00639902e-03,1.42273038e-03,1.89269379e-03,2.42406883e-03,2.97704256e-03,3.52303812e-03,4.15430913e-03,4.89285256e-03,5.71192194e-03,6.58451740e-03,7.47955824e-03,8.42949837e-03,9.18162558e-03,9.58900058e-03,9.80246788e-03,9.98071441e-03,2.57897497e-16,1.24502901e-08,5.43912468e-07,1.97554777e-06,2.10205332e-06,3.45718981e-06,4.17987790e-06,4.89876027e-06,6.03250921e-06,6.71487544e-06,7.71281746e-06,9.96528417e-06,1.40351017e-05,1.87534642e-05,2.15523809e-05,1.77725032e-05,1.10700238e-05,6.98113679e-06,5.98240074e-06,8.03857856e-06,1.55278994e-05,1.98903187e-13,1.45240003e-10,2.39426913e-08,5.63226688e-07,3.10209365e-06,6.64324795e-06,8.83422658e-06,9.89681102e-06,9.97096463e-06,7.74324652e-06,4.95774608e-06,2.61087000e-06,1.29680563e-06,7.46596833e-07,4.94444102e-07,3.51674311e-07,2.61199355e-07,2.03219747e-07,1.66907845e-07,1.42871199e-07,1.25114261e-07,1.11956533e-07,1.02782118e-07,9.86031894e-08,9.95790399e-08,1.06733810e-07,1.26921172e-07,2.10924633e+02,2.07944695e+02,2.05115507e+02,2.03204784e+02,2.06103772e+02,2.12329817e+02,2.19299382e+02,2.26348890e+02,2.33352039e+02,2.40105681e+02,2.46401637e+02,2.52153555e+02,2.57207037e+02,2.61575645e+02,2.65446543e+02,2.68951996e+02,2.72093136e+02,2.74765728e+02,2.76963041e+02,2.78775116e+02,2.80398659e+02,2.81959850e+02,2.83501227e+02,2.84935364e+02,2.86119192e+02,2.86867707e+02,2.87046277e+02,4.78805278e+03,6.25615004e+03,8.06726288e+03,1.03500805e+04,1.30603494e+04,1.61944127e+04,1.97232230e+04,2.36181577e+04,2.78401230e+04,3.23377105e+04,3.70511232e+04,4.19785078e+04,4.70365400e+04,5.21124420e+04,5.72512536e+04,6.23517142e+04,6.72989145e+04,7.20972394e+04,7.66740332e+04,8.09510300e+04,8.49437983e+04,8.85136468e+04,9.16490946e+04,9.42529147e+04,9.63348759e+04,9.77633315e+04,9.86144363e+04,1.61343240e+04,1.47416307e+04,1.34230525e+04,1.21768751e+04,1.10012039e+04,9.89431495e+03,8.85470770e+03,7.88104473e+03,6.97198713e+03,6.12617252e+03,5.34218664e+03,4.61854836e+03,3.95376191e+03,3.34629894e+03,2.79465640e+03,2.29750295e+03,1.85381761e+03,1.46282067e+03,1.12390793e+03,8.36771545e+02,6.01482480e+02,4.18667943e+02,2.90324051e+02,2.20122534e+02,2.57179068e-01])
r2b5_feature_stds = np.array([1.66577356e-07,2.69438906e-07,6.32166532e-07,1.46870734e-06,2.84939866e-06,8.72797379e-06,2.96195352e-05,8.32385500e-05,1.93655438e-04,3.82345501e-04,6.27888913e-04,9.30858552e-04,1.27418047e-03,1.61904466e-03,1.95753088e-03,2.23604988e-03,2.49372225e-03,2.83062031e-03,3.22013981e-03,3.62381600e-03,4.05060687e-03,4.53912094e-03,5.14120557e-03,5.61150119e-03,5.82135854e-03,5.92232391e-03,6.02114792e-03,1.93770206e-12,1.94386132e-07,2.96883744e-06,8.75974976e-06,1.02724976e-05,1.44929996e-05,1.65663508e-05,1.81326398e-05,2.09805520e-05,2.41254125e-05,2.82129201e-05,3.57128254e-05,4.76374494e-05,5.95853155e-05,6.60615445e-05,5.80449728e-05,4.41472861e-05,3.63224833e-05,3.84500230e-05,5.48299167e-05,1.01230094e-04,2.84426774e-10,8.67755936e-08,2.57513880e-06,8.53816046e-06,1.97356234e-05,2.80242488e-05,3.05461589e-05,3.17141059e-05,3.22684724e-05,2.65101493e-05,1.83177779e-05,1.05168506e-05,6.02113023e-06,4.48008643e-06,3.22093921e-06,2.22409748e-06,1.62946826e-06,1.31793070e-06,1.14119306e-06,1.02620335e-06,9.44881472e-07,8.91631794e-07,8.56775098e-07,8.35271824e-07,8.19217124e-07,8.09879379e-07,8.37114763e-07,4.61938080e+00,5.32560366e+00,6.59828260e+00,8.11723979e+00,6.41348334e+00,3.53782199e+00,3.22568870e+00,5.51353694e+00,7.69704358e+00,9.19195458e+00,1.00629480e+01,1.04422426e+01,1.05152774e+01,1.05114012e+01,1.05471048e+01,1.07545816e+01,1.10865116e+01,1.13965606e+01,1.17078707e+01,1.20696061e+01,1.24923716e+01,1.28886157e+01,1.32697644e+01,1.37077879e+01,1.42377515e+01,1.47618886e+01,1.53836576e+01,1.65668010e+02,2.26139751e+02,3.23181791e+02,4.81891512e+02,6.96129876e+02,9.23059045e+02,1.13288883e+03,1.30878137e+03,1.44327690e+03,1.53977206e+03,1.61213693e+03,1.68432353e+03,1.78325130e+03,1.93151619e+03,2.15397041e+03,2.45262514e+03,2.80472214e+03,3.20634251e+03,3.63386944e+03,4.07078962e+03,4.52036282e+03,4.94232873e+03,5.33456701e+03,5.66290433e+03,5.92823270e+03,6.11058960e+03,6.23266007e+03,1.66988637e+00,4.50375687e+00,7.46700178e+00,1.19425370e+01,1.84570049e+01,2.76286540e+01,4.08392722e+01,6.01544270e+01,8.66284809e+01,1.20094159e+02,1.59157271e+02,2.01903044e+02,2.46894999e+02,2.92578673e+02,3.37234060e+02,3.79561515e+02,4.19219129e+02,4.56022726e+02,4.89322766e+02,5.18456198e+02,5.42710799e+02,5.61319067e+02,5.74407679e+02,5.79710837e+02,4.23033734e-01])

In [17]:
def conditional_line(shap_values, narval_samples, feature_names, feature, eps = 1e-4):
    '''
        shap_values: nd_array, indexed by [L][M, N]
        narval_samples: nd_array, indexed by [M, N]
        feature_names: nd_array of strings, indexed by [N]
        feature: String that is in feature_names
    '''
    feature_ind = np.where(feature_names=='%s_32'%feature)[0][0]
    
    xvals = narval_samples[:, feature_ind]
    yvals = np.mean(np.array(shap_values), axis=0)[:, feature_ind] # Average of shap values over all seeds
    
    k_max = int(np.floor(max(xvals)/eps))
    
    b = []
    for k in range(k_max): # Stop after we reached the maximum value for x
        b.append(np.mean([yvals[i] for i in range(len(yvals)) if k*eps <= xvals[i] < (k+1)*eps])) # Basically using bins here
        
    # Corresponding x-values
#     a = eps*np.arange(k_max) + eps/2
    a = eps*np.arange(k_max)
    
    # We have nans if there are no points in k*eps <= xvals[i] < (k+1)*eps. We simply remove these
    a_new = [a[i] for i in range(len(b)) if ~np.isnan(b[i])]
    b_new = [b[i] for i in range(len(b)) if ~np.isnan(b[i])]
    
    # We use a spline of degree 3 to draw a smooth line
    xnew = np.linspace(min(a_new), max(a_new), 200) 
    spl = make_interp_spline(a_new, b_new, k=3)
    y_smooth = spl(xnew)
    
    return xnew, y_smooth

In [18]:
fig = plt.figure(figsize=(18,5))
# plt.subplots_adjust(bottom=0.1)

label_size=20

## First plot
ax = fig.add_subplot(131)
# Scale back to native units
feature_ind_r2b4 = np.where(r2b4_feature_names=='qi_32')[0][0]
r2b4_narval_r2b5_samples_qi = r2b4_narval_r2b5_samples*r2b4_feature_stds[feature_ind_r2b4] + r2b4_feature_means[feature_ind_r2b4]
feature_ind_r2b5 = np.where(r2b5_feature_names=='qi_32')[0][0]
r2b5_narval_r2b5_samples_qi = r2b5_narval_r2b5_samples*r2b5_feature_stds[feature_ind_r2b5] + r2b5_feature_means[feature_ind_r2b5]

# The narval samples should be the same in their original unnormalized space
assert np.all(np.abs(r2b5_narval_r2b5_samples_qi[:, feature_ind_r2b5] - r2b4_narval_r2b5_samples_qi[:, feature_ind_r2b4]) < 1e-10)

# Average SHAP values
r2b4_mean = np.mean(np.array(r2b4_shap_values)[:, :, feature_ind_r2b4], dtype=np.float64)
r2b5_mean = np.mean(np.array(r2b5_shap_values)[:, :, feature_ind_r2b5], dtype=np.float64)

# Put the one with the larger range second
sdp = shap.dependence_plot(feature_ind_r2b4, r2b4_shap_values[0], features=r2b4_narval_r2b5_samples_qi, ax=ax,
                     feature_names=r2b4_feature_names, interaction_index=None, show=False, color='blue', dot_size=5, alpha=0.6) 
sdp = shap.dependence_plot(feature_ind_r2b5, r2b5_shap_values[0], features=r2b5_narval_r2b5_samples_qi, ax=ax,
                     feature_names=r2b5_feature_names, interaction_index=None, show=False, color='orange', dot_size=5, alpha=0.6) 
# It's the same as:
# sdp = shap.dependence_plot(0, r2b5_shap_values[0][:, feature_ind_r2b5:(feature_ind_r2b5+1)], features=r2b5_narval_r2b5_samples_qi[:, feature_ind_r2b5:(feature_ind_r2b5+1)], 
#                              ax=ax, feature_names=r2b5_feature_names, interaction_index=None, show=False, color='orange', dot_size=5)

# Plot showing averages
qi_min = np.min(r2b4_narval_r2b5_samples_qi[:, feature_ind_r2b4])
qi_max = np.max(r2b4_narval_r2b5_samples_qi[:, feature_ind_r2b4])
plt.plot([qi_min, qi_max], [r2b4_mean, r2b4_mean], 'b--', linewidth=1.5)
plt.plot([qi_min, qi_max], [r2b5_mean, r2b5_mean], color='orange', linestyle='--', linewidth=1.5)
# Legend
ax.annotate('NARVAL R2B4 model', xy=(0.5,0.84),xycoords='axes fraction', color='blue', fontsize=14)
ax.annotate('QUBICC R2B5 model', xy=(0.5,0.9),xycoords='axes fraction', color='orange', fontsize=14)
plt.xlabel('$q_i$_32 [kg/kg]', fontsize=label_size)
plt.ylabel('SHAP values for clc_32', fontsize=label_size)

# Conditional averages. The choice of eps has a large influence on the plot
m = 50
xnew, y_smooth = conditional_line(r2b4_shap_values, r2b4_narval_r2b5_samples_qi, r2b4_feature_names, 'qi', eps = 5*1e-6) # eps = 2*1e-5
ax.plot(xnew[:m], y_smooth[:m], linewidth=4)
xnew, y_smooth = conditional_line(r2b5_shap_values, r2b5_narval_r2b5_samples_qi, r2b5_feature_names, 'qi', eps = 5*1e-6)
ax.plot(xnew[:m], y_smooth[:m], linewidth=4)

plt.ylim((-6.279482202575369, 87.43461904261041)) # Taken from the qv plot

## Second plot
ax_2 = fig.add_subplot(132)
# Scale back to native units
feature_ind_r2b4 = np.where(r2b4_feature_names=='qv_32')[0][0]
r2b4_narval_r2b5_samples_qv = r2b4_narval_r2b5_samples*r2b4_feature_stds[feature_ind_r2b4] + r2b4_feature_means[feature_ind_r2b4]
feature_ind_r2b5 = np.where(r2b5_feature_names=='qv_32')[0][0]
r2b5_narval_r2b5_samples_qv = r2b5_narval_r2b5_samples*r2b5_feature_stds[feature_ind_r2b5] + r2b5_feature_means[feature_ind_r2b5]

# The narval samples should be the same in their original unnormalized space
assert np.all(np.abs(r2b5_narval_r2b5_samples_qv[:, feature_ind_r2b5] - r2b4_narval_r2b5_samples_qv[:, feature_ind_r2b4]) < 1e-10)

# Average SHAP values
r2b4_mean = np.mean(np.array(r2b4_shap_values)[:, :, feature_ind_r2b4], dtype=np.float64)
r2b5_mean = np.mean(np.array(r2b5_shap_values)[:, :, feature_ind_r2b5], dtype=np.float64)

# Put the one with the larger range second
sdp_2 = shap.dependence_plot(feature_ind_r2b4, r2b4_shap_values[0], features=r2b4_narval_r2b5_samples_qv, ax=ax_2,
                     feature_names=r2b4_feature_names, interaction_index=None, show=False, color='blue', dot_size=5, alpha=0.6)
sdp_2 = shap.dependence_plot(feature_ind_r2b5, r2b5_shap_values[0], features=r2b5_narval_r2b5_samples_qv, ax=ax_2,
                     feature_names=r2b5_feature_names, interaction_index=None, show=False, color='orange', dot_size=5, alpha=0.6)

# Plot showing averages
qv_min = np.min(r2b4_narval_r2b5_samples_qv[:, feature_ind_r2b4])
qv_max = np.max(r2b4_narval_r2b5_samples_qv[:, feature_ind_r2b4])
plt.plot([qv_min, qv_max], [r2b4_mean, r2b4_mean], 'b--', linewidth=1.5)
plt.plot([qv_min, qv_max], [r2b5_mean, r2b5_mean], color='orange', linestyle='--', linewidth=1.5)

# Conditional averages. The choice of eps has a large influence on the plot
xnew, y_smooth = conditional_line(r2b4_shap_values, r2b4_narval_r2b5_samples_qv, r2b4_feature_names, 'qv', eps = 4*1e-4)
ax_2.plot(xnew, y_smooth, linewidth=4)
xnew, y_smooth = conditional_line(r2b5_shap_values, r2b5_narval_r2b5_samples_qv, r2b5_feature_names, 'qv', eps = 4*1e-4)
ax_2.plot(xnew, y_smooth, linewidth=4)

plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(-2,2))
# ax_2.xaxis.set_major_formatter(FormatStrFormatter('%E'))
plt.xlabel('$q_v$_32 [kg/kg]', fontsize=label_size)
plt.ylabel(' ')

# plt.savefig('figures/shap_dependence_plots.pdf')

Mean of empty slice.
invalid value encountered in double_scalars
Mean of empty slice.
invalid value encountered in double_scalars


Text(0, 0.5, ' ')

In [19]:
# plt.plot(qv_diffs, np.arange(27, 27+len(qi_diffs)), 'bo')
line_r2b5 = plt.plot(qv_r2b5, np.arange(21, 21+len(qv_r2b5)), '.', color='orange')
line_r2b4 = plt.plot(qv_r2b4, np.arange(21, 21+len(qv_r2b5)), '.', color='blue')
plt.ylabel('Vertical layer')
plt.xlabel('SHAP value for qv')
plt.title('qv')
# plt.xlabel('SHAP value difference')
plt.legend([line_r2b5[0], line_r2b4[0]], ['R2B5 QUBICC model', 'R2B4 NARVAL model'])
plt.grid(b=True)
# plt.legend(['NARVAL - QUBICC qi'])
plt.gca().invert_yaxis()

In [20]:
# plt.plot(qi_diffs, np.arange(27, 27+len(qi_diffs)), 'bo')
line_r2b5 = plt.plot(qi_r2b5, np.arange(21, 21+len(qi_r2b5)), '.', color='orange')
line_r2b4 = plt.plot(qi_r2b4, np.arange(21, 21+len(qi_r2b5)), '.', color='blue')
plt.ylabel('Vertical layer')
plt.xlabel('SHAP value for qi')
plt.title('qi')
# plt.xlabel('SHAP value difference')
plt.legend([line_r2b5[0], line_r2b4[0]], ['R2B5 QUBICC model', 'R2B4 NARVAL model'])
plt.grid(b=True)
# plt.legend(['NARVAL - QUBICC qi'])
plt.gca().invert_yaxis()

In [21]:
# Bars show minimum and maximum value
x_labels = ['qv','qi', 'qc', 'zg', 'pres','temp']
x = np.arange(len(x_labels)) # Label locations!
width = 0.4

fig = plt.figure()

ax = fig.add_subplot(111, ylabel='Sum of SHAP values', title='Contributions to clc_32 from the entire column')

ax.axhline(np.sum(r2b5_var_type[:, 0]), xmin=0, xmax=1, color='orange', linewidth=1)
ax.bar(np.arange(len(x_labels))-width/2, r2b5_var_type[:, 0], width=width, color='orange', yerr=np.array([r2b5_err_lower, r2b5_err_upper]),\
       align='center', alpha=0.5, ecolor='black', capsize=5)
ax.bar(np.arange(len(x_labels))+width/2, r2b4_var_type[:, 0], width=width, color='blue', yerr=np.array([r2b4_err_lower, r2b4_err_upper]),\
       align='center', alpha=0.5, ecolor='black', capsize=5)

ax.set_xticks(x)
ax.set_xticklabels(x_labels)
ax.legend(['R2B5 QUBICC model bias', 'R2B5 QUBICC model', 'R2B4 NARVAL model'])

ax.axhline(0, xmin=0, xmax=1, color='gray', linewidth=.5, ls='--')

# plt.savefig('figures/shap_clc_32_cumulative.pdf', bbox_inches='tight')

<matplotlib.lines.Line2D at 0x2ad5a26239d0>

**All in one plot**

In [22]:
# matplotlib.rcParams # To see all parameters of matplotlib

In [24]:
import matplotlib

# Increase the general font size in plots
size_plots_label = 22
matplotlib.rcParams['legend.fontsize'] = size_plots_label
matplotlib.rcParams['axes.labelsize'] = size_plots_label # For an axes xlabel and ylabel
matplotlib.rcParams['axes.titlesize'] = size_plots_label+2 # For an axes xlabel and ylabel
matplotlib.rcParams['xtick.labelsize'] = size_plots_label
matplotlib.rcParams['ytick.labelsize'] = size_plots_label

# Averaged over the NARVAL region
zg_mean_narval = [20785,19153,17604,16134,14741,13422,12175,10998,9890,8848,
                  7871,6958,6107,5317,4587,3915,3300,2741,2237,1787,1390,1046,
                  754,515,329,199,128] # in meters

zg_mean_narval = np.round(np.array(zg_mean_narval)/1000, decimals=1) # in kilometers

# Averaged globally
zg_mean_qubicc = [20785,19153,17604,16134,14742,13424,12178,11002,9896,8857,
                  7885,6977,6133,5351,4630,3968,3363,2814,2320,1878,1490,1153,
                  867,634,452,324,254] # in meters
zg_mean_qubicc = np.round(np.array(zg_mean_qubicc)/1000, decimals=1) # in kilometers

green='#004D40'
red='#D81B60'
blue='#1E88E5'

In [40]:
fig = plt.figure(figsize=(30,11))
# plt.subplots_adjust(bottom=0.1)

# # Increase the general font size
# matplotlib.rcParams['legend.fontsize'] = 'x-large'
# matplotlib.rcParams['axes.labelsize'] = 'xx-large' # For an axes xlabel and ylabel
# matplotlib.rcParams['xtick.labelsize'] = 'xx-large'
# matplotlib.rcParams['ytick.labelsize'] = 'xx-large'

# label_size=20 # For the dependence plots

## First plot
ax1 = fig.add_subplot(121, ylabel='$\Sigma$(SHAP values) / |Samples|')
# Bars show minimum and maximum value
x_labels = ['$q_v$','$q_i$', '$q_c$', '$z_g$', '$p$','$T$']
x = np.arange(len(x_labels)) # Label locations!
width = 0.4
ax1.axhline(np.sum(r2b5_var_type[:, 0]), xmin=0, xmax=1, color='orange', linewidth=1)
ax1.bar(np.arange(len(x_labels))-width/2, r2b5_var_type[:, 0], width=width, color='orange', yerr=np.array([r2b5_err_lower, r2b5_err_upper]),\
       align='center', alpha=0.5, ecolor='black', capsize=5)
ax1.bar(np.arange(len(x_labels))+width/2, r2b4_var_type[:, 0], width=width, color='blue', yerr=np.array([r2b4_err_lower, r2b4_err_upper]),\
       align='center', alpha=0.5, ecolor='black', capsize=5)
ax1.set_xticks(x)
ax1.set_xticklabels(x_labels)
ax1.set_title(r'$\bf{(a)}$ Summed SHAP values from the entire grid column ', fontsize=size_plots_label, pad=12)
ax1.legend(['R2B5 QUBICC model bias', 'R2B5 QUBICC model', 'R2B4 NARVAL model'])
ax1.axhline(0, xmin=0, xmax=1, color='gray', linewidth=.5, ls='--')

## Second plot
ax2 = fig.add_subplot(243, ylabel='Vertical layer')
# plt.plot(qv_diffs, np.arange(27, 27+len(qi_diffs)), 'bo')
line_r2b5 = ax2.plot(qv_r2b5, np.arange(21, 21+len(qv_r2b5)), '.', color='orange')
line_r2b4 = ax2.plot(qv_r2b4, np.arange(21, 21+len(qv_r2b5)), '.', color='blue')
# plt.xlabel('SHAP value difference')
ax2.legend([line_r2b5[0], line_r2b4[0]], ['R2B5 QUBICC model', 'R2B4 NARVAL model'], markerscale=3) # markerscale makes dots larger/readable in the legend!
ax2.grid(b=True)
ax2.set_title(r'$\bf{(b)}$ Mean SHAP values of $q_v$ per layer', fontsize=size_plots_label, pad=12)
# plt.legend(['NARVAL - QUBICC qi'])
plt.gca().invert_yaxis()

## Third plot
ax3 = fig.add_subplot(244)
# plt.plot(qi_diffs, np.arange(27, 27+len(qi_diffs)), 'bo')
line_r2b5 = ax3.plot(qi_r2b5, np.arange(21, 21+len(qi_r2b5)), '.', color='orange')
line_r2b4 = ax3.plot(qi_r2b4, np.arange(21, 21+len(qi_r2b5)), '.', color='blue')
# plt.xlabel('SHAP value difference')
# ax3.legend([line_r2b5[0], line_r2b4[0]], ['R2B5 QUBICC model', 'R2B4 NARVAL model'])
ax3.grid(b=True)
ax3.set_title(r'$\bf{(c)}$ Mean SHAP values of $q_i$ per layer', fontsize=size_plots_label, pad=12)
# plt.legend(['NARVAL - QUBICC qi'])
plt.gca().invert_yaxis()

## Forth plot
ax4 = fig.add_subplot(247)
# Scale back to native units
feature_ind_r2b4 = np.where(r2b4_feature_names=='qv_32')[0][0]
r2b4_narval_r2b5_samples_qv = r2b4_narval_r2b5_samples*r2b4_feature_stds[feature_ind_r2b4] + r2b4_feature_means[feature_ind_r2b4]
feature_ind_r2b5 = np.where(r2b5_feature_names=='qv_32')[0][0]
r2b5_narval_r2b5_samples_qv = r2b5_narval_r2b5_samples*r2b5_feature_stds[feature_ind_r2b5] + r2b5_feature_means[feature_ind_r2b5]
# Average SHAP values
r2b4_mean = np.mean(np.array(r2b4_shap_values)[:, :, feature_ind_r2b4], dtype=np.float64)
r2b5_mean = np.mean(np.array(r2b5_shap_values)[:, :, feature_ind_r2b5], dtype=np.float64)

# Put the one with the larger range second
sdp_2 = shap.dependence_plot(feature_ind_r2b4, r2b4_shap_values[0], features=r2b4_narval_r2b5_samples_qv, ax=ax4,
                     feature_names=r2b4_feature_names, interaction_index=None, show=False, color='blue', dot_size=5, alpha=0.7)
sdp_2 = shap.dependence_plot(feature_ind_r2b5, r2b5_shap_values[0], features=r2b5_narval_r2b5_samples_qv, ax=ax4,
                     feature_names=r2b5_feature_names, interaction_index=None, show=False, color='orange', dot_size=5, alpha=0.7)

# Plot showing averages
qv_min = np.min(r2b4_narval_r2b5_samples_qv[:, feature_ind_r2b4])
qv_max = np.max(r2b4_narval_r2b5_samples_qv[:, feature_ind_r2b4])
plt.plot([qv_min, qv_max], [r2b4_mean, r2b4_mean], 'b--', linewidth=1.5)
plt.plot([qv_min, qv_max], [r2b5_mean, r2b5_mean], color='orange', linestyle='--', linewidth=1.5)

# Conditional averages. The choice of eps has a large influence on the plot
xnew, y_smooth = conditional_line(r2b4_shap_values, r2b4_narval_r2b5_samples_qv, r2b4_feature_names, 'qv', eps = 4*1e-4)
ax4.plot(xnew, y_smooth, linewidth=4)
xnew, y_smooth = conditional_line(r2b5_shap_values, r2b5_narval_r2b5_samples_qv, r2b5_feature_names, 'qv', eps = 4*1e-4)
ax4.plot(xnew, y_smooth, linewidth=4)

plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(-2,2))
# ax_2.xaxis.set_major_formatter(FormatStrFormatter('%E'))
# Legend
ax4.annotate(r'$\bf{(d)}$', xy=(0.1,0.84),xycoords='axes fraction', fontsize=size_plots_label)
# ax4.annotate('NARVAL R2B4 model', xy=(0.1,0.84),xycoords='axes fraction', color='blue', fontsize=14)
# ax4.annotate('QUBICC R2B5 model', xy=(0.1,0.9),xycoords='axes fraction', color='orange', fontsize=14)
ax4.set_xlabel('$q_v$_32 [kg/kg]', fontsize=size_plots_label)
ax4.set_ylabel('SHAP value', fontsize=size_plots_label)
ax4.tick_params(labelsize=size_plots_label)
qv_ylim = plt.ylim()

## Fifth plot
ax5 = fig.add_subplot(248)
# Scale back to native units
feature_ind_r2b4 = np.where(r2b4_feature_names=='qi_32')[0][0]
r2b4_narval_r2b5_samples_qi = r2b4_narval_r2b5_samples*r2b4_feature_stds[feature_ind_r2b4] + r2b4_feature_means[feature_ind_r2b4]
feature_ind_r2b5 = np.where(r2b5_feature_names=='qi_32')[0][0]
r2b5_narval_r2b5_samples_qi = r2b5_narval_r2b5_samples*r2b5_feature_stds[feature_ind_r2b5] + r2b5_feature_means[feature_ind_r2b5]
# Average SHAP values
r2b4_mean = np.mean(np.array(r2b4_shap_values)[:, :, feature_ind_r2b4], dtype=np.float64)
r2b5_mean = np.mean(np.array(r2b5_shap_values)[:, :, feature_ind_r2b5], dtype=np.float64)

# Put the one with the larger range second
sdp = shap.dependence_plot(feature_ind_r2b4, r2b4_shap_values[0], features=r2b4_narval_r2b5_samples_qi, ax=ax5,
                     feature_names=r2b4_feature_names, interaction_index=None, show=False, color='blue', dot_size=5, alpha=0.7, xmax="percentile(99.85)") # Better to cut off at a high percentile
sdp = shap.dependence_plot(feature_ind_r2b5, r2b5_shap_values[0], features=r2b5_narval_r2b5_samples_qi, ax=ax5,
                     feature_names=r2b5_feature_names, interaction_index=None, show=False, color='orange', dot_size=5, alpha=0.7, xmax="percentile(99.85)") 

# Plot showing averages
qi_min = np.min(r2b4_narval_r2b5_samples_qi[:, feature_ind_r2b4])
qi_max = np.max(r2b4_narval_r2b5_samples_qi[:, feature_ind_r2b4])
ax5.plot([qi_min, qi_max], [r2b4_mean, r2b4_mean], 'b--', linewidth=1.5)
ax5.plot([qi_min, qi_max], [r2b5_mean, r2b5_mean], color='orange', linestyle='--', linewidth=1.5) 

# Conditional averages. The choice of eps has a large influence on the plot
m = 50
xnew, y_smooth = conditional_line(r2b4_shap_values, r2b4_narval_r2b5_samples_qi, r2b4_feature_names, 'qi', eps = 8*1e-6) # eps = 2*1e-5
ax5.plot(xnew[:m], y_smooth[:m], linewidth=4)
xnew, y_smooth = conditional_line(r2b5_shap_values, r2b5_narval_r2b5_samples_qi, r2b5_feature_names, 'qi', eps = 8*1e-6)
ax5.plot(xnew[:m], y_smooth[:m], linewidth=4)

# Legend
ax5.annotate(r'$\bf{(e)}$', xy=(0.7,0.84),xycoords='axes fraction', fontsize=size_plots_label)
ax5.set_xlabel('$q_i$_32 [kg/kg]', fontsize=size_plots_label)
ax5.set_ylabel('', fontsize=size_plots_label)
ax5.tick_params(labelsize=size_plots_label)
plt.ylim(qv_ylim)

plt.savefig('figures/shap_clc_32_all_plots.pdf')
# plt.show()

Mean of empty slice.
invalid value encountered in double_scalars
Mean of empty slice.
invalid value encountered in double_scalars


All subplots pertain only SHAP values for clc_32!

For each input feature, the SHAP values were first averaged over all 10000 NARVAL R02B05 samples.
If we now focus on a specific feature, such as q_v, then we can draw a plot like the one shown in **b)**.
We do this for every seed, so we can see some error bars in **(b)**.
If we sum up all SHAP values shown in **(b)**, we receive the first bar in plot **(a)**, including the error bars.

Now why is the Model bias the sum of all SHAP values divided by the number of samples? <br>
-> Well it's because we want to explain the average bias of the model on layer 32, not the one produced if we multiply the bias by the number of parameters!

**d) and e):** <br>
Each dot shows one sample: A pair of input feature and corresponding SHAP value 
Thick lines show the average SHAP values conditioned on small bins (respectively cover around 1/10 of the size of the range of input values).
We do not plot it for qi_32 > 2 kg/kg due to the low density of points there. The NARVAL samples we evaluate SHAP on are exactly the same for the QUBICC and the NARVAL model.