Skip to content

Commit

Permalink
[Rebase]
Browse files Browse the repository at this point in the history
  • Loading branch information
franchuterivera committed May 3, 2021
1 parent c45087b commit 452c629
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):

# Internal dataset has expected settings
assert estimator.dataset.task_type == 'tabular_classification'
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 5
assert estimator.resampling_strategy == resampling_strategy
assert estimator.dataset.resampling_strategy == resampling_strategy
assert len(estimator.dataset.splits) == expected_num_splits
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):
model = estimator._backend.load_cv_model_by_seed_and_id_and_budget(
estimator.seed, successful_num_run, run_key.budget)
assert isinstance(model, VotingClassifier)
assert len(model.estimators_) == 3
assert len(model.estimators_) == 5
assert isinstance(model.estimators_[0].named_steps['network'].get_network(),
torch.nn.Module)
else:
Expand Down Expand Up @@ -243,7 +243,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):

# Internal dataset has expected settings
assert estimator.dataset.task_type == 'tabular_regression'
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3
expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 5
assert estimator.resampling_strategy == resampling_strategy
assert estimator.dataset.resampling_strategy == resampling_strategy
assert len(estimator.dataset.splits) == expected_num_splits
Expand Down Expand Up @@ -310,7 +310,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):
model = estimator._backend.load_cv_model_by_seed_and_id_and_budget(
estimator.seed, successful_num_run, run_key.budget)
assert isinstance(model, VotingRegressor)
assert len(model.estimators_) == 3
assert len(model.estimators_) == 5
assert isinstance(model.estimators_[0].named_steps['network'].get_network(),
torch.nn.Module)
else:
Expand Down

0 comments on commit 452c629

Please sign in to comment.