## Setup

In [None]:
from common import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

#### Get Dask Client

In [None]:
client = get_client()
# client = Client(n_workers=1, threads_per_worker=8, resources={'threads': 8})
client

In [None]:
rf_fut = client.scatter(rf, broadcast=True)
X_fut = client.scatter(X_train, broadcast=True)

## ALE Plotting

### Worldwide

In [None]:
def save_ale_plot_1d_with_ptp(
    model,
    X_train,
    column,
    n_jobs=8,
    monte_carlo_rep=1000,
    monte_carlo_ratio=100,
    verbose=False,
    monte_carlo=True,
):
    model.n_jobs = n_jobs
    with parallel_backend("threading", n_jobs=n_jobs):
        fig, ax = plt.subplots(
            figsize=(7.5, 4.5)
        )  # Make sure plot is plotted onto a new figure.
        out = ale_plot(
            model,
            X_train,
            column,
            bins=20,
            monte_carlo=monte_carlo,
            monte_carlo_rep=monte_carlo_rep,
            monte_carlo_ratio=monte_carlo_ratio,
            plot_quantiles=True,
            quantile_axis=True,
            rugplot_lim=0,
            scilim=0.6,
            return_data=True,
            return_mc_data=True,
            verbose=verbose,
        )
    if monte_carlo:
        fig, axes, data, mc_data = out
    else:
        fig, axes, data = out

    for ax_key in ("ale", "quantiles_x"):
        axes[ax_key].xaxis.set_tick_params(rotation=45)

    sub_dir = "ale" if monte_carlo else "ale_non_mc"
    figure_saver.save_figure(fig, column, sub_directory=sub_dir)

    if monte_carlo:
        mc_ales = np.array([])
        for mc_q, mc_ale in mc_data:
            mc_ales = np.append(mc_ales, mc_ale)
        return np.ptp(data[1]), np.ptp(mc_ales)
    else:
        return np.ptp(data[1])

In [None]:
world_ale_1d_cache = SimpleCache("world_ale_1d", cache_dir=CACHE_DIR)

# world_ale_1d_cache.clear()


@world_ale_1d_cache
def get_world_ale_1d():
    n_threads = 8
    ale_fs = [
        client.submit(
            save_ale_plot_1d_with_ptp,
            model=rf_fut,
            X_train=X_fut,
            column=column,
            n_jobs=n_threads,
            monte_carlo_rep=1000,
            resources={"threads": n_threads},
        )
        for column in X_train.columns
    ]

    for ale_f in tqdm(
        dask.distributed.as_completed(ale_fs),
        total=len(ale_fs),
        unit="plot",
        desc="Calculating 1D ALE plots",
        smoothing=0,
    ):
        if ale_f.status == "error":
            print(ale_f.result())

    ptp_values = {}
    mc_ptp_values = {}

    for column, ale_f in zip(X_train.columns, ale_fs):
        ptp_values[column], mc_ptp_values[column] = ale_f.result()
    return ptp_values, mc_ptp_values


ptp_values, mc_ptp_values = get_world_ale_1d()

### Run Non-MC runs manually (just for the plots)

In [None]:
n_threads = 8
ale_fs = [
    client.submit(
        save_ale_plot_1d_with_ptp,
        model=rf_fut,
        X_train=X_fut,
        column=column,
        n_jobs=n_threads,
        monte_carlo=False,
        resources={"threads": n_threads},
    )
    for column in X_train.columns
]

for ale_f in tqdm(
    dask.distributed.as_completed(ale_fs),
    total=len(ale_fs),
    unit="plot",
    desc="Calculating 1D Non-MC ALE plots",
    smoothing=0,
):
    if ale_f.status == "error":
        print(ale_f.result())

## PDP Plotting

### Worldwide

In [None]:
from alepython.ale import _sci_format


def save_pdp_plot_1d(model, X_train, column, n_jobs):
    data_file = os.path.join(CACHE_DIR, "pdp_data", column)

    if not os.path.isfile(data_file):
        model.n_jobs = n_jobs
        with parallel_backend("threading", n_jobs=n_jobs):
            pdp_isolate_out = pdp.pdp_isolate(
                model=model,
                dataset=X_train,
                model_features=X_train.columns,
                feature=column,
                num_grid_points=20,
            )
        os.makedirs(os.path.dirname(data_file), exist_ok=True)
        with open(data_file, "wb") as f:
            pickle.dump((column, pdp_isolate_out), f, -1)
    else:
        with open(data_file, "rb") as f:
            column, pdp_isolate_out = pickle.load(f)

    # With ICEs.
    fig_ice, axes_ice = pdp.pdp_plot(
        pdp_isolate_out,
        column,
        plot_lines=True,
        center=True,
        frac_to_plot=1000,
        x_quantile=True,
        figsize=(7, 5),
    )
    axes_ice["pdp_ax"].xaxis.set_tick_params(rotation=45)
    figure_saver.save_figure(fig_ice, column, sub_directory="pdp")

    # Without ICEs.
    fig_no_ice, ax = plt.subplots(figsize=(7.5, 4.5))
    plt.plot(pdp_isolate_out.pdp - pdp_isolate_out.pdp[0], marker="o")
    plt.xticks(
        ticks=range(len(pdp_isolate_out.pdp)),
        labels=_sci_format(pdp_isolate_out.feature_grids, scilim=0.6),
        rotation=45,
    )
    plt.xlabel(f"{column}")
    plt.title(f"PDP of feature '{column}'\nBins: {len(pdp_isolate_out.pdp)}")
    plt.grid(alpha=0.4, linestyle="--")
    figure_saver.save_figure(fig_no_ice, column, sub_directory="pdp_no_ice")
    return (fig_ice, fig_no_ice), pdp_isolate_out, data_file

In [None]:
%time save_pdp_plot_1d(rf, X_train, 'Dry Day Period -12 - 0 Month', n_jobs=32)

### Sequentially Locally

In [None]:
for column in tqdm(
    X_train.columns, unit="plot", desc="Calculating 1D PDP plots", smoothing=0.05
):
    figs, pdp_isolate_out, data_file = save_pdp_plot_1d(rf, X_train, column, n_jobs=32)
    for fig in figs:
        plt.close(fig)

### Using a Dask distributed Cluster

In [None]:
%%time
n_threads = 8
pdp_fs = [
    client.submit(
        save_pdp_plot_1d,
        model=rf_fut,
        X_train=X_fut,
        column=column,
        n_jobs=n_threads,
        resources={"threads": n_threads},
    )
    for column in X_train.columns
]

for pdp_f in tqdm(
    dask.distributed.as_completed(pdp_fs),
    total=len(pdp_fs),
    unit="plot",
    desc="Calculating 1D PDP plots",
    smoothing=0.04,
):
    if pdp_f.status == "error":
        print(pdp_f.result())

## Combining Multiple ALE plots

In [None]:
from tqdm.auto import tqdm

from alepython.ale import _sci_format, first_order_ale_quant


def multi_ale_plot_1d(
    model,
    X_train,
    columns,
    fig_name,
    xlabel=None,
    ylabel=None,
    title=None,
    n_jobs=8,
    verbose=False,
):
    fig, ax = plt.subplots(
        figsize=(7.5, 4.5)
    )  # Make sure plot is plotted onto a new figure.
    model.n_jobs = n_jobs
    with parallel_backend("threading", n_jobs=n_jobs):
        quantile_list = []
        ale_list = []
        for feature in tqdm(
            columns, desc="Calculating feature ALEs", disable=not verbose
        ):
            quantiles, ale = first_order_ale_quant(
                model.predict, X_train, feature, bins=20
            )
            quantile_list.append(quantiles)
            ale_list.append(ale)

    # Construct quantiles from the individual quantiles, minimising the amount of interpolation.
    combined_quantiles = np.vstack([quantiles[None] for quantiles in quantile_list])

    final_quantiles = np.mean(combined_quantiles, axis=0)
    # Account for extrema.
    final_quantiles[0] = np.min(combined_quantiles)
    final_quantiles[-1] = np.max(combined_quantiles)

    mod_quantiles = np.arange(len(quantiles))
    for feature, quantiles, ale in zip(columns, quantile_list, ale_list):
        # Interpolate each of the quantiles relative to the accumulated final quantiles.
        ax.plot(
            np.interp(quantiles, final_quantiles, mod_quantiles),
            ale,
            marker="o",
            ms=3,
            label=feature,
        )

    ax.legend(loc="best")
    ax.set_xticks(mod_quantiles)
    ax.set_xticklabels(_sci_format(final_quantiles, scilim=0.6))
    ax.xaxis.set_tick_params(rotation=45)
    ax.grid(alpha=0.4, linestyle="--")

    fig.suptitle(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    figure_saver.save_figure(fig, fig_name, sub_directory="multi_ale")

In [None]:
X_train.columns

In [None]:
for feature in tqdm(
    ("Dry Day Period", "SIF", "FAPAR", "LAI"), desc="Multiple shift ALE plots"
):
    multi_ale_plot_1d(
        rf,
        X_train,
        (f"{feature}", *(f"{feature} {m} Month" for m in (-1, -3, -6, -9))),
        f'{feature.replace(" ", "_").lower()}_ale_shifts',
        n_jobs=32,
        verbose=True,
        xlabel=f"{feature}",
        title=f"First-order ALE for {feature}",
    )