Skip to content

Commit

Permalink
[timeseries] Fix loading of Tabular models failing if predictor moved…
Browse files Browse the repository at this point in the history
… to a different directory (#4171)
  • Loading branch information
shchur committed May 7, 2024
1 parent 964e114 commit 309a58e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
Expand Up @@ -88,6 +88,27 @@ def __init__(
self._train_target_median: Optional[float] = None
self._non_boolean_real_covariates: List[str] = []

@property
def tabular_predictor_path(self) -> str:
return os.path.join(self.path, "tabular_predictor")

def save(self, path: str = None, verbose: bool = True) -> str:
assert "mean" in self._mlf.models_, "TabularPredictor must be trained before saving"
tabular_predictor = self._mlf.models_["mean"].predictor
self._mlf.models_["mean"].predictor = None
save_path = super().save(path=path, verbose=verbose)
self._mlf.models_["mean"].predictor = tabular_predictor
return save_path

@classmethod
def load(
cls, path: str, reset_paths: bool = True, load_oof: bool = False, verbose: bool = True
) -> "AbstractTimeSeriesModel":
model = super().load(path=path, reset_paths=reset_paths, load_oof=load_oof, verbose=verbose)
assert "mean" in model._mlf.models_, "Loaded model doesn't have a trained TabularPredictor"
model._mlf.models_["mean"].predictor = TabularPredictor.load(model.tabular_predictor_path)
return model

def preprocess(self, data: TimeSeriesDataFrame, is_train: bool = False, **kwargs) -> Any:
if is_train:
# All-NaN series are removed; partially-NaN series in train_data are handled inside _generate_train_val_dfs
Expand Down Expand Up @@ -295,7 +316,7 @@ def _fit(

estimator = TabularEstimator(
predictor_init_kwargs={
"path": os.path.join(self.path, "tabular_predictor"),
"path": self.tabular_predictor_path,
"verbosity": verbosity - 2,
"label": MLF_TARGET,
**self._get_extra_tabular_init_kwargs(),
Expand Down
25 changes: 25 additions & 0 deletions timeseries/tests/unittests/models/test_mlforecast.py
@@ -1,3 +1,6 @@
import os
import shutil
import tempfile
from unittest import mock

import numpy as np
Expand Down Expand Up @@ -252,3 +255,25 @@ def test_given_train_data_has_nans_when_fit_called_then_nan_rows_removed_from_tr
model.fit(train_data=data)
train_df, val_df = model._generate_train_val_dfs(model.preprocess(data, is_train=True))
assert len(train_df) + len(val_df) == len(data.dropna())


@pytest.mark.parametrize("model_type", TESTABLE_MODELS)
@pytest.mark.parametrize("eval_metric", ["WAPE", "WQL"])
def test_when_trained_model_moved_to_different_folder_then_loaded_model_can_predict(model_type, eval_metric):
data = DUMMY_TS_DATAFRAME.copy().sort_index()
old_model_dir = tempfile.mkdtemp()
model = model_type(
path=old_model_dir,
freq=data.freq,
eval_metric=eval_metric,
quantile_levels=[0.1, 0.5, 0.9],
prediction_length=3,
hyperparameters={"differences": []},
)
model.fit(train_data=data)
model.save()
new_model_dir = tempfile.mkdtemp()
shutil.move(model.path, new_model_dir)
loaded_model = model_type.load(os.path.join(new_model_dir, model.name))
predictions = loaded_model.predict(data)
assert isinstance(predictions, TimeSeriesDataFrame)

0 comments on commit 309a58e

Please sign in to comment.