#### Setup

In [None]:
from specific import *

### Specify the experiments to compare

In [None]:
experiments = ["all", "15_most_important", "no_temporal_shifts", "best_top_15"]
zorders = [7, 6, 5, 4]
experiment_zorder_dict = {
    experiment: zorder for experiment, zorder in zip(experiments, zorders)
}
experiment_plot_kwargs = {
    experiment: {
        "label": experiment_name_dict[experiment],
        "c": experiment_color_dict[experiment],
        "marker": experiment_marker_dict[experiment],
        "zorder": experiment_zorder_dict[experiment],
    }
    for experiment in experiments
}

# Mirror values for single-parameter experiments, since there can be no clashes
# (except for the dry days case).
for experiment in ("fapar_only", "lai_only", "sif_only", "vod_only"):
    experiments.append(experiment)
    experiment_plot_kwargs[experiment] = {
        "label": experiment_name_dict[experiment],
        "c": experiment_colors[4],
        "marker": "|",
        "zorder": 3,
    }
for experiment in (
    "lagged_fapar_only",
    "lagged_lai_only",
    "lagged_sif_only",
    "lagged_vod_only",
):
    experiments.append(experiment)
    experiment_plot_kwargs[experiment] = {
        "label": experiment_name_dict[experiment],
        "c": experiment_colors[5],
        "marker": "^",
        "zorder": 2,
    }

### Load data

In [None]:
experiment_data = load_experiment_data(
    experiments,
    ignore=(
        "endog_data",
        "exog_data",
        "filled_datasets",
        "masked_datasets",
        "land_mask",
    ),
)

### Check that the masks are aligned

In [None]:
comp_masks = [experiment_data[experiment]["master_mask"] for experiment in experiments]
assert all(np.all(comp_masks[0] == comp_mask) for comp_mask in comp_masks[1:])

## Combining multiple ALE plots across models

In [None]:
# for experiment in experiments:
#     print(experiment)
#     print()
#     print("\n".join(sort_features(experiment_data[experiment]["X_train"].columns)))
#     print()

In [None]:
legend_bboxes = {
    "Dry Day Period": (0.5, 1.02),
    "FAPAR": (0.5, 1.02),
    "LAI": (0.5, 1.02),
    "SIF": (0.5, 1.02),
    "VOD Ku-band": (0.5, 1.02),
}

for feature, legend_bbox in tqdm(
    legend_bboxes.items(), total=len(legend_bboxes), desc="Features"
):
    multi_model_ale_1d(
        get_filled_names(feature),
        experiment_data,
        experiment_plot_kwargs,
        n_jobs=get_ncpus(),
        verbose=False,
        legend_bbox=legend_bbox,
        figure_saver=figure_saver,
    )