#### Setup

In [None]:
from specific import *

### Specify the experiments to compare

In [None]:
experiments = ("15_most_important", "no_temporal_shifts", "fire_seasonality_paper")

### Load data

In [None]:
experiment_data = load_experiment_data(experiments)

### Check that the masks are aligned

In [None]:
comp_masks = [experiment_data[experiment]["master_mask"] for experiment in experiments]
assert all(np.all(comp_masks[0] == comp_mask) for comp_mask in comp_masks[1:])

## Combining multiple ALE plots across models

In [None]:
# X_train.columns

In [None]:
# 'Pretty' experiment names mapping.
experiment_names = {
    "15_most_important": "With Lagged",
    "no_temporal_shifts": "Only Unlagged",
    "fire_seasonality_paper": "All",
}

for feature_name in tqdm(
    ("Dry Day Period", "SIF", "FAPAR", "LAI", "VOD Ku-band"),
    desc="Multiple shift ALE plots",
):
    model_X_cols = []
    plot_kwargs_list = []

    fig_name = f'comp_{feature_name.replace(" ", "_").lower()}_ale_shifts'
    for experiment in experiments[2:]:
        for feature in (
            f"{feature_name}",
            *(f"{feature_name} {m} Month" for m in (-1, -3, -6, -9)),
        ):
            if feature not in experiment_data[experiment]["X_train"]:
                continue

            model_X_cols.append(
                (
                    experiment_data[experiment]["model"],
                    experiment_data[experiment]["X_train"],
                    feature,
                )
            )
            plot_kwargs_list.append(
                {
                    "label": f"{experiment_names[experiment]} - {shorten_features(feature)}"
                }
            )
    assert model_X_cols and plot_kwargs_list

    multi_model_ale_plot_1d(
        model_X_cols,
        plot_kwargs_list,
        fig_name=fig_name,
        n_jobs=get_ncpus(),
        verbose=True,
        xlabel=f"{feature}",
        title=f"First-order ALE for {feature}",
        figure_saver=figure_saver,
    )