## Setup

In [None]:
from specific import *

In [None]:
(
    endog_data,
    exog_data,
    master_mask,
    filled_datasets,
    masked_datasets,
    land_mask,
) = get_offset_data()

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

### SHAP values - see PBS (array) job folder 'shap'!

#### Load the SHAP values (keeping track of missing chunks)

In [None]:
max_index = 995  # Maximum job array index (inclusive).
job_samples = 2000  # Samples per job.
total_samples = (max_index + 1) * job_samples  # Sanity check.

# Load the individual data chunks.
shap_chunks = []
missing = []
for index in tqdm(range(max_index + 1), desc="Loading chunks"):
    try:
        shap_chunks.append(
            SimpleCache(
                f"tree_path_dependent_shap_{index}_{job_samples}",
                cache_dir=os.path.join(CACHE_DIR, "shap"),
                verbose=0,
            ).load()
        )
    except NoCachedDataError:
        missing.append(index)

if missing:
    print("missing:", missing)
    print("nr missing:", len(missing))

shap_values = np.vstack(shap_chunks)

### BA in the train and test sets

Valid elements are situated where master_mask is False

In [None]:
valid_indices = np.where(~master_mask.ravel())[0]

valid_train_indices, valid_test_indices = train_test_split(
    valid_indices, random_state=1, shuffle=True, test_size=0.3
)

masked_train_data = np.ma.MaskedArray(
    np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
)
masked_train_data.ravel()[valid_train_indices] = y_train.values

masked_test_data = np.ma.MaskedArray(
    np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
)
masked_test_data.ravel()[valid_test_indices] = y_test.values

In [None]:
with figure_saver("train_test_set_overall_ba_comp"):
    fig, axes = plt.subplots(
        3,
        1,
        constrained_layout=True,
        figsize=(5.1, 8.4),
        subplot_kw={"projection": ccrs.Robinson()},
    )
    shared_kwargs = {
        "boundaries": [0, 4e-6, 1e-5, 1e-4, 1e-3, 1e-2, 8e-2],
        "extend": "max",
        "cmap": "inferno",
        "colorbar_kwargs": {"format": "%0.1e", "label": "Fractional BA"},
        "coastline_kwargs": {"linewidth": 0.3},
        "title": "",
    }
    axes[0].set_title("Mean Overall GFED4 BA")
    cube_plotting(
        get_masked_array(endog_data.values, master_mask),
        ax=axes[0],
        fig=fig,
        **shared_kwargs
    )
    axes[1].set_title("Mean Train Set GFED4 BA")
    cube_plotting(masked_train_data, ax=axes[1], fig=fig, **shared_kwargs)
    axes[2].set_title("Mean Test Set GFED4 BA")
    cube_plotting(masked_test_data, ax=axes[2], fig=fig, **shared_kwargs)

In [None]:
plt.hist(np.abs(shap_values[:, 0]), bins=5000)
plt.xscale("log")
plt.yscale("log")

In [None]:
masked_shap_arrs = []
vmins = []
vmaxs = []

for i, feature in enumerate(tqdm(X_train.columns, desc="Computing SHAP values")):
    masked_shap_comp = np.ma.MaskedArray(
        np.zeros_like(master_mask, dtype=np.float64), mask=np.ones_like(master_mask)
    )
    masked_shap_comp.ravel()[valid_train_indices[: shap_values.shape[0]]] = shap_values[
        :, i
    ]
    avg_shap_comp = np.ma.mean(masked_shap_comp, axis=0)
    masked_shap_arrs.append(avg_shap_comp)
    vmins.append(np.min(avg_shap_comp))
    vmaxs.append(np.max(avg_shap_comp))

vmin = min(vmins)
vmax = max(vmaxs)

In [None]:
for i, feature in enumerate(tqdm(X_train.columns, desc="Mapping SHAP values")):
    fig = cube_plotting(
        masked_shap_arrs[i],
        fig=plt.figure(figsize=(5.1, 2.8)),
        title=f"Mean SHAP value for '{shorten_features(feature)}'",
        cmap="Spectral_r",
        nbins=7,
        cmap_midpoint=0,
        cmap_symmetric=True,
        vmin=vmin,
        vmax=vmax,
        log=True,
        log_auto_bins=False,
        min_edge=1e-3,
        extend="neither",
        colorbar_kwargs={
            "format": "%0.1e",
            "label": f"SHAP ('{shorten_features(feature)}')",
        },
        coastline_kwargs={"linewidth": 0.3},
    )
    figure_saver.save_figure(fig, f"shap_value_map_{feature}", sub_directory="shap_map")