## Setup

In [None]:
from common import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

### ELI5 Permutation Importances (PFI)

In [None]:
import cloudpickle
import eli5
from eli5.sklearn import PermutationImportance
from joblib import Parallel, delayed, parallel_backend

from wildfires.dask_cx1 import get_parallel_backend

perm_importance_cache = SimpleCache(
    "perm_importance", cache_dir=CACHE_DIR, pickler=cloudpickle
)

# Does not seem to work with the dask parallel backend - it gets bypassed
# and every available core on the machine is used up if attempted.


@perm_importance_cache
def get_perm_importance():
    rf.n_jobs = 30
    return eli5.sklearn.PermutationImportance(rf).fit(X_train, y_train)


# worker = list(client.scheduler_info()['workers'])[0]
# perm_importance = client.run(get_perm_importance, workers=[worker])

perm_importance = get_perm_importance()
perm_df = eli5.explain_weights_df(perm_importance, feature_names=list(X_train.columns))

#### VIF Calculation

In [None]:
train_vif_cache = SimpleCache("train_vif", cache_dir=CACHE_DIR)


@train_vif_cache
def get_vifs():
    return vif(X_train, verbose=True)


vifs = get_vifs()
vifs = vifs.set_index("Name", drop=True).T

## Individual Tree Importances - Gini vs PFI vs SHAP

In [None]:
N_col = 20

fig, axes = plt.subplots(4, 1, sharex=True, figsize=(7, 12))

# Unpack
ax, ax2, ax3, ax4 = axes

# Gini values.
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in rf], columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)
sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
ax.set(
    # title="Gini Importances",
    ylabel="Gini Importance (MSE)\n"
)

# PFI values.
pfi_ind = pd.DataFrame(perm_importance.results_, columns=X_train.columns)

# Re-index according to the same ordering as for the Gini importances!
pfi_ind = pfi_ind.reindex(mean_importances.index, axis=1)

sns.boxplot(data=pfi_ind.iloc[:, :N_col], ax=ax2)
ax2.set(
    # title="PFI Importances",
    ylabel="PFI Importance\n"
)

# SHAP values.
total_samples = 20000
tree_path_dependent_shap_cache = SimpleCache(
    f"tree_path_dependent_shap_{total_samples}", cache_dir=CACHE_DIR
)
shap_values = tree_path_dependent_shap_cache.load()
mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
mean_shap_importances = (
    pd.DataFrame(mean_abs_shap, index=X_train.columns, columns=["SHAP Importance"],)
    .sort_values("SHAP Importance", ascending=False)
    .T
)

# Re-index according to the same ordering as for the Gini importances!
mean_shap_importances = mean_shap_importances.reindex(mean_importances.index, axis=1)

sns.boxplot(data=mean_shap_importances.iloc[:, :N_col], ax=ax3)
ax3.set(ylabel="SHAP Importance\n")

# VIFs

# Re-index according to the same ordering as for the Gini importances!
vifs = vifs.reindex(mean_importances.index, axis=1)

sns.boxplot(data=vifs.iloc[:, :N_col], ax=ax4)
ax4.set(ylabel="VIF\n")

# Rotate the last x axis labels (the only visible ones).
_ = axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=45, ha="right")

for _ax in (ax, ax2, ax3, ax4):
    _ax.grid(which="major", alpha=0.3)
    _ax.tick_params(labelleft=False)

# fig.suptitle("Gini, PFI, SHAP, VIF")
plt.tight_layout()
plt.subplots_adjust(top=0.91)
figure_saver.save_figure(fig, "feature_importances_gini_pfi_shap")