In [None]:
%run model_fitting.py local
%run calculate_pfi.py local
%run calculate_loco_values.py local
%run calculate_shap_values.py local

from enum import Enum

In [None]:
import pandas as pd

from empirical_fire_modelling.utils import tqdm, transform_series_sum_norm

In [None]:
gini_importances = {}
for exp, rf in tqdm(models.items(), desc="Gini"):
    X_train, X_test, y_train, y_test = get_experiment_split_data(exp)

    ind_trees_gini = pd.DataFrame(
        [tree.feature_importances_ for tree in rf],
        columns=X_train.columns,
    )
    mean_importances = ind_trees_gini.mean().sort_values(ascending=False)
    std_importances = ind_trees_gini.std().reindex(mean_importances.index, axis=1)
    gini_importances[exp] = pd.DataFrame(
        {"mean GINI": mean_importances, "std GINI": std_importances}
    )

In [None]:
Metric = Enum("Metric", ["GINI", "SHAP", "PFI", "LOCO"])

In [None]:
importance_data = {
    Metric.GINI: gini_importances,
    Metric.SHAP: shap_importances,
    Metric.PFI: pfi_importances,
    Metric.LOCO: loco_importances,
}
importance_keys = {
    Metric.GINI: "mean GINI",
    Metric.SHAP: "mean SHAP",
    Metric.PFI: "test weight",
    Metric.LOCO: "test score",
}
# std_keys = {
#     Metric.GINI: "std GINI",
#     Metric.SHAP: "std SHAP",
#     Metric.PFI: "test std",
#     # Note absence of LOCO std.
# }

In [None]:
for exp in tqdm(Experiment, desc="Experiment"):
    plt.figure(figsize=(8 * importance_data[Metric.GINI][exp].shape[0] / 15, 5))
    plt.title(exp.name)

    transformed_importances = {}

    # Calculation. Sort according to combined metric after normalisation.
    combined = None
    for importance_metric in Metric:
        importance_df = transform_series_sum_norm(
            importance_data[importance_metric][exp][importance_keys[importance_metric]]
        )
        transformed_importances[importance_metric] = importance_df
        if combined is None:
            combined = importance_df.copy()
        else:
            combined += importance_df
    combined.sort_values(ascending=False, inplace=True)

    transformed_importances = pd.DataFrame(transformed_importances).reindex(
        combined.index, axis=0
    )

    # Plotting.
    for importance_metric in Metric:
        transformed = transformed_importances[importance_metric]
        transformed.index = transformed.copy().index.map(str)
        plt.plot(transformed, label=importance_metric.name)

        _ = plt.setp(plt.gca().xaxis.get_majorticklabels(), rotation=45, ha="right")