
# SHAP Plots with `treeinterpreter` and Random Forest

In this notebook, we'll demonstrate how to create SHAP-type plots using the `treeinterpreter` package along with SHAP. SHAP (SHapley Additive exPlanations) plots are a popular method for interpreting machine learning models by showing the contribution of each feature to a specific prediction. `treeinterpreter` is a tool that breaks down the predictions of tree-based models (e.g., Random Forests) into individual feature contributions. By combining these two tools, we can visualize and interpret the impact of features on model predictions.

We'll walk through two examples using different datasets and discuss how to customize the plots for different classes. 

### Prerequisites
Before running the examples, ensure that you have the necessary libraries installed:
- `shap`
- `treeinterpreter`
- `scikit-learn`


In [None]:
# Import necessary libraries
import numpy as np
import shap  # Library for SHAP plots
from treeinterpreter import treeinterpreter as ti  # Library to interpret tree-based models
from sklearn.datasets import load_wine, load_iris  # Datasets for example purposes
from sklearn.ensemble import RandomForestClassifier  # Model used in this example
import matplotlib.pyplot as plt  # For plotting

# No docstring needed here since these are import statements


In [None]:
## Example 1: SHAP Plots for Iris Dataset

# Load dataset and train a RandomForest model
data = load_iris()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
# 'contributions' gives the impact of each feature on the prediction
prediction, bias, contributions = ti.predict(model, X)

# For visualization, we select the SHAP values for one class (class 0)
shap_values = contributions[:, :, 0]  # Choose class 0 for visualization

# Create a SHAP Explanation object, necessary for generating SHAP plots
explainer = shap.Explainer(model)
shap_object = shap.Explanation(
    values=shap_values,
    base_values=bias[:, 0],  # Base values should match the selected class
    data=X,
    feature_names=data.feature_names
)

# Generate various SHAP plots
shap.summary_plot(shap_object.values, shap_object.data, feature_names=shap_object.feature_names)
shap.waterfall_plot(shap_object[0])  # Example for the first instance

# Generate a SHAP bar plot, showing mean absolute values across all instances
mean_abs_shap_values = np.abs(shap_object.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values, feature_names=shap_object.feature_names)


## Example 1: SHAP Plots for Iris Dataset

def run_example_1():
    """
    This example demonstrates how to generate SHAP plots for a RandomForest model trained on the Iris dataset.
    We use the `treeinterpreter` package to break down the model's predictions into individual feature contributions
    and visualize them using SHAP plots.
    """
    
    # Load the Iris dataset and train a RandomForest model
    data = load_iris()  # Iris dataset, commonly used for classification tasks
    X, y = data.data, data.target  # Features and target variable

    # Create and train a RandomForestClassifier
    model = RandomForestClassifier()  # Instantiate the model
    model.fit(X, y)  # Train the model on the entire dataset

    # Use treeinterpreter to get the prediction, bias, and contributions
    # `ti.predict` returns:
    # - `prediction`: The predicted probabilities
    # - `bias`: The average value (the bias term)
    # - `contributions`: The contribution of each feature to the prediction
    prediction, bias, contributions = ti.predict(model, X)

    # For visualization, we select the SHAP values for one class (class 0)
    # `contributions` is a 3D array (samples, features, classes)
    # Here, we slice it to get the SHAP values for class 0
    shap_values = contributions[:, :, 0]  # Choose class 0 for visualization

    # Create a SHAP Explanation object, necessary for generating SHAP plots
    # The SHAP `Explanation` object is structured data used to create SHAP plots
    explainer = shap.Explainer(model)  # Create a SHAP explainer (though not used directly here)
    shap_object = shap.Explanation(
        values=shap_values,  # The SHAP values for the selected class
        base_values=bias[:, 0],  # Base values should match the selected class (class 0)
        data=X,  # The input features
        feature_names=data.feature_names  # The names of the features
    )

    # Generate various SHAP plots
    # Summary plot shows the feature importance across all instances
    shap.summary_plot(shap_object.values, shap_object.data, feature_names=shap_object.feature_names)

    # Waterfall plot shows the contribution of each feature to the prediction for a single instance
    shap.waterfall_plot(shap_object[0])  # Example for the first instance in the dataset

    # Generate a SHAP bar plot, showing mean absolute values across all instances
    # Bar plot is useful to visualize the overall feature importance
    mean_abs_shap_values = np.abs(shap_object.values).mean(axis=0)
    shap.bar_plot(mean_abs_shap_values, feature_names=shap_object.feature_names)

# Run the example
run_example_1()


In [None]:
## Example 2: Customizing SHAP Plots for Multiple Classes

# Load the Wine dataset and train a RandomForest model
data = load_wine()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
prediction, bias, contributions = ti.predict(model, X)

# Select SHAP values for two different classes
shap_values_class_0 = contributions[:, :, 0]  # Class 0
shap_values_class_1 = contributions[:, :, 1]  # Class 1

# Base values for each class
base_values_class_0 = bias[:, 0]
base_values_class_1 = bias[:, 1]

# Create SHAP Explanation objects for each class
explainer = shap.Explainer(model)
shap_object_class_0 = shap.Explanation(
    values=shap_values_class_0,
    base_values=base_values_class_0,
    data=X,
    feature_names=data.feature_names
)

shap_object_class_1 = shap.Explanation(
    values=shap_values_class_1,
    base_values=base_values_class_1,
    data=X,
    feature_names=data.feature_names
)

# Plotting SHAP visuals for Class 0 with custom x-axis label
print("Class 0 SHAP Visualizations:")
shap.summary_plot(shap_object_class_0.values, shap_object_class_0.data, feature_names=shap_object_class_0.feature_names, show=False)
plt.gca().set_xlabel("Custom ----> LABEL 0")  # Modify the x-axis label
plt.show()  # Display the plot with the updated label
shap.waterfall_plot(shap_object_class_0[0], max_display=10)
mean_abs_shap_values_class_0 = np.abs(shap_object_class_0.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values_class_0, feature_names=shap_object_class_0.feature_names)

# Plotting SHAP visuals for Class 1 with custom x-axis label
print("\nClass 1 SHAP Visualizations:")
shap.summary_plot(shap_object_class_1.values, shap_object_class_1.data, feature_names=shap_object_class_1.feature_names, show=False)
plt.gca().set_xlabel("Custom ----> LABEL 1")  # Modify the x-axis label
plt.show()  # Display the plot with the updated label
shap.waterfall_plot(shap_object_class_1[0], max_display=10)
mean_abs_shap_values_class_1 = np.abs(shap_object_class_1.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values_class_1, feature_names=shap_object_class_1.feature_names)


## Example 2: Customizing SHAP Plots for Multiple Classes

def run_example_2():
    """
    This example demonstrates how to generate SHAP plots for different classes of a RandomForest model
    trained on the Wine dataset. We customize the SHAP plots to display results for two different classes.
    """

    # Load the Wine dataset and train a RandomForest model
    data = load_wine()  # Load the wine dataset, which has multiple classes
    X, y = data.data, data.target  # Features and target variable

    # Create and train a RandomForestClassifier
    model = RandomForestClassifier()  # Instantiate the model
    model.fit(X, y)  # Train the model on the entire dataset

    # Use treeinterpreter to get the prediction, bias, and contributions
    # `ti.predict` returns:
    # - `prediction`: The predicted probabilities
    # - `bias`: The average value (the bias term)
    # - `contributions`: The contribution of each feature to the prediction
    prediction, bias, contributions = ti.predict(model, X)

    # Select SHAP values for two different classes (class 0 and class 1)
    # `contributions` is a 3D array (samples, features, classes)
    shap_values_class_0 = contributions[:, :, 0]  # SHAP values for Class 0
    shap_values_class_1 = contributions[:, :, 1]  # SHAP values for Class 1

    # Base values for each class
    # The base values represent the bias term for the specific class
    base_values_class_0 = bias[:, 0]  # Bias for Class 0
    base_values_class_1 = bias[:, 1]  # Bias for Class 1

    # Create SHAP Explanation objects for each class
    # These objects are used for SHAP visualizations
    explainer = shap.Explainer(model)  # SHAP explainer (though not used directly here)

    # SHAP object for Class 0
    shap_object_class_0 = shap.Explanation(
        values=shap_values_class_0,  # SHAP values for Class 0
        base_values=base_values_class_0,  # Base values for Class 0
        data=X,  # Input features
        feature_names=data.feature_names  # Feature names
    )

    # SHAP object for Class 1
    shap_object_class_1 = shap.Explanation(
        values=shap_values_class_1,  # SHAP values for Class 1
        base_values=base_values_class_1,  # Base values for Class 1
        data=X,  # Input features
        feature_names=data.feature_names  # Feature names
    )

    # Plotting SHAP visuals for Class 0 with a custom x-axis label
    print("Class 0 SHAP Visualizations:")
    # Summary plot for Class 0
    shap.summary_plot(shap_object_class_0.values, shap_object_class_0.data, feature_names=shap_object_class_0.feature_names, show=False)
    plt.gca().set_xlabel("Custom ----> LABEL 0")  # Modify the x-axis label
    plt.show()  # Display the plot with the updated label

    # Waterfall plot for the first instance in Class 0
    shap.waterfall_plot(shap_object_class_0[0], max_display=10)

    # Bar plot for Class 0 showing mean absolute SHAP values
    mean_abs_shap_values_class_0 = np.abs(shap_object_class_0.values).mean(axis=0)
    shap.bar_plot(mean_abs_shap_values_class_0, feature_names=shap_object_class_0.feature_names)

    # Plotting SHAP visuals for Class 1 with a custom x-axis label
    print("\nClass 1 SHAP Visualizations:")
    # Summary plot for Class 1
    shap.summary_plot(shap_object_class_1.values, shap_object_class_1.data, feature_names=shap_object_class_1.feature_names, show=False)
    plt.gca().set_xlabel("Custom ----> LABEL 1")  # Modify the x-axis label
    plt.show()  # Display the plot with the updated label

    # Waterfall plot for the first instance in Class 1
    shap.waterfall_plot(shap_object_class_1[0], max_display=10)

    # Bar plot for Class 1 showing mean absolute SHAP values
    mean_abs_shap_values_class_1 = np.abs(shap_object_class_1.values).mean(axis=0)
    shap.bar_plot(mean_abs_shap_values_class_1, feature_names=shap_object_class_1.feature_names)

# Run the example
run_example_2()


In [None]:
import numpy as np
import shap
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from treeinterpreter import treeinterpreter as ti
import matplotlib.pyplot as plt

# Load dataset and train the model
data = load_iris()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
prediction, bias, contributions = ti.predict(model, X)

# Contributions.shape is (n_samples, n_features, n_classes)
# We reduce the dimensionality by selecting one class
shap_values = contributions[:, :, 0]  # Choose class 0 for visualization

# Creating a SHAP Explanation object
explainer = shap.Explainer(model)
shap_object = shap.Explanation(
    values=shap_values,
    base_values=bias[:, 0],  # Base values should match the selected class
    data=X,
    feature_names=data.feature_names
)

# Generate SHAP summary plot (beeswarm plot) and modify the x-axis label directly
shap.summary_plot(shap_object.values, shap_object.data, feature_names=shap_object.feature_names, show=False)
plt.gca().set_xlabel("CUSTOM ------> CUSTOM ----->")  # Modify the x-axis label
plt.show()  # Display the plot with the updated label

# Generate SHAP waterfall plot for the first instance
shap.waterfall_plot(shap_object[0])

# For the bar plot, extract the mean absolute values across all instances
mean_abs_shap_values = np.abs(shap_object.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values, feature_names=shap_object.feature_names)

## Example 2

In [None]:
# Load the Wine dataset
data = load_wine()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
prediction, bias, contributions = ti.predict(model, X)

# contributions.shape is (n_samples, n_features, n_classes)
# We reduce the dimensionality by selecting one class
shap_values = contributions[:, :, 0]  # Choose class 0 for visualization

# Creating a SHAP Explanation object
explainer = shap.Explainer(model)
shap_object = shap.Explanation(
    values=shap_values,
    base_values=bias[:, 0],  # Base values should match the selected class
    data=X,
    feature_names=data.feature_names
)

# Generate SHAP plots
shap.summary_plot(shap_object.values, shap_object.data, feature_names=shap_object.feature_names)
shap.waterfall_plot(shap_object[0])  # Example for the first instance

# For the bar plot, extract the mean absolute values across all instances
mean_abs_shap_values = np.abs(shap_object.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values, feature_names=shap_object.feature_names)

## Example -2 with custom X_labels

In [None]:
import numpy as np
import shap
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from treeinterpreter import treeinterpreter as ti
import matplotlib.pyplot as plt

# Load the Wine dataset and train the model
data = load_wine()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
prediction, bias, contributions = ti.predict(model, X)

# contributions.shape is (n_samples, n_features, n_classes)
# We reduce the dimensionality by selecting one class
shap_values = contributions[:, :, 0]  # Choose class 0 for visualization

# Creating a SHAP Explanation object
explainer = shap.Explainer(model)
shap_object = shap.Explanation(
    values=shap_values,
    base_values=bias[:, 0],  # Base values should match the selected class
    data=X,
    feature_names=data.feature_names
)

# Generate SHAP summary plot (beeswarm plot) and modify the x-axis label directly
shap.summary_plot(shap_object.values, shap_object.data, feature_names=shap_object.feature_names, show=False)
plt.gca().set_xlabel("Custom -----> CUSTOM -----> CUSTOM")  # Modify the x-axis label
plt.show()  # Display the plot with the updated label

# Generate SHAP waterfall plot for the first instance
shap.waterfall_plot(shap_object[0])

# For the bar plot, extract the mean absolute values across all instances
mean_abs_shap_values = np.abs(shap_object.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values, feature_names=shap_object.feature_names)

## Example - 3

In [None]:
# Load the Wine dataset
data = load_wine()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
prediction, bias, contributions = ti.predict(model, X)

# Select SHAP values for two different classes
shap_values_class_0 = contributions[:, :, 0]  # Class 0
shap_values_class_1 = contributions[:, :, 1]  # Class 1

# Base values for each class
base_values_class_0 = bias[:, 0]
base_values_class_1 = bias[:, 1]

# Create SHAP Explanation objects for each class
explainer = shap.Explainer(model)
shap_object_class_0 = shap.Explanation(
    values=shap_values_class_0,
    base_values=base_values_class_0,
    data=X,
    feature_names=data.feature_names
)

shap_object_class_1 = shap.Explanation(
    values=shap_values_class_1,
    base_values=base_values_class_1,
    data=X,
    feature_names=data.feature_names
)

# Plotting SHAP visuals for Class 0
print("Class 0 SHAP Visualizations:")
shap.summary_plot(shap_object_class_0.values, shap_object_class_0.data, feature_names=shap_object_class_0.feature_names)
shap.waterfall_plot(shap_object_class_0[0], max_display=10)
mean_abs_shap_values_class_0 = np.abs(shap_object_class_0.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values_class_0, feature_names=shap_object_class_0.feature_names)

# Plotting SHAP visuals for Class 1
print("\nClass 1 SHAP Visualizations:")
shap.summary_plot(shap_object_class_1.values, shap_object_class_1.data, feature_names=shap_object_class_1.feature_names)
shap.waterfall_plot(shap_object_class_1[0], max_display=10)
mean_abs_shap_values_class_1 = np.abs(shap_object_class_1.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values_class_1, feature_names=shap_object_class_1.feature_names)

## Example - 3 with custom X_labels

In [None]:
import numpy as np
import shap
from sklearn.datasets import load_wine
from sklearn.ensemble import RandomForestClassifier
from treeinterpreter import treeinterpreter as ti
import matplotlib.pyplot as plt

# Load the Wine dataset and train the model
data = load_wine()
X, y = data.data, data.target
model = RandomForestClassifier()
model.fit(X, y)

# Use treeinterpreter to get the prediction, bias, and contributions
prediction, bias, contributions = ti.predict(model, X)

# Select SHAP values for two different classes
shap_values_class_0 = contributions[:, :, 0]  # Class 0
shap_values_class_1 = contributions[:, :, 1]  # Class 1

# Base values for each class
base_values_class_0 = bias[:, 0]
base_values_class_1 = bias[:, 1]

# Create SHAP Explanation objects for each class
explainer = shap.Explainer(model)
shap_object_class_0 = shap.Explanation(
    values=shap_values_class_0,
    base_values=base_values_class_0,
    data=X,
    feature_names=data.feature_names
)

shap_object_class_1 = shap.Explanation(
    values=shap_values_class_1,
    base_values=base_values_class_1,
    data=X,
    feature_names=data.feature_names
)

# Plotting SHAP visuals for Class 0
print("Class 0 SHAP Visualizations:")
shap.summary_plot(shap_object_class_0.values, shap_object_class_0.data, feature_names=shap_object_class_0.feature_names, show=False)
plt.gca().set_xlabel("Custom ----> LABEL 0")  # Modify the x-axis label
plt.show()  # Display the plot with the updated label
shap.waterfall_plot(shap_object_class_0[0], max_display=10)
mean_abs_shap_values_class_0 = np.abs(shap_object_class_0.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values_class_0, feature_names=shap_object_class_0.feature_names)

# Plotting SHAP visuals for Class 1
print("\nClass 1 SHAP Visualizations:")
shap.summary_plot(shap_object_class_1.values, shap_object_class_1.data, feature_names=shap_object_class_1.feature_names, show=False)
plt.gca().set_xlabel("Custom ----> LABEL 1")  # Modify the x-axis label
plt.show()  # Display the plot with the updated label
shap.waterfall_plot(shap_object_class_1[0], max_display=10)
mean_abs_shap_values_class_1 = np.abs(shap_object_class_1.values).mean(axis=0)
shap.bar_plot(mean_abs_shap_values_class_1, feature_names=shap_object_class_1.feature_names)