Skip to content

Commit

Permalink
Refactor to simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
dsherry committed May 28, 2020
1 parent 3936c3e commit 0b81bd3
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions evalml/automl/automl_algorithm/iterative_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,15 @@ def next_batch(self):
raise StopIteration('No more batches available.')
next_batch = []
if self._batch_number == 0:
next_batch = [self._init_pipeline(cls, {}) for cls in self.allowed_pipelines]
next_batch = [pipeline_class(parameters=self._transform_parameters(pipeline_class, {}))
for pipeline_class in self.allowed_pipelines]
else:
_, pipeline_class = self._pop_best_in_batch()
if pipeline_class is None:
raise AutoMLAlgorithmException('Some results are needed before the next automl batch can be computed.')
for i in range(self.pipelines_per_batch):
proposed_parameters = self._tuners[pipeline_class.name].propose()
next_batch.append(self._init_pipeline(pipeline_class, proposed_parameters))
next_batch.append(pipeline_class(parameters=self._transform_parameters(pipeline_class, proposed_parameters)))
self._pipeline_number += len(next_batch)
self._batch_number += 1
return next_batch
Expand Down Expand Up @@ -95,11 +96,6 @@ def _pop_best_in_batch(self):
best_idx = idx
return self._first_batch_results.pop(best_idx)

def _init_pipeline(self, pipeline_class, parameters):
"""Given a pipeline class and a parameters dict, return a pipeline instance ready for training."""
parameters = self._transform_parameters(pipeline_class, parameters)
return pipeline_class(parameters=parameters)

def _transform_parameters(self, pipeline_class, proposed_parameters):
"""Given a pipeline parameters dict, make sure n_jobs and number_features are set."""
parameters = {}
Expand Down

0 comments on commit 0b81bd3

Please sign in to comment.