## Setup

In [None]:
from specific import *

In [None]:
(
    endog_data,
    exog_data,
    master_mask,
    filled_datasets,
    masked_datasets,
    land_mask,
) = get_offset_data()

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
rf = get_model()

### SHAP values

In [None]:
shap_values = shap_cache.load()

### BA in the train and validation sets

Valid elements are situated where master_mask is False

In [None]:
plt.hist(np.abs(shap_values.flatten()), bins=5000)
plt.xscale("log")
plt.yscale("log")

### Calculate 2D masked array SHAP values

In [None]:
map_shap_results = calculate_2d_masked_shap_values(X_train, master_mask, shap_values)

### Plotting maps of SHAP values

In [None]:
plot_shap_value_maps(
    X_train, map_shap_results, map_figure_saver, directory=Path("shap_maps") / "normal",
)

### High BA month mask

In [None]:
target_ba = get_masked_array(endog_data.values, master_mask)
mean_ba = np.ma.mean(target_ba, axis=0)
max_ba = np.ma.max(target_ba, axis=0)

In [None]:
root_dir = Path("weighted_shap_maps") / "high_ba"

high_ba_mask = (
    ~((target_ba > (2 * np.ma.mean(target_ba, axis=0))) & (np.ma.max(target_ba) > 1e-2))
).data
high_ba_mask |= target_ba.mask
high_ba_mask |= np.sum(high_ba_mask, axis=0) < 9

high_ba_sum = np.sum(high_ba_mask, axis=0)

plot_sum = np.ma.MaskedArray(12 - high_ba_sum, mask=np.isclose(high_ba_sum, 12))

boundaries = np.arange(np.min(plot_sum), np.max(plot_sum) + 2) - 0.5
fig, cbar = cube_plotting(
    plot_sum,
    boundaries=boundaries,
    return_cbar=True,
    colorbar_kwargs={"label": "nr. valid samples"},
    fig=plt.figure(figsize=(5.1, 2.6)),
)
cbar.set_ticks(get_centres(boundaries))
cbar.set_ticklabels(list(map(int, get_centres(boundaries))))

map_figure_saver.save_figure(fig, "high_ba_n_valid", sub_directory=root_dir)

#### Calculate SHAP results only for those voxels with high BA

In [None]:
high_ba_map_shap_results = calculate_2d_masked_shap_values(
    X_train, master_mask, shap_values, additional_mask=high_ba_mask
)

#### Plotting corresponding maps of high BA SHAP values

In [None]:
plot_shap_value_maps(
    X_train,
    high_ba_map_shap_results,
    map_figure_saver,
    directory=Path("shap_maps") / "high_ba",
)

### Rank SHAP masks

In [None]:
max_month = 9

# Mark areas where at least this fraction of |SHAP| has been plotted previously.
thres = 0.7

close_figs = True


def param_iter():
    for (exc_name, exclude_inst) in tqdm(
        [("with_inst", False), ("no_inst", True)], desc="Exclude inst."
    ):
        for feature_name in tqdm(
            ["VOD Ku-band", "SIF", "FAPAR", "LAI", "Dry Day Period"], desc="Feature"
        ):
            for (filter_name, shap_results) in tqdm(
                [("normal", map_shap_results), ("high_ba", high_ba_map_shap_results)],
                desc="SHAP results",
            ):
                for shap_measure in tqdm(
                    ["masked_max_shap_arrs", "masked_abs_shap_arrs"],
                    desc="SHAP measure",
                ):
                    yield (exc_name, exclude_inst), feature_name, (
                        filter_name,
                        shap_results,
                    ), shap_measure


for (
    (exc_name, exclude_inst),
    feature_name,
    (filter_name, shap_results),
    shap_measure,
) in islice(param_iter(), 0, None):
    short_feature = shorten_features(feature_name)

    sub_directory = Path("rank_shap_maps") / filter_name / shap_measure / exc_name

    filtered = np.array(filter_by_month(X_train.columns, feature_name, max_month))
    lags = np.array(
        [get_lag(feature, target_feature=feature_name) for feature in filtered]
    )

    # Ensure lags are sorted consistently.
    lag_sort_inds = np.argsort(lags)
    filtered = tuple(filtered[lag_sort_inds])
    lags = tuple(lags[lag_sort_inds])

    if exclude_inst and 0 in lags:
        assert lags[0] == 0
        lags = lags[1:]
        filtered = filtered[1:]

    n_features = len(filtered)

    # There is no point plotting this map for a single feature or less since we are
    # interested in a comparison between different feature ranks.
    if n_features <= 1:
        continue

    selected_data = np.empty(n_features, dtype=object)
    for i, col in enumerate(X_train.columns):
        if col in filtered:
            selected_data[lags.index(get_lag(col))] = shap_results[shap_measure][
                "data"
            ][i].copy()

    shared_mask = reduce(np.logical_or, (data.mask for data in selected_data))
    for data in selected_data:
        data.mask = shared_mask

    stacked_abs = np.abs(np.vstack([data.data[np.newaxis] for data in selected_data]))
    # Indices in descending order.
    sort_indices = np.argsort(stacked_abs, axis=0)[::-1]

    # Maintain the same colors even if fewer colors are used.
    colors = [lag_color_dict[lag] for lag in lags]

    cmap, norm = from_levels_and_colors(
        levels=np.arange(n_features + 1), colors=colors, extend="neither",
    )

    short_feature = shorten_features(feature_name)

    sum_shap = np.ma.MaskedArray(np.sum(stacked_abs, axis=0), mask=shared_mask)
    already_plotted = np.zeros_like(sum_shap)

    for i, rank in zip(
        tqdm(range(n_features), desc="Plotting", leave=False),
        ["1st", "2nd", "3rd", "4th", "5th"],
    ):
        cube = dummy_lat_lon_cube(np.ma.MaskedArray(sort_indices[i], mask=shared_mask))

        fig, ax = plt.subplots(
            figsize=(5.1, 2.6), subplot_kw={"projection": ccrs.Robinson()}
        )

        style = 2

        if style == 1:
            # Stippling for significant areas.
            mpl.rc("hatch", linewidth=0.2)
            hatches = ["." * 6, None]
        else:
            # Hatching for insignificant areas.
            mpl.rc("hatch", linewidth=0.1)
            hatches = ["/" * 14, None]

        # XXX: Temporary, since this takes an exorbitant amount of time to render.
        #         if np.any(already_plotted >= thres):
        #             ax.contourf(
        #                 cube.coord("longitude").points,
        #                 cube.coord("latitude").points,
        #                 already_plotted,
        #                 transform=ccrs.PlateCarree(),
        #                 colors="none",
        #                 zorder=4,
        #                 levels=[thres, 1],
        #                 hatches=hatches,
        #             )

        fig, cbar = cube_plotting(
            cube,
            title=f"{rank} |SHAP {short_feature} Lag| - thres: {thres * 100:0.0f}%",
            fig=fig,
            ax=ax,
            cmap=cmap,
            norm=norm,
            return_cbar=True,
            colorbar_kwargs={"label": short_feature},
            coastline_kwargs={"linewidth": 0.3},
        )

        # Label the colorbar using the feature names.
        cbar.set_ticks(np.arange(n_features) + 0.5)

        labels = []
        for lag in lags:
            if lag:
                labels.append(f"{lag} M")
            else:
                labels.append("Inst.")
        cbar.set_ticklabels(labels)

        data = np.take_along_axis(stacked_abs, sort_indices[i : i + 1], axis=0)[0]
        already_plotted += data / sum_shap

        map_figure_saver.save_figure(
            fig, f"{rank}_{short_feature}", sub_directory=sub_directory
        )
        plt.close(fig)

### Use significant peak finding algorithm to determine mean timing of maximum impact using SHAP values

#### Investigate the role of the significance parameters

In [None]:
max_month = 9

# Mark areas where at least this fraction of |SHAP| has been plotted previously.
thres = 0.7


def param_iter():
    for exclude_inst in tqdm([False, True], desc="Exclude inst."):
        for feature_name in tqdm(
            ["VOD Ku-band", "SIF", "FAPAR", "LAI", "Dry Day Period"], desc="Feature"
        ):
            yield exclude_inst, feature_name


for exclude_inst, feature_name in islice(param_iter(), 0, None):
    filtered = np.array(filter_by_month(X_train.columns, feature_name, max_month))
    lags = np.array(
        [get_lag(feature, target_feature=feature_name) for feature in filtered]
    )

    # Ensure lags are sorted consistently.
    lag_sort_inds = np.argsort(lags)
    filtered = tuple(filtered[lag_sort_inds])
    lags = tuple(lags[lag_sort_inds])

    if exclude_inst and 0 in lags:
        assert lags[0] == 0
        lags = lags[1:]
        filtered = filtered[1:]

    n_features = len(filtered)

    # There is no point plotting this map for a single feature or less since we are
    # interested in a comparison between different feature ranks.
    if n_features <= 1:
        continue

    short_feature = shorten_features(feature_name)

    selected_data = np.empty(n_features, dtype=object)
    for i, col in enumerate(X_train.columns):
        if col in filtered:
            selected_data[lags.index(get_lag(col))] = map_shap_results[
                "masked_abs_shap_arrs"
            ]["data"][i].copy()

    shared_mask = reduce(np.logical_or, (data.mask for data in selected_data))
    for data in selected_data:
        data.mask = shared_mask

    stacked_shaps = np.vstack([data.data[np.newaxis] for data in selected_data])

    # Calculate the significance of the global maxima for each of the valid pixels.

    # Valid indices are recorded in 'shared_mask'.

    valid_i, valid_j = np.where(~shared_mask)
    total_valid = len(valid_i)

    def get_true_significances(kwargs):
        ptp_threshold_factor = kwargs.pop("ptp_threshold_factor")

        significant = []
        for i, j in zip(valid_i, valid_j):
            ptp_threshold = ptp_threshold_factor * mean_ba[i, j]
            significant.append(
                significant_peak(
                    stacked_shaps[:, i, j], ptp_threshold=ptp_threshold, **kwargs
                )
            )
        return dict(zip(*np.unique(significant, return_counts=True)))[True]

    with concurrent.futures.ProcessPoolExecutor(max_workers=get_ncpus()) as executor:
        diff_thresholds = np.linspace(0.01, 0.99, 15)
        ptp_threshold_factors = np.linspace(0, 1.5, 8)

        kwargs_list = [
            {
                "diff_threshold": diff_threshold,
                "ptp_threshold_factor": ptp_threshold_factor,
            }
            for diff_threshold, ptp_threshold_factor in product(
                diff_thresholds, ptp_threshold_factors
            )
        ]

        fs = [executor.submit(get_true_significances, kwargs) for kwargs in kwargs_list]
        for f in tqdm(
            concurrent.futures.as_completed(fs),
            total=len(fs),
            desc="Varying significances",
            smoothing=0,
        ):
            pass

        plt.figure(figsize=(12, 8))

        perc_sigs = 100 * np.array([f.result() for f in fs]) / total_valid

        for ptp_threshold_factor in ptp_threshold_factors:
            ptp_indices = [
                i
                for i, kwargs in enumerate(kwargs_list)
                if np.isclose(kwargs["ptp_threshold_factor"], ptp_threshold_factor)
            ]
            plt.plot(
                diff_thresholds,
                perc_sigs[ptp_indices],
                label=f"ptp_thres: {ptp_threshold_factor:0.2f}",
            )

        plt.title(f"{short_feature} - Exclude Inst: {exclude_inst}")
        plt.xlabel("Diff threshold")
        plt.ylabel("% significant")
        plt.grid(linestyle="--", alpha=0.4)
        plt.legend(loc="best")

#### Plot sig. maps

In [None]:
diff_threshold = 0.5
ptp_threshold_factor = 0.12  # relative the mean

### Plot weighted SHAP peak locations

In [None]:
max_month = 9
close_figs = True


def param_iter():
    for (exc_name, exclude_inst) in tqdm(
        [("with_inst", False), ("no_inst", True)], desc="Exclude inst."
    ):
        for feature_name in tqdm(
            ["VOD Ku-band", "SIF", "FAPAR", "LAI", "Dry Day Period"], desc="Feature"
        ):
            for (filter_name, shap_results) in tqdm(
                [("normal", map_shap_results), ("high_ba", high_ba_map_shap_results)],
                desc="SHAP results",
            ):
                for shap_measure in tqdm(
                    [
                        "masked_max_shap_arrs",
                        "masked_abs_shap_arrs",
                        "masked_shap_arrs",
                    ],
                    desc="SHAP measure",
                ):
                    yield (exc_name, exclude_inst), feature_name, (
                        filter_name,
                        shap_results,
                    ), shap_measure


for (
    (exc_name, exclude_inst),
    feature_name,
    (filter_name, shap_results),
    shap_measure,
) in islice(param_iter(), 0, None):
    short_feature = shorten_features(feature_name)

    sub_directory = Path("weighted_shap_maps") / filter_name / shap_measure / exc_name

    filtered = np.array(filter_by_month(X_train.columns, feature_name, max_month))
    lags = np.array(
        [get_lag(feature, target_feature=feature_name) for feature in filtered]
    )

    # Ensure lags are sorted consistently.
    lag_sort_inds = np.argsort(lags)
    filtered = tuple(filtered[lag_sort_inds])
    lags = tuple(lags[lag_sort_inds])

    if exclude_inst and 0 in lags:
        assert lags[0] == 0
        lags = lags[1:]
        filtered = filtered[1:]

    n_features = len(filtered)

    # There is no point plotting this map for a single feature or less since we are
    # interested in a comparison between different feature ranks.
    if n_features <= 1:
        continue

    selected_data = np.empty(n_features, dtype=object)
    for i, col in enumerate(X_train.columns):
        if col in filtered:
            selected_data[lags.index(get_lag(col))] = shap_results[shap_measure][
                "data"
            ][i].copy()

    shared_mask = reduce(np.logical_or, (data.mask for data in selected_data))
    for data in selected_data:
        data.mask = shared_mask

    stacked_shaps = np.vstack([data.data[np.newaxis] for data in selected_data])

    # Calculate the significance of the global maxima for each of the valid pixels.

    # Valid indices are recorded in 'shared_mask'.

    valid_i, valid_j = np.where(~shared_mask)
    total_valid = len(valid_i)

    max_positions = np.ma.MaskedArray(
        np.zeros_like(shared_mask, dtype=np.float64), mask=True
    )
    for i, j in zip(tqdm(valid_i, desc="Evaluating maxima", smoothing=0), valid_j):
        ptp_threshold = ptp_threshold_factor * mean_ba[i, j]
        if significant_peak(
            stacked_shaps[:, i, j],
            diff_threshold=diff_threshold,
            ptp_threshold=ptp_threshold,
        ):
            # If the maximum is significant, go on the calculate the weighted avg. of the signal.
            max_positions[i, j] = np.sum(
                np.array(lags) * np.abs(stacked_shaps[:, i, j])
            ) / np.sum(np.abs(stacked_shaps[:, i, j]))

    #
    fig = cube_plotting(
        max_positions,
        title=f"|SHAP| Weighted Maximum {short_feature} - Exclude Inst: {exclude_inst}",
        fig=plt.figure(figsize=(5.1, 2.6)),
        colorbar_kwargs={"label": short_feature, "format": "%0.1f"},
        coastline_kwargs={"linewidth": 0.3},
        #         boundaries=np.arange(0, max(lags)+1),
    )
    map_figure_saver.save_figure(
        fig, f"weighted_shap_{short_feature}", sub_directory=sub_directory
    )
    if close_figs:
        plt.close()
    #

    # Plot example series.

#     plt.figure(figsize=(12, 8))
#     plt.title(f"{short_feature} - Exclude Inst: {exclude_inst}")
#     for i in np.random.RandomState(0).choice(total_valid, 100, False):
#         plot_i = valid_i[i]
#         plot_j = valid_j[i]
#         plot_data = stacked_shaps[:, plot_i, plot_j]
#         plt.plot(lags, plot_data / np.max(plot_data), c="C0", alpha=0.4)
#     plt.xlabel(lags)
#     plt.ylabel("|SHAP|")
#     plt.grid(linestyle="--", alpha=0.4)
# #     plt.legend(loc='best')

### Analyse SHAP -- BA relationship for selected regions

In [None]:
ba_cmap = plt.get_cmap("inferno")

fig, ax = plt.subplots(
    1, 1, figsize=(14, 6), subplot_kw=dict(projection=ccrs.Robinson())
)

feature_name = "Dry Days"
results_dict = map_shap_results["masked_shap_arrs"]
fig = cube_plotting(
    results_dict["data"][
        list(map(shorten_features, X_train.columns)).index(feature_name)
    ],
    fig=fig,
    ax=ax,
    title=f"Mean SHAP value for '{shorten_features(feature_name)}'",
    nbins=7,
    vmin=results_dict["vmin"],
    vmax=results_dict["vmax"],
    log=True,
    log_auto_bins=False,
    extend="neither",
    min_edge=1e-3,
    cmap="Spectral_r",
    cmap_midpoint=0,
    cmap_symmetric=True,
    colorbar_kwargs={
        "format": "%0.1e",
        "label": f"SHAP ('{shorten_features(feature_name)}')",
    },
    coastline_kwargs={"linewidth": 0.3},
)

# lat_range = (-5, 4)
lat_range = (-5, 0)
lon_range = (12.5, 26.5)

ax.set_global()
ax.add_patch(
    mpatches.Rectangle(
        xy=[min(lon_range), min(lat_range)],
        width=np.ptp(lon_range),
        height=np.ptp(lat_range),
        facecolor="none",
        edgecolor="blue",
        alpha=0.8,
        lw=2,
        transform=ccrs.PlateCarree(),
    )
)


# _ = ax.gridlines()

ba_train_cube = dummy_lat_lon_cube(get_mm_data(y_train.values, master_mask, "train"))
ba_subset = ba_train_cube.intersection(latitude=lat_range).intersection(
    longitude=lon_range
)
cube_plotting(ba_subset, log=True)

mm_valid_indices, mm_valid_train_indices, mm_valid_val_indices = get_mm_indices(
    master_mask
)
mm_kind_indices = mm_valid_train_indices[: shap_values.shape[0]]
X_i = list(map(shorten_features, X_train.columns)).index(feature_name)

# Convert 1D shap values into 3D array (time, lat, lon).
masked_shap_comp = np.ma.MaskedArray(
    np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
)
masked_shap_comp.ravel()[mm_kind_indices] = shap_values[:, X_i]

# if additional_mask is not None:
#     masked_shap_comp.mask |= match_shape(additional_mask, masked_shap_comp.shape)

assert np.all(
    np.isclose(results_dict["data"][X_i], np.ma.mean(masked_shap_comp, axis=0))
)

shap_cube = dummy_lat_lon_cube(masked_shap_comp)
shap_subset = shap_cube.intersection(latitude=lat_range).intersection(
    longitude=lon_range
)

plt.figure()
indices = list(np.ndindex(ba_subset.shape[1:]))
ba_camap = plt.get_cmap("inferno")
for i, index in enumerate(tqdm(indices, desc="Plotting")):
    plt.plot(
        ba_subset.data[(slice(None), *index)],
        marker="o",
        alpha=0.4,
        c=ba_cmap(i / (len(indices) - 1)),
    )
_ = plt.ylabel("BA")

plt.figure()
indices = list(np.ndindex(ba_subset.shape[1:]))
ba_camap = plt.get_cmap("inferno")
for i, index in enumerate(tqdm(indices, desc="Plotting")):
    plt.plot(
        shap_subset.data[(slice(None), *index)],
        marker="o",
        alpha=0.4,
        c=ba_cmap(i / (len(indices) - 1)),
    )
_ = plt.ylabel("SHAP(Dry Days)")

plt.figure(figsize=(15, 15))
indices = list(np.ndindex(ba_subset.shape[1:]))
ba_camap = plt.get_cmap("inferno")
for i, index in enumerate(tqdm(indices, desc="Plotting")):
    plt.plot(
        shap_subset.data[(slice(None), *index)],
        ba_subset.data[(slice(None), *index)],
        marker="o",
        alpha=0.4,
        c=ba_cmap(i / (len(indices) - 1)),
        linestyle="",
    )
plt.ylabel("BA")
_ = plt.xlabel("SHAP(Dry Days)")

In [None]:
fig, ax = plt.subplots(
    1, 1, figsize=(14, 6), subplot_kw=dict(projection=ccrs.Robinson())
)

feature_name = "Dry Days"
results_dict = map_shap_results["masked_shap_arrs"]
fig = cube_plotting(
    results_dict["data"][
        list(map(shorten_features, X_train.columns)).index(feature_name)
    ],
    fig=fig,
    ax=ax,
    title=f"Mean SHAP value for '{shorten_features(feature_name)}'",
    nbins=7,
    vmin=results_dict["vmin"],
    vmax=results_dict["vmax"],
    log=True,
    log_auto_bins=False,
    extend="neither",
    min_edge=1e-3,
    cmap="Spectral_r",
    cmap_midpoint=0,
    cmap_symmetric=True,
    colorbar_kwargs={
        "format": "%0.1e",
        "label": f"SHAP ('{shorten_features(feature_name)}')",
    },
    coastline_kwargs={"linewidth": 0.3},
)

# lat_range = (-5, 4)
lat_range = (6, 11)
lon_range = (12.5, 26.5)

ax.set_global()
ax.add_patch(
    mpatches.Rectangle(
        xy=[min(lon_range), min(lat_range)],
        width=np.ptp(lon_range),
        height=np.ptp(lat_range),
        facecolor="none",
        edgecolor="blue",
        alpha=0.8,
        lw=2,
        transform=ccrs.PlateCarree(),
    )
)


# _ = ax.gridlines()

ba_train_cube = dummy_lat_lon_cube(get_mm_data(y_train.values, master_mask, "train"))
ba_subset = ba_train_cube.intersection(latitude=lat_range).intersection(
    longitude=lon_range
)
cube_plotting(ba_subset, log=True)

mm_valid_indices, mm_valid_train_indices, mm_valid_val_indices = get_mm_indices(
    master_mask
)
mm_kind_indices = mm_valid_train_indices[: shap_values.shape[0]]
X_i = list(map(shorten_features, X_train.columns)).index(feature_name)

# Convert 1D shap values into 3D array (time, lat, lon).
masked_shap_comp = np.ma.MaskedArray(
    np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
)
masked_shap_comp.ravel()[mm_kind_indices] = shap_values[:, X_i]

# if additional_mask is not None:
#     masked_shap_comp.mask |= match_shape(additional_mask, masked_shap_comp.shape)

assert np.all(
    np.isclose(results_dict["data"][X_i], np.ma.mean(masked_shap_comp, axis=0))
)

shap_cube = dummy_lat_lon_cube(masked_shap_comp)
shap_subset = shap_cube.intersection(latitude=lat_range).intersection(
    longitude=lon_range
)

plt.figure()
indices = list(np.ndindex(ba_subset.shape[1:]))
ba_camap = plt.get_cmap("inferno")
for i, index in enumerate(tqdm(indices, desc="Plotting")):
    plt.plot(
        ba_subset.data[(slice(None), *index)],
        marker="o",
        alpha=0.4,
        c=ba_cmap(i / (len(indices) - 1)),
    )
_ = plt.ylabel("BA")

plt.figure()
indices = list(np.ndindex(ba_subset.shape[1:]))
ba_camap = plt.get_cmap("inferno")
for i, index in enumerate(tqdm(indices, desc="Plotting")):
    plt.plot(
        shap_subset.data[(slice(None), *index)],
        marker="o",
        alpha=0.4,
        c=ba_cmap(i / (len(indices) - 1)),
    )
_ = plt.ylabel("SHAP(Dry Days)")

plt.figure(figsize=(15, 15))
indices = list(np.ndindex(ba_subset.shape[1:]))
ba_camap = plt.get_cmap("inferno")
for i, index in enumerate(tqdm(indices, desc="Plotting")):
    plt.plot(
        shap_subset.data[(slice(None), *index)],
        ba_subset.data[(slice(None), *index)],
        marker="o",
        alpha=0.4,
        c=ba_cmap(i / (len(indices) - 1)),
        linestyle="",
    )
plt.ylabel("BA")
_ = plt.xlabel("SHAP(Dry Days)")

### Categorisation into multiple peaks

#### Test the effect of parameters on the peak distribution

In [None]:
max_month = 9

# Mark areas where at least this fraction of |SHAP| has been plotted previously.
thres = 0.7


def param_iter():
    for exclude_inst in tqdm([False], desc="Exclude inst."):
        for feature_name in tqdm(["FAPAR", "Dry Day Period"], desc="Feature"):
            yield exclude_inst, feature_name


for exclude_inst, feature_name in islice(param_iter(), 0, 1):
    filtered = np.array(filter_by_month(X_train.columns, feature_name, max_month))
    lags = np.array(
        [get_lag(feature, target_feature=feature_name) for feature in filtered]
    )

    # Ensure lags are sorted consistently.
    lag_sort_inds = np.argsort(lags)
    filtered = tuple(filtered[lag_sort_inds])
    lags = tuple(lags[lag_sort_inds])

    if exclude_inst and 0 in lags:
        assert lags[0] == 0
        lags = lags[1:]
        filtered = filtered[1:]

    n_features = len(filtered)

    # There is no point plotting this map for a single feature or less since we are
    # interested in a comparison between different feature ranks.
    if n_features <= 1:
        continue

    short_feature = shorten_features(feature_name)

    selected_data = np.empty(n_features, dtype=object)
    for i, col in enumerate(X_train.columns):
        if col in filtered:
            selected_data[lags.index(get_lag(col))] = map_shap_results[
                "masked_shap_arrs"
            ]["data"][i].copy()

    shared_mask = reduce(np.logical_or, (data.mask for data in selected_data))
    for data in selected_data:
        data.mask = shared_mask

    stacked_shaps = np.vstack([data.data[np.newaxis] for data in selected_data])

    # Calculate the significance of the global maxima for each of the valid pixels.

    # Valid indices are recorded in 'shared_mask'.

    valid_i, valid_j = np.where(~shared_mask)
    total_valid = len(valid_i)

    max_positions = np.ma.MaskedArray(
        np.zeros_like(shared_mask, dtype=np.float64), mask=True
    )

    def get_n_peaks(ptp_threshold_factor, diff_threshold):
        n_peaks = []

        for i, j in zip(valid_i, valid_j):
            ptp_threshold = ptp_threshold_factor * mean_ba[i, j]
            peaks_i = significant_peak(
                stacked_shaps[:, i, j],
                diff_threshold=diff_threshold,
                ptp_threshold=ptp_threshold,
                strict=False,
            )
            n_peaks.append(peaks_i)
        return dict(zip(*np.unique(n_peaks, return_counts=True)))

    with concurrent.futures.ProcessPoolExecutor(max_workers=get_ncpus()) as executor:
        diff_thresholds = np.round(np.linspace(0.01, 0.99, 8), 2)
        ptp_threshold_factors = np.round(np.linspace(0, 0.75, 12), 2)

        fs = [
            executor.submit(get_n_peaks, ptp_threshold_factor, diff_threshold)
            for ptp_threshold_factor, diff_threshold in product(
                ptp_threshold_factors, diff_thresholds
            )
        ]
        for f in tqdm(
            concurrent.futures.as_completed(fs),
            total=len(fs),
            desc="Varying significances",
            smoothing=0,
        ):
            pass

        results = {}
        for (f, (ptp_threshold_factor, diff_threshold)) in zip(
            fs, product(ptp_threshold_factors, diff_thresholds)
        ):
            results[(ptp_threshold_factor, diff_threshold)] = f.result()

In [None]:
n_peaks_dict = {}
for key, vals in results.items():
    n_peaks_dict[key] = {}
    all_tuples = list(vals.keys())
    all_n_peaks = np.array([len(tup) for tup in all_tuples])
    all_n = np.array(list(vals.values()))

    for n in np.unique(all_n_peaks):
        n_peaks_dict[key][n] = np.sum(all_n[all_n_peaks == n])

In [None]:
df = pd.DataFrame(n_peaks_dict).T
df.index.names = ["ptp_thres_f", "diff_thres"]
_ = df.groupby("ptp_thres_f").plot(kind="bar")

In [None]:
df = pd.DataFrame(n_peaks_dict).T
df.index.names = ["ptp_thres_f", "diff_thres"]
_ = df.groupby("diff_thres").plot(kind="bar")

#### Carry out run

In [None]:
pfts = ESA_CCI_Landcover_PFT()
pfts.limit_months(start=PartialDateTime(2010, 1), end=PartialDateTime(2015, 1))
pfts.regrid()
pfts = pfts.get_mean_dataset()

In [None]:
pft_cube = Datasets(pfts).select_variables("pftHerb", inplace=False).cube
fig = cube_plotting(pft_cube, fig=plt.figure(figsize=(12, 5)))

In [None]:
pft_cubes = []
pft_names = ("pftShrubBD", "pftShrubBE", "pftShrubNE")

for pft_name in pft_names:
    pft_cube = Datasets(pfts).select_variables(pft_name, inplace=False).cube
    pft_cubes.append(pft_cube)
    fig = cube_plotting(pft_cube, fig=plt.figure(figsize=(12, 5)))
fig = cube_plotting(
    reduce(lambda x, y: x + y, pft_cubes),
    title=f"Sum of {', '.join(pft_names)}",
    fig=plt.figure(figsize=(12, 5)),
)

pft_cube = Datasets(pfts).select_variables("ShrubAll", inplace=False).cube
fig = cube_plotting(pft_cube, fig=plt.figure(figsize=(12, 5)))

In [None]:
pft_cubes = []
pft_names = ("pftTreeBD", "pftTreeBE", "pftTreeND", "pftTreeNE")
for pft_name in pft_names:
    pft_cube = Datasets(pfts).select_variables(pft_name, inplace=False).cube
    pft_cubes.append(pft_cube)
    fig = cube_plotting(pft_cube, fig=plt.figure(figsize=(12, 5)))
fig = cube_plotting(
    reduce(lambda x, y: x + y, pft_cubes),
    title=f"Sum of {', '.join(pft_names)}",
    fig=plt.figure(figsize=(12, 5)),
)

pft_cube = Datasets(pfts).select_variables("TreeAll", inplace=False).cube
fig = cube_plotting(pft_cube, fig=plt.figure(figsize=(12, 5)))

In [None]:
pft_cubes = []
pft_names = ("ShrubAll", "TreeAll")
for pft_name in pft_names:
    pft_cube = Datasets(pfts).select_variables(pft_name, inplace=False).cube
    pft_cubes.append(pft_cube)
    fig = cube_plotting(pft_cube, fig=plt.figure(figsize=(12, 5)))
fig = cube_plotting(
    reduce(lambda x, y: x + y, pft_cubes),
    title=f"Sum of {', '.join(pft_names)}",
    fig=plt.figure(figsize=(12, 5)),
)

In [None]:
max_month = 9
close_figs = True


def param_iter():
    for (exc_name, exclude_inst) in tqdm(
        [("with_inst", False), ("no_inst", True)], desc="Exclude inst."
    ):
        for feature_name in tqdm(
            ["VOD Ku-band", "SIF", "FAPAR", "LAI", "Dry Day Period"], desc="Feature"
        ):
            for (filter_name, shap_results) in tqdm(
                [("normal", map_shap_results), ("high_ba", high_ba_map_shap_results)],
                desc="SHAP results",
            ):
                for shap_measure in tqdm(
                    [
                        "masked_max_shap_arrs",
                        "masked_abs_shap_arrs",
                        "masked_shap_arrs",
                    ],
                    desc="SHAP measure",
                ):
                    yield (exc_name, exclude_inst), feature_name, (
                        filter_name,
                        shap_results,
                    ), shap_measure


for (
    (exc_name, exclude_inst),
    feature_name,
    (filter_name, shap_results),
    shap_measure,
) in islice(param_iter(), 0, None):
    short_feature = shorten_features(feature_name)

    sub_directory = (
        Path("shap_peaks") / filter_name / shap_measure / short_feature / exc_name
    )

    filtered = np.array(filter_by_month(X_train.columns, feature_name, max_month))
    lags = np.array(
        [get_lag(feature, target_feature=feature_name) for feature in filtered]
    )

    # Ensure lags are sorted consistently.
    lag_sort_inds = np.argsort(lags)
    filtered = tuple(filtered[lag_sort_inds])
    lags = tuple(lags[lag_sort_inds])

    if exclude_inst and 0 in lags:
        assert lags[0] == 0
        lags = lags[1:]
        filtered = filtered[1:]

    n_features = len(filtered)

    # There is no point plotting this map for a single feature or less since we are
    # interested in a comparison between different feature ranks.
    if n_features <= 1:
        continue

    selected_data = np.empty(n_features, dtype=object)
    for i, col in enumerate(X_train.columns):
        if col in filtered:
            selected_data[lags.index(get_lag(col))] = shap_results[shap_measure][
                "data"
            ][i].copy()

    shared_mask = reduce(np.logical_or, (data.mask for data in selected_data))
    for data in selected_data:
        data.mask = shared_mask

    stacked_shaps = np.vstack([data.data[np.newaxis] for data in selected_data])

    # Calculate the significance of the global maxima for each of the valid pixels.

    # Valid indices are recorded in 'shared_mask'.

    valid_i, valid_j = np.where(~shared_mask)
    total_valid = len(valid_i)

    max_positions = np.ma.MaskedArray(
        np.zeros_like(shared_mask, dtype=np.float64), mask=True
    )

    peak_indices = []

    for i, j in zip(tqdm(valid_i, desc="Evaluating maxima", smoothing=0), valid_j):
        ptp_threshold = ptp_threshold_factor * mean_ba[i, j]
        peaks_i = significant_peak(
            stacked_shaps[:, i, j],
            diff_threshold=diff_threshold,
            ptp_threshold=ptp_threshold,
            strict=False,
        )

        # Disregarding the sign of the mean influence, sorted by absolute
        # value (not peak height) magnitude.
        #
        # peak_indices.append(peaks_i)

        # Adding information about the sign of the mean influence, sorted by absolute
        # value (not peak height) magnitude.
        #
        #         peak_indices.append(tuple(
        #             f"{p_i}({'+' if stacked_shaps[p_i, i, j] > 0 else '-'})" for p_i in peaks_i
        #         ))

        # Adding information about the sign of the mean influence, sorted by time.
        #
        peak_indices.append(
            tuple(
                f"{lags[p_i]}({'+' if stacked_shaps[p_i, i, j] > 0 else '-'})"
                for p_i in sorted(peaks_i)
            )
        )

    #
    pd.Series(
        dict(
            zip(
                *np.unique(
                    [len(indices) for indices in peak_indices], return_counts=True
                )
            )
        )
    ).plot.bar(
        ax=plt.subplots(figsize=(6, 4))[1],
        title=f"{short_feature}, exclude inst: {exclude_inst}",
        rot=0,
    )
    figure_saver.save_figure(plt.gcf(), "n_peaks_distr", sub_directory=sub_directory)
    if close_figs:
        plt.close()
    #

    peaks_arr = np.ma.MaskedArray(
        np.zeros_like(shared_mask, dtype=np.float64), mask=True
    )
    for i, j, indices in zip(valid_i, valid_j, peak_indices):
        peaks_arr[i, j] = len(indices)

    #
    fig, cbar = cube_plotting(
        peaks_arr,
        title=f"Nr. Peaks {short_feature}, exclude inst: {exclude_inst}",
        boundaries=np.arange(0, 4) - 0.5,
        fig=plt.figure(figsize=(5.1, 2.6)),
        coastline_kwargs={"linewidth": 0.3},
        colorbar_kwargs={"label": "nr. peaks", "format": "%0.1f"},
        return_cbar=True,
    )

    tick_pos = np.arange(4, dtype=np.float64)
    tick_pos[3] -= 0.5

    cbar.set_ticks(tick_pos)

    tick_labels = list(map(str, range(3))) + [">2"]
    cbar.set_ticklabels(tick_labels)

    #     plt.gca().gridlines()

    map_figure_saver.save_figure(
        fig, f"nr_shap_peaks_map_{short_feature}", sub_directory=sub_directory
    )
    if close_figs:
        plt.close()
    #

    masked_peaks = peaks_arr.copy()
    masked_peaks.mask |= (peaks_arr.data == 0) | (peaks_arr.data > 2)

    cmap, norm = from_levels_and_colors(
        levels=np.arange(1, 4) - 0.5, colors=["C1", "C2"], extend="neither",
    )

    #
    fig, cbar = cube_plotting(
        masked_peaks,
        title=f"Nr. Peaks {short_feature}, exclude inst: {exclude_inst}",
        # boundaries=np.arange(1, 4) - 0.5,
        fig=plt.figure(figsize=(5.1, 2.6)),
        coastline_kwargs={"linewidth": 0.3},
        colorbar_kwargs={"label": "nr. peaks", "format": "%0.1f"},
        return_cbar=True,
        cmap=cmap,
        norm=norm,
    )

    tick_pos = np.arange(3, dtype=np.float64)
    cbar.set_ticks(tick_pos)

    tick_labels = list(map(str, range(1, 3)))
    cbar.set_ticklabels(tick_labels)

    #     plt.gca().gridlines()

    map_figure_saver.save_figure(
        fig, f"filtered_nr_shap_peaks_map_{short_feature}", sub_directory=sub_directory
    )
    if close_figs:
        plt.close()
    #

    valid_peak_indices = []
    for i, j, peaks_i in zip(valid_i, valid_j, peak_indices):
        if masked_peaks.mask[i, j]:
            # Only use valid samples.
            continue
        valid_peak_indices.append(peaks_i)

    assert np.all(
        np.sort(np.unique([len(indices) for indices in valid_peak_indices]))
        == np.array([1, 2])
    )

    #
    pd.Series(
        dict(
            zip(
                *np.unique(
                    [len(indices) for indices in valid_peak_indices], return_counts=True
                )
            )
        )
    ).plot.bar(
        ax=plt.subplots(figsize=(6, 4))[1],
        title=f"{short_feature}, exclude inst: {exclude_inst}",
        rot=0,
    )
    figure_saver.save_figure(
        plt.gcf(), "filtered_n_peaks_distr", sub_directory=sub_directory
    )
    if close_figs:
        plt.close()
    #

    peaks_dict = dict(zip(*np.unique(valid_peak_indices, return_counts=True)))

    total_counts = np.sum(list(peaks_dict.values()))
    relative_counts_dict = {key: val / total_counts for key, val in peaks_dict.items()}
    #     print(f"{short_feature}, exclude inst: {exclude_inst}", relative_counts_dict)

    #
    fig = plt.figure(figsize=(7, 0.3 * len(relative_counts_dict) + 0.4))
    pd.Series(
        {", ".join(k): v for k, v in relative_counts_dict.items()}
    ).sort_values().plot.barh(
        fontsize=12, title=f"{short_feature}, exclude inst: {exclude_inst}",
    )
    figure_saver.save_figure(plt.gcf(), "peak_comb_distr", sub_directory=sub_directory)
    if close_figs:
        plt.close()
    #

    ##### Eliminate the lowest X% iff there are more than Y combinations

    keys, values = list(zip(*relative_counts_dict.items()))
    keys = np.asarray(keys)
    values = np.asarray(values)

    #     elim_frac = 0.2
    #     min_n_peaks = 6
    max_n_peaks = 6
    min_frac = 0.05  # Require at least this fraction per entry.

    sorted_indices = np.argsort(values)
    cumulative_fractions = np.cumsum(values[sorted_indices])

    # Ensure at least `min_n_entries` entries are present, but no more than `max_n_peaks`.
    mask = np.ones_like(cumulative_fractions, dtype=np.bool_)
    #     mask[-min_n_peaks:] = True
    mask[:-max_n_peaks] = False  # no more than `max_n_peaks`.

    #     mask |= (cumulative_fractions > elim_frac)
    mask &= values[sorted_indices] > min_frac

    print(
        f"Remaining fraction: {short_feature}, exclude inst: {exclude_inst}",
        np.sum(values[sorted_indices][mask]),
    )

    thres_counts_dict = {
        key: val
        for key, val in zip(keys[sorted_indices][mask], values[sorted_indices][mask])
    }
    print(f"{short_feature}, exclude inst: {exclude_inst}", thres_counts_dict)

    peak_keys = list(thres_counts_dict)

    imp_peaks = np.ma.MaskedArray(
        np.zeros_like(shared_mask, dtype=np.float64), mask=True
    )

    for i, j, indices in zip(valid_i, valid_j, peak_indices):
        if indices in peak_keys:
            imp_peaks[i, j] = peak_keys.index(indices)

    #
    boundaries = np.arange(len(peak_keys) + 1) - 0.5

    cmap, norm = from_levels_and_colors(
        levels=boundaries,
        colors=[plt.get_cmap("tab10")(i) for i in range(len(peak_keys))],
        extend="neither",
    )

    fig, cbar = cube_plotting(
        imp_peaks,
        title=f"Peak Distr. {short_feature}, exclude inst: {exclude_inst}",
        fig=plt.figure(figsize=(5.1, 2.6)),
        coastline_kwargs={"linewidth": 0.3},
        colorbar_kwargs={"label": "peak combination"},
        return_cbar=True,
        cmap=cmap,
        norm=norm,
    )

    tick_pos = np.arange(len(peak_keys), dtype=np.float64)
    cbar.set_ticks(tick_pos)

    cbar.set_ticklabels(peak_keys)

    #     plt.gca().gridlines()

    map_figure_saver.save_figure(
        fig, f"shap_peak_distr_map_{short_feature}", sub_directory=sub_directory
    )
    if close_figs:
        plt.close()
    #

    assert len(np.unique(imp_peaks.data[~imp_peaks.mask])) == len(peak_keys)

    results = {}
    for comb_i in tqdm(
        np.unique(imp_peaks.data[~imp_peaks.mask]), desc="Peak combination"
    ):
        for pft_cube in pfts:
            selection = (~(pft_cube.data.mask | imp_peaks.mask)) & np.isclose(
                imp_peaks, comb_i
            )
            results[
                (str(peak_keys[int(comb_i)]), pft_cube.name())
            ] = pft_cube.data.data[selection]

    df = pd.DataFrame({key: pd.Series(vals) for key, vals in results.items()})
    df.columns.names = ["peak_combination", "pft"]

    for level in [0, 1]:
        #
        df.groupby(axis=1, level=level).boxplot(
            subplots=True,
            layout=(len(df.columns.levels[level]), 1),
            figsize=(10, 5 * len(df.columns.levels[level])),
            rot=30,
        )
        plt.tight_layout()
        figure_saver.save_figure(
            plt.gcf(),
            f"boxplots_level_{df.columns.names[level]}",
            sub_directory=sub_directory,
        )
        if close_figs:
            plt.close()
        #

        for key in df.columns.levels[level]:
            #
            fig = plt.figure()
            sns.violinplot(data=df.xs(key, level=level, axis="columns"))
            plt.title(key)
            figure_saver.save_figure(
                fig,
                f"violin_level_{df.columns.names[level]}_{key}",
                sub_directory=sub_directory,
            )
            if close_figs:
                plt.close()
            #