## Initialisation

In [None]:
from itertools import product
from tempfile import TemporaryDirectory
from time import time

import numpy as np
from dask.distributed import Client, as_completed
from joblib import parallel_backend
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import KFold, cross_val_score
from tqdm import tqdm

from wildfires.dask_cx1 import get_client
from wildfires.logging_config import enable_logging

enable_logging(level="debug")


class Time:
    def __init__(self, name=""):
        self.name = name

    def __enter__(self):
        self.start = time()

    def __exit__(self, type, value, traceback):
        print("Time taken for {}: {}".format(self.name, time() - self.start))

## Create a LocalCluster for demonstration purposes

In [None]:
# Used for local scoring.
local_n_jobs = 1

threads_per_worker = 3
client = Client(n_workers=1, threads_per_worker=threads_per_worker,)
client

### Or use an existing distributed cluster

In [None]:
# Used for local scoring.
local_n_jobs = 32

client = get_client()
client

### Define Common Parameters and Data

In [None]:
# Define the common training and test data.
np.random.seed(1)
X = np.random.random((int(1e3), 40))
y = X[:, 0] + X[:, 1] + np.random.random((X.shape[0],))

# Define the number of splits.
n_splits = 5
kf = KFold(n_splits=n_splits)

# Define the parameter space.
parameters_RF = {
    "n_estimators": [50],
    "max_depth": [6, 9, 12],
    "min_samples_split": [2],
    "min_samples_leaf": [1, 5, 10],
}

default_param_dict = {
    "random_state": 1,
    "bootstrap": True,
    "max_features": "auto",
}

## Cached Timed Dask RF Grid Search

In [None]:
from wildfires.dask_cx1 import (
    DaskRandomForestRegressor,
    fit_dask_sub_est_random_search_cv,
)

with Time("Custom Dask random grid search"):
    with TemporaryDirectory() as tempdir:
        for _ in range(3):
            results, fit_est = fit_dask_sub_est_random_search_cv(
                DaskRandomForestRegressor(**default_param_dict),
                X,
                y,
                parameters_RF,
                client,
                n_splits=n_splits,
                max_time="6s",
                n_iter=None,
                verbose=True,
                refit=True,
                return_train_score=True,
                local_n_jobs=local_n_jobs,
                random_state=0,
                cache_dir=tempdir,
            )
            print("Nr. of results:", len(results))