## Setup

In [None]:
from common import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

### ELI5 Permutation Importances (PFI)

In [None]:
import cloudpickle
import eli5
from eli5.sklearn import PermutationImportance
from joblib import Parallel, delayed, parallel_backend

from wildfires.dask_cx1 import get_parallel_backend

perm_importance_cache = SimpleCache(
    "perm_importance", cache_dir=CACHE_DIR, pickler=cloudpickle
)

# Does not seem to work with the dask parallel backend - it gets bypassed
# and every available core on the machine is used up if attempted.


@perm_importance_cache
def get_perm_importance():
    rf.n_jobs = 30
    return eli5.sklearn.PermutationImportance(rf).fit(X_train, y_train)


# worker = list(client.scheduler_info()['workers'])[0]
# perm_importance = client.run(get_perm_importance, workers=[worker])

perm_importance = get_perm_importance()
perm_df = eli5.explain_weights_df(perm_importance, feature_names=list(X_train.columns))

#### VIF Calculation

In [None]:
train_vif_cache = SimpleCache("train_vif", cache_dir=CACHE_DIR)


@train_vif_cache
def get_vifs():
    return vif(X_train, verbose=True)


vifs = get_vifs()
vifs = vifs.set_index("Name", drop=True).T

#### LOCO Calculation - from the LOCO notebook

In [None]:
loco_cache = SimpleCache("loco_results", cache_dir=CACHE_DIR)
loco_results = loco_cache.load()
baseline_mse = loco_results[""]["mse"]

loco_df = pd.DataFrame(
    {
        column: [loco_results[column]["mse"] - baseline_mse]
        for column in loco_results
        if column
    }
)
loco_df.columns.name = "Name"
loco_df.index = ["LOCO (MSE)"]

## Individual Tree Importances - Gini vs PFI vs SHAP

SHAP values are loaded from the shap notebook.

In [None]:
def plot_importances(df, ax=None):
    means = df.mean().sort_values(ascending=False)
    df = df.reindex(means.index, axis=1)

    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 12))
    ax = sns.boxplot(data=df, orient="h", ax=ax)
    ax.grid(which="both")

### Gini

In [None]:
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in rf], columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)
shorten_columns(ind_trees_gini, inplace=True)


def gini_plot(ax, N_col):
    sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
    ax.set(
        # title="Gini Importances",
        ylabel="Gini Importance (MSE)\n"
    )

In [None]:
plot_importances(ind_trees_gini)

### PFI

In [None]:
pfi_ind = pd.DataFrame(perm_importance.results_, columns=X_train.columns)

# Re-index according to the same ordering as for the Gini importances!
pfi_ind = pfi_ind.reindex(mean_importances.index, axis=1)
shorten_columns(pfi_ind, inplace=True)


def pfi_plot(ax, N_col):
    sns.boxplot(data=pfi_ind.iloc[:, :N_col], ax=ax)
    ax.set(
        # title="PFI Importances",
        ylabel="PFI Importance\n"
    )

In [None]:
plot_importances(pfi_ind)

### SHAP

In [None]:
max_index = 995  # Maximum job array index (inclusive).
job_samples = 2000  # Samples per job.
total_samples = (max_index + 1) * job_samples  # Sanity check.

# Load the individual data chunks.
shap_chunks = []
for index in tqdm(range(max_index + 1), desc="Loading chunks"):
    shap_chunks.append(
        SimpleCache(
            f"tree_path_dependent_shap_{index}_{job_samples}",
            cache_dir=os.path.join(CACHE_DIR, "shap"),
            verbose=0,
        ).load()
    )
shap_values = np.vstack(shap_chunks)

mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
mean_shap_importances = (
    pd.DataFrame(mean_abs_shap, index=X_train.columns, columns=["SHAP Importance"],)
    .sort_values("SHAP Importance", ascending=False)
    .T
)

# Re-index according to the same ordering as for the Gini importances!
mean_shap_importances = mean_shap_importances.reindex(mean_importances.index, axis=1)

shorten_columns(mean_shap_importances, inplace=True)


def shap_plot(ax, N_col):
    sns.boxplot(data=mean_shap_importances.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="SHAP Importance\n")

In [None]:
plot_importances(mean_shap_importances)

### LOCO

In [None]:
loco_df = loco_df.reindex(mean_importances.index, axis=1)

shorten_columns(loco_df, inplace=True)


def loco_plot(ax, N_col):
    sns.boxplot(data=loco_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="LOCO (MSE)\n")

In [None]:
plot_importances(loco_df)

### VIF

In [None]:
# Re-index according to the same ordering as for the Gini importances!
vifs = vifs.reindex(mean_importances.index, axis=1)

shorten_columns(vifs, inplace=True)


def vif_plot(ax, N_col):
    sns.boxplot(data=vifs.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="VIF\n")

In [None]:
plot_importances(vifs)

### ALE 1D

In [None]:
world_ale_1d_cache = SimpleCache("world_ale_1d", cache_dir=CACHE_DIR)
ptp_values, mc_ptp_values = world_ale_1d_cache.load()

ale_1d_df = pd.DataFrame(ptp_values, index=["ALE 1D (PTP)"])
ale_1d_df.columns.name = "Name"

ale_1d_mc_df = pd.DataFrame(mc_ptp_values, index=["ALE 1D MC (PTP)"])
ale_1d_mc_df.columns.name = "Name"

# Re-index according to the same ordering as for the Gini importances!
ale_1d_df.reindex(mean_importances.index, axis=1)
ale_1d_mc_df.reindex(mean_importances.index, axis=1)

shorten_columns(ale_1d_df, inplace=True)
shorten_columns(ale_1d_mc_df, inplace=True)


def ale_1d_plot(ax, N_col):
    sns.boxplot(data=ale_1d_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="ALE 1D\n")


def ale_1d_mc_plot(ax, N_col):
    sns.boxplot(data=ale_1d_mc_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="ALE 1D MC\n")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 12))
plot_importances(ale_1d_df, ax=axes[0])
axes[0].set_title("ALE 1D")
plot_importances(ale_1d_mc_df, ax=axes[1])
axes[1].set_title("ALE 1D MC")
for ax in axes:
    ax.set_ylabel("")

plt.tight_layout()

### ALE 2D - very cursory analysis

Does not take into account which of the 2 variables is the one responsible for the interaction.

In [None]:
world_ale_2d_cache = SimpleCache("world_ale_2d", cache_dir=CACHE_DIR)
ptp_2d_values = world_ale_2d_cache.load()

interaction_data = defaultdict(float)
for feature in X_train.columns:
    for feature_pair, ptp_2d_value in ptp_2d_values.items():
        if feature in feature_pair:
            interaction_data[feature] += ptp_2d_value

ale_2d_df = pd.DataFrame(interaction_data, index=["ALE 2D (PTP)"])
ale_2d_df.columns.name = "Name"

# Re-index according to the same ordering as for the Gini importances!
ale_2d_df.reindex(mean_importances.index, axis=1)

shorten_columns(ale_2d_df, inplace=True)


def ale_2d_plot(ax, N_col):
    sns.boxplot(data=ale_2d_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="ALE 2D\n")

In [None]:
plot_importances(ale_2d_df)

### Combining the plots

In [None]:
N_col = 20

plot_funcs = (
    gini_plot,
    pfi_plot,
    shap_plot,
    loco_plot,
    ale_1d_plot,
    ale_1d_mc_plot,
    ale_2d_plot,
    vif_plot,
)

fig, axes = plt.subplots(len(plot_funcs), 1, sharex=True, figsize=(7, 20))

for plot_func, ax in zip(plot_funcs, axes):
    plot_func(ax, N_col)

# Rotate the last x axis labels (the only visible ones).
axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=45, ha="right")

for _ax in axes:
    _ax.grid(which="major", alpha=0.4, linestyle="--")
    _ax.tick_params(labelleft=False)

for _ax in axes[:-1]:
    _ax.set_xlabel("")

# fig.suptitle("Gini, PFI, SHAP, VIF")
plt.tight_layout()
plt.subplots_adjust(top=0.91)
figure_saver.save_figure(fig, "feature_importances")

In [None]:
importances = {
    "Gini": ind_trees_gini,
    "PFI": pfi_ind,
    "SHAP": mean_shap_importances,
    "LOCO": loco_df,
    "VIF": vifs,
    "ALE 1D": ale_1d_df,
    "ALE 1D MC": ale_1d_mc_df,
    "ALE 2D": ale_2d_df,
}
for key, df in importances.items():
    importances[key] = df.mean().sort_values(ascending=False)

In [None]:
table_str = np.array([df.index.values for df in importances.values()]).T

In [None]:
fig = plt.figure(figsize=(15, 20))
spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=[3, 1])
axes = [fig.add_subplot(s) for s in spec]

axes[0].set_axis_off()
axes[0].table(
    table_str,
    loc="left",
    rowLabels=range(1, len(table_str) + 1),
    bbox=[0, 0, 1, 1],
    colLabels=list(importances.keys()),
)
axes[1].plot(list(importances.values())[0].values, np.linspace(1, 0, len(table_str)))
axes[1].yaxis.set_label_position("right")
axes[1].yaxis.tick_right()
plt.tight_layout()