# Arize Tutorial: SHAP Value For Neural Networks

Let's get started on using Arize! ✨

Arize helps you visualize your model performance, understand drift & data quality issues, and share insights learned from your models.

**SHAP (SHapley Additive exPlanations)** is a game theoretic approach to explain the output of any machine learning model.

For neural network models, we can use `GradientExplainer` from the `SHAP` package to generate SHAP values ([API reference](https://shap-lrjball.readthedocs.io/en/latest/generated/shap.GradientExplainer.html)).

This demo consists of three parts.

1.   Train a neural network on tabular data using `tf.keras`
2.   Generate SHAP values using `shap.GradientExplainer`
3.   Logging the predictions and SHAP values to the Arize platform


# 1. Download Data and Train Model
For this demo, we use the pulsar classification dataset from UCI ([link](https://archive.ics.uci.edu/ml/datasets/HTRU2)).

In [52]:
import pandas as pd

df = pd.read_csv("https://storage.googleapis.com/arize-assets/fixtures/UCI/HTRU_2.zip")
features = df.columns.drop("class")
df

Unnamed: 0,mean_ip,std_ip,kurtosis_ip,skewness_ip,mean_dm,std_dm,kurtosis_dm,skewness_dm,class
0,140.562500,55.683782,-0.234571,-0.699648,3.199833,19.110426,7.975532,74.242225,0
1,102.507812,58.882430,0.465318,-0.515088,1.677258,14.860146,10.576487,127.393580,0
2,103.015625,39.341649,0.323328,1.051164,3.121237,21.744669,7.735822,63.171909,0
3,136.750000,57.178449,-0.068415,-0.636238,3.642977,20.959280,6.896499,53.593661,0
4,88.726562,40.672225,0.600866,1.123492,1.178930,11.468720,14.269573,252.567306,0
...,...,...,...,...,...,...,...,...,...
17893,136.429688,59.847421,-0.187846,-0.738123,1.296823,12.166062,15.450260,285.931022,0
17894,122.554688,49.485605,0.127978,0.323061,16.409699,44.626893,2.945244,8.297092,0
17895,119.335938,59.935939,0.159363,-0.743025,21.430602,58.872000,2.499517,4.595173,0
17896,114.507812,53.902400,0.201161,-0.024789,1.946488,13.381731,10.007967,134.238910,0


Split data into train and test sets; standardize feature variables.

In [53]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    df[features], df["class"], test_size=1000
)

from sklearn.preprocessing import StandardScaler

sc = StandardScaler()
X_train2 = sc.fit_transform(X_train)
X_test2 = sc.transform(X_test)

Train a neural network using `tf.keras`.

In [54]:
import tensorflow as tf

model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Input(shape=(len(features),)),
        tf.keras.layers.Dense(16, activation="relu"),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ]
)

model.compile(
    optimizer="Adam", loss="binary_crossentropy", metrics=[tf.keras.metrics.AUC()]
)
model.fit(X_train2, y_train)
model.evaluate(X_test2, y_test)



[0.09806293249130249, 0.9669120907783508]

# 2. Generate SHAP Values
Install the SHAP package.

In [55]:
!pip install -q shap
import shap

Use `GradientExplainer` to generate SHAP values from neural network models. ([API reference](https://shap-lrjball.readthedocs.io/en/latest/generated/shap.GradientExplainer.html))

In [None]:
e = shap.GradientExplainer(model, X_train2)
shap_values = pd.DataFrame(e.shap_values(X_test2)[0], columns=features)
shap_values

# 3. Log Predictions to Arize
We'll use the following helper functions to generate prediction IDs and timestamps to simulate a production environment.

In [26]:
import uuid
import numpy as np
from datetime import datetime, timedelta

# Prediction ID is required for logging any dataset
def generate_prediction_ids(X):
    return pd.Series((str(uuid.uuid4()) for _ in range(len(X_test))))


# OPTIONAL: We can directly specify when inferences were made
def simulate_production_timestamps(X, days=30):
    t = datetime.now()
    current_t, earlier_t = t.timestamp(), (t - timedelta(days=days)).timestamp()
    return pd.Series(np.linspace(earlier_t, current_t, num=len(X)))

Assemble pandas dataframe.

In [19]:
y_test_score = model.predict(X_test2).flatten()
prediction_label = list(
    map(lambda x: "pulsar" if x > 0.5 else "non-pulsar", y_test_score)
)
actual_label = list(map(lambda x: "pulsar" if x > 0.5 else "non-pulsar", y_test))

shap_values_column_names_mapping = {f"{feat}": f"{feat}_shap" for feat in features}

production_dataset = pd.concat(
    [
        X_test.reset_index(drop=True),
        pd.DataFrame(
            {
                "prediction_id": generate_prediction_ids(X_test),
                "prediction_ts": simulate_production_timestamps(X_test),
                "prediction_label": prediction_label,
                "prediction_score": y_test_score,
                "actual_label": actual_label,
            }
        ),
        shap_values.rename(columns=shap_values_column_names_mapping),
    ],
    axis=1,
)

Initialize Arize client.

In [None]:
!pip install -q arize
from arize.pandas.logger import Client, Schema
from arize.utils.types import ModelTypes, Environments

SPACE_KEY = "SPACE_KEY"
API_KEY = "API_KEY"

arize_client = Client(space_key=SPACE_KEY, api_key=API_KEY)

if SPACE_KEY == "SPACE_KEY" or API_KEY == "API_KEY":
    raise ValueError("❌ NEED TO CHANGE SPACE AND/OR API_KEY")
else:
    print("✅ Import and Setup Arize Client Done! Now we can start using Arize!")

Log data to arize.

In [None]:
# Define a Schema() object for Arize to pick up data from the correct columns for logging
production_schema = Schema(
    prediction_id_column_name="prediction_id",  # REQUIRED
    timestamp_column_name="prediction_ts",
    prediction_label_column_name="prediction_label",
    prediction_score_column_name="prediction_score",
    actual_label_column_name="actual_label",
    feature_column_names=features,
    shap_values_column_names=shap_values_column_names_mapping,
)

# arize_client.log returns a Response object from Python's requests module
response = arize_client.log(
    dataframe=production_dataset,
    schema=production_schema,
    model_id="pulsar",
    model_type=ModelTypes.SCORE_CATEGORICAL,
    environment=Environments.PRODUCTION,
)

# If successful, the server will return a status_code of 200
if response.status_code != 200:
    print(
        f"❌ logging failed with response code {response.status_code}, {response.text}"
    )
else:
    print(
        f"✅ You have successfully logged {len(production_dataset)} data points to Arize!"
    )

#  Conclusion
You now know how to seamlessly log SHAP values for neural networks onto the Arize platform. Go to [Arize](https://app.arize.com/) in order to analyze and monitor the logged SHAP values.

### Overview
Arize is an end-to-end ML observability and model monitoring platform. The platform is designed to help ML engineers and data science practitioners surface and fix issues with ML models in production faster with:
- Automated ML monitoring and model monitoring
- Workflows to troubleshoot model performance
- Real-time visualizations for model performance monitoring, data quality monitoring, and drift monitoring
- Model prediction cohort analysis
- Pre-deployment model validation
- Integrated model explainability

### Website
Visit Us At: https://arize.com/model-monitoring/

### Additional Resources
- [What is ML observability?](https://arize.com/what-is-ml-observability/)
- [Playbook to model monitoring in production](https://arize.com/the-playbook-to-monitor-your-models-performance-in-production/)
- [Using statistical distance metrics for ML monitoring and observability](https://arize.com/using-statistical-distance-metrics-for-machine-learning-observability/)
- [ML infrastructure tools for data preparation](https://arize.com/ml-infrastructure-tools-for-data-preparation/)
- [ML infrastructure tools for model building](https://arize.com/ml-infrastructure-tools-for-model-building/)
- [ML infrastructure tools for production](https://arize.com/ml-infrastructure-tools-for-production-part-1/)
- [ML infrastructure tools for model deployment and model serving](https://arize.com/ml-infrastructure-tools-for-production-part-2-model-deployment-and-serving/)
- [ML infrastructure tools for ML monitoring and observability](https://arize.com/ml-infrastructure-tools-ml-observability/)

Visit the [Arize Blog](https://arize.com/blog) and [Resource Center](https://arize.com/resource-hub/) for more resources on ML observability and model monitoring.