Skip to content

Commit

Permalink
fix issue 3899, call self.save after update _time_fit_training, add t… (
Browse files Browse the repository at this point in the history
autogluon#3900)

Co-authored-by: Oleksandr Shchur <shchuro@amazon.com>
  • Loading branch information
2 people authored and LennartPurucker committed Jun 1, 2024
1 parent 0f82ad2 commit 75be8c4
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
9 changes: 9 additions & 0 deletions full_install.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

python -m pip install -e common/[tests]
python -m pip install -e core/[all,tests]
python -m pip install -e features/
python -m pip install -e tabular/[all,tests]
python -m pip install -e multimodal/[tests]
python -m pip install -e timeseries/[all,tests]
python -m pip install -e eda/
python -m pip install -e autogluon/
2 changes: 1 addition & 1 deletion timeseries/src/autogluon/timeseries/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def _fit(
excluded_model_types=kwargs.get("excluded_model_types"),
time_limit=time_limit,
)
self.save_trainer(trainer=self.trainer)

self._time_fit_training = time.time() - time_start
self.save()

def _align_covariates_with_forecast_index(
self,
Expand Down
37 changes: 26 additions & 11 deletions timeseries/tests/unittests/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,18 @@ def test_when_fit_summary_is_called_then_all_keys_and_models_are_included(
assert len(fit_summary[key]) == num_models


EXPECTED_INFO_KEYS = [
"path",
"version",
"time_fit_training",
"time_limit",
"best_model",
"best_model_score_val",
"num_models_trained",
"model_info",
]


@pytest.mark.parametrize(
"hyperparameters, num_models",
[
Expand All @@ -421,23 +433,26 @@ def test_when_fit_summary_is_called_then_all_keys_and_models_are_included(
def test_when_info_is_called_then_all_keys_and_models_are_included(temp_model_path, hyperparameters, num_models):
predictor = TimeSeriesPredictor(path=temp_model_path)
predictor.fit(DUMMY_TS_DATAFRAME, hyperparameters=hyperparameters)
expected_keys = [
"path",
"version",
"time_fit_training",
"time_limit",
"best_model",
"best_model_score_val",
"num_models_trained",
"model_info",
]
info = predictor.info()
for key in expected_keys:
for key in EXPECTED_INFO_KEYS:
assert key in info

assert len(info["model_info"]) == num_models


def test_when_predictor_is_loaded_then_info_works(temp_model_path):
predictor = TimeSeriesPredictor(path=temp_model_path, prediction_length=2)
predictor.fit(train_data=DUMMY_TS_DATAFRAME, hyperparameters=DUMMY_HYPERPARAMETERS)
predictor.save()
del predictor
predictor = TimeSeriesPredictor.load(temp_model_path)
info = predictor.info()
for key in EXPECTED_INFO_KEYS:
assert key in info

assert len(info["model_info"]) == len(DUMMY_HYPERPARAMETERS) + 1 # + 1 for ensemble


def test_when_train_data_contains_nans_then_predictor_can_fit(temp_model_path):
predictor = TimeSeriesPredictor(path=temp_model_path)
df = DUMMY_TS_DATAFRAME.copy()
Expand Down

0 comments on commit 75be8c4

Please sign in to comment.