## Setup

In [None]:
from specific import *

In [None]:
(
    endog_data,
    exog_data,
    master_mask,
    filled_datasets,
    masked_datasets,
    land_mask,
) = get_offset_data()

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

### SHAP values - see PBS (array) job folder 'shap'!

#### Load the SHAP values (keeping track of missing chunks)

In [None]:
max_index = 995  # Maximum job array index (inclusive).
job_samples = 2000  # Samples per job.
total_samples = (max_index + 1) * job_samples  # Sanity check.

# Load the individual data chunks.
shap_chunks = []
missing = []
for index in tqdm(range(max_index + 1), desc="Loading chunks"):
    try:
        shap_chunks.append(
            SimpleCache(
                f"tree_path_dependent_shap_{index}_{job_samples}",
                cache_dir=os.path.join(CACHE_DIR, "shap"),
                verbose=0,
            ).load()
        )
    except NoCachedDataError:
        missing.append(index)

if missing:
    print("missing:", missing)
    print("nr missing:", len(missing))

shap_values = np.vstack(shap_chunks)

### BA in the train and test sets

Valid elements are situated where master_mask is False

In [None]:
valid_indices = np.where(~master_mask.ravel())[0]

valid_train_indices, valid_test_indices = train_test_split(
    valid_indices, random_state=1, shuffle=True, test_size=0.3
)

masked_train_data = np.ma.MaskedArray(
    np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
)
masked_train_data.ravel()[valid_train_indices] = y_train.values

masked_test_data = np.ma.MaskedArray(
    np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
)
masked_test_data.ravel()[valid_test_indices] = y_test.values

In [None]:
with map_figure_saver("train_test_set_overall_ba_comp"):
    fig, axes = plt.subplots(
        3,
        1,
        constrained_layout=True,
        figsize=(5.1, 8.4),
        subplot_kw={"projection": ccrs.Robinson()},
    )
    shared_kwargs = {
        "boundaries": [0, 4e-6, 1e-5, 1e-4, 1e-3, 1e-2, 8e-2],
        "extend": "max",
        "cmap": "inferno",
        "colorbar_kwargs": {"format": "%0.1e", "label": "Fractional BA"},
        "coastline_kwargs": {"linewidth": 0.3},
        "title": "",
    }
    axes[0].set_title("Mean Overall GFED4 BA")
    cube_plotting(
        get_masked_array(endog_data.values, master_mask),
        ax=axes[0],
        fig=fig,
        **shared_kwargs
    )
    axes[1].set_title("Mean Train Set GFED4 BA")
    cube_plotting(masked_train_data, ax=axes[1], fig=fig, **shared_kwargs)
    axes[2].set_title("Mean Test Set GFED4 BA")
    cube_plotting(masked_test_data, ax=axes[2], fig=fig, **shared_kwargs)

In [None]:
plt.hist(np.abs(shap_values[:, 0]), bins=5000)
plt.xscale("log")
plt.yscale("log")

### Calculate 2D masked array SHAP values

In [None]:
masked_shap_arrs = []
vmins = []
vmaxs = []

for i, feature in enumerate(tqdm(X_train.columns, desc="Computing SHAP values")):
    masked_shap_comp = np.ma.MaskedArray(
        np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
    )
    masked_shap_comp.ravel()[valid_train_indices[: shap_values.shape[0]]] = shap_values[
        :, i
    ]
    avg_shap_comp = np.ma.mean(masked_shap_comp, axis=0)
    masked_shap_arrs.append(avg_shap_comp)
    vmins.append(np.min(avg_shap_comp))
    vmaxs.append(np.max(avg_shap_comp))

vmin = min(vmins)
vmax = max(vmaxs)

### Plotting maps of SHAP values

In [None]:
for i, feature in enumerate(tqdm(X_train.columns, desc="Mapping SHAP values")):
    fig = cube_plotting(
        masked_shap_arrs[i],
        fig=plt.figure(figsize=(5.1, 2.8)),
        title=f"Mean SHAP value for '{shorten_features(feature)}'",
        cmap="Spectral_r",
        nbins=7,
        cmap_midpoint=0,
        cmap_symmetric=True,
        vmin=vmin,
        vmax=vmax,
        log=True,
        log_auto_bins=False,
        min_edge=1e-3,
        extend="neither",
        colorbar_kwargs={
            "format": "%0.1e",
            "label": f"SHAP ('{shorten_features(feature)}')",
        },
        coastline_kwargs={"linewidth": 0.3},
    )
    map_figure_saver.save_figure(
        fig, f"shap_value_map_{feature}", sub_directory="shap_map"
    )

### Find common mask for a range of 2D SHAP arrays to compare

In [None]:
max_month = 9

# Mark areas where at least this fraction of |SHAP| has been plotted previously.
thres = 0.7


def param_iter():
    for exclude_inst in tqdm([False, True], desc="Exclude inst."):
        for feature_name in tqdm(
            ["VOD Ku-band", "SIF", "FAPAR", "LAI", "Dry Day Period"], desc="Feature"
        ):
            yield exclude_inst, feature_name


for exclude_inst, feature_name in param_iter():
    if exclude_inst:
        sub_directory = "rank_shap_map_no_inst"
    else:
        sub_directory = "rank_shap_map"

    filtered = np.array(filter_by_month(X_train.columns, feature_name, max_month))
    lags = np.array(
        [get_lag(feature, target_feature=feature_name) for feature in filtered]
    )

    # Ensure lags are sorted consistently.
    lag_sort_inds = np.argsort(lags)
    filtered = tuple(filtered[lag_sort_inds])
    lags = tuple(lags[lag_sort_inds])

    if exclude_inst:
        if 0 in lags:
            assert lags[0] == 0
            lags = lags[1:]
            filtered = filtered[1:]

    n_features = len(filtered)

    # There is no point plotting this map for a single feature or less since we are
    # interested in a comparison between different feature ranks.
    if n_features <= 1:
        continue

    selected_data = np.empty(n_features, dtype=object)
    for i, col in enumerate(X_train.columns):
        if col in filtered:
            selected_data[lags.index(get_lag(col))] = masked_shap_arrs[i].copy()

    shared_mask = reduce(np.logical_or, (data.mask for data in selected_data))
    for data in selected_data:
        data.mask = shared_mask

    stacked_abs = np.abs(np.vstack([data.data[np.newaxis] for data in selected_data]))
    # Indices in descending order.
    sort_indices = np.argsort(stacked_abs, axis=0)[::-1]

    # Maintain the same colors even if fewer colors are used.
    colors = [lag_color_dict[lag] for lag in lags]

    cmap, norm = from_levels_and_colors(
        levels=np.arange(n_features + 1), colors=colors, extend="neither",
    )

    short_feature = shorten_features(feature_name)

    sum_shap = np.ma.MaskedArray(np.sum(stacked_abs, axis=0), mask=shared_mask)
    already_plotted = np.zeros_like(sum_shap)

    for i, rank in zip(
        tqdm(range(n_features), desc="Plotting", leave=False),
        ["1st", "2nd", "3rd", "4th", "5th"],
    ):
        cube = dummy_lat_lon_cube(np.ma.MaskedArray(sort_indices[i], mask=shared_mask))

        fig, ax = plt.subplots(
            figsize=(5.1, 2.6), subplot_kw={"projection": ccrs.Robinson()}
        )

        style = 2

        if style == 1:
            # Stippling for significant areas.
            mpl.rc("hatch", linewidth=0.2)
            hatches = ["." * 6, None]
        else:
            # Hatching for insignificant areas.
            mpl.rc("hatch", linewidth=0.1)
            hatches = ["/" * 14, None]

        if np.any(already_plotted >= thres):
            ax.contourf(
                cube.coord("longitude").points,
                cube.coord("latitude").points,
                already_plotted,
                transform=ccrs.PlateCarree(),
                colors="none",
                zorder=4,
                levels=[thres, 1],
                hatches=hatches,
            )

        fig, cbar = cube_plotting(
            cube,
            title=f"{rank} |SHAP {short_feature} Lag| - thres: {thres * 100:0.0f}%",
            fig=fig,
            ax=ax,
            cmap=cmap,
            norm=norm,
            return_cbar=True,
            colorbar_kwargs={"label": short_feature},
            coastline_kwargs={"linewidth": 0.3},
        )

        # Label the colorbar using the feature names.
        cbar.set_ticks(np.arange(n_features) + 0.5)

        labels = []
        for lag in lags:
            if lag:
                labels.append(f"{lag} M")
            else:
                labels.append("Inst.")
        cbar.set_ticklabels(labels)

        data = np.take_along_axis(stacked_abs, sort_indices[i : i + 1], axis=0)[0]
        already_plotted += data / sum_shap

        map_figure_saver.save_figure(
            fig, f"{rank}_{short_feature}", sub_directory=sub_directory
        )
        plt.close(fig)