# HParam Search Example

In [None]:
%cd ..

In [None]:
from darts.metrics import mae
from src.pipeline.pipeline import ExperimentPipeline
from darts.models import RNNModel
from src.pipeline.experiment import Experiment, HyperParameter, BayesOptHyperParameter
from darts.dataprocessing.transformers import Scaler
from darts.dataprocessing import Pipeline
from darts.models.forecasting.forecasting_model import LocalForecastingModel

We first define our dataset and algorithm

In [None]:
# define dataset
dataset = 'traffic'  # ['exchange_rate', 'traffic', 'electricity'], only uses first covariate

In [None]:
# set DARTS model class
model = RNNModel

In [None]:
# define preprocessing
preprocessing = Pipeline([
    Scaler()
])

The we define our hyperparameters, which can either be set as `HyperParameters` or sampled from a distribution with `BayesOptHyperparameter`

for more info on which method are available, see documentation [here](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html)

In [None]:
hparams  = [
    HyperParameter(
        name='model',
        value='LSTM'
    ),
    HyperParameter(
        name='hidden_dim',
        value=20,
    ),
    HyperParameter(
        name='input_chunk_length',
        value=24*7,
    ),
    HyperParameter(
        name='n_epochs',
        value=10
    ),
    HyperParameter(
        name='optimizer_kwargs',
        value= {"lr": 1e-3}
    )
]

Once this is done we can define our experiment

In [None]:
params = Experiment(
    dataset=dataset,
    preprocessing=preprocessing,
    model=model,
    hyper_parameters=hparams,
    metric=mae,
    horizon=24*7,           # Horizon for prediction
    optuna_timeout=60,   # Time allocated for HParam search in seconds
    n_backtest=100,           # Number of validation samples for the backtest, means len(valid_dataset)
    n_train_samples=500     # Number of samples for our .fit(), means len(train_dataset)
)

The we run the experiment and save the logs in our ipynb

In [None]:
# get loads of warnings
import warnings
warnings.simplefilter("ignore", category=UserWarning)

In [None]:
pipeline = ExperimentPipeline(params)
pipeline.run()

In [None]:
series = pipeline.data['train'].append(pipeline.data['valid'])

model = RNNModel(
    model='LSTM',
    hidden_dim=20,
    input_chunk_length=24*7,
    n_epochs=10,
    optimizer_kwargs= {"lr": 1e-3},
)

model.fit(series)

preds = model.predict(n=24*7)
preds.plot(label='prediction')

pipeline.data['test'][:24*7].plot(label='truth')

In [None]:
# evaluate with full backtest on test data

model.backtest(
    series=series.append(pipeline.data['test']) + 1e-9,
    start=len(series),
    forecast_horizon=24*7,
    retrain=isinstance(model, LocalForecastingModel),
    metric=mae,
)