## Setup

In [None]:
from specific import *

### Get shifted data

In [None]:
(
    endog_data,
    exog_data,
    master_mask,
    filled_datasets,
    masked_datasets,
    land_mask,
) = get_offset_data()

In [None]:
client = get_client()
client

### Define the training and test data

In [None]:
@data_split_cache
def get_split_data():
    X_train, X_test, y_train, y_test = train_test_split(
        exog_data, endog_data, random_state=1, shuffle=True, test_size=0.3
    )
    return X_train, X_test, y_train, y_test


X_train, X_test, y_train, y_test = get_split_data()

In [None]:
n_splits = 5

# Define the parameter space.

parameters_RF = {
    "n_estimators": [100, 600],
    "max_depth": [8, 14],
    "min_samples_split": [2, 10],
    "min_samples_leaf": [1, 4],
    "max_leaf_nodes": [1000, None],
}

default_param_dict = {
    "random_state": 1,
    "bootstrap": True,
    "max_features": "auto",
}

## Hyperparameter optimisation

In [None]:
@cross_val_cache
def run_cross_val():
    results, rf = fit_dask_rf_grid_search_cv(
        DaskRandomForestRegressor(**default_param_dict),
        X_train.values,
        y_train.values,
        n_splits,
        parameters_RF,
        client,
        verbose=True,
        return_train_score=True,
        refit=True,
        local_n_jobs=30,
    )
    return results, rf


results, rf = run_cross_val()

## Hyperparameter Search Visualisation

In [None]:
hyperparams = defaultdict(list)

for param_tuples, param_results in results.items():
    for category, scores in param_results.items():
        hyperparams[category].extend(scores)

    for param, param_value in param_tuples:
        hyperparams[param].extend([param_value] * len(scores))

In [None]:
hyperparams = pd.DataFrame(hyperparams)
score_keys = list(param_results)
param_keys = list(set(hyperparams.columns) - set(param_results))
hyperparams.fillna(-1, inplace=True)

In [None]:
means = hyperparams.groupby(param_keys).mean()
means[means["test_scores"] == np.max(means["test_scores"])]

In [None]:
hyperparams.boxplot(column=score_keys, by=param_keys[:2])

In [None]:
melted = pd.melt(
    hyperparams,
    id_vars=param_keys,
    value_vars=score_keys,
    var_name="category",
    value_name="score",
)
melted

In [None]:
for param_key in param_keys:
    fig = plt.figure(figsize=(7, 6))
    ax = sns.boxplot(x=param_key, y="score", hue="category", data=melted)
    ax.set(ylabel="R2 Score")
    figure_saver.save_figure(fig, param_key, sub_directory="hyperparameters")