diff --git a/tests/tpot_tests.py b/tests/tpot_tests.py index 0088833c..d8f878e8 100644 --- a/tests/tpot_tests.py +++ b/tests/tpot_tests.py @@ -964,12 +964,39 @@ def test_fit_4(): tpot_obj.generations == 20 tpot_obj.fit(training_features, training_target) - + assert tpot_obj._pop == [] assert isinstance(tpot_obj._optimized_pipeline, creator.Individual) assert not (tpot_obj._start_datetime is None) def test_fit_5(): + """Assert that the TPOT fit function provides an optimized pipeline with max_time_mins of 2 second with warm_start=True.""" + tpot_obj = TPOTClassifier( + random_state=42, + population_size=2, + generations=None, + verbosity=0, + max_time_mins=3/60., + config_dict='TPOT light', + warm_start=True + ) + tpot_obj._fit_init() + assert tpot_obj.generations == 1000000 + + # reset generations to 20 just in case that the failed test may take too much time + tpot_obj.generations == 20 + + tpot_obj.fit(training_features, training_target) + assert tpot_obj._pop != [] + assert isinstance(tpot_obj._optimized_pipeline, creator.Individual) + assert not (tpot_obj._start_datetime is None) + # rerun it + tpot_obj.fit(training_features, training_target) + assert tpot_obj._pop != [] + + + +def test_fit_6(): """Assert that the TPOT fit function provides an optimized pipeline with pandas DataFrame""" tpot_obj = TPOTClassifier( random_state=42, diff --git a/tpot/base.py b/tpot/base.py index 7bfc56e5..9b003ce7 100644 --- a/tpot/base.py +++ b/tpot/base.py @@ -752,10 +752,6 @@ def pareto_eq(ind1, ind2): per_generation_function=self._check_periodic_pipeline ) - # store population for the next call - if self.warm_start: - self._pop = pop - # Allow for certain exceptions to signal a premature fit() cancellation except (KeyboardInterrupt, SystemExit, StopIteration) as e: if self.verbosity > 0: @@ -763,6 +759,9 @@ def pareto_eq(ind1, ind2): self._pbar.write('{}\nTPOT closed prematurely. Will use the current best pipeline.'.format(e), file=self._file) finally: + # clean population for the next call if warm_start=False + if not self.warm_start: + self._pop = [] # keep trying 10 times in case weird things happened like multiple CTRL+C or exceptions attempts = 10 for attempt in range(attempts): @@ -1383,6 +1382,8 @@ def _evaluate_individuals(self, population, features, target, sample_weight=None ind.fitness.values = (5000.,-float('inf')) self._pareto_front.update(population) + + self._pop = population raise KeyboardInterrupt self._update_evaluated_individuals_(result_score_list, eval_individuals_str, operator_counts, stats_dicts)