#### Setup

In [None]:
from specific import *

### Specify the experiments to compare

In [None]:
experiment_zorder_dict = {
    "all": 7,
    "15_most_important": 6,
    "no_temporal_shifts": 5,
    "fapar_only": 3,
    "lai_only": 3,
    "sif_only": 3,
    "vod_only": 3,
    "lagged_fapar_only": 2,
    "lagged_lai_only": 2,
    "lagged_sif_only": 2,
    "lagged_vod_only": 2,
    "best_top_15": 4,
}

# Update color and marker dicts.

# Mirror values for single-parameter experiments, since there can be no clashes
# (except for the dry days case).

experiment_color_dict.update(
    {
        **{
            exp: experiment_colors[4]
            for exp in (
                "fapar_only",
                "lai_only",
                "sif_only",
                "vod_only",
            )
        },
        **{
            exp: experiment_colors[5]
            for exp in (
                "lagged_fapar_only",
                "lagged_lai_only",
                "lagged_sif_only",
                "lagged_vod_only",
            )
        },
    }
)

experiment_marker_dict.update(
    {
        **{
            exp: "|"
            for exp in (
                "fapar_only",
                "lai_only",
                "sif_only",
                "vod_only",
            )
        },
        **{
            exp: "^"
            for exp in (
                "lagged_fapar_only",
                "lagged_lai_only",
                "lagged_sif_only",
                "lagged_vod_only",
            )
        },
    }
)

experiment_plot_kwargs = {
    experiment: {
        "label": experiment_name_dict[experiment],
        "c": experiment_color_dict[experiment],
        "marker": experiment_marker_dict[experiment],
        "zorder": experiment_zorder_dict[experiment],
    }
    for experiment in experiment_name_dict
}

### Load data

In [None]:
experiment_data = load_experiment_data(
    list(experiment_name_dict),
    ignore=(
        "endog_data",
        "exog_data",
        "filled_datasets",
        "masked_datasets",
        "land_mask",
    ),
)

### Check that the masks are aligned

In [None]:
comp_masks = [
    experiment_data[experiment]["master_mask"]
    for experiment in list(experiment_name_dict)
]
assert all(np.all(comp_masks[0] == comp_mask) for comp_mask in comp_masks[1:])

## Combining multiple ALE plots across models

In [None]:
legend_bboxes = {
    "Dry Day Period": (0.5, 1.02),
    "FAPAR": (0.5, 1.02),
    "LAI": (0.5, 1.02),
    "SIF": (0.5, 1.02),
    "VOD Ku-band": (0.5, 1.02),
}

for feature, legend_bbox in tqdm(
    legend_bboxes.items(), total=len(legend_bboxes), desc="Features"
):
    multi_model_ale_1d(
        get_filled_names(feature),
        experiment_data,
        experiment_plot_kwargs,
        n_jobs=get_ncpus(),
        verbose=False,
        legend_bbox=legend_bbox,
        figure_saver=figure_saver,
    )

### Join the above plots for multiple models

In [None]:
template_comp_exps = [
    exp.replace("fapar", "xxx")
    for exp in sort_experiments(
        [
            "all",
            "best_top_15",
            "15_most_important",
            "no_temporal_shifts",
            "fapar_only",
            "lagged_fapar_only",
        ]
    )
]
lags = (0, 1, 3, 6, 9)
for comp_vars in [["FAPAR", "LAI"], ["SIF", "VOD Ku-band"]]:
    # Check that the required data is available.
    for unfilled_comp_var, comp_var in zip(comp_vars, get_filled_names(comp_vars)):
        comp_exps = [
            exp.replace("xxx", unfilled_comp_var.lower()) for exp in template_comp_exps
        ]
        for lag in lags:
            if lag:
                feature = f"{comp_var} {-lag} Month"
            else:
                feature = comp_var

            assert any(
                feature in experiment_data[exp]["X_train"].columns for exp in comp_exps
            )

    # Create general legend labels (with 'X' instead of FAPAR, or LAI, etc...).
    legend_labels = [
        experiment_name_dict[exp].replace("FAPAR", "X")
        # FAPAR replacement is used as a placeholder here - the same conversion
        # applies to all the other vegetation variables too.
        for exp in [exp.replace("xxx", "fapar") for exp in template_comp_exps]
    ]

    fig, axes = plt.subplots(
        5, 2, sharex="col", figsize=np.array([3.525, 0.9793]) * np.array([2, 5])
    )

    rotation = 25

    multi_model_ale_1d(
        get_filled_names(comp_vars[0]),
        experiment_data,
        experiment_plot_kwargs,
        n_jobs=get_ncpus(),
        verbose=False,
        legend_bbox=(0.5, 1.01),
        fig=fig,
        axes=axes[:, 0:1],
        legend_labels=legend_labels,
        lags=lags,
        rotation=rotation,
    )
    multi_model_ale_1d(
        get_filled_names(comp_vars[1]),
        experiment_data,
        experiment_plot_kwargs,
        n_jobs=get_ncpus(),
        verbose=False,
        legend=False,
        fig=fig,
        axes=axes[:, 1:2],
        lags=lags,
        rotation=rotation,
    )

    for ax in axes[:, 1]:
        ax.set_ylabel("")
    for ax in axes[:, 0]:
        lag_match = re.search("(\dM)", ax.get_xlabel())
        if lag_match:
            lag_m = f" ({lag_match.group(1)})"
        else:
            lag_m = ""
        ax.set_ylabel(f"ALE{lag_m} (BA)")

    for ax in axes.flatten():
        ax.set_xlabel("")

    for ax, exp in zip(axes[-1], shorten_features(comp_vars)):
        ax.set_xlabel(add_units(exp), va="center_baseline", labelpad=6)

    for ax, title in zip(axes.flatten(), ascii_lowercase):
        ax.text(0.5, 1.06, f"({title})", transform=ax.transAxes)

    fig.tight_layout(
        h_pad=0.4,
        w_pad=0.5,
    )
    fig.align_labels()  # align_ylabels()

    #     # Explicitly set the x-axis labels' positions so they line up horizontally.
    #     y_min = 1
    #     for ax in axes[-1]:
    #         bbox = ax.get_position()
    #         if bbox.ymin < y_min:
    #             y_min = bbox.ymin
    #     for ax in axes[-1]:
    #         bbox = ax.get_position()
    #         mean_x = (bbox.xmin + bbox.xmax) / 2.
    #         ax.xaxis.set_label_coords(mean_x, y_min - 0.074, transform=fig.transFigure)

    shortened_vars = [
        var.replace(" ", "_").lower()
        for var in shorten_features(repl_fill_names(comp_vars))
    ]

    figure_saver.save_figure(
        fig,
        f"{'__'.join(shortened_vars)}_ale_comp",
        sub_directory="ale_comp",
    )