Skip to content

Commit

Permalink
Merge pull request #949 from weixuanfu/fix-maxtime-warmstart
Browse files Browse the repository at this point in the history
Fix the bug that warm_start is not working without default max_time_mins
  • Loading branch information
weixuanfu committed Nov 5, 2019
2 parents f6303fe + 73cec4e commit f8d9ce8
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
29 changes: 28 additions & 1 deletion tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,12 +964,39 @@ def test_fit_4():
tpot_obj.generations == 20

tpot_obj.fit(training_features, training_target)

assert tpot_obj._pop == []
assert isinstance(tpot_obj._optimized_pipeline, creator.Individual)
assert not (tpot_obj._start_datetime is None)


def test_fit_5():
"""Assert that the TPOT fit function provides an optimized pipeline with max_time_mins of 2 second with warm_start=True."""
tpot_obj = TPOTClassifier(
random_state=42,
population_size=2,
generations=None,
verbosity=0,
max_time_mins=3/60.,
config_dict='TPOT light',
warm_start=True
)
tpot_obj._fit_init()
assert tpot_obj.generations == 1000000

# reset generations to 20 just in case that the failed test may take too much time
tpot_obj.generations == 20

tpot_obj.fit(training_features, training_target)
assert tpot_obj._pop != []
assert isinstance(tpot_obj._optimized_pipeline, creator.Individual)
assert not (tpot_obj._start_datetime is None)
# rerun it
tpot_obj.fit(training_features, training_target)
assert tpot_obj._pop != []



def test_fit_6():
"""Assert that the TPOT fit function provides an optimized pipeline with pandas DataFrame"""
tpot_obj = TPOTClassifier(
random_state=42,
Expand Down
9 changes: 5 additions & 4 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,17 +752,16 @@ def pareto_eq(ind1, ind2):
per_generation_function=self._check_periodic_pipeline
)

# store population for the next call
if self.warm_start:
self._pop = pop

# Allow for certain exceptions to signal a premature fit() cancellation
except (KeyboardInterrupt, SystemExit, StopIteration) as e:
if self.verbosity > 0:
self._pbar.write('', file=self._file)
self._pbar.write('{}\nTPOT closed prematurely. Will use the current best pipeline.'.format(e),
file=self._file)
finally:
# clean population for the next call if warm_start=False
if not self.warm_start:
self._pop = []
# keep trying 10 times in case weird things happened like multiple CTRL+C or exceptions
attempts = 10
for attempt in range(attempts):
Expand Down Expand Up @@ -1383,6 +1382,8 @@ def _evaluate_individuals(self, population, features, target, sample_weight=None
ind.fitness.values = (5000.,-float('inf'))

self._pareto_front.update(population)

self._pop = population
raise KeyboardInterrupt

self._update_evaluated_individuals_(result_score_list, eval_individuals_str, operator_counts, stats_dicts)
Expand Down

0 comments on commit f8d9ce8

Please sign in to comment.