## Setup

In [None]:
from common import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

### SHAP values

#### Normal SHAP values

~ 15 sample / s

In [None]:
total_samples = 20000
chunk_size = 100
cores = 30

tree_path_dependent_shap_cache = SimpleCache(
    f"tree_path_dependent_shap_{total_samples}", cache_dir=CACHE_DIR
)

# tree_path_dependent_shap_cache.clear()


@tree_path_dependent_shap_cache
def tree_path_dependent_shap():
    if chunk_size is None:
        chunk_nr = cores
    else:
        # Use the next highest multiple of `cores` for the number of chunks.
        chunk_nr = math.ceil((total_samples / chunk_size) / cores) * cores

    chunk_edges = np.unique(np.linspace(0, total_samples, chunk_nr + 1, dtype=np.int64))
    with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor:
        shap_fs = [
            executor.submit(get_shap_values, rf, X_train[i:j])
            for i, j in zip(chunk_edges[:-1], chunk_edges[1:])
        ]

        shap_prog = tqdm(
            total=total_samples,
            unit="sample",
            desc="Calculating SHAP values",
            smoothing=0,
            position=0,
        )
        for shap_f in concurrent.futures.as_completed(shap_fs):
            shap_prog.update(shap_f.result().shape[0])

    return np.vstack([shap_f.result() for shap_f in shap_fs])


shap_values = tree_path_dependent_shap()

In [None]:
with figure_saver("SHAP"):
    shap.summary_plot(
        shap_values,
        X_train[:total_samples],
        title="SHAP Feature Importances",
        show=False,
    )

In [None]:
mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
mean_shap_importances = pd.DataFrame(
    [X_train.columns, mean_abs_shap], index=["column", "shap"]
)
mean_shap_importances = mean_shap_importances.T
mean_shap_importances.sort_values("shap", ascending=False, inplace=True)
mean_shap_importances

#### Interaction SHAP values

~ 150 s / sample !!

In [None]:
client = get_client()
client

In [None]:
total_samples = 600
cores = 30

chunk_edges = np.arange(0, total_samples + 1, 2)

tree_path_dependent_shap_interact_cache = SimpleCache(
    f"tree_path_dependent_shap_interact_{total_samples}", cache_dir=CACHE_DIR
)

# tree_path_dependent_shap_interact_cache.clear()

get_interact_shap_values = partial(get_shap_values, interaction=True)


@tree_path_dependent_shap_interact_cache
def tree_path_dependent_shap_interact(chunk_edges=chunk_edges):
    if chunk_edges is None:
        if chunk_size is None:
            chunk_nr = cores
        else:
            # Use the next highest multiple of `cores` for the number of chunks.
            chunk_nr = math.ceil((total_samples / chunk_size) / cores) * cores

        chunk_edges = np.unique(
            np.linspace(0, total_samples, chunk_nr + 1, dtype=np.int64)
        )
    with concurrent.futures.ProcessPoolExecutor(max_workers=cores) as executor:
        shap_fs = [
            executor.submit(get_interact_shap_values, rf, X_train[i:j])
            for i, j in zip(chunk_edges[:-1], chunk_edges[1:])
        ]

        shap_prog = tqdm(
            total=total_samples,
            unit="sample",
            desc="Calculating SHAP interaction values",
            smoothing=0,
            position=0,
        )
        for shap_f in concurrent.futures.as_completed(shap_fs):
            shap_prog.update(shap_f.result().shape[0])

    return np.vstack([shap_f.result() for shap_f in shap_fs])


shap_interact_values = tree_path_dependent_shap_interact()

# worker = list(client.scheduler_info()['workers'])[1]
# shap_interact_values = client.run(tree_path_dependent_shap_interact, workers=[worker])

In [None]:
shap.summary_plot(
    shap_values, X_train[:total_samples], title="SHAP Feature Importances"
)