In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import configparser
import json 
import os
import sys
import warnings
import boto3
import sagemaker
import numpy as np
import pandas as pd
from bokeh.io import export_svgs, output_notebook
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, NumeralTickFormatter
from sagemaker import clarify
from sagemaker import s3
sys.path.append("..")

from utils.common import (
    dump_pickle,
    load_pickle,
)

output_notebook()
np.random.seed(42)
warnings.filterwarnings(action="ignore")

## Explaining Predictions

In [3]:
config = configparser.ConfigParser()
_ = config.read(os.path.join("..", "conf", "config.ini"))

default_bucket = config["proj"]["s3_default_bucket"]
base_job_prefix = config["proj"]["s3_base_job_prefix"]
role = config["proj"]["iam_role"]
sampling_rate = eval(config["model"]["sampling_rate"])
num_samples = eval(config["model"]["num_samples"])

boto_session = boto3.Session()
sagemaker_session = sagemaker.Session()
account_id = boto_session.client("sts").get_caller_identity().get("Account")
role = f"arn:aws:iam::{account_id}:role/service-role/{role}"

if len(default_bucket) == 0:
    default_bucket = sagemaker_session.default_bucket()

PROC_DATA_PATH = os.path.join("..", "proc_data")
MODEL_PATH = os.path.join("..", "models")
IMG_PATH = os.path.join("..", "img")

columns, model_name, baseline = load_pickle(os.path.join(MODEL_PATH, "clarify.pkl"))

In [4]:
arr_re_train = np.loadtxt(
    os.path.join(PROC_DATA_PATH, "re_train", "arr_re_train.csv"), delimiter=","
)
if sampling_rate == 1.0:
    arr_sampled_train = arr_re_train.copy()
else:
    arr_sampled_train = arr_re_train[
        np.random.choice(
            np.arange(arr_re_train.shape[0]),
            int(arr_re_train.shape[0] * sampling_rate),
            replace=False,
        )
    ]

os.makedirs(os.path.join(PROC_DATA_PATH, "sampled_train"), exist_ok=True)
np.savetxt(
    os.path.join(PROC_DATA_PATH, "sampled_train", "sampled_train.csv"),
    arr_sampled_train,
    delimiter=",",
    fmt="%i",
)

_ = s3.S3Uploader.upload(
    os.path.join(PROC_DATA_PATH, "sampled_train", "sampled_train.csv"),
    f"s3://{default_bucket}/{base_job_prefix}/sampled_train/sampled_train.csv",
)

In [5]:
data_config = clarify.DataConfig(
    s3_data_input_path=f"s3://{default_bucket}/{base_job_prefix}/sampled_train",
    s3_output_path=f"s3://{default_bucket}/{base_job_prefix}/clarify-expl",
    label="isFraud",
    headers=columns,
    dataset_type="text/csv",
)

model_config = clarify.ModelConfig(
    model_name=model_name,
    instance_count=2,
    instance_type="ml.c5.2xlarge",
    accept_type="text/csv",
    content_type="text/csv",
)

shap_config = clarify.SHAPConfig(
    baseline=baseline,
    num_samples=num_samples,
    agg_method="mean_abs",
    use_logit=False,
    save_local_shap_values=True,
    seed=42,
)

In [6]:
%%time
%%capture
clarify_processor = clarify.SageMakerClarifyProcessor(
    role=role,
    instance_count=2,
    instance_type="ml.c5.4xlarge",
    sagemaker_session=sagemaker_session,
)

clarify_processor.run_explainability(
    data_config=data_config,
    model_config=model_config,
    explainability_config=shap_config,
)

CPU times: user 9.97 s, sys: 955 ms, total: 10.9 s
Wall time: 1h 1min 54s


In [7]:
s3.S3Downloader.download(
    s3_uri=f"s3://{default_bucket}/{base_job_prefix}/clarify-expl/analysis.json",
    local_path=MODEL_PATH,
    sagemaker_session=sagemaker_session,
)

s3.S3Downloader.download(
    s3_uri=f"s3://{default_bucket}/{base_job_prefix}/clarify-expl/explanations_shap/out.csv",
    local_path=MODEL_PATH,
    sagemaker_session=sagemaker_session,
)

In [8]:
with open(os.path.join(MODEL_PATH, "analysis.json")) as file:
    data = json.load(file)
    global_shap_values = data["explanations"]["kernel_shap"]["label0"][
        "global_shap_values"
    ]

num_features = 25
global_shap_values = pd.Series(global_shap_values).sort_values(ascending=False)[
    :num_features
][::-1]
features = global_shap_values.index.str.upper().tolist()
shap_values = global_shap_values.values.tolist()

source = ColumnDataSource(dict(features=features, shap_values=shap_values))
p = figure(
    y_range=features,
    width=500,
    height=20 * num_features + 50,
    title=f"Global Shapley Value by Feature (Top {num_features})",
)

p.hbar(y="features", right="shap_values", height=0.8, source=source)

p.x_range.start = 0.0
p.xaxis.formatter = NumeralTickFormatter(format="0.00 %")
p.axis.minor_tick_line_color = None
p.ygrid.grid_line_color = None
p.title.align = "center"
p.title.text_font_size = "11pt"

show(p)

p.output_backend = "svg"
_ = export_svgs(p, filename=os.path.join(IMG_PATH, "global_shap_values.svg"))