# Analysis of model explainability using SHAP

This notebook contains code to analyze the predictions of the ML4ShipTelemetry models. [SHAP](https://github.com/shap/shap/tree/master) is used for explaining the impact of the input features on the model output.

Both quality flag classification and the regression models are analyzed.

The primary goal is to identify which features plays the strongest role in the models, and to gauge whether the models have overfitted to unexpected features or use reasonable features according to domain experts.



As a bonus, we create interactive maps of the data using [KeplerGL](https://kepler.gl/).

In [1]:
import numpy as np
import pandas as pd
import joblib
import matplotlib.pyplot as plt
import shap
import os

## Configurations

In [2]:
# Set some display options
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 100)
pd.set_option('display.width', 1000)

In [3]:
# Set paths to stored models and data

# Base path to ml4shiptelemetry package: */ml4shiptelemetry/
base_path = '/path/to/ml4shiptelemetry/'

# Path (relative or absolute) to the data file (.npz file format) created by running main().
data_dir = os.path.join(base_path, 'data', 'proc_files_n_test_files_3_n_neighbours_0.npz')

# Path (relative or absolute) to the file containing the trained regression models (.pkl file format) created by running main().
model_reg_path = os.path.join(base_path, 'output', 'model_regression_n_neighbours_0.pkl')

# Path (relative or absolute) to the file containing the trained classification models (.pkl file format) created by running main().
model_path = os.path.join(base_path, 'output', 'model_classification_n_neighbours_0.pkl')

# Path to folder where SHAP files should be written to and read from, e.g. one folder above ml4shiptelemetry.
shap_path = '/path/to/shap_folder/'

# Configure export settings
export_figs = True
export_figs_format = '.png'

# Create rng for deterministic sampling
rng = np.random.default_rng(1011)

## Load data and models

Load data and models and perform some simple postprocessing to classify the predictions into true positives, etc.

In [4]:
# Load data
data = np.load(data_dir)
x = data['x']
y_reg = data['y_reg']
y_class = data['y_class']
x_test = data['x_test']
y_test_reg = data['y_test_reg']
y_test_class = data['y_test_class']
targets_reg = data['targets_reg']
targets_class = data['targets_class']
feature_names = data['feature_names'].tolist()

In [5]:
# Extract training and test data
df_train = pd.DataFrame(data=x, columns=feature_names)
df_train[targets_class.tolist()] = y_class
df_train[targets_reg.tolist()] = y_reg

df_test = pd.DataFrame(data=x_test, columns=feature_names)
df_test[targets_class.tolist()] = y_test_class
df_test[targets_reg.tolist()] = y_test_reg

df_train['train'] = True
df_test['train'] = False

# Combined training and test data to single dataframe
dff = pd.concat([df_train, df_test], axis=0, ignore_index=True)

In [None]:
# Load models
models = joblib.load(model_path)
model_temp = models['Temp_Flag']['classifier'].classifier
model_sal = models['Sal_Flag']['classifier'].classifier
model_reg = joblib.load(model_reg_path).regressor


In [None]:
# Calculate additional performance metrics for classification models
temp_ind = targets_class.tolist().index('Temp_Flag')
sal_ind = targets_class.tolist().index('Sal_Flag')

df_dict = {}
for flag in ['Temp_Flag', 'Sal_Flag']:
    if flag == 'Temp_Flag':
        model = model_temp
        ind = temp_ind
    else:
        model = model_sal
        ind = sal_ind
        
    # True label
    y_class = y_test_class[:, ind]
    
    
    # Predicted label
    y_pred_proba = model.predict_proba(x_test)[:, 1]
    y_pred = model.predict(x_test)

    # Classify predictions into true positives, false, positives, ...
    tp = (y_pred == 1) & (y_pred == y_class)
    fp = (y_pred == 0) & (y_pred != y_class)
    fn = (y_pred == 1) & (y_pred != y_class)
    tn = (y_pred == 0) & (y_pred == y_class)

    # Calcualte difference in probability of the classification output.
    prob_diff = np.abs(y_class - y_pred_proba)

    # Build dataframe with test data as well as the prediction metrics
    dfe = pd.DataFrame(data=x_test, columns=feature_names)
    dfe['true_class'] = y_class
    dfe['pred'] = y_pred
    dfe['prob'] = y_pred_proba
    dfe['prob_diff'] = prob_diff
    dfe['tp'] = tp
    dfe['fp'] = fp
    dfe['fn'] = fn
    dfe['tn'] = tn
    dfe['pred_type'] = (1*dfe['fp'] + 2*dfe['tp'] - 1*dfe['fn'] - 2*dfe['tn']).astype(int)
    df_dict[flag] = dfe

## Create KeplerGL map

In [8]:
# (Optional) Create KeplerGl map
create_kepler_map = False

if create_kepler_map:
    from keplergl import KeplerGl
    map1 = KeplerGl(data={'df': dff[['train', 'SYS.STR.PosLat','SYS.STR.PosLon']+targets_class.tolist()+targets_reg.tolist()].rename(columns={'SYS.STR.PosLat':'latitude','SYS.STR.PosLon':'longitude'})})
    map1.save_to_html(file_name=os.path.join(base_path, 'imgs', 'map_of_data_positions.html'))

# XAI with SHAP

In [9]:
# Load existing SHAP calculations if they exist, otherwise recalculate them
# Temp. classification shap model
try:
    expl_temp = joblib.load(os.path.join(shap_path, 'explainer_temp.pkl'))
except:    
    expl_temp = shap.TreeExplainer(model_temp, feature_names=feature_names)
    joblib.dump(expl_temp, os.path.join(shap_path, 'explainer_temp.pkl'))

# Sal. classification shap model
try:
    expl_sal = joblib.load(os.path.join(shap_path, 'explainer_sal.pkl'))
except:    
    expl_sal = shap.TreeExplainer(model_sal, feature_names=feature_names)
    joblib.dump(expl_sal, os.path.join(shap_path, 'explainer_sal.pkl'))

# Regression classification shap model
try:
    expl_reg = joblib.load('explainer_reg.pkl')
except:
    expl_reg = shap.TreeExplainer(model_reg, feature_names=feature_names)
    joblib.dump(expl_reg, os.path.join(shap_path, 'explainer_reg.pkl'))

In [10]:
# Sample random test samples for SHAP plots, equally many from positive and negative class.

shap_dict = {}
# Temperature
x_test_temp_neg = x_test[y_test_class[:, temp_ind]==0, :]

# Sample positive values from this range
# All negative samples have latitude < 40
# Randomly sample equally many positive samples as negative
n_negative = x_test_temp_neg.shape[0]
pos_ind = y_test_class[:,temp_ind]==1
x_test_temp_pos = x_test[pos_ind, :][rng.choice(pos_ind.sum(), size=n_negative), :]
shap_dict['Temp_Flag'] = {'pos': x_test_temp_pos, 'neg': x_test_temp_neg}

# Salinity
x_test_sal_neg = x_test[y_test_class[:,sal_ind]==0, :]

# Sample positive values from this range
# All negative samples have latitude < 40
# Randomly sample equally many positive samples as negative
n_negative = x_test_sal_neg.shape[0]
pos_ind = y_test_class[:,sal_ind]==1
x_test_sal_pos = x_test[pos_ind, :][rng.choice(pos_ind.sum(), size=n_negative), :]
shap_dict['Sal_Flag'] = {'pos': x_test_sal_pos, 'neg': x_test_sal_neg}

# Sample points from the regression data
# Randomly sample
n_sample = 10000
x_test_reg = x_test[rng.choice(len(x_test), size=n_sample), :]
shap_dict['reg'] = x_test_reg

In [11]:
# Calculate SHAP values

# Temperature flags
# Negative samples
try:
    shap_values_temp_neg = joblib.load(os.path.join(shap_path, 'shap_values_temp_neg.pkl'))
except:    
    shap_values_temp_neg = expl_temp(shap_dict['Temp_Flag']['neg'])
    joblib.dump(shap_values_temp_neg, os.path.join(shap_path, 'shap_values_temp_neg.pkl'))
# Positive samples
try:
    shap_values_temp_pos = joblib.load(os.path.join(shap_path, 'shap_values_temp_pos.pkl'))
except:    
    shap_values_temp_pos = expl_temp(shap_dict['Temp_Flag']['pos'])
    joblib.dump(shap_values_temp_pos, os.path.join(shap_path, 'shap_values_temp_pos.pkl'))

# Salinty flags
# Negative samples
try:
    shap_values_sal_neg = joblib.load(os.path.join(shap_path, 'shap_values_sal_neg.pkl'))
except:    
    shap_values_sal_neg = expl_sal(shap_dict['Sal_Flag']['neg'])
    joblib.dump(shap_values_sal_neg, os.path.join(shap_path, 'shap_values_sal_neg.pkl'))
# Positive samples
try:
    shap_values_sal_pos = joblib.load(os.path.join(shap_path, 'shap_values_sal_pos.pkl'))
except:    
    shap_values_sal_pos = expl_sal(shap_dict['Sal_Flag']['pos'])
    joblib.dump(shap_values_sal_pos, os.path.join(shap_path, 'shap_values_sal_pos.pkl'))

# Regression
try:
    shap_values_reg = joblib.load(os.path.join(shap_path, 'shap_values_reg.pkl'))
except:    
    shap_values_reg = expl_reg(shap_dict['reg'])
    joblib.dump(shap_values_reg, os.path.join(shap_path, 'shap_values_reg.pkl'))


In [12]:
file_name_map = {'Temp [°C]':'temp', 'Cond [S/m]':'cond', 'Temp_int [°C]':'temp_int', 'Sal':'sal'}

## Temperature

Create SHAP value plots for temperature classfication of quality flags, showing the most important features, in terms of contribution to the prediction outcome

In [None]:
# Combine SHAP values from negative and positive classes: Extract SHAP values, base values and data
comb_vals = np.concatenate([shap_values_temp_neg.values, shap_values_temp_pos.values], axis=0)
comb_base_vals = np.concatenate([shap_values_temp_neg.base_values, shap_values_temp_pos.base_values], axis=0)
comb_data = np.concatenate([shap_values_temp_neg.data, shap_values_temp_pos.data], axis=0)
# Create SHAP Explanation from the extracted information
expl_temp = shap.Explanation(values=comb_vals, base_values=comb_base_vals, data=comb_data, feature_names=feature_names)

# Plot feature importance
fig = plt.figure()
ax = shap.plots.beeswarm(expl_temp[:, :, 1], show=False)
ax.set_title('Temperature flag')
if export_figs:
    fig.savefig(os.path.join(base_path, 'imgs', 'shap_temp_flag'+export_figs_format), bbox_inches='tight', transparent=False)
plt.show()

In [None]:
# Plot feature importance for only the negative class ("bad" flags)
fig = plt.figure()
ax = shap.plots.beeswarm(shap_values_temp_neg[:, :, 1], show=False)
ax.set_title('Temperature, bad flags')
if export_figs:
    fig.savefig(os.path.join(base_path, 'imgs', 'shap_temp_flag_bad'+export_figs_format), bbox_inches='tight', transparent=False)
plt.show()

- Flow in the mid range increase likelihood of predicting positive class (good flag). Low or high flow pushed toward negative class (bad flag)
- High month of year push toward positive class (good flag). However, since we only have bad flags in jan, nov and dec is it likely not month related. The ship moves approx in -20 to 20 latitude, so all the time around the equator. Therefore, it is likely more due to sampling than the month itself. Likely we have bad flags toward the end of the measurement sequence
- 

In [None]:
# Plot feature importance for only the positive class ("good" flags)
fig = plt.figure()
ax = shap.plots.beeswarm(shap_values_temp_pos[:, :, 1], show=False)
ax.set_title('Temperature, good flags')
if export_figs:
    fig.savefig(os.path.join(base_path, 'imgs', 'shap_temp_flag_good'+export_figs_format), bbox_inches='tight', transparent=False)

plt.show()

## Salinity

Create SHAP value plots for salinity classfication of quality flags, showing the most important features, in terms of contribution to the prediction outcome

In [None]:
# Combine SHAP values from negative and positive classes: Extract SHAP values, base values and data
comb_vals = np.concatenate([shap_values_sal_neg.values, shap_values_sal_pos.values], axis=0)
comb_base_vals = np.concatenate([shap_values_sal_neg.base_values, shap_values_sal_pos.base_values], axis=0)
comb_data = np.concatenate([shap_values_sal_neg.data, shap_values_sal_pos.data], axis=0)
# Create SHAP Explanation from the extracted information
expl_sal = shap.Explanation(values=comb_vals, base_values=comb_base_vals, data=comb_data, feature_names=feature_names)

# Plot feature importance
fig = plt.figure()
ax = shap.plots.beeswarm(expl_sal[:, :, 1], show=False)
ax.set_title('Salinity flag')
if export_figs:
    fig.savefig(os.path.join(base_path, 'imgs', 'shap_sal_flag'+export_figs_format), bbox_inches='tight', transparent=False)
plt.show()

In [None]:
# Plot feature importance for only the negative class ("bad" flags)
fig = plt.figure()
ax = shap.plots.beeswarm(shap_values_sal_neg[:, :, 1], show=False)
ax.set_title('Salinity, bad flags')
if export_figs:
    fig.savefig(os.path.join(base_path, 'imgs', 'shap_sal_flag_bad'+export_figs_format), bbox_inches='tight', transparent=False)
plt.show()

In [None]:
# Plot feature importance for only the positive class ("good" flags)
fig = plt.figure()
ax = shap.plots.beeswarm(shap_values_sal_pos[:, :, 1], show=False)
ax.set_title('Salinity, good flags')
if export_figs:
    fig.savefig(os.path.join(base_path, 'imgs', 'shap_sal_flag_good'+export_figs_format), bbox_inches='tight', transparent=False)
plt.show()

## Regression

In [None]:
# Plot feature importance for the regression model on the four targets
for ind, t in enumerate(targets_reg):
    fig = plt.figure()
    ax = shap.plots.beeswarm(shap_values_reg[:, :, ind], show=False)
    ax.set_title(t)
    if export_figs:
        fig.savefig(os.path.join(base_path, 'imgs', f'shap_{file_name_map[t]}'+export_figs_format), bbox_inches='tight', transparent=False)
    plt.show()
