# Explainability using SHAP

This notebook details the process of explaining a model's predictions using the [SHAP](https://shap.readthedocs.io/en/latest/index.html) library.

To do this, we will first create a synthetic dataset related to the vital signs of a cohort of people. Some of these people are suffering from an infection, while others are healthy. The data set contains the following features:

* systolic pressure
* diasolic pressure
* daily average body temperature
* current body temperature
* daily average respiration rate
* current respiration rate
* daily average pulse rate
* current pulse rate
* whether the person has the infection (target variable)

We are using a synthetic dataset as the relationships between these features can be controlled. The explanations are therefore easier to relate to the data.

### Import libraries

This notebook requires the SHAP and xgboost libraries that are not installed by default. Use the commands below to install these modules

In [None]:
!pip install shap==0.40.0;
!pip install xgboost==1.5.0;

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, plot_confusion_matrix
from xgboost import XGBClassifier

np.random.seed(0)

CORAL = "#FA7268"
GREY = "#9C9C9C"
TEAL = "#05C5C5"
YELLOW = "#F4C41A"

### Generate synthetic data

To create the data, let us first define the relationships that will go into the data.

* systolic pressure - sampled from $\mathcal{N}(105, 10)$
* diasolic pressure - 'systolic pressure' $\times\ \mathcal{N}(\frac{2}{3}, 0.01)$
* daily average body temperature - sampled from $\mathcal{N}(36.7, 0.05)$
* current body temperature - 'daily average body temperature' + $\mathcal{N}(0, 0.01)$
* daily average respiration rate - sampled from $\mathcal{N}(14, 0.8)$
* current respiration rate - 'daily average respiration rate' + $\mathcal{N}(0, 0.5)$ _if healthy, else_ 'daily average respiration rate' + $\mathcal{N}(1.5, 0.4)$

* daily average pulse rate - daily average respiration rate $\times\ \mathcal{N}(5.5, 0.2)$
* current pulse rate - daily average pulse rate + $\mathcal{N}(0, 4)$ _if healthy, else_ 'daily average pulse rate' + $\mathcal{N}(10, 5)$
* whether the person has the infection (target variable)

The following cell contains two functions

In [None]:
#pressure 
SYSTOLIC_PRESSURE = 105
SYSTOLIC_PRESSURE_STD = 10
SYSTOLIC_TO_DIASTOLIC_FACTOR = 2./3.
SYSTOLIC_TO_DIASTOLIC_FACTOR_STD = 0.01

# body temperature
MEAN_BODY_TEMP = 36.7
MEAN_BODY_TEMP_STD = 0.05
BODY_TEMP_NOISE = 0
BODY_TEMP_NOISE_STD = 0.01

# respiration rate
MEAN_RESPIRATION_RATE = 14
MEAN_RESPIRATION_RATE_STD = 0.8
RESPIRATION_RATE_NOISE = 0
RESPIRATION_RATE_NOISE_STD = 0.5
RESPIRATION_RATE_NOISE_INFECTED = 1.5
RESPIRATION_RATE_NOISE_STD_INFECTED = 0.4

# pulse rate
RESPIRATION_TO_PULSE_FACTOR = 5.5
RESPIRATION_TO_PULSE_FACTOR_NOISE = 0.2
PULSE_RATE_NOISE = 0
PULSE_RATE_NOISE_STD = 4
PULSE_RATE_NOISE_INFECTED = 10
PULSE_RATE_NOISE_STD_INFECTED = 5

def generate_healthy(n_patients):
    """Generates healthy patients

    Parameters
    ----------
    n_patients: int
        Number of patients to generate

    Returns
    -------
    patients: np.ndarray
        Array of data points of shape (n_patients, 9)
    """
    # pressure
    systolic_pressure = np.random.normal(
        SYSTOLIC_PRESSURE, SYSTOLIC_PRESSURE_STD, n_patients
    )
    diastolic_pressure = (
        systolic_pressure * np.random.normal(
            SYSTOLIC_TO_DIASTOLIC_FACTOR,
            SYSTOLIC_TO_DIASTOLIC_FACTOR_STD,
            n_patients
        )
    )
    
    # body temperature
    daily_avg_body_temperature = np.random.normal(
        MEAN_BODY_TEMP,
        MEAN_BODY_TEMP_STD,
        n_patients
    )
    body_temperature = (
        daily_avg_body_temperature 
        + np.random.normal(
            BODY_TEMP_NOISE,
            BODY_TEMP_NOISE_STD,
            n_patients
        )
    )
    
    # respiration rate
    daily_avg_respiration_rate = np.random.normal(
        MEAN_RESPIRATION_RATE,
        MEAN_RESPIRATION_RATE_STD,
        n_patients
    )
    respiration_rate = (
        daily_avg_respiration_rate 
        + np.random.normal(
            RESPIRATION_RATE_NOISE,
            RESPIRATION_RATE_NOISE_STD,
            n_patients
        )
    )
    
    # pulse rate
    daily_avg_pulse_rate = (
        daily_avg_respiration_rate 
        * np.random.normal(
            RESPIRATION_TO_PULSE_FACTOR,
            RESPIRATION_TO_PULSE_FACTOR_NOISE,
            n_patients
        )
    )
    pulse_rate = (
        daily_avg_pulse_rate 
        + np.random.normal(
            PULSE_RATE_NOISE,
            PULSE_RATE_NOISE_STD,
            n_patients
        )
    )
    
    infection = np.zeros(n_patients)
    
    patients = pd.DataFrame(
        np.stack(
            [
                systolic_pressure,
                diastolic_pressure,
                daily_avg_body_temperature,
                body_temperature,
                daily_avg_respiration_rate,
                respiration_rate,
                daily_avg_pulse_rate,
                pulse_rate,
                infection,
            ],
            axis=1
        ),
        columns=[
            "systolic_pressure",
            "diastolic_pressure",
            "daily_avg_body_temperature",
            "body_temperature",
            "daily_avg_respiration_rate",
            "respiration_rate",
            "daily_avg_pulse_rate",
            "pulse_rate",
            "infection",
        ]
    )

    return patients


def generate_infected(n_patients):
    """Generates patients infected with illness that
    raises respiration and pulse

    Parameters
    ----------
    n_patients: int
        Number of patients to generate

    Returns
    -------
    patients: np.ndarray
        Array of data points of shape (n_patients, 9)
    """
    systolic_pressure = np.random.normal(
        SYSTOLIC_PRESSURE, SYSTOLIC_PRESSURE_STD, n_patients
    )
    diastolic_pressure = (
        systolic_pressure * np.random.normal(
            SYSTOLIC_TO_DIASTOLIC_FACTOR,
            SYSTOLIC_TO_DIASTOLIC_FACTOR_STD,
            n_patients
        )
    )
    
    # body temperature
    daily_avg_body_temperature = np.random.normal(
        MEAN_BODY_TEMP,
        MEAN_BODY_TEMP_STD,
        n_patients
    )
    body_temperature = (
        daily_avg_body_temperature 
        + np.random.normal(
            BODY_TEMP_NOISE,
            BODY_TEMP_NOISE_STD,
            n_patients
        )
    )
    
    # respiration rate
    daily_avg_respiration_rate = np.random.normal(
        MEAN_RESPIRATION_RATE,
        MEAN_RESPIRATION_RATE_STD,
        n_patients
    )
    respiration_rate = (
        daily_avg_respiration_rate 
        + np.random.normal(
            RESPIRATION_RATE_NOISE_INFECTED,
            RESPIRATION_RATE_NOISE_STD_INFECTED,
            n_patients
        )
    )
    
    # pulse rate
    daily_avg_pulse_rate = (
        daily_avg_respiration_rate 
        * np.random.normal(
            RESPIRATION_TO_PULSE_FACTOR,
            RESPIRATION_TO_PULSE_FACTOR_NOISE,
            n_patients
        )
    )
    pulse_rate = (
        daily_avg_pulse_rate 
        + np.random.normal(
            PULSE_RATE_NOISE_INFECTED,
            PULSE_RATE_NOISE_STD_INFECTED,
            n_patients
        )
    )
    
    infection = np.ones(n_patients)
    
    patients = pd.DataFrame(
        np.stack(
            [
                systolic_pressure,
                diastolic_pressure,
                daily_avg_body_temperature,
                body_temperature,
                daily_avg_respiration_rate,
                respiration_rate,
                daily_avg_pulse_rate,
                pulse_rate,
                infection
            ],
            axis=1
        ),
        columns=[
            "systolic_pressure",
            "diastolic_pressure",
            "daily_avg_body_temperature",
            "body_temperature",
            "daily_avg_respiration_rate",
            "respiration_rate",
            "daily_avg_pulse_rate",
            "pulse_rate",
            "infection",
        ]
    )

    return patients

We will use these functions to generate a dataset containing 1000 individuals, 100 of which are suffering from the infection. We will then create a training and test set from this data set that we can use to train our model.

In [None]:
healthy = generate_healthy(900)
infected = generate_infected(100)

data = pd.concat((healthy, infected)).reset_index(drop=True)

# split the data into a training and test set
X_train, X_test, y_train, y_test = train_test_split(
    data.drop("infection", axis=1),
    data["infection"].values,
    test_size=0.2,
    random_state=0,
)

## Exploration of the data

Now that we have seen how the data is generated, lets explore the relationships between the features in the dataset and the target variable.

**Use the functions below to explore the data and understand the relationship between the features and the target variable**

In [None]:
def plot_histogram(X, y, feature):
    """
    Plot a histogram of a feature,
    coloured by the target variable
    
    Parameters
    ----------
    X: pd.DataFrame
        Dataset containing feature1 and feature2
    y: np.array
        Target variable
    feature: str
        Name of feature of interest
    """
    plt.figure(figsize=(7, 5))
    plt.hist(
        X.loc[y==0, feature],
        bins=30,
        color=GREY,
        histtype="step",
        linewidth=2,
        range=(np.min(X[feature]), np.max(X[feature])),
        label="Healthy"
    )
    plt.hist(
        X.loc[y==1, feature],
        bins=30,
        color=CORAL,
        histtype="step",
        linewidth=2,
        range=(np.min(X[feature]), np.max(X[feature])),
        label="Infected"
    )
    plt.xlabel(" ".join(feature.split("_")))
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()
    

def plot_scatter(X, y, feature1, feature2):
    """
    Plot a scatter plot of feature1 vs feature2,
    where the points are coloured by the target variable
    
    Parameters
    ----------
    X: pd.DataFrame
        Dataset containing feature1 and feature2
    y: np.array
        Target variable
    feature1: str
        Name of feature of interest
    feature2: str
        Name of other feature of interest
    """
    plt.figure(figsize=(7, 5))
    plt.plot(
        X.loc[y==0, feature1],
        X.loc[y==0, feature2],
        ".",
        label="healthy",
        color=GREY,
    )
    plt.plot(
        X.loc[y==1, feature1],
        X.loc[y==1, feature2],
        ".",
        label="infected",
        color=CORAL,
    )
    plt.xlabel(feature1)
    plt.ylabel(feature2)
    plt.legend()
    plt.show()

In [None]:
# plot histograms for all features
for feature in X_train.columns:
    plot_histogram(X_train, y_train, feature)

In [None]:
# plot scatter plot of daily_avg_respiration_rate vs respiration_rate
plot_scatter(X_train, y_train, "daily_avg_respiration_rate", "respiration_rate")

Here we can see that the infection is causing the current respiration rate to be elevated above the average for this individual, outside of the normal variation that we see in healthy individuals

In [None]:
# plot scatter plot of daily_avg_pulse_rate vs pulse_rate
plot_scatter(X_train, y_train, "daily_avg_pulse_rate", "pulse_rate")

We can also see that the infection is having a similar effect on the pulse rate when compared to their average for infected individuals

In [None]:
plot_scatter(X_train, y_train, "respiration_rate", "pulse_rate")

Infected individuals can also be somewhat separated based on their pulse rate and respiration rate alone, but some infected individuals fall within the normal healthy range, and so would be hard to spot on these features alone.

In [None]:
plot_scatter(X_train, y_train, "daily_avg_respiration_rate", "daily_avg_pulse_rate")

However, just by looking at the daily averages of pulse and respiration rate, we cannot discern between healthy and infected individuals

In [None]:
plot_scatter(X_train, y_train, "systolic_pressure", "body_temperature")

The infection has no effect on either of the pressure features or the body temperature features.

Based on our exploration of the data, we would expect that the model would most likely try to use the _respiration_rate_ and _pulse_rate_ features. However, it is important to note that just using these features alone it would be hard to accurately predict whether an individual is infected or not. In order to separate these individuals, we must compare their current respiration rate and pulse rate to their daily average. When we do this, we can see that it is elevated outside of the usual distribution.

We would therefore expect the most important features in any model to be _respiration_rate_, _pulse_rate_, _daily_avg_respiration_rate_ and _daily_avg_pulse_rate_.

## Train a model

Now we will train a model to predict which individuals are infected and which are healthy. We will use xgboost in this demonstration.

In [None]:
# Train the model
clf = XGBClassifier(
    n_estimators=100,
    max_depth=3,
    random_state=0,
    use_label_encoder=False,
    eval_metric="logloss",
    n_jobs=-1,
)
clf.fit(X_train, y_train.astype(np.int32))

As this dataset is very simple, we can see that our model performs very well, getting 99% accuracy on the test set.

In [None]:
# show model metrics
print("Accuracy Score: ", accuracy_score(y_test, clf.predict(X_test)))
plot_confusion_matrix(clf, X_test, y_test, normalize="true")
plt.show()

We can investigate the performance further by looking at the predictions in terms of the predicted probabilty of each individual being infected. By plotting a histogram of the predicted probabilties for each class, we can see that the model is very confident with its predictions, and only makes a few mistakes.

In [None]:
def plot_proba_scores(ax, scores_list):
    """
    Plot a histogram of the predicted probabilities for each class
    """
    bins = 20
    for scores, label, color in zip(
        scores_list,
        [
            "Healthy",
            "Infected",
        ],
        [GREY, CORAL],
    ):
        ax.hist(
            scores,
            bins=bins,
            label=label,
            color=color,
            histtype="step",
            linewidth=2,
            range=(0, 1),
        )
    ax.set_xlabel("Predicted probability of infection")
    ax.set_ylabel("Frequency")
    ax.set_yscale("log")
    ax.legend()

predicted_healthy = clf.predict_proba(X_test[y_test==0])[:, 1]
predicted_infected = clf.predict_proba(X_test[y_test==1])[:, 1]

fig, ax = plt.subplots(1, 1)
plot_proba_scores(ax, [predicted_healthy, predicted_infected])
plt.show()

We can also compare the predicted and ground truth scatter plots, to visually evaluate the performance.

In [None]:
def plot_scatter_prediction(clf, X_test, feature1, feature2):
    """
    Plot a scatter plot coloured by predicted probability of being infected
    """
    plt.figure(figsize=(8.8, 5))
    scores = clf.predict_proba(X_test)[:, 1]
    plt.scatter(
        X_test[feature1],
        X_test[feature2],
        marker=".",
        c=scores,
        vmin=0,
        vmax=1,
    )
    plt.colorbar()
    plt.xlabel(feature1)
    plt.ylabel(feature2)
    plt.show()

plot_scatter(X_test, y_test, "daily_avg_respiration_rate", "respiration_rate")
plot_scatter_prediction(clf, X_test, "daily_avg_respiration_rate", "respiration_rate")

## Explanations 

#### Feature importance in xgboost

The importance of each feature in an xgboost model is related to the improvement in performance given by splitting on that feature, accounting for the number of data points that are affected by that split. The improvement in performance acorss all of the splits on that feature across all of the trees in the model are then averaged, to create a measure of importance.

For a more detailed discussion of this, see this [blog post](https://machinelearningmastery.com/feature-importance-and-feature-selection-with-xgboost-in-python/)

However, feature importance only provides an estimate of relative importance, and does not indicate how much that feature influenced the prediction. These importances are _global_ i.e. they show the importance of features across the whole dataset.


In [None]:
values = sorted(clf.feature_importances_)
names = [x for _, x in sorted(zip(clf.feature_importances_, data.columns[:-1]))]

plt.barh(names, values, color=CORAL)
plt.xlabel("Importance")
plt.show()

### Explanations with SHAP 

We will now use SHAP to explain the predictions made by this dataset. Because we are using a small dataset we can use the Exact explainer API, which calculates the full shapely value sum. When datasets are larger, the SHAP library offers alternative methods such as KernelSHAP which can is also model agnostic, or TreeSHAP, which is specifically for tree-based models.

In [None]:
explainer = shap.explainers.Exact(clf.predict_proba, X_test)
shap_values = explainer(X_test)

First we will look at the _global_ explanation provided by SHAP. Here we can see that it has identified the four features we expected to be important to the model.

In [None]:
shap.summary_plot(
    shap_values.values[:,:, 1],
    X_test,
    plot_type="bar",
    show=False
)
plt.xlim(0, 0.7)
plt.show()

However, we can also split this view into looking at subsections of the data. In the next two figures we explain predictions made on healthy individuals and infected individuals. We can see, that the shapley values of each feature are much larger for the infected individuals. We can investigate this further in the next plots. 

In [None]:
shap.summary_plot(
    shap_values.values[y_test==0,:, 1],
    X_test.loc[y_test==0, :],
    plot_type="bar",
    show=False
)
plt.xlim(0, 0.7)
plt.show()

In [None]:
shap.summary_plot(
    shap_values.values[y_test==1,:, 1],
    X_test.loc[y_test==1, :],
    plot_type="bar",
    show=False
)
plt.xlim(0, 0.7)
plt.show()

In [None]:
def highlight_point(X, feature1, feature2, index):
    """Highlight a point in a scatter plot between two features"""
    plt.plot(X_test.loc[:, feature1], X_test.loc[:, feature2], ".", color=GREY)
    plt.plot(
        X_test.iloc[index].loc[feature1],
        X_test.iloc[index].loc[feature2],
        "o",
        color=CORAL
    )
    plt.xlabel(feature1)
    plt.ylabel(feature2)
    plt.show()

If we look at the _local_ explanation for an infected individual, we can see how the shapley values are related to the predicted probability of being infected. In the top plot, we see the absolute shapley value for each feature. In the next plot, we can see how these shapley values sum to the predicted probability of infection for that data point.

In [None]:
index = np.where(y_test == 1)[0][0]

print(
    "Predicted probability: ",
    clf.predict_proba(X_test.iloc[index].values.reshape(1, -1))[0][1]
)

shap.summary_plot(
    shap_values.values[index, :, 1].reshape(1, -1),
    X_test.loc[y_test==1, :],
    plot_type="bar",
    show=False
)
plt.xlim(0, 0.7)
plt.show()

shap.plots.waterfall(shap_values[index][:, 1])
plt.show()

highlight_point(X_test, "daily_avg_respiration_rate", "respiration_rate", index=index)
highlight_point(X_test, "respiration_rate", "pulse_rate", index=index)
highlight_point(X_test, "daily_avg_respiration_rate", "pulse_rate", index=index)    

We can do the same for a _local_ explanation of a healthy individual. Here we can see that the low values of _respiration_rate_ and _daily_avg_respiration_rate_ decrease the probability of this individual being infected.

In [None]:
index = np.where(y_test == 0)[0][0]

print(
    "Predicted probability: ",
    clf.predict_proba(X_test.iloc[index].values.reshape(1, -1))[0][1]
)

shap.summary_plot(
    shap_values.values[index, :, 1].reshape(1, -1),
    X_test.loc[y_test==1, :],
    plot_type="bar",
    show=False
)
plt.xlim(0, 0.7)
plt.show()

shap.plots.waterfall(shap_values[index][:, 1])
plt.show()

highlight_point(X_test, "daily_avg_respiration_rate", "respiration_rate", index=index)

We can also look at the errors that the model made. Here we look at a false negative prediction. We can see that the predicted probability is close to 0.5, meaning that the model is uncertain about this data point. The two most important features are _respiration_rate_ and _daily_avg_respiration_rate_.

In [None]:

index = np.argwhere(clf.predict(X_test) != y_test)[1][0]

print("Class: ", y_test[index])
print(
    "Predicted probability: ",
    clf.predict_proba(X_test.iloc[index].values.reshape(1, -1))[0][1]
)

shap.summary_plot(
    shap_values.values[index, :, 1].reshape(1, -1),
    X_test.loc[y_test==1, :],
    plot_type="bar",
    show=False
)
plt.xlim(0, 0.6)
plt.show()

shap.plots.waterfall(shap_values[index][:, 1])
plt.show()

highlight_point(X_test, "daily_avg_respiration_rate", "respiration_rate", index=index)
highlight_point(X_test, "daily_avg_respiration_rate", "pulse_rate", index=index)

In this example, we can see that despite _respiration_rate_ and _daily_avg_respiration_rate_ both being high, the prediction is uncertain. We can look at the relationship between the value of each feature and the shapley value using the summary_plot. Here we can see that the model is picking up on combinations of high _respiration_rate_ with low _daily_avg_respiration_rate_.

In [None]:
shap.summary_plot(shap_values.values[..., 1], X_test)

## Summary

In this example we've explored a simple synthetic data set that highlights how shapley values can be used to explain predictions made by machine learning models.