#### Setup

In [None]:
from specific import *

### Specify the experiments to compare

In [None]:
experiments = ("all", "15_most_important", "no_temporal_shifts")
zorders = [4, 3, 2]
experiment_zorder_dict = {
    experiment: zorder for experiment, zorder in zip(experiments, zorders)
}

### Load data

In [None]:
experiment_data = load_experiment_data(experiments)

### Check that the masks are aligned

In [None]:
comp_masks = [experiment_data[experiment]["master_mask"] for experiment in experiments]
assert all(np.all(comp_masks[0] == comp_mask) for comp_mask in comp_masks[1:])

## Combining multiple ALE plots across models

In [None]:
for experiment in experiments:
    print(experiment)
    print()
    print("\n".join(sort_features(experiment_data[experiment]["X_train"].columns)))
    print()

In [None]:
for feature_name in tqdm(
    ("Dry Day Period", "SIF 3NN", "FAPAR 3NN", "LAI 3NN", "VOD Ku-band 3NN"),
    desc="Features",
):
    for lag in tqdm([0, 1, 3, 6, 9], desc="Lags"):
        if lag:
            feature = f"{feature_name} {-lag} Month"
        else:
            feature = feature_name

        short_feature = shorten_features(feature)

        model_X_cols_kwargs = []
        for experiment in experiments:
            # Skip experiments that do not contain this feature.
            if feature not in experiment_data[experiment]["X_train"]:
                continue

            # Data required to calculate the ALEs.
            model_X_cols_kwargs.append(
                (
                    experiment,
                    experiment_data,
                    feature,
                    {
                        "label": experiment_name_dict[experiment],
                        "c": experiment_color_dict[experiment],
                        "marker": experiment_marker_dict[experiment],
                        "zorder": experiment_zorder_dict[experiment],
                    },
                )
            )

        if len(model_X_cols_kwargs) <= 1:
            # We need at least two models for a comparison.
            continue

        fig_name = f'comp_{short_feature.replace(" ", "_").lower()}_ale_shifts'

        multi_model_ale_plot_1d(
            model_X_cols_kwargs,
            fig_name=fig_name,
            n_jobs=get_ncpus(),
            verbose=True,
            xlabel=short_feature,
            # title=f"First-order ALE for {short_feature}",
            figure_saver=figure_saver,
            figsize=(9, 1.5),
        )