In [None]:
%run model_fitting.py --experiment ALL --experiment ALL_NN --experiment CURR local
%run calculate_pfi.py --experiment ALL --experiment ALL_NN --experiment CURR local
%run calculate_loco_values.py --experiment ALL --experiment ALL_NN --experiment CURR local
%run calculate_shap_values.py --experiment ALL --experiment ALL_NN --experiment CURR local

from collections import defaultdict

In [None]:
from enum import Enum

import pandas as pd

from empirical_fire_modelling import variable
from empirical_fire_modelling.utils import tqdm, transform_series_sum_norm

In [None]:
gini_importances = {}
for exp, rf in tqdm(models.items(), desc="Gini"):
    X_train, X_test, y_train, y_test = get_experiment_split_data(exp)

    ind_trees_gini = pd.DataFrame(
        [tree.feature_importances_ for tree in rf],
        columns=X_train.columns,
    )
    mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
    std_importances = ind_trees_gini.std().reindex(mean_importances.index, axis=1)
    gini_importances[exp] = pd.DataFrame(
        {"mean GINI": mean_importances, "std GINI": std_importances}
    )

In [None]:
Metric = Enum("Metric", ["GINI", "SHAP", "PFI", "LOCO"])

In [None]:
importance_data = {
    Metric.GINI: gini_importances,
    Metric.SHAP: shap_importances,
    Metric.PFI: pfi_importances,
    Metric.LOCO: loco_importances,
}
importance_keys = {
    Metric.GINI: "mean GINI",
    Metric.SHAP: "test mean SHAP",
    Metric.PFI: "test weight",
    Metric.LOCO: "test score",
}
# std_keys = {
#     Metric.GINI: "std GINI",
#     Metric.SHAP: "std SHAP",
#     Metric.PFI: "test std",
#     # Note absence of LOCO std.
# }

In [None]:
combined_data = {}
plot_data = defaultdict(dict)

for exp in tqdm(models, desc="Experiment"):
    plt.figure(figsize=(8 * importance_data[Metric.GINI][exp].shape[0] / 15, 5))
    plt.title(exp.name)

    transformed_importances = {}

    # Calculation. Sort according to combined metric after normalisation.
    combined = None
    for importance_metric in Metric:
        importance_s = transform_series_sum_norm(
            importance_data[importance_metric][exp][importance_keys[importance_metric]]
        )
        transformed_importances[importance_metric] = importance_s
        if combined is None:
            combined = importance_s.copy()
        else:
            combined += importance_s
    combined.sort_values(ascending=False, inplace=True)

    transformed_importances = pd.DataFrame(transformed_importances).reindex(
        combined.index, axis=0
    )

    combined.name = f"{exp.name} (combined)"
    combined_data[exp] = combined.copy()

    # Plotting.
    for importance_metric in Metric:
        transformed = transformed_importances[importance_metric]
        transformed.index = transformed.copy().index.map(str)

        plot_data[exp][importance_metric.name] = importance_metric

        plt.plot(transformed, label=importance_metric.name)
        _ = plt.setp(plt.gca().xaxis.get_majorticklabels(), rotation=45, ha="right")
    plt.legend()

In [None]:
for combined in combined_data.values():
    print(combined)
    print()

In [None]:
from IPython.core.display import HTML, display

In [None]:
def get_html_rep(data, name):
    data = list(data)
    return pd.DataFrame(
        list(map(str, data)), columns=[name], index=list(range(1, len(data) + 1))
    )._repr_html_()

In [None]:
print_data = []
for exp, combined in combined_data.items():
    veg_mask = np.array(
        [
            variable.match_factory(
                var, variable.feature_categories[variable.Category.VEGETATION]
            )
            for var in combined.index
        ]
    )
    print_data.append(
        f"""<h1>{exp.name}</h1>
        <h2>Top 15</h2>
        {get_html_rep(combined.index[:15], 'Top 15')}
        <h2>Veg Features</h2>
        {get_html_rep(combined.index[veg_mask], 'Veg')}
        <h2>Non-Veg Features</h2>
        {get_html_rep(combined.index[~veg_mask], 'Non-Veg')}
        <h2>Top 10 Non-Veg features</h2>
        {get_html_rep(combined.index[~veg_mask][:10], 'Top 10 Non-Veg')}"""
    )
display(HTML("".join(print_data)))