#### Setup

In [None]:
from specific import *

### Specify the experiments to compare

In [None]:
experiments = ("15_most_important", "no_temporal_shifts", "fire_seasonality_paper")

### Load data

In [None]:
experiment_data = load_experiment_data(experiments)

### Check that the masks are aligned

In [None]:
comp_masks = [experiment_data[experiment]["master_mask"] for experiment in experiments]
assert all(np.all(comp_masks[0] == comp_mask) for comp_mask in comp_masks[1:])

## Combining multiple ALE plots across models

In [None]:
# X_train.columns

In [None]:
# 'Pretty' experiment names mapping.
experiment_names = {
    "15_most_important": "With Lag",
    "no_temporal_shifts": "Unlagged",
    "fire_seasonality_paper": "All",
}

default_kwargs = {
    "c": "0.5",
    "alpha": 0.5,
    "linestyle": "--",
}

feature_kwargs_map = {
    "fire_seasonality_paper": {
        "Dry Days": {"c": "C0", "marker": "o"},
        "Dry Days 1 M": {"marker": "v", **default_kwargs},
        "Dry Days 3 M": {"c": "C3", "marker": "x"},
        "Dry Days 6 M": {"marker": "^", **default_kwargs},
        "Dry Days 9 M": {"marker": ">", **default_kwargs},
        "SIF": {"c": "C0", "marker": "o"},
        "SIF 1 M": {"marker": "v", **default_kwargs},
        "SIF 3 M": {"marker": "^", **default_kwargs},
        "SIF 6 M": {"c": "C3", "marker": "x"},
        "SIF 9 M": {"marker": ">", **default_kwargs},
        "FAPAR": {"c": "C0", "marker": "o"},
        "FAPAR 1 M": {"marker": "v", **default_kwargs},
        "FAPAR 3 M": {"marker": "^", **default_kwargs},
        "FAPAR 6 M": {"marker": "<", **default_kwargs},
        "FAPAR 9 M": {"marker": ">", **default_kwargs},
        "LAI": {"c": "C0", "marker": "o"},
        "LAI 1 M": {"marker": "v", **default_kwargs},
        "LAI 3 M": {"c": "C3", "marker": "x"},
        "LAI 6 M": {"marker": "^", **default_kwargs},
        "LAI 9 M": {"marker": ">", **default_kwargs},
        "VOD": {"c": "C0", "marker": "o"},
        "VOD 1 M": {"c": "C1", "marker": "v"},
        "VOD 3 M": {"c": "C2", "marker": "x"},
        "VOD 6 M": {"c": "C3", "marker": "^"},
        "VOD 9 M": {"marker": ">", **default_kwargs},
    },
    "15_most_important": {
        "Dry Days": {"c": "C1", "marker": "o"},
        "Dry Days 3 M": {"c": "C4", "marker": "x"},
        "SIF": {"c": "C1", "marker": "o"},
        "SIF 6 M": {"c": "C4", "marker": "x"},
        "FAPAR": {"c": "C1", "marker": "o"},
        "LAI": {"c": "C1", "marker": "o"},
        "LAI 3 M": {"c": "C4", "marker": "x"},
        "VOD": {"c": "C4", "marker": "o"},
        "VOD 1 M": {"c": "C5", "marker": "v"},
        "VOD 3 M": {"c": "C6", "marker": "x"},
        "VOD 6 M": {"c": "C7", "marker": "^"},
    },
    "no_temporal_shifts": {
        "Dry Days": {"c": "C2", "marker": "o"},
        "SIF": {"c": "C2", "marker": "o"},
        "FAPAR": {"c": "C2", "marker": "o"},
        "LAI": {"c": "C2", "marker": "o"},
        "VOD": {"c": "C8", "marker": "o"},
    },
}

markersizes = {}

for feature_name in tqdm(
    ("Dry Day Period", "SIF", "FAPAR", "LAI", "VOD Ku-band"),
    desc="Multiple shift ALE plots",
):
    model_X_cols = []
    plot_kwargs_list = []

    fig_name = f'comp_{feature_name.replace(" ", "_").lower()}_ale_shifts'
    for experiment in experiments:
        for feature in (
            f"{feature_name}",
            *(f"{feature_name} {m} Month" for m in (-1, -3, -6, -9)),
        ):
            if feature not in experiment_data[experiment]["X_train"]:
                continue

            model_X_cols.append(
                (
                    experiment_data[experiment]["model"],
                    experiment_data[experiment]["X_train"],
                    feature,
                )
            )
            short_feature = shorten_features(feature)

            plot_kwargs = {
                "label": f"{experiment_names[experiment]} - {short_feature}",
                **feature_kwargs_map[experiment].get(short_feature, {}),
            }
            plot_kwargs["ms"] = markersizes.get(plot_kwargs.get("marker"), 4)

            plot_kwargs_list.append(plot_kwargs)
    assert model_X_cols and plot_kwargs_list

    multi_model_ale_plot_1d(
        model_X_cols,
        plot_kwargs_list,
        fig_name=fig_name,
        n_jobs=get_ncpus(),
        verbose=True,
        xlabel=f"{shorten_features(feature_name)}",
        title=f"First-order ALE for {shorten_features(feature_name)}",
        figure_saver=figure_saver,
    )