## Setup

In [None]:
from common import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
results, rf = cross_val_cache.load()

#### Get Dask Client

In [None]:
client = get_client()
client

## ALE Plotting

### Worldwide

In [None]:
def save_ale_plot_1d(model, X_train, column):
    fig, ax = ale_plot(
        model,
        X_train,
        column,
        bins=20,
        monte_carlo=True,
        monte_carlo_rep=10000,
        monte_carlo_ratio=100,
        plot_quantiles=False,
        quantile_axis=True,
        rugplot_lim=0,
        scilim=0.6,
    )
    ax.xaxis.set_tick_params(rotation=45)
    figure_saver.save_figure(fig, column, sub_directory="ale")


# save_ale_plot_1d(rf, X_train[:1000], 'Dry Day Period -12 - 0 Month')

# for column in tqdm(
#             X_train.columns,
#             unit="ALE plot",
#             desc="Calculating 1D ALE plots",
#             smoothing=0,
#             position=0,
#         ):
#     rf.n_jobs = 30
#     save_ale_plot_1d(rf, X_train, column)

rf_fut = client.scatter(rf, broadcast=True)
X_fut = client.scatter(X_train, broadcast=True)

ale_fs = [
    client.submit(save_ale_plot_1d, rf_fut, X_fut, column)
    #           for column in ['Dry Day Period -12 - 0 Month']
    for column in X_train.columns
]

for ale_f in tqdm(
    dask.distributed.as_completed(ale_fs),
    total=len(ale_fs),
    unit="plot",
    desc="Calculating 1D ALE plots",
    smoothing=0,
    position=0,
):
    if ale_f.status == "error":
        print(ale_f.result())

## PDP Plotting

### Worldwide

In [None]:
def save_pdp_plot_1d(model, X_train, column):
    pdp_isolate_out = pdp.pdp_isolate(
        model=rf,
        dataset=X_train,
        model_features=X_train.columns,
        feature=column,
        num_grid_points=20,
    )
    fig, axes = pdp.pdp_plot(
        pdp_isolate_out,
        column,
        plot_lines=True,
        frac_to_plot=0.2,
        x_quantile=True,
        center=True,
        figsize=(7, 5),
    )
    axes["pdp_ax"].xaxis.set_tick_params(rotation=45)
    figure_saver.save_figure(fig, column, sub_directory="pdp")


rf.n_jobs = 8
# save_pdp_plot_1d(rf, X_train[:1000], 'Dry Day Period -12 - 0 Month')

rf_fut = client.scatter(rf, broadcast=True)
X_fut = client.scatter(X_train, broadcast=True)

pdp_fs = [
    client.submit(save_pdp_plot_1d, rf_fut, X_fut, column)
    #           for column in ['Dry Day Period -12 - 0 Month']
    for column in X_train.columns
]

for pdp_f in tqdm(
    dask.distributed.as_completed(pdp_fs),
    total=len(pdp_fs),
    unit="plot",
    desc="Calculating 1D PDP plots",
    smoothing=0,
    position=0,
):
    if pdp_f.status == "error":
        print(pdp_f.result())