# Explain Model Predictions with Amazon SageMaker Clarify

There are expanding business needs and legislative regulations that require explainations of _why_ a model mades the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.

In [1]:
import boto3
import sagemaker
import pandas as pd
import numpy as np

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

sm = boto3.Session().client(service_name="sagemaker", region_name=region)

In [2]:
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format='retina'

# Test data for explainability

We created test data in JSONLines format to match the model inputs. 

In [3]:
test_data_explainability_path = "./data-clarify/test_data_explainability.jsonl"

In [4]:
!head -n 1 $test_data_explainability_path

{"features":["I have been using Quicken for years now and it does everything that I need it to accomplish for my personal finances.","Digital_Software"]}


# Upload the data

In [5]:
test_data_explainablity_s3_uri = sess.upload_data(
    bucket=bucket, key_prefix="bias/test_data_explainability", path=test_data_explainability_path
)
test_data_explainablity_s3_uri

's3://sagemaker-us-east-1-298039562326/bias/test_data_explainability/test_data_explainability.jsonl'

In [6]:
!aws s3 ls $test_data_explainablity_s3_uri

2021-05-15 20:12:04     190783 test_data_explainability.jsonl


In [7]:
%store test_data_explainablity_s3_uri

Stored 'test_data_explainablity_s3_uri' (str)


# List Pipeline Execution Steps


In [8]:
%store -r pipeline_name

In [9]:
print(pipeline_name)

BERT-pipeline-1621102217


In [10]:
%%time

import time
from pprint import pprint

executions_response = sm.list_pipeline_executions(PipelineName=pipeline_name)["PipelineExecutionSummaries"]
pipeline_execution_status = executions_response[0]["PipelineExecutionStatus"]
print(pipeline_execution_status)

while pipeline_execution_status == "Executing":
    try:
        executions_response = sm.list_pipeline_executions(PipelineName=pipeline_name)["PipelineExecutionSummaries"]
        pipeline_execution_status = executions_response[0]["PipelineExecutionStatus"]
    except Exception as e:
        print("Please wait...")
        time.sleep(30)

pprint(executions_response)

Succeeded
[{'PipelineExecutionArn': 'arn:aws:sagemaker:us-east-1:298039562326:pipeline/bert-pipeline-1621102217/execution/34f5wo3bayx8',
  'PipelineExecutionDisplayName': 'execution-1621102228350',
  'PipelineExecutionStatus': 'Succeeded',
  'StartTime': datetime.datetime(2021, 5, 15, 18, 10, 28, 189000, tzinfo=tzlocal())}]
CPU times: user 6.9 ms, sys: 7.81 ms, total: 14.7 ms
Wall time: 123 ms


In [11]:
pipeline_execution_status = executions_response[0]["PipelineExecutionStatus"]
print(pipeline_execution_status)

Succeeded


In [12]:
pipeline_execution_arn = executions_response[0]["PipelineExecutionArn"]
print(pipeline_execution_arn)

arn:aws:sagemaker:us-east-1:298039562326:pipeline/bert-pipeline-1621102217/execution/34f5wo3bayx8


In [13]:
from pprint import pprint

steps = sm.list_pipeline_execution_steps(PipelineExecutionArn=pipeline_execution_arn)

pprint(steps)

{'PipelineExecutionSteps': [{'EndTime': datetime.datetime(2021, 5, 15, 18, 48, 17, 743000, tzinfo=tzlocal()),
                             'Metadata': {'RegisterModel': {'Arn': 'arn:aws:sagemaker:us-east-1:298039562326:model-package/bert-reviews-1621102219/1'}},
                             'StartTime': datetime.datetime(2021, 5, 15, 18, 48, 16, 822000, tzinfo=tzlocal()),
                             'StepName': 'RegisterModel',
                             'StepStatus': 'Succeeded'},
                            {'EndTime': datetime.datetime(2021, 5, 15, 18, 48, 17, 529000, tzinfo=tzlocal()),
                             'Metadata': {'Model': {'Arn': 'arn:aws:sagemaker:us-east-1:298039562326:model/pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t'}},
                             'StartTime': datetime.datetime(2021, 5, 15, 18, 48, 16, 820000, tzinfo=tzlocal()),
                             'StepName': 'CreateModel',
                             'StepStatus': 'Succeeded'},
                  

# View Created Model

In [14]:
for execution_step in steps["PipelineExecutionSteps"]:
    if execution_step["StepName"] == "CreateModel":
        model_arn = execution_step["Metadata"]["Model"]["Arn"]
        break
print(model_arn)

pipeline_model_name = model_arn.split("/")[-1]
print(pipeline_model_name)

arn:aws:sagemaker:us-east-1:298039562326:model/pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t
pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t


# Setup Model Explainability Analysis

In [15]:
from sagemaker import clarify

clarify_processor = clarify.SageMakerClarifyProcessor(
    role=role, instance_count=1, instance_type="ml.c5.2xlarge", sagemaker_session=sess
)

# Writing DataConfig and ModelConfig
A `DataConfig` object communicates some basic information about data I/O to Clarify. We specify where to find the input dataset, where to store the output, the target column (`label`), the header names, and the dataset type.

Similarly, the `ModelConfig` object communicates information about your trained model and `ModelPredictedLabelConfig` provides information on the format of your predictions.  

**Note**: To avoid additional traffic to your production models, SageMaker Clarify sets up and tears down a dedicated endpoint when processing. `ModelConfig` specifies your preferred instance type and instance count used to run your model on during Clarify's processing.

## DataConfig

In [16]:
explainability_report_prefix = "bias/explainability-report-{}".format(pipeline_model_name)

explainability_output_path = "s3://{}/{}".format(bucket, explainability_report_prefix)

explainability_data_config = clarify.DataConfig(
    s3_data_input_path=test_data_explainablity_s3_uri,
    s3_output_path=explainability_output_path,
    headers=["review_body", "product_category"],
    features="features",
    dataset_type="application/jsonlines",
)

## ModelConfig

In [17]:
model_config = clarify.ModelConfig(
    model_name=pipeline_model_name,
    instance_type="ml.m5.4xlarge",
    instance_count=1,
    content_type="application/jsonlines",
    accept_type="application/jsonlines",
    content_template='{"features":$features}',
)

## SHAPConfig

Here is more information about explainability and SHAP:
* https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-model-explainability.html
* https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-shapley-values.html
* https://papers.nips.cc/paper/2017/file/8a20a8621978632d76c43dfd28b67767-Paper.pdf

In [18]:
shap_config = clarify.SHAPConfig(
    baseline=[{"features": ["ok", "Digital_Software"]}],  # [data.iloc[0].values.tolist()],
    num_samples=5,
    agg_method="mean_abs",
)

# Run Clarify Job

In [19]:
clarify_processor.run_explainability(
    model_config=model_config,
    model_scores="predicted_label",
    data_config=explainability_data_config,
    explainability_config=shap_config,
    wait=False,
    logs=False,
)


Job Name:  Clarify-Explainability-2021-05-15-20-12-05-386
Inputs:  [{'InputName': 'dataset', 'AppManaged': False, 'S3Input': {'S3Uri': 's3://sagemaker-us-east-1-298039562326/bias/test_data_explainability/test_data_explainability.jsonl', 'LocalPath': '/opt/ml/processing/input/data', 'S3DataType': 'S3Prefix', 'S3InputMode': 'File', 'S3DataDistributionType': 'FullyReplicated', 'S3CompressionType': 'None'}}, {'InputName': 'analysis_config', 'AppManaged': False, 'S3Input': {'S3Uri': 's3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t/analysis_config.json', 'LocalPath': '/opt/ml/processing/input/config', 'S3DataType': 'S3Prefix', 'S3InputMode': 'File', 'S3DataDistributionType': 'FullyReplicated', 'S3CompressionType': 'None'}}]
Outputs:  [{'OutputName': 'analysis_result', 'AppManaged': False, 'S3Output': {'S3Uri': 's3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t', 'Local

In [20]:
run_explainability_job_name = clarify_processor.latest_job.job_name
run_explainability_job_name

'Clarify-Explainability-2021-05-15-20-12-05-386'

In [21]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={}#/processing-jobs/{}">Processing Job</a></b>'.format(
            region, run_explainability_job_name
        )
    )
)

In [22]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/ProcessingJobs;prefix={};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a> After About 5 Minutes</b>'.format(
            region, run_explainability_job_name
        )
    )
)

In [23]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://s3.console.aws.amazon.com/s3/buckets/{}?prefix={}/">S3 Output Data</a> After The Processing Job Has Completed</b>'.format(
            bucket, explainability_report_prefix
        )
    )
)

In [24]:
from pprint import pprint

running_processor = sagemaker.processing.ProcessingJob.from_processing_name(
    processing_job_name=run_explainability_job_name, sagemaker_session=sess
)

processing_job_description = running_processor.describe()

pprint(processing_job_description)

{'AppSpecification': {'ImageUri': '205585389593.dkr.ecr.us-east-1.amazonaws.com/sagemaker-clarify-processing:1.0'},
 'CreationTime': datetime.datetime(2021, 5, 15, 20, 12, 5, 672000, tzinfo=tzlocal()),
 'LastModifiedTime': datetime.datetime(2021, 5, 15, 20, 12, 6, 39000, tzinfo=tzlocal()),
 'ProcessingInputs': [{'AppManaged': False,
                       'InputName': 'dataset',
                       'S3Input': {'LocalPath': '/opt/ml/processing/input/data',
                                   'S3CompressionType': 'None',
                                   'S3DataDistributionType': 'FullyReplicated',
                                   'S3DataType': 'S3Prefix',
                                   'S3InputMode': 'File',
                                   'S3Uri': 's3://sagemaker-us-east-1-298039562326/bias/test_data_explainability/test_data_explainability.jsonl'}},
                      {'AppManaged': False,
                       'InputName': 'analysis_config',
                       'S3I

In [25]:
running_processor.wait(logs=False)

.......................................................................................................................................................!

# Download Report From S3

In [26]:
!aws s3 ls $explainability_output_path/

                           PRE explanations_shap/
2021-05-15 20:24:39        339 analysis.json
2021-05-15 20:12:06        644 analysis_config.json
2021-05-15 20:24:39     292697 report.html
2021-05-15 20:24:39      19824 report.ipynb
2021-05-15 20:24:39      40183 report.pdf


In [27]:
!aws s3 cp --recursive $explainability_output_path ./explainability_report/

download: s3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t/analysis_config.json to explainability_report/analysis_config.json
download: s3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t/analysis.json to explainability_report/analysis.json
download: s3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t/explanations_shap/baseline.csv to explainability_report/explanations_shap/baseline.csv
download: s3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t/explanations_shap/out.csv to explainability_report/explanations_shap/out.csv
download: s3://sagemaker-us-east-1-298039562326/bias/explainability-report-pipelines-34f5wo3bayx8-createmodel-bbjbc6kn0t/report.html to explainability_report/report.html
download: s3://sagemaker-us-east-1-298039562326/bias/explainab

In [28]:
from IPython.core.display import display, HTML

display(HTML('<b>Review <a target="blank" href="./explainability_report/report.html">Explainability Report</a></b>'))

# View the Explainability Report
As with the bias report, you can view the explainability report in Studio under the experiments tab


<img src="img/explainability_detail.gif">

The Model Insights tab contains direct links to the report and model insights.

If you're not a Studio user yet, as with the Bias Report, you can access this report at the following S3 bucket.

# Release Resources

In [29]:
%%html

<p><b>Shutting down your kernel for this notebook to release resources.</b></p>
<button class="sm-command-button" data-commandlinker-command="kernelmenu:shutdown" style="display:none;">Shutdown Kernel</button>
        
<script>
try {
    els = document.getElementsByClassName("sm-command-button");
    els[0].click();
}
catch(err) {
    // NoOp
}    
</script>