## Setup

In [None]:
from specific import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
rf = get_model()

### ELI5 Permutation Importances (PFI)

In [None]:
perm_importance_cache = SimpleCache(
    "perm_importance", cache_dir=CACHE_DIR, pickler=cloudpickle
)

# Does not seem to work with the dask parallel backend - it gets bypassed
# and every available core on the machine is used up if attempted.


@perm_importance_cache
def get_perm_importance():
    rf.n_jobs = 30
    return eli5.sklearn.PermutationImportance(rf).fit(X_train, y_train)


# worker = list(client.scheduler_info()['workers'])[0]
# perm_importance = client.run(get_perm_importance, workers=[worker])

perm_importance = get_perm_importance()
perm_df = eli5.explain_weights_df(perm_importance, feature_names=list(X_train.columns))

#### VIF Calculation

In [None]:
train_vif_cache = SimpleCache("train_vif", cache_dir=CACHE_DIR)


@train_vif_cache
def get_vifs():
    return vif(X_train, verbose=True)


vifs = get_vifs()
vifs = vifs.set_index("Name", drop=True).T

#### LOCO Calculation - from the LOCO notebook

In [None]:
loco_cache = SimpleCache("loco_results", cache_dir=CACHE_DIR)
loco_results = loco_cache.load()
baseline_mse = loco_results[""]["mse"]

loco_df = pd.DataFrame(
    {
        column: [loco_results[column]["mse"] - baseline_mse]
        for column in loco_results
        if column
    }
)
loco_df.columns.name = "Name"
loco_df.index = ["LOCO (MSE)"]

## Individual Tree Importances - Gini vs PFI vs SHAP

SHAP values are loaded from the shap notebook.

In [None]:
def plot_importances(df, ax=None):
    means = df.mean().sort_values(ascending=False)
    df = df.reindex(means.index, axis=1)

    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 12))
    ax = sns.boxplot(data=df, orient="h", ax=ax)
    ax.grid(which="both")

### Gini

In [None]:
ind_trees_gini = pd.DataFrame(
    [tree.feature_importances_ for tree in rf],
    columns=X_train.columns,
)
mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
ind_trees_gini = ind_trees_gini.reindex(mean_importances.index, axis=1)
shorten_columns(ind_trees_gini, inplace=True)


def gini_plot(ax, N_col):
    sns.boxplot(data=ind_trees_gini.iloc[:, :N_col], ax=ax)
    ax.set(
        # title="Gini Importances",
        ylabel="Gini Importance (MSE)\n"
    )

In [None]:
plot_importances(ind_trees_gini)

### PFI

In [None]:
pfi_ind = pd.DataFrame(perm_importance.results_, columns=X_train.columns)

# Re-index according to the same ordering as for the Gini importances!
pfi_ind = pfi_ind.reindex(mean_importances.index, axis=1)
shorten_columns(pfi_ind, inplace=True)


def pfi_plot(ax, N_col):
    sns.boxplot(data=pfi_ind.iloc[:, :N_col], ax=ax)
    ax.set(
        # title="PFI Importances",
        ylabel="PFI Importance\n"
    )

In [None]:
plot_importances(pfi_ind)

### SHAP

In [None]:
shap_values = shap_cache.load()

mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
mean_shap_importances = (
    pd.DataFrame(
        mean_abs_shap,
        index=X_train.columns,
        columns=["SHAP Importance"],
    )
    .sort_values("SHAP Importance", ascending=False)
    .T
)

# Re-index according to the same ordering as for the Gini importances!
mean_shap_importances = mean_shap_importances.reindex(mean_importances.index, axis=1)

shorten_columns(mean_shap_importances, inplace=True)


def shap_plot(ax, N_col):
    sns.boxplot(data=mean_shap_importances.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="SHAP Importance\n")

In [None]:
plot_importances(mean_shap_importances)

### LOCO

In [None]:
loco_df = loco_df.reindex(mean_importances.index, axis=1)

shorten_columns(loco_df, inplace=True)


def loco_plot(ax, N_col):
    sns.boxplot(data=loco_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="LOCO (MSE)\n")

In [None]:
plot_importances(loco_df)

### VIF

In [None]:
# Re-index according to the same ordering as for the Gini importances!
vifs = vifs.reindex(mean_importances.index, axis=1)

shorten_columns(vifs, inplace=True)


def vif_plot(ax, N_col):
    sns.boxplot(data=vifs.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="VIF\n")

In [None]:
plot_importances(vifs)

### ALE 1D

In [None]:
world_ale_1d_cache = SimpleCache("world_ale_1d", cache_dir=CACHE_DIR)
ptp_values, mc_ptp_values = world_ale_1d_cache.load()

ale_1d_df = pd.DataFrame(ptp_values, index=["ALE 1D (PTP)"])
ale_1d_df.columns.name = "Name"

ale_1d_mc_df = pd.DataFrame(mc_ptp_values, index=["ALE 1D MC (PTP)"])
ale_1d_mc_df.columns.name = "Name"

# Re-index according to the same ordering as for the Gini importances!
ale_1d_df.reindex(mean_importances.index, axis=1)
ale_1d_mc_df.reindex(mean_importances.index, axis=1)

shorten_columns(ale_1d_df, inplace=True)
shorten_columns(ale_1d_mc_df, inplace=True)


def ale_1d_plot(ax, N_col):
    sns.boxplot(data=ale_1d_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="ALE 1D\n")


def ale_1d_mc_plot(ax, N_col):
    sns.boxplot(data=ale_1d_mc_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="ALE 1D MC\n")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 12))
plot_importances(ale_1d_df, ax=axes[0])
axes[0].set_title("ALE 1D")
plot_importances(ale_1d_mc_df, ax=axes[1])
axes[1].set_title("ALE 1D MC")
for ax in axes:
    ax.set_ylabel("")

plt.tight_layout()

### ALE 2D - very cursory analysis

Does not take into account which of the 2 variables is the one responsible for the interaction.

In [None]:
world_ale_2d_cache = SimpleCache("world_ale_2d", cache_dir=CACHE_DIR)
ptp_2d_values = world_ale_2d_cache.load()

interaction_data = defaultdict(float)
for feature in X_train.columns:
    for feature_pair, ptp_2d_value in ptp_2d_values.items():
        if feature in feature_pair:
            interaction_data[feature] += ptp_2d_value

ale_2d_df = pd.DataFrame(interaction_data, index=["ALE 2D (PTP)"])
ale_2d_df.columns.name = "Name"

# Re-index according to the same ordering as for the Gini importances!
ale_2d_df.reindex(mean_importances.index, axis=1)

shorten_columns(ale_2d_df, inplace=True)


def ale_2d_plot(ax, N_col):
    sns.boxplot(data=ale_2d_df.iloc[:, :N_col], ax=ax)
    ax.set(ylabel="ALE 2D\n")

In [None]:
plot_importances(ale_2d_df)

### Combining the plots

In [None]:
N_col = 20

plot_funcs = (
    gini_plot,
    pfi_plot,
    shap_plot,
    loco_plot,
    #     ale_1d_plot,
    #     ale_1d_mc_plot,
    #     ale_2d_plot,
    #     vif_plot,
)

fig, axes = plt.subplots(
    len(plot_funcs), 1, sharex=True, figsize=(7, 1.8 + 2 * len(plot_funcs))
)

for plot_func, ax in zip(plot_funcs, axes):
    plot_func(ax, N_col)

# Rotate the last x axis labels (the only visible ones).
axes[-1].set_xticklabels(axes[-1].get_xticklabels(), rotation=45, ha="right")

for _ax in axes:
    _ax.grid(which="major", alpha=0.4, linestyle="--")
    _ax.tick_params(labelleft=False)
    _ax.yaxis.get_major_formatter().set_scientific(False)

for _ax in axes[:-1]:
    _ax.set_xlabel("")

# fig.suptitle("Gini, PFI, SHAP, VIF")
plt.tight_layout()
plt.subplots_adjust(top=0.91)
figure_saver.save_figure(
    fig,
    "_".join(
        (
            "feature_importances",
            *(func.__name__.split("_plot")[0] for func in plot_funcs),
        )
    ),
)

In [None]:
importances = {
    "Gini": ind_trees_gini,
    "PFI": pfi_ind,
    "SHAP": mean_shap_importances,
    "LOCO": loco_df,
    "ALE 1D": ale_1d_df,
    "ALE 1D MC": ale_1d_mc_df,
    "ALE 2D": ale_2d_df,
    "VIF": vifs,
}
for key, df in importances.items():
    importances[key] = df.mean().sort_values(ascending=False)

In [None]:
table_str = np.array([df.index.values for df in importances.values()]).T

In [None]:
def transform(x):
    """Transform x to be in [0, 1]."""
    x = np.asanyarray(x)
    x = x - np.min(x)
    return x / np.max(x)

In [None]:
# 4 groups of variables - vegetation, landcover, human, meteorological

divisions = {
    "vegetation": (70, 150),  # 4 + 4 x 7: 32.
    "landcover": (150, 230),  # 4: 4.
    "human": (230, 270),  #  2: 2.
    "meteorology": (270, 430),  # 5 + 7: 12.
}

division_members = {
    "vegetation": 4,
    "landcover": 4,
    "human": 2,
    "meteorology": 5,
}

division_names = {
    category: shorten_features(features)
    for category, features in feature_categories.items()
}

var_keys = []
var_H_vals = []
factors = []

for division in divisions:
    var_keys.extend(division_names[division])
    var_H_vals.extend(
        np.linspace(
            *divisions[division], division_members["vegetation"], endpoint=False
        )
        % 360
    )
    factors.extend(np.linspace(0, 1, division_members["vegetation"]))


shifts = [0, 1, 3, 6, 9, 12, 18, 24]


def combined_get_colors(x):
    assert len(x.shape) == 2
    out = []
    for x_i in x:
        out.append([])
        for x_ij in x_i:
            match_obj = re.search("(.*)\s.{,1}(\d+)M", x_ij)
            if match_obj:
                x_ij_mod = match_obj.group(1)
                shift = int(match_obj.group(2))
            else:
                x_ij_mod = x_ij
                shift = 0
            index = var_keys.index(x_ij_mod)
            H = var_H_vals[index]
            S = 1.0 - 0.3 * (shifts.index(shift) / (len(shifts) - 1))
            V = 0.85 - 0.55 * (shifts.index(shift) / (len(shifts) - 1))

            S -= factors[index] * 0.2
            V -= factors[index] * 0.06

            out[-1].append(hsluv_to_rgb((H, S * 100, V * 100)))
    return out


# Define separate functions for each of the categories on their own.
ind_get_color_funcs = []
for division in divisions:

    def get_colors(x, division=division):
        assert len(x.shape) == 2
        out = []
        for x_i in x:
            out.append([])
            for x_ij in x_i:
                match_obj = re.search("(.*)\s.{,1}(\d+)M", x_ij)
                if match_obj:
                    x_ij_mod = match_obj.group(1)
                    shift = int(match_obj.group(2))
                else:
                    x_ij_mod = x_ij
                    shift = 0

                if x_ij_mod not in division_names[division]:
                    out[-1].append((1, 1, 1))
                else:
                    index = division_names[division].index(x_ij_mod)
                    desat = 0.85 - 0.7 * (shifts.index(shift) / (len(shifts) - 1))
                    out[-1].append(
                        sns.color_palette(
                            "Set1", n_colors=division_members[division], desat=desat
                        )[index]
                    )
        return out

    ind_get_color_funcs.append(get_colors)


for get_colors, suffix in zip(
    (combined_get_colors, *ind_get_color_funcs),
    ("combined", *divisions),
):
    fig = plt.figure(figsize=(12, 18))
    spec = fig.add_gridspec(ncols=2, nrows=1, width_ratios=[3, 1])
    axes = [fig.add_subplot(s) for s in spec]

    def table_importance_plot(x, **kwargs):
        axes[1].plot(transform(x), np.linspace(1, 0, len(table_str)), **kwargs)

    axes[0].set_axis_off()
    table = axes[0].table(
        table_str,
        loc="left",
        rowLabels=range(1, len(table_str) + 1),
        bbox=[0, 0, 1, 1],
        colLabels=list(importances.keys()),
        cellColours=get_colors(table_str),
    )
    table.auto_set_font_size(False)
    table.set_fontsize(8)

    color_dict = {
        "Gini": "C0",
        "PFI": "C1",
        "SHAP": "C2",
        "LOCO": "C3",
        "ALE 1D": "C4",
        "ALE 1D MC": "C4",
        "ALE 2D": "C4",
        "VIF": "C5",
    }

    ls_dict = {
        "Gini": "-",
        "PFI": "-",
        "SHAP": "-",
        "LOCO": "-",
        "ALE 1D": "-",
        "ALE 1D MC": "--",
        "ALE 2D": "-.",
        "VIF": "-",
    }

    for (importance_measure, importance_values), marker in zip(
        importances.items(),
        ["+", "x", "|", "_", "1", "2", "3", "4", "d"],
    ):
        table_importance_plot(
            importance_values,
            label=importance_measure,
            marker=marker,
            c=color_dict[importance_measure],
            ls=ls_dict[importance_measure],
            ms=8,
        )

    axes[1].yaxis.set_label_position("right")
    axes[1].yaxis.tick_right()

    cell_height = 1 / (table_str.shape[0] + 1)
    axes[1].set_ylim(-cell_height / 2, 1 + (3 / 2) * cell_height)
    axes[1].set_yticks(np.linspace(1, 0, table_str.shape[0]))
    axes[1].set_yticklabels(range(1, table_str.shape[0] + 1))

    axes[1].set_xlim(0, 1)
    axes[1].set_xticks([0, 1])
    axes[1].set_xticklabels([0, 1])

    axes[1].set_xticks(np.linspace(0, 1, 8), minor=True)

    axes[1].grid(alpha=0.4, linestyle="--")
    axes[1].grid(which="minor", axis="x", alpha=0.4, linestyle="--")

    axes[1].legend(loc="best")

    plt.tight_layout()

    figure_saver.save_figure(
        fig, "_".join(("feature_importance_breakdown", suffix)).strip("_")
    )

In [None]:
unique_str = np.unique(table_str)
colors = get_colors(unique_str.reshape(1, -1))[0]

In [None]:
def hsluv_conv(hsv):
    out = []
    for x_i in hsv:
        out.append([])
        for x_ij in x_i:
            out[-1].append(hsluv_to_rgb(x_ij))
    return np.array(out)


V, H = np.mgrid[0:1:100j, 0:1:100j]
S = np.ones_like(V) * 1
HSV = np.dstack((H * 360, S * 100, V * 100))
RGB = hsluv_conv(HSV)

plt.figure(figsize=(20, 20))
plt.imshow(RGB, origin="lower", extent=[0, 360, 0, 100], aspect=2)
plt.xlabel("H")
plt.ylabel("V")

for color in colors:
    h, s, v = rgb_to_hsluv(color)
    for (division, values), marker in zip(divisions.items(), ["+", "x", "_", "|"]):
        if (values[0] - 1e-5) < h and h < (values[1] + 1e-5):
            break
    plt.plot(h, v, marker=marker, linestyle="", c="k")

## Choose the 15 most important features using the above metrics

In [None]:
list(importances)

In [None]:
methods = ["Gini", "PFI", "LOCO", "SHAP"]
combined = plot_and_list_importances(importances, methods, N=None)

print("Top 15:\n")
print("\n".join(list(combined[:15].index)))

no_veg = [
    c
    for c in list(combined.index)
    if not any(veg in c for veg in shorten_features(feature_categories["vegetation"]))
]
print(f"\nAll without vegetation: {len(no_veg)}\n")
print("\n".join(no_veg))

short_lags = [l for l in lags if int(l) < 12]

n_remain = 15 - len(short_lags)
print("\nN short lags:", len(short_lags), "\n")

print(f"\nTop {n_remain} without vegetation:\n")
print("\n".join(no_veg[:n_remain]))

In [None]:
c2 = combined.reset_index()
c2.columns = ["Variable", "Importance"]
c2.index = list(range(1, len(c2.index) + 1))
with shelve.open(fi_shelve_file) as db:
    db[PROJECT_DIR.name] = c2
c2