## Setup

In [None]:
from common import *
from common import _get_centres, _sci_format, _second_order_ale_quant

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

#### Get Dask Client

In [None]:
client = get_client()
client

## ALE Plotting

In [None]:
def save_ale_2d(predictor, train_set, features, bins=40, coverage=1):
    if coverage < 1:
        # This should be ok if `train_set` is randomised, as it usually is.
        train_set = train_set[: int(train_set.shape[0] * coverage)]

    ale, quantiles_list, samples_grid = _second_order_ale_quant(
        predictor, train_set, features, bins=bins, return_samples_grid=True
    )

    fig, ax = plt.subplots(figsize=(7.5, 4.5))

    # Quantile axis transformation.
    quantile_axis_list = ("x", "y")

    plotting_quantiles_list = []
    for axis, quantiles in zip(("x", "y"), quantiles_list):
        if axis in quantile_axis_list:
            inds = np.arange(len(quantiles))
            plotting_quantiles_list.append(inds)
            ax.set(**{f"{axis}ticks": _get_centres(inds)})
            ax.set(
                **{
                    f"{axis}ticklabels": _sci_format(
                        _get_centres(quantiles), scilim=0.6
                    )
                }
            )
        else:
            plotting_quantiles_list.append(quantiles)

    centres_list = [_get_centres(quantiles) for quantiles in plotting_quantiles_list]
    n_x, n_y = 50, 50
    x = np.linspace(centres_list[0][0], centres_list[0][-1], n_x)
    y = np.linspace(centres_list[1][0], centres_list[1][-1], n_y)

    X, Y = np.meshgrid(x, y, indexing="xy")
    ale_interp = scipy.interpolate.interp2d(centres_list[0], centres_list[1], ale.T)

    CF = ax.contourf(X, Y, ale_interp(x, y), levels=30, alpha=0.85,)

    # Do not autoscale, so that boxes at the edges (contourf only plots the bin
    # centres, not their edges) don't enlarge the plot. Such boxes include markings for
    # invalid cells, or hatched boxes for valid cells.
    plt.autoscale(False)

    # Add hatching for the significant cells. These have at least `min_samples` samples.
    # By default, calculate this as the number of samples in each bin if everything was equally distributed, divided by 10.
    min_samples = (train_set.shape[0] / reduce(mul, map(len, centres_list))) / 10
    for i, j in zip(*np.where(samples_grid >= min_samples)):
        ax.add_patch(
            Rectangle(
                [plotting_quantiles_list[0][i], plotting_quantiles_list[1][j]],
                plotting_quantiles_list[0][i + 1] - plotting_quantiles_list[0][i],
                plotting_quantiles_list[1][j + 1] - plotting_quantiles_list[1][j],
                linewidth=0,
                fill=None,
                hatch=".",
                alpha=0.4,
            )
        )

    if np.any(ale.mask):
        # Add rectangles to indicate cells without samples.
        for i, j in zip(*np.where(ale.mask)):
            ax.add_patch(
                Rectangle(
                    [plotting_quantiles_list[0][i], plotting_quantiles_list[1][j]],
                    plotting_quantiles_list[0][i + 1] - plotting_quantiles_list[0][i],
                    plotting_quantiles_list[1][j + 1] - plotting_quantiles_list[1][j],
                    linewidth=1,
                    edgecolor="k",
                    facecolor="none",
                    alpha=0.4,
                )
            )
    fig.colorbar(CF, format="%.0e", pad=0.03, aspect=32, shrink=0.85)
    ax.set_xlabel(features[0])
    ax.set_ylabel(features[1])
    nbins_str = "x".join([str(len(centres)) for centres in centres_list])
    fig.suptitle(
        f"Second-order ALE of {features[0]} and {features[1]}\n"
        f"Bins: {nbins_str} (Hatching: Sig., Boxes: Missing)"
    )
    plt.subplots_adjust(top=0.89)
    ax.xaxis.set_tick_params(rotation=45)

    figure_saver.save_figure(fig, "__".join(features), sub_directory="2d_ale")
    return ale, quantiles_list, samples_grid

### Worldwide

In [None]:
def save_ale_and_get_importance(columns, model, train_set, coverage=0.02):
    model.n_jobs = 8
    ale, quantiles_list, samples_grid = save_ale_2d(
        model.predict, train_set, columns, bins=20, coverage=coverage,
    )
    min_samples = (
        train_set.shape[0] / reduce(mul, map(lambda x: len(x) - 1, quantiles_list))
    ) / 10
    #     try:
    return np.ma.max(ale[samples_grid > min_samples]) - np.ma.min(
        ale[samples_grid > min_samples]
    )


#     except:
#         return None

# XXX: Local trial
# save_ale_and_get_importance(columns_list[0], rf, X_train)

columns_list = list(combinations(X_train.columns, 2))

print("Scattering")
rf_fut = client.scatter(rf, broadcast=True)
X_fut = client.scatter(X_train, broadcast=True)
print("Finished scattering")

ale_fs = [
    client.submit(save_ale_and_get_importance, columns, rf_fut, X_fut)
    for columns in columns_list
]

for ale_f in tqdm(
    dask.distributed.as_completed(ale_fs),
    total=len(ale_fs),
    unit="plot",
    desc="Calculating 2D ALE plots",
    smoothing=0,
    position=0,
):
    if ale_f.status == "error":
        print(ale_f.result())


ptp_values = {}

for columns, ale_f in zip(columns_list, ale_fs):
    ptp_values[columns] = ale_f.result()

In [None]:
# Ignore and count None values, then plot a histogram of the ptp values.
filtered_columns_list = []
filtered_ptp_values = []
for columns, ptp in ptp_values.items():
    if ptp is not None:
        filtered_columns_list.append(columns)
        filtered_ptp_values.append(ptp)

np.asarray([ptp for ptp in ptp_values if ptp is not None])
_ = plt.hist(filtered_ptp_values, bins=20)

pdp_results = pd.Series(filtered_ptp_values, index=filtered_columns_list)
pdp_results.sort_values(inplace=True, ascending=False)
print(pdp_results.head(20))