In [None]:
# %run model_fitting.py --experiment ALL --experiment ALL_NN --experiment CURR --experiment 15VEG_FAPAR local
# %run calculate_pfi.py --experiment ALL --experiment ALL_NN --experiment CURR --experiment 15VEG_FAPAR local
# %run calculate_loco_values.py --experiment ALL --experiment ALL_NN --experiment CURR --experiment 15VEG_FAPAR local
# %run calculate_shap_values.py --experiment ALL --experiment ALL_NN --experiment CURR --experiment 15VEG_FAPAR local

In [None]:
%run model_fitting.py --experiment CLIM local
%run calculate_pfi.py --experiment CLIM local
%run calculate_loco_values.py --experiment CLIM local
%run calculate_shap_values.py --experiment CLIM local

In [None]:
import gc
import string
from collections import defaultdict
from copy import deepcopy
from enum import Enum

import matplotlib.pyplot as plt
import pandas as pd
from IPython.core.display import HTML, display

from empirical_fire_modelling import variable
from empirical_fire_modelling.model import get_gini_importances
from empirical_fire_modelling.plotting import figure_saver
from empirical_fire_modelling.utils import tqdm, transform_series_sum_norm

In [None]:
gini_importances = {}
for exp, rf in tqdm(models.items(), desc="Gini"):
    X_train, X_test, y_train, y_test = get_experiment_split_data(exp)
    mean_importances, std_importances = get_gini_importances(X_train, y_train)
    gini_importances[exp] = pd.DataFrame(
        {"mean GINI": mean_importances, "std GINI": std_importances}
    )
    gc.collect()

In [None]:
Metric = Enum("Metric", ["GINI", "SHAP", "PFI", "LOCO"])

In [None]:
importance_data = {
    Metric.GINI: gini_importances,
    Metric.SHAP: shap_importances,
    Metric.PFI: pfi_importances,
    Metric.LOCO: loco_importances,
}
importance_keys = {
    Metric.GINI: "mean GINI",
    Metric.SHAP: "test mean SHAP",
    Metric.PFI: "test weight",
    Metric.LOCO: "test score",
}
std_keys = {
    Metric.GINI: "std GINI",
    Metric.SHAP: "test std SHAP",
    Metric.PFI: "test std",
    # Note absence of LOCO std.
}

In [None]:
combined_data = {}
plot_data = defaultdict(dict)

for exp in tqdm(models, desc="Experiment"):
    transformed_importances = {}
    transformed_importances_std = {}

    # Calculation. Sort according to combined metric after normalisation.
    combined = None
    for importance_metric in Metric:
        if importance_metric in std_keys:
            importance_s, importance_std_s = transform_series_sum_norm(
                x=importance_data[importance_metric][exp][
                    importance_keys[importance_metric]
                ],
                y=importance_data[importance_metric][exp][std_keys[importance_metric]],
            )
        else:
            importance_s = transform_series_sum_norm(
                x=importance_data[importance_metric][exp][
                    importance_keys[importance_metric]
                ]
            )
            importance_std_s = None

        transformed_importances[importance_metric] = importance_s
        transformed_importances_std[importance_metric] = importance_std_s

        if combined is None:
            combined = importance_s.copy()
        else:
            combined += importance_s
    combined.sort_values(ascending=False, inplace=True)

    transformed_importances = pd.DataFrame(transformed_importances).reindex(
        combined.index, axis=0
    )

    combined.name = f"{exp.name} (combined)"
    combined_data[exp] = combined.copy()

    # Plotting.
    for importance_metric in Metric:
        transformed = transformed_importances[importance_metric]
        transformed.index = transformed.copy().index.map(str)

        if transformed_importances_std[importance_metric] is not None:
            transformed_std = transformed_importances_std[importance_metric].reindex(
                combined.index, axis=0
            )
            transformed_std.index = transformed_std.copy().index.map(str)
        else:
            transformed_std = None

        plot_data[exp][importance_metric.name] = {
            "mean": transformed,
            "std": transformed_std,
        }

In [None]:
for exp, metric_data in plot_data.items():
    plt.figure(figsize=(8 * importance_data[Metric.GINI][exp].shape[0] / 15, 5))
    plt.title(exp.name)
    for metric_name, transformed in metric_data.items():
        plt.plot(transformed["mean"], label=metric_name)
        if transformed["std"] is not None:
            # Add shaded region to illustrate the std.
            plt.fill_between(
                transformed["mean"].index,
                transformed["mean"] - transformed["std"],
                transformed["mean"] + transformed["std"],
                label=f"{metric_name} std",
                alpha=0.1,
            )
        _ = plt.setp(plt.gca().xaxis.get_majorticklabels(), rotation=45, ha="right")
    plt.legend()

In [None]:
for combined in combined_data.values():
    print(combined)
    print()

In [None]:
def get_html_rep(data, name):
    data = list(data)
    return pd.DataFrame(
        list(map(str, data)), columns=[name], index=list(range(1, len(data) + 1))
    )._repr_html_()

In [None]:
print_data = []
for exp, combined in combined_data.items():
    veg_mask = np.array(
        [
            variable.match_factory(
                var, variable.feature_categories[variable.Category.VEGETATION]
            )
            for var in combined.index
        ]
    )
    print_data.append(
        f"""<h1>{exp.name}</h1>
        <h2>Top 15</h2>
        {get_html_rep(combined.index[:15], 'Top 15')}
        <h2>Veg Features</h2>
        {get_html_rep(combined.index[veg_mask], 'Veg')}
        <h2>Non-Veg Features</h2>
        {get_html_rep(combined.index[~veg_mask], 'Non-Veg')}
        <h2>Top 10 Non-Veg features</h2>
        {get_html_rep(combined.index[~veg_mask][:10], 'Top 10 Non-Veg')}"""
    )
display(HTML("".join(print_data)))

### Comparative plotting of importances

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(5, 7.8))

for exp, ax, title in zip(
    [
        Experiment.ALL,
        Experiment.ALL_NN,
        Experiment["15VEG_FAPAR"],
    ],
    axes,
    string.ascii_lowercase,
):
    metric_data = plot_data[exp]
    ax.set_title(f"({title}) {exp.name}")
    for metric_name, transformed in metric_data.items():
        ax.plot(
            list(
                map(lambda s: shorten_features(str(s)), transformed["mean"].index[:15])
            ),
            transformed["mean"][:15].values,
            label=(metric_name if ax is axes[0] else None),
        )
        if transformed["std"] is not None:
            # Add shaded region to illustrate the std.
            ax.fill_between(
                list(
                    map(
                        lambda s: shorten_features(str(s)),
                        transformed["mean"].index[:15],
                    )
                ),
                (transformed["mean"] - transformed["std"])[:15],
                (transformed["mean"] + transformed["std"])[:15],
                # label=f"{metric_name} std",
                alpha=0.1,
            )
        _ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=34, ha="right")

    if ax is axes[0]:
        ax.legend()

    ax.set_ylabel("Importance")
    ax.set_xlim(-0.15, 14.15)

fig.subplots_adjust(hspace=0.45)
fig.align_labels()

figure_saver.save_figure(fig, "model_comp_importances")

In [None]:
fig, axes = plt.subplots(4, 1, figsize=(8.15, 8))

exp = Experiment.ALL

metric_data = plot_data[exp]

for (ax, (metric_name, transformed)) in zip(axes, metric_data.items()):
    ax.set_title(f"{metric_name}")

    sorted_transformed = deepcopy(transformed)["mean"].sort_values(ascending=False)

    ax.plot(
        list(map(lambda s: shorten_features(str(s)), sorted_transformed.index)),
        sorted_transformed.values,
        label=(metric_name if ax is axes[0] else None),
    )

    if transformed["std"] is not None:
        sorted_std = deepcopy(transformed["std"]).reindex(sorted_transformed.index)

        # Add shaded region to illustrate the std.
        ax.fill_between(
            list(map(lambda s: shorten_features(str(s)), sorted_transformed.index)),
            (sorted_transformed - sorted_std),
            (sorted_transformed + sorted_std),
            # label=f"{metric_name} std",
            alpha=0.1,
        )
    _ = plt.setp(ax.xaxis.get_majorticklabels(), rotation=52, ha="right")

    ax.set_ylabel("Importance")
    ax.set_xlim(-0.15, 49.15)
    ax.set_ylim(ax.get_ylim())
    ax.vlines(14, *ax.get_ylim(), linestyle="--", color="C3")

fig.subplots_adjust(hspace=0.95)
fig.align_labels()

figure_saver.save_figure(fig, "importance_metric_comp")

### Table of combined importances

In [None]:
from wildfires.utils import shorten_features

In [None]:
combined_table_df = pd.DataFrame(
    {
        exp.name: pd.Series(shorten_features(list(map(str, combined.index))))
        for exp, combined in combined_data.items()
    }
)
combined_table_df.index = np.arange(1, 51)
combined_table_df

In [None]:
latex_df = combined_table_df.to_latex(na_rep="")
latex_df = latex_df.replace("toprule", "tophline")
latex_df = latex_df.replace("midrule", "middlehline")
latex_df = latex_df.replace("bottomrule", "bottomhline")
latex_df = latex_df.replace("Δ", r"\ensuremath{\Delta}")
print(latex_df)