In [95]:
from src.utils.config_loader import load_yaml_config
from src.tuning.param_grid import generate_param_grid
from pprint import pprint as pp

# Change below to the path of your config file
config = load_yaml_config("../../src/tuning/configs/modset2.yaml")
# Specify the model name you want to test params for (should match as labelled in the config)
param_grid = generate_param_grid("VAR", config)

pp(param_grid)

{'maxlags': [5, 10, 15]}


In [96]:
# For testing only: this can be used to simplify the dictionary returned by `get_params()`
# so that the param grid has only one value for each parameter
def simplify_dict(d):
    for key, value in d.items():
        if isinstance(value, list):
            # Keep only the first element of the list (max length 1)
            d[key] = [
                value[0]
            ]  # If you want to test other params, change this to a different index
    return d


# need this version of the function for DartsXGBModel to work
# def simplify_dict(d):
#     for key, value in d.items():
#         if isinstance(value, list):
#             if key == 'lags_past_covariates':
#                 # For lags_past_covariates: must be positive integer or list
#                 valid_value = next(
#                     (v for v in value if v is not None and (isinstance(v, int) or isinstance(v, list))),
#                     3  # Default to 3 if no valid value found
#                 )
#                 d[key] = [valid_value]
#             elif key == 'lags_future_covariates':
#                 # For lags_future_covariates: must be tuple or list
#                 valid_value = next(
#                     (v for v in value if v is not None and isinstance(v, list)),
#                     [3, 3]  # Default to [3, 3] if no valid value found
#                 )
#                 d[key] = [valid_value]
#             elif key == 'lags':
#                 # For lags: similar to lags_past_covariates
#                 valid_value = next(
#                     (v for v in value if v is not None and (isinstance(v, int) or isinstance(v, list))),
#                     3  # Default to 3 if no valid value found
#                 )
#                 d[key] = [valid_value]
#             elif key == 'add_encoders':
#                 # For add_encoders: take the first non-None dictionary value
#                 valid_value = next(
#                     (v for v in value if v is not None and isinstance(v, dict)),
#                     {"cyclic": {"future": ["month"]}}  # Default encoder if no valid value found
#                 )
#                 d[key] = [valid_value]
#             else:
#                 # For other parameters, take first non-None value if available
#                 d[key] = [value[1] if value[0] is None and len(value) > 1 else value[0]]
#     return d

In [97]:
# To speed up testing we will only test one value for each parameter
param_grid = simplify_dict(param_grid)
print(param_grid)

{'maxlags': [5]}


In [98]:
from sktime.forecasting.model_selection import ForecastingGridSearchCV
from sktime.split import ExpandingSlidingWindowSplitter
from sktime.performance_metrics.forecasting import MeanSquaredError
from sktime.datasets import load_airline
import pandas as pd
import numpy as np
from sktime.forecasting.var import VAR  # Change this to the model you want to test

# Not the actual data we will use, just a placeholder for simple testing
# y = load_airline()

y_airline = load_airline()  # univariate
y2 = y_airline + 10 * np.random.randn(len(y_airline))
y_multi = pd.DataFrame({"airline": y_airline, "airline2": y2})


fh = [1, 2, 3, 4, 5, 6]
cv = ExpandingSlidingWindowSplitter(
    fh=fh, initial_window=12, step_length=12, max_expanding_window_length=24 * 12
)

# Specify the forecaster you generated the param grid for
forecaster = VAR()

gscv = ForecastingGridSearchCV(
    forecaster=forecaster,
    # Simplify the dictionary so only one set of values are tested
    param_grid=param_grid,
    cv=cv,
    scoring=MeanSquaredError(square_root=True),
    # Raise errors so we can see what params are causing errors
    error_score="raise",
)
gscv.fit(y_multi, fh=fh)
y_pred = gscv.predict(fh)

In [99]:
y_pred

Unnamed: 0_level_0,airline,airline2
Period,Unnamed: 1_level_1,Unnamed: 2_level_1
1961-01,489.279425,486.546922
1961-02,505.165205,498.169038
1961-03,515.562224,513.889995
1961-04,498.213958,496.290932
1961-05,473.93355,473.563491
1961-06,462.10474,461.462882


In [100]:
gscv.best_params_

{'maxlags': 5}