## Setup

In [None]:
from specific import *

### Get shifted data

In [None]:
(
    endog_data,
    exog_data,
    master_mask,
    filled_datasets,
    masked_datasets,
    land_mask,
) = get_offset_data()

In [None]:
client = get_client()
client

### Define the training and test data

In [None]:
@data_split_cache
def get_split_data():
    X_train, X_test, y_train, y_test = train_test_split(
        exog_data, endog_data, random_state=1, shuffle=True, test_size=0.3
    )
    return X_train, X_test, y_train, y_test


X_train, X_test, y_train, y_test = get_split_data()

## Fit combinations

In [None]:
exog_data.columns

In [None]:
veg_features = ["VOD Ku-band 3NN", "LAI 3NN", "SIF 3NN", "FAPAR 3NN"]
shifts = ["", *[f" -{x} Month" for x in [1, 3, 6, 9]]]
veg_lags = []
for shift in shifts:
    shift_arr = []
    for veg_feature in veg_features:
        shift_arr.append(veg_feature + shift)
    veg_lags.append(shift_arr)
assert all(feature in exog_data for unpacked in veg_lags for feature in unpacked)
veg_lags

In [None]:
combinations = [
    (
        "Dry Day Period",
        "Max Temp",
        "pftCrop",
        "popd",
        "Diurnal Temp Range",
        "Dry Day Period -3 Month",
        "AGB Tree",
        "Dry Day Period -1 Month",
        "SWI(1) 3NN",
        "Dry Day Period -9 Month",
        *veg_lag_product,
    )
    for veg_lag_product in product(*veg_lags)
]

assert all(len(combination) == 15 for combination in combinations)

len(combinations)

In [None]:
scores = dask_fit_combinations(
    DaskRandomForestRegressor(**param_dict),
    X_train,
    y_train,
    X_test,
    y_test,
    client,
    combinations,
    local_n_jobs=max(get_ncpus() - 2, 1),
    verbose=True,
    cache_dir=CACHE_DIR,
)

In [None]:
keys = list(scores)

In [None]:
r2_scores = np.array([results["test_score"]["r2"] for results in scores.values()])
mse_scores = np.array([results["test_score"]["mse"] for results in scores.values()])

assert np.argmax(r2_scores) == np.argmin(mse_scores)

indices = np.argsort(r2_scores)[::-1]
fig, ax = plt.subplots()
ax.plot(r2_scores[indices][:10])
ax2 = ax.twinx()
ax2.plot(mse_scores[indices][:10], c="C1")

np.max(r2_scores), keys[np.argmax(r2_scores)]

In [None]:
train_r2_scores = np.array(
    [results["train_score"]["r2"] for results in scores.values()]
)
train_mse_scores = np.array(
    [results["train_score"]["mse"] for results in scores.values()]
)

assert np.argmax(train_r2_scores) == np.argmin(train_mse_scores)

train_indices = np.argsort(train_r2_scores)[::-1]
fig, ax = plt.subplots()
ax.plot(train_r2_scores[train_indices][:10])
ax2 = ax.twinx()
ax2.plot(train_mse_scores[train_indices][:10], c="C1")

np.max(train_r2_scores), keys[np.argmax(train_r2_scores)]