In [None]:
#| default_exp common._base_auto

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

# Base AutoModel class

In [None]:
#| hide
from fastcore.test import test_eq
from nbdev.showdoc import show_doc

In [None]:
#| export
from copy import deepcopy
from os import cpu_count

import torch
from pytorch_lightning.callbacks import TQDMProgressBar
from ray import air, tune
from ray.tune.integration.pytorch_lightning import TuneReportCallback
from ray.tune.search.basic_variant import BasicVariantGenerator

from neuralforecast.losses.pytorch import MAE

In [None]:
#| exporti
def train_tune(config_step, cls_model, dataset, val_size, test_size):
    metrics = {"loss": "ptl/val_loss"}
    callbacks = [TQDMProgressBar(), TuneReportCallback(metrics, on="validation_end")]
    if 'callbacks' in config_step.keys():
        callbacks += config_step['callbacks']
    config_step = {**config_step, **{'callbacks': callbacks}}
    model = cls_model(**config_step)
    model.fit(
        dataset,
        val_size=val_size, 
        test_size=test_size
    )

In [None]:
#| exporti
def tune_model(
        cls_model, 
        dataset, 
        val_size, 
        test_size,
        cpus,
        gpus,
        verbose,
        num_samples, 
        search_alg, 
        config
    ):
    train_fn_with_parameters = tune.with_parameters(
        train_tune,
        cls_model=cls_model,
        dataset=dataset,
        val_size=val_size,
        test_size=test_size,
    )

    # Device
    if gpus > 0:
        device_dict = {'gpu':gpus}
    else:
        device_dict = {'cpu':cpus}

    tuner = tune.Tuner(
        tune.with_resources(train_fn_with_parameters, device_dict),
        run_config=air.RunConfig(
            verbose=verbose,
            #checkpoint_config=air.CheckpointConfig(
                #num_to_keep=0,
                #keep_checkpoints_num=None
            #)
        ),
        tune_config=tune.TuneConfig(
            metric="loss",
            mode="min",
            num_samples=num_samples, 
            search_alg=search_alg
        ),
        param_space=config,
    )
    results = tuner.fit()
    return results

In [None]:
#| export
# If None overwrite with default


In [None]:
#| export
class BaseAuto:
    
    def __init__(self, 
                 cls_model,
                 h,
                 config, 
                 search_alg=BasicVariantGenerator(random_state=1),
                 num_samples=10,
                 cpus=cpu_count(),
                 gpus=torch.cuda.device_count(),
                 refit_wo_val=False,
                 verbose=False):
        
        config['h'] = h
        self.cls_model = cls_model
        self.h = h
        self.config = config
        self.num_samples = num_samples
        self.search_alg = search_alg
        self.cpus = cpus
        self.gpus = gpus
        self.refit_wo_val = refit_wo_val
        self.verbose = verbose
        self.loss = self.config.get('loss', MAE())
        
    def fit(self, dataset, val_size=0, test_size=0):
        #we need val_size > 0 to perform
        #hyperparameter selection.
        search_alg = deepcopy(self.search_alg)
        val_size = val_size if val_size > 0 else self.h
        results = tune_model(
            cls_model=self.cls_model,
            dataset=dataset,
            val_size=val_size, 
            test_size=test_size, 
            cpus=self.cpus,
            gpus=self.gpus,
            verbose=self.verbose,
            num_samples=self.num_samples, 
            search_alg=search_alg, 
            config=self.config
        )
        best_config = results.get_best_result().config
        self.model = self.cls_model(**best_config)
        self.model.fit(
            dataset=dataset, 
            val_size=val_size * (1 - self.refit_wo_val), 
            test_size=test_size,
        )
        self.results = results
        
    def predict(self, dataset, step_size=1, **data_kwargs):
        return self.model.predict(dataset=dataset, 
                                  step_size=step_size, **data_kwargs)

In [None]:
#| hide
import logging
import warnings
warnings.filterwarnings("ignore")

import pytorch_lightning as pl
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")

In [None]:
#| hide
import pandas as pd
from neuralforecast.models.mlp import MLP
from neuralforecast.utils import AirPassengersDF as Y_df
from neuralforecast.tsdataset import TimeSeriesDataset

Y_train_df = Y_df[Y_df.ds<='1959-12-31'] # 132 train
Y_test_df = Y_df[Y_df.ds>'1959-12-31']   # 12 test

dataset, *_ = TimeSeriesDataset.from_df(Y_train_df)
config = {
    "hidden_size": tune.choice([512]),
    "num_layers": tune.choice([3, 4]),
    "input_size": 12,
    "h": 12,
    "max_epochs": 10
}
auto = BaseAuto(h=12, cls_model=MLP, config=config, num_samples=2, cpus=1, gpus=0)
auto.fit(dataset=dataset)
y_hat = auto.predict(dataset=dataset)

In [None]:
#| hide
Y_test_df['AutoMLP'] = y_hat

pd.concat([Y_train_df, Y_test_df]).drop('unique_id', axis=1).set_index('ds').plot()