## Setup

In [None]:
from common import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

### SHAP values - via PBS array jobs

#### Normal SHAP values - single thread timing

~ 2 s / sample

In [None]:
plt.plot([10, 30, 50, 100], [19.6, 57.2, 95, (3 * 60) + 18], marker="o")
plt.xlabel("Samples")
plt.ylabel("time (s)")

### SHAP values - see PBS (array) job folder 'shap'!

In [None]:
killed_jobs = []
out_dir = PROJECT_DIR / Path("shap/output")
for out in tqdm(os.listdir(out_dir)):
    with open(os.path.join(out_dir, out), "r") as f:
        content = f.read()
    if "job killed" in content:
        killed_jobs.append(re.search("\[(\d+)\]", out).group(1))
print(len(killed_jobs), "killed jobs")

In [None]:
print(sorted([int(k) for k in killed_jobs]))

#### Actually load the SHAP values

In [None]:
max_index = 995  # Maximum job array index (inclusive).
job_samples = 2000  # Samples per job.
total_samples = (max_index + 1) * job_samples  # Sanity check.

# Load the individual data chunks.
shap_chunks = []
for index in tqdm(range(max_index + 1), desc="Loading chunks"):
    shap_chunks.append(
        SimpleCache(
            f"tree_path_dependent_shap_{index}_{job_samples}",
            cache_dir=os.path.join(CACHE_DIR, "shap"),
            verbose=0,
        ).load()
    )
shap_values = np.vstack(shap_chunks)

In [None]:
%%time
with figure_saver("SHAP"):
    shap.summary_plot(
        shap_values,
        shorten_columns(X_train[:total_samples]),
        title="SHAP Feature Importances",
        show=False,
    )

In [None]:
mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
mean_shap_importances = pd.DataFrame(
    [shorten_features(X_train.columns), mean_abs_shap], index=["column", "shap"]
)
mean_shap_importances = mean_shap_importances.T
mean_shap_importances.sort_values("shap", ascending=False, inplace=True)
mean_shap_importances

### Interaction SHAP values - see PBS (array) job folder 'shap_interaction'!

~ 150 s / sample !!

In [None]:
killed_jobs = []
out_dir = PROJECT_DIR / Path("shap_interaction/output")
for out in tqdm(os.listdir(out_dir)):
    with open(os.path.join(out_dir, out), "r") as f:
        content = f.read()
    if "job killed" in content:
        killed_jobs.append(re.search("\[(\d+)\]", out).group(1))
print(len(killed_jobs), "killed jobs")

In [None]:
print(sorted([int(k) for k in killed_jobs]))

In [None]:
max_index = 5999  # Maximum job array index (inclusive).
job_samples = 50  # Samples per job.
total_interact_samples = (max_index + 1) * job_samples

shap_interact_cache = SimpleCache(
    "shap_interact_cache", cache_dir=CACHE_DIR / Path("shap_interaction")
)


@shap_interact_cache
def load_shap_interact_from_chunks():
    # Load the individual data chunks.
    shap_interact_chunks = []
    for index in tqdm(range(max_index + 1), desc="Loading chunks"):
        shap_interact_chunks.append(
            SimpleCache(
                f"tree_path_dependent_shap_interact_{index}_{job_samples}",
                cache_dir=os.path.join(CACHE_DIR, "shap_interaction"),
                verbose=0,
            ).load()
        )
    return np.vstack(shap_interact_chunks)


shap_interact_values = load_shap_interact_from_chunks()

In [None]:
with figure_saver("SHAP_interaction"):
    shap.summary_plot(
        shap_interact_values,
        shorten_columns(X_train[:total_interact_samples]),
        title="SHAP Feature Interactions",
        show=False,
    )

In [None]:
with figure_saver("SHAP_interaction_compact"):
    shap.summary_plot(
        shap_interact_values,
        shorten_columns(X_train[:total_interact_samples]),
        title="SHAP Feature Interactions",
        show=False,
        plot_type="compact_dot",
    )

In [None]:
mean_interact = np.mean(np.abs(shap_interact_values), axis=0)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax = sns.heatmap(
    np.log(mean_interact),
    square=True,
    xticklabels=shorten_features(X_train.columns),
    yticklabels=shorten_features(X_train.columns),
    ax=ax,
)
ax.xaxis.set_tick_params(rotation=90)
ax.yaxis.set_tick_params(rotation=0)
_ = fig.suptitle("log(SHAP Interaction Values)")

### Get the most significant interactions

In [None]:
def get_highest_interact_index(index, mean_interact):
    indices = np.argsort(-mean_interact[index])
    return indices[indices != index][0]

In [None]:
masked_mean_interact = mean_interact.copy()
masked_mean_interact[np.triu_indices(mean_interact.shape[0])] = np.nan

fig, ax = plt.subplots(figsize=(10, 10))
ax = sns.heatmap(
    np.log(masked_mean_interact),
    square=True,
    xticklabels=shorten_features(X_train.columns),
    yticklabels=shorten_features(X_train.columns),
    ax=ax,
)
ax.xaxis.set_tick_params(rotation=90)
ax.yaxis.set_tick_params(rotation=0)
_ = fig.suptitle("log(SHAP Interaction Values)")

In [None]:
interact_indices = np.tril_indices(mean_interact.shape[0])
interact_values = mean_interact[interact_indices]

interact_data = {}
for i, j, interact_value in zip(*interact_indices, interact_values):
    if i == j:
        continue
    interact_data[
        tuple(
            sorted(
                (
                    shorten_features(X_train.columns[i]),
                    shorten_features(X_train.columns[j]),
                )
            )
        )
    ] = interact_value

interact_data = pd.Series(interact_data).sort_values(ascending=False, inplace=False)
interact_data[:20]

### Compare the approximate and 'exact' interactions

In [None]:
length = 18
header = f"{'Feature':>{length}} : {'Approx':>{length}} {'Non Approx':>{length}}"
print(header)
print("-" * len(header))

for i in range(len(X_train.columns)):
    approx = shap.common.approximate_interactions(
        X_train.columns[i], shap_values, X_train[:total_samples]
    )[0]
    non_approx = get_highest_interact_index(i, mean_interact)
    features = shorten_features(
        X_train.columns[index] for index in (i, approx, non_approx)
    )
    print(f"{features[0]:>{length}} : {features[1]:>{length}} {features[2]:>{length}}")

### Dependence plots with the approximate dependence metric

In [None]:
def approx_dependence_plot(i, shap_values=shap_values, X=X_train[:total_samples]):
    fig, ax = plt.subplots(figsize=(7, 5))
    shap.dependence_plot(
        i, shap_values, X, alpha=0.1, ax=ax,
    )
    figure_saver.save_figure(
        fig,
        f"shap_dependence_{X_train.columns[i]}_approx",
        sub_directory="shap_dependence_approx",
    )


with concurrent.futures.ProcessPoolExecutor(max_workers=32) as executor:
    plot_fs = []
    for i in range(len(X_train.columns)):
        plot_fs.append(executor.submit(approx_dependence_plot, i))

    for plot_f in tqdm(
        concurrent.futures.as_completed(plot_fs),
        total=len(plot_fs),
        desc="Plotting approx SHAP dependence plots",
    ):
        pass

In [None]:
def dependence_plot(
    i, shap_values=shap_values, X=X_train[:total_samples], mean_interact=mean_interact
):
    fig, ax = plt.subplots(figsize=(7, 5))
    shap.dependence_plot(
        i,
        shap_values,
        X,
        interaction_index=get_highest_interact_index(i, mean_interact),
        alpha=0.1,
        ax=ax,
    )
    figure_saver.save_figure(
        fig, f"shap_dependence_{X_train.columns[i]}", sub_directory="shap_dependence"
    )


with concurrent.futures.ProcessPoolExecutor(max_workers=32) as executor:
    plot_fs = []
    for i in range(len(X_train.columns)):
        plot_fs.append(executor.submit(dependence_plot, i))

    for plot_f in tqdm(
        concurrent.futures.as_completed(plot_fs),
        total=len(plot_fs),
        desc="Plotting SHAP dependence plots",
    ):
        pass

### SHAP force plot

#### force_plot memory usage

In [None]:
(5e4 * (5e4 - 1) // 2) * 8 / 1e9

In [None]:
shap.initjs()

In [None]:
rf.n_jobs = 32
N = int(1e3)
shap.force_plot(
    np.mean(rf.predict(X_train[:N])), shap_values[:N], shorten_columns(X_train[:N])
)