Skip to content

Commit

Permalink
Merge a0c0c40 into 828f3d0
Browse files Browse the repository at this point in the history
  • Loading branch information
weixuanfu committed Nov 5, 2019
2 parents 828f3d0 + a0c0c40 commit be766df
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 13 deletions.
11 changes: 11 additions & 0 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,17 @@ def test_template_4():
assert issubclass(sklearn_pipeline.steps[2][1].__class__, ClassifierMixin)


def test_template_5():
"""Assert that TPOT rasie ValueError when template parameter is invalid."""

tpot_obj = TPOTClassifier(
random_state=42,
verbosity=0,
template='SelectPercentile-Transformer-Classifie' # a typ in Classifier
)
assert_raises(ValueError, tpot_obj._fit_init)


def test_fit_GroupKFold():
"""Assert that TPOT properly handles the group parameter when using GroupKFold."""
# This check tests if the darker digits images would generalize to the lighter ones.
Expand Down
32 changes: 19 additions & 13 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,22 +460,28 @@ def _add_operators(self):
ret_types.append(step_ret_type)
else:
step_ret_type = Output_Array
check_template = True
if step == 'CombineDFs':
self._pset.addPrimitive(CombineDFs(), [step_in_type, step_in_type], step_in_type)
elif main_type.count(step): # if the step is a main type
for operator in self.operators:
ops = [op for op in self.operators if op.type() == step]
for operator in ops:
arg_types = operator.parameter_types()[0][1:]
if operator.type() == step:
p_types = ([step_in_type] + arg_types, step_ret_type)
self._pset.addPrimitive(operator, *p_types)
self._import_hash_and_add_terminals(operator, arg_types)
else: # is the step is a specific operator
for operator in self.operators:
arg_types = operator.parameter_types()[0][1:]
if operator.__name__ == step:
p_types = ([step_in_type] + arg_types, step_ret_type)
self._pset.addPrimitive(operator, *p_types)
self._import_hash_and_add_terminals(operator, arg_types)
p_types = ([step_in_type] + arg_types, step_ret_type)
self._pset.addPrimitive(operator, *p_types)
self._import_hash_and_add_terminals(operator, arg_types)
else: # is the step is a specific operator or a wrong input
try:
operator = next(op for op in self.operators if op.__name__ == step)
except:
raise ValueError(
'An error occured while attempting to read the specified '
'template. Please check a step named {}'.format(step)
)
arg_types = operator.parameter_types()[0][1:]
p_types = ([step_in_type] + arg_types, step_ret_type)
self._pset.addPrimitive(operator, *p_types)
self._import_hash_and_add_terminals(operator, arg_types)
self.ret_types = [np.ndarray, Output_Array] + ret_types


Expand Down Expand Up @@ -1382,7 +1388,7 @@ def _evaluate_individuals(self, population, features, target, sample_weight=None
ind.fitness.values = (5000.,-float('inf'))

self._pareto_front.update(population)

self._pop = population
raise KeyboardInterrupt

Expand Down

0 comments on commit be766df

Please sign in to comment.