In [None]:
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

In [None]:
from wildfires.dask_cx1 import CachedResults, DaskRandomForestRegressor

from empirical_fire_modelling.configuration import CACHE_DIR, n_splits

#### Load cached results

In [None]:
cached = CachedResults(
    estimator_class=DaskRandomForestRegressor, n_splits=n_splits, cache_dir=CACHE_DIR
)
results = cached.collate_scores(train_scores=True)

In [None]:
cached.get_best_params()

#### Visualise the different hyperparameter combinations

In [None]:
hyperparams = defaultdict(list)

for param_tuples, param_results in results.items():
    for category, scores in param_results.items():
        if len(scores) == n_splits:
            hyperparams[category].append(np.mean(scores))
            hyperparams[category + "_std"].append(np.std(scores))
        else:
            print(param_tuples, category, len(scores))
            break  # Do not append anything.
    else:
        for param, param_value in param_tuples:
            hyperparams[param].append(param_value)

In [None]:
hyperparams = pd.DataFrame(hyperparams)
score_keys = list(param_results)
score_std_keys = [score_key + "_std" for score_key in score_keys]
param_keys = list(set(hyperparams.columns) - set(score_keys) - set(score_std_keys))
hyperparams.fillna(-1, inplace=True)

In [None]:
hyperparams_gap = hyperparams[hyperparams["test_score"] > 0.68].copy()
hyperparams_gap["gap"] = hyperparams_gap["train_score"] - hyperparams_gap["test_score"]
print("Nr. of params:", len(hyperparams_gap))
plt.plot(
    hyperparams_gap["test_score"], hyperparams_gap["gap"], linestyle="", marker="o"
)
plt.xlabel("test score (R2)")
plt.ylabel("R2 gap")
plt.grid(linestyle="--", alpha=0.4)
hyperparams_gap.sort_values(by="gap")

In [None]:
hyperparams.sort_values(by="test_score", ascending=False)[:20]

In [None]:
_ = hyperparams.boxplot(column=score_keys, by=["min_samples_split", "n_estimators"])

In [None]:
melted = pd.melt(
    hyperparams[hyperparams["test_score"] > 0.63].drop(columns=score_std_keys),
    id_vars=param_keys,
    value_vars=score_keys,
    var_name="category",
    value_name="score",
)
melted

#### Visualise the effect of individual parameters

In [None]:
from alepython.ale import _sci_format

for param_key in param_keys:
    fig = plt.figure(figsize=(9, 6))

    ax = sns.boxplot(x=param_key, y="score", hue="category", data=melted)
    ax.set(ylabel="R2 Score")
    ax.grid(which="both", alpha=0.4, linestyle="--")

    if param_key == "ccp_alpha":
        ax.xaxis.set_ticklabels(
            _sci_format(
                np.array(
                    list(map(lambda x: float(x.get_text()), ax.xaxis.get_ticklabels()))
                )
            )
        )
        ax.xaxis.set_tick_params(rotation=45)

#### Standard deviations

In [None]:
melted_std = pd.melt(
    hyperparams[hyperparams["test_score"] > 0.63].drop(columns=score_keys),
    id_vars=param_keys,
    value_vars=score_std_keys,
    var_name="category",
    value_name="score_std",
)
melted_std

In [None]:
from alepython.ale import _sci_format

for param_key in param_keys:
    fig = plt.figure(figsize=(9, 6))

    ax = sns.boxplot(x=param_key, y="score_std", hue="category", data=melted_std)
    ax.set(ylabel="R2 Score")
    ax.grid(which="both", alpha=0.4, linestyle="--")

    if param_key == "ccp_alpha":
        ax.xaxis.set_ticklabels(
            _sci_format(
                np.array(
                    list(map(lambda x: float(x.get_text()), ax.xaxis.get_ticklabels()))
                )
            )
        )
        ax.xaxis.set_tick_params(rotation=45)