## Setup

In [None]:
from specific import *

### Retrieve previous results from the 'model' notebook

In [None]:
X_train, X_test, y_train, y_test = data_split_cache.load()
rf = get_model()

#### Get Dask Client

In [None]:
client = get_client()
# client = Client(n_workers=1, threads_per_worker=8, resources={'threads': 8})
add_common_path(client)
client

In [None]:
rf_fut = client.scatter(rf, broadcast=True)
X_fut = client.scatter(X_train, broadcast=True)

## ALE Plotting

### Worldwide

In [None]:
world_ale_1d_cache = SimpleCache("world_ale_1d", cache_dir=CACHE_DIR)

# world_ale_1d_cache.clear()


@world_ale_1d_cache
def get_world_ale_1d():
    n_threads = 8
    ale_fs = [
        client.submit(
            add_common_path_deco(save_ale_plot_1d_with_ptp),
            model=rf_fut,
            X_train=X_fut,
            column=column,
            n_jobs=n_threads,
            monte_carlo_rep=1000,
            resources={"threads": n_threads},
            pure=False,
        )
        for column in X_train.columns
    ]

    for ale_f in tqdm(
        dask.distributed.as_completed(ale_fs),
        total=len(ale_fs),
        unit="plot",
        desc="Calculating 1D ALE plots",
        smoothing=0,
    ):
        if ale_f.status == "error":
            print(ale_f.result())

    ptp_values = {}
    mc_ptp_values = {}

    for column, ale_f in zip(X_train.columns, ale_fs):
        ptp_values[column], mc_ptp_values[column] = ale_f.result()
    return ptp_values, mc_ptp_values


ptp_values, mc_ptp_values = get_world_ale_1d()

### Run Non-MC runs manually (just for the plots)

In [None]:
n_threads = 8
ale_fs = [
    client.submit(
        add_common_path_deco(save_ale_plot_1d_with_ptp),
        model=rf_fut,
        X_train=X_fut,
        column=column,
        n_jobs=n_threads,
        monte_carlo=False,
        resources={"threads": n_threads},
        pure=False,
    )
    for column in X_train.columns
]

for ale_f in tqdm(
    dask.distributed.as_completed(ale_fs),
    total=len(ale_fs),
    unit="plot",
    desc="Calculating 1D Non-MC ALE plots",
    smoothing=0,
):
    if ale_f.status == "error":
        print(ale_f.result())

## PDP Plotting

### Worldwide

In [None]:
%time save_pdp_plot_1d(rf, X_train, 'Dry Day Period -12 - 0 Month', n_jobs=32)

### Sequentially Locally

In [None]:
for column in tqdm(
    X_train.columns, unit="plot", desc="Calculating 1D PDP plots", smoothing=0.05
):
    figs, pdp_isolate_out, data_file = save_pdp_plot_1d(rf, X_train, column, n_jobs=32)
    for fig in figs:
        plt.close(fig)

### Using a Dask distributed Cluster

In [None]:
%%time
n_threads = 8
pdp_fs = [
    client.submit(
        add_common_path_deco(save_pdp_plot_1d),
        model=rf_fut,
        X_train=X_fut,
        column=column,
        n_jobs=n_threads,
        resources={"threads": n_threads},
        pure=False,
    )
    for column in X_train.columns
]

for pdp_f in tqdm(
    dask.distributed.as_completed(pdp_fs),
    total=len(pdp_fs),
    unit="plot",
    desc="Calculating 1D PDP plots",
    smoothing=0.04,
):
    if pdp_f.status == "error":
        print(pdp_f.result())

## Combining Multiple ALE plots

In [None]:
short_X_train = shorten_columns(X_train)
short_X_train.columns = repl_fill_names(short_X_train.columns)

for feature in tqdm(
    [
        name
        for name in ("Dry Days", "SIF", "FAPAR", "LAI", "VOD")
        # Require the feature to be present for all shifts (including 0)
        if sum(name in c for c in short_X_train.columns) > 2
    ],
    desc="Multiple shift ALE plots",
):
    multi_ale_plot_1d(
        rf,
        short_X_train,
        [c for c in short_X_train.columns if feature in c and get_lag(c) <= 9],
        f'{feature.replace(" ", "_").lower()}_ale_shifts',
        n_jobs=get_ncpus(),
        verbose=False,
        xlabel=f"{feature}",
        #         title=f"First-order ALE for {feature}",
        figure_saver=figure_saver,
        CACHE_DIR=CACHE_DIR,
    )

In [None]:
short_X_train = shorten_columns(X_train)
short_X_train.columns = repl_fill_names(short_X_train.columns)

fig, axes = plt.subplots(1, 2, figsize=(9.1, 3.4))

features = ("FAPAR", "Dry Days")

for feature, ax, title in zip(features, axes, ("(a)", "(b)")):
    multi_ale_plot_1d(
        rf,
        short_X_train,
        [c for c in short_X_train.columns if feature in c and get_lag(c) <= 9],
        fig=fig,
        ax=ax,
        n_jobs=get_ncpus(),
        verbose=False,
        xlabel=f"{feature}",
        #         title=f"First-order ALE for {feature}",
        CACHE_DIR=CACHE_DIR,
        figure_saver=None,
        x_rotation=30,
    )
    ax.text(-0.09, 1.05, title, transform=ax.transAxes, fontsize=11)

axes[0].set_ylabel("ALE")

fig.tight_layout()

figure_saver.save_figure(
    fig,
    f'{"__".join(features).replace(" ", "_").lower()}_ale_shifts',
    sub_directory="multi_ale",
)